parent
e04cd18846
commit
8f5e61090b
@ -0,0 +1,68 @@
|
|||||||
|
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
|
||||||
|
# See the LICENSE file for licensing terms (BSD-style).
|
||||||
|
# Modified from https://github.com/webdataset/webdataset
|
||||||
|
#
|
||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
from .cache import (
|
||||||
|
cached_tarfile_samples,
|
||||||
|
cached_tarfile_to_samples,
|
||||||
|
lru_cleanup,
|
||||||
|
pipe_cleaner,
|
||||||
|
)
|
||||||
|
from .compat import WebDataset, WebLoader, FluidWrapper
|
||||||
|
from webdataset.extradatasets import MockDataset, with_epoch, with_length
|
||||||
|
from .filters import (
|
||||||
|
associate,
|
||||||
|
batched,
|
||||||
|
decode,
|
||||||
|
detshuffle,
|
||||||
|
extract_keys,
|
||||||
|
getfirst,
|
||||||
|
info,
|
||||||
|
map,
|
||||||
|
map_dict,
|
||||||
|
map_tuple,
|
||||||
|
pipelinefilter,
|
||||||
|
rename,
|
||||||
|
rename_keys,
|
||||||
|
rsample,
|
||||||
|
select,
|
||||||
|
shuffle,
|
||||||
|
slice,
|
||||||
|
to_tuple,
|
||||||
|
transform_with,
|
||||||
|
unbatched,
|
||||||
|
xdecode,
|
||||||
|
data_filter,
|
||||||
|
tokenize,
|
||||||
|
resample,
|
||||||
|
compute_fbank,
|
||||||
|
spec_aug,
|
||||||
|
sort,
|
||||||
|
padding,
|
||||||
|
cmvn
|
||||||
|
)
|
||||||
|
from webdataset.handlers import (
|
||||||
|
ignore_and_continue,
|
||||||
|
ignore_and_stop,
|
||||||
|
reraise_exception,
|
||||||
|
warn_and_continue,
|
||||||
|
warn_and_stop,
|
||||||
|
)
|
||||||
|
from .pipeline import DataPipeline
|
||||||
|
from .shardlists import (
|
||||||
|
MultiShardSample,
|
||||||
|
ResampledShards,
|
||||||
|
SimpleShardList,
|
||||||
|
non_empty,
|
||||||
|
resampled,
|
||||||
|
shardspec,
|
||||||
|
single_node_only,
|
||||||
|
split_by_node,
|
||||||
|
split_by_worker,
|
||||||
|
)
|
||||||
|
from .tariterators import tarfile_samples, tarfile_to_samples
|
||||||
|
from .utils import PipelineStage, repeatedly
|
||||||
|
from webdataset.writer import ShardWriter, TarWriter, numpy_dumps
|
||||||
|
from webdataset.mix import RandomMix, RoundRobin
|
@ -0,0 +1,190 @@
|
|||||||
|
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
# See the LICENSE file for licensing terms (BSD-style).
|
||||||
|
# Modified from https://github.com/webdataset/webdataset
|
||||||
|
import itertools, os, random, re, sys
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from . import filters
|
||||||
|
from webdataset import gopen
|
||||||
|
from webdataset.handlers import reraise_exception
|
||||||
|
from .tariterators import tar_file_and_group_expander
|
||||||
|
|
||||||
|
default_cache_dir = os.environ.get("WDS_CACHE", "./_cache")
|
||||||
|
default_cache_size = float(os.environ.get("WDS_CACHE_SIZE", "1e18"))
|
||||||
|
|
||||||
|
|
||||||
|
def lru_cleanup(cache_dir, cache_size, keyfn=os.path.getctime, verbose=False):
|
||||||
|
"""Performs cleanup of the file cache in cache_dir using an LRU strategy,
|
||||||
|
keeping the total size of all remaining files below cache_size."""
|
||||||
|
if not os.path.exists(cache_dir):
|
||||||
|
return
|
||||||
|
total_size = 0
|
||||||
|
for dirpath, dirnames, filenames in os.walk(cache_dir):
|
||||||
|
for filename in filenames:
|
||||||
|
total_size += os.path.getsize(os.path.join(dirpath, filename))
|
||||||
|
if total_size <= cache_size:
|
||||||
|
return
|
||||||
|
# sort files by last access time
|
||||||
|
files = []
|
||||||
|
for dirpath, dirnames, filenames in os.walk(cache_dir):
|
||||||
|
for filename in filenames:
|
||||||
|
files.append(os.path.join(dirpath, filename))
|
||||||
|
files.sort(key=keyfn, reverse=True)
|
||||||
|
# delete files until we're under the cache size
|
||||||
|
while len(files) > 0 and total_size > cache_size:
|
||||||
|
fname = files.pop()
|
||||||
|
total_size -= os.path.getsize(fname)
|
||||||
|
if verbose:
|
||||||
|
print("# deleting %s" % fname, file=sys.stderr)
|
||||||
|
os.remove(fname)
|
||||||
|
|
||||||
|
|
||||||
|
def download(url, dest, chunk_size=1024 ** 2, verbose=False):
|
||||||
|
"""Download a file from `url` to `dest`."""
|
||||||
|
temp = dest + f".temp{os.getpid()}"
|
||||||
|
with gopen.gopen(url) as stream:
|
||||||
|
with open(temp, "wb") as f:
|
||||||
|
while True:
|
||||||
|
data = stream.read(chunk_size)
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
f.write(data)
|
||||||
|
os.rename(temp, dest)
|
||||||
|
|
||||||
|
|
||||||
|
def pipe_cleaner(spec):
|
||||||
|
"""Guess the actual URL from a "pipe:" specification."""
|
||||||
|
if spec.startswith("pipe:"):
|
||||||
|
spec = spec[5:]
|
||||||
|
words = spec.split(" ")
|
||||||
|
for word in words:
|
||||||
|
if re.match(r"^(https?|gs|ais|s3)", word):
|
||||||
|
return word
|
||||||
|
return spec
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_cached(
|
||||||
|
spec,
|
||||||
|
cache_size=-1,
|
||||||
|
cache_dir=None,
|
||||||
|
url_to_name=pipe_cleaner,
|
||||||
|
verbose=False,
|
||||||
|
):
|
||||||
|
if cache_size == -1:
|
||||||
|
cache_size = default_cache_size
|
||||||
|
if cache_dir is None:
|
||||||
|
cache_dir = default_cache_dir
|
||||||
|
url = url_to_name(spec)
|
||||||
|
parsed = urlparse(url)
|
||||||
|
dirname, filename = os.path.split(parsed.path)
|
||||||
|
dirname = dirname.lstrip("/")
|
||||||
|
dirname = re.sub(r"[:/|;]", "_", dirname)
|
||||||
|
destdir = os.path.join(cache_dir, dirname)
|
||||||
|
os.makedirs(destdir, exist_ok=True)
|
||||||
|
dest = os.path.join(cache_dir, dirname, filename)
|
||||||
|
if not os.path.exists(dest):
|
||||||
|
if verbose:
|
||||||
|
print("# downloading %s to %s" % (url, dest), file=sys.stderr)
|
||||||
|
lru_cleanup(cache_dir, cache_size, verbose=verbose)
|
||||||
|
download(spec, dest, verbose=verbose)
|
||||||
|
return dest
|
||||||
|
|
||||||
|
|
||||||
|
def get_filetype(fname):
|
||||||
|
with os.popen("file '%s'" % fname) as f:
|
||||||
|
ftype = f.read()
|
||||||
|
return ftype
|
||||||
|
|
||||||
|
|
||||||
|
def check_tar_format(fname):
|
||||||
|
"""Check whether a file is a tar archive."""
|
||||||
|
ftype = get_filetype(fname)
|
||||||
|
return "tar archive" in ftype or "gzip compressed" in ftype
|
||||||
|
|
||||||
|
|
||||||
|
verbose_cache = int(os.environ.get("WDS_VERBOSE_CACHE", "0"))
|
||||||
|
|
||||||
|
|
||||||
|
def cached_url_opener(
|
||||||
|
data,
|
||||||
|
handler=reraise_exception,
|
||||||
|
cache_size=-1,
|
||||||
|
cache_dir=None,
|
||||||
|
url_to_name=pipe_cleaner,
|
||||||
|
validator=check_tar_format,
|
||||||
|
verbose=False,
|
||||||
|
always=False,
|
||||||
|
):
|
||||||
|
"""Given a stream of url names (packaged in `dict(url=url)`), yield opened streams."""
|
||||||
|
verbose = verbose or verbose_cache
|
||||||
|
for sample in data:
|
||||||
|
assert isinstance(sample, dict), sample
|
||||||
|
assert "url" in sample
|
||||||
|
url = sample["url"]
|
||||||
|
attempts = 5
|
||||||
|
try:
|
||||||
|
if not always and os.path.exists(url):
|
||||||
|
dest = url
|
||||||
|
else:
|
||||||
|
dest = get_file_cached(
|
||||||
|
url,
|
||||||
|
cache_size=cache_size,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
url_to_name=url_to_name,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
if verbose:
|
||||||
|
print("# opening %s" % dest, file=sys.stderr)
|
||||||
|
assert os.path.exists(dest)
|
||||||
|
if not validator(dest):
|
||||||
|
ftype = get_filetype(dest)
|
||||||
|
with open(dest, "rb") as f:
|
||||||
|
data = f.read(200)
|
||||||
|
os.remove(dest)
|
||||||
|
raise ValueError(
|
||||||
|
"%s (%s) is not a tar archive, but a %s, contains %s"
|
||||||
|
% (dest, url, ftype, repr(data))
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
stream = open(dest, "rb")
|
||||||
|
sample.update(stream=stream)
|
||||||
|
yield sample
|
||||||
|
except FileNotFoundError as exn:
|
||||||
|
# dealing with race conditions in lru_cleanup
|
||||||
|
attempts -= 1
|
||||||
|
if attempts > 0:
|
||||||
|
time.sleep(random.random() * 10)
|
||||||
|
continue
|
||||||
|
raise exn
|
||||||
|
except Exception as exn:
|
||||||
|
exn.args = exn.args + (url,)
|
||||||
|
if handler(exn):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def cached_tarfile_samples(
|
||||||
|
src,
|
||||||
|
handler=reraise_exception,
|
||||||
|
cache_size=-1,
|
||||||
|
cache_dir=None,
|
||||||
|
verbose=False,
|
||||||
|
url_to_name=pipe_cleaner,
|
||||||
|
always=False,
|
||||||
|
):
|
||||||
|
streams = cached_url_opener(
|
||||||
|
src,
|
||||||
|
handler=handler,
|
||||||
|
cache_size=cache_size,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
verbose=verbose,
|
||||||
|
url_to_name=url_to_name,
|
||||||
|
always=always,
|
||||||
|
)
|
||||||
|
samples = tar_file_and_group_expander(streams, handler=handler)
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
cached_tarfile_to_samples = filters.pipelinefilter(cached_tarfile_samples)
|
@ -0,0 +1,170 @@
|
|||||||
|
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
# See the LICENSE file for licensing terms (BSD-style).
|
||||||
|
# Modified from https://github.com/webdataset/webdataset
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from itertools import islice
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import braceexpand, yaml
|
||||||
|
|
||||||
|
from webdataset import autodecode
|
||||||
|
from . import cache, filters, shardlists, tariterators
|
||||||
|
from .filters import reraise_exception
|
||||||
|
from .pipeline import DataPipeline
|
||||||
|
from .paddle_utils import DataLoader, IterableDataset
|
||||||
|
|
||||||
|
|
||||||
|
class FluidInterface:
|
||||||
|
def batched(self, batchsize):
|
||||||
|
return self.compose(filters.batched(batchsize))
|
||||||
|
|
||||||
|
def dynamic_batched(self, max_frames_in_batch):
|
||||||
|
return self.compose(filter.dynamic_batched(max_frames_in_batch))
|
||||||
|
|
||||||
|
def unbatched(self):
|
||||||
|
return self.compose(filters.unbatched())
|
||||||
|
|
||||||
|
def listed(self, batchsize, partial=True):
|
||||||
|
return self.compose(filters.batched(), batchsize=batchsize, collation_fn=None)
|
||||||
|
|
||||||
|
def unlisted(self):
|
||||||
|
return self.compose(filters.unlisted())
|
||||||
|
|
||||||
|
def log_keys(self, logfile=None):
|
||||||
|
return self.compose(filters.log_keys(logfile))
|
||||||
|
|
||||||
|
def shuffle(self, size, **kw):
|
||||||
|
if size < 1:
|
||||||
|
return self
|
||||||
|
else:
|
||||||
|
return self.compose(filters.shuffle(size, **kw))
|
||||||
|
|
||||||
|
def map(self, f, handler=reraise_exception):
|
||||||
|
return self.compose(filters.map(f, handler=handler))
|
||||||
|
|
||||||
|
def decode(self, *args, pre=None, post=None, only=None, partial=False, handler=reraise_exception):
|
||||||
|
handlers = [autodecode.ImageHandler(x) if isinstance(x, str) else x for x in args]
|
||||||
|
decoder = autodecode.Decoder(handlers, pre=pre, post=post, only=only, partial=partial)
|
||||||
|
return self.map(decoder, handler=handler)
|
||||||
|
|
||||||
|
def map_dict(self, handler=reraise_exception, **kw):
|
||||||
|
return self.compose(filters.map_dict(handler=handler, **kw))
|
||||||
|
|
||||||
|
def select(self, predicate, **kw):
|
||||||
|
return self.compose(filters.select(predicate, **kw))
|
||||||
|
|
||||||
|
def to_tuple(self, *args, handler=reraise_exception):
|
||||||
|
return self.compose(filters.to_tuple(*args, handler=handler))
|
||||||
|
|
||||||
|
def map_tuple(self, *args, handler=reraise_exception):
|
||||||
|
return self.compose(filters.map_tuple(*args, handler=handler))
|
||||||
|
|
||||||
|
def slice(self, *args):
|
||||||
|
return self.compose(filters.slice(*args))
|
||||||
|
|
||||||
|
def rename(self, **kw):
|
||||||
|
return self.compose(filters.rename(**kw))
|
||||||
|
|
||||||
|
def rsample(self, p=0.5):
|
||||||
|
return self.compose(filters.rsample(p))
|
||||||
|
|
||||||
|
def rename_keys(self, *args, **kw):
|
||||||
|
return self.compose(filters.rename_keys(*args, **kw))
|
||||||
|
|
||||||
|
def extract_keys(self, *args, **kw):
|
||||||
|
return self.compose(filters.extract_keys(*args, **kw))
|
||||||
|
|
||||||
|
def xdecode(self, *args, **kw):
|
||||||
|
return self.compose(filters.xdecode(*args, **kw))
|
||||||
|
|
||||||
|
def data_filter(self, *args, **kw):
|
||||||
|
return self.compose(filters.data_filter(*args, **kw))
|
||||||
|
|
||||||
|
def tokenize(self, *args, **kw):
|
||||||
|
return self.compose(filters.tokenize(*args, **kw))
|
||||||
|
|
||||||
|
def resample(self, *args, **kw):
|
||||||
|
return self.compose(filters.resample(*args, **kw))
|
||||||
|
|
||||||
|
def compute_fbank(self, *args, **kw):
|
||||||
|
return self.compose(filters.compute_fbank(*args, **kw))
|
||||||
|
|
||||||
|
def spec_aug(self, *args, **kw):
|
||||||
|
return self.compose(filters.spec_aug(*args, **kw))
|
||||||
|
|
||||||
|
def sort(self, size=500):
|
||||||
|
return self.compose(filters.sort(size))
|
||||||
|
|
||||||
|
def padding(self):
|
||||||
|
return self.compose(filters.padding())
|
||||||
|
|
||||||
|
def cmvn(self, cmvn_file):
|
||||||
|
return self.compose(filters.cmvn(cmvn_file))
|
||||||
|
|
||||||
|
class WebDataset(DataPipeline, FluidInterface):
|
||||||
|
"""Small fluid-interface wrapper for DataPipeline."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
urls,
|
||||||
|
handler=reraise_exception,
|
||||||
|
resampled=False,
|
||||||
|
repeat=False,
|
||||||
|
shardshuffle=None,
|
||||||
|
cache_size=0,
|
||||||
|
cache_dir=None,
|
||||||
|
detshuffle=False,
|
||||||
|
nodesplitter=shardlists.single_node_only,
|
||||||
|
verbose=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if isinstance(urls, IterableDataset):
|
||||||
|
assert not resampled
|
||||||
|
self.append(urls)
|
||||||
|
elif isinstance(urls, str) and (urls.endswith(".yaml") or urls.endswith(".yml")):
|
||||||
|
with (open(urls)) as stream:
|
||||||
|
spec = yaml.safe_load(stream)
|
||||||
|
assert "datasets" in spec
|
||||||
|
self.append(shardlists.MultiShardSample(spec))
|
||||||
|
elif isinstance(urls, dict):
|
||||||
|
assert "datasets" in urls
|
||||||
|
self.append(shardlists.MultiShardSample(urls))
|
||||||
|
elif resampled:
|
||||||
|
self.append(shardlists.ResampledShards(urls))
|
||||||
|
else:
|
||||||
|
self.append(shardlists.SimpleShardList(urls))
|
||||||
|
self.append(nodesplitter)
|
||||||
|
self.append(shardlists.split_by_worker)
|
||||||
|
if shardshuffle is True:
|
||||||
|
shardshuffle = 100
|
||||||
|
if shardshuffle is not None:
|
||||||
|
if detshuffle:
|
||||||
|
self.append(filters.detshuffle(shardshuffle))
|
||||||
|
else:
|
||||||
|
self.append(filters.shuffle(shardshuffle))
|
||||||
|
if cache_size == 0:
|
||||||
|
self.append(tariterators.tarfile_to_samples(handler=handler))
|
||||||
|
else:
|
||||||
|
assert cache_size == -1 or cache_size > 0
|
||||||
|
self.append(
|
||||||
|
cache.cached_tarfile_to_samples(
|
||||||
|
handler=handler,
|
||||||
|
verbose=verbose,
|
||||||
|
cache_size=cache_size,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FluidWrapper(DataPipeline, FluidInterface):
|
||||||
|
"""Small fluid-interface wrapper for DataPipeline."""
|
||||||
|
|
||||||
|
def __init__(self, initial):
|
||||||
|
super().__init__()
|
||||||
|
self.append(initial)
|
||||||
|
|
||||||
|
|
||||||
|
class WebLoader(DataPipeline, FluidInterface):
|
||||||
|
def __init__(self, *args, **kw):
|
||||||
|
super().__init__(DataLoader(*args, **kw))
|
@ -0,0 +1,912 @@
|
|||||||
|
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
# This file is part of the WebDataset library.
|
||||||
|
# See the LICENSE file for licensing terms (BSD-style).
|
||||||
|
#
|
||||||
|
|
||||||
|
# Modified from https://github.com/webdataset/webdataset
|
||||||
|
# Modified from wenet(https://github.com/wenet-e2e/wenet)
|
||||||
|
"""A collection of iterators for data transformations.
|
||||||
|
|
||||||
|
These functions are plain iterator functions. You can find curried versions
|
||||||
|
in webdataset.filters, and you can find IterableDataset wrappers in
|
||||||
|
webdataset.processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import io
|
||||||
|
from fnmatch import fnmatch
|
||||||
|
import re
|
||||||
|
import itertools, os, random, sys, time
|
||||||
|
from functools import reduce, wraps
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from webdataset import autodecode
|
||||||
|
from . import utils
|
||||||
|
from .paddle_utils import PaddleTensor
|
||||||
|
from .utils import PipelineStage
|
||||||
|
|
||||||
|
from .. import backends
|
||||||
|
from ..compliance import kaldi
|
||||||
|
import paddle
|
||||||
|
from ..transform.cmvn import GlobalCMVN
|
||||||
|
from ..utils.tensor_utils import pad_sequence
|
||||||
|
from ..transform.spec_augment import time_warp
|
||||||
|
from ..transform.spec_augment import time_mask
|
||||||
|
from ..transform.spec_augment import freq_mask
|
||||||
|
|
||||||
|
class FilterFunction(object):
|
||||||
|
"""Helper class for currying pipeline stages.
|
||||||
|
|
||||||
|
We use this roundabout construct becauce it can be pickled.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, f, *args, **kw):
|
||||||
|
"""Create a curried function."""
|
||||||
|
self.f = f
|
||||||
|
self.args = args
|
||||||
|
self.kw = kw
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
"""Call the curried function with the given argument."""
|
||||||
|
return self.f(data, *self.args, **self.kw)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
"""Compute a string representation."""
|
||||||
|
return f"<{self.f.__name__} {self.args} {self.kw}>"
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
"""Compute a string representation."""
|
||||||
|
return f"<{self.f.__name__} {self.args} {self.kw}>"
|
||||||
|
|
||||||
|
|
||||||
|
class RestCurried(object):
|
||||||
|
"""Helper class for currying pipeline stages.
|
||||||
|
|
||||||
|
We use this roundabout construct because it can be pickled.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, f):
|
||||||
|
"""Store the function for future currying."""
|
||||||
|
self.f = f
|
||||||
|
|
||||||
|
def __call__(self, *args, **kw):
|
||||||
|
"""Curry with the given arguments."""
|
||||||
|
return FilterFunction(self.f, *args, **kw)
|
||||||
|
|
||||||
|
|
||||||
|
def pipelinefilter(f):
|
||||||
|
"""Turn the decorated function into one that is partially applied for
|
||||||
|
all arguments other than the first."""
|
||||||
|
result = RestCurried(f)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def reraise_exception(exn):
|
||||||
|
"""Reraises the given exception; used as a handler.
|
||||||
|
|
||||||
|
:param exn: exception
|
||||||
|
"""
|
||||||
|
raise exn
|
||||||
|
|
||||||
|
|
||||||
|
def identity(x):
|
||||||
|
"""Return the argument."""
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def compose2(f, g):
|
||||||
|
"""Compose two functions, g(f(x))."""
|
||||||
|
return lambda x: g(f(x))
|
||||||
|
|
||||||
|
|
||||||
|
def compose(*args):
|
||||||
|
"""Compose a sequence of functions (left-to-right)."""
|
||||||
|
return reduce(compose2, args)
|
||||||
|
|
||||||
|
|
||||||
|
def pipeline(source, *args):
|
||||||
|
"""Write an input pipeline; first argument is source, rest are filters."""
|
||||||
|
if len(args) == 0:
|
||||||
|
return source
|
||||||
|
return compose(*args)(source)
|
||||||
|
|
||||||
|
|
||||||
|
def getfirst(a, keys, default=None, missing_is_error=True):
|
||||||
|
"""Get the first matching key from a dictionary.
|
||||||
|
|
||||||
|
Keys can be specified as a list, or as a string of keys separated by ';'.
|
||||||
|
"""
|
||||||
|
if isinstance(keys, str):
|
||||||
|
assert " " not in keys
|
||||||
|
keys = keys.split(";")
|
||||||
|
for k in keys:
|
||||||
|
if k in a:
|
||||||
|
return a[k]
|
||||||
|
if missing_is_error:
|
||||||
|
raise ValueError(f"didn't find {keys} in {list(a.keys())}")
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def parse_field_spec(fields):
|
||||||
|
"""Parse a specification for a list of fields to be extracted.
|
||||||
|
|
||||||
|
Keys are separated by spaces in the spec. Each key can itself
|
||||||
|
be composed of key alternatives separated by ';'.
|
||||||
|
"""
|
||||||
|
if isinstance(fields, str):
|
||||||
|
fields = fields.split()
|
||||||
|
return [field.split(";") for field in fields]
|
||||||
|
|
||||||
|
|
||||||
|
def transform_with(sample, transformers):
|
||||||
|
"""Transform a list of values using a list of functions.
|
||||||
|
|
||||||
|
sample: list of values
|
||||||
|
transformers: list of functions
|
||||||
|
|
||||||
|
If there are fewer transformers than inputs, or if a transformer
|
||||||
|
function is None, then the identity function is used for the
|
||||||
|
corresponding sample fields.
|
||||||
|
"""
|
||||||
|
if transformers is None or len(transformers) == 0:
|
||||||
|
return sample
|
||||||
|
result = list(sample)
|
||||||
|
assert len(transformers) <= len(sample)
|
||||||
|
for i in range(len(transformers)): # skipcq: PYL-C0200
|
||||||
|
f = transformers[i]
|
||||||
|
if f is not None:
|
||||||
|
result[i] = f(sample[i])
|
||||||
|
return result
|
||||||
|
|
||||||
|
###
|
||||||
|
# Iterators
|
||||||
|
###
|
||||||
|
|
||||||
|
def _info(data, fmt=None, n=3, every=-1, width=50, stream=sys.stderr, name=""):
|
||||||
|
"""Print information about the samples that are passing through.
|
||||||
|
|
||||||
|
:param data: source iterator
|
||||||
|
:param fmt: format statement (using sample dict as keyword)
|
||||||
|
:param n: when to stop
|
||||||
|
:param every: how often to print
|
||||||
|
:param width: maximum width
|
||||||
|
:param stream: output stream
|
||||||
|
:param name: identifier printed before any output
|
||||||
|
"""
|
||||||
|
for i, sample in enumerate(data):
|
||||||
|
if i < n or (every > 0 and (i + 1) % every == 0):
|
||||||
|
if fmt is None:
|
||||||
|
print("---", name, file=stream)
|
||||||
|
for k, v in sample.items():
|
||||||
|
print(k, repr(v)[:width], file=stream)
|
||||||
|
else:
|
||||||
|
print(fmt.format(**sample), file=stream)
|
||||||
|
yield sample
|
||||||
|
|
||||||
|
|
||||||
|
info = pipelinefilter(_info)
|
||||||
|
|
||||||
|
|
||||||
|
def pick(buf, rng):
|
||||||
|
k = rng.randint(0, len(buf) - 1)
|
||||||
|
sample = buf[k]
|
||||||
|
buf[k] = buf[-1]
|
||||||
|
buf.pop()
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
def _shuffle(data, bufsize=1000, initial=100, rng=None, handler=None):
|
||||||
|
"""Shuffle the data in the stream.
|
||||||
|
|
||||||
|
This uses a buffer of size `bufsize`. Shuffling at
|
||||||
|
startup is less random; this is traded off against
|
||||||
|
yielding samples quickly.
|
||||||
|
|
||||||
|
data: iterator
|
||||||
|
bufsize: buffer size for shuffling
|
||||||
|
returns: iterator
|
||||||
|
rng: either random module or random.Random instance
|
||||||
|
|
||||||
|
"""
|
||||||
|
if rng is None:
|
||||||
|
rng = random.Random(int((os.getpid() + time.time()) * 1e9))
|
||||||
|
initial = min(initial, bufsize)
|
||||||
|
buf = []
|
||||||
|
for sample in data:
|
||||||
|
buf.append(sample)
|
||||||
|
if len(buf) < bufsize:
|
||||||
|
try:
|
||||||
|
buf.append(next(data)) # skipcq: PYL-R1708
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
if len(buf) >= initial:
|
||||||
|
yield pick(buf, rng)
|
||||||
|
while len(buf) > 0:
|
||||||
|
yield pick(buf, rng)
|
||||||
|
|
||||||
|
|
||||||
|
shuffle = pipelinefilter(_shuffle)
|
||||||
|
|
||||||
|
|
||||||
|
class detshuffle(PipelineStage):
|
||||||
|
def __init__(self, bufsize=1000, initial=100, seed=0, epoch=-1):
|
||||||
|
self.bufsize = bufsize
|
||||||
|
self.initial = initial
|
||||||
|
self.seed = seed
|
||||||
|
self.epoch = epoch
|
||||||
|
|
||||||
|
def run(self, src):
|
||||||
|
self.epoch += 1
|
||||||
|
rng = random.Random()
|
||||||
|
rng.seed((self.seed, self.epoch))
|
||||||
|
return _shuffle(src, self.bufsize, self.initial, rng)
|
||||||
|
|
||||||
|
|
||||||
|
def _select(data, predicate):
|
||||||
|
"""Select samples based on a predicate.
|
||||||
|
|
||||||
|
:param data: source iterator
|
||||||
|
:param predicate: predicate (function)
|
||||||
|
"""
|
||||||
|
for sample in data:
|
||||||
|
if predicate(sample):
|
||||||
|
yield sample
|
||||||
|
|
||||||
|
|
||||||
|
select = pipelinefilter(_select)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_keys(data, logfile=None):
|
||||||
|
import fcntl
|
||||||
|
|
||||||
|
if logfile is None or logfile == "":
|
||||||
|
for sample in data:
|
||||||
|
yield sample
|
||||||
|
else:
|
||||||
|
with open(logfile, "a") as stream:
|
||||||
|
for i, sample in enumerate(data):
|
||||||
|
buf = f"{i}\t{sample.get('__worker__')}\t{sample.get('__rank__')}\t{sample.get('__key__')}\n"
|
||||||
|
try:
|
||||||
|
fcntl.flock(stream.fileno(), fcntl.LOCK_EX)
|
||||||
|
stream.write(buf)
|
||||||
|
finally:
|
||||||
|
fcntl.flock(stream.fileno(), fcntl.LOCK_UN)
|
||||||
|
yield sample
|
||||||
|
|
||||||
|
|
||||||
|
log_keys = pipelinefilter(_log_keys)
|
||||||
|
|
||||||
|
|
||||||
|
def _decode(data, *args, handler=reraise_exception, **kw):
|
||||||
|
"""Decode data based on the decoding functions given as arguments."""
|
||||||
|
|
||||||
|
decoder = lambda x: autodecode.imagehandler(x) if isinstance(x, str) else x
|
||||||
|
handlers = [decoder(x) for x in args]
|
||||||
|
f = autodecode.Decoder(handlers, **kw)
|
||||||
|
|
||||||
|
for sample in data:
|
||||||
|
assert isinstance(sample, dict), sample
|
||||||
|
try:
|
||||||
|
decoded = f(sample)
|
||||||
|
except Exception as exn: # skipcq: PYL-W0703
|
||||||
|
if handler(exn):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
yield decoded
|
||||||
|
|
||||||
|
|
||||||
|
decode = pipelinefilter(_decode)
|
||||||
|
|
||||||
|
|
||||||
|
def _map(data, f, handler=reraise_exception):
|
||||||
|
"""Map samples."""
|
||||||
|
for sample in data:
|
||||||
|
try:
|
||||||
|
result = f(sample)
|
||||||
|
except Exception as exn:
|
||||||
|
if handler(exn):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
if result is None:
|
||||||
|
continue
|
||||||
|
if isinstance(sample, dict) and isinstance(result, dict):
|
||||||
|
result["__key__"] = sample.get("__key__")
|
||||||
|
yield result
|
||||||
|
|
||||||
|
|
||||||
|
map = pipelinefilter(_map)
|
||||||
|
|
||||||
|
|
||||||
|
def _rename(data, handler=reraise_exception, keep=True, **kw):
|
||||||
|
"""Rename samples based on keyword arguments."""
|
||||||
|
for sample in data:
|
||||||
|
try:
|
||||||
|
if not keep:
|
||||||
|
yield {k: getfirst(sample, v, missing_is_error=True) for k, v in kw.items()}
|
||||||
|
else:
|
||||||
|
|
||||||
|
def listify(v):
|
||||||
|
return v.split(";") if isinstance(v, str) else v
|
||||||
|
|
||||||
|
to_be_replaced = {x for v in kw.values() for x in listify(v)}
|
||||||
|
result = {k: v for k, v in sample.items() if k not in to_be_replaced}
|
||||||
|
result.update({k: getfirst(sample, v, missing_is_error=True) for k, v in kw.items()})
|
||||||
|
yield result
|
||||||
|
except Exception as exn:
|
||||||
|
if handler(exn):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
rename = pipelinefilter(_rename)
|
||||||
|
|
||||||
|
|
||||||
|
def _associate(data, associator, **kw):
|
||||||
|
"""Associate additional data with samples."""
|
||||||
|
for sample in data:
|
||||||
|
if callable(associator):
|
||||||
|
extra = associator(sample["__key__"])
|
||||||
|
else:
|
||||||
|
extra = associator.get(sample["__key__"], {})
|
||||||
|
sample.update(extra) # destructive
|
||||||
|
yield sample
|
||||||
|
|
||||||
|
|
||||||
|
associate = pipelinefilter(_associate)
|
||||||
|
|
||||||
|
|
||||||
|
def _map_dict(data, handler=reraise_exception, **kw):
|
||||||
|
"""Map the entries in a dict sample with individual functions."""
|
||||||
|
assert len(list(kw.keys())) > 0
|
||||||
|
for key, f in kw.items():
|
||||||
|
assert callable(f), (key, f)
|
||||||
|
|
||||||
|
for sample in data:
|
||||||
|
assert isinstance(sample, dict)
|
||||||
|
try:
|
||||||
|
for k, f in kw.items():
|
||||||
|
sample[k] = f(sample[k])
|
||||||
|
except Exception as exn:
|
||||||
|
if handler(exn):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
yield sample
|
||||||
|
|
||||||
|
|
||||||
|
map_dict = pipelinefilter(_map_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def _to_tuple(data, *args, handler=reraise_exception, missing_is_error=True, none_is_error=None):
|
||||||
|
"""Convert dict samples to tuples."""
|
||||||
|
if none_is_error is None:
|
||||||
|
none_is_error = missing_is_error
|
||||||
|
if len(args) == 1 and isinstance(args[0], str) and " " in args[0]:
|
||||||
|
args = args[0].split()
|
||||||
|
|
||||||
|
for sample in data:
|
||||||
|
try:
|
||||||
|
result = tuple([getfirst(sample, f, missing_is_error=missing_is_error) for f in args])
|
||||||
|
if none_is_error and any(x is None for x in result):
|
||||||
|
raise ValueError(f"to_tuple {args} got {sample.keys()}")
|
||||||
|
yield result
|
||||||
|
except Exception as exn:
|
||||||
|
if handler(exn):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
to_tuple = pipelinefilter(_to_tuple)
|
||||||
|
|
||||||
|
|
||||||
|
def _map_tuple(data, *args, handler=reraise_exception):
|
||||||
|
"""Map the entries of a tuple with individual functions."""
|
||||||
|
args = [f if f is not None else utils.identity for f in args]
|
||||||
|
for f in args:
|
||||||
|
assert callable(f), f
|
||||||
|
for sample in data:
|
||||||
|
assert isinstance(sample, (list, tuple))
|
||||||
|
sample = list(sample)
|
||||||
|
n = min(len(args), len(sample))
|
||||||
|
try:
|
||||||
|
for i in range(n):
|
||||||
|
sample[i] = args[i](sample[i])
|
||||||
|
except Exception as exn:
|
||||||
|
if handler(exn):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
yield tuple(sample)
|
||||||
|
|
||||||
|
|
||||||
|
map_tuple = pipelinefilter(_map_tuple)
|
||||||
|
|
||||||
|
|
||||||
|
def _unlisted(data):
|
||||||
|
"""Turn batched data back into unbatched data."""
|
||||||
|
for batch in data:
|
||||||
|
assert isinstance(batch, list), sample
|
||||||
|
for sample in batch:
|
||||||
|
yield sample
|
||||||
|
|
||||||
|
|
||||||
|
unlisted = pipelinefilter(_unlisted)
|
||||||
|
|
||||||
|
|
||||||
|
def _unbatched(data):
|
||||||
|
"""Turn batched data back into unbatched data."""
|
||||||
|
for sample in data:
|
||||||
|
assert isinstance(sample, (tuple, list)), sample
|
||||||
|
assert len(sample) > 0
|
||||||
|
for i in range(len(sample[0])):
|
||||||
|
yield tuple(x[i] for x in sample)
|
||||||
|
|
||||||
|
|
||||||
|
unbatched = pipelinefilter(_unbatched)
|
||||||
|
|
||||||
|
|
||||||
|
def _rsample(data, p=0.5):
|
||||||
|
"""Randomly subsample a stream of data."""
|
||||||
|
assert p >= 0.0 and p <= 1.0
|
||||||
|
for sample in data:
|
||||||
|
if random.uniform(0.0, 1.0) < p:
|
||||||
|
yield sample
|
||||||
|
|
||||||
|
|
||||||
|
rsample = pipelinefilter(_rsample)
|
||||||
|
|
||||||
|
slice = pipelinefilter(itertools.islice)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_keys(source, *patterns, duplicate_is_error=True, ignore_missing=False):
|
||||||
|
for sample in source:
|
||||||
|
result = []
|
||||||
|
for pattern in patterns:
|
||||||
|
pattern = pattern.split(";") if isinstance(pattern, str) else pattern
|
||||||
|
matches = [x for x in sample.keys() if any(fnmatch("." + x, p) for p in pattern)]
|
||||||
|
if len(matches) == 0:
|
||||||
|
if ignore_missing:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Cannot find {pattern} in sample keys {sample.keys()}.")
|
||||||
|
if len(matches) > 1 and duplicate_is_error:
|
||||||
|
raise ValueError(f"Multiple sample keys {sample.keys()} match {pattern}.")
|
||||||
|
value = sample[matches[0]]
|
||||||
|
result.append(value)
|
||||||
|
yield tuple(result)
|
||||||
|
|
||||||
|
|
||||||
|
extract_keys = pipelinefilter(_extract_keys)
|
||||||
|
|
||||||
|
|
||||||
|
def _rename_keys(source, *args, keep_unselected=False, must_match=True, duplicate_is_error=True, **kw):
|
||||||
|
renamings = [(pattern, output) for output, pattern in args]
|
||||||
|
renamings += [(pattern, output) for output, pattern in kw.items()]
|
||||||
|
for sample in source:
|
||||||
|
new_sample = {}
|
||||||
|
matched = {k: False for k, _ in renamings}
|
||||||
|
for path, value in sample.items():
|
||||||
|
fname = re.sub(r".*/", "", path)
|
||||||
|
new_name = None
|
||||||
|
for pattern, name in renamings[::-1]:
|
||||||
|
if fnmatch(fname.lower(), pattern):
|
||||||
|
matched[pattern] = True
|
||||||
|
new_name = name
|
||||||
|
break
|
||||||
|
if new_name is None:
|
||||||
|
if keep_unselected:
|
||||||
|
new_sample[path] = value
|
||||||
|
continue
|
||||||
|
if new_name in new_sample:
|
||||||
|
if duplicate_is_error:
|
||||||
|
raise ValueError(f"Duplicate value in sample {sample.keys()} after rename.")
|
||||||
|
continue
|
||||||
|
new_sample[new_name] = value
|
||||||
|
if must_match and not all(matched.values()):
|
||||||
|
raise ValueError(f"Not all patterns ({matched}) matched sample keys ({sample.keys()}).")
|
||||||
|
|
||||||
|
yield new_sample
|
||||||
|
|
||||||
|
|
||||||
|
rename_keys = pipelinefilter(_rename_keys)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_bin(stream):
|
||||||
|
return stream.read()
|
||||||
|
|
||||||
|
|
||||||
|
def decode_text(stream):
|
||||||
|
binary = stream.read()
|
||||||
|
return binary.decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def decode_pickle(stream):
|
||||||
|
return pickle.load(stream)
|
||||||
|
|
||||||
|
|
||||||
|
default_decoders = [
|
||||||
|
("*.bin", decode_bin),
|
||||||
|
("*.txt", decode_text),
|
||||||
|
("*.pyd", decode_pickle),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def find_decoder(decoders, path):
|
||||||
|
fname = re.sub(r".*/", "", path)
|
||||||
|
if fname.startswith("__"):
|
||||||
|
return lambda x: x
|
||||||
|
for pattern, fun in decoders[::-1]:
|
||||||
|
if fnmatch(fname.lower(), pattern) or fnmatch("." + fname.lower(), pattern):
|
||||||
|
return fun
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _xdecode(
|
||||||
|
source,
|
||||||
|
*args,
|
||||||
|
must_decode=True,
|
||||||
|
defaults=default_decoders,
|
||||||
|
**kw,
|
||||||
|
):
|
||||||
|
decoders = list(defaults) + list(args)
|
||||||
|
decoders += [("*." + k, v) for k, v in kw.items()]
|
||||||
|
for sample in source:
|
||||||
|
new_sample = {}
|
||||||
|
for path, data in sample.items():
|
||||||
|
if path.startswith("__"):
|
||||||
|
new_sample[path] = data
|
||||||
|
continue
|
||||||
|
decoder = find_decoder(decoders, path)
|
||||||
|
if decoder is False:
|
||||||
|
value = data
|
||||||
|
elif decoder is None:
|
||||||
|
if must_decode:
|
||||||
|
raise ValueError(f"No decoder found for {path}.")
|
||||||
|
value = data
|
||||||
|
else:
|
||||||
|
if isinstance(data, bytes):
|
||||||
|
data = io.BytesIO(data)
|
||||||
|
value = decoder(data)
|
||||||
|
new_sample[path] = value
|
||||||
|
yield new_sample
|
||||||
|
|
||||||
|
xdecode = pipelinefilter(_xdecode)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _data_filter(source,
|
||||||
|
frame_shift=10,
|
||||||
|
max_length=10240,
|
||||||
|
min_length=10,
|
||||||
|
token_max_length=200,
|
||||||
|
token_min_length=1,
|
||||||
|
min_output_input_ratio=0.0005,
|
||||||
|
max_output_input_ratio=1):
|
||||||
|
""" Filter sample according to feature and label length
|
||||||
|
Inplace operation.
|
||||||
|
|
||||||
|
Args::
|
||||||
|
source: Iterable[{fname, wav, label, sample_rate}]
|
||||||
|
frame_shift: length of frame shift (ms)
|
||||||
|
max_length: drop utterance which is greater than max_length(10ms)
|
||||||
|
min_length: drop utterance which is less than min_length(10ms)
|
||||||
|
token_max_length: drop utterance which is greater than
|
||||||
|
token_max_length, especially when use char unit for
|
||||||
|
english modeling
|
||||||
|
token_min_length: drop utterance which is
|
||||||
|
less than token_max_length
|
||||||
|
min_output_input_ratio: minimal ration of
|
||||||
|
token_length / feats_length(10ms)
|
||||||
|
max_output_input_ratio: maximum ration of
|
||||||
|
token_length / feats_length(10ms)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Iterable[{fname, wav, label, sample_rate}]
|
||||||
|
"""
|
||||||
|
for sample in source:
|
||||||
|
assert 'sample_rate' in sample
|
||||||
|
assert 'wav' in sample
|
||||||
|
assert 'label' in sample
|
||||||
|
# sample['wav'] is paddle.Tensor, we have 100 frames every second (default)
|
||||||
|
num_frames = sample['wav'].shape[1] / sample['sample_rate'] * (1000 / frame_shift)
|
||||||
|
if num_frames < min_length:
|
||||||
|
continue
|
||||||
|
if num_frames > max_length:
|
||||||
|
continue
|
||||||
|
if len(sample['label']) < token_min_length:
|
||||||
|
continue
|
||||||
|
if len(sample['label']) > token_max_length:
|
||||||
|
continue
|
||||||
|
if num_frames != 0:
|
||||||
|
if len(sample['label']) / num_frames < min_output_input_ratio:
|
||||||
|
continue
|
||||||
|
if len(sample['label']) / num_frames > max_output_input_ratio:
|
||||||
|
continue
|
||||||
|
yield sample
|
||||||
|
|
||||||
|
data_filter = pipelinefilter(_data_filter)
|
||||||
|
|
||||||
|
def _tokenize(source,
|
||||||
|
symbol_table,
|
||||||
|
bpe_model=None,
|
||||||
|
non_lang_syms=None,
|
||||||
|
split_with_space=False):
|
||||||
|
""" Decode text to chars or BPE
|
||||||
|
Inplace operation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: Iterable[{fname, wav, txt, sample_rate}]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Iterable[{fname, wav, txt, tokens, label, sample_rate}]
|
||||||
|
"""
|
||||||
|
if non_lang_syms is not None:
|
||||||
|
non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})")
|
||||||
|
else:
|
||||||
|
non_lang_syms = {}
|
||||||
|
non_lang_syms_pattern = None
|
||||||
|
|
||||||
|
if bpe_model is not None:
|
||||||
|
import sentencepiece as spm
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(bpe_model)
|
||||||
|
else:
|
||||||
|
sp = None
|
||||||
|
|
||||||
|
for sample in source:
|
||||||
|
assert 'txt' in sample
|
||||||
|
txt = sample['txt'].strip()
|
||||||
|
if non_lang_syms_pattern is not None:
|
||||||
|
parts = non_lang_syms_pattern.split(txt.upper())
|
||||||
|
parts = [w for w in parts if len(w.strip()) > 0]
|
||||||
|
else:
|
||||||
|
parts = [txt]
|
||||||
|
|
||||||
|
label = []
|
||||||
|
tokens = []
|
||||||
|
for part in parts:
|
||||||
|
if part in non_lang_syms:
|
||||||
|
tokens.append(part)
|
||||||
|
else:
|
||||||
|
if bpe_model is not None:
|
||||||
|
tokens.extend(__tokenize_by_bpe_model(sp, part))
|
||||||
|
else:
|
||||||
|
if split_with_space:
|
||||||
|
part = part.split(" ")
|
||||||
|
for ch in part:
|
||||||
|
if ch == ' ':
|
||||||
|
ch = "<space>"
|
||||||
|
tokens.append(ch)
|
||||||
|
|
||||||
|
for ch in tokens:
|
||||||
|
if ch in symbol_table:
|
||||||
|
label.append(symbol_table[ch])
|
||||||
|
elif '<unk>' in symbol_table:
|
||||||
|
label.append(symbol_table['<unk>'])
|
||||||
|
|
||||||
|
sample['tokens'] = tokens
|
||||||
|
sample['label'] = label
|
||||||
|
yield sample
|
||||||
|
|
||||||
|
tokenize = pipelinefilter(_tokenize)
|
||||||
|
|
||||||
|
def _resample(source, resample_rate=16000):
|
||||||
|
""" Resample data.
|
||||||
|
Inplace operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Iterable[{fname, wav, label, sample_rate}]
|
||||||
|
resample_rate: target resample rate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Iterable[{fname, wav, label, sample_rate}]
|
||||||
|
"""
|
||||||
|
for sample in source:
|
||||||
|
assert 'sample_rate' in sample
|
||||||
|
assert 'wav' in sample
|
||||||
|
sample_rate = sample['sample_rate']
|
||||||
|
waveform = sample['wav']
|
||||||
|
if sample_rate != resample_rate:
|
||||||
|
sample['sample_rate'] = resample_rate
|
||||||
|
sample['wav'] = paddle.to_tensor(backends.soundfile_backend.resample(
|
||||||
|
waveform.numpy(), src_sr = sample_rate, target_sr = resample_rate
|
||||||
|
))
|
||||||
|
yield sample
|
||||||
|
|
||||||
|
resample = pipelinefilter(_resample)
|
||||||
|
|
||||||
|
def _compute_fbank(source,
|
||||||
|
num_mel_bins=80,
|
||||||
|
frame_length=25,
|
||||||
|
frame_shift=10,
|
||||||
|
dither=0.0):
|
||||||
|
""" Extract fbank
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: Iterable[{fname, wav, label, sample_rate}]
|
||||||
|
num_mel_bins: number of mel filter bank
|
||||||
|
frame_length: length of one frame (ms)
|
||||||
|
frame_shift: length of frame shift (ms)
|
||||||
|
dither: value of dither
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Iterable[{fname, feat, label}]
|
||||||
|
"""
|
||||||
|
for sample in source:
|
||||||
|
assert 'sample_rate' in sample
|
||||||
|
assert 'wav' in sample
|
||||||
|
assert 'fname' in sample
|
||||||
|
assert 'label' in sample
|
||||||
|
sample_rate = sample['sample_rate']
|
||||||
|
waveform = sample['wav']
|
||||||
|
waveform = waveform * (1 << 15)
|
||||||
|
# Only keep fname, feat, label
|
||||||
|
mat = kaldi.fbank(waveform,
|
||||||
|
n_mels=num_mel_bins,
|
||||||
|
frame_length=frame_length,
|
||||||
|
frame_shift=frame_shift,
|
||||||
|
dither=dither,
|
||||||
|
energy_floor=0.0,
|
||||||
|
sr=sample_rate)
|
||||||
|
yield dict(fname=sample['fname'], label=sample['label'], feat=mat)
|
||||||
|
|
||||||
|
|
||||||
|
compute_fbank = pipelinefilter(_compute_fbank)
|
||||||
|
|
||||||
|
def _spec_aug(source, num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80):
|
||||||
|
""" Do spec augmentation
|
||||||
|
Inplace operation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: Iterable[{fname, feat, label}]
|
||||||
|
num_t_mask: number of time mask to apply
|
||||||
|
num_f_mask: number of freq mask to apply
|
||||||
|
max_t: max width of time mask
|
||||||
|
max_f: max width of freq mask
|
||||||
|
max_w: max width of time warp
|
||||||
|
|
||||||
|
Returns
|
||||||
|
Iterable[{fname, feat, label}]
|
||||||
|
"""
|
||||||
|
for sample in source:
|
||||||
|
x = sample['feat']
|
||||||
|
x = x.numpy()
|
||||||
|
x = time_warp(x, max_time_warp=max_w, inplace = True, mode= "PIL")
|
||||||
|
x = freq_mask(x, F = max_f, n_mask = num_f_mask, inplace = True, replace_with_zero = False)
|
||||||
|
x = time_mask(x, T = max_t, n_mask = num_t_mask, inplace = True, replace_with_zero = False)
|
||||||
|
sample['feat'] = paddle.to_tensor(x, dtype=paddle.float32)
|
||||||
|
yield sample
|
||||||
|
|
||||||
|
spec_aug = pipelinefilter(_spec_aug)
|
||||||
|
|
||||||
|
|
||||||
|
def _sort(source, sort_size=500):
|
||||||
|
""" Sort the data by feature length.
|
||||||
|
Sort is used after shuffle and before batch, so we can group
|
||||||
|
utts with similar lengths into a batch, and `sort_size` should
|
||||||
|
be less than `shuffle_size`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: Iterable[{fname, feat, label}]
|
||||||
|
sort_size: buffer size for sort
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Iterable[{fname, feat, label}]
|
||||||
|
"""
|
||||||
|
|
||||||
|
buf = []
|
||||||
|
for sample in source:
|
||||||
|
buf.append(sample)
|
||||||
|
if len(buf) >= sort_size:
|
||||||
|
buf.sort(key=lambda x: x['feat'].shape[0])
|
||||||
|
for x in buf:
|
||||||
|
yield x
|
||||||
|
buf = []
|
||||||
|
# The sample left over
|
||||||
|
buf.sort(key=lambda x: x['feat'].shape[0])
|
||||||
|
for x in buf:
|
||||||
|
yield x
|
||||||
|
|
||||||
|
sort = pipelinefilter(_sort)
|
||||||
|
|
||||||
|
def _batched(source, batch_size=16):
|
||||||
|
""" Static batch the data by `batch_size`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Iterable[{fname, feat, label}]
|
||||||
|
batch_size: batch size
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Iterable[List[{fname, feat, label}]]
|
||||||
|
"""
|
||||||
|
buf = []
|
||||||
|
for sample in source:
|
||||||
|
buf.append(sample)
|
||||||
|
if len(buf) >= batch_size:
|
||||||
|
yield buf
|
||||||
|
buf = []
|
||||||
|
if len(buf) > 0:
|
||||||
|
yield buf
|
||||||
|
|
||||||
|
batched = pipelinefilter(_batched)
|
||||||
|
|
||||||
|
def dynamic_batched(source, max_frames_in_batch=12000):
|
||||||
|
""" Dynamic batch the data until the total frames in batch
|
||||||
|
reach `max_frames_in_batch`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: Iterable[{fname, feat, label}]
|
||||||
|
max_frames_in_batch: max_frames in one batch
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Iterable[List[{fname, feat, label}]]
|
||||||
|
"""
|
||||||
|
buf = []
|
||||||
|
longest_frames = 0
|
||||||
|
for sample in source:
|
||||||
|
assert 'feat' in sample
|
||||||
|
assert isinstance(sample['feat'], paddle.Tensor)
|
||||||
|
new_sample_frames = sample['feat'].size(0)
|
||||||
|
longest_frames = max(longest_frames, new_sample_frames)
|
||||||
|
frames_after_padding = longest_frames * (len(buf) + 1)
|
||||||
|
if frames_after_padding > max_frames_in_batch:
|
||||||
|
yield buf
|
||||||
|
buf = [sample]
|
||||||
|
longest_frames = new_sample_frames
|
||||||
|
else:
|
||||||
|
buf.append(sample)
|
||||||
|
if len(buf) > 0:
|
||||||
|
yield buf
|
||||||
|
|
||||||
|
|
||||||
|
def _padding(source):
|
||||||
|
""" Padding the data into training data
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: Iterable[List[{fname, feat, label}]]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Iterable[Tuple(fname, feats, labels, feats lengths, label lengths)]
|
||||||
|
"""
|
||||||
|
for sample in source:
|
||||||
|
assert isinstance(sample, list)
|
||||||
|
feats_length = paddle.to_tensor([x['feat'].shape[0] for x in sample],
|
||||||
|
dtype="int64")
|
||||||
|
order = paddle.argsort(feats_length, descending=True)
|
||||||
|
feats_lengths = paddle.to_tensor(
|
||||||
|
[sample[i]['feat'].shape[0] for i in order], dtype="int64")
|
||||||
|
sorted_feats = [sample[i]['feat'] for i in order]
|
||||||
|
sorted_keys = [sample[i]['fname'] for i in order]
|
||||||
|
sorted_labels = [
|
||||||
|
paddle.to_tensor(sample[i]['label'], dtype="int32") for i in order
|
||||||
|
]
|
||||||
|
label_lengths = paddle.to_tensor([x.shape[0] for x in sorted_labels],
|
||||||
|
dtype="int64")
|
||||||
|
padded_feats = pad_sequence(sorted_feats,
|
||||||
|
batch_first=True,
|
||||||
|
padding_value=0)
|
||||||
|
padding_labels = pad_sequence(sorted_labels,
|
||||||
|
batch_first=True,
|
||||||
|
padding_value=-1)
|
||||||
|
|
||||||
|
yield (sorted_keys, padded_feats, feats_lengths, padding_labels,
|
||||||
|
label_lengths)
|
||||||
|
|
||||||
|
padding = pipelinefilter(_padding)
|
||||||
|
|
||||||
|
def _cmvn(source, cmvn_file):
|
||||||
|
global_cmvn = GlobalCMVN(cmvn_file)
|
||||||
|
for batch in source:
|
||||||
|
sorted_keys, padded_feats, feats_lengths, padding_labels, label_lengths = batch
|
||||||
|
padded_feats = padded_feats.numpy()
|
||||||
|
padded_feats = global_cmvn(padded_feats)
|
||||||
|
padded_feats = paddle.to_tensor(padded_feats, dtype=paddle.float32)
|
||||||
|
yield (sorted_keys, padded_feats, feats_lengths, padding_labels,
|
||||||
|
label_lengths)
|
||||||
|
|
||||||
|
cmvn = pipelinefilter(_cmvn)
|
@ -0,0 +1,33 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
# This file is part of the WebDataset library.
|
||||||
|
# See the LICENSE file for licensing terms (BSD-style).
|
||||||
|
# Modified from https://github.com/webdataset/webdataset
|
||||||
|
#
|
||||||
|
|
||||||
|
"""Mock implementations of paddle interfaces when paddle is not available."""
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from paddle.io import DataLoader, IterableDataset
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
|
||||||
|
class IterableDataset:
|
||||||
|
"""Empty implementation of IterableDataset when paddle is not available."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
class DataLoader:
|
||||||
|
"""Empty implementation of DataLoader when paddle is not available."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from paddle import Tensor as PaddleTensor
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
|
||||||
|
class TorchTensor:
|
||||||
|
"""Empty implementation of PaddleTensor when paddle is not available."""
|
||||||
|
|
||||||
|
pass
|
@ -0,0 +1,127 @@
|
|||||||
|
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
# See the LICENSE file for licensing terms (BSD-style).
|
||||||
|
# Modified from https://github.com/webdataset/webdataset
|
||||||
|
#%%
|
||||||
|
import copy, os, random, sys, time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from itertools import islice
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import braceexpand, yaml
|
||||||
|
|
||||||
|
from webdataset import autodecode, extradatasets as eds, filters, shardlists, tariterators
|
||||||
|
from webdataset.handlers import reraise_exception
|
||||||
|
from .paddle_utils import DataLoader, IterableDataset
|
||||||
|
from .utils import PipelineStage
|
||||||
|
|
||||||
|
|
||||||
|
def add_length_method(obj):
|
||||||
|
def length(self):
|
||||||
|
return self.size
|
||||||
|
|
||||||
|
Combined = type(
|
||||||
|
obj.__class__.__name__ + "_Length",
|
||||||
|
(obj.__class__, IterableDataset),
|
||||||
|
{"__len__": length},
|
||||||
|
)
|
||||||
|
obj.__class__ = Combined
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
class DataPipeline(IterableDataset, PipelineStage):
|
||||||
|
"""A pipeline starting with an IterableDataset and a series of filters."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.pipeline = []
|
||||||
|
self.length = -1
|
||||||
|
self.repetitions = 1
|
||||||
|
self.nsamples = -1
|
||||||
|
for arg in args:
|
||||||
|
if arg is None:
|
||||||
|
continue
|
||||||
|
if isinstance(arg, list):
|
||||||
|
self.pipeline.extend(arg)
|
||||||
|
else:
|
||||||
|
self.pipeline.append(arg)
|
||||||
|
|
||||||
|
def invoke(self, f, *args, **kwargs):
|
||||||
|
"""Apply a pipeline stage, possibly to the output of a previous stage."""
|
||||||
|
if isinstance(f, PipelineStage):
|
||||||
|
return f.run(*args, **kwargs)
|
||||||
|
if isinstance(f, (IterableDataset, DataLoader)) and len(args) == 0:
|
||||||
|
return iter(f)
|
||||||
|
if isinstance(f, list):
|
||||||
|
return iter(f)
|
||||||
|
if callable(f):
|
||||||
|
result = f(*args, **kwargs)
|
||||||
|
return result
|
||||||
|
raise ValueError(f"{f}: not a valid pipeline stage")
|
||||||
|
|
||||||
|
def iterator1(self):
|
||||||
|
"""Create an iterator through one epoch in the pipeline."""
|
||||||
|
source = self.invoke(self.pipeline[0])
|
||||||
|
for step in self.pipeline[1:]:
|
||||||
|
source = self.invoke(step, source)
|
||||||
|
return source
|
||||||
|
|
||||||
|
def iterator(self):
|
||||||
|
"""Create an iterator through the entire dataset, using the given number of repetitions."""
|
||||||
|
for i in range(self.repetitions):
|
||||||
|
for sample in self.iterator1():
|
||||||
|
yield sample
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
"""Create an iterator through the pipeline, repeating and slicing as requested."""
|
||||||
|
if self.repetitions != 1:
|
||||||
|
if self.nsamples > 0:
|
||||||
|
return islice(self.iterator(), self.nsamples)
|
||||||
|
else:
|
||||||
|
return self.iterator()
|
||||||
|
else:
|
||||||
|
return self.iterator()
|
||||||
|
|
||||||
|
def stage(self, i):
|
||||||
|
"""Return pipeline stage i."""
|
||||||
|
return self.pipeline[i]
|
||||||
|
|
||||||
|
def append(self, f):
|
||||||
|
"""Append a pipeline stage (modifies the object)."""
|
||||||
|
self.pipeline.append(f)
|
||||||
|
|
||||||
|
def compose(self, *args):
|
||||||
|
"""Append a pipeline stage to a copy of the pipeline and returns the copy."""
|
||||||
|
result = copy.copy(self)
|
||||||
|
for arg in args:
|
||||||
|
result.append(arg)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def with_length(self, n):
|
||||||
|
"""Add a __len__ method returning the desired value.
|
||||||
|
|
||||||
|
This does not change the actual number of samples in an epoch.
|
||||||
|
PyTorch IterableDataset should not have a __len__ method.
|
||||||
|
This is provided only as a workaround for some broken training environments
|
||||||
|
that require a __len__ method.
|
||||||
|
"""
|
||||||
|
self.size = n
|
||||||
|
return add_length_method(self)
|
||||||
|
|
||||||
|
def with_epoch(self, nsamples=-1, nbatches=-1):
|
||||||
|
"""Change the epoch to return the given number of samples/batches.
|
||||||
|
|
||||||
|
The two arguments mean the same thing."""
|
||||||
|
self.repetitions = sys.maxsize
|
||||||
|
self.nsamples = max(nsamples, nbatches)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def repeat(self, nepochs=-1, nbatches=-1):
|
||||||
|
"""Repeat iterating through the dataset for the given #epochs up to the given #samples."""
|
||||||
|
if nepochs > 0:
|
||||||
|
self.repetitions = nepochs
|
||||||
|
self.nsamples = nbatches
|
||||||
|
else:
|
||||||
|
self.repetitions = sys.maxsize
|
||||||
|
self.nsamples = nbatches
|
||||||
|
return self
|
@ -0,0 +1,257 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
# This file is part of the WebDataset library.
|
||||||
|
# See the LICENSE file for licensing terms (BSD-style).
|
||||||
|
#
|
||||||
|
|
||||||
|
# Modified from https://github.com/webdataset/webdataset
|
||||||
|
|
||||||
|
"""Train PyTorch models directly from POSIX tar archive.
|
||||||
|
|
||||||
|
Code works locally or over HTTP connections.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os, random, sys, time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from itertools import islice
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import braceexpand, yaml
|
||||||
|
|
||||||
|
from . import utils
|
||||||
|
from .filters import pipelinefilter
|
||||||
|
from .paddle_utils import IterableDataset
|
||||||
|
|
||||||
|
|
||||||
|
def expand_urls(urls):
|
||||||
|
if isinstance(urls, str):
|
||||||
|
urllist = urls.split("::")
|
||||||
|
result = []
|
||||||
|
for url in urllist:
|
||||||
|
result.extend(braceexpand.braceexpand(url))
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
return list(urls)
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleShardList(IterableDataset):
|
||||||
|
"""An iterable dataset yielding a list of urls."""
|
||||||
|
|
||||||
|
def __init__(self, urls, seed=None):
|
||||||
|
"""Iterate through the list of shards.
|
||||||
|
|
||||||
|
:param urls: a list of URLs as a Python list or brace notation string
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
urls = expand_urls(urls)
|
||||||
|
self.urls = urls
|
||||||
|
assert isinstance(self.urls[0], str)
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.urls)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
"""Return an iterator over the shards."""
|
||||||
|
urls = self.urls.copy()
|
||||||
|
if self.seed is not None:
|
||||||
|
random.Random(self.seed).shuffle(urls)
|
||||||
|
for url in urls:
|
||||||
|
yield dict(url=url)
|
||||||
|
|
||||||
|
|
||||||
|
def split_by_node(src, group=None):
|
||||||
|
rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group)
|
||||||
|
if world_size > 1:
|
||||||
|
for s in islice(src, rank, None, world_size):
|
||||||
|
yield s
|
||||||
|
else:
|
||||||
|
for s in src:
|
||||||
|
yield s
|
||||||
|
|
||||||
|
|
||||||
|
def single_node_only(src, group=None):
|
||||||
|
rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group)
|
||||||
|
if world_size > 1:
|
||||||
|
raise ValueError("input pipeline needs to be reconfigured for multinode training")
|
||||||
|
for s in src:
|
||||||
|
yield s
|
||||||
|
|
||||||
|
|
||||||
|
def split_by_worker(src):
|
||||||
|
rank, world_size, worker, num_workers = utils.paddle_worker_info()
|
||||||
|
if num_workers > 1:
|
||||||
|
for s in islice(src, worker, None, num_workers):
|
||||||
|
yield s
|
||||||
|
else:
|
||||||
|
for s in src:
|
||||||
|
yield s
|
||||||
|
|
||||||
|
|
||||||
|
def resampled_(src, n=sys.maxsize):
|
||||||
|
import random
|
||||||
|
|
||||||
|
seed = time.time()
|
||||||
|
try:
|
||||||
|
seed = open("/dev/random", "rb").read(20)
|
||||||
|
except Exception as exn:
|
||||||
|
print(repr(exn)[:50], file=sys.stderr)
|
||||||
|
rng = random.Random(seed)
|
||||||
|
print("# resampled loading", file=sys.stderr)
|
||||||
|
items = list(src)
|
||||||
|
print(f"# resampled got {len(items)} samples, yielding {n}", file=sys.stderr)
|
||||||
|
for i in range(n):
|
||||||
|
yield rng.choice(items)
|
||||||
|
|
||||||
|
|
||||||
|
resampled = pipelinefilter(resampled_)
|
||||||
|
|
||||||
|
|
||||||
|
def non_empty(src):
|
||||||
|
count = 0
|
||||||
|
for s in src:
|
||||||
|
yield s
|
||||||
|
count += 1
|
||||||
|
if count == 0:
|
||||||
|
raise ValueError("pipeline stage received no data at all and this was declared as an error")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MSSource:
|
||||||
|
"""Class representing a data source."""
|
||||||
|
|
||||||
|
name: str = ""
|
||||||
|
perepoch: int = -1
|
||||||
|
resample: bool = False
|
||||||
|
urls: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
default_rng = random.Random()
|
||||||
|
|
||||||
|
|
||||||
|
def expand(s):
|
||||||
|
return os.path.expanduser(os.path.expandvars(s))
|
||||||
|
|
||||||
|
|
||||||
|
class MultiShardSample(IterableDataset):
|
||||||
|
def __init__(self, fname):
|
||||||
|
"""Construct a shardlist from multiple sources using a YAML spec."""
|
||||||
|
self.epoch = -1
|
||||||
|
class MultiShardSample(IterableDataset):
|
||||||
|
def __init__(self, fname):
|
||||||
|
"""Construct a shardlist from multiple sources using a YAML spec."""
|
||||||
|
self.epoch = -1
|
||||||
|
self.parse_spec(fname)
|
||||||
|
|
||||||
|
def parse_spec(self, fname):
|
||||||
|
self.rng = default_rng # capture default_rng if we fork
|
||||||
|
if isinstance(fname, dict):
|
||||||
|
spec = fname
|
||||||
|
fname = "{dict}"
|
||||||
|
else:
|
||||||
|
with open(fname) as stream:
|
||||||
|
spec = yaml.safe_load(stream)
|
||||||
|
assert set(spec.keys()).issubset(set("prefix datasets buckets".split())), list(spec.keys())
|
||||||
|
prefix = expand(spec.get("prefix", ""))
|
||||||
|
self.sources = []
|
||||||
|
for ds in spec["datasets"]:
|
||||||
|
assert set(ds.keys()).issubset(set("buckets name shards resample choose".split())), list(
|
||||||
|
ds.keys()
|
||||||
|
)
|
||||||
|
buckets = ds.get("buckets", spec.get("buckets", []))
|
||||||
|
if isinstance(buckets, str):
|
||||||
|
buckets = [buckets]
|
||||||
|
buckets = [expand(s) for s in buckets]
|
||||||
|
if buckets == []:
|
||||||
|
buckets = [""]
|
||||||
|
assert len(buckets) == 1, f"{buckets}: FIXME support for multiple buckets unimplemented"
|
||||||
|
bucket = buckets[0]
|
||||||
|
name = ds.get("name", "@" + bucket)
|
||||||
|
urls = ds["shards"]
|
||||||
|
if isinstance(urls, str):
|
||||||
|
urls = [urls]
|
||||||
|
# urls = [u for url in urls for u in braceexpand.braceexpand(url)]
|
||||||
|
urls = [
|
||||||
|
prefix + os.path.join(bucket, u) for url in urls for u in braceexpand.braceexpand(expand(url))
|
||||||
|
]
|
||||||
|
resample = ds.get("resample", -1)
|
||||||
|
nsample = ds.get("choose", -1)
|
||||||
|
if nsample > len(urls):
|
||||||
|
raise ValueError(f"perepoch {nsample} must be no greater than the number of shards")
|
||||||
|
if (nsample > 0) and (resample > 0):
|
||||||
|
raise ValueError("specify only one of perepoch or choose")
|
||||||
|
entry = MSSource(name=name, urls=urls, perepoch=nsample, resample=resample)
|
||||||
|
self.sources.append(entry)
|
||||||
|
print(f"# {name} {len(urls)} {nsample}", file=sys.stderr)
|
||||||
|
|
||||||
|
def set_epoch(self, seed):
|
||||||
|
"""Set the current epoch (for consistent shard selection among nodes)."""
|
||||||
|
self.rng = random.Random(seed)
|
||||||
|
|
||||||
|
def get_shards_for_epoch(self):
|
||||||
|
result = []
|
||||||
|
for source in self.sources:
|
||||||
|
if source.resample > 0:
|
||||||
|
# sample with replacement
|
||||||
|
l = self.rng.choices(source.urls, k=source.resample)
|
||||||
|
elif source.perepoch > 0:
|
||||||
|
# sample without replacement
|
||||||
|
l = list(source.urls)
|
||||||
|
self.rng.shuffle(l)
|
||||||
|
l = l[: source.perepoch]
|
||||||
|
else:
|
||||||
|
l = list(source.urls)
|
||||||
|
result += l
|
||||||
|
self.rng.shuffle(result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
shards = self.get_shards_for_epoch()
|
||||||
|
for shard in shards:
|
||||||
|
yield dict(url=shard)
|
||||||
|
|
||||||
|
|
||||||
|
def shardspec(spec):
|
||||||
|
if spec.endswith(".yaml"):
|
||||||
|
return MultiShardSample(spec)
|
||||||
|
else:
|
||||||
|
return SimpleShardList(spec)
|
||||||
|
|
||||||
|
|
||||||
|
class ResampledShards(IterableDataset):
|
||||||
|
"""An iterable dataset yielding a list of urls."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
urls,
|
||||||
|
nshards=sys.maxsize,
|
||||||
|
worker_seed=None,
|
||||||
|
deterministic=False,
|
||||||
|
):
|
||||||
|
"""Sample shards from the shard list with replacement.
|
||||||
|
|
||||||
|
:param urls: a list of URLs as a Python list or brace notation string
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
urls = expand_urls(urls)
|
||||||
|
self.urls = urls
|
||||||
|
assert isinstance(self.urls[0], str)
|
||||||
|
self.nshards = nshards
|
||||||
|
self.worker_seed = utils.paddle_worker_seed if worker_seed is None else worker_seed
|
||||||
|
self.deterministic = deterministic
|
||||||
|
self.epoch = -1
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
"""Return an iterator over the shards."""
|
||||||
|
self.epoch += 1
|
||||||
|
if self.deterministic:
|
||||||
|
seed = utils.make_seed(self.worker_seed(), self.epoch)
|
||||||
|
else:
|
||||||
|
seed = utils.make_seed(self.worker_seed(), self.epoch, os.getpid(), time.time_ns(), os.urandom(4))
|
||||||
|
if os.environ.get("WDS_SHOW_SEED", "0") == "1":
|
||||||
|
print(f"# ResampledShards seed {seed}")
|
||||||
|
self.rng = random.Random(seed)
|
||||||
|
for _ in range(self.nshards):
|
||||||
|
index = self.rng.randint(0, len(self.urls) - 1)
|
||||||
|
yield dict(url=self.urls[index])
|
@ -0,0 +1,283 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
# This file is part of the WebDataset library.
|
||||||
|
# See the LICENSE file for licensing terms (BSD-style).
|
||||||
|
|
||||||
|
# Modified from https://github.com/webdataset/webdataset
|
||||||
|
# Modified from wenet(https://github.com/wenet-e2e/wenet)
|
||||||
|
|
||||||
|
"""Low level iteration functions for tar archives."""
|
||||||
|
|
||||||
|
import random, re, tarfile
|
||||||
|
|
||||||
|
import braceexpand
|
||||||
|
|
||||||
|
from . import filters
|
||||||
|
from webdataset import gopen
|
||||||
|
from webdataset.handlers import reraise_exception
|
||||||
|
|
||||||
|
trace = False
|
||||||
|
meta_prefix = "__"
|
||||||
|
meta_suffix = "__"
|
||||||
|
|
||||||
|
from ... import audio as paddleaudio
|
||||||
|
import paddle
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
|
||||||
|
|
||||||
|
def base_plus_ext(path):
|
||||||
|
"""Split off all file extensions.
|
||||||
|
|
||||||
|
Returns base, allext.
|
||||||
|
|
||||||
|
:param path: path with extensions
|
||||||
|
:param returns: path with all extensions removed
|
||||||
|
|
||||||
|
"""
|
||||||
|
match = re.match(r"^((?:.*/|)[^.]+)[.]([^/]*)$", path)
|
||||||
|
if not match:
|
||||||
|
return None, None
|
||||||
|
return match.group(1), match.group(2)
|
||||||
|
|
||||||
|
|
||||||
|
def valid_sample(sample):
|
||||||
|
"""Check whether a sample is valid.
|
||||||
|
|
||||||
|
:param sample: sample to be checked
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
sample is not None
|
||||||
|
and isinstance(sample, dict)
|
||||||
|
and len(list(sample.keys())) > 0
|
||||||
|
and not sample.get("__bad__", False)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# FIXME: UNUSED
|
||||||
|
def shardlist(urls, *, shuffle=False):
|
||||||
|
"""Given a list of URLs, yields that list, possibly shuffled."""
|
||||||
|
if isinstance(urls, str):
|
||||||
|
urls = braceexpand.braceexpand(urls)
|
||||||
|
else:
|
||||||
|
urls = list(urls)
|
||||||
|
if shuffle:
|
||||||
|
random.shuffle(urls)
|
||||||
|
for url in urls:
|
||||||
|
yield dict(url=url)
|
||||||
|
|
||||||
|
|
||||||
|
def url_opener(data, handler=reraise_exception, **kw):
|
||||||
|
"""Given a stream of url names (packaged in `dict(url=url)`), yield opened streams."""
|
||||||
|
for sample in data:
|
||||||
|
assert isinstance(sample, dict), sample
|
||||||
|
assert "url" in sample
|
||||||
|
url = sample["url"]
|
||||||
|
try:
|
||||||
|
stream = gopen.gopen(url, **kw)
|
||||||
|
sample.update(stream=stream)
|
||||||
|
yield sample
|
||||||
|
except Exception as exn:
|
||||||
|
exn.args = exn.args + (url,)
|
||||||
|
if handler(exn):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def tar_file_iterator(
|
||||||
|
fileobj, skip_meta=r"__[^/]*__($|/)", handler=reraise_exception
|
||||||
|
):
|
||||||
|
"""Iterate over tar file, yielding filename, content pairs for the given tar stream.
|
||||||
|
|
||||||
|
:param fileobj: byte stream suitable for tarfile
|
||||||
|
:param skip_meta: regexp for keys that are skipped entirely (Default value = r"__[^/]*__($|/)")
|
||||||
|
|
||||||
|
"""
|
||||||
|
stream = tarfile.open(fileobj=fileobj, mode="r:*")
|
||||||
|
for tarinfo in stream:
|
||||||
|
fname = tarinfo.name
|
||||||
|
try:
|
||||||
|
if not tarinfo.isreg():
|
||||||
|
continue
|
||||||
|
if fname is None:
|
||||||
|
continue
|
||||||
|
if (
|
||||||
|
"/" not in fname
|
||||||
|
and fname.startswith(meta_prefix)
|
||||||
|
and fname.endswith(meta_suffix)
|
||||||
|
):
|
||||||
|
# skipping metadata for now
|
||||||
|
continue
|
||||||
|
if skip_meta is not None and re.match(skip_meta, fname):
|
||||||
|
continue
|
||||||
|
|
||||||
|
name = tarinfo.name
|
||||||
|
pos = name.rfind('.')
|
||||||
|
assert pos > 0
|
||||||
|
prefix, postfix = name[:pos], name[pos + 1:]
|
||||||
|
if postfix == 'wav':
|
||||||
|
waveform, sample_rate = paddleaudio.load(stream.extractfile(tarinfo), normal=False)
|
||||||
|
result = dict(fname=prefix, wav=waveform, sample_rate = sample_rate)
|
||||||
|
else:
|
||||||
|
txt = stream.extractfile(tarinfo).read().decode('utf8').strip()
|
||||||
|
result = dict(fname=prefix, txt=txt)
|
||||||
|
#result = dict(fname=fname, data=data)
|
||||||
|
yield result
|
||||||
|
stream.members = []
|
||||||
|
except Exception as exn:
|
||||||
|
if hasattr(exn, "args") and len(exn.args) > 0:
|
||||||
|
exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:]
|
||||||
|
if handler(exn):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
del stream
|
||||||
|
|
||||||
|
def tar_file_and_group_iterator(
|
||||||
|
fileobj, skip_meta=r"__[^/]*__($|/)", handler=reraise_exception
|
||||||
|
):
|
||||||
|
""" Expand a stream of open tar files into a stream of tar file contents.
|
||||||
|
And groups the file with same prefix
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Iterable[{src, stream}]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Iterable[{key, wav, txt, sample_rate}]
|
||||||
|
"""
|
||||||
|
stream = tarfile.open(fileobj=fileobj, mode="r:*")
|
||||||
|
prev_prefix = None
|
||||||
|
example = {}
|
||||||
|
valid = True
|
||||||
|
for tarinfo in stream:
|
||||||
|
name = tarinfo.name
|
||||||
|
pos = name.rfind('.')
|
||||||
|
assert pos > 0
|
||||||
|
prefix, postfix = name[:pos], name[pos + 1:]
|
||||||
|
if prev_prefix is not None and prefix != prev_prefix:
|
||||||
|
example['fname'] = prev_prefix
|
||||||
|
if valid:
|
||||||
|
yield example
|
||||||
|
example = {}
|
||||||
|
valid = True
|
||||||
|
with stream.extractfile(tarinfo) as file_obj:
|
||||||
|
try:
|
||||||
|
if postfix == 'txt':
|
||||||
|
example['txt'] = file_obj.read().decode('utf8').strip()
|
||||||
|
elif postfix in AUDIO_FORMAT_SETS:
|
||||||
|
waveform, sample_rate = paddleaudio.load(file_obj, normal=False)
|
||||||
|
waveform = paddle.to_tensor(np.expand_dims(np.array(waveform),0), dtype=paddle.float32)
|
||||||
|
|
||||||
|
example['wav'] = waveform
|
||||||
|
example['sample_rate'] = sample_rate
|
||||||
|
else:
|
||||||
|
example[postfix] = file_obj.read()
|
||||||
|
except Exception as exn:
|
||||||
|
if hasattr(exn, "args") and len(exn.args) > 0:
|
||||||
|
exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:]
|
||||||
|
if handler(exn):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
valid = False
|
||||||
|
# logging.warning('error to parse {}'.format(name))
|
||||||
|
prev_prefix = prefix
|
||||||
|
if prev_prefix is not None:
|
||||||
|
example['fname'] = prev_prefix
|
||||||
|
yield example
|
||||||
|
stream.close()
|
||||||
|
|
||||||
|
def tar_file_expander(data, handler=reraise_exception):
|
||||||
|
"""Expand a stream of open tar files into a stream of tar file contents.
|
||||||
|
|
||||||
|
This returns an iterator over (filename, file_contents).
|
||||||
|
"""
|
||||||
|
for source in data:
|
||||||
|
url = source["url"]
|
||||||
|
try:
|
||||||
|
assert isinstance(source, dict)
|
||||||
|
assert "stream" in source
|
||||||
|
for sample in tar_file_iterator(source["stream"]):
|
||||||
|
assert (
|
||||||
|
isinstance(sample, dict) and "data" in sample and "fname" in sample
|
||||||
|
)
|
||||||
|
sample["__url__"] = url
|
||||||
|
yield sample
|
||||||
|
except Exception as exn:
|
||||||
|
exn.args = exn.args + (source.get("stream"), source.get("url"))
|
||||||
|
if handler(exn):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def tar_file_and_group_expander(data, handler=reraise_exception):
|
||||||
|
"""Expand a stream of open tar files into a stream of tar file contents.
|
||||||
|
|
||||||
|
This returns an iterator over (filename, file_contents).
|
||||||
|
"""
|
||||||
|
for source in data:
|
||||||
|
url = source["url"]
|
||||||
|
try:
|
||||||
|
assert isinstance(source, dict)
|
||||||
|
assert "stream" in source
|
||||||
|
for sample in tar_file_and_group_iterator(source["stream"]):
|
||||||
|
assert (
|
||||||
|
isinstance(sample, dict) and "wav" in sample and "txt" in sample and "fname" in sample
|
||||||
|
)
|
||||||
|
sample["__url__"] = url
|
||||||
|
yield sample
|
||||||
|
except Exception as exn:
|
||||||
|
exn.args = exn.args + (source.get("stream"), source.get("url"))
|
||||||
|
if handler(exn):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def group_by_keys(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
|
||||||
|
"""Return function over iterator that groups key, value pairs into samples.
|
||||||
|
|
||||||
|
:param keys: function that splits the key into key and extension (base_plus_ext)
|
||||||
|
:param lcase: convert suffixes to lower case (Default value = True)
|
||||||
|
"""
|
||||||
|
current_sample = None
|
||||||
|
for filesample in data:
|
||||||
|
assert isinstance(filesample, dict)
|
||||||
|
fname, value = filesample["fname"], filesample["data"]
|
||||||
|
prefix, suffix = keys(fname)
|
||||||
|
if trace:
|
||||||
|
print(
|
||||||
|
prefix,
|
||||||
|
suffix,
|
||||||
|
current_sample.keys() if isinstance(current_sample, dict) else None,
|
||||||
|
)
|
||||||
|
if prefix is None:
|
||||||
|
continue
|
||||||
|
if lcase:
|
||||||
|
suffix = suffix.lower()
|
||||||
|
if current_sample is None or prefix != current_sample["__key__"]:
|
||||||
|
if valid_sample(current_sample):
|
||||||
|
yield current_sample
|
||||||
|
current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
|
||||||
|
if suffix in current_sample:
|
||||||
|
raise ValueError(
|
||||||
|
f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}"
|
||||||
|
)
|
||||||
|
if suffixes is None or suffix in suffixes:
|
||||||
|
current_sample[suffix] = value
|
||||||
|
if valid_sample(current_sample):
|
||||||
|
yield current_sample
|
||||||
|
|
||||||
|
|
||||||
|
def tarfile_samples(src, handler=reraise_exception):
|
||||||
|
streams = url_opener(src, handler=handler)
|
||||||
|
samples = tar_file_and_group_expander(streams, handler=handler)
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
tarfile_to_samples = filters.pipelinefilter(tarfile_samples)
|
@ -0,0 +1,128 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
# This file is part of the WebDataset library.
|
||||||
|
# See the LICENSE file for licensing terms (BSD-style).
|
||||||
|
#
|
||||||
|
|
||||||
|
# Modified from https://github.com/webdataset/webdataset
|
||||||
|
|
||||||
|
"""Miscellaneous utility functions."""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import itertools as itt
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from typing import Any, Callable, Iterator, Optional, Union
|
||||||
|
|
||||||
|
|
||||||
|
def make_seed(*args):
|
||||||
|
seed = 0
|
||||||
|
for arg in args:
|
||||||
|
seed = (seed * 31 + hash(arg)) & 0x7FFFFFFF
|
||||||
|
return seed
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineStage:
|
||||||
|
def invoke(self, *args, **kw):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def identity(x: Any) -> Any:
|
||||||
|
"""Return the argument as is."""
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def safe_eval(s: str, expr: str = "{}"):
|
||||||
|
"""Evaluate the given expression more safely."""
|
||||||
|
if re.sub("[^A-Za-z0-9_]", "", s) != s:
|
||||||
|
raise ValueError(f"safe_eval: illegal characters in: '{s}'")
|
||||||
|
return eval(expr.format(s))
|
||||||
|
|
||||||
|
|
||||||
|
def lookup_sym(sym: str, modules: list):
|
||||||
|
"""Look up a symbol in a list of modules."""
|
||||||
|
for mname in modules:
|
||||||
|
module = importlib.import_module(mname, package="webdataset")
|
||||||
|
result = getattr(module, sym, None)
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def repeatedly0(
|
||||||
|
loader: Iterator, nepochs: int = sys.maxsize, nbatches: int = sys.maxsize
|
||||||
|
):
|
||||||
|
"""Repeatedly returns batches from a DataLoader."""
|
||||||
|
for epoch in range(nepochs):
|
||||||
|
for sample in itt.islice(loader, nbatches):
|
||||||
|
yield sample
|
||||||
|
|
||||||
|
|
||||||
|
def guess_batchsize(batch: Union[tuple, list]):
|
||||||
|
"""Guess the batch size by looking at the length of the first element in a tuple."""
|
||||||
|
return len(batch[0])
|
||||||
|
|
||||||
|
|
||||||
|
def repeatedly(
|
||||||
|
source: Iterator,
|
||||||
|
nepochs: int = None,
|
||||||
|
nbatches: int = None,
|
||||||
|
nsamples: int = None,
|
||||||
|
batchsize: Callable[..., int] = guess_batchsize,
|
||||||
|
):
|
||||||
|
"""Repeatedly yield samples from an iterator."""
|
||||||
|
epoch = 0
|
||||||
|
batch = 0
|
||||||
|
total = 0
|
||||||
|
while True:
|
||||||
|
for sample in source:
|
||||||
|
yield sample
|
||||||
|
batch += 1
|
||||||
|
if nbatches is not None and batch >= nbatches:
|
||||||
|
return
|
||||||
|
if nsamples is not None:
|
||||||
|
total += guess_batchsize(sample)
|
||||||
|
if total >= nsamples:
|
||||||
|
return
|
||||||
|
epoch += 1
|
||||||
|
if nepochs is not None and epoch >= nepochs:
|
||||||
|
return
|
||||||
|
|
||||||
|
def paddle_worker_info(group=None):
|
||||||
|
"""Return node and worker info for PyTorch and some distributed environments."""
|
||||||
|
rank = 0
|
||||||
|
world_size = 1
|
||||||
|
worker = 0
|
||||||
|
num_workers = 1
|
||||||
|
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||||
|
rank = int(os.environ["RANK"])
|
||||||
|
world_size = int(os.environ["WORLD_SIZE"])
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
import paddle.distributed
|
||||||
|
group = group or paddle.distributed.get_group()
|
||||||
|
rank = paddle.distributed.get_rank()
|
||||||
|
world_size = paddle.distributed.get_world_size()
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
pass
|
||||||
|
if "WORKER" in os.environ and "NUM_WORKERS" in os.environ:
|
||||||
|
worker = int(os.environ["WORKER"])
|
||||||
|
num_workers = int(os.environ["NUM_WORKERS"])
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
import paddle.io.get_worker_info
|
||||||
|
worker_info = paddle.io.get_worker_info()
|
||||||
|
if worker_info is not None:
|
||||||
|
worker = worker_info.id
|
||||||
|
num_workers = worker_info.num_workers
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return rank, world_size, worker, num_workers
|
||||||
|
|
||||||
|
def paddle_worker_seed(group=None):
|
||||||
|
"""Compute a distinct, deterministic RNG seed for each worker and node."""
|
||||||
|
rank, world_size, worker, num_workers = paddle_worker_info(group=group)
|
||||||
|
return rank * 1000 + worker
|
@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
@ -0,0 +1,54 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Modified from espnet(https://github.com/espnet/espnet)
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def delta(feat, window):
|
||||||
|
assert window > 0
|
||||||
|
delta_feat = np.zeros_like(feat)
|
||||||
|
for i in range(1, window + 1):
|
||||||
|
delta_feat[:-i] += i * feat[i:]
|
||||||
|
delta_feat[i:] += -i * feat[:-i]
|
||||||
|
delta_feat[-i:] += i * feat[-1]
|
||||||
|
delta_feat[:i] += -i * feat[0]
|
||||||
|
delta_feat /= 2 * sum(i**2 for i in range(1, window + 1))
|
||||||
|
return delta_feat
|
||||||
|
|
||||||
|
|
||||||
|
def add_deltas(x, window=2, order=2):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (np.ndarray): speech feat, (T, D).
|
||||||
|
|
||||||
|
Return:
|
||||||
|
np.ndarray: (T, (1+order)*D)
|
||||||
|
"""
|
||||||
|
feats = [x]
|
||||||
|
for _ in range(order):
|
||||||
|
feats.append(delta(feats[-1], window))
|
||||||
|
return np.concatenate(feats, axis=1)
|
||||||
|
|
||||||
|
|
||||||
|
class AddDeltas():
|
||||||
|
def __init__(self, window=2, order=2):
|
||||||
|
self.window = window
|
||||||
|
self.order = order
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{name}(window={window}, order={order}".format(
|
||||||
|
name=self.__class__.__name__, window=self.window, order=self.order)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return add_deltas(x, window=self.window, order=self.order)
|
@ -0,0 +1,57 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Modified from espnet(https://github.com/espnet/espnet)
|
||||||
|
import numpy
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelSelector():
|
||||||
|
"""Select 1ch from multi-channel signal"""
|
||||||
|
|
||||||
|
def __init__(self, train_channel="random", eval_channel=0, axis=1):
|
||||||
|
self.train_channel = train_channel
|
||||||
|
self.eval_channel = eval_channel
|
||||||
|
self.axis = axis
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return ("{name}(train_channel={train_channel}, "
|
||||||
|
"eval_channel={eval_channel}, axis={axis})".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
train_channel=self.train_channel,
|
||||||
|
eval_channel=self.eval_channel,
|
||||||
|
axis=self.axis, ))
|
||||||
|
|
||||||
|
def __call__(self, x, train=True):
|
||||||
|
# Assuming x: [Time, Channel] by default
|
||||||
|
|
||||||
|
if x.ndim <= self.axis:
|
||||||
|
# If the dimension is insufficient, then unsqueeze
|
||||||
|
# (e.g [Time] -> [Time, 1])
|
||||||
|
ind = tuple(
|
||||||
|
slice(None) if i < x.ndim else None
|
||||||
|
for i in range(self.axis + 1))
|
||||||
|
x = x[ind]
|
||||||
|
|
||||||
|
if train:
|
||||||
|
channel = self.train_channel
|
||||||
|
else:
|
||||||
|
channel = self.eval_channel
|
||||||
|
|
||||||
|
if channel == "random":
|
||||||
|
ch = numpy.random.randint(0, x.shape[self.axis])
|
||||||
|
else:
|
||||||
|
ch = channel
|
||||||
|
|
||||||
|
ind = tuple(
|
||||||
|
slice(None) if i != self.axis else ch for i in range(x.ndim))
|
||||||
|
return x[ind]
|
@ -0,0 +1,201 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Modified from espnet(https://github.com/espnet/espnet)
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
|
||||||
|
import h5py
|
||||||
|
import kaldiio
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class CMVN():
|
||||||
|
"Apply Global/Spk CMVN/iverserCMVN."
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
stats,
|
||||||
|
norm_means=True,
|
||||||
|
norm_vars=False,
|
||||||
|
filetype="mat",
|
||||||
|
utt2spk=None,
|
||||||
|
spk2utt=None,
|
||||||
|
reverse=False,
|
||||||
|
std_floor=1.0e-20, ):
|
||||||
|
self.stats_file = stats
|
||||||
|
self.norm_means = norm_means
|
||||||
|
self.norm_vars = norm_vars
|
||||||
|
self.reverse = reverse
|
||||||
|
|
||||||
|
if isinstance(stats, dict):
|
||||||
|
stats_dict = dict(stats)
|
||||||
|
else:
|
||||||
|
# Use for global CMVN
|
||||||
|
if filetype == "mat":
|
||||||
|
stats_dict = {None: kaldiio.load_mat(stats)}
|
||||||
|
# Use for global CMVN
|
||||||
|
elif filetype == "npy":
|
||||||
|
stats_dict = {None: np.load(stats)}
|
||||||
|
# Use for speaker CMVN
|
||||||
|
elif filetype == "ark":
|
||||||
|
self.accept_uttid = True
|
||||||
|
stats_dict = dict(kaldiio.load_ark(stats))
|
||||||
|
# Use for speaker CMVN
|
||||||
|
elif filetype == "hdf5":
|
||||||
|
self.accept_uttid = True
|
||||||
|
stats_dict = h5py.File(stats)
|
||||||
|
else:
|
||||||
|
raise ValueError("Not supporting filetype={}".format(filetype))
|
||||||
|
|
||||||
|
if utt2spk is not None:
|
||||||
|
self.utt2spk = {}
|
||||||
|
with io.open(utt2spk, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
utt, spk = line.rstrip().split(None, 1)
|
||||||
|
self.utt2spk[utt] = spk
|
||||||
|
elif spk2utt is not None:
|
||||||
|
self.utt2spk = {}
|
||||||
|
with io.open(spk2utt, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
spk, utts = line.rstrip().split(None, 1)
|
||||||
|
for utt in utts.split():
|
||||||
|
self.utt2spk[utt] = spk
|
||||||
|
else:
|
||||||
|
self.utt2spk = None
|
||||||
|
|
||||||
|
# Kaldi makes a matrix for CMVN which has a shape of (2, feat_dim + 1),
|
||||||
|
# and the first vector contains the sum of feats and the second is
|
||||||
|
# the sum of squares. The last value of the first, i.e. stats[0,-1],
|
||||||
|
# is the number of samples for this statistics.
|
||||||
|
self.bias = {}
|
||||||
|
self.scale = {}
|
||||||
|
for spk, stats in stats_dict.items():
|
||||||
|
assert len(stats) == 2, stats.shape
|
||||||
|
|
||||||
|
count = stats[0, -1]
|
||||||
|
|
||||||
|
# If the feature has two or more dimensions
|
||||||
|
if not (np.isscalar(count) or isinstance(count, (int, float))):
|
||||||
|
# The first is only used
|
||||||
|
count = count.flatten()[0]
|
||||||
|
|
||||||
|
mean = stats[0, :-1] / count
|
||||||
|
# V(x) = E(x^2) - (E(x))^2
|
||||||
|
var = stats[1, :-1] / count - mean * mean
|
||||||
|
std = np.maximum(np.sqrt(var), std_floor)
|
||||||
|
self.bias[spk] = -mean
|
||||||
|
self.scale[spk] = 1 / std
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return ("{name}(stats_file={stats_file}, "
|
||||||
|
"norm_means={norm_means}, norm_vars={norm_vars}, "
|
||||||
|
"reverse={reverse})".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
stats_file=self.stats_file,
|
||||||
|
norm_means=self.norm_means,
|
||||||
|
norm_vars=self.norm_vars,
|
||||||
|
reverse=self.reverse, ))
|
||||||
|
|
||||||
|
def __call__(self, x, uttid=None):
|
||||||
|
if self.utt2spk is not None:
|
||||||
|
spk = self.utt2spk[uttid]
|
||||||
|
else:
|
||||||
|
spk = uttid
|
||||||
|
|
||||||
|
if not self.reverse:
|
||||||
|
# apply cmvn
|
||||||
|
if self.norm_means:
|
||||||
|
x = np.add(x, self.bias[spk])
|
||||||
|
if self.norm_vars:
|
||||||
|
x = np.multiply(x, self.scale[spk])
|
||||||
|
|
||||||
|
else:
|
||||||
|
# apply reverse cmvn
|
||||||
|
if self.norm_vars:
|
||||||
|
x = np.divide(x, self.scale[spk])
|
||||||
|
if self.norm_means:
|
||||||
|
x = np.subtract(x, self.bias[spk])
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class UtteranceCMVN():
|
||||||
|
"Apply Utterance CMVN"
|
||||||
|
|
||||||
|
def __init__(self, norm_means=True, norm_vars=False, std_floor=1.0e-20):
|
||||||
|
self.norm_means = norm_means
|
||||||
|
self.norm_vars = norm_vars
|
||||||
|
self.std_floor = std_floor
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{name}(norm_means={norm_means}, norm_vars={norm_vars})".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
norm_means=self.norm_means,
|
||||||
|
norm_vars=self.norm_vars, )
|
||||||
|
|
||||||
|
def __call__(self, x, uttid=None):
|
||||||
|
# x: [Time, Dim]
|
||||||
|
square_sums = (x**2).sum(axis=0)
|
||||||
|
mean = x.mean(axis=0)
|
||||||
|
|
||||||
|
if self.norm_means:
|
||||||
|
x = np.subtract(x, mean)
|
||||||
|
|
||||||
|
if self.norm_vars:
|
||||||
|
var = square_sums / x.shape[0] - mean**2
|
||||||
|
std = np.maximum(np.sqrt(var), self.std_floor)
|
||||||
|
x = np.divide(x, std)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalCMVN():
|
||||||
|
"Apply Global CMVN"
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
cmvn_path,
|
||||||
|
norm_means=True,
|
||||||
|
norm_vars=True,
|
||||||
|
std_floor=1.0e-20):
|
||||||
|
# cmvn_path: Option[str, dict]
|
||||||
|
cmvn = cmvn_path
|
||||||
|
self.cmvn = cmvn
|
||||||
|
self.norm_means = norm_means
|
||||||
|
self.norm_vars = norm_vars
|
||||||
|
self.std_floor = std_floor
|
||||||
|
if isinstance(cmvn, dict):
|
||||||
|
cmvn_stats = cmvn
|
||||||
|
else:
|
||||||
|
with open(cmvn) as f:
|
||||||
|
cmvn_stats = json.load(f)
|
||||||
|
self.count = cmvn_stats['frame_num']
|
||||||
|
self.mean = np.array(cmvn_stats['mean_stat']) / self.count
|
||||||
|
self.square_sums = np.array(cmvn_stats['var_stat'])
|
||||||
|
self.var = self.square_sums / self.count - self.mean**2
|
||||||
|
self.std = np.maximum(np.sqrt(self.var), self.std_floor)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"""{self.__class__.__name__}(
|
||||||
|
cmvn_path={self.cmvn},
|
||||||
|
norm_means={self.norm_means},
|
||||||
|
norm_vars={self.norm_vars},)"""
|
||||||
|
|
||||||
|
def __call__(self, x, uttid=None):
|
||||||
|
# x: [Time, Dim]
|
||||||
|
if self.norm_means:
|
||||||
|
x = np.subtract(x, self.mean)
|
||||||
|
|
||||||
|
if self.norm_vars:
|
||||||
|
x = np.divide(x, self.std)
|
||||||
|
return x
|
@ -0,0 +1,86 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Modified from espnet(https://github.com/espnet/espnet)
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from paddlespeech.audio.transform.transform_interface import TransformInterface
|
||||||
|
from paddlespeech.audio.utils.check_kwargs import check_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
class FuncTrans(TransformInterface):
|
||||||
|
"""Functional Transformation
|
||||||
|
|
||||||
|
WARNING:
|
||||||
|
Builtin or C/C++ functions may not work properly
|
||||||
|
because this class heavily depends on the `inspect` module.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
>>> def foo_bar(x, a=1, b=2):
|
||||||
|
... '''Foo bar
|
||||||
|
... :param x: input
|
||||||
|
... :param int a: default 1
|
||||||
|
... :param int b: default 2
|
||||||
|
... '''
|
||||||
|
... return x + a - b
|
||||||
|
|
||||||
|
|
||||||
|
>>> class FooBar(FuncTrans):
|
||||||
|
... _func = foo_bar
|
||||||
|
... __doc__ = foo_bar.__doc__
|
||||||
|
"""
|
||||||
|
|
||||||
|
_func = None
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.kwargs = kwargs
|
||||||
|
check_kwargs(self.func, kwargs)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return self.func(x, **self.kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_arguments(cls, parser):
|
||||||
|
fname = cls._func.__name__.replace("_", "-")
|
||||||
|
group = parser.add_argument_group(fname + " transformation setting")
|
||||||
|
for k, v in cls.default_params().items():
|
||||||
|
# TODO(karita): get help and choices from docstring?
|
||||||
|
attr = k.replace("_", "-")
|
||||||
|
group.add_argument(f"--{fname}-{attr}", default=v, type=type(v))
|
||||||
|
return parser
|
||||||
|
|
||||||
|
@property
|
||||||
|
def func(self):
|
||||||
|
return type(self)._func
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def default_params(cls):
|
||||||
|
try:
|
||||||
|
d = dict(inspect.signature(cls._func).parameters)
|
||||||
|
except ValueError:
|
||||||
|
d = dict()
|
||||||
|
return {
|
||||||
|
k: v.default
|
||||||
|
for k, v in d.items() if v.default != inspect.Parameter.empty
|
||||||
|
}
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
params = self.default_params()
|
||||||
|
params.update(**self.kwargs)
|
||||||
|
ret = self.__class__.__name__ + "("
|
||||||
|
if len(params) == 0:
|
||||||
|
return ret + ")"
|
||||||
|
for k, v in params.items():
|
||||||
|
ret += "{}={}, ".format(k, v)
|
||||||
|
return ret[:-2] + ")"
|
@ -0,0 +1,561 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Modified from espnet(https://github.com/espnet/espnet)
|
||||||
|
import librosa
|
||||||
|
import numpy
|
||||||
|
import scipy
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class SoundHDF5File():
|
||||||
|
"""Collecting sound files to a HDF5 file
|
||||||
|
|
||||||
|
>>> f = SoundHDF5File('a.flac.h5', mode='a')
|
||||||
|
>>> array = np.random.randint(0, 100, 100, dtype=np.int16)
|
||||||
|
>>> f['id'] = (array, 16000)
|
||||||
|
>>> array, rate = f['id']
|
||||||
|
|
||||||
|
|
||||||
|
:param: str filepath:
|
||||||
|
:param: str mode:
|
||||||
|
:param: str format: The type used when saving wav. flac, nist, htk, etc.
|
||||||
|
:param: str dtype:
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
filepath,
|
||||||
|
mode="r+",
|
||||||
|
format=None,
|
||||||
|
dtype="int16",
|
||||||
|
**kwargs):
|
||||||
|
self.filepath = filepath
|
||||||
|
self.mode = mode
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
self.file = h5py.File(filepath, mode, **kwargs)
|
||||||
|
if format is None:
|
||||||
|
# filepath = a.flac.h5 -> format = flac
|
||||||
|
second_ext = os.path.splitext(os.path.splitext(filepath)[0])[1]
|
||||||
|
format = second_ext[1:]
|
||||||
|
if format.upper() not in soundfile.available_formats():
|
||||||
|
# If not found, flac is selected
|
||||||
|
format = "flac"
|
||||||
|
|
||||||
|
# This format affects only saving
|
||||||
|
self.format = format
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return '<SoundHDF5 file "{}" (mode {}, format {}, type {})>'.format(
|
||||||
|
self.filepath, self.mode, self.format, self.dtype)
|
||||||
|
|
||||||
|
def create_dataset(self, name, shape=None, data=None, **kwds):
|
||||||
|
f = io.BytesIO()
|
||||||
|
array, rate = data
|
||||||
|
soundfile.write(f, array, rate, format=self.format)
|
||||||
|
self.file.create_dataset(
|
||||||
|
name, shape=shape, data=np.void(f.getvalue()), **kwds)
|
||||||
|
|
||||||
|
def __setitem__(self, name, data):
|
||||||
|
self.create_dataset(name, data=data)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
data = self.file[key][()]
|
||||||
|
f = io.BytesIO(data.tobytes())
|
||||||
|
array, rate = soundfile.read(f, dtype=self.dtype)
|
||||||
|
return array, rate
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
return self.file.keys()
|
||||||
|
|
||||||
|
def values(self):
|
||||||
|
for k in self.file:
|
||||||
|
yield self[k]
|
||||||
|
|
||||||
|
def items(self):
|
||||||
|
for k in self.file:
|
||||||
|
yield k, self[k]
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.file)
|
||||||
|
|
||||||
|
def __contains__(self, item):
|
||||||
|
return item in self.file
|
||||||
|
|
||||||
|
def __len__(self, item):
|
||||||
|
return len(self.file)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.file.close()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.file.close()
|
||||||
|
|
||||||
|
class SpeedPerturbation():
|
||||||
|
"""SpeedPerturbation
|
||||||
|
|
||||||
|
The speed perturbation in kaldi uses sox-speed instead of sox-tempo,
|
||||||
|
and sox-speed just to resample the input,
|
||||||
|
i.e pitch and tempo are changed both.
|
||||||
|
|
||||||
|
"Why use speed option instead of tempo -s in SoX for speed perturbation"
|
||||||
|
https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
This function is very slow because of resampling.
|
||||||
|
I recommmend to apply speed-perturb outside the training using sox.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
lower=0.9,
|
||||||
|
upper=1.1,
|
||||||
|
utt2ratio=None,
|
||||||
|
keep_length=True,
|
||||||
|
res_type="kaiser_best",
|
||||||
|
seed=None, ):
|
||||||
|
self.res_type = res_type
|
||||||
|
self.keep_length = keep_length
|
||||||
|
self.state = numpy.random.RandomState(seed)
|
||||||
|
|
||||||
|
if utt2ratio is not None:
|
||||||
|
self.utt2ratio = {}
|
||||||
|
# Use the scheduled ratio for each utterances
|
||||||
|
self.utt2ratio_file = utt2ratio
|
||||||
|
self.lower = None
|
||||||
|
self.upper = None
|
||||||
|
self.accept_uttid = True
|
||||||
|
|
||||||
|
with open(utt2ratio, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
utt, ratio = line.rstrip().split(None, 1)
|
||||||
|
ratio = float(ratio)
|
||||||
|
self.utt2ratio[utt] = ratio
|
||||||
|
else:
|
||||||
|
self.utt2ratio = None
|
||||||
|
# The ratio is given on runtime randomly
|
||||||
|
self.lower = lower
|
||||||
|
self.upper = upper
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
if self.utt2ratio is None:
|
||||||
|
return "{}(lower={}, upper={}, " "keep_length={}, res_type={})".format(
|
||||||
|
self.__class__.__name__,
|
||||||
|
self.lower,
|
||||||
|
self.upper,
|
||||||
|
self.keep_length,
|
||||||
|
self.res_type, )
|
||||||
|
else:
|
||||||
|
return "{}({}, res_type={})".format(
|
||||||
|
self.__class__.__name__, self.utt2ratio_file, self.res_type)
|
||||||
|
|
||||||
|
def __call__(self, x, uttid=None, train=True):
|
||||||
|
if not train:
|
||||||
|
return x
|
||||||
|
x = x.astype(numpy.float32)
|
||||||
|
if self.accept_uttid:
|
||||||
|
ratio = self.utt2ratio[uttid]
|
||||||
|
else:
|
||||||
|
ratio = self.state.uniform(self.lower, self.upper)
|
||||||
|
|
||||||
|
# Note1: resample requires the sampling-rate of input and output,
|
||||||
|
# but actually only the ratio is used.
|
||||||
|
y = librosa.resample(
|
||||||
|
x, orig_sr=ratio, target_sr=1, res_type=self.res_type)
|
||||||
|
|
||||||
|
if self.keep_length:
|
||||||
|
diff = abs(len(x) - len(y))
|
||||||
|
if len(y) > len(x):
|
||||||
|
# Truncate noise
|
||||||
|
y = y[diff // 2:-((diff + 1) // 2)]
|
||||||
|
elif len(y) < len(x):
|
||||||
|
# Assume the time-axis is the first: (Time, Channel)
|
||||||
|
pad_width = [(diff // 2, (diff + 1) // 2)] + [
|
||||||
|
(0, 0) for _ in range(y.ndim - 1)
|
||||||
|
]
|
||||||
|
y = numpy.pad(
|
||||||
|
y, pad_width=pad_width, constant_values=0, mode="constant")
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class SpeedPerturbationSox():
|
||||||
|
"""SpeedPerturbationSox
|
||||||
|
|
||||||
|
The speed perturbation in kaldi uses sox-speed instead of sox-tempo,
|
||||||
|
and sox-speed just to resample the input,
|
||||||
|
i.e pitch and tempo are changed both.
|
||||||
|
|
||||||
|
To speed up or slow down the sound of a file,
|
||||||
|
use speed to modify the pitch and the duration of the file.
|
||||||
|
This raises the speed and reduces the time.
|
||||||
|
The default factor is 1.0 which makes no change to the audio.
|
||||||
|
2.0 doubles speed, thus time length is cut by a half and pitch is one interval higher.
|
||||||
|
|
||||||
|
"Why use speed option instead of tempo -s in SoX for speed perturbation"
|
||||||
|
https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8
|
||||||
|
|
||||||
|
tempo option:
|
||||||
|
sox -t wav input.wav -t wav output.tempo0.9.wav tempo -s 0.9
|
||||||
|
|
||||||
|
speed option:
|
||||||
|
sox -t wav input.wav -t wav output.speed0.9.wav speed 0.9
|
||||||
|
|
||||||
|
If we use speed option like above, the pitch of audio also will be changed,
|
||||||
|
but the tempo option does not change the pitch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
lower=0.9,
|
||||||
|
upper=1.1,
|
||||||
|
utt2ratio=None,
|
||||||
|
keep_length=True,
|
||||||
|
sr=16000,
|
||||||
|
seed=None, ):
|
||||||
|
self.sr = sr
|
||||||
|
self.keep_length = keep_length
|
||||||
|
self.state = numpy.random.RandomState(seed)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import soxbindings as sox
|
||||||
|
except ImportError:
|
||||||
|
try:
|
||||||
|
from paddlespeech.s2t.utils import dynamic_pip_install
|
||||||
|
package = "sox"
|
||||||
|
dynamic_pip_install.install(package)
|
||||||
|
package = "soxbindings"
|
||||||
|
if sys.platform != "win32":
|
||||||
|
dynamic_pip_install.install(package)
|
||||||
|
import soxbindings as sox
|
||||||
|
except Exception:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Can not install soxbindings on your system.")
|
||||||
|
self.sox = sox
|
||||||
|
|
||||||
|
if utt2ratio is not None:
|
||||||
|
self.utt2ratio = {}
|
||||||
|
# Use the scheduled ratio for each utterances
|
||||||
|
self.utt2ratio_file = utt2ratio
|
||||||
|
self.lower = None
|
||||||
|
self.upper = None
|
||||||
|
self.accept_uttid = True
|
||||||
|
|
||||||
|
with open(utt2ratio, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
utt, ratio = line.rstrip().split(None, 1)
|
||||||
|
ratio = float(ratio)
|
||||||
|
self.utt2ratio[utt] = ratio
|
||||||
|
else:
|
||||||
|
self.utt2ratio = None
|
||||||
|
# The ratio is given on runtime randomly
|
||||||
|
self.lower = lower
|
||||||
|
self.upper = upper
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
if self.utt2ratio is None:
|
||||||
|
return f"""{self.__class__.__name__}(
|
||||||
|
lower={self.lower},
|
||||||
|
upper={self.upper},
|
||||||
|
keep_length={self.keep_length},
|
||||||
|
sample_rate={self.sr})"""
|
||||||
|
|
||||||
|
else:
|
||||||
|
return f"""{self.__class__.__name__}(
|
||||||
|
utt2ratio={self.utt2ratio_file},
|
||||||
|
sample_rate={self.sr})"""
|
||||||
|
|
||||||
|
def __call__(self, x, uttid=None, train=True):
|
||||||
|
if not train:
|
||||||
|
return x
|
||||||
|
|
||||||
|
x = x.astype(numpy.float32)
|
||||||
|
if self.accept_uttid:
|
||||||
|
ratio = self.utt2ratio[uttid]
|
||||||
|
else:
|
||||||
|
ratio = self.state.uniform(self.lower, self.upper)
|
||||||
|
|
||||||
|
tfm = self.sox.Transformer()
|
||||||
|
tfm.set_globals(multithread=False)
|
||||||
|
tfm.speed(ratio)
|
||||||
|
y = tfm.build_array(input_array=x, sample_rate_in=self.sr)
|
||||||
|
|
||||||
|
if self.keep_length:
|
||||||
|
diff = abs(len(x) - len(y))
|
||||||
|
if len(y) > len(x):
|
||||||
|
# Truncate noise
|
||||||
|
y = y[diff // 2:-((diff + 1) // 2)]
|
||||||
|
elif len(y) < len(x):
|
||||||
|
# Assume the time-axis is the first: (Time, Channel)
|
||||||
|
pad_width = [(diff // 2, (diff + 1) // 2)] + [
|
||||||
|
(0, 0) for _ in range(y.ndim - 1)
|
||||||
|
]
|
||||||
|
y = numpy.pad(
|
||||||
|
y, pad_width=pad_width, constant_values=0, mode="constant")
|
||||||
|
|
||||||
|
if y.ndim == 2 and x.ndim == 1:
|
||||||
|
# (T, C) -> (T)
|
||||||
|
y = y.sequence(1)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class BandpassPerturbation():
|
||||||
|
"""BandpassPerturbation
|
||||||
|
|
||||||
|
Randomly dropout along the frequency axis.
|
||||||
|
|
||||||
|
The original idea comes from the following:
|
||||||
|
"randomly-selected frequency band was cut off under the constraint of
|
||||||
|
leaving at least 1,000 Hz band within the range of less than 4,000Hz."
|
||||||
|
(The Hitachi/JHU CHiME-5 system: Advances in speech recognition for
|
||||||
|
everyday home environments using multiple microphone arrays;
|
||||||
|
http://spandh.dcs.shef.ac.uk/chime_workshop/papers/CHiME_2018_paper_kanda.pdf)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, lower=0.0, upper=0.75, seed=None, axes=(-1, )):
|
||||||
|
self.lower = lower
|
||||||
|
self.upper = upper
|
||||||
|
self.state = numpy.random.RandomState(seed)
|
||||||
|
# x_stft: (Time, Channel, Freq)
|
||||||
|
self.axes = axes
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{}(lower={}, upper={})".format(self.__class__.__name__,
|
||||||
|
self.lower, self.upper)
|
||||||
|
|
||||||
|
def __call__(self, x_stft, uttid=None, train=True):
|
||||||
|
if not train:
|
||||||
|
return x_stft
|
||||||
|
|
||||||
|
if x_stft.ndim == 1:
|
||||||
|
raise RuntimeError("Input in time-freq domain: "
|
||||||
|
"(Time, Channel, Freq) or (Time, Freq)")
|
||||||
|
|
||||||
|
ratio = self.state.uniform(self.lower, self.upper)
|
||||||
|
axes = [i if i >= 0 else x_stft.ndim - i for i in self.axes]
|
||||||
|
shape = [s if i in axes else 1 for i, s in enumerate(x_stft.shape)]
|
||||||
|
|
||||||
|
mask = self.state.randn(*shape) > ratio
|
||||||
|
x_stft *= mask
|
||||||
|
return x_stft
|
||||||
|
|
||||||
|
|
||||||
|
class VolumePerturbation():
|
||||||
|
def __init__(self,
|
||||||
|
lower=-1.6,
|
||||||
|
upper=1.6,
|
||||||
|
utt2ratio=None,
|
||||||
|
dbunit=True,
|
||||||
|
seed=None):
|
||||||
|
self.dbunit = dbunit
|
||||||
|
self.utt2ratio_file = utt2ratio
|
||||||
|
self.lower = lower
|
||||||
|
self.upper = upper
|
||||||
|
self.state = numpy.random.RandomState(seed)
|
||||||
|
|
||||||
|
if utt2ratio is not None:
|
||||||
|
# Use the scheduled ratio for each utterances
|
||||||
|
self.utt2ratio = {}
|
||||||
|
self.lower = None
|
||||||
|
self.upper = None
|
||||||
|
self.accept_uttid = True
|
||||||
|
|
||||||
|
with open(utt2ratio, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
utt, ratio = line.rstrip().split(None, 1)
|
||||||
|
ratio = float(ratio)
|
||||||
|
self.utt2ratio[utt] = ratio
|
||||||
|
else:
|
||||||
|
# The ratio is given on runtime randomly
|
||||||
|
self.utt2ratio = None
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
if self.utt2ratio is None:
|
||||||
|
return "{}(lower={}, upper={}, dbunit={})".format(
|
||||||
|
self.__class__.__name__, self.lower, self.upper, self.dbunit)
|
||||||
|
else:
|
||||||
|
return '{}("{}", dbunit={})'.format(
|
||||||
|
self.__class__.__name__, self.utt2ratio_file, self.dbunit)
|
||||||
|
|
||||||
|
def __call__(self, x, uttid=None, train=True):
|
||||||
|
if not train:
|
||||||
|
return x
|
||||||
|
|
||||||
|
x = x.astype(numpy.float32)
|
||||||
|
|
||||||
|
if self.accept_uttid:
|
||||||
|
ratio = self.utt2ratio[uttid]
|
||||||
|
else:
|
||||||
|
ratio = self.state.uniform(self.lower, self.upper)
|
||||||
|
if self.dbunit:
|
||||||
|
ratio = 10**(ratio / 20)
|
||||||
|
return x * ratio
|
||||||
|
|
||||||
|
|
||||||
|
class NoiseInjection():
|
||||||
|
"""Add isotropic noise"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
utt2noise=None,
|
||||||
|
lower=-20,
|
||||||
|
upper=-5,
|
||||||
|
utt2ratio=None,
|
||||||
|
filetype="list",
|
||||||
|
dbunit=True,
|
||||||
|
seed=None, ):
|
||||||
|
self.utt2noise_file = utt2noise
|
||||||
|
self.utt2ratio_file = utt2ratio
|
||||||
|
self.filetype = filetype
|
||||||
|
self.dbunit = dbunit
|
||||||
|
self.lower = lower
|
||||||
|
self.upper = upper
|
||||||
|
self.state = numpy.random.RandomState(seed)
|
||||||
|
|
||||||
|
if utt2ratio is not None:
|
||||||
|
# Use the scheduled ratio for each utterances
|
||||||
|
self.utt2ratio = {}
|
||||||
|
with open(utt2noise, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
utt, snr = line.rstrip().split(None, 1)
|
||||||
|
snr = float(snr)
|
||||||
|
self.utt2ratio[utt] = snr
|
||||||
|
else:
|
||||||
|
# The ratio is given on runtime randomly
|
||||||
|
self.utt2ratio = None
|
||||||
|
|
||||||
|
if utt2noise is not None:
|
||||||
|
self.utt2noise = {}
|
||||||
|
if filetype == "list":
|
||||||
|
with open(utt2noise, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
utt, filename = line.rstrip().split(None, 1)
|
||||||
|
signal, rate = soundfile.read(filename, dtype="int16")
|
||||||
|
# Load all files in memory
|
||||||
|
self.utt2noise[utt] = (signal, rate)
|
||||||
|
|
||||||
|
elif filetype == "sound.hdf5":
|
||||||
|
self.utt2noise = SoundHDF5File(utt2noise, "r")
|
||||||
|
else:
|
||||||
|
raise ValueError(filetype)
|
||||||
|
else:
|
||||||
|
self.utt2noise = None
|
||||||
|
|
||||||
|
if utt2noise is not None and utt2ratio is not None:
|
||||||
|
if set(self.utt2ratio) != set(self.utt2noise):
|
||||||
|
raise RuntimeError("The uttids mismatch between {} and {}".
|
||||||
|
format(utt2ratio, utt2noise))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
if self.utt2ratio is None:
|
||||||
|
return "{}(lower={}, upper={}, dbunit={})".format(
|
||||||
|
self.__class__.__name__, self.lower, self.upper, self.dbunit)
|
||||||
|
else:
|
||||||
|
return '{}("{}", dbunit={})'.format(
|
||||||
|
self.__class__.__name__, self.utt2ratio_file, self.dbunit)
|
||||||
|
|
||||||
|
def __call__(self, x, uttid=None, train=True):
|
||||||
|
if not train:
|
||||||
|
return x
|
||||||
|
x = x.astype(numpy.float32)
|
||||||
|
|
||||||
|
# 1. Get ratio of noise to signal in sound pressure level
|
||||||
|
if uttid is not None and self.utt2ratio is not None:
|
||||||
|
ratio = self.utt2ratio[uttid]
|
||||||
|
else:
|
||||||
|
ratio = self.state.uniform(self.lower, self.upper)
|
||||||
|
|
||||||
|
if self.dbunit:
|
||||||
|
ratio = 10**(ratio / 20)
|
||||||
|
scale = ratio * numpy.sqrt((x**2).mean())
|
||||||
|
|
||||||
|
# 2. Get noise
|
||||||
|
if self.utt2noise is not None:
|
||||||
|
# Get noise from the external source
|
||||||
|
if uttid is not None:
|
||||||
|
noise, rate = self.utt2noise[uttid]
|
||||||
|
else:
|
||||||
|
# Randomly select the noise source
|
||||||
|
noise = self.state.choice(list(self.utt2noise.values()))
|
||||||
|
# Normalize the level
|
||||||
|
noise /= numpy.sqrt((noise**2).mean())
|
||||||
|
|
||||||
|
# Adjust the noise length
|
||||||
|
diff = abs(len(x) - len(noise))
|
||||||
|
offset = self.state.randint(0, diff)
|
||||||
|
if len(noise) > len(x):
|
||||||
|
# Truncate noise
|
||||||
|
noise = noise[offset:-(diff - offset)]
|
||||||
|
else:
|
||||||
|
noise = numpy.pad(
|
||||||
|
noise, pad_width=[offset, diff - offset], mode="wrap")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Generate white noise
|
||||||
|
noise = self.state.normal(0, 1, x.shape)
|
||||||
|
|
||||||
|
# 3. Add noise to signal
|
||||||
|
return x + noise * scale
|
||||||
|
|
||||||
|
|
||||||
|
class RIRConvolve():
|
||||||
|
def __init__(self, utt2rir, filetype="list"):
|
||||||
|
self.utt2rir_file = utt2rir
|
||||||
|
self.filetype = filetype
|
||||||
|
|
||||||
|
self.utt2rir = {}
|
||||||
|
if filetype == "list":
|
||||||
|
with open(utt2rir, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
utt, filename = line.rstrip().split(None, 1)
|
||||||
|
signal, rate = soundfile.read(filename, dtype="int16")
|
||||||
|
self.utt2rir[utt] = (signal, rate)
|
||||||
|
|
||||||
|
elif filetype == "sound.hdf5":
|
||||||
|
self.utt2rir = SoundHDF5File(utt2rir, "r")
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(filetype)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return '{}("{}")'.format(self.__class__.__name__, self.utt2rir_file)
|
||||||
|
|
||||||
|
def __call__(self, x, uttid=None, train=True):
|
||||||
|
if not train:
|
||||||
|
return x
|
||||||
|
|
||||||
|
x = x.astype(numpy.float32)
|
||||||
|
|
||||||
|
if x.ndim != 1:
|
||||||
|
# Must be single channel
|
||||||
|
raise RuntimeError(
|
||||||
|
"Input x must be one dimensional array, but got {}".format(
|
||||||
|
x.shape))
|
||||||
|
|
||||||
|
rir, rate = self.utt2rir[uttid]
|
||||||
|
if rir.ndim == 2:
|
||||||
|
# FIXME(kamo): Use chainer.convolution_1d?
|
||||||
|
# return [Time, Channel]
|
||||||
|
return numpy.stack(
|
||||||
|
[scipy.convolve(x, r, mode="same") for r in rir], axis=-1)
|
||||||
|
else:
|
||||||
|
return scipy.convolve(x, rir, mode="same")
|
||||||
|
|
@ -0,0 +1,214 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Modified from espnet(https://github.com/espnet/espnet)
|
||||||
|
"""Spec Augment module for preprocessing i.e., data augmentation"""
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
from PIL import Image
|
||||||
|
from PIL.Image import BICUBIC
|
||||||
|
|
||||||
|
from .functional import FuncTrans
|
||||||
|
|
||||||
|
|
||||||
|
def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"):
|
||||||
|
"""time warp for spec augment
|
||||||
|
|
||||||
|
move random center frame by the random width ~ uniform(-window, window)
|
||||||
|
:param numpy.ndarray x: spectrogram (time, freq)
|
||||||
|
:param int max_time_warp: maximum time frames to warp
|
||||||
|
:param bool inplace: overwrite x with the result
|
||||||
|
:param str mode: "PIL" (default, fast, not differentiable) or "sparse_image_warp"
|
||||||
|
(slow, differentiable)
|
||||||
|
:returns numpy.ndarray: time warped spectrogram (time, freq)
|
||||||
|
"""
|
||||||
|
window = max_time_warp
|
||||||
|
if window == 0:
|
||||||
|
return x
|
||||||
|
|
||||||
|
if mode == "PIL":
|
||||||
|
t = x.shape[0]
|
||||||
|
if t - window <= window:
|
||||||
|
return x
|
||||||
|
# NOTE: randrange(a, b) emits a, a + 1, ..., b - 1
|
||||||
|
center = random.randrange(window, t - window)
|
||||||
|
warped = random.randrange(center - window, center +
|
||||||
|
window) + 1 # 1 ... t - 1
|
||||||
|
|
||||||
|
left = Image.fromarray(x[:center]).resize((x.shape[1], warped), BICUBIC)
|
||||||
|
right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped),
|
||||||
|
BICUBIC)
|
||||||
|
if inplace:
|
||||||
|
x[:warped] = left
|
||||||
|
x[warped:] = right
|
||||||
|
return x
|
||||||
|
return numpy.concatenate((left, right), 0)
|
||||||
|
elif mode == "sparse_image_warp":
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from espnet.utils import spec_augment
|
||||||
|
|
||||||
|
# TODO(karita): make this differentiable again
|
||||||
|
return spec_augment.time_warp(paddle.to_tensor(x), window).numpy()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("unknown resize mode: " + mode +
|
||||||
|
", choose one from (PIL, sparse_image_warp).")
|
||||||
|
|
||||||
|
|
||||||
|
class TimeWarp(FuncTrans):
|
||||||
|
_func = time_warp
|
||||||
|
__doc__ = time_warp.__doc__
|
||||||
|
|
||||||
|
def __call__(self, x, train):
|
||||||
|
if not train:
|
||||||
|
return x
|
||||||
|
return super().__call__(x)
|
||||||
|
|
||||||
|
|
||||||
|
def freq_mask(x, F=30, n_mask=2, replace_with_zero=True, inplace=False):
|
||||||
|
"""freq mask for spec agument
|
||||||
|
|
||||||
|
:param numpy.ndarray x: (time, freq)
|
||||||
|
:param int n_mask: the number of masks
|
||||||
|
:param bool inplace: overwrite
|
||||||
|
:param bool replace_with_zero: pad zero on mask if true else use mean
|
||||||
|
"""
|
||||||
|
if inplace:
|
||||||
|
cloned = x
|
||||||
|
else:
|
||||||
|
cloned = x.copy()
|
||||||
|
|
||||||
|
num_mel_channels = cloned.shape[1]
|
||||||
|
fs = numpy.random.randint(0, F, size=(n_mask, 2))
|
||||||
|
|
||||||
|
for f, mask_end in fs:
|
||||||
|
f_zero = random.randrange(0, num_mel_channels - f)
|
||||||
|
mask_end += f_zero
|
||||||
|
|
||||||
|
# avoids randrange error if values are equal and range is empty
|
||||||
|
if f_zero == f_zero + f:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if replace_with_zero:
|
||||||
|
cloned[:, f_zero:mask_end] = 0
|
||||||
|
else:
|
||||||
|
cloned[:, f_zero:mask_end] = cloned.mean()
|
||||||
|
return cloned
|
||||||
|
|
||||||
|
|
||||||
|
class FreqMask(FuncTrans):
|
||||||
|
_func = freq_mask
|
||||||
|
__doc__ = freq_mask.__doc__
|
||||||
|
|
||||||
|
def __call__(self, x, train):
|
||||||
|
if not train:
|
||||||
|
return x
|
||||||
|
return super().__call__(x)
|
||||||
|
|
||||||
|
|
||||||
|
def time_mask(spec, T=40, n_mask=2, replace_with_zero=True, inplace=False):
|
||||||
|
"""freq mask for spec agument
|
||||||
|
|
||||||
|
:param numpy.ndarray spec: (time, freq)
|
||||||
|
:param int n_mask: the number of masks
|
||||||
|
:param bool inplace: overwrite
|
||||||
|
:param bool replace_with_zero: pad zero on mask if true else use mean
|
||||||
|
"""
|
||||||
|
if inplace:
|
||||||
|
cloned = spec
|
||||||
|
else:
|
||||||
|
cloned = spec.copy()
|
||||||
|
len_spectro = cloned.shape[0]
|
||||||
|
ts = numpy.random.randint(0, T, size=(n_mask, 2))
|
||||||
|
for t, mask_end in ts:
|
||||||
|
# avoid randint range error
|
||||||
|
if len_spectro - t <= 0:
|
||||||
|
continue
|
||||||
|
t_zero = random.randrange(0, len_spectro - t)
|
||||||
|
|
||||||
|
# avoids randrange error if values are equal and range is empty
|
||||||
|
if t_zero == t_zero + t:
|
||||||
|
continue
|
||||||
|
|
||||||
|
mask_end += t_zero
|
||||||
|
if replace_with_zero:
|
||||||
|
cloned[t_zero:mask_end] = 0
|
||||||
|
else:
|
||||||
|
cloned[t_zero:mask_end] = cloned.mean()
|
||||||
|
return cloned
|
||||||
|
|
||||||
|
|
||||||
|
class TimeMask(FuncTrans):
|
||||||
|
_func = time_mask
|
||||||
|
__doc__ = time_mask.__doc__
|
||||||
|
|
||||||
|
def __call__(self, x, train):
|
||||||
|
if not train:
|
||||||
|
return x
|
||||||
|
return super().__call__(x)
|
||||||
|
|
||||||
|
|
||||||
|
def spec_augment(
|
||||||
|
x,
|
||||||
|
resize_mode="PIL",
|
||||||
|
max_time_warp=80,
|
||||||
|
max_freq_width=27,
|
||||||
|
n_freq_mask=2,
|
||||||
|
max_time_width=100,
|
||||||
|
n_time_mask=2,
|
||||||
|
inplace=True,
|
||||||
|
replace_with_zero=True, ):
|
||||||
|
"""spec agument
|
||||||
|
|
||||||
|
apply random time warping and time/freq masking
|
||||||
|
default setting is based on LD (Librispeech double) in Table 2
|
||||||
|
https://arxiv.org/pdf/1904.08779.pdf
|
||||||
|
|
||||||
|
:param numpy.ndarray x: (time, freq)
|
||||||
|
:param str resize_mode: "PIL" (fast, nondifferentiable) or "sparse_image_warp"
|
||||||
|
(slow, differentiable)
|
||||||
|
:param int max_time_warp: maximum frames to warp the center frame in spectrogram (W)
|
||||||
|
:param int freq_mask_width: maximum width of the random freq mask (F)
|
||||||
|
:param int n_freq_mask: the number of the random freq mask (m_F)
|
||||||
|
:param int time_mask_width: maximum width of the random time mask (T)
|
||||||
|
:param int n_time_mask: the number of the random time mask (m_T)
|
||||||
|
:param bool inplace: overwrite intermediate array
|
||||||
|
:param bool replace_with_zero: pad zero on mask if true else use mean
|
||||||
|
"""
|
||||||
|
assert isinstance(x, numpy.ndarray)
|
||||||
|
assert x.ndim == 2
|
||||||
|
x = time_warp(x, max_time_warp, inplace=inplace, mode=resize_mode)
|
||||||
|
x = freq_mask(
|
||||||
|
x,
|
||||||
|
max_freq_width,
|
||||||
|
n_freq_mask,
|
||||||
|
inplace=inplace,
|
||||||
|
replace_with_zero=replace_with_zero, )
|
||||||
|
x = time_mask(
|
||||||
|
x,
|
||||||
|
max_time_width,
|
||||||
|
n_time_mask,
|
||||||
|
inplace=inplace,
|
||||||
|
replace_with_zero=replace_with_zero, )
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SpecAugment(FuncTrans):
|
||||||
|
_func = spec_augment
|
||||||
|
__doc__ = spec_augment.__doc__
|
||||||
|
|
||||||
|
def __call__(self, x, train):
|
||||||
|
if not train:
|
||||||
|
return x
|
||||||
|
return super().__call__(x)
|
@ -0,0 +1,475 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Modified from espnet(https://github.com/espnet/espnet)
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
from python_speech_features import logfbank
|
||||||
|
|
||||||
|
from ..compliance import kaldi
|
||||||
|
|
||||||
|
|
||||||
|
def stft(x,
|
||||||
|
n_fft,
|
||||||
|
n_shift,
|
||||||
|
win_length=None,
|
||||||
|
window="hann",
|
||||||
|
center=True,
|
||||||
|
pad_mode="reflect"):
|
||||||
|
# x: [Time, Channel]
|
||||||
|
if x.ndim == 1:
|
||||||
|
single_channel = True
|
||||||
|
# x: [Time] -> [Time, Channel]
|
||||||
|
x = x[:, None]
|
||||||
|
else:
|
||||||
|
single_channel = False
|
||||||
|
x = x.astype(np.float32)
|
||||||
|
|
||||||
|
# FIXME(kamo): librosa.stft can't use multi-channel?
|
||||||
|
# x: [Time, Channel, Freq]
|
||||||
|
x = np.stack(
|
||||||
|
[
|
||||||
|
librosa.stft(
|
||||||
|
y=x[:, ch],
|
||||||
|
n_fft=n_fft,
|
||||||
|
hop_length=n_shift,
|
||||||
|
win_length=win_length,
|
||||||
|
window=window,
|
||||||
|
center=center,
|
||||||
|
pad_mode=pad_mode, ).T for ch in range(x.shape[1])
|
||||||
|
],
|
||||||
|
axis=1, )
|
||||||
|
|
||||||
|
if single_channel:
|
||||||
|
# x: [Time, Channel, Freq] -> [Time, Freq]
|
||||||
|
x = x[:, 0]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def istft(x, n_shift, win_length=None, window="hann", center=True):
|
||||||
|
# x: [Time, Channel, Freq]
|
||||||
|
if x.ndim == 2:
|
||||||
|
single_channel = True
|
||||||
|
# x: [Time, Freq] -> [Time, Channel, Freq]
|
||||||
|
x = x[:, None, :]
|
||||||
|
else:
|
||||||
|
single_channel = False
|
||||||
|
|
||||||
|
# x: [Time, Channel]
|
||||||
|
x = np.stack(
|
||||||
|
[
|
||||||
|
librosa.istft(
|
||||||
|
stft_matrix=x[:, ch].T, # [Time, Freq] -> [Freq, Time]
|
||||||
|
hop_length=n_shift,
|
||||||
|
win_length=win_length,
|
||||||
|
window=window,
|
||||||
|
center=center, ) for ch in range(x.shape[1])
|
||||||
|
],
|
||||||
|
axis=1, )
|
||||||
|
|
||||||
|
if single_channel:
|
||||||
|
# x: [Time, Channel] -> [Time]
|
||||||
|
x = x[:, 0]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def stft2logmelspectrogram(x_stft,
|
||||||
|
fs,
|
||||||
|
n_mels,
|
||||||
|
n_fft,
|
||||||
|
fmin=None,
|
||||||
|
fmax=None,
|
||||||
|
eps=1e-10):
|
||||||
|
# x_stft: (Time, Channel, Freq) or (Time, Freq)
|
||||||
|
fmin = 0 if fmin is None else fmin
|
||||||
|
fmax = fs / 2 if fmax is None else fmax
|
||||||
|
|
||||||
|
# spc: (Time, Channel, Freq) or (Time, Freq)
|
||||||
|
spc = np.abs(x_stft)
|
||||||
|
# mel_basis: (Mel_freq, Freq)
|
||||||
|
mel_basis = librosa.filters.mel(
|
||||||
|
sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
|
||||||
|
# lmspc: (Time, Channel, Mel_freq) or (Time, Mel_freq)
|
||||||
|
lmspc = np.log10(np.maximum(eps, np.dot(spc, mel_basis.T)))
|
||||||
|
|
||||||
|
return lmspc
|
||||||
|
|
||||||
|
|
||||||
|
def spectrogram(x, n_fft, n_shift, win_length=None, window="hann"):
|
||||||
|
# x: (Time, Channel) -> spc: (Time, Channel, Freq)
|
||||||
|
spc = np.abs(stft(x, n_fft, n_shift, win_length, window=window))
|
||||||
|
return spc
|
||||||
|
|
||||||
|
|
||||||
|
def logmelspectrogram(
|
||||||
|
x,
|
||||||
|
fs,
|
||||||
|
n_mels,
|
||||||
|
n_fft,
|
||||||
|
n_shift,
|
||||||
|
win_length=None,
|
||||||
|
window="hann",
|
||||||
|
fmin=None,
|
||||||
|
fmax=None,
|
||||||
|
eps=1e-10,
|
||||||
|
pad_mode="reflect", ):
|
||||||
|
# stft: (Time, Channel, Freq) or (Time, Freq)
|
||||||
|
x_stft = stft(
|
||||||
|
x,
|
||||||
|
n_fft=n_fft,
|
||||||
|
n_shift=n_shift,
|
||||||
|
win_length=win_length,
|
||||||
|
window=window,
|
||||||
|
pad_mode=pad_mode, )
|
||||||
|
|
||||||
|
return stft2logmelspectrogram(
|
||||||
|
x_stft,
|
||||||
|
fs=fs,
|
||||||
|
n_mels=n_mels,
|
||||||
|
n_fft=n_fft,
|
||||||
|
fmin=fmin,
|
||||||
|
fmax=fmax,
|
||||||
|
eps=eps)
|
||||||
|
|
||||||
|
|
||||||
|
class Spectrogram():
|
||||||
|
def __init__(self, n_fft, n_shift, win_length=None, window="hann"):
|
||||||
|
self.n_fft = n_fft
|
||||||
|
self.n_shift = n_shift
|
||||||
|
self.win_length = win_length
|
||||||
|
self.window = window
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return ("{name}(n_fft={n_fft}, n_shift={n_shift}, "
|
||||||
|
"win_length={win_length}, window={window})".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
n_shift=self.n_shift,
|
||||||
|
win_length=self.win_length,
|
||||||
|
window=self.window, ))
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return spectrogram(
|
||||||
|
x,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
n_shift=self.n_shift,
|
||||||
|
win_length=self.win_length,
|
||||||
|
window=self.window, )
|
||||||
|
|
||||||
|
|
||||||
|
class LogMelSpectrogram():
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fs,
|
||||||
|
n_mels,
|
||||||
|
n_fft,
|
||||||
|
n_shift,
|
||||||
|
win_length=None,
|
||||||
|
window="hann",
|
||||||
|
fmin=None,
|
||||||
|
fmax=None,
|
||||||
|
eps=1e-10, ):
|
||||||
|
self.fs = fs
|
||||||
|
self.n_mels = n_mels
|
||||||
|
self.n_fft = n_fft
|
||||||
|
self.n_shift = n_shift
|
||||||
|
self.win_length = win_length
|
||||||
|
self.window = window
|
||||||
|
self.fmin = fmin
|
||||||
|
self.fmax = fmax
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return ("{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
|
||||||
|
"n_shift={n_shift}, win_length={win_length}, window={window}, "
|
||||||
|
"fmin={fmin}, fmax={fmax}, eps={eps}))".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
fs=self.fs,
|
||||||
|
n_mels=self.n_mels,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
n_shift=self.n_shift,
|
||||||
|
win_length=self.win_length,
|
||||||
|
window=self.window,
|
||||||
|
fmin=self.fmin,
|
||||||
|
fmax=self.fmax,
|
||||||
|
eps=self.eps, ))
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return logmelspectrogram(
|
||||||
|
x,
|
||||||
|
fs=self.fs,
|
||||||
|
n_mels=self.n_mels,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
n_shift=self.n_shift,
|
||||||
|
win_length=self.win_length,
|
||||||
|
window=self.window, )
|
||||||
|
|
||||||
|
|
||||||
|
class Stft2LogMelSpectrogram():
|
||||||
|
def __init__(self, fs, n_mels, n_fft, fmin=None, fmax=None, eps=1e-10):
|
||||||
|
self.fs = fs
|
||||||
|
self.n_mels = n_mels
|
||||||
|
self.n_fft = n_fft
|
||||||
|
self.fmin = fmin
|
||||||
|
self.fmax = fmax
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return ("{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
|
||||||
|
"fmin={fmin}, fmax={fmax}, eps={eps}))".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
fs=self.fs,
|
||||||
|
n_mels=self.n_mels,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
fmin=self.fmin,
|
||||||
|
fmax=self.fmax,
|
||||||
|
eps=self.eps, ))
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return stft2logmelspectrogram(
|
||||||
|
x,
|
||||||
|
fs=self.fs,
|
||||||
|
n_mels=self.n_mels,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
fmin=self.fmin,
|
||||||
|
fmax=self.fmax, )
|
||||||
|
|
||||||
|
|
||||||
|
class Stft():
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_fft,
|
||||||
|
n_shift,
|
||||||
|
win_length=None,
|
||||||
|
window="hann",
|
||||||
|
center=True,
|
||||||
|
pad_mode="reflect", ):
|
||||||
|
self.n_fft = n_fft
|
||||||
|
self.n_shift = n_shift
|
||||||
|
self.win_length = win_length
|
||||||
|
self.window = window
|
||||||
|
self.center = center
|
||||||
|
self.pad_mode = pad_mode
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return ("{name}(n_fft={n_fft}, n_shift={n_shift}, "
|
||||||
|
"win_length={win_length}, window={window},"
|
||||||
|
"center={center}, pad_mode={pad_mode})".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
n_shift=self.n_shift,
|
||||||
|
win_length=self.win_length,
|
||||||
|
window=self.window,
|
||||||
|
center=self.center,
|
||||||
|
pad_mode=self.pad_mode, ))
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return stft(
|
||||||
|
x,
|
||||||
|
self.n_fft,
|
||||||
|
self.n_shift,
|
||||||
|
win_length=self.win_length,
|
||||||
|
window=self.window,
|
||||||
|
center=self.center,
|
||||||
|
pad_mode=self.pad_mode, )
|
||||||
|
|
||||||
|
|
||||||
|
class IStft():
|
||||||
|
def __init__(self, n_shift, win_length=None, window="hann", center=True):
|
||||||
|
self.n_shift = n_shift
|
||||||
|
self.win_length = win_length
|
||||||
|
self.window = window
|
||||||
|
self.center = center
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return ("{name}(n_shift={n_shift}, "
|
||||||
|
"win_length={win_length}, window={window},"
|
||||||
|
"center={center})".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
n_shift=self.n_shift,
|
||||||
|
win_length=self.win_length,
|
||||||
|
window=self.window,
|
||||||
|
center=self.center, ))
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return istft(
|
||||||
|
x,
|
||||||
|
self.n_shift,
|
||||||
|
win_length=self.win_length,
|
||||||
|
window=self.window,
|
||||||
|
center=self.center, )
|
||||||
|
|
||||||
|
|
||||||
|
class LogMelSpectrogramKaldi():
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fs=16000,
|
||||||
|
n_mels=80,
|
||||||
|
n_shift=160, # unit:sample, 10ms
|
||||||
|
win_length=400, # unit:sample, 25ms
|
||||||
|
energy_floor=0.0,
|
||||||
|
dither=0.1):
|
||||||
|
"""
|
||||||
|
The Kaldi implementation of LogMelSpectrogram
|
||||||
|
Args:
|
||||||
|
fs (int): sample rate of the audio
|
||||||
|
n_mels (int): number of mel filter banks
|
||||||
|
n_shift (int): number of points in a frame shift
|
||||||
|
win_length (int): number of points in a frame windows
|
||||||
|
energy_floor (float): Floor on energy in Spectrogram computation (absolute)
|
||||||
|
dither (float): Dithering constant
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LogMelSpectrogramKaldi
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.fs = fs
|
||||||
|
self.n_mels = n_mels
|
||||||
|
num_point_ms = fs / 1000
|
||||||
|
self.n_frame_length = win_length / num_point_ms
|
||||||
|
self.n_frame_shift = n_shift / num_point_ms
|
||||||
|
self.energy_floor = energy_floor
|
||||||
|
self.dither = dither
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (
|
||||||
|
"{name}(fs={fs}, n_mels={n_mels}, "
|
||||||
|
"n_frame_shift={n_frame_shift}, n_frame_length={n_frame_length}, "
|
||||||
|
"dither={dither}))".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
fs=self.fs,
|
||||||
|
n_mels=self.n_mels,
|
||||||
|
n_frame_shift=self.n_frame_shift,
|
||||||
|
n_frame_length=self.n_frame_length,
|
||||||
|
dither=self.dither, ))
|
||||||
|
|
||||||
|
def __call__(self, x, train):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (np.ndarray): shape (Ti,)
|
||||||
|
train (bool): True, train mode.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: not support (Ti, C)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: (T, D)
|
||||||
|
"""
|
||||||
|
dither = self.dither if train else 0.0
|
||||||
|
if x.ndim != 1:
|
||||||
|
raise ValueError("Not support x: [Time, Channel]")
|
||||||
|
waveform = paddle.to_tensor(np.expand_dims(x, 0), dtype=paddle.float32)
|
||||||
|
mat = kaldi.fbank(
|
||||||
|
waveform,
|
||||||
|
n_mels=self.n_mels,
|
||||||
|
frame_length=self.n_frame_length,
|
||||||
|
frame_shift=self.n_frame_shift,
|
||||||
|
dither=dither,
|
||||||
|
energy_floor=self.energy_floor,
|
||||||
|
sr=self.fs)
|
||||||
|
mat = np.squeeze(mat.numpy())
|
||||||
|
return mat
|
||||||
|
|
||||||
|
|
||||||
|
class LogMelSpectrogramKaldi_decay():
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fs=16000,
|
||||||
|
n_mels=80,
|
||||||
|
n_fft=512, # fft point
|
||||||
|
n_shift=160, # unit:sample, 10ms
|
||||||
|
win_length=400, # unit:sample, 25ms
|
||||||
|
window="povey",
|
||||||
|
fmin=20,
|
||||||
|
fmax=None,
|
||||||
|
eps=1e-10,
|
||||||
|
dither=1.0):
|
||||||
|
self.fs = fs
|
||||||
|
self.n_mels = n_mels
|
||||||
|
self.n_fft = n_fft
|
||||||
|
if n_shift > win_length:
|
||||||
|
raise ValueError("Stride size must not be greater than "
|
||||||
|
"window size.")
|
||||||
|
self.n_shift = n_shift / fs # unit: ms
|
||||||
|
self.win_length = win_length / fs # unit: ms
|
||||||
|
|
||||||
|
self.window = window
|
||||||
|
self.fmin = fmin
|
||||||
|
if fmax is None:
|
||||||
|
fmax_ = fmax if fmax else self.fs / 2
|
||||||
|
elif fmax > int(self.fs / 2):
|
||||||
|
raise ValueError("fmax must not be greater than half of "
|
||||||
|
"sample rate.")
|
||||||
|
self.fmax = fmax_
|
||||||
|
|
||||||
|
self.eps = eps
|
||||||
|
self.remove_dc_offset = True
|
||||||
|
self.preemph = 0.97
|
||||||
|
self.dither = dither # only work in train mode
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (
|
||||||
|
"{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
|
||||||
|
"n_shift={n_shift}, win_length={win_length}, preemph={preemph}, window={window}, "
|
||||||
|
"fmin={fmin}, fmax={fmax}, eps={eps}, dither={dither}))".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
fs=self.fs,
|
||||||
|
n_mels=self.n_mels,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
n_shift=self.n_shift,
|
||||||
|
preemph=self.preemph,
|
||||||
|
win_length=self.win_length,
|
||||||
|
window=self.window,
|
||||||
|
fmin=self.fmin,
|
||||||
|
fmax=self.fmax,
|
||||||
|
eps=self.eps,
|
||||||
|
dither=self.dither, ))
|
||||||
|
|
||||||
|
def __call__(self, x, train):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (np.ndarray): shape (Ti,)
|
||||||
|
train (bool): True, train mode.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: not support (Ti, C)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: (T, D)
|
||||||
|
"""
|
||||||
|
dither = self.dither if train else 0.0
|
||||||
|
if x.ndim != 1:
|
||||||
|
raise ValueError("Not support x: [Time, Channel]")
|
||||||
|
|
||||||
|
if x.dtype in np.sctypes['float']:
|
||||||
|
# PCM32 -> PCM16
|
||||||
|
bits = np.iinfo(np.int16).bits
|
||||||
|
x = x * 2**(bits - 1)
|
||||||
|
|
||||||
|
# logfbank need PCM16 input
|
||||||
|
y = logfbank(
|
||||||
|
signal=x,
|
||||||
|
samplerate=self.fs,
|
||||||
|
winlen=self.win_length, # unit ms
|
||||||
|
winstep=self.n_shift, # unit ms
|
||||||
|
nfilt=self.n_mels,
|
||||||
|
nfft=self.n_fft,
|
||||||
|
lowfreq=self.fmin,
|
||||||
|
highfreq=self.fmax,
|
||||||
|
dither=dither,
|
||||||
|
remove_dc_offset=self.remove_dc_offset,
|
||||||
|
preemph=self.preemph,
|
||||||
|
wintype=self.window)
|
||||||
|
return y
|
@ -0,0 +1,35 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Modified from espnet(https://github.com/espnet/espnet)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformInterface:
|
||||||
|
"""Transform Interface"""
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
raise NotImplementedError("__call__ method is not implemented")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_arguments(cls, parser):
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__class__.__name__ + "()"
|
||||||
|
|
||||||
|
|
||||||
|
class Identity(TransformInterface):
|
||||||
|
"""Identity Function"""
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return x
|
@ -0,0 +1,158 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Modified from espnet(https://github.com/espnet/espnet)
|
||||||
|
"""Transformation module."""
|
||||||
|
import copy
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
from collections import OrderedDict
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from inspect import signature
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from ..utils.dynamic_import import dynamic_import
|
||||||
|
|
||||||
|
import_alias = dict(
|
||||||
|
identity="paddlespeech.audio.transform.transform_interface:Identity",
|
||||||
|
time_warp="paddlespeech.audio.transform.spec_augment:TimeWarp",
|
||||||
|
time_mask="paddlespeech.audio.transform.spec_augment:TimeMask",
|
||||||
|
freq_mask="paddlespeech.audio.transform.spec_augment:FreqMask",
|
||||||
|
spec_augment="paddlespeech.audio.transform.spec_augment:SpecAugment",
|
||||||
|
speed_perturbation="paddlespeech.audio.transform.perturb:SpeedPerturbation",
|
||||||
|
speed_perturbation_sox="paddlespeech.audio.transform.perturb:SpeedPerturbationSox",
|
||||||
|
volume_perturbation="paddlespeech.audio.transform.perturb:VolumePerturbation",
|
||||||
|
noise_injection="paddlespeech.audio.transform.perturb:NoiseInjection",
|
||||||
|
bandpass_perturbation="paddlespeech.audio.transform.perturb:BandpassPerturbation",
|
||||||
|
rir_convolve="paddlespeech.audio.transform.perturb:RIRConvolve",
|
||||||
|
delta="paddlespeech.audio.transform.add_deltas:AddDeltas",
|
||||||
|
cmvn="paddlespeech.audio.transform.cmvn:CMVN",
|
||||||
|
utterance_cmvn="paddlespeech.audio.transform.cmvn:UtteranceCMVN",
|
||||||
|
fbank="paddlespeech.audio.transform.spectrogram:LogMelSpectrogram",
|
||||||
|
spectrogram="paddlespeech.audio.transform.spectrogram:Spectrogram",
|
||||||
|
stft="paddlespeech.audio.transform.spectrogram:Stft",
|
||||||
|
istft="paddlespeech.audio.transform.spectrogram:IStft",
|
||||||
|
stft2fbank="paddlespeech.audio.transform.spectrogram:Stft2LogMelSpectrogram",
|
||||||
|
wpe="paddlespeech.audio.transform.wpe:WPE",
|
||||||
|
channel_selector="paddlespeech.audio.transform.channel_selector:ChannelSelector",
|
||||||
|
fbank_kaldi="paddlespeech.audio.transform.spectrogram:LogMelSpectrogramKaldi",
|
||||||
|
cmvn_json="paddlespeech.audio.transform.cmvn:GlobalCMVN")
|
||||||
|
|
||||||
|
|
||||||
|
class Transformation():
|
||||||
|
"""Apply some functions to the mini-batch
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> kwargs = {"process": [{"type": "fbank",
|
||||||
|
... "n_mels": 80,
|
||||||
|
... "fs": 16000},
|
||||||
|
... {"type": "cmvn",
|
||||||
|
... "stats": "data/train/cmvn.ark",
|
||||||
|
... "norm_vars": True},
|
||||||
|
... {"type": "delta", "window": 2, "order": 2}]}
|
||||||
|
>>> transform = Transformation(kwargs)
|
||||||
|
>>> bs = 10
|
||||||
|
>>> xs = [np.random.randn(100, 80).astype(np.float32)
|
||||||
|
... for _ in range(bs)]
|
||||||
|
>>> xs = transform(xs)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, conffile=None):
|
||||||
|
if conffile is not None:
|
||||||
|
if isinstance(conffile, dict):
|
||||||
|
self.conf = copy.deepcopy(conffile)
|
||||||
|
else:
|
||||||
|
with io.open(conffile, encoding="utf-8") as f:
|
||||||
|
self.conf = yaml.safe_load(f)
|
||||||
|
assert isinstance(self.conf, dict), type(self.conf)
|
||||||
|
else:
|
||||||
|
self.conf = {"mode": "sequential", "process": []}
|
||||||
|
|
||||||
|
self.functions = OrderedDict()
|
||||||
|
if self.conf.get("mode", "sequential") == "sequential":
|
||||||
|
for idx, process in enumerate(self.conf["process"]):
|
||||||
|
assert isinstance(process, dict), type(process)
|
||||||
|
opts = dict(process)
|
||||||
|
process_type = opts.pop("type")
|
||||||
|
class_obj = dynamic_import(process_type, import_alias)
|
||||||
|
# TODO(karita): assert issubclass(class_obj, TransformInterface)
|
||||||
|
try:
|
||||||
|
self.functions[idx] = class_obj(**opts)
|
||||||
|
except TypeError:
|
||||||
|
try:
|
||||||
|
signa = signature(class_obj)
|
||||||
|
except ValueError:
|
||||||
|
# Some function, e.g. built-in function, are failed
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
logging.error("Expected signature: {}({})".format(
|
||||||
|
class_obj.__name__, signa))
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Not supporting mode={}".format(self.conf["mode"]))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
rep = "\n" + "\n".join(" {}: {}".format(k, v)
|
||||||
|
for k, v in self.functions.items())
|
||||||
|
return "{}({})".format(self.__class__.__name__, rep)
|
||||||
|
|
||||||
|
def __call__(self, xs, uttid_list=None, **kwargs):
|
||||||
|
"""Return new mini-batch
|
||||||
|
|
||||||
|
:param Union[Sequence[np.ndarray], np.ndarray] xs:
|
||||||
|
:param Union[Sequence[str], str] uttid_list:
|
||||||
|
:return: batch:
|
||||||
|
:rtype: List[np.ndarray]
|
||||||
|
"""
|
||||||
|
if not isinstance(xs, Sequence):
|
||||||
|
is_batch = False
|
||||||
|
xs = [xs]
|
||||||
|
else:
|
||||||
|
is_batch = True
|
||||||
|
|
||||||
|
if isinstance(uttid_list, str):
|
||||||
|
uttid_list = [uttid_list for _ in range(len(xs))]
|
||||||
|
|
||||||
|
if self.conf.get("mode", "sequential") == "sequential":
|
||||||
|
for idx in range(len(self.conf["process"])):
|
||||||
|
func = self.functions[idx]
|
||||||
|
# TODO(karita): use TrainingTrans and UttTrans to check __call__ args
|
||||||
|
# Derive only the args which the func has
|
||||||
|
try:
|
||||||
|
param = signature(func).parameters
|
||||||
|
except ValueError:
|
||||||
|
# Some function, e.g. built-in function, are failed
|
||||||
|
param = {}
|
||||||
|
_kwargs = {k: v for k, v in kwargs.items() if k in param}
|
||||||
|
try:
|
||||||
|
if uttid_list is not None and "uttid" in param:
|
||||||
|
xs = [
|
||||||
|
func(x, u, **_kwargs)
|
||||||
|
for x, u in zip(xs, uttid_list)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
xs = [func(x, **_kwargs) for x in xs]
|
||||||
|
except Exception:
|
||||||
|
logging.fatal("Catch a exception from {}th func: {}".format(
|
||||||
|
idx, func))
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Not supporting mode={}".format(self.conf["mode"]))
|
||||||
|
|
||||||
|
if is_batch:
|
||||||
|
return xs
|
||||||
|
else:
|
||||||
|
return xs[0]
|
@ -0,0 +1,58 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Modified from espnet(https://github.com/espnet/espnet)
|
||||||
|
from nara_wpe.wpe import wpe
|
||||||
|
|
||||||
|
|
||||||
|
class WPE(object):
|
||||||
|
def __init__(self,
|
||||||
|
taps=10,
|
||||||
|
delay=3,
|
||||||
|
iterations=3,
|
||||||
|
psd_context=0,
|
||||||
|
statistics_mode="full"):
|
||||||
|
self.taps = taps
|
||||||
|
self.delay = delay
|
||||||
|
self.iterations = iterations
|
||||||
|
self.psd_context = psd_context
|
||||||
|
self.statistics_mode = statistics_mode
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return ("{name}(taps={taps}, delay={delay}"
|
||||||
|
"iterations={iterations}, psd_context={psd_context}, "
|
||||||
|
"statistics_mode={statistics_mode})".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
taps=self.taps,
|
||||||
|
delay=self.delay,
|
||||||
|
iterations=self.iterations,
|
||||||
|
psd_context=self.psd_context,
|
||||||
|
statistics_mode=self.statistics_mode, ))
|
||||||
|
|
||||||
|
def __call__(self, xs):
|
||||||
|
"""Return enhanced
|
||||||
|
|
||||||
|
:param np.ndarray xs: (Time, Channel, Frequency)
|
||||||
|
:return: enhanced_xs
|
||||||
|
:rtype: np.ndarray
|
||||||
|
|
||||||
|
"""
|
||||||
|
# nara_wpe.wpe: (F, C, T)
|
||||||
|
xs = wpe(
|
||||||
|
xs.transpose((2, 1, 0)),
|
||||||
|
taps=self.taps,
|
||||||
|
delay=self.delay,
|
||||||
|
iterations=self.iterations,
|
||||||
|
psd_context=self.psd_context,
|
||||||
|
statistics_mode=self.statistics_mode, )
|
||||||
|
return xs.transpose(2, 1, 0)
|
@ -0,0 +1,35 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Modified from espnet(https://github.com/espnet/espnet)
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
|
||||||
|
def check_kwargs(func, kwargs, name=None):
|
||||||
|
"""check kwargs are valid for func
|
||||||
|
|
||||||
|
If kwargs are invalid, raise TypeError as same as python default
|
||||||
|
:param function func: function to be validated
|
||||||
|
:param dict kwargs: keyword arguments for func
|
||||||
|
:param str name: name used in TypeError (default is func name)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
params = inspect.signature(func).parameters
|
||||||
|
except ValueError:
|
||||||
|
return
|
||||||
|
if name is None:
|
||||||
|
name = func.__name__
|
||||||
|
for k in kwargs.keys():
|
||||||
|
if k not in params:
|
||||||
|
raise TypeError(
|
||||||
|
f"{name}() got an unexpected keyword argument '{k}'")
|
@ -0,0 +1,38 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Modified from espnet(https://github.com/espnet/espnet)
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
__all__ = ["dynamic_import"]
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_import(import_path, alias=dict()):
|
||||||
|
"""dynamic import module and class
|
||||||
|
|
||||||
|
:param str import_path: syntax 'module_name:class_name'
|
||||||
|
e.g., 'paddlespeech.s2t.models.u2:U2Model'
|
||||||
|
:param dict alias: shortcut for registered class
|
||||||
|
:return: imported class
|
||||||
|
"""
|
||||||
|
if import_path not in alias and ":" not in import_path:
|
||||||
|
raise ValueError(
|
||||||
|
"import_path should be one of {} or "
|
||||||
|
'include ":", e.g. "paddlespeech.s2t.models.u2:U2Model" : '
|
||||||
|
"{}".format(set(alias), import_path))
|
||||||
|
if ":" not in import_path:
|
||||||
|
import_path = alias[import_path]
|
||||||
|
|
||||||
|
module_name, objname = import_path.split(":")
|
||||||
|
m = importlib.import_module(module_name)
|
||||||
|
return getattr(m, objname)
|
@ -0,0 +1,195 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Unility functions for Transformer."""
|
||||||
|
from typing import List
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from .log import Logger
|
||||||
|
|
||||||
|
__all__ = ["pad_sequence", "add_sos_eos", "th_accuracy", "has_tensor"]
|
||||||
|
|
||||||
|
logger = Logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def has_tensor(val):
|
||||||
|
if isinstance(val, (list, tuple)):
|
||||||
|
for item in val:
|
||||||
|
if has_tensor(item):
|
||||||
|
return True
|
||||||
|
elif isinstance(val, dict):
|
||||||
|
for k, v in val.items():
|
||||||
|
print(k)
|
||||||
|
if has_tensor(v):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return paddle.is_tensor(val)
|
||||||
|
|
||||||
|
|
||||||
|
def pad_sequence(sequences: List[paddle.Tensor],
|
||||||
|
batch_first: bool=False,
|
||||||
|
padding_value: float=0.0) -> paddle.Tensor:
|
||||||
|
r"""Pad a list of variable length Tensors with ``padding_value``
|
||||||
|
|
||||||
|
``pad_sequence`` stacks a list of Tensors along a new dimension,
|
||||||
|
and pads them to equal length. For example, if the input is list of
|
||||||
|
sequences with size ``L x *`` and if batch_first is False, and ``T x B x *``
|
||||||
|
otherwise.
|
||||||
|
|
||||||
|
`B` is batch size. It is equal to the number of elements in ``sequences``.
|
||||||
|
`T` is length of the longest sequence.
|
||||||
|
`L` is length of the sequence.
|
||||||
|
`*` is any number of trailing dimensions, including none.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from paddle.nn.utils.rnn import pad_sequence
|
||||||
|
>>> a = paddle.ones(25, 300)
|
||||||
|
>>> b = paddle.ones(22, 300)
|
||||||
|
>>> c = paddle.ones(15, 300)
|
||||||
|
>>> pad_sequence([a, b, c]).shape
|
||||||
|
paddle.Tensor([25, 3, 300])
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This function returns a Tensor of size ``T x B x *`` or ``B x T x *``
|
||||||
|
where `T` is the length of the longest sequence. This function assumes
|
||||||
|
trailing dimensions and type of all the Tensors in sequences are same.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sequences (list[Tensor]): list of variable length sequences.
|
||||||
|
batch_first (bool, optional): output will be in ``B x T x *`` if True, or in
|
||||||
|
``T x B x *`` otherwise
|
||||||
|
padding_value (float, optional): value for padded elements. Default: 0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``.
|
||||||
|
Tensor of size ``B x T x *`` otherwise
|
||||||
|
"""
|
||||||
|
|
||||||
|
# assuming trailing dimensions and type of all the Tensors
|
||||||
|
# in sequences are same and fetching those from sequences[0]
|
||||||
|
max_size = paddle.shape(sequences[0])
|
||||||
|
# (TODO Hui Zhang): slice not supprot `end==start`
|
||||||
|
# trailing_dims = max_size[1:]
|
||||||
|
trailing_dims = tuple(
|
||||||
|
max_size[1:].numpy().tolist()) if sequences[0].ndim >= 2 else ()
|
||||||
|
max_len = max([s.shape[0] for s in sequences])
|
||||||
|
if batch_first:
|
||||||
|
out_dims = (len(sequences), max_len) + trailing_dims
|
||||||
|
else:
|
||||||
|
out_dims = (max_len, len(sequences)) + trailing_dims
|
||||||
|
out_tensor = paddle.full(out_dims, padding_value, sequences[0].dtype)
|
||||||
|
for i, tensor in enumerate(sequences):
|
||||||
|
length = tensor.shape[0]
|
||||||
|
# use index notation to prevent duplicate references to the tensor
|
||||||
|
logger.info(
|
||||||
|
f"length {length}, out_tensor {out_tensor.shape}, tensor {tensor.shape}"
|
||||||
|
)
|
||||||
|
if batch_first:
|
||||||
|
# TODO (Hui Zhang): set_value op not supprot `end==start`
|
||||||
|
# TODO (Hui Zhang): set_value op not support int16
|
||||||
|
# TODO (Hui Zhang): set_varbase 2 rank not support [0,0,...]
|
||||||
|
# out_tensor[i, :length, ...] = tensor
|
||||||
|
if length != 0:
|
||||||
|
out_tensor[i, :length] = tensor
|
||||||
|
else:
|
||||||
|
out_tensor[i, length] = tensor
|
||||||
|
else:
|
||||||
|
# TODO (Hui Zhang): set_value op not supprot `end==start`
|
||||||
|
# out_tensor[:length, i, ...] = tensor
|
||||||
|
if length != 0:
|
||||||
|
out_tensor[:length, i] = tensor
|
||||||
|
else:
|
||||||
|
out_tensor[length, i] = tensor
|
||||||
|
|
||||||
|
return out_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
|
||||||
|
ignore_id: int) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||||
|
"""Add <sos> and <eos> labels.
|
||||||
|
Args:
|
||||||
|
ys_pad (paddle.Tensor): batch of padded target sequences (B, Lmax)
|
||||||
|
sos (int): index of <sos>
|
||||||
|
eos (int): index of <eeos>
|
||||||
|
ignore_id (int): index of padding
|
||||||
|
Returns:
|
||||||
|
ys_in (paddle.Tensor) : (B, Lmax + 1)
|
||||||
|
ys_out (paddle.Tensor) : (B, Lmax + 1)
|
||||||
|
Examples:
|
||||||
|
>>> sos_id = 10
|
||||||
|
>>> eos_id = 11
|
||||||
|
>>> ignore_id = -1
|
||||||
|
>>> ys_pad
|
||||||
|
tensor([[ 1, 2, 3, 4, 5],
|
||||||
|
[ 4, 5, 6, -1, -1],
|
||||||
|
[ 7, 8, 9, -1, -1]], dtype=paddle.int32)
|
||||||
|
>>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id)
|
||||||
|
>>> ys_in
|
||||||
|
tensor([[10, 1, 2, 3, 4, 5],
|
||||||
|
[10, 4, 5, 6, 11, 11],
|
||||||
|
[10, 7, 8, 9, 11, 11]])
|
||||||
|
>>> ys_out
|
||||||
|
tensor([[ 1, 2, 3, 4, 5, 11],
|
||||||
|
[ 4, 5, 6, 11, -1, -1],
|
||||||
|
[ 7, 8, 9, 11, -1, -1]])
|
||||||
|
"""
|
||||||
|
# TODO(Hui Zhang): using comment code,
|
||||||
|
#_sos = paddle.to_tensor(
|
||||||
|
# [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
|
||||||
|
#_eos = paddle.to_tensor(
|
||||||
|
# [eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
|
||||||
|
#ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
|
||||||
|
#ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]
|
||||||
|
#ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]
|
||||||
|
#return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)
|
||||||
|
B = ys_pad.shape[0]
|
||||||
|
_sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos
|
||||||
|
_eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos
|
||||||
|
ys_in = paddle.cat([_sos, ys_pad], dim=1)
|
||||||
|
mask_pad = (ys_in == ignore_id)
|
||||||
|
ys_in = ys_in.masked_fill(mask_pad, eos)
|
||||||
|
|
||||||
|
ys_out = paddle.cat([ys_pad, _eos], dim=1)
|
||||||
|
ys_out = ys_out.masked_fill(mask_pad, eos)
|
||||||
|
mask_eos = (ys_out == ignore_id)
|
||||||
|
ys_out = ys_out.masked_fill(mask_eos, eos)
|
||||||
|
ys_out = ys_out.masked_fill(mask_pad, ignore_id)
|
||||||
|
return ys_in, ys_out
|
||||||
|
|
||||||
|
|
||||||
|
def th_accuracy(pad_outputs: paddle.Tensor,
|
||||||
|
pad_targets: paddle.Tensor,
|
||||||
|
ignore_label: int) -> float:
|
||||||
|
"""Calculate accuracy.
|
||||||
|
Args:
|
||||||
|
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
|
||||||
|
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
|
||||||
|
ignore_label (int): Ignore label id.
|
||||||
|
Returns:
|
||||||
|
float: Accuracy value (0.0 - 1.0).
|
||||||
|
"""
|
||||||
|
pad_pred = pad_outputs.view(pad_targets.shape[0], pad_targets.shape[1],
|
||||||
|
pad_outputs.shape[1]).argmax(2)
|
||||||
|
mask = pad_targets != ignore_label
|
||||||
|
#TODO(Hui Zhang): sum not support bool type
|
||||||
|
# numerator = paddle.sum(
|
||||||
|
# pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
|
||||||
|
numerator = (
|
||||||
|
pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
|
||||||
|
numerator = paddle.sum(numerator.type_as(pad_targets))
|
||||||
|
#TODO(Hui Zhang): sum not support bool type
|
||||||
|
# denominator = paddle.sum(mask)
|
||||||
|
denominator = paddle.sum(mask.type_as(pad_targets))
|
||||||
|
return float(numerator) / float(denominator)
|
Loading…
Reference in new issue