new feature: Add webdataset in audio

pull/2062/head
huangyuxin 2 years ago
parent e04cd18846
commit 8f5e61090b

@ -0,0 +1,68 @@
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
# flake8: noqa
from .cache import (
cached_tarfile_samples,
cached_tarfile_to_samples,
lru_cleanup,
pipe_cleaner,
)
from .compat import WebDataset, WebLoader, FluidWrapper
from webdataset.extradatasets import MockDataset, with_epoch, with_length
from .filters import (
associate,
batched,
decode,
detshuffle,
extract_keys,
getfirst,
info,
map,
map_dict,
map_tuple,
pipelinefilter,
rename,
rename_keys,
rsample,
select,
shuffle,
slice,
to_tuple,
transform_with,
unbatched,
xdecode,
data_filter,
tokenize,
resample,
compute_fbank,
spec_aug,
sort,
padding,
cmvn
)
from webdataset.handlers import (
ignore_and_continue,
ignore_and_stop,
reraise_exception,
warn_and_continue,
warn_and_stop,
)
from .pipeline import DataPipeline
from .shardlists import (
MultiShardSample,
ResampledShards,
SimpleShardList,
non_empty,
resampled,
shardspec,
single_node_only,
split_by_node,
split_by_worker,
)
from .tariterators import tarfile_samples, tarfile_to_samples
from .utils import PipelineStage, repeatedly
from webdataset.writer import ShardWriter, TarWriter, numpy_dumps
from webdataset.mix import RandomMix, RoundRobin

@ -0,0 +1,190 @@
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
import itertools, os, random, re, sys
from urllib.parse import urlparse
from . import filters
from webdataset import gopen
from webdataset.handlers import reraise_exception
from .tariterators import tar_file_and_group_expander
default_cache_dir = os.environ.get("WDS_CACHE", "./_cache")
default_cache_size = float(os.environ.get("WDS_CACHE_SIZE", "1e18"))
def lru_cleanup(cache_dir, cache_size, keyfn=os.path.getctime, verbose=False):
"""Performs cleanup of the file cache in cache_dir using an LRU strategy,
keeping the total size of all remaining files below cache_size."""
if not os.path.exists(cache_dir):
return
total_size = 0
for dirpath, dirnames, filenames in os.walk(cache_dir):
for filename in filenames:
total_size += os.path.getsize(os.path.join(dirpath, filename))
if total_size <= cache_size:
return
# sort files by last access time
files = []
for dirpath, dirnames, filenames in os.walk(cache_dir):
for filename in filenames:
files.append(os.path.join(dirpath, filename))
files.sort(key=keyfn, reverse=True)
# delete files until we're under the cache size
while len(files) > 0 and total_size > cache_size:
fname = files.pop()
total_size -= os.path.getsize(fname)
if verbose:
print("# deleting %s" % fname, file=sys.stderr)
os.remove(fname)
def download(url, dest, chunk_size=1024 ** 2, verbose=False):
"""Download a file from `url` to `dest`."""
temp = dest + f".temp{os.getpid()}"
with gopen.gopen(url) as stream:
with open(temp, "wb") as f:
while True:
data = stream.read(chunk_size)
if not data:
break
f.write(data)
os.rename(temp, dest)
def pipe_cleaner(spec):
"""Guess the actual URL from a "pipe:" specification."""
if spec.startswith("pipe:"):
spec = spec[5:]
words = spec.split(" ")
for word in words:
if re.match(r"^(https?|gs|ais|s3)", word):
return word
return spec
def get_file_cached(
spec,
cache_size=-1,
cache_dir=None,
url_to_name=pipe_cleaner,
verbose=False,
):
if cache_size == -1:
cache_size = default_cache_size
if cache_dir is None:
cache_dir = default_cache_dir
url = url_to_name(spec)
parsed = urlparse(url)
dirname, filename = os.path.split(parsed.path)
dirname = dirname.lstrip("/")
dirname = re.sub(r"[:/|;]", "_", dirname)
destdir = os.path.join(cache_dir, dirname)
os.makedirs(destdir, exist_ok=True)
dest = os.path.join(cache_dir, dirname, filename)
if not os.path.exists(dest):
if verbose:
print("# downloading %s to %s" % (url, dest), file=sys.stderr)
lru_cleanup(cache_dir, cache_size, verbose=verbose)
download(spec, dest, verbose=verbose)
return dest
def get_filetype(fname):
with os.popen("file '%s'" % fname) as f:
ftype = f.read()
return ftype
def check_tar_format(fname):
"""Check whether a file is a tar archive."""
ftype = get_filetype(fname)
return "tar archive" in ftype or "gzip compressed" in ftype
verbose_cache = int(os.environ.get("WDS_VERBOSE_CACHE", "0"))
def cached_url_opener(
data,
handler=reraise_exception,
cache_size=-1,
cache_dir=None,
url_to_name=pipe_cleaner,
validator=check_tar_format,
verbose=False,
always=False,
):
"""Given a stream of url names (packaged in `dict(url=url)`), yield opened streams."""
verbose = verbose or verbose_cache
for sample in data:
assert isinstance(sample, dict), sample
assert "url" in sample
url = sample["url"]
attempts = 5
try:
if not always and os.path.exists(url):
dest = url
else:
dest = get_file_cached(
url,
cache_size=cache_size,
cache_dir=cache_dir,
url_to_name=url_to_name,
verbose=verbose,
)
if verbose:
print("# opening %s" % dest, file=sys.stderr)
assert os.path.exists(dest)
if not validator(dest):
ftype = get_filetype(dest)
with open(dest, "rb") as f:
data = f.read(200)
os.remove(dest)
raise ValueError(
"%s (%s) is not a tar archive, but a %s, contains %s"
% (dest, url, ftype, repr(data))
)
try:
stream = open(dest, "rb")
sample.update(stream=stream)
yield sample
except FileNotFoundError as exn:
# dealing with race conditions in lru_cleanup
attempts -= 1
if attempts > 0:
time.sleep(random.random() * 10)
continue
raise exn
except Exception as exn:
exn.args = exn.args + (url,)
if handler(exn):
continue
else:
break
def cached_tarfile_samples(
src,
handler=reraise_exception,
cache_size=-1,
cache_dir=None,
verbose=False,
url_to_name=pipe_cleaner,
always=False,
):
streams = cached_url_opener(
src,
handler=handler,
cache_size=cache_size,
cache_dir=cache_dir,
verbose=verbose,
url_to_name=url_to_name,
always=always,
)
samples = tar_file_and_group_expander(streams, handler=handler)
return samples
cached_tarfile_to_samples = filters.pipelinefilter(cached_tarfile_samples)

@ -0,0 +1,170 @@
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
from dataclasses import dataclass
from itertools import islice
from typing import List
import braceexpand, yaml
from webdataset import autodecode
from . import cache, filters, shardlists, tariterators
from .filters import reraise_exception
from .pipeline import DataPipeline
from .paddle_utils import DataLoader, IterableDataset
class FluidInterface:
def batched(self, batchsize):
return self.compose(filters.batched(batchsize))
def dynamic_batched(self, max_frames_in_batch):
return self.compose(filter.dynamic_batched(max_frames_in_batch))
def unbatched(self):
return self.compose(filters.unbatched())
def listed(self, batchsize, partial=True):
return self.compose(filters.batched(), batchsize=batchsize, collation_fn=None)
def unlisted(self):
return self.compose(filters.unlisted())
def log_keys(self, logfile=None):
return self.compose(filters.log_keys(logfile))
def shuffle(self, size, **kw):
if size < 1:
return self
else:
return self.compose(filters.shuffle(size, **kw))
def map(self, f, handler=reraise_exception):
return self.compose(filters.map(f, handler=handler))
def decode(self, *args, pre=None, post=None, only=None, partial=False, handler=reraise_exception):
handlers = [autodecode.ImageHandler(x) if isinstance(x, str) else x for x in args]
decoder = autodecode.Decoder(handlers, pre=pre, post=post, only=only, partial=partial)
return self.map(decoder, handler=handler)
def map_dict(self, handler=reraise_exception, **kw):
return self.compose(filters.map_dict(handler=handler, **kw))
def select(self, predicate, **kw):
return self.compose(filters.select(predicate, **kw))
def to_tuple(self, *args, handler=reraise_exception):
return self.compose(filters.to_tuple(*args, handler=handler))
def map_tuple(self, *args, handler=reraise_exception):
return self.compose(filters.map_tuple(*args, handler=handler))
def slice(self, *args):
return self.compose(filters.slice(*args))
def rename(self, **kw):
return self.compose(filters.rename(**kw))
def rsample(self, p=0.5):
return self.compose(filters.rsample(p))
def rename_keys(self, *args, **kw):
return self.compose(filters.rename_keys(*args, **kw))
def extract_keys(self, *args, **kw):
return self.compose(filters.extract_keys(*args, **kw))
def xdecode(self, *args, **kw):
return self.compose(filters.xdecode(*args, **kw))
def data_filter(self, *args, **kw):
return self.compose(filters.data_filter(*args, **kw))
def tokenize(self, *args, **kw):
return self.compose(filters.tokenize(*args, **kw))
def resample(self, *args, **kw):
return self.compose(filters.resample(*args, **kw))
def compute_fbank(self, *args, **kw):
return self.compose(filters.compute_fbank(*args, **kw))
def spec_aug(self, *args, **kw):
return self.compose(filters.spec_aug(*args, **kw))
def sort(self, size=500):
return self.compose(filters.sort(size))
def padding(self):
return self.compose(filters.padding())
def cmvn(self, cmvn_file):
return self.compose(filters.cmvn(cmvn_file))
class WebDataset(DataPipeline, FluidInterface):
"""Small fluid-interface wrapper for DataPipeline."""
def __init__(
self,
urls,
handler=reraise_exception,
resampled=False,
repeat=False,
shardshuffle=None,
cache_size=0,
cache_dir=None,
detshuffle=False,
nodesplitter=shardlists.single_node_only,
verbose=False,
):
super().__init__()
if isinstance(urls, IterableDataset):
assert not resampled
self.append(urls)
elif isinstance(urls, str) and (urls.endswith(".yaml") or urls.endswith(".yml")):
with (open(urls)) as stream:
spec = yaml.safe_load(stream)
assert "datasets" in spec
self.append(shardlists.MultiShardSample(spec))
elif isinstance(urls, dict):
assert "datasets" in urls
self.append(shardlists.MultiShardSample(urls))
elif resampled:
self.append(shardlists.ResampledShards(urls))
else:
self.append(shardlists.SimpleShardList(urls))
self.append(nodesplitter)
self.append(shardlists.split_by_worker)
if shardshuffle is True:
shardshuffle = 100
if shardshuffle is not None:
if detshuffle:
self.append(filters.detshuffle(shardshuffle))
else:
self.append(filters.shuffle(shardshuffle))
if cache_size == 0:
self.append(tariterators.tarfile_to_samples(handler=handler))
else:
assert cache_size == -1 or cache_size > 0
self.append(
cache.cached_tarfile_to_samples(
handler=handler,
verbose=verbose,
cache_size=cache_size,
cache_dir=cache_dir,
)
)
class FluidWrapper(DataPipeline, FluidInterface):
"""Small fluid-interface wrapper for DataPipeline."""
def __init__(self, initial):
super().__init__()
self.append(initial)
class WebLoader(DataPipeline, FluidInterface):
def __init__(self, *args, **kw):
super().__init__(DataLoader(*args, **kw))

@ -0,0 +1,912 @@
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
# Modified from https://github.com/webdataset/webdataset
# Modified from wenet(https://github.com/wenet-e2e/wenet)
"""A collection of iterators for data transformations.
These functions are plain iterator functions. You can find curried versions
in webdataset.filters, and you can find IterableDataset wrappers in
webdataset.processing.
"""
import io
from fnmatch import fnmatch
import re
import itertools, os, random, sys, time
from functools import reduce, wraps
import numpy as np
from webdataset import autodecode
from . import utils
from .paddle_utils import PaddleTensor
from .utils import PipelineStage
from .. import backends
from ..compliance import kaldi
import paddle
from ..transform.cmvn import GlobalCMVN
from ..utils.tensor_utils import pad_sequence
from ..transform.spec_augment import time_warp
from ..transform.spec_augment import time_mask
from ..transform.spec_augment import freq_mask
class FilterFunction(object):
"""Helper class for currying pipeline stages.
We use this roundabout construct becauce it can be pickled.
"""
def __init__(self, f, *args, **kw):
"""Create a curried function."""
self.f = f
self.args = args
self.kw = kw
def __call__(self, data):
"""Call the curried function with the given argument."""
return self.f(data, *self.args, **self.kw)
def __str__(self):
"""Compute a string representation."""
return f"<{self.f.__name__} {self.args} {self.kw}>"
def __repr__(self):
"""Compute a string representation."""
return f"<{self.f.__name__} {self.args} {self.kw}>"
class RestCurried(object):
"""Helper class for currying pipeline stages.
We use this roundabout construct because it can be pickled.
"""
def __init__(self, f):
"""Store the function for future currying."""
self.f = f
def __call__(self, *args, **kw):
"""Curry with the given arguments."""
return FilterFunction(self.f, *args, **kw)
def pipelinefilter(f):
"""Turn the decorated function into one that is partially applied for
all arguments other than the first."""
result = RestCurried(f)
return result
def reraise_exception(exn):
"""Reraises the given exception; used as a handler.
:param exn: exception
"""
raise exn
def identity(x):
"""Return the argument."""
return x
def compose2(f, g):
"""Compose two functions, g(f(x))."""
return lambda x: g(f(x))
def compose(*args):
"""Compose a sequence of functions (left-to-right)."""
return reduce(compose2, args)
def pipeline(source, *args):
"""Write an input pipeline; first argument is source, rest are filters."""
if len(args) == 0:
return source
return compose(*args)(source)
def getfirst(a, keys, default=None, missing_is_error=True):
"""Get the first matching key from a dictionary.
Keys can be specified as a list, or as a string of keys separated by ';'.
"""
if isinstance(keys, str):
assert " " not in keys
keys = keys.split(";")
for k in keys:
if k in a:
return a[k]
if missing_is_error:
raise ValueError(f"didn't find {keys} in {list(a.keys())}")
return default
def parse_field_spec(fields):
"""Parse a specification for a list of fields to be extracted.
Keys are separated by spaces in the spec. Each key can itself
be composed of key alternatives separated by ';'.
"""
if isinstance(fields, str):
fields = fields.split()
return [field.split(";") for field in fields]
def transform_with(sample, transformers):
"""Transform a list of values using a list of functions.
sample: list of values
transformers: list of functions
If there are fewer transformers than inputs, or if a transformer
function is None, then the identity function is used for the
corresponding sample fields.
"""
if transformers is None or len(transformers) == 0:
return sample
result = list(sample)
assert len(transformers) <= len(sample)
for i in range(len(transformers)): # skipcq: PYL-C0200
f = transformers[i]
if f is not None:
result[i] = f(sample[i])
return result
###
# Iterators
###
def _info(data, fmt=None, n=3, every=-1, width=50, stream=sys.stderr, name=""):
"""Print information about the samples that are passing through.
:param data: source iterator
:param fmt: format statement (using sample dict as keyword)
:param n: when to stop
:param every: how often to print
:param width: maximum width
:param stream: output stream
:param name: identifier printed before any output
"""
for i, sample in enumerate(data):
if i < n or (every > 0 and (i + 1) % every == 0):
if fmt is None:
print("---", name, file=stream)
for k, v in sample.items():
print(k, repr(v)[:width], file=stream)
else:
print(fmt.format(**sample), file=stream)
yield sample
info = pipelinefilter(_info)
def pick(buf, rng):
k = rng.randint(0, len(buf) - 1)
sample = buf[k]
buf[k] = buf[-1]
buf.pop()
return sample
def _shuffle(data, bufsize=1000, initial=100, rng=None, handler=None):
"""Shuffle the data in the stream.
This uses a buffer of size `bufsize`. Shuffling at
startup is less random; this is traded off against
yielding samples quickly.
data: iterator
bufsize: buffer size for shuffling
returns: iterator
rng: either random module or random.Random instance
"""
if rng is None:
rng = random.Random(int((os.getpid() + time.time()) * 1e9))
initial = min(initial, bufsize)
buf = []
for sample in data:
buf.append(sample)
if len(buf) < bufsize:
try:
buf.append(next(data)) # skipcq: PYL-R1708
except StopIteration:
pass
if len(buf) >= initial:
yield pick(buf, rng)
while len(buf) > 0:
yield pick(buf, rng)
shuffle = pipelinefilter(_shuffle)
class detshuffle(PipelineStage):
def __init__(self, bufsize=1000, initial=100, seed=0, epoch=-1):
self.bufsize = bufsize
self.initial = initial
self.seed = seed
self.epoch = epoch
def run(self, src):
self.epoch += 1
rng = random.Random()
rng.seed((self.seed, self.epoch))
return _shuffle(src, self.bufsize, self.initial, rng)
def _select(data, predicate):
"""Select samples based on a predicate.
:param data: source iterator
:param predicate: predicate (function)
"""
for sample in data:
if predicate(sample):
yield sample
select = pipelinefilter(_select)
def _log_keys(data, logfile=None):
import fcntl
if logfile is None or logfile == "":
for sample in data:
yield sample
else:
with open(logfile, "a") as stream:
for i, sample in enumerate(data):
buf = f"{i}\t{sample.get('__worker__')}\t{sample.get('__rank__')}\t{sample.get('__key__')}\n"
try:
fcntl.flock(stream.fileno(), fcntl.LOCK_EX)
stream.write(buf)
finally:
fcntl.flock(stream.fileno(), fcntl.LOCK_UN)
yield sample
log_keys = pipelinefilter(_log_keys)
def _decode(data, *args, handler=reraise_exception, **kw):
"""Decode data based on the decoding functions given as arguments."""
decoder = lambda x: autodecode.imagehandler(x) if isinstance(x, str) else x
handlers = [decoder(x) for x in args]
f = autodecode.Decoder(handlers, **kw)
for sample in data:
assert isinstance(sample, dict), sample
try:
decoded = f(sample)
except Exception as exn: # skipcq: PYL-W0703
if handler(exn):
continue
else:
break
yield decoded
decode = pipelinefilter(_decode)
def _map(data, f, handler=reraise_exception):
"""Map samples."""
for sample in data:
try:
result = f(sample)
except Exception as exn:
if handler(exn):
continue
else:
break
if result is None:
continue
if isinstance(sample, dict) and isinstance(result, dict):
result["__key__"] = sample.get("__key__")
yield result
map = pipelinefilter(_map)
def _rename(data, handler=reraise_exception, keep=True, **kw):
"""Rename samples based on keyword arguments."""
for sample in data:
try:
if not keep:
yield {k: getfirst(sample, v, missing_is_error=True) for k, v in kw.items()}
else:
def listify(v):
return v.split(";") if isinstance(v, str) else v
to_be_replaced = {x for v in kw.values() for x in listify(v)}
result = {k: v for k, v in sample.items() if k not in to_be_replaced}
result.update({k: getfirst(sample, v, missing_is_error=True) for k, v in kw.items()})
yield result
except Exception as exn:
if handler(exn):
continue
else:
break
rename = pipelinefilter(_rename)
def _associate(data, associator, **kw):
"""Associate additional data with samples."""
for sample in data:
if callable(associator):
extra = associator(sample["__key__"])
else:
extra = associator.get(sample["__key__"], {})
sample.update(extra) # destructive
yield sample
associate = pipelinefilter(_associate)
def _map_dict(data, handler=reraise_exception, **kw):
"""Map the entries in a dict sample with individual functions."""
assert len(list(kw.keys())) > 0
for key, f in kw.items():
assert callable(f), (key, f)
for sample in data:
assert isinstance(sample, dict)
try:
for k, f in kw.items():
sample[k] = f(sample[k])
except Exception as exn:
if handler(exn):
continue
else:
break
yield sample
map_dict = pipelinefilter(_map_dict)
def _to_tuple(data, *args, handler=reraise_exception, missing_is_error=True, none_is_error=None):
"""Convert dict samples to tuples."""
if none_is_error is None:
none_is_error = missing_is_error
if len(args) == 1 and isinstance(args[0], str) and " " in args[0]:
args = args[0].split()
for sample in data:
try:
result = tuple([getfirst(sample, f, missing_is_error=missing_is_error) for f in args])
if none_is_error and any(x is None for x in result):
raise ValueError(f"to_tuple {args} got {sample.keys()}")
yield result
except Exception as exn:
if handler(exn):
continue
else:
break
to_tuple = pipelinefilter(_to_tuple)
def _map_tuple(data, *args, handler=reraise_exception):
"""Map the entries of a tuple with individual functions."""
args = [f if f is not None else utils.identity for f in args]
for f in args:
assert callable(f), f
for sample in data:
assert isinstance(sample, (list, tuple))
sample = list(sample)
n = min(len(args), len(sample))
try:
for i in range(n):
sample[i] = args[i](sample[i])
except Exception as exn:
if handler(exn):
continue
else:
break
yield tuple(sample)
map_tuple = pipelinefilter(_map_tuple)
def _unlisted(data):
"""Turn batched data back into unbatched data."""
for batch in data:
assert isinstance(batch, list), sample
for sample in batch:
yield sample
unlisted = pipelinefilter(_unlisted)
def _unbatched(data):
"""Turn batched data back into unbatched data."""
for sample in data:
assert isinstance(sample, (tuple, list)), sample
assert len(sample) > 0
for i in range(len(sample[0])):
yield tuple(x[i] for x in sample)
unbatched = pipelinefilter(_unbatched)
def _rsample(data, p=0.5):
"""Randomly subsample a stream of data."""
assert p >= 0.0 and p <= 1.0
for sample in data:
if random.uniform(0.0, 1.0) < p:
yield sample
rsample = pipelinefilter(_rsample)
slice = pipelinefilter(itertools.islice)
def _extract_keys(source, *patterns, duplicate_is_error=True, ignore_missing=False):
for sample in source:
result = []
for pattern in patterns:
pattern = pattern.split(";") if isinstance(pattern, str) else pattern
matches = [x for x in sample.keys() if any(fnmatch("." + x, p) for p in pattern)]
if len(matches) == 0:
if ignore_missing:
continue
else:
raise ValueError(f"Cannot find {pattern} in sample keys {sample.keys()}.")
if len(matches) > 1 and duplicate_is_error:
raise ValueError(f"Multiple sample keys {sample.keys()} match {pattern}.")
value = sample[matches[0]]
result.append(value)
yield tuple(result)
extract_keys = pipelinefilter(_extract_keys)
def _rename_keys(source, *args, keep_unselected=False, must_match=True, duplicate_is_error=True, **kw):
renamings = [(pattern, output) for output, pattern in args]
renamings += [(pattern, output) for output, pattern in kw.items()]
for sample in source:
new_sample = {}
matched = {k: False for k, _ in renamings}
for path, value in sample.items():
fname = re.sub(r".*/", "", path)
new_name = None
for pattern, name in renamings[::-1]:
if fnmatch(fname.lower(), pattern):
matched[pattern] = True
new_name = name
break
if new_name is None:
if keep_unselected:
new_sample[path] = value
continue
if new_name in new_sample:
if duplicate_is_error:
raise ValueError(f"Duplicate value in sample {sample.keys()} after rename.")
continue
new_sample[new_name] = value
if must_match and not all(matched.values()):
raise ValueError(f"Not all patterns ({matched}) matched sample keys ({sample.keys()}).")
yield new_sample
rename_keys = pipelinefilter(_rename_keys)
def decode_bin(stream):
return stream.read()
def decode_text(stream):
binary = stream.read()
return binary.decode("utf-8")
def decode_pickle(stream):
return pickle.load(stream)
default_decoders = [
("*.bin", decode_bin),
("*.txt", decode_text),
("*.pyd", decode_pickle),
]
def find_decoder(decoders, path):
fname = re.sub(r".*/", "", path)
if fname.startswith("__"):
return lambda x: x
for pattern, fun in decoders[::-1]:
if fnmatch(fname.lower(), pattern) or fnmatch("." + fname.lower(), pattern):
return fun
return None
def _xdecode(
source,
*args,
must_decode=True,
defaults=default_decoders,
**kw,
):
decoders = list(defaults) + list(args)
decoders += [("*." + k, v) for k, v in kw.items()]
for sample in source:
new_sample = {}
for path, data in sample.items():
if path.startswith("__"):
new_sample[path] = data
continue
decoder = find_decoder(decoders, path)
if decoder is False:
value = data
elif decoder is None:
if must_decode:
raise ValueError(f"No decoder found for {path}.")
value = data
else:
if isinstance(data, bytes):
data = io.BytesIO(data)
value = decoder(data)
new_sample[path] = value
yield new_sample
xdecode = pipelinefilter(_xdecode)
def _data_filter(source,
frame_shift=10,
max_length=10240,
min_length=10,
token_max_length=200,
token_min_length=1,
min_output_input_ratio=0.0005,
max_output_input_ratio=1):
""" Filter sample according to feature and label length
Inplace operation.
Args::
source: Iterable[{fname, wav, label, sample_rate}]
frame_shift: length of frame shift (ms)
max_length: drop utterance which is greater than max_length(10ms)
min_length: drop utterance which is less than min_length(10ms)
token_max_length: drop utterance which is greater than
token_max_length, especially when use char unit for
english modeling
token_min_length: drop utterance which is
less than token_max_length
min_output_input_ratio: minimal ration of
token_length / feats_length(10ms)
max_output_input_ratio: maximum ration of
token_length / feats_length(10ms)
Returns:
Iterable[{fname, wav, label, sample_rate}]
"""
for sample in source:
assert 'sample_rate' in sample
assert 'wav' in sample
assert 'label' in sample
# sample['wav'] is paddle.Tensor, we have 100 frames every second (default)
num_frames = sample['wav'].shape[1] / sample['sample_rate'] * (1000 / frame_shift)
if num_frames < min_length:
continue
if num_frames > max_length:
continue
if len(sample['label']) < token_min_length:
continue
if len(sample['label']) > token_max_length:
continue
if num_frames != 0:
if len(sample['label']) / num_frames < min_output_input_ratio:
continue
if len(sample['label']) / num_frames > max_output_input_ratio:
continue
yield sample
data_filter = pipelinefilter(_data_filter)
def _tokenize(source,
symbol_table,
bpe_model=None,
non_lang_syms=None,
split_with_space=False):
""" Decode text to chars or BPE
Inplace operation
Args:
source: Iterable[{fname, wav, txt, sample_rate}]
Returns:
Iterable[{fname, wav, txt, tokens, label, sample_rate}]
"""
if non_lang_syms is not None:
non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})")
else:
non_lang_syms = {}
non_lang_syms_pattern = None
if bpe_model is not None:
import sentencepiece as spm
sp = spm.SentencePieceProcessor()
sp.load(bpe_model)
else:
sp = None
for sample in source:
assert 'txt' in sample
txt = sample['txt'].strip()
if non_lang_syms_pattern is not None:
parts = non_lang_syms_pattern.split(txt.upper())
parts = [w for w in parts if len(w.strip()) > 0]
else:
parts = [txt]
label = []
tokens = []
for part in parts:
if part in non_lang_syms:
tokens.append(part)
else:
if bpe_model is not None:
tokens.extend(__tokenize_by_bpe_model(sp, part))
else:
if split_with_space:
part = part.split(" ")
for ch in part:
if ch == ' ':
ch = "<space>"
tokens.append(ch)
for ch in tokens:
if ch in symbol_table:
label.append(symbol_table[ch])
elif '<unk>' in symbol_table:
label.append(symbol_table['<unk>'])
sample['tokens'] = tokens
sample['label'] = label
yield sample
tokenize = pipelinefilter(_tokenize)
def _resample(source, resample_rate=16000):
""" Resample data.
Inplace operation.
Args:
data: Iterable[{fname, wav, label, sample_rate}]
resample_rate: target resample rate
Returns:
Iterable[{fname, wav, label, sample_rate}]
"""
for sample in source:
assert 'sample_rate' in sample
assert 'wav' in sample
sample_rate = sample['sample_rate']
waveform = sample['wav']
if sample_rate != resample_rate:
sample['sample_rate'] = resample_rate
sample['wav'] = paddle.to_tensor(backends.soundfile_backend.resample(
waveform.numpy(), src_sr = sample_rate, target_sr = resample_rate
))
yield sample
resample = pipelinefilter(_resample)
def _compute_fbank(source,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
dither=0.0):
""" Extract fbank
Args:
source: Iterable[{fname, wav, label, sample_rate}]
num_mel_bins: number of mel filter bank
frame_length: length of one frame (ms)
frame_shift: length of frame shift (ms)
dither: value of dither
Returns:
Iterable[{fname, feat, label}]
"""
for sample in source:
assert 'sample_rate' in sample
assert 'wav' in sample
assert 'fname' in sample
assert 'label' in sample
sample_rate = sample['sample_rate']
waveform = sample['wav']
waveform = waveform * (1 << 15)
# Only keep fname, feat, label
mat = kaldi.fbank(waveform,
n_mels=num_mel_bins,
frame_length=frame_length,
frame_shift=frame_shift,
dither=dither,
energy_floor=0.0,
sr=sample_rate)
yield dict(fname=sample['fname'], label=sample['label'], feat=mat)
compute_fbank = pipelinefilter(_compute_fbank)
def _spec_aug(source, num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80):
""" Do spec augmentation
Inplace operation
Args:
source: Iterable[{fname, feat, label}]
num_t_mask: number of time mask to apply
num_f_mask: number of freq mask to apply
max_t: max width of time mask
max_f: max width of freq mask
max_w: max width of time warp
Returns
Iterable[{fname, feat, label}]
"""
for sample in source:
x = sample['feat']
x = x.numpy()
x = time_warp(x, max_time_warp=max_w, inplace = True, mode= "PIL")
x = freq_mask(x, F = max_f, n_mask = num_f_mask, inplace = True, replace_with_zero = False)
x = time_mask(x, T = max_t, n_mask = num_t_mask, inplace = True, replace_with_zero = False)
sample['feat'] = paddle.to_tensor(x, dtype=paddle.float32)
yield sample
spec_aug = pipelinefilter(_spec_aug)
def _sort(source, sort_size=500):
""" Sort the data by feature length.
Sort is used after shuffle and before batch, so we can group
utts with similar lengths into a batch, and `sort_size` should
be less than `shuffle_size`
Args:
source: Iterable[{fname, feat, label}]
sort_size: buffer size for sort
Returns:
Iterable[{fname, feat, label}]
"""
buf = []
for sample in source:
buf.append(sample)
if len(buf) >= sort_size:
buf.sort(key=lambda x: x['feat'].shape[0])
for x in buf:
yield x
buf = []
# The sample left over
buf.sort(key=lambda x: x['feat'].shape[0])
for x in buf:
yield x
sort = pipelinefilter(_sort)
def _batched(source, batch_size=16):
""" Static batch the data by `batch_size`
Args:
data: Iterable[{fname, feat, label}]
batch_size: batch size
Returns:
Iterable[List[{fname, feat, label}]]
"""
buf = []
for sample in source:
buf.append(sample)
if len(buf) >= batch_size:
yield buf
buf = []
if len(buf) > 0:
yield buf
batched = pipelinefilter(_batched)
def dynamic_batched(source, max_frames_in_batch=12000):
""" Dynamic batch the data until the total frames in batch
reach `max_frames_in_batch`
Args:
source: Iterable[{fname, feat, label}]
max_frames_in_batch: max_frames in one batch
Returns:
Iterable[List[{fname, feat, label}]]
"""
buf = []
longest_frames = 0
for sample in source:
assert 'feat' in sample
assert isinstance(sample['feat'], paddle.Tensor)
new_sample_frames = sample['feat'].size(0)
longest_frames = max(longest_frames, new_sample_frames)
frames_after_padding = longest_frames * (len(buf) + 1)
if frames_after_padding > max_frames_in_batch:
yield buf
buf = [sample]
longest_frames = new_sample_frames
else:
buf.append(sample)
if len(buf) > 0:
yield buf
def _padding(source):
""" Padding the data into training data
Args:
source: Iterable[List[{fname, feat, label}]]
Returns:
Iterable[Tuple(fname, feats, labels, feats lengths, label lengths)]
"""
for sample in source:
assert isinstance(sample, list)
feats_length = paddle.to_tensor([x['feat'].shape[0] for x in sample],
dtype="int64")
order = paddle.argsort(feats_length, descending=True)
feats_lengths = paddle.to_tensor(
[sample[i]['feat'].shape[0] for i in order], dtype="int64")
sorted_feats = [sample[i]['feat'] for i in order]
sorted_keys = [sample[i]['fname'] for i in order]
sorted_labels = [
paddle.to_tensor(sample[i]['label'], dtype="int32") for i in order
]
label_lengths = paddle.to_tensor([x.shape[0] for x in sorted_labels],
dtype="int64")
padded_feats = pad_sequence(sorted_feats,
batch_first=True,
padding_value=0)
padding_labels = pad_sequence(sorted_labels,
batch_first=True,
padding_value=-1)
yield (sorted_keys, padded_feats, feats_lengths, padding_labels,
label_lengths)
padding = pipelinefilter(_padding)
def _cmvn(source, cmvn_file):
global_cmvn = GlobalCMVN(cmvn_file)
for batch in source:
sorted_keys, padded_feats, feats_lengths, padding_labels, label_lengths = batch
padded_feats = padded_feats.numpy()
padded_feats = global_cmvn(padded_feats)
padded_feats = paddle.to_tensor(padded_feats, dtype=paddle.float32)
yield (sorted_keys, padded_feats, feats_lengths, padding_labels,
label_lengths)
cmvn = pipelinefilter(_cmvn)

@ -0,0 +1,33 @@
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
"""Mock implementations of paddle interfaces when paddle is not available."""
try:
from paddle.io import DataLoader, IterableDataset
except ModuleNotFoundError:
class IterableDataset:
"""Empty implementation of IterableDataset when paddle is not available."""
pass
class DataLoader:
"""Empty implementation of DataLoader when paddle is not available."""
pass
try:
from paddle import Tensor as PaddleTensor
except ModuleNotFoundError:
class TorchTensor:
"""Empty implementation of PaddleTensor when paddle is not available."""
pass

@ -0,0 +1,127 @@
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#%%
import copy, os, random, sys, time
from dataclasses import dataclass
from itertools import islice
from typing import List
import braceexpand, yaml
from webdataset import autodecode, extradatasets as eds, filters, shardlists, tariterators
from webdataset.handlers import reraise_exception
from .paddle_utils import DataLoader, IterableDataset
from .utils import PipelineStage
def add_length_method(obj):
def length(self):
return self.size
Combined = type(
obj.__class__.__name__ + "_Length",
(obj.__class__, IterableDataset),
{"__len__": length},
)
obj.__class__ = Combined
return obj
class DataPipeline(IterableDataset, PipelineStage):
"""A pipeline starting with an IterableDataset and a series of filters."""
def __init__(self, *args, **kwargs):
super().__init__()
self.pipeline = []
self.length = -1
self.repetitions = 1
self.nsamples = -1
for arg in args:
if arg is None:
continue
if isinstance(arg, list):
self.pipeline.extend(arg)
else:
self.pipeline.append(arg)
def invoke(self, f, *args, **kwargs):
"""Apply a pipeline stage, possibly to the output of a previous stage."""
if isinstance(f, PipelineStage):
return f.run(*args, **kwargs)
if isinstance(f, (IterableDataset, DataLoader)) and len(args) == 0:
return iter(f)
if isinstance(f, list):
return iter(f)
if callable(f):
result = f(*args, **kwargs)
return result
raise ValueError(f"{f}: not a valid pipeline stage")
def iterator1(self):
"""Create an iterator through one epoch in the pipeline."""
source = self.invoke(self.pipeline[0])
for step in self.pipeline[1:]:
source = self.invoke(step, source)
return source
def iterator(self):
"""Create an iterator through the entire dataset, using the given number of repetitions."""
for i in range(self.repetitions):
for sample in self.iterator1():
yield sample
def __iter__(self):
"""Create an iterator through the pipeline, repeating and slicing as requested."""
if self.repetitions != 1:
if self.nsamples > 0:
return islice(self.iterator(), self.nsamples)
else:
return self.iterator()
else:
return self.iterator()
def stage(self, i):
"""Return pipeline stage i."""
return self.pipeline[i]
def append(self, f):
"""Append a pipeline stage (modifies the object)."""
self.pipeline.append(f)
def compose(self, *args):
"""Append a pipeline stage to a copy of the pipeline and returns the copy."""
result = copy.copy(self)
for arg in args:
result.append(arg)
return result
def with_length(self, n):
"""Add a __len__ method returning the desired value.
This does not change the actual number of samples in an epoch.
PyTorch IterableDataset should not have a __len__ method.
This is provided only as a workaround for some broken training environments
that require a __len__ method.
"""
self.size = n
return add_length_method(self)
def with_epoch(self, nsamples=-1, nbatches=-1):
"""Change the epoch to return the given number of samples/batches.
The two arguments mean the same thing."""
self.repetitions = sys.maxsize
self.nsamples = max(nsamples, nbatches)
return self
def repeat(self, nepochs=-1, nbatches=-1):
"""Repeat iterating through the dataset for the given #epochs up to the given #samples."""
if nepochs > 0:
self.repetitions = nepochs
self.nsamples = nbatches
else:
self.repetitions = sys.maxsize
self.nsamples = nbatches
return self

@ -0,0 +1,257 @@
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
# Modified from https://github.com/webdataset/webdataset
"""Train PyTorch models directly from POSIX tar archive.
Code works locally or over HTTP connections.
"""
import os, random, sys, time
from dataclasses import dataclass, field
from itertools import islice
from typing import List
import braceexpand, yaml
from . import utils
from .filters import pipelinefilter
from .paddle_utils import IterableDataset
def expand_urls(urls):
if isinstance(urls, str):
urllist = urls.split("::")
result = []
for url in urllist:
result.extend(braceexpand.braceexpand(url))
return result
else:
return list(urls)
class SimpleShardList(IterableDataset):
"""An iterable dataset yielding a list of urls."""
def __init__(self, urls, seed=None):
"""Iterate through the list of shards.
:param urls: a list of URLs as a Python list or brace notation string
"""
super().__init__()
urls = expand_urls(urls)
self.urls = urls
assert isinstance(self.urls[0], str)
self.seed = seed
def __len__(self):
return len(self.urls)
def __iter__(self):
"""Return an iterator over the shards."""
urls = self.urls.copy()
if self.seed is not None:
random.Random(self.seed).shuffle(urls)
for url in urls:
yield dict(url=url)
def split_by_node(src, group=None):
rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group)
if world_size > 1:
for s in islice(src, rank, None, world_size):
yield s
else:
for s in src:
yield s
def single_node_only(src, group=None):
rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group)
if world_size > 1:
raise ValueError("input pipeline needs to be reconfigured for multinode training")
for s in src:
yield s
def split_by_worker(src):
rank, world_size, worker, num_workers = utils.paddle_worker_info()
if num_workers > 1:
for s in islice(src, worker, None, num_workers):
yield s
else:
for s in src:
yield s
def resampled_(src, n=sys.maxsize):
import random
seed = time.time()
try:
seed = open("/dev/random", "rb").read(20)
except Exception as exn:
print(repr(exn)[:50], file=sys.stderr)
rng = random.Random(seed)
print("# resampled loading", file=sys.stderr)
items = list(src)
print(f"# resampled got {len(items)} samples, yielding {n}", file=sys.stderr)
for i in range(n):
yield rng.choice(items)
resampled = pipelinefilter(resampled_)
def non_empty(src):
count = 0
for s in src:
yield s
count += 1
if count == 0:
raise ValueError("pipeline stage received no data at all and this was declared as an error")
@dataclass
class MSSource:
"""Class representing a data source."""
name: str = ""
perepoch: int = -1
resample: bool = False
urls: List[str] = field(default_factory=list)
default_rng = random.Random()
def expand(s):
return os.path.expanduser(os.path.expandvars(s))
class MultiShardSample(IterableDataset):
def __init__(self, fname):
"""Construct a shardlist from multiple sources using a YAML spec."""
self.epoch = -1
class MultiShardSample(IterableDataset):
def __init__(self, fname):
"""Construct a shardlist from multiple sources using a YAML spec."""
self.epoch = -1
self.parse_spec(fname)
def parse_spec(self, fname):
self.rng = default_rng # capture default_rng if we fork
if isinstance(fname, dict):
spec = fname
fname = "{dict}"
else:
with open(fname) as stream:
spec = yaml.safe_load(stream)
assert set(spec.keys()).issubset(set("prefix datasets buckets".split())), list(spec.keys())
prefix = expand(spec.get("prefix", ""))
self.sources = []
for ds in spec["datasets"]:
assert set(ds.keys()).issubset(set("buckets name shards resample choose".split())), list(
ds.keys()
)
buckets = ds.get("buckets", spec.get("buckets", []))
if isinstance(buckets, str):
buckets = [buckets]
buckets = [expand(s) for s in buckets]
if buckets == []:
buckets = [""]
assert len(buckets) == 1, f"{buckets}: FIXME support for multiple buckets unimplemented"
bucket = buckets[0]
name = ds.get("name", "@" + bucket)
urls = ds["shards"]
if isinstance(urls, str):
urls = [urls]
# urls = [u for url in urls for u in braceexpand.braceexpand(url)]
urls = [
prefix + os.path.join(bucket, u) for url in urls for u in braceexpand.braceexpand(expand(url))
]
resample = ds.get("resample", -1)
nsample = ds.get("choose", -1)
if nsample > len(urls):
raise ValueError(f"perepoch {nsample} must be no greater than the number of shards")
if (nsample > 0) and (resample > 0):
raise ValueError("specify only one of perepoch or choose")
entry = MSSource(name=name, urls=urls, perepoch=nsample, resample=resample)
self.sources.append(entry)
print(f"# {name} {len(urls)} {nsample}", file=sys.stderr)
def set_epoch(self, seed):
"""Set the current epoch (for consistent shard selection among nodes)."""
self.rng = random.Random(seed)
def get_shards_for_epoch(self):
result = []
for source in self.sources:
if source.resample > 0:
# sample with replacement
l = self.rng.choices(source.urls, k=source.resample)
elif source.perepoch > 0:
# sample without replacement
l = list(source.urls)
self.rng.shuffle(l)
l = l[: source.perepoch]
else:
l = list(source.urls)
result += l
self.rng.shuffle(result)
return result
def __iter__(self):
shards = self.get_shards_for_epoch()
for shard in shards:
yield dict(url=shard)
def shardspec(spec):
if spec.endswith(".yaml"):
return MultiShardSample(spec)
else:
return SimpleShardList(spec)
class ResampledShards(IterableDataset):
"""An iterable dataset yielding a list of urls."""
def __init__(
self,
urls,
nshards=sys.maxsize,
worker_seed=None,
deterministic=False,
):
"""Sample shards from the shard list with replacement.
:param urls: a list of URLs as a Python list or brace notation string
"""
super().__init__()
urls = expand_urls(urls)
self.urls = urls
assert isinstance(self.urls[0], str)
self.nshards = nshards
self.worker_seed = utils.paddle_worker_seed if worker_seed is None else worker_seed
self.deterministic = deterministic
self.epoch = -1
def __iter__(self):
"""Return an iterator over the shards."""
self.epoch += 1
if self.deterministic:
seed = utils.make_seed(self.worker_seed(), self.epoch)
else:
seed = utils.make_seed(self.worker_seed(), self.epoch, os.getpid(), time.time_ns(), os.urandom(4))
if os.environ.get("WDS_SHOW_SEED", "0") == "1":
print(f"# ResampledShards seed {seed}")
self.rng = random.Random(seed)
for _ in range(self.nshards):
index = self.rng.randint(0, len(self.urls) - 1)
yield dict(url=self.urls[index])

@ -0,0 +1,283 @@
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
# Modified from wenet(https://github.com/wenet-e2e/wenet)
"""Low level iteration functions for tar archives."""
import random, re, tarfile
import braceexpand
from . import filters
from webdataset import gopen
from webdataset.handlers import reraise_exception
trace = False
meta_prefix = "__"
meta_suffix = "__"
from ... import audio as paddleaudio
import paddle
import numpy as np
AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
def base_plus_ext(path):
"""Split off all file extensions.
Returns base, allext.
:param path: path with extensions
:param returns: path with all extensions removed
"""
match = re.match(r"^((?:.*/|)[^.]+)[.]([^/]*)$", path)
if not match:
return None, None
return match.group(1), match.group(2)
def valid_sample(sample):
"""Check whether a sample is valid.
:param sample: sample to be checked
"""
return (
sample is not None
and isinstance(sample, dict)
and len(list(sample.keys())) > 0
and not sample.get("__bad__", False)
)
# FIXME: UNUSED
def shardlist(urls, *, shuffle=False):
"""Given a list of URLs, yields that list, possibly shuffled."""
if isinstance(urls, str):
urls = braceexpand.braceexpand(urls)
else:
urls = list(urls)
if shuffle:
random.shuffle(urls)
for url in urls:
yield dict(url=url)
def url_opener(data, handler=reraise_exception, **kw):
"""Given a stream of url names (packaged in `dict(url=url)`), yield opened streams."""
for sample in data:
assert isinstance(sample, dict), sample
assert "url" in sample
url = sample["url"]
try:
stream = gopen.gopen(url, **kw)
sample.update(stream=stream)
yield sample
except Exception as exn:
exn.args = exn.args + (url,)
if handler(exn):
continue
else:
break
def tar_file_iterator(
fileobj, skip_meta=r"__[^/]*__($|/)", handler=reraise_exception
):
"""Iterate over tar file, yielding filename, content pairs for the given tar stream.
:param fileobj: byte stream suitable for tarfile
:param skip_meta: regexp for keys that are skipped entirely (Default value = r"__[^/]*__($|/)")
"""
stream = tarfile.open(fileobj=fileobj, mode="r:*")
for tarinfo in stream:
fname = tarinfo.name
try:
if not tarinfo.isreg():
continue
if fname is None:
continue
if (
"/" not in fname
and fname.startswith(meta_prefix)
and fname.endswith(meta_suffix)
):
# skipping metadata for now
continue
if skip_meta is not None and re.match(skip_meta, fname):
continue
name = tarinfo.name
pos = name.rfind('.')
assert pos > 0
prefix, postfix = name[:pos], name[pos + 1:]
if postfix == 'wav':
waveform, sample_rate = paddleaudio.load(stream.extractfile(tarinfo), normal=False)
result = dict(fname=prefix, wav=waveform, sample_rate = sample_rate)
else:
txt = stream.extractfile(tarinfo).read().decode('utf8').strip()
result = dict(fname=prefix, txt=txt)
#result = dict(fname=fname, data=data)
yield result
stream.members = []
except Exception as exn:
if hasattr(exn, "args") and len(exn.args) > 0:
exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:]
if handler(exn):
continue
else:
break
del stream
def tar_file_and_group_iterator(
fileobj, skip_meta=r"__[^/]*__($|/)", handler=reraise_exception
):
""" Expand a stream of open tar files into a stream of tar file contents.
And groups the file with same prefix
Args:
data: Iterable[{src, stream}]
Returns:
Iterable[{key, wav, txt, sample_rate}]
"""
stream = tarfile.open(fileobj=fileobj, mode="r:*")
prev_prefix = None
example = {}
valid = True
for tarinfo in stream:
name = tarinfo.name
pos = name.rfind('.')
assert pos > 0
prefix, postfix = name[:pos], name[pos + 1:]
if prev_prefix is not None and prefix != prev_prefix:
example['fname'] = prev_prefix
if valid:
yield example
example = {}
valid = True
with stream.extractfile(tarinfo) as file_obj:
try:
if postfix == 'txt':
example['txt'] = file_obj.read().decode('utf8').strip()
elif postfix in AUDIO_FORMAT_SETS:
waveform, sample_rate = paddleaudio.load(file_obj, normal=False)
waveform = paddle.to_tensor(np.expand_dims(np.array(waveform),0), dtype=paddle.float32)
example['wav'] = waveform
example['sample_rate'] = sample_rate
else:
example[postfix] = file_obj.read()
except Exception as exn:
if hasattr(exn, "args") and len(exn.args) > 0:
exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:]
if handler(exn):
continue
else:
break
valid = False
# logging.warning('error to parse {}'.format(name))
prev_prefix = prefix
if prev_prefix is not None:
example['fname'] = prev_prefix
yield example
stream.close()
def tar_file_expander(data, handler=reraise_exception):
"""Expand a stream of open tar files into a stream of tar file contents.
This returns an iterator over (filename, file_contents).
"""
for source in data:
url = source["url"]
try:
assert isinstance(source, dict)
assert "stream" in source
for sample in tar_file_iterator(source["stream"]):
assert (
isinstance(sample, dict) and "data" in sample and "fname" in sample
)
sample["__url__"] = url
yield sample
except Exception as exn:
exn.args = exn.args + (source.get("stream"), source.get("url"))
if handler(exn):
continue
else:
break
def tar_file_and_group_expander(data, handler=reraise_exception):
"""Expand a stream of open tar files into a stream of tar file contents.
This returns an iterator over (filename, file_contents).
"""
for source in data:
url = source["url"]
try:
assert isinstance(source, dict)
assert "stream" in source
for sample in tar_file_and_group_iterator(source["stream"]):
assert (
isinstance(sample, dict) and "wav" in sample and "txt" in sample and "fname" in sample
)
sample["__url__"] = url
yield sample
except Exception as exn:
exn.args = exn.args + (source.get("stream"), source.get("url"))
if handler(exn):
continue
else:
break
def group_by_keys(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
"""Return function over iterator that groups key, value pairs into samples.
:param keys: function that splits the key into key and extension (base_plus_ext)
:param lcase: convert suffixes to lower case (Default value = True)
"""
current_sample = None
for filesample in data:
assert isinstance(filesample, dict)
fname, value = filesample["fname"], filesample["data"]
prefix, suffix = keys(fname)
if trace:
print(
prefix,
suffix,
current_sample.keys() if isinstance(current_sample, dict) else None,
)
if prefix is None:
continue
if lcase:
suffix = suffix.lower()
if current_sample is None or prefix != current_sample["__key__"]:
if valid_sample(current_sample):
yield current_sample
current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
if suffix in current_sample:
raise ValueError(
f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}"
)
if suffixes is None or suffix in suffixes:
current_sample[suffix] = value
if valid_sample(current_sample):
yield current_sample
def tarfile_samples(src, handler=reraise_exception):
streams = url_opener(src, handler=handler)
samples = tar_file_and_group_expander(streams, handler=handler)
return samples
tarfile_to_samples = filters.pipelinefilter(tarfile_samples)

@ -0,0 +1,128 @@
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
# Modified from https://github.com/webdataset/webdataset
"""Miscellaneous utility functions."""
import importlib
import itertools as itt
import os
import re
import sys
from typing import Any, Callable, Iterator, Optional, Union
def make_seed(*args):
seed = 0
for arg in args:
seed = (seed * 31 + hash(arg)) & 0x7FFFFFFF
return seed
class PipelineStage:
def invoke(self, *args, **kw):
raise NotImplementedError
def identity(x: Any) -> Any:
"""Return the argument as is."""
return x
def safe_eval(s: str, expr: str = "{}"):
"""Evaluate the given expression more safely."""
if re.sub("[^A-Za-z0-9_]", "", s) != s:
raise ValueError(f"safe_eval: illegal characters in: '{s}'")
return eval(expr.format(s))
def lookup_sym(sym: str, modules: list):
"""Look up a symbol in a list of modules."""
for mname in modules:
module = importlib.import_module(mname, package="webdataset")
result = getattr(module, sym, None)
if result is not None:
return result
return None
def repeatedly0(
loader: Iterator, nepochs: int = sys.maxsize, nbatches: int = sys.maxsize
):
"""Repeatedly returns batches from a DataLoader."""
for epoch in range(nepochs):
for sample in itt.islice(loader, nbatches):
yield sample
def guess_batchsize(batch: Union[tuple, list]):
"""Guess the batch size by looking at the length of the first element in a tuple."""
return len(batch[0])
def repeatedly(
source: Iterator,
nepochs: int = None,
nbatches: int = None,
nsamples: int = None,
batchsize: Callable[..., int] = guess_batchsize,
):
"""Repeatedly yield samples from an iterator."""
epoch = 0
batch = 0
total = 0
while True:
for sample in source:
yield sample
batch += 1
if nbatches is not None and batch >= nbatches:
return
if nsamples is not None:
total += guess_batchsize(sample)
if total >= nsamples:
return
epoch += 1
if nepochs is not None and epoch >= nepochs:
return
def paddle_worker_info(group=None):
"""Return node and worker info for PyTorch and some distributed environments."""
rank = 0
world_size = 1
worker = 0
num_workers = 1
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
else:
try:
import paddle.distributed
group = group or paddle.distributed.get_group()
rank = paddle.distributed.get_rank()
world_size = paddle.distributed.get_world_size()
except ModuleNotFoundError:
pass
if "WORKER" in os.environ and "NUM_WORKERS" in os.environ:
worker = int(os.environ["WORKER"])
num_workers = int(os.environ["NUM_WORKERS"])
else:
try:
import paddle.io.get_worker_info
worker_info = paddle.io.get_worker_info()
if worker_info is not None:
worker = worker_info.id
num_workers = worker_info.num_workers
except ModuleNotFoundError:
pass
return rank, world_size, worker, num_workers
def paddle_worker_seed(group=None):
"""Compute a distinct, deterministic RNG seed for each worker and node."""
rank, world_size, worker, num_workers = paddle_worker_info(group=group)
return rank * 1000 + worker

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,54 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import numpy as np
def delta(feat, window):
assert window > 0
delta_feat = np.zeros_like(feat)
for i in range(1, window + 1):
delta_feat[:-i] += i * feat[i:]
delta_feat[i:] += -i * feat[:-i]
delta_feat[-i:] += i * feat[-1]
delta_feat[:i] += -i * feat[0]
delta_feat /= 2 * sum(i**2 for i in range(1, window + 1))
return delta_feat
def add_deltas(x, window=2, order=2):
"""
Args:
x (np.ndarray): speech feat, (T, D).
Return:
np.ndarray: (T, (1+order)*D)
"""
feats = [x]
for _ in range(order):
feats.append(delta(feats[-1], window))
return np.concatenate(feats, axis=1)
class AddDeltas():
def __init__(self, window=2, order=2):
self.window = window
self.order = order
def __repr__(self):
return "{name}(window={window}, order={order}".format(
name=self.__class__.__name__, window=self.window, order=self.order)
def __call__(self, x):
return add_deltas(x, window=self.window, order=self.order)

@ -0,0 +1,57 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import numpy
class ChannelSelector():
"""Select 1ch from multi-channel signal"""
def __init__(self, train_channel="random", eval_channel=0, axis=1):
self.train_channel = train_channel
self.eval_channel = eval_channel
self.axis = axis
def __repr__(self):
return ("{name}(train_channel={train_channel}, "
"eval_channel={eval_channel}, axis={axis})".format(
name=self.__class__.__name__,
train_channel=self.train_channel,
eval_channel=self.eval_channel,
axis=self.axis, ))
def __call__(self, x, train=True):
# Assuming x: [Time, Channel] by default
if x.ndim <= self.axis:
# If the dimension is insufficient, then unsqueeze
# (e.g [Time] -> [Time, 1])
ind = tuple(
slice(None) if i < x.ndim else None
for i in range(self.axis + 1))
x = x[ind]
if train:
channel = self.train_channel
else:
channel = self.eval_channel
if channel == "random":
ch = numpy.random.randint(0, x.shape[self.axis])
else:
ch = channel
ind = tuple(
slice(None) if i != self.axis else ch for i in range(x.ndim))
return x[ind]

@ -0,0 +1,201 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import io
import json
import h5py
import kaldiio
import numpy as np
class CMVN():
"Apply Global/Spk CMVN/iverserCMVN."
def __init__(
self,
stats,
norm_means=True,
norm_vars=False,
filetype="mat",
utt2spk=None,
spk2utt=None,
reverse=False,
std_floor=1.0e-20, ):
self.stats_file = stats
self.norm_means = norm_means
self.norm_vars = norm_vars
self.reverse = reverse
if isinstance(stats, dict):
stats_dict = dict(stats)
else:
# Use for global CMVN
if filetype == "mat":
stats_dict = {None: kaldiio.load_mat(stats)}
# Use for global CMVN
elif filetype == "npy":
stats_dict = {None: np.load(stats)}
# Use for speaker CMVN
elif filetype == "ark":
self.accept_uttid = True
stats_dict = dict(kaldiio.load_ark(stats))
# Use for speaker CMVN
elif filetype == "hdf5":
self.accept_uttid = True
stats_dict = h5py.File(stats)
else:
raise ValueError("Not supporting filetype={}".format(filetype))
if utt2spk is not None:
self.utt2spk = {}
with io.open(utt2spk, "r", encoding="utf-8") as f:
for line in f:
utt, spk = line.rstrip().split(None, 1)
self.utt2spk[utt] = spk
elif spk2utt is not None:
self.utt2spk = {}
with io.open(spk2utt, "r", encoding="utf-8") as f:
for line in f:
spk, utts = line.rstrip().split(None, 1)
for utt in utts.split():
self.utt2spk[utt] = spk
else:
self.utt2spk = None
# Kaldi makes a matrix for CMVN which has a shape of (2, feat_dim + 1),
# and the first vector contains the sum of feats and the second is
# the sum of squares. The last value of the first, i.e. stats[0,-1],
# is the number of samples for this statistics.
self.bias = {}
self.scale = {}
for spk, stats in stats_dict.items():
assert len(stats) == 2, stats.shape
count = stats[0, -1]
# If the feature has two or more dimensions
if not (np.isscalar(count) or isinstance(count, (int, float))):
# The first is only used
count = count.flatten()[0]
mean = stats[0, :-1] / count
# V(x) = E(x^2) - (E(x))^2
var = stats[1, :-1] / count - mean * mean
std = np.maximum(np.sqrt(var), std_floor)
self.bias[spk] = -mean
self.scale[spk] = 1 / std
def __repr__(self):
return ("{name}(stats_file={stats_file}, "
"norm_means={norm_means}, norm_vars={norm_vars}, "
"reverse={reverse})".format(
name=self.__class__.__name__,
stats_file=self.stats_file,
norm_means=self.norm_means,
norm_vars=self.norm_vars,
reverse=self.reverse, ))
def __call__(self, x, uttid=None):
if self.utt2spk is not None:
spk = self.utt2spk[uttid]
else:
spk = uttid
if not self.reverse:
# apply cmvn
if self.norm_means:
x = np.add(x, self.bias[spk])
if self.norm_vars:
x = np.multiply(x, self.scale[spk])
else:
# apply reverse cmvn
if self.norm_vars:
x = np.divide(x, self.scale[spk])
if self.norm_means:
x = np.subtract(x, self.bias[spk])
return x
class UtteranceCMVN():
"Apply Utterance CMVN"
def __init__(self, norm_means=True, norm_vars=False, std_floor=1.0e-20):
self.norm_means = norm_means
self.norm_vars = norm_vars
self.std_floor = std_floor
def __repr__(self):
return "{name}(norm_means={norm_means}, norm_vars={norm_vars})".format(
name=self.__class__.__name__,
norm_means=self.norm_means,
norm_vars=self.norm_vars, )
def __call__(self, x, uttid=None):
# x: [Time, Dim]
square_sums = (x**2).sum(axis=0)
mean = x.mean(axis=0)
if self.norm_means:
x = np.subtract(x, mean)
if self.norm_vars:
var = square_sums / x.shape[0] - mean**2
std = np.maximum(np.sqrt(var), self.std_floor)
x = np.divide(x, std)
return x
class GlobalCMVN():
"Apply Global CMVN"
def __init__(self,
cmvn_path,
norm_means=True,
norm_vars=True,
std_floor=1.0e-20):
# cmvn_path: Option[str, dict]
cmvn = cmvn_path
self.cmvn = cmvn
self.norm_means = norm_means
self.norm_vars = norm_vars
self.std_floor = std_floor
if isinstance(cmvn, dict):
cmvn_stats = cmvn
else:
with open(cmvn) as f:
cmvn_stats = json.load(f)
self.count = cmvn_stats['frame_num']
self.mean = np.array(cmvn_stats['mean_stat']) / self.count
self.square_sums = np.array(cmvn_stats['var_stat'])
self.var = self.square_sums / self.count - self.mean**2
self.std = np.maximum(np.sqrt(self.var), self.std_floor)
def __repr__(self):
return f"""{self.__class__.__name__}(
cmvn_path={self.cmvn},
norm_means={self.norm_means},
norm_vars={self.norm_vars},)"""
def __call__(self, x, uttid=None):
# x: [Time, Dim]
if self.norm_means:
x = np.subtract(x, self.mean)
if self.norm_vars:
x = np.divide(x, self.std)
return x

@ -0,0 +1,86 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import inspect
from paddlespeech.audio.transform.transform_interface import TransformInterface
from paddlespeech.audio.utils.check_kwargs import check_kwargs
class FuncTrans(TransformInterface):
"""Functional Transformation
WARNING:
Builtin or C/C++ functions may not work properly
because this class heavily depends on the `inspect` module.
Usage:
>>> def foo_bar(x, a=1, b=2):
... '''Foo bar
... :param x: input
... :param int a: default 1
... :param int b: default 2
... '''
... return x + a - b
>>> class FooBar(FuncTrans):
... _func = foo_bar
... __doc__ = foo_bar.__doc__
"""
_func = None
def __init__(self, **kwargs):
self.kwargs = kwargs
check_kwargs(self.func, kwargs)
def __call__(self, x):
return self.func(x, **self.kwargs)
@classmethod
def add_arguments(cls, parser):
fname = cls._func.__name__.replace("_", "-")
group = parser.add_argument_group(fname + " transformation setting")
for k, v in cls.default_params().items():
# TODO(karita): get help and choices from docstring?
attr = k.replace("_", "-")
group.add_argument(f"--{fname}-{attr}", default=v, type=type(v))
return parser
@property
def func(self):
return type(self)._func
@classmethod
def default_params(cls):
try:
d = dict(inspect.signature(cls._func).parameters)
except ValueError:
d = dict()
return {
k: v.default
for k, v in d.items() if v.default != inspect.Parameter.empty
}
def __repr__(self):
params = self.default_params()
params.update(**self.kwargs)
ret = self.__class__.__name__ + "("
if len(params) == 0:
return ret + ")"
for k, v in params.items():
ret += "{}={}, ".format(k, v)
return ret[:-2] + ")"

@ -0,0 +1,561 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import librosa
import numpy
import scipy
import soundfile
import io
import os
import h5py
import numpy as np
class SoundHDF5File():
"""Collecting sound files to a HDF5 file
>>> f = SoundHDF5File('a.flac.h5', mode='a')
>>> array = np.random.randint(0, 100, 100, dtype=np.int16)
>>> f['id'] = (array, 16000)
>>> array, rate = f['id']
:param: str filepath:
:param: str mode:
:param: str format: The type used when saving wav. flac, nist, htk, etc.
:param: str dtype:
"""
def __init__(self,
filepath,
mode="r+",
format=None,
dtype="int16",
**kwargs):
self.filepath = filepath
self.mode = mode
self.dtype = dtype
self.file = h5py.File(filepath, mode, **kwargs)
if format is None:
# filepath = a.flac.h5 -> format = flac
second_ext = os.path.splitext(os.path.splitext(filepath)[0])[1]
format = second_ext[1:]
if format.upper() not in soundfile.available_formats():
# If not found, flac is selected
format = "flac"
# This format affects only saving
self.format = format
def __repr__(self):
return '<SoundHDF5 file "{}" (mode {}, format {}, type {})>'.format(
self.filepath, self.mode, self.format, self.dtype)
def create_dataset(self, name, shape=None, data=None, **kwds):
f = io.BytesIO()
array, rate = data
soundfile.write(f, array, rate, format=self.format)
self.file.create_dataset(
name, shape=shape, data=np.void(f.getvalue()), **kwds)
def __setitem__(self, name, data):
self.create_dataset(name, data=data)
def __getitem__(self, key):
data = self.file[key][()]
f = io.BytesIO(data.tobytes())
array, rate = soundfile.read(f, dtype=self.dtype)
return array, rate
def keys(self):
return self.file.keys()
def values(self):
for k in self.file:
yield self[k]
def items(self):
for k in self.file:
yield k, self[k]
def __iter__(self):
return iter(self.file)
def __contains__(self, item):
return item in self.file
def __len__(self, item):
return len(self.file)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.file.close()
def close(self):
self.file.close()
class SpeedPerturbation():
"""SpeedPerturbation
The speed perturbation in kaldi uses sox-speed instead of sox-tempo,
and sox-speed just to resample the input,
i.e pitch and tempo are changed both.
"Why use speed option instead of tempo -s in SoX for speed perturbation"
https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8
Warning:
This function is very slow because of resampling.
I recommmend to apply speed-perturb outside the training using sox.
"""
def __init__(
self,
lower=0.9,
upper=1.1,
utt2ratio=None,
keep_length=True,
res_type="kaiser_best",
seed=None, ):
self.res_type = res_type
self.keep_length = keep_length
self.state = numpy.random.RandomState(seed)
if utt2ratio is not None:
self.utt2ratio = {}
# Use the scheduled ratio for each utterances
self.utt2ratio_file = utt2ratio
self.lower = None
self.upper = None
self.accept_uttid = True
with open(utt2ratio, "r") as f:
for line in f:
utt, ratio = line.rstrip().split(None, 1)
ratio = float(ratio)
self.utt2ratio[utt] = ratio
else:
self.utt2ratio = None
# The ratio is given on runtime randomly
self.lower = lower
self.upper = upper
def __repr__(self):
if self.utt2ratio is None:
return "{}(lower={}, upper={}, " "keep_length={}, res_type={})".format(
self.__class__.__name__,
self.lower,
self.upper,
self.keep_length,
self.res_type, )
else:
return "{}({}, res_type={})".format(
self.__class__.__name__, self.utt2ratio_file, self.res_type)
def __call__(self, x, uttid=None, train=True):
if not train:
return x
x = x.astype(numpy.float32)
if self.accept_uttid:
ratio = self.utt2ratio[uttid]
else:
ratio = self.state.uniform(self.lower, self.upper)
# Note1: resample requires the sampling-rate of input and output,
# but actually only the ratio is used.
y = librosa.resample(
x, orig_sr=ratio, target_sr=1, res_type=self.res_type)
if self.keep_length:
diff = abs(len(x) - len(y))
if len(y) > len(x):
# Truncate noise
y = y[diff // 2:-((diff + 1) // 2)]
elif len(y) < len(x):
# Assume the time-axis is the first: (Time, Channel)
pad_width = [(diff // 2, (diff + 1) // 2)] + [
(0, 0) for _ in range(y.ndim - 1)
]
y = numpy.pad(
y, pad_width=pad_width, constant_values=0, mode="constant")
return y
class SpeedPerturbationSox():
"""SpeedPerturbationSox
The speed perturbation in kaldi uses sox-speed instead of sox-tempo,
and sox-speed just to resample the input,
i.e pitch and tempo are changed both.
To speed up or slow down the sound of a file,
use speed to modify the pitch and the duration of the file.
This raises the speed and reduces the time.
The default factor is 1.0 which makes no change to the audio.
2.0 doubles speed, thus time length is cut by a half and pitch is one interval higher.
"Why use speed option instead of tempo -s in SoX for speed perturbation"
https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8
tempo option:
sox -t wav input.wav -t wav output.tempo0.9.wav tempo -s 0.9
speed option:
sox -t wav input.wav -t wav output.speed0.9.wav speed 0.9
If we use speed option like above, the pitch of audio also will be changed,
but the tempo option does not change the pitch.
"""
def __init__(
self,
lower=0.9,
upper=1.1,
utt2ratio=None,
keep_length=True,
sr=16000,
seed=None, ):
self.sr = sr
self.keep_length = keep_length
self.state = numpy.random.RandomState(seed)
try:
import soxbindings as sox
except ImportError:
try:
from paddlespeech.s2t.utils import dynamic_pip_install
package = "sox"
dynamic_pip_install.install(package)
package = "soxbindings"
if sys.platform != "win32":
dynamic_pip_install.install(package)
import soxbindings as sox
except Exception:
raise RuntimeError(
"Can not install soxbindings on your system.")
self.sox = sox
if utt2ratio is not None:
self.utt2ratio = {}
# Use the scheduled ratio for each utterances
self.utt2ratio_file = utt2ratio
self.lower = None
self.upper = None
self.accept_uttid = True
with open(utt2ratio, "r") as f:
for line in f:
utt, ratio = line.rstrip().split(None, 1)
ratio = float(ratio)
self.utt2ratio[utt] = ratio
else:
self.utt2ratio = None
# The ratio is given on runtime randomly
self.lower = lower
self.upper = upper
def __repr__(self):
if self.utt2ratio is None:
return f"""{self.__class__.__name__}(
lower={self.lower},
upper={self.upper},
keep_length={self.keep_length},
sample_rate={self.sr})"""
else:
return f"""{self.__class__.__name__}(
utt2ratio={self.utt2ratio_file},
sample_rate={self.sr})"""
def __call__(self, x, uttid=None, train=True):
if not train:
return x
x = x.astype(numpy.float32)
if self.accept_uttid:
ratio = self.utt2ratio[uttid]
else:
ratio = self.state.uniform(self.lower, self.upper)
tfm = self.sox.Transformer()
tfm.set_globals(multithread=False)
tfm.speed(ratio)
y = tfm.build_array(input_array=x, sample_rate_in=self.sr)
if self.keep_length:
diff = abs(len(x) - len(y))
if len(y) > len(x):
# Truncate noise
y = y[diff // 2:-((diff + 1) // 2)]
elif len(y) < len(x):
# Assume the time-axis is the first: (Time, Channel)
pad_width = [(diff // 2, (diff + 1) // 2)] + [
(0, 0) for _ in range(y.ndim - 1)
]
y = numpy.pad(
y, pad_width=pad_width, constant_values=0, mode="constant")
if y.ndim == 2 and x.ndim == 1:
# (T, C) -> (T)
y = y.sequence(1)
return y
class BandpassPerturbation():
"""BandpassPerturbation
Randomly dropout along the frequency axis.
The original idea comes from the following:
"randomly-selected frequency band was cut off under the constraint of
leaving at least 1,000 Hz band within the range of less than 4,000Hz."
(The Hitachi/JHU CHiME-5 system: Advances in speech recognition for
everyday home environments using multiple microphone arrays;
http://spandh.dcs.shef.ac.uk/chime_workshop/papers/CHiME_2018_paper_kanda.pdf)
"""
def __init__(self, lower=0.0, upper=0.75, seed=None, axes=(-1, )):
self.lower = lower
self.upper = upper
self.state = numpy.random.RandomState(seed)
# x_stft: (Time, Channel, Freq)
self.axes = axes
def __repr__(self):
return "{}(lower={}, upper={})".format(self.__class__.__name__,
self.lower, self.upper)
def __call__(self, x_stft, uttid=None, train=True):
if not train:
return x_stft
if x_stft.ndim == 1:
raise RuntimeError("Input in time-freq domain: "
"(Time, Channel, Freq) or (Time, Freq)")
ratio = self.state.uniform(self.lower, self.upper)
axes = [i if i >= 0 else x_stft.ndim - i for i in self.axes]
shape = [s if i in axes else 1 for i, s in enumerate(x_stft.shape)]
mask = self.state.randn(*shape) > ratio
x_stft *= mask
return x_stft
class VolumePerturbation():
def __init__(self,
lower=-1.6,
upper=1.6,
utt2ratio=None,
dbunit=True,
seed=None):
self.dbunit = dbunit
self.utt2ratio_file = utt2ratio
self.lower = lower
self.upper = upper
self.state = numpy.random.RandomState(seed)
if utt2ratio is not None:
# Use the scheduled ratio for each utterances
self.utt2ratio = {}
self.lower = None
self.upper = None
self.accept_uttid = True
with open(utt2ratio, "r") as f:
for line in f:
utt, ratio = line.rstrip().split(None, 1)
ratio = float(ratio)
self.utt2ratio[utt] = ratio
else:
# The ratio is given on runtime randomly
self.utt2ratio = None
def __repr__(self):
if self.utt2ratio is None:
return "{}(lower={}, upper={}, dbunit={})".format(
self.__class__.__name__, self.lower, self.upper, self.dbunit)
else:
return '{}("{}", dbunit={})'.format(
self.__class__.__name__, self.utt2ratio_file, self.dbunit)
def __call__(self, x, uttid=None, train=True):
if not train:
return x
x = x.astype(numpy.float32)
if self.accept_uttid:
ratio = self.utt2ratio[uttid]
else:
ratio = self.state.uniform(self.lower, self.upper)
if self.dbunit:
ratio = 10**(ratio / 20)
return x * ratio
class NoiseInjection():
"""Add isotropic noise"""
def __init__(
self,
utt2noise=None,
lower=-20,
upper=-5,
utt2ratio=None,
filetype="list",
dbunit=True,
seed=None, ):
self.utt2noise_file = utt2noise
self.utt2ratio_file = utt2ratio
self.filetype = filetype
self.dbunit = dbunit
self.lower = lower
self.upper = upper
self.state = numpy.random.RandomState(seed)
if utt2ratio is not None:
# Use the scheduled ratio for each utterances
self.utt2ratio = {}
with open(utt2noise, "r") as f:
for line in f:
utt, snr = line.rstrip().split(None, 1)
snr = float(snr)
self.utt2ratio[utt] = snr
else:
# The ratio is given on runtime randomly
self.utt2ratio = None
if utt2noise is not None:
self.utt2noise = {}
if filetype == "list":
with open(utt2noise, "r") as f:
for line in f:
utt, filename = line.rstrip().split(None, 1)
signal, rate = soundfile.read(filename, dtype="int16")
# Load all files in memory
self.utt2noise[utt] = (signal, rate)
elif filetype == "sound.hdf5":
self.utt2noise = SoundHDF5File(utt2noise, "r")
else:
raise ValueError(filetype)
else:
self.utt2noise = None
if utt2noise is not None and utt2ratio is not None:
if set(self.utt2ratio) != set(self.utt2noise):
raise RuntimeError("The uttids mismatch between {} and {}".
format(utt2ratio, utt2noise))
def __repr__(self):
if self.utt2ratio is None:
return "{}(lower={}, upper={}, dbunit={})".format(
self.__class__.__name__, self.lower, self.upper, self.dbunit)
else:
return '{}("{}", dbunit={})'.format(
self.__class__.__name__, self.utt2ratio_file, self.dbunit)
def __call__(self, x, uttid=None, train=True):
if not train:
return x
x = x.astype(numpy.float32)
# 1. Get ratio of noise to signal in sound pressure level
if uttid is not None and self.utt2ratio is not None:
ratio = self.utt2ratio[uttid]
else:
ratio = self.state.uniform(self.lower, self.upper)
if self.dbunit:
ratio = 10**(ratio / 20)
scale = ratio * numpy.sqrt((x**2).mean())
# 2. Get noise
if self.utt2noise is not None:
# Get noise from the external source
if uttid is not None:
noise, rate = self.utt2noise[uttid]
else:
# Randomly select the noise source
noise = self.state.choice(list(self.utt2noise.values()))
# Normalize the level
noise /= numpy.sqrt((noise**2).mean())
# Adjust the noise length
diff = abs(len(x) - len(noise))
offset = self.state.randint(0, diff)
if len(noise) > len(x):
# Truncate noise
noise = noise[offset:-(diff - offset)]
else:
noise = numpy.pad(
noise, pad_width=[offset, diff - offset], mode="wrap")
else:
# Generate white noise
noise = self.state.normal(0, 1, x.shape)
# 3. Add noise to signal
return x + noise * scale
class RIRConvolve():
def __init__(self, utt2rir, filetype="list"):
self.utt2rir_file = utt2rir
self.filetype = filetype
self.utt2rir = {}
if filetype == "list":
with open(utt2rir, "r") as f:
for line in f:
utt, filename = line.rstrip().split(None, 1)
signal, rate = soundfile.read(filename, dtype="int16")
self.utt2rir[utt] = (signal, rate)
elif filetype == "sound.hdf5":
self.utt2rir = SoundHDF5File(utt2rir, "r")
else:
raise NotImplementedError(filetype)
def __repr__(self):
return '{}("{}")'.format(self.__class__.__name__, self.utt2rir_file)
def __call__(self, x, uttid=None, train=True):
if not train:
return x
x = x.astype(numpy.float32)
if x.ndim != 1:
# Must be single channel
raise RuntimeError(
"Input x must be one dimensional array, but got {}".format(
x.shape))
rir, rate = self.utt2rir[uttid]
if rir.ndim == 2:
# FIXME(kamo): Use chainer.convolution_1d?
# return [Time, Channel]
return numpy.stack(
[scipy.convolve(x, r, mode="same") for r in rir], axis=-1)
else:
return scipy.convolve(x, rir, mode="same")

@ -0,0 +1,214 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
"""Spec Augment module for preprocessing i.e., data augmentation"""
import random
import numpy
from PIL import Image
from PIL.Image import BICUBIC
from .functional import FuncTrans
def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"):
"""time warp for spec augment
move random center frame by the random width ~ uniform(-window, window)
:param numpy.ndarray x: spectrogram (time, freq)
:param int max_time_warp: maximum time frames to warp
:param bool inplace: overwrite x with the result
:param str mode: "PIL" (default, fast, not differentiable) or "sparse_image_warp"
(slow, differentiable)
:returns numpy.ndarray: time warped spectrogram (time, freq)
"""
window = max_time_warp
if window == 0:
return x
if mode == "PIL":
t = x.shape[0]
if t - window <= window:
return x
# NOTE: randrange(a, b) emits a, a + 1, ..., b - 1
center = random.randrange(window, t - window)
warped = random.randrange(center - window, center +
window) + 1 # 1 ... t - 1
left = Image.fromarray(x[:center]).resize((x.shape[1], warped), BICUBIC)
right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped),
BICUBIC)
if inplace:
x[:warped] = left
x[warped:] = right
return x
return numpy.concatenate((left, right), 0)
elif mode == "sparse_image_warp":
import paddle
from espnet.utils import spec_augment
# TODO(karita): make this differentiable again
return spec_augment.time_warp(paddle.to_tensor(x), window).numpy()
else:
raise NotImplementedError("unknown resize mode: " + mode +
", choose one from (PIL, sparse_image_warp).")
class TimeWarp(FuncTrans):
_func = time_warp
__doc__ = time_warp.__doc__
def __call__(self, x, train):
if not train:
return x
return super().__call__(x)
def freq_mask(x, F=30, n_mask=2, replace_with_zero=True, inplace=False):
"""freq mask for spec agument
:param numpy.ndarray x: (time, freq)
:param int n_mask: the number of masks
:param bool inplace: overwrite
:param bool replace_with_zero: pad zero on mask if true else use mean
"""
if inplace:
cloned = x
else:
cloned = x.copy()
num_mel_channels = cloned.shape[1]
fs = numpy.random.randint(0, F, size=(n_mask, 2))
for f, mask_end in fs:
f_zero = random.randrange(0, num_mel_channels - f)
mask_end += f_zero
# avoids randrange error if values are equal and range is empty
if f_zero == f_zero + f:
continue
if replace_with_zero:
cloned[:, f_zero:mask_end] = 0
else:
cloned[:, f_zero:mask_end] = cloned.mean()
return cloned
class FreqMask(FuncTrans):
_func = freq_mask
__doc__ = freq_mask.__doc__
def __call__(self, x, train):
if not train:
return x
return super().__call__(x)
def time_mask(spec, T=40, n_mask=2, replace_with_zero=True, inplace=False):
"""freq mask for spec agument
:param numpy.ndarray spec: (time, freq)
:param int n_mask: the number of masks
:param bool inplace: overwrite
:param bool replace_with_zero: pad zero on mask if true else use mean
"""
if inplace:
cloned = spec
else:
cloned = spec.copy()
len_spectro = cloned.shape[0]
ts = numpy.random.randint(0, T, size=(n_mask, 2))
for t, mask_end in ts:
# avoid randint range error
if len_spectro - t <= 0:
continue
t_zero = random.randrange(0, len_spectro - t)
# avoids randrange error if values are equal and range is empty
if t_zero == t_zero + t:
continue
mask_end += t_zero
if replace_with_zero:
cloned[t_zero:mask_end] = 0
else:
cloned[t_zero:mask_end] = cloned.mean()
return cloned
class TimeMask(FuncTrans):
_func = time_mask
__doc__ = time_mask.__doc__
def __call__(self, x, train):
if not train:
return x
return super().__call__(x)
def spec_augment(
x,
resize_mode="PIL",
max_time_warp=80,
max_freq_width=27,
n_freq_mask=2,
max_time_width=100,
n_time_mask=2,
inplace=True,
replace_with_zero=True, ):
"""spec agument
apply random time warping and time/freq masking
default setting is based on LD (Librispeech double) in Table 2
https://arxiv.org/pdf/1904.08779.pdf
:param numpy.ndarray x: (time, freq)
:param str resize_mode: "PIL" (fast, nondifferentiable) or "sparse_image_warp"
(slow, differentiable)
:param int max_time_warp: maximum frames to warp the center frame in spectrogram (W)
:param int freq_mask_width: maximum width of the random freq mask (F)
:param int n_freq_mask: the number of the random freq mask (m_F)
:param int time_mask_width: maximum width of the random time mask (T)
:param int n_time_mask: the number of the random time mask (m_T)
:param bool inplace: overwrite intermediate array
:param bool replace_with_zero: pad zero on mask if true else use mean
"""
assert isinstance(x, numpy.ndarray)
assert x.ndim == 2
x = time_warp(x, max_time_warp, inplace=inplace, mode=resize_mode)
x = freq_mask(
x,
max_freq_width,
n_freq_mask,
inplace=inplace,
replace_with_zero=replace_with_zero, )
x = time_mask(
x,
max_time_width,
n_time_mask,
inplace=inplace,
replace_with_zero=replace_with_zero, )
return x
class SpecAugment(FuncTrans):
_func = spec_augment
__doc__ = spec_augment.__doc__
def __call__(self, x, train):
if not train:
return x
return super().__call__(x)

@ -0,0 +1,475 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import librosa
import numpy as np
import paddle
from python_speech_features import logfbank
from ..compliance import kaldi
def stft(x,
n_fft,
n_shift,
win_length=None,
window="hann",
center=True,
pad_mode="reflect"):
# x: [Time, Channel]
if x.ndim == 1:
single_channel = True
# x: [Time] -> [Time, Channel]
x = x[:, None]
else:
single_channel = False
x = x.astype(np.float32)
# FIXME(kamo): librosa.stft can't use multi-channel?
# x: [Time, Channel, Freq]
x = np.stack(
[
librosa.stft(
y=x[:, ch],
n_fft=n_fft,
hop_length=n_shift,
win_length=win_length,
window=window,
center=center,
pad_mode=pad_mode, ).T for ch in range(x.shape[1])
],
axis=1, )
if single_channel:
# x: [Time, Channel, Freq] -> [Time, Freq]
x = x[:, 0]
return x
def istft(x, n_shift, win_length=None, window="hann", center=True):
# x: [Time, Channel, Freq]
if x.ndim == 2:
single_channel = True
# x: [Time, Freq] -> [Time, Channel, Freq]
x = x[:, None, :]
else:
single_channel = False
# x: [Time, Channel]
x = np.stack(
[
librosa.istft(
stft_matrix=x[:, ch].T, # [Time, Freq] -> [Freq, Time]
hop_length=n_shift,
win_length=win_length,
window=window,
center=center, ) for ch in range(x.shape[1])
],
axis=1, )
if single_channel:
# x: [Time, Channel] -> [Time]
x = x[:, 0]
return x
def stft2logmelspectrogram(x_stft,
fs,
n_mels,
n_fft,
fmin=None,
fmax=None,
eps=1e-10):
# x_stft: (Time, Channel, Freq) or (Time, Freq)
fmin = 0 if fmin is None else fmin
fmax = fs / 2 if fmax is None else fmax
# spc: (Time, Channel, Freq) or (Time, Freq)
spc = np.abs(x_stft)
# mel_basis: (Mel_freq, Freq)
mel_basis = librosa.filters.mel(
sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
# lmspc: (Time, Channel, Mel_freq) or (Time, Mel_freq)
lmspc = np.log10(np.maximum(eps, np.dot(spc, mel_basis.T)))
return lmspc
def spectrogram(x, n_fft, n_shift, win_length=None, window="hann"):
# x: (Time, Channel) -> spc: (Time, Channel, Freq)
spc = np.abs(stft(x, n_fft, n_shift, win_length, window=window))
return spc
def logmelspectrogram(
x,
fs,
n_mels,
n_fft,
n_shift,
win_length=None,
window="hann",
fmin=None,
fmax=None,
eps=1e-10,
pad_mode="reflect", ):
# stft: (Time, Channel, Freq) or (Time, Freq)
x_stft = stft(
x,
n_fft=n_fft,
n_shift=n_shift,
win_length=win_length,
window=window,
pad_mode=pad_mode, )
return stft2logmelspectrogram(
x_stft,
fs=fs,
n_mels=n_mels,
n_fft=n_fft,
fmin=fmin,
fmax=fmax,
eps=eps)
class Spectrogram():
def __init__(self, n_fft, n_shift, win_length=None, window="hann"):
self.n_fft = n_fft
self.n_shift = n_shift
self.win_length = win_length
self.window = window
def __repr__(self):
return ("{name}(n_fft={n_fft}, n_shift={n_shift}, "
"win_length={win_length}, window={window})".format(
name=self.__class__.__name__,
n_fft=self.n_fft,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window, ))
def __call__(self, x):
return spectrogram(
x,
n_fft=self.n_fft,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window, )
class LogMelSpectrogram():
def __init__(
self,
fs,
n_mels,
n_fft,
n_shift,
win_length=None,
window="hann",
fmin=None,
fmax=None,
eps=1e-10, ):
self.fs = fs
self.n_mels = n_mels
self.n_fft = n_fft
self.n_shift = n_shift
self.win_length = win_length
self.window = window
self.fmin = fmin
self.fmax = fmax
self.eps = eps
def __repr__(self):
return ("{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
"n_shift={n_shift}, win_length={win_length}, window={window}, "
"fmin={fmin}, fmax={fmax}, eps={eps}))".format(
name=self.__class__.__name__,
fs=self.fs,
n_mels=self.n_mels,
n_fft=self.n_fft,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window,
fmin=self.fmin,
fmax=self.fmax,
eps=self.eps, ))
def __call__(self, x):
return logmelspectrogram(
x,
fs=self.fs,
n_mels=self.n_mels,
n_fft=self.n_fft,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window, )
class Stft2LogMelSpectrogram():
def __init__(self, fs, n_mels, n_fft, fmin=None, fmax=None, eps=1e-10):
self.fs = fs
self.n_mels = n_mels
self.n_fft = n_fft
self.fmin = fmin
self.fmax = fmax
self.eps = eps
def __repr__(self):
return ("{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
"fmin={fmin}, fmax={fmax}, eps={eps}))".format(
name=self.__class__.__name__,
fs=self.fs,
n_mels=self.n_mels,
n_fft=self.n_fft,
fmin=self.fmin,
fmax=self.fmax,
eps=self.eps, ))
def __call__(self, x):
return stft2logmelspectrogram(
x,
fs=self.fs,
n_mels=self.n_mels,
n_fft=self.n_fft,
fmin=self.fmin,
fmax=self.fmax, )
class Stft():
def __init__(
self,
n_fft,
n_shift,
win_length=None,
window="hann",
center=True,
pad_mode="reflect", ):
self.n_fft = n_fft
self.n_shift = n_shift
self.win_length = win_length
self.window = window
self.center = center
self.pad_mode = pad_mode
def __repr__(self):
return ("{name}(n_fft={n_fft}, n_shift={n_shift}, "
"win_length={win_length}, window={window},"
"center={center}, pad_mode={pad_mode})".format(
name=self.__class__.__name__,
n_fft=self.n_fft,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window,
center=self.center,
pad_mode=self.pad_mode, ))
def __call__(self, x):
return stft(
x,
self.n_fft,
self.n_shift,
win_length=self.win_length,
window=self.window,
center=self.center,
pad_mode=self.pad_mode, )
class IStft():
def __init__(self, n_shift, win_length=None, window="hann", center=True):
self.n_shift = n_shift
self.win_length = win_length
self.window = window
self.center = center
def __repr__(self):
return ("{name}(n_shift={n_shift}, "
"win_length={win_length}, window={window},"
"center={center})".format(
name=self.__class__.__name__,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window,
center=self.center, ))
def __call__(self, x):
return istft(
x,
self.n_shift,
win_length=self.win_length,
window=self.window,
center=self.center, )
class LogMelSpectrogramKaldi():
def __init__(
self,
fs=16000,
n_mels=80,
n_shift=160, # unit:sample, 10ms
win_length=400, # unit:sample, 25ms
energy_floor=0.0,
dither=0.1):
"""
The Kaldi implementation of LogMelSpectrogram
Args:
fs (int): sample rate of the audio
n_mels (int): number of mel filter banks
n_shift (int): number of points in a frame shift
win_length (int): number of points in a frame windows
energy_floor (float): Floor on energy in Spectrogram computation (absolute)
dither (float): Dithering constant
Returns:
LogMelSpectrogramKaldi
"""
self.fs = fs
self.n_mels = n_mels
num_point_ms = fs / 1000
self.n_frame_length = win_length / num_point_ms
self.n_frame_shift = n_shift / num_point_ms
self.energy_floor = energy_floor
self.dither = dither
def __repr__(self):
return (
"{name}(fs={fs}, n_mels={n_mels}, "
"n_frame_shift={n_frame_shift}, n_frame_length={n_frame_length}, "
"dither={dither}))".format(
name=self.__class__.__name__,
fs=self.fs,
n_mels=self.n_mels,
n_frame_shift=self.n_frame_shift,
n_frame_length=self.n_frame_length,
dither=self.dither, ))
def __call__(self, x, train):
"""
Args:
x (np.ndarray): shape (Ti,)
train (bool): True, train mode.
Raises:
ValueError: not support (Ti, C)
Returns:
np.ndarray: (T, D)
"""
dither = self.dither if train else 0.0
if x.ndim != 1:
raise ValueError("Not support x: [Time, Channel]")
waveform = paddle.to_tensor(np.expand_dims(x, 0), dtype=paddle.float32)
mat = kaldi.fbank(
waveform,
n_mels=self.n_mels,
frame_length=self.n_frame_length,
frame_shift=self.n_frame_shift,
dither=dither,
energy_floor=self.energy_floor,
sr=self.fs)
mat = np.squeeze(mat.numpy())
return mat
class LogMelSpectrogramKaldi_decay():
def __init__(
self,
fs=16000,
n_mels=80,
n_fft=512, # fft point
n_shift=160, # unit:sample, 10ms
win_length=400, # unit:sample, 25ms
window="povey",
fmin=20,
fmax=None,
eps=1e-10,
dither=1.0):
self.fs = fs
self.n_mels = n_mels
self.n_fft = n_fft
if n_shift > win_length:
raise ValueError("Stride size must not be greater than "
"window size.")
self.n_shift = n_shift / fs # unit: ms
self.win_length = win_length / fs # unit: ms
self.window = window
self.fmin = fmin
if fmax is None:
fmax_ = fmax if fmax else self.fs / 2
elif fmax > int(self.fs / 2):
raise ValueError("fmax must not be greater than half of "
"sample rate.")
self.fmax = fmax_
self.eps = eps
self.remove_dc_offset = True
self.preemph = 0.97
self.dither = dither # only work in train mode
def __repr__(self):
return (
"{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
"n_shift={n_shift}, win_length={win_length}, preemph={preemph}, window={window}, "
"fmin={fmin}, fmax={fmax}, eps={eps}, dither={dither}))".format(
name=self.__class__.__name__,
fs=self.fs,
n_mels=self.n_mels,
n_fft=self.n_fft,
n_shift=self.n_shift,
preemph=self.preemph,
win_length=self.win_length,
window=self.window,
fmin=self.fmin,
fmax=self.fmax,
eps=self.eps,
dither=self.dither, ))
def __call__(self, x, train):
"""
Args:
x (np.ndarray): shape (Ti,)
train (bool): True, train mode.
Raises:
ValueError: not support (Ti, C)
Returns:
np.ndarray: (T, D)
"""
dither = self.dither if train else 0.0
if x.ndim != 1:
raise ValueError("Not support x: [Time, Channel]")
if x.dtype in np.sctypes['float']:
# PCM32 -> PCM16
bits = np.iinfo(np.int16).bits
x = x * 2**(bits - 1)
# logfbank need PCM16 input
y = logfbank(
signal=x,
samplerate=self.fs,
winlen=self.win_length, # unit ms
winstep=self.n_shift, # unit ms
nfilt=self.n_mels,
nfft=self.n_fft,
lowfreq=self.fmin,
highfreq=self.fmax,
dither=dither,
remove_dc_offset=self.remove_dc_offset,
preemph=self.preemph,
wintype=self.window)
return y

@ -0,0 +1,35 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
class TransformInterface:
"""Transform Interface"""
def __call__(self, x):
raise NotImplementedError("__call__ method is not implemented")
@classmethod
def add_arguments(cls, parser):
return parser
def __repr__(self):
return self.__class__.__name__ + "()"
class Identity(TransformInterface):
"""Identity Function"""
def __call__(self, x):
return x

@ -0,0 +1,158 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
"""Transformation module."""
import copy
import io
import logging
from collections import OrderedDict
from collections.abc import Sequence
from inspect import signature
import yaml
from ..utils.dynamic_import import dynamic_import
import_alias = dict(
identity="paddlespeech.audio.transform.transform_interface:Identity",
time_warp="paddlespeech.audio.transform.spec_augment:TimeWarp",
time_mask="paddlespeech.audio.transform.spec_augment:TimeMask",
freq_mask="paddlespeech.audio.transform.spec_augment:FreqMask",
spec_augment="paddlespeech.audio.transform.spec_augment:SpecAugment",
speed_perturbation="paddlespeech.audio.transform.perturb:SpeedPerturbation",
speed_perturbation_sox="paddlespeech.audio.transform.perturb:SpeedPerturbationSox",
volume_perturbation="paddlespeech.audio.transform.perturb:VolumePerturbation",
noise_injection="paddlespeech.audio.transform.perturb:NoiseInjection",
bandpass_perturbation="paddlespeech.audio.transform.perturb:BandpassPerturbation",
rir_convolve="paddlespeech.audio.transform.perturb:RIRConvolve",
delta="paddlespeech.audio.transform.add_deltas:AddDeltas",
cmvn="paddlespeech.audio.transform.cmvn:CMVN",
utterance_cmvn="paddlespeech.audio.transform.cmvn:UtteranceCMVN",
fbank="paddlespeech.audio.transform.spectrogram:LogMelSpectrogram",
spectrogram="paddlespeech.audio.transform.spectrogram:Spectrogram",
stft="paddlespeech.audio.transform.spectrogram:Stft",
istft="paddlespeech.audio.transform.spectrogram:IStft",
stft2fbank="paddlespeech.audio.transform.spectrogram:Stft2LogMelSpectrogram",
wpe="paddlespeech.audio.transform.wpe:WPE",
channel_selector="paddlespeech.audio.transform.channel_selector:ChannelSelector",
fbank_kaldi="paddlespeech.audio.transform.spectrogram:LogMelSpectrogramKaldi",
cmvn_json="paddlespeech.audio.transform.cmvn:GlobalCMVN")
class Transformation():
"""Apply some functions to the mini-batch
Examples:
>>> kwargs = {"process": [{"type": "fbank",
... "n_mels": 80,
... "fs": 16000},
... {"type": "cmvn",
... "stats": "data/train/cmvn.ark",
... "norm_vars": True},
... {"type": "delta", "window": 2, "order": 2}]}
>>> transform = Transformation(kwargs)
>>> bs = 10
>>> xs = [np.random.randn(100, 80).astype(np.float32)
... for _ in range(bs)]
>>> xs = transform(xs)
"""
def __init__(self, conffile=None):
if conffile is not None:
if isinstance(conffile, dict):
self.conf = copy.deepcopy(conffile)
else:
with io.open(conffile, encoding="utf-8") as f:
self.conf = yaml.safe_load(f)
assert isinstance(self.conf, dict), type(self.conf)
else:
self.conf = {"mode": "sequential", "process": []}
self.functions = OrderedDict()
if self.conf.get("mode", "sequential") == "sequential":
for idx, process in enumerate(self.conf["process"]):
assert isinstance(process, dict), type(process)
opts = dict(process)
process_type = opts.pop("type")
class_obj = dynamic_import(process_type, import_alias)
# TODO(karita): assert issubclass(class_obj, TransformInterface)
try:
self.functions[idx] = class_obj(**opts)
except TypeError:
try:
signa = signature(class_obj)
except ValueError:
# Some function, e.g. built-in function, are failed
pass
else:
logging.error("Expected signature: {}({})".format(
class_obj.__name__, signa))
raise
else:
raise NotImplementedError(
"Not supporting mode={}".format(self.conf["mode"]))
def __repr__(self):
rep = "\n" + "\n".join(" {}: {}".format(k, v)
for k, v in self.functions.items())
return "{}({})".format(self.__class__.__name__, rep)
def __call__(self, xs, uttid_list=None, **kwargs):
"""Return new mini-batch
:param Union[Sequence[np.ndarray], np.ndarray] xs:
:param Union[Sequence[str], str] uttid_list:
:return: batch:
:rtype: List[np.ndarray]
"""
if not isinstance(xs, Sequence):
is_batch = False
xs = [xs]
else:
is_batch = True
if isinstance(uttid_list, str):
uttid_list = [uttid_list for _ in range(len(xs))]
if self.conf.get("mode", "sequential") == "sequential":
for idx in range(len(self.conf["process"])):
func = self.functions[idx]
# TODO(karita): use TrainingTrans and UttTrans to check __call__ args
# Derive only the args which the func has
try:
param = signature(func).parameters
except ValueError:
# Some function, e.g. built-in function, are failed
param = {}
_kwargs = {k: v for k, v in kwargs.items() if k in param}
try:
if uttid_list is not None and "uttid" in param:
xs = [
func(x, u, **_kwargs)
for x, u in zip(xs, uttid_list)
]
else:
xs = [func(x, **_kwargs) for x in xs]
except Exception:
logging.fatal("Catch a exception from {}th func: {}".format(
idx, func))
raise
else:
raise NotImplementedError(
"Not supporting mode={}".format(self.conf["mode"]))
if is_batch:
return xs
else:
return xs[0]

@ -0,0 +1,58 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
from nara_wpe.wpe import wpe
class WPE(object):
def __init__(self,
taps=10,
delay=3,
iterations=3,
psd_context=0,
statistics_mode="full"):
self.taps = taps
self.delay = delay
self.iterations = iterations
self.psd_context = psd_context
self.statistics_mode = statistics_mode
def __repr__(self):
return ("{name}(taps={taps}, delay={delay}"
"iterations={iterations}, psd_context={psd_context}, "
"statistics_mode={statistics_mode})".format(
name=self.__class__.__name__,
taps=self.taps,
delay=self.delay,
iterations=self.iterations,
psd_context=self.psd_context,
statistics_mode=self.statistics_mode, ))
def __call__(self, xs):
"""Return enhanced
:param np.ndarray xs: (Time, Channel, Frequency)
:return: enhanced_xs
:rtype: np.ndarray
"""
# nara_wpe.wpe: (F, C, T)
xs = wpe(
xs.transpose((2, 1, 0)),
taps=self.taps,
delay=self.delay,
iterations=self.iterations,
psd_context=self.psd_context,
statistics_mode=self.statistics_mode, )
return xs.transpose(2, 1, 0)

@ -0,0 +1,35 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import inspect
def check_kwargs(func, kwargs, name=None):
"""check kwargs are valid for func
If kwargs are invalid, raise TypeError as same as python default
:param function func: function to be validated
:param dict kwargs: keyword arguments for func
:param str name: name used in TypeError (default is func name)
"""
try:
params = inspect.signature(func).parameters
except ValueError:
return
if name is None:
name = func.__name__
for k in kwargs.keys():
if k not in params:
raise TypeError(
f"{name}() got an unexpected keyword argument '{k}'")

@ -0,0 +1,38 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import importlib
__all__ = ["dynamic_import"]
def dynamic_import(import_path, alias=dict()):
"""dynamic import module and class
:param str import_path: syntax 'module_name:class_name'
e.g., 'paddlespeech.s2t.models.u2:U2Model'
:param dict alias: shortcut for registered class
:return: imported class
"""
if import_path not in alias and ":" not in import_path:
raise ValueError(
"import_path should be one of {} or "
'include ":", e.g. "paddlespeech.s2t.models.u2:U2Model" : '
"{}".format(set(alias), import_path))
if ":" not in import_path:
import_path = alias[import_path]
module_name, objname = import_path.split(":")
m = importlib.import_module(module_name)
return getattr(m, objname)

@ -0,0 +1,195 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unility functions for Transformer."""
from typing import List
from typing import Tuple
import paddle
from .log import Logger
__all__ = ["pad_sequence", "add_sos_eos", "th_accuracy", "has_tensor"]
logger = Logger(__name__)
def has_tensor(val):
if isinstance(val, (list, tuple)):
for item in val:
if has_tensor(item):
return True
elif isinstance(val, dict):
for k, v in val.items():
print(k)
if has_tensor(v):
return True
else:
return paddle.is_tensor(val)
def pad_sequence(sequences: List[paddle.Tensor],
batch_first: bool=False,
padding_value: float=0.0) -> paddle.Tensor:
r"""Pad a list of variable length Tensors with ``padding_value``
``pad_sequence`` stacks a list of Tensors along a new dimension,
and pads them to equal length. For example, if the input is list of
sequences with size ``L x *`` and if batch_first is False, and ``T x B x *``
otherwise.
`B` is batch size. It is equal to the number of elements in ``sequences``.
`T` is length of the longest sequence.
`L` is length of the sequence.
`*` is any number of trailing dimensions, including none.
Example:
>>> from paddle.nn.utils.rnn import pad_sequence
>>> a = paddle.ones(25, 300)
>>> b = paddle.ones(22, 300)
>>> c = paddle.ones(15, 300)
>>> pad_sequence([a, b, c]).shape
paddle.Tensor([25, 3, 300])
Note:
This function returns a Tensor of size ``T x B x *`` or ``B x T x *``
where `T` is the length of the longest sequence. This function assumes
trailing dimensions and type of all the Tensors in sequences are same.
Args:
sequences (list[Tensor]): list of variable length sequences.
batch_first (bool, optional): output will be in ``B x T x *`` if True, or in
``T x B x *`` otherwise
padding_value (float, optional): value for padded elements. Default: 0.
Returns:
Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``.
Tensor of size ``B x T x *`` otherwise
"""
# assuming trailing dimensions and type of all the Tensors
# in sequences are same and fetching those from sequences[0]
max_size = paddle.shape(sequences[0])
# (TODO Hui Zhang): slice not supprot `end==start`
# trailing_dims = max_size[1:]
trailing_dims = tuple(
max_size[1:].numpy().tolist()) if sequences[0].ndim >= 2 else ()
max_len = max([s.shape[0] for s in sequences])
if batch_first:
out_dims = (len(sequences), max_len) + trailing_dims
else:
out_dims = (max_len, len(sequences)) + trailing_dims
out_tensor = paddle.full(out_dims, padding_value, sequences[0].dtype)
for i, tensor in enumerate(sequences):
length = tensor.shape[0]
# use index notation to prevent duplicate references to the tensor
logger.info(
f"length {length}, out_tensor {out_tensor.shape}, tensor {tensor.shape}"
)
if batch_first:
# TODO (Hui Zhang): set_value op not supprot `end==start`
# TODO (Hui Zhang): set_value op not support int16
# TODO (Hui Zhang): set_varbase 2 rank not support [0,0,...]
# out_tensor[i, :length, ...] = tensor
if length != 0:
out_tensor[i, :length] = tensor
else:
out_tensor[i, length] = tensor
else:
# TODO (Hui Zhang): set_value op not supprot `end==start`
# out_tensor[:length, i, ...] = tensor
if length != 0:
out_tensor[:length, i] = tensor
else:
out_tensor[length, i] = tensor
return out_tensor
def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
ignore_id: int) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Add <sos> and <eos> labels.
Args:
ys_pad (paddle.Tensor): batch of padded target sequences (B, Lmax)
sos (int): index of <sos>
eos (int): index of <eeos>
ignore_id (int): index of padding
Returns:
ys_in (paddle.Tensor) : (B, Lmax + 1)
ys_out (paddle.Tensor) : (B, Lmax + 1)
Examples:
>>> sos_id = 10
>>> eos_id = 11
>>> ignore_id = -1
>>> ys_pad
tensor([[ 1, 2, 3, 4, 5],
[ 4, 5, 6, -1, -1],
[ 7, 8, 9, -1, -1]], dtype=paddle.int32)
>>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id)
>>> ys_in
tensor([[10, 1, 2, 3, 4, 5],
[10, 4, 5, 6, 11, 11],
[10, 7, 8, 9, 11, 11]])
>>> ys_out
tensor([[ 1, 2, 3, 4, 5, 11],
[ 4, 5, 6, 11, -1, -1],
[ 7, 8, 9, 11, -1, -1]])
"""
# TODO(Hui Zhang): using comment code,
#_sos = paddle.to_tensor(
# [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
#_eos = paddle.to_tensor(
# [eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
#ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
#ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]
#ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]
#return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)
B = ys_pad.shape[0]
_sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos
_eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos
ys_in = paddle.cat([_sos, ys_pad], dim=1)
mask_pad = (ys_in == ignore_id)
ys_in = ys_in.masked_fill(mask_pad, eos)
ys_out = paddle.cat([ys_pad, _eos], dim=1)
ys_out = ys_out.masked_fill(mask_pad, eos)
mask_eos = (ys_out == ignore_id)
ys_out = ys_out.masked_fill(mask_eos, eos)
ys_out = ys_out.masked_fill(mask_pad, ignore_id)
return ys_in, ys_out
def th_accuracy(pad_outputs: paddle.Tensor,
pad_targets: paddle.Tensor,
ignore_label: int) -> float:
"""Calculate accuracy.
Args:
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
ignore_label (int): Ignore label id.
Returns:
float: Accuracy value (0.0 - 1.0).
"""
pad_pred = pad_outputs.view(pad_targets.shape[0], pad_targets.shape[1],
pad_outputs.shape[1]).argmax(2)
mask = pad_targets != ignore_label
#TODO(Hui Zhang): sum not support bool type
# numerator = paddle.sum(
# pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
numerator = (
pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
numerator = paddle.sum(numerator.type_as(pad_targets))
#TODO(Hui Zhang): sum not support bool type
# denominator = paddle.sum(mask)
denominator = paddle.sum(mask.type_as(pad_targets))
return float(numerator) / float(denominator)

@ -38,7 +38,7 @@ base = [
"pypinyin", "pypinyin-dict", "python-dateutil", "pyworld", "resampy==0.2.2",
"sacrebleu", "scipy", "sentencepiece~=0.1.96", "soundfile~=0.10",
"textgrid", "timer", "tqdm", "typeguard", "visualdl", "webrtcvad",
"yacs~=0.1.8", "prettytable", "zhon", 'colorlog', 'pathos == 0.2.8'
"yacs~=0.1.8", "prettytable", "zhon", 'colorlog', 'pathos == 0.2.8', 'webdataset'
]
server = [

Loading…
Cancel
Save