From 8f5e61090b569f9bf77f53a16668a533ae925f99 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Wed, 22 Jun 2022 05:12:31 +0000 Subject: [PATCH] new feature: Add webdataset in audio --- paddlespeech/audio/stream_data/__init__.py | 68 ++ paddlespeech/audio/stream_data/cache.py | 190 ++++ paddlespeech/audio/stream_data/compat.py | 170 ++++ paddlespeech/audio/stream_data/filters.py | 912 ++++++++++++++++++ .../audio/stream_data/paddle_utils.py | 33 + paddlespeech/audio/stream_data/pipeline.py | 127 +++ paddlespeech/audio/stream_data/shardlists.py | 257 +++++ .../audio/stream_data/tariterators.py | 283 ++++++ paddlespeech/audio/stream_data/utils.py | 128 +++ paddlespeech/audio/transform/__init__.py | 13 + paddlespeech/audio/transform/add_deltas.py | 54 ++ .../audio/transform/channel_selector.py | 57 ++ paddlespeech/audio/transform/cmvn.py | 201 ++++ paddlespeech/audio/transform/functional.py | 86 ++ paddlespeech/audio/transform/perturb.py | 561 +++++++++++ paddlespeech/audio/transform/spec_augment.py | 214 ++++ paddlespeech/audio/transform/spectrogram.py | 475 +++++++++ .../audio/transform/transform_interface.py | 35 + .../audio/transform/transformation.py | 158 +++ paddlespeech/audio/transform/wpe.py | 58 ++ paddlespeech/audio/utils/check_kwargs.py | 35 + paddlespeech/audio/utils/dynamic_import.py | 38 + paddlespeech/audio/utils/tensor_utils.py | 195 ++++ setup.py | 2 +- 24 files changed, 4349 insertions(+), 1 deletion(-) create mode 100644 paddlespeech/audio/stream_data/__init__.py create mode 100644 paddlespeech/audio/stream_data/cache.py create mode 100644 paddlespeech/audio/stream_data/compat.py create mode 100644 paddlespeech/audio/stream_data/filters.py create mode 100644 paddlespeech/audio/stream_data/paddle_utils.py create mode 100644 paddlespeech/audio/stream_data/pipeline.py create mode 100644 paddlespeech/audio/stream_data/shardlists.py create mode 100644 paddlespeech/audio/stream_data/tariterators.py create mode 100644 paddlespeech/audio/stream_data/utils.py create mode 100644 paddlespeech/audio/transform/__init__.py create mode 100644 paddlespeech/audio/transform/add_deltas.py create mode 100644 paddlespeech/audio/transform/channel_selector.py create mode 100644 paddlespeech/audio/transform/cmvn.py create mode 100644 paddlespeech/audio/transform/functional.py create mode 100644 paddlespeech/audio/transform/perturb.py create mode 100644 paddlespeech/audio/transform/spec_augment.py create mode 100644 paddlespeech/audio/transform/spectrogram.py create mode 100644 paddlespeech/audio/transform/transform_interface.py create mode 100644 paddlespeech/audio/transform/transformation.py create mode 100644 paddlespeech/audio/transform/wpe.py create mode 100644 paddlespeech/audio/utils/check_kwargs.py create mode 100644 paddlespeech/audio/utils/dynamic_import.py create mode 100644 paddlespeech/audio/utils/tensor_utils.py diff --git a/paddlespeech/audio/stream_data/__init__.py b/paddlespeech/audio/stream_data/__init__.py new file mode 100644 index 00000000..fdb3458c --- /dev/null +++ b/paddlespeech/audio/stream_data/__init__.py @@ -0,0 +1,68 @@ +# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved. +# See the LICENSE file for licensing terms (BSD-style). +# Modified from https://github.com/webdataset/webdataset +# +# flake8: noqa + +from .cache import ( + cached_tarfile_samples, + cached_tarfile_to_samples, + lru_cleanup, + pipe_cleaner, +) +from .compat import WebDataset, WebLoader, FluidWrapper +from webdataset.extradatasets import MockDataset, with_epoch, with_length +from .filters import ( + associate, + batched, + decode, + detshuffle, + extract_keys, + getfirst, + info, + map, + map_dict, + map_tuple, + pipelinefilter, + rename, + rename_keys, + rsample, + select, + shuffle, + slice, + to_tuple, + transform_with, + unbatched, + xdecode, + data_filter, + tokenize, + resample, + compute_fbank, + spec_aug, + sort, + padding, + cmvn +) +from webdataset.handlers import ( + ignore_and_continue, + ignore_and_stop, + reraise_exception, + warn_and_continue, + warn_and_stop, +) +from .pipeline import DataPipeline +from .shardlists import ( + MultiShardSample, + ResampledShards, + SimpleShardList, + non_empty, + resampled, + shardspec, + single_node_only, + split_by_node, + split_by_worker, +) +from .tariterators import tarfile_samples, tarfile_to_samples +from .utils import PipelineStage, repeatedly +from webdataset.writer import ShardWriter, TarWriter, numpy_dumps +from webdataset.mix import RandomMix, RoundRobin diff --git a/paddlespeech/audio/stream_data/cache.py b/paddlespeech/audio/stream_data/cache.py new file mode 100644 index 00000000..724f6911 --- /dev/null +++ b/paddlespeech/audio/stream_data/cache.py @@ -0,0 +1,190 @@ +# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# See the LICENSE file for licensing terms (BSD-style). +# Modified from https://github.com/webdataset/webdataset +import itertools, os, random, re, sys +from urllib.parse import urlparse + +from . import filters +from webdataset import gopen +from webdataset.handlers import reraise_exception +from .tariterators import tar_file_and_group_expander + +default_cache_dir = os.environ.get("WDS_CACHE", "./_cache") +default_cache_size = float(os.environ.get("WDS_CACHE_SIZE", "1e18")) + + +def lru_cleanup(cache_dir, cache_size, keyfn=os.path.getctime, verbose=False): + """Performs cleanup of the file cache in cache_dir using an LRU strategy, + keeping the total size of all remaining files below cache_size.""" + if not os.path.exists(cache_dir): + return + total_size = 0 + for dirpath, dirnames, filenames in os.walk(cache_dir): + for filename in filenames: + total_size += os.path.getsize(os.path.join(dirpath, filename)) + if total_size <= cache_size: + return + # sort files by last access time + files = [] + for dirpath, dirnames, filenames in os.walk(cache_dir): + for filename in filenames: + files.append(os.path.join(dirpath, filename)) + files.sort(key=keyfn, reverse=True) + # delete files until we're under the cache size + while len(files) > 0 and total_size > cache_size: + fname = files.pop() + total_size -= os.path.getsize(fname) + if verbose: + print("# deleting %s" % fname, file=sys.stderr) + os.remove(fname) + + +def download(url, dest, chunk_size=1024 ** 2, verbose=False): + """Download a file from `url` to `dest`.""" + temp = dest + f".temp{os.getpid()}" + with gopen.gopen(url) as stream: + with open(temp, "wb") as f: + while True: + data = stream.read(chunk_size) + if not data: + break + f.write(data) + os.rename(temp, dest) + + +def pipe_cleaner(spec): + """Guess the actual URL from a "pipe:" specification.""" + if spec.startswith("pipe:"): + spec = spec[5:] + words = spec.split(" ") + for word in words: + if re.match(r"^(https?|gs|ais|s3)", word): + return word + return spec + + +def get_file_cached( + spec, + cache_size=-1, + cache_dir=None, + url_to_name=pipe_cleaner, + verbose=False, +): + if cache_size == -1: + cache_size = default_cache_size + if cache_dir is None: + cache_dir = default_cache_dir + url = url_to_name(spec) + parsed = urlparse(url) + dirname, filename = os.path.split(parsed.path) + dirname = dirname.lstrip("/") + dirname = re.sub(r"[:/|;]", "_", dirname) + destdir = os.path.join(cache_dir, dirname) + os.makedirs(destdir, exist_ok=True) + dest = os.path.join(cache_dir, dirname, filename) + if not os.path.exists(dest): + if verbose: + print("# downloading %s to %s" % (url, dest), file=sys.stderr) + lru_cleanup(cache_dir, cache_size, verbose=verbose) + download(spec, dest, verbose=verbose) + return dest + + +def get_filetype(fname): + with os.popen("file '%s'" % fname) as f: + ftype = f.read() + return ftype + + +def check_tar_format(fname): + """Check whether a file is a tar archive.""" + ftype = get_filetype(fname) + return "tar archive" in ftype or "gzip compressed" in ftype + + +verbose_cache = int(os.environ.get("WDS_VERBOSE_CACHE", "0")) + + +def cached_url_opener( + data, + handler=reraise_exception, + cache_size=-1, + cache_dir=None, + url_to_name=pipe_cleaner, + validator=check_tar_format, + verbose=False, + always=False, +): + """Given a stream of url names (packaged in `dict(url=url)`), yield opened streams.""" + verbose = verbose or verbose_cache + for sample in data: + assert isinstance(sample, dict), sample + assert "url" in sample + url = sample["url"] + attempts = 5 + try: + if not always and os.path.exists(url): + dest = url + else: + dest = get_file_cached( + url, + cache_size=cache_size, + cache_dir=cache_dir, + url_to_name=url_to_name, + verbose=verbose, + ) + if verbose: + print("# opening %s" % dest, file=sys.stderr) + assert os.path.exists(dest) + if not validator(dest): + ftype = get_filetype(dest) + with open(dest, "rb") as f: + data = f.read(200) + os.remove(dest) + raise ValueError( + "%s (%s) is not a tar archive, but a %s, contains %s" + % (dest, url, ftype, repr(data)) + ) + try: + stream = open(dest, "rb") + sample.update(stream=stream) + yield sample + except FileNotFoundError as exn: + # dealing with race conditions in lru_cleanup + attempts -= 1 + if attempts > 0: + time.sleep(random.random() * 10) + continue + raise exn + except Exception as exn: + exn.args = exn.args + (url,) + if handler(exn): + continue + else: + break + + +def cached_tarfile_samples( + src, + handler=reraise_exception, + cache_size=-1, + cache_dir=None, + verbose=False, + url_to_name=pipe_cleaner, + always=False, +): + streams = cached_url_opener( + src, + handler=handler, + cache_size=cache_size, + cache_dir=cache_dir, + verbose=verbose, + url_to_name=url_to_name, + always=always, + ) + samples = tar_file_and_group_expander(streams, handler=handler) + return samples + + +cached_tarfile_to_samples = filters.pipelinefilter(cached_tarfile_samples) diff --git a/paddlespeech/audio/stream_data/compat.py b/paddlespeech/audio/stream_data/compat.py new file mode 100644 index 00000000..ee564431 --- /dev/null +++ b/paddlespeech/audio/stream_data/compat.py @@ -0,0 +1,170 @@ +# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# See the LICENSE file for licensing terms (BSD-style). +# Modified from https://github.com/webdataset/webdataset +from dataclasses import dataclass +from itertools import islice +from typing import List + +import braceexpand, yaml + +from webdataset import autodecode +from . import cache, filters, shardlists, tariterators +from .filters import reraise_exception +from .pipeline import DataPipeline +from .paddle_utils import DataLoader, IterableDataset + + +class FluidInterface: + def batched(self, batchsize): + return self.compose(filters.batched(batchsize)) + + def dynamic_batched(self, max_frames_in_batch): + return self.compose(filter.dynamic_batched(max_frames_in_batch)) + + def unbatched(self): + return self.compose(filters.unbatched()) + + def listed(self, batchsize, partial=True): + return self.compose(filters.batched(), batchsize=batchsize, collation_fn=None) + + def unlisted(self): + return self.compose(filters.unlisted()) + + def log_keys(self, logfile=None): + return self.compose(filters.log_keys(logfile)) + + def shuffle(self, size, **kw): + if size < 1: + return self + else: + return self.compose(filters.shuffle(size, **kw)) + + def map(self, f, handler=reraise_exception): + return self.compose(filters.map(f, handler=handler)) + + def decode(self, *args, pre=None, post=None, only=None, partial=False, handler=reraise_exception): + handlers = [autodecode.ImageHandler(x) if isinstance(x, str) else x for x in args] + decoder = autodecode.Decoder(handlers, pre=pre, post=post, only=only, partial=partial) + return self.map(decoder, handler=handler) + + def map_dict(self, handler=reraise_exception, **kw): + return self.compose(filters.map_dict(handler=handler, **kw)) + + def select(self, predicate, **kw): + return self.compose(filters.select(predicate, **kw)) + + def to_tuple(self, *args, handler=reraise_exception): + return self.compose(filters.to_tuple(*args, handler=handler)) + + def map_tuple(self, *args, handler=reraise_exception): + return self.compose(filters.map_tuple(*args, handler=handler)) + + def slice(self, *args): + return self.compose(filters.slice(*args)) + + def rename(self, **kw): + return self.compose(filters.rename(**kw)) + + def rsample(self, p=0.5): + return self.compose(filters.rsample(p)) + + def rename_keys(self, *args, **kw): + return self.compose(filters.rename_keys(*args, **kw)) + + def extract_keys(self, *args, **kw): + return self.compose(filters.extract_keys(*args, **kw)) + + def xdecode(self, *args, **kw): + return self.compose(filters.xdecode(*args, **kw)) + + def data_filter(self, *args, **kw): + return self.compose(filters.data_filter(*args, **kw)) + + def tokenize(self, *args, **kw): + return self.compose(filters.tokenize(*args, **kw)) + + def resample(self, *args, **kw): + return self.compose(filters.resample(*args, **kw)) + + def compute_fbank(self, *args, **kw): + return self.compose(filters.compute_fbank(*args, **kw)) + + def spec_aug(self, *args, **kw): + return self.compose(filters.spec_aug(*args, **kw)) + + def sort(self, size=500): + return self.compose(filters.sort(size)) + + def padding(self): + return self.compose(filters.padding()) + + def cmvn(self, cmvn_file): + return self.compose(filters.cmvn(cmvn_file)) + +class WebDataset(DataPipeline, FluidInterface): + """Small fluid-interface wrapper for DataPipeline.""" + + def __init__( + self, + urls, + handler=reraise_exception, + resampled=False, + repeat=False, + shardshuffle=None, + cache_size=0, + cache_dir=None, + detshuffle=False, + nodesplitter=shardlists.single_node_only, + verbose=False, + ): + super().__init__() + if isinstance(urls, IterableDataset): + assert not resampled + self.append(urls) + elif isinstance(urls, str) and (urls.endswith(".yaml") or urls.endswith(".yml")): + with (open(urls)) as stream: + spec = yaml.safe_load(stream) + assert "datasets" in spec + self.append(shardlists.MultiShardSample(spec)) + elif isinstance(urls, dict): + assert "datasets" in urls + self.append(shardlists.MultiShardSample(urls)) + elif resampled: + self.append(shardlists.ResampledShards(urls)) + else: + self.append(shardlists.SimpleShardList(urls)) + self.append(nodesplitter) + self.append(shardlists.split_by_worker) + if shardshuffle is True: + shardshuffle = 100 + if shardshuffle is not None: + if detshuffle: + self.append(filters.detshuffle(shardshuffle)) + else: + self.append(filters.shuffle(shardshuffle)) + if cache_size == 0: + self.append(tariterators.tarfile_to_samples(handler=handler)) + else: + assert cache_size == -1 or cache_size > 0 + self.append( + cache.cached_tarfile_to_samples( + handler=handler, + verbose=verbose, + cache_size=cache_size, + cache_dir=cache_dir, + ) + ) + + +class FluidWrapper(DataPipeline, FluidInterface): + """Small fluid-interface wrapper for DataPipeline.""" + + def __init__(self, initial): + super().__init__() + self.append(initial) + + +class WebLoader(DataPipeline, FluidInterface): + def __init__(self, *args, **kw): + super().__init__(DataLoader(*args, **kw)) diff --git a/paddlespeech/audio/stream_data/filters.py b/paddlespeech/audio/stream_data/filters.py new file mode 100644 index 00000000..3112c954 --- /dev/null +++ b/paddlespeech/audio/stream_data/filters.py @@ -0,0 +1,912 @@ +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). +# + +# Modified from https://github.com/webdataset/webdataset +# Modified from wenet(https://github.com/wenet-e2e/wenet) +"""A collection of iterators for data transformations. + +These functions are plain iterator functions. You can find curried versions +in webdataset.filters, and you can find IterableDataset wrappers in +webdataset.processing. +""" + +import io +from fnmatch import fnmatch +import re +import itertools, os, random, sys, time +from functools import reduce, wraps + +import numpy as np + +from webdataset import autodecode +from . import utils +from .paddle_utils import PaddleTensor +from .utils import PipelineStage + +from .. import backends +from ..compliance import kaldi +import paddle +from ..transform.cmvn import GlobalCMVN +from ..utils.tensor_utils import pad_sequence +from ..transform.spec_augment import time_warp +from ..transform.spec_augment import time_mask +from ..transform.spec_augment import freq_mask + +class FilterFunction(object): + """Helper class for currying pipeline stages. + + We use this roundabout construct becauce it can be pickled. + """ + + def __init__(self, f, *args, **kw): + """Create a curried function.""" + self.f = f + self.args = args + self.kw = kw + + def __call__(self, data): + """Call the curried function with the given argument.""" + return self.f(data, *self.args, **self.kw) + + def __str__(self): + """Compute a string representation.""" + return f"<{self.f.__name__} {self.args} {self.kw}>" + + def __repr__(self): + """Compute a string representation.""" + return f"<{self.f.__name__} {self.args} {self.kw}>" + + +class RestCurried(object): + """Helper class for currying pipeline stages. + + We use this roundabout construct because it can be pickled. + """ + + def __init__(self, f): + """Store the function for future currying.""" + self.f = f + + def __call__(self, *args, **kw): + """Curry with the given arguments.""" + return FilterFunction(self.f, *args, **kw) + + +def pipelinefilter(f): + """Turn the decorated function into one that is partially applied for + all arguments other than the first.""" + result = RestCurried(f) + return result + + +def reraise_exception(exn): + """Reraises the given exception; used as a handler. + + :param exn: exception + """ + raise exn + + +def identity(x): + """Return the argument.""" + return x + + +def compose2(f, g): + """Compose two functions, g(f(x)).""" + return lambda x: g(f(x)) + + +def compose(*args): + """Compose a sequence of functions (left-to-right).""" + return reduce(compose2, args) + + +def pipeline(source, *args): + """Write an input pipeline; first argument is source, rest are filters.""" + if len(args) == 0: + return source + return compose(*args)(source) + + +def getfirst(a, keys, default=None, missing_is_error=True): + """Get the first matching key from a dictionary. + + Keys can be specified as a list, or as a string of keys separated by ';'. + """ + if isinstance(keys, str): + assert " " not in keys + keys = keys.split(";") + for k in keys: + if k in a: + return a[k] + if missing_is_error: + raise ValueError(f"didn't find {keys} in {list(a.keys())}") + return default + + +def parse_field_spec(fields): + """Parse a specification for a list of fields to be extracted. + + Keys are separated by spaces in the spec. Each key can itself + be composed of key alternatives separated by ';'. + """ + if isinstance(fields, str): + fields = fields.split() + return [field.split(";") for field in fields] + + +def transform_with(sample, transformers): + """Transform a list of values using a list of functions. + + sample: list of values + transformers: list of functions + + If there are fewer transformers than inputs, or if a transformer + function is None, then the identity function is used for the + corresponding sample fields. + """ + if transformers is None or len(transformers) == 0: + return sample + result = list(sample) + assert len(transformers) <= len(sample) + for i in range(len(transformers)): # skipcq: PYL-C0200 + f = transformers[i] + if f is not None: + result[i] = f(sample[i]) + return result + +### +# Iterators +### + +def _info(data, fmt=None, n=3, every=-1, width=50, stream=sys.stderr, name=""): + """Print information about the samples that are passing through. + + :param data: source iterator + :param fmt: format statement (using sample dict as keyword) + :param n: when to stop + :param every: how often to print + :param width: maximum width + :param stream: output stream + :param name: identifier printed before any output + """ + for i, sample in enumerate(data): + if i < n or (every > 0 and (i + 1) % every == 0): + if fmt is None: + print("---", name, file=stream) + for k, v in sample.items(): + print(k, repr(v)[:width], file=stream) + else: + print(fmt.format(**sample), file=stream) + yield sample + + +info = pipelinefilter(_info) + + +def pick(buf, rng): + k = rng.randint(0, len(buf) - 1) + sample = buf[k] + buf[k] = buf[-1] + buf.pop() + return sample + + +def _shuffle(data, bufsize=1000, initial=100, rng=None, handler=None): + """Shuffle the data in the stream. + + This uses a buffer of size `bufsize`. Shuffling at + startup is less random; this is traded off against + yielding samples quickly. + + data: iterator + bufsize: buffer size for shuffling + returns: iterator + rng: either random module or random.Random instance + + """ + if rng is None: + rng = random.Random(int((os.getpid() + time.time()) * 1e9)) + initial = min(initial, bufsize) + buf = [] + for sample in data: + buf.append(sample) + if len(buf) < bufsize: + try: + buf.append(next(data)) # skipcq: PYL-R1708 + except StopIteration: + pass + if len(buf) >= initial: + yield pick(buf, rng) + while len(buf) > 0: + yield pick(buf, rng) + + +shuffle = pipelinefilter(_shuffle) + + +class detshuffle(PipelineStage): + def __init__(self, bufsize=1000, initial=100, seed=0, epoch=-1): + self.bufsize = bufsize + self.initial = initial + self.seed = seed + self.epoch = epoch + + def run(self, src): + self.epoch += 1 + rng = random.Random() + rng.seed((self.seed, self.epoch)) + return _shuffle(src, self.bufsize, self.initial, rng) + + +def _select(data, predicate): + """Select samples based on a predicate. + + :param data: source iterator + :param predicate: predicate (function) + """ + for sample in data: + if predicate(sample): + yield sample + + +select = pipelinefilter(_select) + + +def _log_keys(data, logfile=None): + import fcntl + + if logfile is None or logfile == "": + for sample in data: + yield sample + else: + with open(logfile, "a") as stream: + for i, sample in enumerate(data): + buf = f"{i}\t{sample.get('__worker__')}\t{sample.get('__rank__')}\t{sample.get('__key__')}\n" + try: + fcntl.flock(stream.fileno(), fcntl.LOCK_EX) + stream.write(buf) + finally: + fcntl.flock(stream.fileno(), fcntl.LOCK_UN) + yield sample + + +log_keys = pipelinefilter(_log_keys) + + +def _decode(data, *args, handler=reraise_exception, **kw): + """Decode data based on the decoding functions given as arguments.""" + + decoder = lambda x: autodecode.imagehandler(x) if isinstance(x, str) else x + handlers = [decoder(x) for x in args] + f = autodecode.Decoder(handlers, **kw) + + for sample in data: + assert isinstance(sample, dict), sample + try: + decoded = f(sample) + except Exception as exn: # skipcq: PYL-W0703 + if handler(exn): + continue + else: + break + yield decoded + + +decode = pipelinefilter(_decode) + + +def _map(data, f, handler=reraise_exception): + """Map samples.""" + for sample in data: + try: + result = f(sample) + except Exception as exn: + if handler(exn): + continue + else: + break + if result is None: + continue + if isinstance(sample, dict) and isinstance(result, dict): + result["__key__"] = sample.get("__key__") + yield result + + +map = pipelinefilter(_map) + + +def _rename(data, handler=reraise_exception, keep=True, **kw): + """Rename samples based on keyword arguments.""" + for sample in data: + try: + if not keep: + yield {k: getfirst(sample, v, missing_is_error=True) for k, v in kw.items()} + else: + + def listify(v): + return v.split(";") if isinstance(v, str) else v + + to_be_replaced = {x for v in kw.values() for x in listify(v)} + result = {k: v for k, v in sample.items() if k not in to_be_replaced} + result.update({k: getfirst(sample, v, missing_is_error=True) for k, v in kw.items()}) + yield result + except Exception as exn: + if handler(exn): + continue + else: + break + + +rename = pipelinefilter(_rename) + + +def _associate(data, associator, **kw): + """Associate additional data with samples.""" + for sample in data: + if callable(associator): + extra = associator(sample["__key__"]) + else: + extra = associator.get(sample["__key__"], {}) + sample.update(extra) # destructive + yield sample + + +associate = pipelinefilter(_associate) + + +def _map_dict(data, handler=reraise_exception, **kw): + """Map the entries in a dict sample with individual functions.""" + assert len(list(kw.keys())) > 0 + for key, f in kw.items(): + assert callable(f), (key, f) + + for sample in data: + assert isinstance(sample, dict) + try: + for k, f in kw.items(): + sample[k] = f(sample[k]) + except Exception as exn: + if handler(exn): + continue + else: + break + yield sample + + +map_dict = pipelinefilter(_map_dict) + + +def _to_tuple(data, *args, handler=reraise_exception, missing_is_error=True, none_is_error=None): + """Convert dict samples to tuples.""" + if none_is_error is None: + none_is_error = missing_is_error + if len(args) == 1 and isinstance(args[0], str) and " " in args[0]: + args = args[0].split() + + for sample in data: + try: + result = tuple([getfirst(sample, f, missing_is_error=missing_is_error) for f in args]) + if none_is_error and any(x is None for x in result): + raise ValueError(f"to_tuple {args} got {sample.keys()}") + yield result + except Exception as exn: + if handler(exn): + continue + else: + break + + +to_tuple = pipelinefilter(_to_tuple) + + +def _map_tuple(data, *args, handler=reraise_exception): + """Map the entries of a tuple with individual functions.""" + args = [f if f is not None else utils.identity for f in args] + for f in args: + assert callable(f), f + for sample in data: + assert isinstance(sample, (list, tuple)) + sample = list(sample) + n = min(len(args), len(sample)) + try: + for i in range(n): + sample[i] = args[i](sample[i]) + except Exception as exn: + if handler(exn): + continue + else: + break + yield tuple(sample) + + +map_tuple = pipelinefilter(_map_tuple) + + +def _unlisted(data): + """Turn batched data back into unbatched data.""" + for batch in data: + assert isinstance(batch, list), sample + for sample in batch: + yield sample + + +unlisted = pipelinefilter(_unlisted) + + +def _unbatched(data): + """Turn batched data back into unbatched data.""" + for sample in data: + assert isinstance(sample, (tuple, list)), sample + assert len(sample) > 0 + for i in range(len(sample[0])): + yield tuple(x[i] for x in sample) + + +unbatched = pipelinefilter(_unbatched) + + +def _rsample(data, p=0.5): + """Randomly subsample a stream of data.""" + assert p >= 0.0 and p <= 1.0 + for sample in data: + if random.uniform(0.0, 1.0) < p: + yield sample + + +rsample = pipelinefilter(_rsample) + +slice = pipelinefilter(itertools.islice) + + +def _extract_keys(source, *patterns, duplicate_is_error=True, ignore_missing=False): + for sample in source: + result = [] + for pattern in patterns: + pattern = pattern.split(";") if isinstance(pattern, str) else pattern + matches = [x for x in sample.keys() if any(fnmatch("." + x, p) for p in pattern)] + if len(matches) == 0: + if ignore_missing: + continue + else: + raise ValueError(f"Cannot find {pattern} in sample keys {sample.keys()}.") + if len(matches) > 1 and duplicate_is_error: + raise ValueError(f"Multiple sample keys {sample.keys()} match {pattern}.") + value = sample[matches[0]] + result.append(value) + yield tuple(result) + + +extract_keys = pipelinefilter(_extract_keys) + + +def _rename_keys(source, *args, keep_unselected=False, must_match=True, duplicate_is_error=True, **kw): + renamings = [(pattern, output) for output, pattern in args] + renamings += [(pattern, output) for output, pattern in kw.items()] + for sample in source: + new_sample = {} + matched = {k: False for k, _ in renamings} + for path, value in sample.items(): + fname = re.sub(r".*/", "", path) + new_name = None + for pattern, name in renamings[::-1]: + if fnmatch(fname.lower(), pattern): + matched[pattern] = True + new_name = name + break + if new_name is None: + if keep_unselected: + new_sample[path] = value + continue + if new_name in new_sample: + if duplicate_is_error: + raise ValueError(f"Duplicate value in sample {sample.keys()} after rename.") + continue + new_sample[new_name] = value + if must_match and not all(matched.values()): + raise ValueError(f"Not all patterns ({matched}) matched sample keys ({sample.keys()}).") + + yield new_sample + + +rename_keys = pipelinefilter(_rename_keys) + + +def decode_bin(stream): + return stream.read() + + +def decode_text(stream): + binary = stream.read() + return binary.decode("utf-8") + + +def decode_pickle(stream): + return pickle.load(stream) + + +default_decoders = [ + ("*.bin", decode_bin), + ("*.txt", decode_text), + ("*.pyd", decode_pickle), +] + + +def find_decoder(decoders, path): + fname = re.sub(r".*/", "", path) + if fname.startswith("__"): + return lambda x: x + for pattern, fun in decoders[::-1]: + if fnmatch(fname.lower(), pattern) or fnmatch("." + fname.lower(), pattern): + return fun + return None + + +def _xdecode( + source, + *args, + must_decode=True, + defaults=default_decoders, + **kw, +): + decoders = list(defaults) + list(args) + decoders += [("*." + k, v) for k, v in kw.items()] + for sample in source: + new_sample = {} + for path, data in sample.items(): + if path.startswith("__"): + new_sample[path] = data + continue + decoder = find_decoder(decoders, path) + if decoder is False: + value = data + elif decoder is None: + if must_decode: + raise ValueError(f"No decoder found for {path}.") + value = data + else: + if isinstance(data, bytes): + data = io.BytesIO(data) + value = decoder(data) + new_sample[path] = value + yield new_sample + +xdecode = pipelinefilter(_xdecode) + + + +def _data_filter(source, + frame_shift=10, + max_length=10240, + min_length=10, + token_max_length=200, + token_min_length=1, + min_output_input_ratio=0.0005, + max_output_input_ratio=1): + """ Filter sample according to feature and label length + Inplace operation. + + Args:: + source: Iterable[{fname, wav, label, sample_rate}] + frame_shift: length of frame shift (ms) + max_length: drop utterance which is greater than max_length(10ms) + min_length: drop utterance which is less than min_length(10ms) + token_max_length: drop utterance which is greater than + token_max_length, especially when use char unit for + english modeling + token_min_length: drop utterance which is + less than token_max_length + min_output_input_ratio: minimal ration of + token_length / feats_length(10ms) + max_output_input_ratio: maximum ration of + token_length / feats_length(10ms) + + Returns: + Iterable[{fname, wav, label, sample_rate}] + """ + for sample in source: + assert 'sample_rate' in sample + assert 'wav' in sample + assert 'label' in sample + # sample['wav'] is paddle.Tensor, we have 100 frames every second (default) + num_frames = sample['wav'].shape[1] / sample['sample_rate'] * (1000 / frame_shift) + if num_frames < min_length: + continue + if num_frames > max_length: + continue + if len(sample['label']) < token_min_length: + continue + if len(sample['label']) > token_max_length: + continue + if num_frames != 0: + if len(sample['label']) / num_frames < min_output_input_ratio: + continue + if len(sample['label']) / num_frames > max_output_input_ratio: + continue + yield sample + +data_filter = pipelinefilter(_data_filter) + +def _tokenize(source, + symbol_table, + bpe_model=None, + non_lang_syms=None, + split_with_space=False): + """ Decode text to chars or BPE + Inplace operation + + Args: + source: Iterable[{fname, wav, txt, sample_rate}] + + Returns: + Iterable[{fname, wav, txt, tokens, label, sample_rate}] + """ + if non_lang_syms is not None: + non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})") + else: + non_lang_syms = {} + non_lang_syms_pattern = None + + if bpe_model is not None: + import sentencepiece as spm + sp = spm.SentencePieceProcessor() + sp.load(bpe_model) + else: + sp = None + + for sample in source: + assert 'txt' in sample + txt = sample['txt'].strip() + if non_lang_syms_pattern is not None: + parts = non_lang_syms_pattern.split(txt.upper()) + parts = [w for w in parts if len(w.strip()) > 0] + else: + parts = [txt] + + label = [] + tokens = [] + for part in parts: + if part in non_lang_syms: + tokens.append(part) + else: + if bpe_model is not None: + tokens.extend(__tokenize_by_bpe_model(sp, part)) + else: + if split_with_space: + part = part.split(" ") + for ch in part: + if ch == ' ': + ch = "" + tokens.append(ch) + + for ch in tokens: + if ch in symbol_table: + label.append(symbol_table[ch]) + elif '' in symbol_table: + label.append(symbol_table['']) + + sample['tokens'] = tokens + sample['label'] = label + yield sample + +tokenize = pipelinefilter(_tokenize) + +def _resample(source, resample_rate=16000): + """ Resample data. + Inplace operation. + + Args: + data: Iterable[{fname, wav, label, sample_rate}] + resample_rate: target resample rate + + Returns: + Iterable[{fname, wav, label, sample_rate}] + """ + for sample in source: + assert 'sample_rate' in sample + assert 'wav' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + if sample_rate != resample_rate: + sample['sample_rate'] = resample_rate + sample['wav'] = paddle.to_tensor(backends.soundfile_backend.resample( + waveform.numpy(), src_sr = sample_rate, target_sr = resample_rate + )) + yield sample + +resample = pipelinefilter(_resample) + +def _compute_fbank(source, + num_mel_bins=80, + frame_length=25, + frame_shift=10, + dither=0.0): + """ Extract fbank + + Args: + source: Iterable[{fname, wav, label, sample_rate}] + num_mel_bins: number of mel filter bank + frame_length: length of one frame (ms) + frame_shift: length of frame shift (ms) + dither: value of dither + + Returns: + Iterable[{fname, feat, label}] + """ + for sample in source: + assert 'sample_rate' in sample + assert 'wav' in sample + assert 'fname' in sample + assert 'label' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + waveform = waveform * (1 << 15) + # Only keep fname, feat, label + mat = kaldi.fbank(waveform, + n_mels=num_mel_bins, + frame_length=frame_length, + frame_shift=frame_shift, + dither=dither, + energy_floor=0.0, + sr=sample_rate) + yield dict(fname=sample['fname'], label=sample['label'], feat=mat) + + +compute_fbank = pipelinefilter(_compute_fbank) + +def _spec_aug(source, num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80): + """ Do spec augmentation + Inplace operation + + Args: + source: Iterable[{fname, feat, label}] + num_t_mask: number of time mask to apply + num_f_mask: number of freq mask to apply + max_t: max width of time mask + max_f: max width of freq mask + max_w: max width of time warp + + Returns + Iterable[{fname, feat, label}] + """ + for sample in source: + x = sample['feat'] + x = x.numpy() + x = time_warp(x, max_time_warp=max_w, inplace = True, mode= "PIL") + x = freq_mask(x, F = max_f, n_mask = num_f_mask, inplace = True, replace_with_zero = False) + x = time_mask(x, T = max_t, n_mask = num_t_mask, inplace = True, replace_with_zero = False) + sample['feat'] = paddle.to_tensor(x, dtype=paddle.float32) + yield sample + +spec_aug = pipelinefilter(_spec_aug) + + +def _sort(source, sort_size=500): + """ Sort the data by feature length. + Sort is used after shuffle and before batch, so we can group + utts with similar lengths into a batch, and `sort_size` should + be less than `shuffle_size` + + Args: + source: Iterable[{fname, feat, label}] + sort_size: buffer size for sort + + Returns: + Iterable[{fname, feat, label}] + """ + + buf = [] + for sample in source: + buf.append(sample) + if len(buf) >= sort_size: + buf.sort(key=lambda x: x['feat'].shape[0]) + for x in buf: + yield x + buf = [] + # The sample left over + buf.sort(key=lambda x: x['feat'].shape[0]) + for x in buf: + yield x + +sort = pipelinefilter(_sort) + +def _batched(source, batch_size=16): + """ Static batch the data by `batch_size` + + Args: + data: Iterable[{fname, feat, label}] + batch_size: batch size + + Returns: + Iterable[List[{fname, feat, label}]] + """ + buf = [] + for sample in source: + buf.append(sample) + if len(buf) >= batch_size: + yield buf + buf = [] + if len(buf) > 0: + yield buf + +batched = pipelinefilter(_batched) + +def dynamic_batched(source, max_frames_in_batch=12000): + """ Dynamic batch the data until the total frames in batch + reach `max_frames_in_batch` + + Args: + source: Iterable[{fname, feat, label}] + max_frames_in_batch: max_frames in one batch + + Returns: + Iterable[List[{fname, feat, label}]] + """ + buf = [] + longest_frames = 0 + for sample in source: + assert 'feat' in sample + assert isinstance(sample['feat'], paddle.Tensor) + new_sample_frames = sample['feat'].size(0) + longest_frames = max(longest_frames, new_sample_frames) + frames_after_padding = longest_frames * (len(buf) + 1) + if frames_after_padding > max_frames_in_batch: + yield buf + buf = [sample] + longest_frames = new_sample_frames + else: + buf.append(sample) + if len(buf) > 0: + yield buf + + +def _padding(source): + """ Padding the data into training data + + Args: + source: Iterable[List[{fname, feat, label}]] + + Returns: + Iterable[Tuple(fname, feats, labels, feats lengths, label lengths)] + """ + for sample in source: + assert isinstance(sample, list) + feats_length = paddle.to_tensor([x['feat'].shape[0] for x in sample], + dtype="int64") + order = paddle.argsort(feats_length, descending=True) + feats_lengths = paddle.to_tensor( + [sample[i]['feat'].shape[0] for i in order], dtype="int64") + sorted_feats = [sample[i]['feat'] for i in order] + sorted_keys = [sample[i]['fname'] for i in order] + sorted_labels = [ + paddle.to_tensor(sample[i]['label'], dtype="int32") for i in order + ] + label_lengths = paddle.to_tensor([x.shape[0] for x in sorted_labels], + dtype="int64") + padded_feats = pad_sequence(sorted_feats, + batch_first=True, + padding_value=0) + padding_labels = pad_sequence(sorted_labels, + batch_first=True, + padding_value=-1) + + yield (sorted_keys, padded_feats, feats_lengths, padding_labels, + label_lengths) + +padding = pipelinefilter(_padding) + +def _cmvn(source, cmvn_file): + global_cmvn = GlobalCMVN(cmvn_file) + for batch in source: + sorted_keys, padded_feats, feats_lengths, padding_labels, label_lengths = batch + padded_feats = padded_feats.numpy() + padded_feats = global_cmvn(padded_feats) + padded_feats = paddle.to_tensor(padded_feats, dtype=paddle.float32) + yield (sorted_keys, padded_feats, feats_lengths, padding_labels, + label_lengths) + +cmvn = pipelinefilter(_cmvn) diff --git a/paddlespeech/audio/stream_data/paddle_utils.py b/paddlespeech/audio/stream_data/paddle_utils.py new file mode 100644 index 00000000..02bc4c84 --- /dev/null +++ b/paddlespeech/audio/stream_data/paddle_utils.py @@ -0,0 +1,33 @@ +# +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). +# Modified from https://github.com/webdataset/webdataset +# + +"""Mock implementations of paddle interfaces when paddle is not available.""" + + +try: + from paddle.io import DataLoader, IterableDataset +except ModuleNotFoundError: + + class IterableDataset: + """Empty implementation of IterableDataset when paddle is not available.""" + + pass + + class DataLoader: + """Empty implementation of DataLoader when paddle is not available.""" + + pass + +try: + from paddle import Tensor as PaddleTensor +except ModuleNotFoundError: + + class TorchTensor: + """Empty implementation of PaddleTensor when paddle is not available.""" + + pass diff --git a/paddlespeech/audio/stream_data/pipeline.py b/paddlespeech/audio/stream_data/pipeline.py new file mode 100644 index 00000000..b672773b --- /dev/null +++ b/paddlespeech/audio/stream_data/pipeline.py @@ -0,0 +1,127 @@ +# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# See the LICENSE file for licensing terms (BSD-style). +# Modified from https://github.com/webdataset/webdataset +#%% +import copy, os, random, sys, time +from dataclasses import dataclass +from itertools import islice +from typing import List + +import braceexpand, yaml + +from webdataset import autodecode, extradatasets as eds, filters, shardlists, tariterators +from webdataset.handlers import reraise_exception +from .paddle_utils import DataLoader, IterableDataset +from .utils import PipelineStage + + +def add_length_method(obj): + def length(self): + return self.size + + Combined = type( + obj.__class__.__name__ + "_Length", + (obj.__class__, IterableDataset), + {"__len__": length}, + ) + obj.__class__ = Combined + return obj + + +class DataPipeline(IterableDataset, PipelineStage): + """A pipeline starting with an IterableDataset and a series of filters.""" + + def __init__(self, *args, **kwargs): + super().__init__() + self.pipeline = [] + self.length = -1 + self.repetitions = 1 + self.nsamples = -1 + for arg in args: + if arg is None: + continue + if isinstance(arg, list): + self.pipeline.extend(arg) + else: + self.pipeline.append(arg) + + def invoke(self, f, *args, **kwargs): + """Apply a pipeline stage, possibly to the output of a previous stage.""" + if isinstance(f, PipelineStage): + return f.run(*args, **kwargs) + if isinstance(f, (IterableDataset, DataLoader)) and len(args) == 0: + return iter(f) + if isinstance(f, list): + return iter(f) + if callable(f): + result = f(*args, **kwargs) + return result + raise ValueError(f"{f}: not a valid pipeline stage") + + def iterator1(self): + """Create an iterator through one epoch in the pipeline.""" + source = self.invoke(self.pipeline[0]) + for step in self.pipeline[1:]: + source = self.invoke(step, source) + return source + + def iterator(self): + """Create an iterator through the entire dataset, using the given number of repetitions.""" + for i in range(self.repetitions): + for sample in self.iterator1(): + yield sample + + def __iter__(self): + """Create an iterator through the pipeline, repeating and slicing as requested.""" + if self.repetitions != 1: + if self.nsamples > 0: + return islice(self.iterator(), self.nsamples) + else: + return self.iterator() + else: + return self.iterator() + + def stage(self, i): + """Return pipeline stage i.""" + return self.pipeline[i] + + def append(self, f): + """Append a pipeline stage (modifies the object).""" + self.pipeline.append(f) + + def compose(self, *args): + """Append a pipeline stage to a copy of the pipeline and returns the copy.""" + result = copy.copy(self) + for arg in args: + result.append(arg) + return result + + def with_length(self, n): + """Add a __len__ method returning the desired value. + + This does not change the actual number of samples in an epoch. + PyTorch IterableDataset should not have a __len__ method. + This is provided only as a workaround for some broken training environments + that require a __len__ method. + """ + self.size = n + return add_length_method(self) + + def with_epoch(self, nsamples=-1, nbatches=-1): + """Change the epoch to return the given number of samples/batches. + + The two arguments mean the same thing.""" + self.repetitions = sys.maxsize + self.nsamples = max(nsamples, nbatches) + return self + + def repeat(self, nepochs=-1, nbatches=-1): + """Repeat iterating through the dataset for the given #epochs up to the given #samples.""" + if nepochs > 0: + self.repetitions = nepochs + self.nsamples = nbatches + else: + self.repetitions = sys.maxsize + self.nsamples = nbatches + return self diff --git a/paddlespeech/audio/stream_data/shardlists.py b/paddlespeech/audio/stream_data/shardlists.py new file mode 100644 index 00000000..503bfe57 --- /dev/null +++ b/paddlespeech/audio/stream_data/shardlists.py @@ -0,0 +1,257 @@ +# +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). +# + +# Modified from https://github.com/webdataset/webdataset + +"""Train PyTorch models directly from POSIX tar archive. + +Code works locally or over HTTP connections. +""" + +import os, random, sys, time +from dataclasses import dataclass, field +from itertools import islice +from typing import List + +import braceexpand, yaml + +from . import utils +from .filters import pipelinefilter +from .paddle_utils import IterableDataset + + +def expand_urls(urls): + if isinstance(urls, str): + urllist = urls.split("::") + result = [] + for url in urllist: + result.extend(braceexpand.braceexpand(url)) + return result + else: + return list(urls) + + +class SimpleShardList(IterableDataset): + """An iterable dataset yielding a list of urls.""" + + def __init__(self, urls, seed=None): + """Iterate through the list of shards. + + :param urls: a list of URLs as a Python list or brace notation string + """ + super().__init__() + urls = expand_urls(urls) + self.urls = urls + assert isinstance(self.urls[0], str) + self.seed = seed + + def __len__(self): + return len(self.urls) + + def __iter__(self): + """Return an iterator over the shards.""" + urls = self.urls.copy() + if self.seed is not None: + random.Random(self.seed).shuffle(urls) + for url in urls: + yield dict(url=url) + + +def split_by_node(src, group=None): + rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group) + if world_size > 1: + for s in islice(src, rank, None, world_size): + yield s + else: + for s in src: + yield s + + +def single_node_only(src, group=None): + rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group) + if world_size > 1: + raise ValueError("input pipeline needs to be reconfigured for multinode training") + for s in src: + yield s + + +def split_by_worker(src): + rank, world_size, worker, num_workers = utils.paddle_worker_info() + if num_workers > 1: + for s in islice(src, worker, None, num_workers): + yield s + else: + for s in src: + yield s + + +def resampled_(src, n=sys.maxsize): + import random + + seed = time.time() + try: + seed = open("/dev/random", "rb").read(20) + except Exception as exn: + print(repr(exn)[:50], file=sys.stderr) + rng = random.Random(seed) + print("# resampled loading", file=sys.stderr) + items = list(src) + print(f"# resampled got {len(items)} samples, yielding {n}", file=sys.stderr) + for i in range(n): + yield rng.choice(items) + + +resampled = pipelinefilter(resampled_) + + +def non_empty(src): + count = 0 + for s in src: + yield s + count += 1 + if count == 0: + raise ValueError("pipeline stage received no data at all and this was declared as an error") + + +@dataclass +class MSSource: + """Class representing a data source.""" + + name: str = "" + perepoch: int = -1 + resample: bool = False + urls: List[str] = field(default_factory=list) + + +default_rng = random.Random() + + +def expand(s): + return os.path.expanduser(os.path.expandvars(s)) + + +class MultiShardSample(IterableDataset): + def __init__(self, fname): + """Construct a shardlist from multiple sources using a YAML spec.""" + self.epoch = -1 +class MultiShardSample(IterableDataset): + def __init__(self, fname): + """Construct a shardlist from multiple sources using a YAML spec.""" + self.epoch = -1 + self.parse_spec(fname) + + def parse_spec(self, fname): + self.rng = default_rng # capture default_rng if we fork + if isinstance(fname, dict): + spec = fname + fname = "{dict}" + else: + with open(fname) as stream: + spec = yaml.safe_load(stream) + assert set(spec.keys()).issubset(set("prefix datasets buckets".split())), list(spec.keys()) + prefix = expand(spec.get("prefix", "")) + self.sources = [] + for ds in spec["datasets"]: + assert set(ds.keys()).issubset(set("buckets name shards resample choose".split())), list( + ds.keys() + ) + buckets = ds.get("buckets", spec.get("buckets", [])) + if isinstance(buckets, str): + buckets = [buckets] + buckets = [expand(s) for s in buckets] + if buckets == []: + buckets = [""] + assert len(buckets) == 1, f"{buckets}: FIXME support for multiple buckets unimplemented" + bucket = buckets[0] + name = ds.get("name", "@" + bucket) + urls = ds["shards"] + if isinstance(urls, str): + urls = [urls] + # urls = [u for url in urls for u in braceexpand.braceexpand(url)] + urls = [ + prefix + os.path.join(bucket, u) for url in urls for u in braceexpand.braceexpand(expand(url)) + ] + resample = ds.get("resample", -1) + nsample = ds.get("choose", -1) + if nsample > len(urls): + raise ValueError(f"perepoch {nsample} must be no greater than the number of shards") + if (nsample > 0) and (resample > 0): + raise ValueError("specify only one of perepoch or choose") + entry = MSSource(name=name, urls=urls, perepoch=nsample, resample=resample) + self.sources.append(entry) + print(f"# {name} {len(urls)} {nsample}", file=sys.stderr) + + def set_epoch(self, seed): + """Set the current epoch (for consistent shard selection among nodes).""" + self.rng = random.Random(seed) + + def get_shards_for_epoch(self): + result = [] + for source in self.sources: + if source.resample > 0: + # sample with replacement + l = self.rng.choices(source.urls, k=source.resample) + elif source.perepoch > 0: + # sample without replacement + l = list(source.urls) + self.rng.shuffle(l) + l = l[: source.perepoch] + else: + l = list(source.urls) + result += l + self.rng.shuffle(result) + return result + + def __iter__(self): + shards = self.get_shards_for_epoch() + for shard in shards: + yield dict(url=shard) + + +def shardspec(spec): + if spec.endswith(".yaml"): + return MultiShardSample(spec) + else: + return SimpleShardList(spec) + + +class ResampledShards(IterableDataset): + """An iterable dataset yielding a list of urls.""" + + def __init__( + self, + urls, + nshards=sys.maxsize, + worker_seed=None, + deterministic=False, + ): + """Sample shards from the shard list with replacement. + + :param urls: a list of URLs as a Python list or brace notation string + """ + super().__init__() + urls = expand_urls(urls) + self.urls = urls + assert isinstance(self.urls[0], str) + self.nshards = nshards + self.worker_seed = utils.paddle_worker_seed if worker_seed is None else worker_seed + self.deterministic = deterministic + self.epoch = -1 + + def __iter__(self): + """Return an iterator over the shards.""" + self.epoch += 1 + if self.deterministic: + seed = utils.make_seed(self.worker_seed(), self.epoch) + else: + seed = utils.make_seed(self.worker_seed(), self.epoch, os.getpid(), time.time_ns(), os.urandom(4)) + if os.environ.get("WDS_SHOW_SEED", "0") == "1": + print(f"# ResampledShards seed {seed}") + self.rng = random.Random(seed) + for _ in range(self.nshards): + index = self.rng.randint(0, len(self.urls) - 1) + yield dict(url=self.urls[index]) diff --git a/paddlespeech/audio/stream_data/tariterators.py b/paddlespeech/audio/stream_data/tariterators.py new file mode 100644 index 00000000..d9469797 --- /dev/null +++ b/paddlespeech/audio/stream_data/tariterators.py @@ -0,0 +1,283 @@ +# +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). + +# Modified from https://github.com/webdataset/webdataset +# Modified from wenet(https://github.com/wenet-e2e/wenet) + +"""Low level iteration functions for tar archives.""" + +import random, re, tarfile + +import braceexpand + +from . import filters +from webdataset import gopen +from webdataset.handlers import reraise_exception + +trace = False +meta_prefix = "__" +meta_suffix = "__" + +from ... import audio as paddleaudio +import paddle +import numpy as np + +AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma']) + +def base_plus_ext(path): + """Split off all file extensions. + + Returns base, allext. + + :param path: path with extensions + :param returns: path with all extensions removed + + """ + match = re.match(r"^((?:.*/|)[^.]+)[.]([^/]*)$", path) + if not match: + return None, None + return match.group(1), match.group(2) + + +def valid_sample(sample): + """Check whether a sample is valid. + + :param sample: sample to be checked + """ + return ( + sample is not None + and isinstance(sample, dict) + and len(list(sample.keys())) > 0 + and not sample.get("__bad__", False) + ) + + +# FIXME: UNUSED +def shardlist(urls, *, shuffle=False): + """Given a list of URLs, yields that list, possibly shuffled.""" + if isinstance(urls, str): + urls = braceexpand.braceexpand(urls) + else: + urls = list(urls) + if shuffle: + random.shuffle(urls) + for url in urls: + yield dict(url=url) + + +def url_opener(data, handler=reraise_exception, **kw): + """Given a stream of url names (packaged in `dict(url=url)`), yield opened streams.""" + for sample in data: + assert isinstance(sample, dict), sample + assert "url" in sample + url = sample["url"] + try: + stream = gopen.gopen(url, **kw) + sample.update(stream=stream) + yield sample + except Exception as exn: + exn.args = exn.args + (url,) + if handler(exn): + continue + else: + break + + +def tar_file_iterator( + fileobj, skip_meta=r"__[^/]*__($|/)", handler=reraise_exception +): + """Iterate over tar file, yielding filename, content pairs for the given tar stream. + + :param fileobj: byte stream suitable for tarfile + :param skip_meta: regexp for keys that are skipped entirely (Default value = r"__[^/]*__($|/)") + + """ + stream = tarfile.open(fileobj=fileobj, mode="r:*") + for tarinfo in stream: + fname = tarinfo.name + try: + if not tarinfo.isreg(): + continue + if fname is None: + continue + if ( + "/" not in fname + and fname.startswith(meta_prefix) + and fname.endswith(meta_suffix) + ): + # skipping metadata for now + continue + if skip_meta is not None and re.match(skip_meta, fname): + continue + + name = tarinfo.name + pos = name.rfind('.') + assert pos > 0 + prefix, postfix = name[:pos], name[pos + 1:] + if postfix == 'wav': + waveform, sample_rate = paddleaudio.load(stream.extractfile(tarinfo), normal=False) + result = dict(fname=prefix, wav=waveform, sample_rate = sample_rate) + else: + txt = stream.extractfile(tarinfo).read().decode('utf8').strip() + result = dict(fname=prefix, txt=txt) + #result = dict(fname=fname, data=data) + yield result + stream.members = [] + except Exception as exn: + if hasattr(exn, "args") and len(exn.args) > 0: + exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:] + if handler(exn): + continue + else: + break + del stream + +def tar_file_and_group_iterator( + fileobj, skip_meta=r"__[^/]*__($|/)", handler=reraise_exception +): + """ Expand a stream of open tar files into a stream of tar file contents. + And groups the file with same prefix + + Args: + data: Iterable[{src, stream}] + + Returns: + Iterable[{key, wav, txt, sample_rate}] + """ + stream = tarfile.open(fileobj=fileobj, mode="r:*") + prev_prefix = None + example = {} + valid = True + for tarinfo in stream: + name = tarinfo.name + pos = name.rfind('.') + assert pos > 0 + prefix, postfix = name[:pos], name[pos + 1:] + if prev_prefix is not None and prefix != prev_prefix: + example['fname'] = prev_prefix + if valid: + yield example + example = {} + valid = True + with stream.extractfile(tarinfo) as file_obj: + try: + if postfix == 'txt': + example['txt'] = file_obj.read().decode('utf8').strip() + elif postfix in AUDIO_FORMAT_SETS: + waveform, sample_rate = paddleaudio.load(file_obj, normal=False) + waveform = paddle.to_tensor(np.expand_dims(np.array(waveform),0), dtype=paddle.float32) + + example['wav'] = waveform + example['sample_rate'] = sample_rate + else: + example[postfix] = file_obj.read() + except Exception as exn: + if hasattr(exn, "args") and len(exn.args) > 0: + exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:] + if handler(exn): + continue + else: + break + valid = False + # logging.warning('error to parse {}'.format(name)) + prev_prefix = prefix + if prev_prefix is not None: + example['fname'] = prev_prefix + yield example + stream.close() + +def tar_file_expander(data, handler=reraise_exception): + """Expand a stream of open tar files into a stream of tar file contents. + + This returns an iterator over (filename, file_contents). + """ + for source in data: + url = source["url"] + try: + assert isinstance(source, dict) + assert "stream" in source + for sample in tar_file_iterator(source["stream"]): + assert ( + isinstance(sample, dict) and "data" in sample and "fname" in sample + ) + sample["__url__"] = url + yield sample + except Exception as exn: + exn.args = exn.args + (source.get("stream"), source.get("url")) + if handler(exn): + continue + else: + break + + + + +def tar_file_and_group_expander(data, handler=reraise_exception): + """Expand a stream of open tar files into a stream of tar file contents. + + This returns an iterator over (filename, file_contents). + """ + for source in data: + url = source["url"] + try: + assert isinstance(source, dict) + assert "stream" in source + for sample in tar_file_and_group_iterator(source["stream"]): + assert ( + isinstance(sample, dict) and "wav" in sample and "txt" in sample and "fname" in sample + ) + sample["__url__"] = url + yield sample + except Exception as exn: + exn.args = exn.args + (source.get("stream"), source.get("url")) + if handler(exn): + continue + else: + break + + +def group_by_keys(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): + """Return function over iterator that groups key, value pairs into samples. + + :param keys: function that splits the key into key and extension (base_plus_ext) + :param lcase: convert suffixes to lower case (Default value = True) + """ + current_sample = None + for filesample in data: + assert isinstance(filesample, dict) + fname, value = filesample["fname"], filesample["data"] + prefix, suffix = keys(fname) + if trace: + print( + prefix, + suffix, + current_sample.keys() if isinstance(current_sample, dict) else None, + ) + if prefix is None: + continue + if lcase: + suffix = suffix.lower() + if current_sample is None or prefix != current_sample["__key__"]: + if valid_sample(current_sample): + yield current_sample + current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) + if suffix in current_sample: + raise ValueError( + f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}" + ) + if suffixes is None or suffix in suffixes: + current_sample[suffix] = value + if valid_sample(current_sample): + yield current_sample + + +def tarfile_samples(src, handler=reraise_exception): + streams = url_opener(src, handler=handler) + samples = tar_file_and_group_expander(streams, handler=handler) + return samples + + +tarfile_to_samples = filters.pipelinefilter(tarfile_samples) diff --git a/paddlespeech/audio/stream_data/utils.py b/paddlespeech/audio/stream_data/utils.py new file mode 100644 index 00000000..83a42bad --- /dev/null +++ b/paddlespeech/audio/stream_data/utils.py @@ -0,0 +1,128 @@ +# +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). +# + +# Modified from https://github.com/webdataset/webdataset + +"""Miscellaneous utility functions.""" + +import importlib +import itertools as itt +import os +import re +import sys +from typing import Any, Callable, Iterator, Optional, Union + + +def make_seed(*args): + seed = 0 + for arg in args: + seed = (seed * 31 + hash(arg)) & 0x7FFFFFFF + return seed + + +class PipelineStage: + def invoke(self, *args, **kw): + raise NotImplementedError + + +def identity(x: Any) -> Any: + """Return the argument as is.""" + return x + + +def safe_eval(s: str, expr: str = "{}"): + """Evaluate the given expression more safely.""" + if re.sub("[^A-Za-z0-9_]", "", s) != s: + raise ValueError(f"safe_eval: illegal characters in: '{s}'") + return eval(expr.format(s)) + + +def lookup_sym(sym: str, modules: list): + """Look up a symbol in a list of modules.""" + for mname in modules: + module = importlib.import_module(mname, package="webdataset") + result = getattr(module, sym, None) + if result is not None: + return result + return None + + +def repeatedly0( + loader: Iterator, nepochs: int = sys.maxsize, nbatches: int = sys.maxsize +): + """Repeatedly returns batches from a DataLoader.""" + for epoch in range(nepochs): + for sample in itt.islice(loader, nbatches): + yield sample + + +def guess_batchsize(batch: Union[tuple, list]): + """Guess the batch size by looking at the length of the first element in a tuple.""" + return len(batch[0]) + + +def repeatedly( + source: Iterator, + nepochs: int = None, + nbatches: int = None, + nsamples: int = None, + batchsize: Callable[..., int] = guess_batchsize, +): + """Repeatedly yield samples from an iterator.""" + epoch = 0 + batch = 0 + total = 0 + while True: + for sample in source: + yield sample + batch += 1 + if nbatches is not None and batch >= nbatches: + return + if nsamples is not None: + total += guess_batchsize(sample) + if total >= nsamples: + return + epoch += 1 + if nepochs is not None and epoch >= nepochs: + return + +def paddle_worker_info(group=None): + """Return node and worker info for PyTorch and some distributed environments.""" + rank = 0 + world_size = 1 + worker = 0 + num_workers = 1 + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + else: + try: + import paddle.distributed + group = group or paddle.distributed.get_group() + rank = paddle.distributed.get_rank() + world_size = paddle.distributed.get_world_size() + except ModuleNotFoundError: + pass + if "WORKER" in os.environ and "NUM_WORKERS" in os.environ: + worker = int(os.environ["WORKER"]) + num_workers = int(os.environ["NUM_WORKERS"]) + else: + try: + import paddle.io.get_worker_info + worker_info = paddle.io.get_worker_info() + if worker_info is not None: + worker = worker_info.id + num_workers = worker_info.num_workers + except ModuleNotFoundError: + pass + + return rank, world_size, worker, num_workers + +def paddle_worker_seed(group=None): + """Compute a distinct, deterministic RNG seed for each worker and node.""" + rank, world_size, worker, num_workers = paddle_worker_info(group=group) + return rank * 1000 + worker diff --git a/paddlespeech/audio/transform/__init__.py b/paddlespeech/audio/transform/__init__.py new file mode 100644 index 00000000..185a92b8 --- /dev/null +++ b/paddlespeech/audio/transform/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/audio/transform/add_deltas.py b/paddlespeech/audio/transform/add_deltas.py new file mode 100644 index 00000000..1387fe9d --- /dev/null +++ b/paddlespeech/audio/transform/add_deltas.py @@ -0,0 +1,54 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +import numpy as np + + +def delta(feat, window): + assert window > 0 + delta_feat = np.zeros_like(feat) + for i in range(1, window + 1): + delta_feat[:-i] += i * feat[i:] + delta_feat[i:] += -i * feat[:-i] + delta_feat[-i:] += i * feat[-1] + delta_feat[:i] += -i * feat[0] + delta_feat /= 2 * sum(i**2 for i in range(1, window + 1)) + return delta_feat + + +def add_deltas(x, window=2, order=2): + """ + Args: + x (np.ndarray): speech feat, (T, D). + + Return: + np.ndarray: (T, (1+order)*D) + """ + feats = [x] + for _ in range(order): + feats.append(delta(feats[-1], window)) + return np.concatenate(feats, axis=1) + + +class AddDeltas(): + def __init__(self, window=2, order=2): + self.window = window + self.order = order + + def __repr__(self): + return "{name}(window={window}, order={order}".format( + name=self.__class__.__name__, window=self.window, order=self.order) + + def __call__(self, x): + return add_deltas(x, window=self.window, order=self.order) diff --git a/paddlespeech/audio/transform/channel_selector.py b/paddlespeech/audio/transform/channel_selector.py new file mode 100644 index 00000000..b078dcf8 --- /dev/null +++ b/paddlespeech/audio/transform/channel_selector.py @@ -0,0 +1,57 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +import numpy + + +class ChannelSelector(): + """Select 1ch from multi-channel signal""" + + def __init__(self, train_channel="random", eval_channel=0, axis=1): + self.train_channel = train_channel + self.eval_channel = eval_channel + self.axis = axis + + def __repr__(self): + return ("{name}(train_channel={train_channel}, " + "eval_channel={eval_channel}, axis={axis})".format( + name=self.__class__.__name__, + train_channel=self.train_channel, + eval_channel=self.eval_channel, + axis=self.axis, )) + + def __call__(self, x, train=True): + # Assuming x: [Time, Channel] by default + + if x.ndim <= self.axis: + # If the dimension is insufficient, then unsqueeze + # (e.g [Time] -> [Time, 1]) + ind = tuple( + slice(None) if i < x.ndim else None + for i in range(self.axis + 1)) + x = x[ind] + + if train: + channel = self.train_channel + else: + channel = self.eval_channel + + if channel == "random": + ch = numpy.random.randint(0, x.shape[self.axis]) + else: + ch = channel + + ind = tuple( + slice(None) if i != self.axis else ch for i in range(x.ndim)) + return x[ind] diff --git a/paddlespeech/audio/transform/cmvn.py b/paddlespeech/audio/transform/cmvn.py new file mode 100644 index 00000000..2db0070b --- /dev/null +++ b/paddlespeech/audio/transform/cmvn.py @@ -0,0 +1,201 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +import io +import json + +import h5py +import kaldiio +import numpy as np + + +class CMVN(): + "Apply Global/Spk CMVN/iverserCMVN." + + def __init__( + self, + stats, + norm_means=True, + norm_vars=False, + filetype="mat", + utt2spk=None, + spk2utt=None, + reverse=False, + std_floor=1.0e-20, ): + self.stats_file = stats + self.norm_means = norm_means + self.norm_vars = norm_vars + self.reverse = reverse + + if isinstance(stats, dict): + stats_dict = dict(stats) + else: + # Use for global CMVN + if filetype == "mat": + stats_dict = {None: kaldiio.load_mat(stats)} + # Use for global CMVN + elif filetype == "npy": + stats_dict = {None: np.load(stats)} + # Use for speaker CMVN + elif filetype == "ark": + self.accept_uttid = True + stats_dict = dict(kaldiio.load_ark(stats)) + # Use for speaker CMVN + elif filetype == "hdf5": + self.accept_uttid = True + stats_dict = h5py.File(stats) + else: + raise ValueError("Not supporting filetype={}".format(filetype)) + + if utt2spk is not None: + self.utt2spk = {} + with io.open(utt2spk, "r", encoding="utf-8") as f: + for line in f: + utt, spk = line.rstrip().split(None, 1) + self.utt2spk[utt] = spk + elif spk2utt is not None: + self.utt2spk = {} + with io.open(spk2utt, "r", encoding="utf-8") as f: + for line in f: + spk, utts = line.rstrip().split(None, 1) + for utt in utts.split(): + self.utt2spk[utt] = spk + else: + self.utt2spk = None + + # Kaldi makes a matrix for CMVN which has a shape of (2, feat_dim + 1), + # and the first vector contains the sum of feats and the second is + # the sum of squares. The last value of the first, i.e. stats[0,-1], + # is the number of samples for this statistics. + self.bias = {} + self.scale = {} + for spk, stats in stats_dict.items(): + assert len(stats) == 2, stats.shape + + count = stats[0, -1] + + # If the feature has two or more dimensions + if not (np.isscalar(count) or isinstance(count, (int, float))): + # The first is only used + count = count.flatten()[0] + + mean = stats[0, :-1] / count + # V(x) = E(x^2) - (E(x))^2 + var = stats[1, :-1] / count - mean * mean + std = np.maximum(np.sqrt(var), std_floor) + self.bias[spk] = -mean + self.scale[spk] = 1 / std + + def __repr__(self): + return ("{name}(stats_file={stats_file}, " + "norm_means={norm_means}, norm_vars={norm_vars}, " + "reverse={reverse})".format( + name=self.__class__.__name__, + stats_file=self.stats_file, + norm_means=self.norm_means, + norm_vars=self.norm_vars, + reverse=self.reverse, )) + + def __call__(self, x, uttid=None): + if self.utt2spk is not None: + spk = self.utt2spk[uttid] + else: + spk = uttid + + if not self.reverse: + # apply cmvn + if self.norm_means: + x = np.add(x, self.bias[spk]) + if self.norm_vars: + x = np.multiply(x, self.scale[spk]) + + else: + # apply reverse cmvn + if self.norm_vars: + x = np.divide(x, self.scale[spk]) + if self.norm_means: + x = np.subtract(x, self.bias[spk]) + + return x + + +class UtteranceCMVN(): + "Apply Utterance CMVN" + + def __init__(self, norm_means=True, norm_vars=False, std_floor=1.0e-20): + self.norm_means = norm_means + self.norm_vars = norm_vars + self.std_floor = std_floor + + def __repr__(self): + return "{name}(norm_means={norm_means}, norm_vars={norm_vars})".format( + name=self.__class__.__name__, + norm_means=self.norm_means, + norm_vars=self.norm_vars, ) + + def __call__(self, x, uttid=None): + # x: [Time, Dim] + square_sums = (x**2).sum(axis=0) + mean = x.mean(axis=0) + + if self.norm_means: + x = np.subtract(x, mean) + + if self.norm_vars: + var = square_sums / x.shape[0] - mean**2 + std = np.maximum(np.sqrt(var), self.std_floor) + x = np.divide(x, std) + + return x + + +class GlobalCMVN(): + "Apply Global CMVN" + + def __init__(self, + cmvn_path, + norm_means=True, + norm_vars=True, + std_floor=1.0e-20): + # cmvn_path: Option[str, dict] + cmvn = cmvn_path + self.cmvn = cmvn + self.norm_means = norm_means + self.norm_vars = norm_vars + self.std_floor = std_floor + if isinstance(cmvn, dict): + cmvn_stats = cmvn + else: + with open(cmvn) as f: + cmvn_stats = json.load(f) + self.count = cmvn_stats['frame_num'] + self.mean = np.array(cmvn_stats['mean_stat']) / self.count + self.square_sums = np.array(cmvn_stats['var_stat']) + self.var = self.square_sums / self.count - self.mean**2 + self.std = np.maximum(np.sqrt(self.var), self.std_floor) + + def __repr__(self): + return f"""{self.__class__.__name__}( + cmvn_path={self.cmvn}, + norm_means={self.norm_means}, + norm_vars={self.norm_vars},)""" + + def __call__(self, x, uttid=None): + # x: [Time, Dim] + if self.norm_means: + x = np.subtract(x, self.mean) + + if self.norm_vars: + x = np.divide(x, self.std) + return x diff --git a/paddlespeech/audio/transform/functional.py b/paddlespeech/audio/transform/functional.py new file mode 100644 index 00000000..271819ad --- /dev/null +++ b/paddlespeech/audio/transform/functional.py @@ -0,0 +1,86 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +import inspect + +from paddlespeech.audio.transform.transform_interface import TransformInterface +from paddlespeech.audio.utils.check_kwargs import check_kwargs + + +class FuncTrans(TransformInterface): + """Functional Transformation + + WARNING: + Builtin or C/C++ functions may not work properly + because this class heavily depends on the `inspect` module. + + Usage: + + >>> def foo_bar(x, a=1, b=2): + ... '''Foo bar + ... :param x: input + ... :param int a: default 1 + ... :param int b: default 2 + ... ''' + ... return x + a - b + + + >>> class FooBar(FuncTrans): + ... _func = foo_bar + ... __doc__ = foo_bar.__doc__ + """ + + _func = None + + def __init__(self, **kwargs): + self.kwargs = kwargs + check_kwargs(self.func, kwargs) + + def __call__(self, x): + return self.func(x, **self.kwargs) + + @classmethod + def add_arguments(cls, parser): + fname = cls._func.__name__.replace("_", "-") + group = parser.add_argument_group(fname + " transformation setting") + for k, v in cls.default_params().items(): + # TODO(karita): get help and choices from docstring? + attr = k.replace("_", "-") + group.add_argument(f"--{fname}-{attr}", default=v, type=type(v)) + return parser + + @property + def func(self): + return type(self)._func + + @classmethod + def default_params(cls): + try: + d = dict(inspect.signature(cls._func).parameters) + except ValueError: + d = dict() + return { + k: v.default + for k, v in d.items() if v.default != inspect.Parameter.empty + } + + def __repr__(self): + params = self.default_params() + params.update(**self.kwargs) + ret = self.__class__.__name__ + "(" + if len(params) == 0: + return ret + ")" + for k, v in params.items(): + ret += "{}={}, ".format(k, v) + return ret[:-2] + ")" diff --git a/paddlespeech/audio/transform/perturb.py b/paddlespeech/audio/transform/perturb.py new file mode 100644 index 00000000..8044dc36 --- /dev/null +++ b/paddlespeech/audio/transform/perturb.py @@ -0,0 +1,561 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +import librosa +import numpy +import scipy +import soundfile + +import io +import os +import h5py +import numpy as np + +class SoundHDF5File(): + """Collecting sound files to a HDF5 file + + >>> f = SoundHDF5File('a.flac.h5', mode='a') + >>> array = np.random.randint(0, 100, 100, dtype=np.int16) + >>> f['id'] = (array, 16000) + >>> array, rate = f['id'] + + + :param: str filepath: + :param: str mode: + :param: str format: The type used when saving wav. flac, nist, htk, etc. + :param: str dtype: + + """ + + def __init__(self, + filepath, + mode="r+", + format=None, + dtype="int16", + **kwargs): + self.filepath = filepath + self.mode = mode + self.dtype = dtype + + self.file = h5py.File(filepath, mode, **kwargs) + if format is None: + # filepath = a.flac.h5 -> format = flac + second_ext = os.path.splitext(os.path.splitext(filepath)[0])[1] + format = second_ext[1:] + if format.upper() not in soundfile.available_formats(): + # If not found, flac is selected + format = "flac" + + # This format affects only saving + self.format = format + + def __repr__(self): + return ''.format( + self.filepath, self.mode, self.format, self.dtype) + + def create_dataset(self, name, shape=None, data=None, **kwds): + f = io.BytesIO() + array, rate = data + soundfile.write(f, array, rate, format=self.format) + self.file.create_dataset( + name, shape=shape, data=np.void(f.getvalue()), **kwds) + + def __setitem__(self, name, data): + self.create_dataset(name, data=data) + + def __getitem__(self, key): + data = self.file[key][()] + f = io.BytesIO(data.tobytes()) + array, rate = soundfile.read(f, dtype=self.dtype) + return array, rate + + def keys(self): + return self.file.keys() + + def values(self): + for k in self.file: + yield self[k] + + def items(self): + for k in self.file: + yield k, self[k] + + def __iter__(self): + return iter(self.file) + + def __contains__(self, item): + return item in self.file + + def __len__(self, item): + return len(self.file) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.file.close() + + def close(self): + self.file.close() + +class SpeedPerturbation(): + """SpeedPerturbation + + The speed perturbation in kaldi uses sox-speed instead of sox-tempo, + and sox-speed just to resample the input, + i.e pitch and tempo are changed both. + + "Why use speed option instead of tempo -s in SoX for speed perturbation" + https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8 + + Warning: + This function is very slow because of resampling. + I recommmend to apply speed-perturb outside the training using sox. + + """ + + def __init__( + self, + lower=0.9, + upper=1.1, + utt2ratio=None, + keep_length=True, + res_type="kaiser_best", + seed=None, ): + self.res_type = res_type + self.keep_length = keep_length + self.state = numpy.random.RandomState(seed) + + if utt2ratio is not None: + self.utt2ratio = {} + # Use the scheduled ratio for each utterances + self.utt2ratio_file = utt2ratio + self.lower = None + self.upper = None + self.accept_uttid = True + + with open(utt2ratio, "r") as f: + for line in f: + utt, ratio = line.rstrip().split(None, 1) + ratio = float(ratio) + self.utt2ratio[utt] = ratio + else: + self.utt2ratio = None + # The ratio is given on runtime randomly + self.lower = lower + self.upper = upper + + def __repr__(self): + if self.utt2ratio is None: + return "{}(lower={}, upper={}, " "keep_length={}, res_type={})".format( + self.__class__.__name__, + self.lower, + self.upper, + self.keep_length, + self.res_type, ) + else: + return "{}({}, res_type={})".format( + self.__class__.__name__, self.utt2ratio_file, self.res_type) + + def __call__(self, x, uttid=None, train=True): + if not train: + return x + x = x.astype(numpy.float32) + if self.accept_uttid: + ratio = self.utt2ratio[uttid] + else: + ratio = self.state.uniform(self.lower, self.upper) + + # Note1: resample requires the sampling-rate of input and output, + # but actually only the ratio is used. + y = librosa.resample( + x, orig_sr=ratio, target_sr=1, res_type=self.res_type) + + if self.keep_length: + diff = abs(len(x) - len(y)) + if len(y) > len(x): + # Truncate noise + y = y[diff // 2:-((diff + 1) // 2)] + elif len(y) < len(x): + # Assume the time-axis is the first: (Time, Channel) + pad_width = [(diff // 2, (diff + 1) // 2)] + [ + (0, 0) for _ in range(y.ndim - 1) + ] + y = numpy.pad( + y, pad_width=pad_width, constant_values=0, mode="constant") + return y + + +class SpeedPerturbationSox(): + """SpeedPerturbationSox + + The speed perturbation in kaldi uses sox-speed instead of sox-tempo, + and sox-speed just to resample the input, + i.e pitch and tempo are changed both. + + To speed up or slow down the sound of a file, + use speed to modify the pitch and the duration of the file. + This raises the speed and reduces the time. + The default factor is 1.0 which makes no change to the audio. + 2.0 doubles speed, thus time length is cut by a half and pitch is one interval higher. + + "Why use speed option instead of tempo -s in SoX for speed perturbation" + https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8 + + tempo option: + sox -t wav input.wav -t wav output.tempo0.9.wav tempo -s 0.9 + + speed option: + sox -t wav input.wav -t wav output.speed0.9.wav speed 0.9 + + If we use speed option like above, the pitch of audio also will be changed, + but the tempo option does not change the pitch. + """ + + def __init__( + self, + lower=0.9, + upper=1.1, + utt2ratio=None, + keep_length=True, + sr=16000, + seed=None, ): + self.sr = sr + self.keep_length = keep_length + self.state = numpy.random.RandomState(seed) + + try: + import soxbindings as sox + except ImportError: + try: + from paddlespeech.s2t.utils import dynamic_pip_install + package = "sox" + dynamic_pip_install.install(package) + package = "soxbindings" + if sys.platform != "win32": + dynamic_pip_install.install(package) + import soxbindings as sox + except Exception: + raise RuntimeError( + "Can not install soxbindings on your system.") + self.sox = sox + + if utt2ratio is not None: + self.utt2ratio = {} + # Use the scheduled ratio for each utterances + self.utt2ratio_file = utt2ratio + self.lower = None + self.upper = None + self.accept_uttid = True + + with open(utt2ratio, "r") as f: + for line in f: + utt, ratio = line.rstrip().split(None, 1) + ratio = float(ratio) + self.utt2ratio[utt] = ratio + else: + self.utt2ratio = None + # The ratio is given on runtime randomly + self.lower = lower + self.upper = upper + + def __repr__(self): + if self.utt2ratio is None: + return f"""{self.__class__.__name__}( + lower={self.lower}, + upper={self.upper}, + keep_length={self.keep_length}, + sample_rate={self.sr})""" + + else: + return f"""{self.__class__.__name__}( + utt2ratio={self.utt2ratio_file}, + sample_rate={self.sr})""" + + def __call__(self, x, uttid=None, train=True): + if not train: + return x + + x = x.astype(numpy.float32) + if self.accept_uttid: + ratio = self.utt2ratio[uttid] + else: + ratio = self.state.uniform(self.lower, self.upper) + + tfm = self.sox.Transformer() + tfm.set_globals(multithread=False) + tfm.speed(ratio) + y = tfm.build_array(input_array=x, sample_rate_in=self.sr) + + if self.keep_length: + diff = abs(len(x) - len(y)) + if len(y) > len(x): + # Truncate noise + y = y[diff // 2:-((diff + 1) // 2)] + elif len(y) < len(x): + # Assume the time-axis is the first: (Time, Channel) + pad_width = [(diff // 2, (diff + 1) // 2)] + [ + (0, 0) for _ in range(y.ndim - 1) + ] + y = numpy.pad( + y, pad_width=pad_width, constant_values=0, mode="constant") + + if y.ndim == 2 and x.ndim == 1: + # (T, C) -> (T) + y = y.sequence(1) + return y + + +class BandpassPerturbation(): + """BandpassPerturbation + + Randomly dropout along the frequency axis. + + The original idea comes from the following: + "randomly-selected frequency band was cut off under the constraint of + leaving at least 1,000 Hz band within the range of less than 4,000Hz." + (The Hitachi/JHU CHiME-5 system: Advances in speech recognition for + everyday home environments using multiple microphone arrays; + http://spandh.dcs.shef.ac.uk/chime_workshop/papers/CHiME_2018_paper_kanda.pdf) + + """ + + def __init__(self, lower=0.0, upper=0.75, seed=None, axes=(-1, )): + self.lower = lower + self.upper = upper + self.state = numpy.random.RandomState(seed) + # x_stft: (Time, Channel, Freq) + self.axes = axes + + def __repr__(self): + return "{}(lower={}, upper={})".format(self.__class__.__name__, + self.lower, self.upper) + + def __call__(self, x_stft, uttid=None, train=True): + if not train: + return x_stft + + if x_stft.ndim == 1: + raise RuntimeError("Input in time-freq domain: " + "(Time, Channel, Freq) or (Time, Freq)") + + ratio = self.state.uniform(self.lower, self.upper) + axes = [i if i >= 0 else x_stft.ndim - i for i in self.axes] + shape = [s if i in axes else 1 for i, s in enumerate(x_stft.shape)] + + mask = self.state.randn(*shape) > ratio + x_stft *= mask + return x_stft + + +class VolumePerturbation(): + def __init__(self, + lower=-1.6, + upper=1.6, + utt2ratio=None, + dbunit=True, + seed=None): + self.dbunit = dbunit + self.utt2ratio_file = utt2ratio + self.lower = lower + self.upper = upper + self.state = numpy.random.RandomState(seed) + + if utt2ratio is not None: + # Use the scheduled ratio for each utterances + self.utt2ratio = {} + self.lower = None + self.upper = None + self.accept_uttid = True + + with open(utt2ratio, "r") as f: + for line in f: + utt, ratio = line.rstrip().split(None, 1) + ratio = float(ratio) + self.utt2ratio[utt] = ratio + else: + # The ratio is given on runtime randomly + self.utt2ratio = None + + def __repr__(self): + if self.utt2ratio is None: + return "{}(lower={}, upper={}, dbunit={})".format( + self.__class__.__name__, self.lower, self.upper, self.dbunit) + else: + return '{}("{}", dbunit={})'.format( + self.__class__.__name__, self.utt2ratio_file, self.dbunit) + + def __call__(self, x, uttid=None, train=True): + if not train: + return x + + x = x.astype(numpy.float32) + + if self.accept_uttid: + ratio = self.utt2ratio[uttid] + else: + ratio = self.state.uniform(self.lower, self.upper) + if self.dbunit: + ratio = 10**(ratio / 20) + return x * ratio + + +class NoiseInjection(): + """Add isotropic noise""" + + def __init__( + self, + utt2noise=None, + lower=-20, + upper=-5, + utt2ratio=None, + filetype="list", + dbunit=True, + seed=None, ): + self.utt2noise_file = utt2noise + self.utt2ratio_file = utt2ratio + self.filetype = filetype + self.dbunit = dbunit + self.lower = lower + self.upper = upper + self.state = numpy.random.RandomState(seed) + + if utt2ratio is not None: + # Use the scheduled ratio for each utterances + self.utt2ratio = {} + with open(utt2noise, "r") as f: + for line in f: + utt, snr = line.rstrip().split(None, 1) + snr = float(snr) + self.utt2ratio[utt] = snr + else: + # The ratio is given on runtime randomly + self.utt2ratio = None + + if utt2noise is not None: + self.utt2noise = {} + if filetype == "list": + with open(utt2noise, "r") as f: + for line in f: + utt, filename = line.rstrip().split(None, 1) + signal, rate = soundfile.read(filename, dtype="int16") + # Load all files in memory + self.utt2noise[utt] = (signal, rate) + + elif filetype == "sound.hdf5": + self.utt2noise = SoundHDF5File(utt2noise, "r") + else: + raise ValueError(filetype) + else: + self.utt2noise = None + + if utt2noise is not None and utt2ratio is not None: + if set(self.utt2ratio) != set(self.utt2noise): + raise RuntimeError("The uttids mismatch between {} and {}". + format(utt2ratio, utt2noise)) + + def __repr__(self): + if self.utt2ratio is None: + return "{}(lower={}, upper={}, dbunit={})".format( + self.__class__.__name__, self.lower, self.upper, self.dbunit) + else: + return '{}("{}", dbunit={})'.format( + self.__class__.__name__, self.utt2ratio_file, self.dbunit) + + def __call__(self, x, uttid=None, train=True): + if not train: + return x + x = x.astype(numpy.float32) + + # 1. Get ratio of noise to signal in sound pressure level + if uttid is not None and self.utt2ratio is not None: + ratio = self.utt2ratio[uttid] + else: + ratio = self.state.uniform(self.lower, self.upper) + + if self.dbunit: + ratio = 10**(ratio / 20) + scale = ratio * numpy.sqrt((x**2).mean()) + + # 2. Get noise + if self.utt2noise is not None: + # Get noise from the external source + if uttid is not None: + noise, rate = self.utt2noise[uttid] + else: + # Randomly select the noise source + noise = self.state.choice(list(self.utt2noise.values())) + # Normalize the level + noise /= numpy.sqrt((noise**2).mean()) + + # Adjust the noise length + diff = abs(len(x) - len(noise)) + offset = self.state.randint(0, diff) + if len(noise) > len(x): + # Truncate noise + noise = noise[offset:-(diff - offset)] + else: + noise = numpy.pad( + noise, pad_width=[offset, diff - offset], mode="wrap") + + else: + # Generate white noise + noise = self.state.normal(0, 1, x.shape) + + # 3. Add noise to signal + return x + noise * scale + + +class RIRConvolve(): + def __init__(self, utt2rir, filetype="list"): + self.utt2rir_file = utt2rir + self.filetype = filetype + + self.utt2rir = {} + if filetype == "list": + with open(utt2rir, "r") as f: + for line in f: + utt, filename = line.rstrip().split(None, 1) + signal, rate = soundfile.read(filename, dtype="int16") + self.utt2rir[utt] = (signal, rate) + + elif filetype == "sound.hdf5": + self.utt2rir = SoundHDF5File(utt2rir, "r") + else: + raise NotImplementedError(filetype) + + def __repr__(self): + return '{}("{}")'.format(self.__class__.__name__, self.utt2rir_file) + + def __call__(self, x, uttid=None, train=True): + if not train: + return x + + x = x.astype(numpy.float32) + + if x.ndim != 1: + # Must be single channel + raise RuntimeError( + "Input x must be one dimensional array, but got {}".format( + x.shape)) + + rir, rate = self.utt2rir[uttid] + if rir.ndim == 2: + # FIXME(kamo): Use chainer.convolution_1d? + # return [Time, Channel] + return numpy.stack( + [scipy.convolve(x, r, mode="same") for r in rir], axis=-1) + else: + return scipy.convolve(x, rir, mode="same") + diff --git a/paddlespeech/audio/transform/spec_augment.py b/paddlespeech/audio/transform/spec_augment.py new file mode 100644 index 00000000..c8f0a855 --- /dev/null +++ b/paddlespeech/audio/transform/spec_augment.py @@ -0,0 +1,214 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +"""Spec Augment module for preprocessing i.e., data augmentation""" +import random + +import numpy +from PIL import Image +from PIL.Image import BICUBIC + +from .functional import FuncTrans + + +def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"): + """time warp for spec augment + + move random center frame by the random width ~ uniform(-window, window) + :param numpy.ndarray x: spectrogram (time, freq) + :param int max_time_warp: maximum time frames to warp + :param bool inplace: overwrite x with the result + :param str mode: "PIL" (default, fast, not differentiable) or "sparse_image_warp" + (slow, differentiable) + :returns numpy.ndarray: time warped spectrogram (time, freq) + """ + window = max_time_warp + if window == 0: + return x + + if mode == "PIL": + t = x.shape[0] + if t - window <= window: + return x + # NOTE: randrange(a, b) emits a, a + 1, ..., b - 1 + center = random.randrange(window, t - window) + warped = random.randrange(center - window, center + + window) + 1 # 1 ... t - 1 + + left = Image.fromarray(x[:center]).resize((x.shape[1], warped), BICUBIC) + right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped), + BICUBIC) + if inplace: + x[:warped] = left + x[warped:] = right + return x + return numpy.concatenate((left, right), 0) + elif mode == "sparse_image_warp": + import paddle + + from espnet.utils import spec_augment + + # TODO(karita): make this differentiable again + return spec_augment.time_warp(paddle.to_tensor(x), window).numpy() + else: + raise NotImplementedError("unknown resize mode: " + mode + + ", choose one from (PIL, sparse_image_warp).") + + +class TimeWarp(FuncTrans): + _func = time_warp + __doc__ = time_warp.__doc__ + + def __call__(self, x, train): + if not train: + return x + return super().__call__(x) + + +def freq_mask(x, F=30, n_mask=2, replace_with_zero=True, inplace=False): + """freq mask for spec agument + + :param numpy.ndarray x: (time, freq) + :param int n_mask: the number of masks + :param bool inplace: overwrite + :param bool replace_with_zero: pad zero on mask if true else use mean + """ + if inplace: + cloned = x + else: + cloned = x.copy() + + num_mel_channels = cloned.shape[1] + fs = numpy.random.randint(0, F, size=(n_mask, 2)) + + for f, mask_end in fs: + f_zero = random.randrange(0, num_mel_channels - f) + mask_end += f_zero + + # avoids randrange error if values are equal and range is empty + if f_zero == f_zero + f: + continue + + if replace_with_zero: + cloned[:, f_zero:mask_end] = 0 + else: + cloned[:, f_zero:mask_end] = cloned.mean() + return cloned + + +class FreqMask(FuncTrans): + _func = freq_mask + __doc__ = freq_mask.__doc__ + + def __call__(self, x, train): + if not train: + return x + return super().__call__(x) + + +def time_mask(spec, T=40, n_mask=2, replace_with_zero=True, inplace=False): + """freq mask for spec agument + + :param numpy.ndarray spec: (time, freq) + :param int n_mask: the number of masks + :param bool inplace: overwrite + :param bool replace_with_zero: pad zero on mask if true else use mean + """ + if inplace: + cloned = spec + else: + cloned = spec.copy() + len_spectro = cloned.shape[0] + ts = numpy.random.randint(0, T, size=(n_mask, 2)) + for t, mask_end in ts: + # avoid randint range error + if len_spectro - t <= 0: + continue + t_zero = random.randrange(0, len_spectro - t) + + # avoids randrange error if values are equal and range is empty + if t_zero == t_zero + t: + continue + + mask_end += t_zero + if replace_with_zero: + cloned[t_zero:mask_end] = 0 + else: + cloned[t_zero:mask_end] = cloned.mean() + return cloned + + +class TimeMask(FuncTrans): + _func = time_mask + __doc__ = time_mask.__doc__ + + def __call__(self, x, train): + if not train: + return x + return super().__call__(x) + + +def spec_augment( + x, + resize_mode="PIL", + max_time_warp=80, + max_freq_width=27, + n_freq_mask=2, + max_time_width=100, + n_time_mask=2, + inplace=True, + replace_with_zero=True, ): + """spec agument + + apply random time warping and time/freq masking + default setting is based on LD (Librispeech double) in Table 2 + https://arxiv.org/pdf/1904.08779.pdf + + :param numpy.ndarray x: (time, freq) + :param str resize_mode: "PIL" (fast, nondifferentiable) or "sparse_image_warp" + (slow, differentiable) + :param int max_time_warp: maximum frames to warp the center frame in spectrogram (W) + :param int freq_mask_width: maximum width of the random freq mask (F) + :param int n_freq_mask: the number of the random freq mask (m_F) + :param int time_mask_width: maximum width of the random time mask (T) + :param int n_time_mask: the number of the random time mask (m_T) + :param bool inplace: overwrite intermediate array + :param bool replace_with_zero: pad zero on mask if true else use mean + """ + assert isinstance(x, numpy.ndarray) + assert x.ndim == 2 + x = time_warp(x, max_time_warp, inplace=inplace, mode=resize_mode) + x = freq_mask( + x, + max_freq_width, + n_freq_mask, + inplace=inplace, + replace_with_zero=replace_with_zero, ) + x = time_mask( + x, + max_time_width, + n_time_mask, + inplace=inplace, + replace_with_zero=replace_with_zero, ) + return x + + +class SpecAugment(FuncTrans): + _func = spec_augment + __doc__ = spec_augment.__doc__ + + def __call__(self, x, train): + if not train: + return x + return super().__call__(x) diff --git a/paddlespeech/audio/transform/spectrogram.py b/paddlespeech/audio/transform/spectrogram.py new file mode 100644 index 00000000..864f3f99 --- /dev/null +++ b/paddlespeech/audio/transform/spectrogram.py @@ -0,0 +1,475 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +import librosa +import numpy as np +import paddle +from python_speech_features import logfbank + +from ..compliance import kaldi + + +def stft(x, + n_fft, + n_shift, + win_length=None, + window="hann", + center=True, + pad_mode="reflect"): + # x: [Time, Channel] + if x.ndim == 1: + single_channel = True + # x: [Time] -> [Time, Channel] + x = x[:, None] + else: + single_channel = False + x = x.astype(np.float32) + + # FIXME(kamo): librosa.stft can't use multi-channel? + # x: [Time, Channel, Freq] + x = np.stack( + [ + librosa.stft( + y=x[:, ch], + n_fft=n_fft, + hop_length=n_shift, + win_length=win_length, + window=window, + center=center, + pad_mode=pad_mode, ).T for ch in range(x.shape[1]) + ], + axis=1, ) + + if single_channel: + # x: [Time, Channel, Freq] -> [Time, Freq] + x = x[:, 0] + return x + + +def istft(x, n_shift, win_length=None, window="hann", center=True): + # x: [Time, Channel, Freq] + if x.ndim == 2: + single_channel = True + # x: [Time, Freq] -> [Time, Channel, Freq] + x = x[:, None, :] + else: + single_channel = False + + # x: [Time, Channel] + x = np.stack( + [ + librosa.istft( + stft_matrix=x[:, ch].T, # [Time, Freq] -> [Freq, Time] + hop_length=n_shift, + win_length=win_length, + window=window, + center=center, ) for ch in range(x.shape[1]) + ], + axis=1, ) + + if single_channel: + # x: [Time, Channel] -> [Time] + x = x[:, 0] + return x + + +def stft2logmelspectrogram(x_stft, + fs, + n_mels, + n_fft, + fmin=None, + fmax=None, + eps=1e-10): + # x_stft: (Time, Channel, Freq) or (Time, Freq) + fmin = 0 if fmin is None else fmin + fmax = fs / 2 if fmax is None else fmax + + # spc: (Time, Channel, Freq) or (Time, Freq) + spc = np.abs(x_stft) + # mel_basis: (Mel_freq, Freq) + mel_basis = librosa.filters.mel( + sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) + # lmspc: (Time, Channel, Mel_freq) or (Time, Mel_freq) + lmspc = np.log10(np.maximum(eps, np.dot(spc, mel_basis.T))) + + return lmspc + + +def spectrogram(x, n_fft, n_shift, win_length=None, window="hann"): + # x: (Time, Channel) -> spc: (Time, Channel, Freq) + spc = np.abs(stft(x, n_fft, n_shift, win_length, window=window)) + return spc + + +def logmelspectrogram( + x, + fs, + n_mels, + n_fft, + n_shift, + win_length=None, + window="hann", + fmin=None, + fmax=None, + eps=1e-10, + pad_mode="reflect", ): + # stft: (Time, Channel, Freq) or (Time, Freq) + x_stft = stft( + x, + n_fft=n_fft, + n_shift=n_shift, + win_length=win_length, + window=window, + pad_mode=pad_mode, ) + + return stft2logmelspectrogram( + x_stft, + fs=fs, + n_mels=n_mels, + n_fft=n_fft, + fmin=fmin, + fmax=fmax, + eps=eps) + + +class Spectrogram(): + def __init__(self, n_fft, n_shift, win_length=None, window="hann"): + self.n_fft = n_fft + self.n_shift = n_shift + self.win_length = win_length + self.window = window + + def __repr__(self): + return ("{name}(n_fft={n_fft}, n_shift={n_shift}, " + "win_length={win_length}, window={window})".format( + name=self.__class__.__name__, + n_fft=self.n_fft, + n_shift=self.n_shift, + win_length=self.win_length, + window=self.window, )) + + def __call__(self, x): + return spectrogram( + x, + n_fft=self.n_fft, + n_shift=self.n_shift, + win_length=self.win_length, + window=self.window, ) + + +class LogMelSpectrogram(): + def __init__( + self, + fs, + n_mels, + n_fft, + n_shift, + win_length=None, + window="hann", + fmin=None, + fmax=None, + eps=1e-10, ): + self.fs = fs + self.n_mels = n_mels + self.n_fft = n_fft + self.n_shift = n_shift + self.win_length = win_length + self.window = window + self.fmin = fmin + self.fmax = fmax + self.eps = eps + + def __repr__(self): + return ("{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, " + "n_shift={n_shift}, win_length={win_length}, window={window}, " + "fmin={fmin}, fmax={fmax}, eps={eps}))".format( + name=self.__class__.__name__, + fs=self.fs, + n_mels=self.n_mels, + n_fft=self.n_fft, + n_shift=self.n_shift, + win_length=self.win_length, + window=self.window, + fmin=self.fmin, + fmax=self.fmax, + eps=self.eps, )) + + def __call__(self, x): + return logmelspectrogram( + x, + fs=self.fs, + n_mels=self.n_mels, + n_fft=self.n_fft, + n_shift=self.n_shift, + win_length=self.win_length, + window=self.window, ) + + +class Stft2LogMelSpectrogram(): + def __init__(self, fs, n_mels, n_fft, fmin=None, fmax=None, eps=1e-10): + self.fs = fs + self.n_mels = n_mels + self.n_fft = n_fft + self.fmin = fmin + self.fmax = fmax + self.eps = eps + + def __repr__(self): + return ("{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, " + "fmin={fmin}, fmax={fmax}, eps={eps}))".format( + name=self.__class__.__name__, + fs=self.fs, + n_mels=self.n_mels, + n_fft=self.n_fft, + fmin=self.fmin, + fmax=self.fmax, + eps=self.eps, )) + + def __call__(self, x): + return stft2logmelspectrogram( + x, + fs=self.fs, + n_mels=self.n_mels, + n_fft=self.n_fft, + fmin=self.fmin, + fmax=self.fmax, ) + + +class Stft(): + def __init__( + self, + n_fft, + n_shift, + win_length=None, + window="hann", + center=True, + pad_mode="reflect", ): + self.n_fft = n_fft + self.n_shift = n_shift + self.win_length = win_length + self.window = window + self.center = center + self.pad_mode = pad_mode + + def __repr__(self): + return ("{name}(n_fft={n_fft}, n_shift={n_shift}, " + "win_length={win_length}, window={window}," + "center={center}, pad_mode={pad_mode})".format( + name=self.__class__.__name__, + n_fft=self.n_fft, + n_shift=self.n_shift, + win_length=self.win_length, + window=self.window, + center=self.center, + pad_mode=self.pad_mode, )) + + def __call__(self, x): + return stft( + x, + self.n_fft, + self.n_shift, + win_length=self.win_length, + window=self.window, + center=self.center, + pad_mode=self.pad_mode, ) + + +class IStft(): + def __init__(self, n_shift, win_length=None, window="hann", center=True): + self.n_shift = n_shift + self.win_length = win_length + self.window = window + self.center = center + + def __repr__(self): + return ("{name}(n_shift={n_shift}, " + "win_length={win_length}, window={window}," + "center={center})".format( + name=self.__class__.__name__, + n_shift=self.n_shift, + win_length=self.win_length, + window=self.window, + center=self.center, )) + + def __call__(self, x): + return istft( + x, + self.n_shift, + win_length=self.win_length, + window=self.window, + center=self.center, ) + + +class LogMelSpectrogramKaldi(): + def __init__( + self, + fs=16000, + n_mels=80, + n_shift=160, # unit:sample, 10ms + win_length=400, # unit:sample, 25ms + energy_floor=0.0, + dither=0.1): + """ + The Kaldi implementation of LogMelSpectrogram + Args: + fs (int): sample rate of the audio + n_mels (int): number of mel filter banks + n_shift (int): number of points in a frame shift + win_length (int): number of points in a frame windows + energy_floor (float): Floor on energy in Spectrogram computation (absolute) + dither (float): Dithering constant + + Returns: + LogMelSpectrogramKaldi + """ + + self.fs = fs + self.n_mels = n_mels + num_point_ms = fs / 1000 + self.n_frame_length = win_length / num_point_ms + self.n_frame_shift = n_shift / num_point_ms + self.energy_floor = energy_floor + self.dither = dither + + def __repr__(self): + return ( + "{name}(fs={fs}, n_mels={n_mels}, " + "n_frame_shift={n_frame_shift}, n_frame_length={n_frame_length}, " + "dither={dither}))".format( + name=self.__class__.__name__, + fs=self.fs, + n_mels=self.n_mels, + n_frame_shift=self.n_frame_shift, + n_frame_length=self.n_frame_length, + dither=self.dither, )) + + def __call__(self, x, train): + """ + Args: + x (np.ndarray): shape (Ti,) + train (bool): True, train mode. + + Raises: + ValueError: not support (Ti, C) + + Returns: + np.ndarray: (T, D) + """ + dither = self.dither if train else 0.0 + if x.ndim != 1: + raise ValueError("Not support x: [Time, Channel]") + waveform = paddle.to_tensor(np.expand_dims(x, 0), dtype=paddle.float32) + mat = kaldi.fbank( + waveform, + n_mels=self.n_mels, + frame_length=self.n_frame_length, + frame_shift=self.n_frame_shift, + dither=dither, + energy_floor=self.energy_floor, + sr=self.fs) + mat = np.squeeze(mat.numpy()) + return mat + + +class LogMelSpectrogramKaldi_decay(): + def __init__( + self, + fs=16000, + n_mels=80, + n_fft=512, # fft point + n_shift=160, # unit:sample, 10ms + win_length=400, # unit:sample, 25ms + window="povey", + fmin=20, + fmax=None, + eps=1e-10, + dither=1.0): + self.fs = fs + self.n_mels = n_mels + self.n_fft = n_fft + if n_shift > win_length: + raise ValueError("Stride size must not be greater than " + "window size.") + self.n_shift = n_shift / fs # unit: ms + self.win_length = win_length / fs # unit: ms + + self.window = window + self.fmin = fmin + if fmax is None: + fmax_ = fmax if fmax else self.fs / 2 + elif fmax > int(self.fs / 2): + raise ValueError("fmax must not be greater than half of " + "sample rate.") + self.fmax = fmax_ + + self.eps = eps + self.remove_dc_offset = True + self.preemph = 0.97 + self.dither = dither # only work in train mode + + def __repr__(self): + return ( + "{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, " + "n_shift={n_shift}, win_length={win_length}, preemph={preemph}, window={window}, " + "fmin={fmin}, fmax={fmax}, eps={eps}, dither={dither}))".format( + name=self.__class__.__name__, + fs=self.fs, + n_mels=self.n_mels, + n_fft=self.n_fft, + n_shift=self.n_shift, + preemph=self.preemph, + win_length=self.win_length, + window=self.window, + fmin=self.fmin, + fmax=self.fmax, + eps=self.eps, + dither=self.dither, )) + + def __call__(self, x, train): + """ + + Args: + x (np.ndarray): shape (Ti,) + train (bool): True, train mode. + + Raises: + ValueError: not support (Ti, C) + + Returns: + np.ndarray: (T, D) + """ + dither = self.dither if train else 0.0 + if x.ndim != 1: + raise ValueError("Not support x: [Time, Channel]") + + if x.dtype in np.sctypes['float']: + # PCM32 -> PCM16 + bits = np.iinfo(np.int16).bits + x = x * 2**(bits - 1) + + # logfbank need PCM16 input + y = logfbank( + signal=x, + samplerate=self.fs, + winlen=self.win_length, # unit ms + winstep=self.n_shift, # unit ms + nfilt=self.n_mels, + nfft=self.n_fft, + lowfreq=self.fmin, + highfreq=self.fmax, + dither=dither, + remove_dc_offset=self.remove_dc_offset, + preemph=self.preemph, + wintype=self.window) + return y diff --git a/paddlespeech/audio/transform/transform_interface.py b/paddlespeech/audio/transform/transform_interface.py new file mode 100644 index 00000000..8bc62420 --- /dev/null +++ b/paddlespeech/audio/transform/transform_interface.py @@ -0,0 +1,35 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) + + +class TransformInterface: + """Transform Interface""" + + def __call__(self, x): + raise NotImplementedError("__call__ method is not implemented") + + @classmethod + def add_arguments(cls, parser): + return parser + + def __repr__(self): + return self.__class__.__name__ + "()" + + +class Identity(TransformInterface): + """Identity Function""" + + def __call__(self, x): + return x diff --git a/paddlespeech/audio/transform/transformation.py b/paddlespeech/audio/transform/transformation.py new file mode 100644 index 00000000..d24d6437 --- /dev/null +++ b/paddlespeech/audio/transform/transformation.py @@ -0,0 +1,158 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +"""Transformation module.""" +import copy +import io +import logging +from collections import OrderedDict +from collections.abc import Sequence +from inspect import signature + +import yaml + +from ..utils.dynamic_import import dynamic_import + +import_alias = dict( + identity="paddlespeech.audio.transform.transform_interface:Identity", + time_warp="paddlespeech.audio.transform.spec_augment:TimeWarp", + time_mask="paddlespeech.audio.transform.spec_augment:TimeMask", + freq_mask="paddlespeech.audio.transform.spec_augment:FreqMask", + spec_augment="paddlespeech.audio.transform.spec_augment:SpecAugment", + speed_perturbation="paddlespeech.audio.transform.perturb:SpeedPerturbation", + speed_perturbation_sox="paddlespeech.audio.transform.perturb:SpeedPerturbationSox", + volume_perturbation="paddlespeech.audio.transform.perturb:VolumePerturbation", + noise_injection="paddlespeech.audio.transform.perturb:NoiseInjection", + bandpass_perturbation="paddlespeech.audio.transform.perturb:BandpassPerturbation", + rir_convolve="paddlespeech.audio.transform.perturb:RIRConvolve", + delta="paddlespeech.audio.transform.add_deltas:AddDeltas", + cmvn="paddlespeech.audio.transform.cmvn:CMVN", + utterance_cmvn="paddlespeech.audio.transform.cmvn:UtteranceCMVN", + fbank="paddlespeech.audio.transform.spectrogram:LogMelSpectrogram", + spectrogram="paddlespeech.audio.transform.spectrogram:Spectrogram", + stft="paddlespeech.audio.transform.spectrogram:Stft", + istft="paddlespeech.audio.transform.spectrogram:IStft", + stft2fbank="paddlespeech.audio.transform.spectrogram:Stft2LogMelSpectrogram", + wpe="paddlespeech.audio.transform.wpe:WPE", + channel_selector="paddlespeech.audio.transform.channel_selector:ChannelSelector", + fbank_kaldi="paddlespeech.audio.transform.spectrogram:LogMelSpectrogramKaldi", + cmvn_json="paddlespeech.audio.transform.cmvn:GlobalCMVN") + + +class Transformation(): + """Apply some functions to the mini-batch + + Examples: + >>> kwargs = {"process": [{"type": "fbank", + ... "n_mels": 80, + ... "fs": 16000}, + ... {"type": "cmvn", + ... "stats": "data/train/cmvn.ark", + ... "norm_vars": True}, + ... {"type": "delta", "window": 2, "order": 2}]} + >>> transform = Transformation(kwargs) + >>> bs = 10 + >>> xs = [np.random.randn(100, 80).astype(np.float32) + ... for _ in range(bs)] + >>> xs = transform(xs) + """ + + def __init__(self, conffile=None): + if conffile is not None: + if isinstance(conffile, dict): + self.conf = copy.deepcopy(conffile) + else: + with io.open(conffile, encoding="utf-8") as f: + self.conf = yaml.safe_load(f) + assert isinstance(self.conf, dict), type(self.conf) + else: + self.conf = {"mode": "sequential", "process": []} + + self.functions = OrderedDict() + if self.conf.get("mode", "sequential") == "sequential": + for idx, process in enumerate(self.conf["process"]): + assert isinstance(process, dict), type(process) + opts = dict(process) + process_type = opts.pop("type") + class_obj = dynamic_import(process_type, import_alias) + # TODO(karita): assert issubclass(class_obj, TransformInterface) + try: + self.functions[idx] = class_obj(**opts) + except TypeError: + try: + signa = signature(class_obj) + except ValueError: + # Some function, e.g. built-in function, are failed + pass + else: + logging.error("Expected signature: {}({})".format( + class_obj.__name__, signa)) + raise + else: + raise NotImplementedError( + "Not supporting mode={}".format(self.conf["mode"])) + + def __repr__(self): + rep = "\n" + "\n".join(" {}: {}".format(k, v) + for k, v in self.functions.items()) + return "{}({})".format(self.__class__.__name__, rep) + + def __call__(self, xs, uttid_list=None, **kwargs): + """Return new mini-batch + + :param Union[Sequence[np.ndarray], np.ndarray] xs: + :param Union[Sequence[str], str] uttid_list: + :return: batch: + :rtype: List[np.ndarray] + """ + if not isinstance(xs, Sequence): + is_batch = False + xs = [xs] + else: + is_batch = True + + if isinstance(uttid_list, str): + uttid_list = [uttid_list for _ in range(len(xs))] + + if self.conf.get("mode", "sequential") == "sequential": + for idx in range(len(self.conf["process"])): + func = self.functions[idx] + # TODO(karita): use TrainingTrans and UttTrans to check __call__ args + # Derive only the args which the func has + try: + param = signature(func).parameters + except ValueError: + # Some function, e.g. built-in function, are failed + param = {} + _kwargs = {k: v for k, v in kwargs.items() if k in param} + try: + if uttid_list is not None and "uttid" in param: + xs = [ + func(x, u, **_kwargs) + for x, u in zip(xs, uttid_list) + ] + else: + xs = [func(x, **_kwargs) for x in xs] + except Exception: + logging.fatal("Catch a exception from {}th func: {}".format( + idx, func)) + raise + else: + raise NotImplementedError( + "Not supporting mode={}".format(self.conf["mode"])) + + if is_batch: + return xs + else: + return xs[0] diff --git a/paddlespeech/audio/transform/wpe.py b/paddlespeech/audio/transform/wpe.py new file mode 100644 index 00000000..777379d0 --- /dev/null +++ b/paddlespeech/audio/transform/wpe.py @@ -0,0 +1,58 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +from nara_wpe.wpe import wpe + + +class WPE(object): + def __init__(self, + taps=10, + delay=3, + iterations=3, + psd_context=0, + statistics_mode="full"): + self.taps = taps + self.delay = delay + self.iterations = iterations + self.psd_context = psd_context + self.statistics_mode = statistics_mode + + def __repr__(self): + return ("{name}(taps={taps}, delay={delay}" + "iterations={iterations}, psd_context={psd_context}, " + "statistics_mode={statistics_mode})".format( + name=self.__class__.__name__, + taps=self.taps, + delay=self.delay, + iterations=self.iterations, + psd_context=self.psd_context, + statistics_mode=self.statistics_mode, )) + + def __call__(self, xs): + """Return enhanced + + :param np.ndarray xs: (Time, Channel, Frequency) + :return: enhanced_xs + :rtype: np.ndarray + + """ + # nara_wpe.wpe: (F, C, T) + xs = wpe( + xs.transpose((2, 1, 0)), + taps=self.taps, + delay=self.delay, + iterations=self.iterations, + psd_context=self.psd_context, + statistics_mode=self.statistics_mode, ) + return xs.transpose(2, 1, 0) diff --git a/paddlespeech/audio/utils/check_kwargs.py b/paddlespeech/audio/utils/check_kwargs.py new file mode 100644 index 00000000..0aa839ac --- /dev/null +++ b/paddlespeech/audio/utils/check_kwargs.py @@ -0,0 +1,35 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +import inspect + + +def check_kwargs(func, kwargs, name=None): + """check kwargs are valid for func + + If kwargs are invalid, raise TypeError as same as python default + :param function func: function to be validated + :param dict kwargs: keyword arguments for func + :param str name: name used in TypeError (default is func name) + """ + try: + params = inspect.signature(func).parameters + except ValueError: + return + if name is None: + name = func.__name__ + for k in kwargs.keys(): + if k not in params: + raise TypeError( + f"{name}() got an unexpected keyword argument '{k}'") diff --git a/paddlespeech/audio/utils/dynamic_import.py b/paddlespeech/audio/utils/dynamic_import.py new file mode 100644 index 00000000..99f93356 --- /dev/null +++ b/paddlespeech/audio/utils/dynamic_import.py @@ -0,0 +1,38 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +import importlib + +__all__ = ["dynamic_import"] + + +def dynamic_import(import_path, alias=dict()): + """dynamic import module and class + + :param str import_path: syntax 'module_name:class_name' + e.g., 'paddlespeech.s2t.models.u2:U2Model' + :param dict alias: shortcut for registered class + :return: imported class + """ + if import_path not in alias and ":" not in import_path: + raise ValueError( + "import_path should be one of {} or " + 'include ":", e.g. "paddlespeech.s2t.models.u2:U2Model" : ' + "{}".format(set(alias), import_path)) + if ":" not in import_path: + import_path = alias[import_path] + + module_name, objname = import_path.split(":") + m = importlib.import_module(module_name) + return getattr(m, objname) diff --git a/paddlespeech/audio/utils/tensor_utils.py b/paddlespeech/audio/utils/tensor_utils.py new file mode 100644 index 00000000..bae473ec --- /dev/null +++ b/paddlespeech/audio/utils/tensor_utils.py @@ -0,0 +1,195 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unility functions for Transformer.""" +from typing import List +from typing import Tuple + +import paddle + +from .log import Logger + +__all__ = ["pad_sequence", "add_sos_eos", "th_accuracy", "has_tensor"] + +logger = Logger(__name__) + + +def has_tensor(val): + if isinstance(val, (list, tuple)): + for item in val: + if has_tensor(item): + return True + elif isinstance(val, dict): + for k, v in val.items(): + print(k) + if has_tensor(v): + return True + else: + return paddle.is_tensor(val) + + +def pad_sequence(sequences: List[paddle.Tensor], + batch_first: bool=False, + padding_value: float=0.0) -> paddle.Tensor: + r"""Pad a list of variable length Tensors with ``padding_value`` + + ``pad_sequence`` stacks a list of Tensors along a new dimension, + and pads them to equal length. For example, if the input is list of + sequences with size ``L x *`` and if batch_first is False, and ``T x B x *`` + otherwise. + + `B` is batch size. It is equal to the number of elements in ``sequences``. + `T` is length of the longest sequence. + `L` is length of the sequence. + `*` is any number of trailing dimensions, including none. + + Example: + >>> from paddle.nn.utils.rnn import pad_sequence + >>> a = paddle.ones(25, 300) + >>> b = paddle.ones(22, 300) + >>> c = paddle.ones(15, 300) + >>> pad_sequence([a, b, c]).shape + paddle.Tensor([25, 3, 300]) + + Note: + This function returns a Tensor of size ``T x B x *`` or ``B x T x *`` + where `T` is the length of the longest sequence. This function assumes + trailing dimensions and type of all the Tensors in sequences are same. + + Args: + sequences (list[Tensor]): list of variable length sequences. + batch_first (bool, optional): output will be in ``B x T x *`` if True, or in + ``T x B x *`` otherwise + padding_value (float, optional): value for padded elements. Default: 0. + + Returns: + Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``. + Tensor of size ``B x T x *`` otherwise + """ + + # assuming trailing dimensions and type of all the Tensors + # in sequences are same and fetching those from sequences[0] + max_size = paddle.shape(sequences[0]) + # (TODO Hui Zhang): slice not supprot `end==start` + # trailing_dims = max_size[1:] + trailing_dims = tuple( + max_size[1:].numpy().tolist()) if sequences[0].ndim >= 2 else () + max_len = max([s.shape[0] for s in sequences]) + if batch_first: + out_dims = (len(sequences), max_len) + trailing_dims + else: + out_dims = (max_len, len(sequences)) + trailing_dims + out_tensor = paddle.full(out_dims, padding_value, sequences[0].dtype) + for i, tensor in enumerate(sequences): + length = tensor.shape[0] + # use index notation to prevent duplicate references to the tensor + logger.info( + f"length {length}, out_tensor {out_tensor.shape}, tensor {tensor.shape}" + ) + if batch_first: + # TODO (Hui Zhang): set_value op not supprot `end==start` + # TODO (Hui Zhang): set_value op not support int16 + # TODO (Hui Zhang): set_varbase 2 rank not support [0,0,...] + # out_tensor[i, :length, ...] = tensor + if length != 0: + out_tensor[i, :length] = tensor + else: + out_tensor[i, length] = tensor + else: + # TODO (Hui Zhang): set_value op not supprot `end==start` + # out_tensor[:length, i, ...] = tensor + if length != 0: + out_tensor[:length, i] = tensor + else: + out_tensor[length, i] = tensor + + return out_tensor + + +def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int, + ignore_id: int) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Add and labels. + Args: + ys_pad (paddle.Tensor): batch of padded target sequences (B, Lmax) + sos (int): index of + eos (int): index of + ignore_id (int): index of padding + Returns: + ys_in (paddle.Tensor) : (B, Lmax + 1) + ys_out (paddle.Tensor) : (B, Lmax + 1) + Examples: + >>> sos_id = 10 + >>> eos_id = 11 + >>> ignore_id = -1 + >>> ys_pad + tensor([[ 1, 2, 3, 4, 5], + [ 4, 5, 6, -1, -1], + [ 7, 8, 9, -1, -1]], dtype=paddle.int32) + >>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id) + >>> ys_in + tensor([[10, 1, 2, 3, 4, 5], + [10, 4, 5, 6, 11, 11], + [10, 7, 8, 9, 11, 11]]) + >>> ys_out + tensor([[ 1, 2, 3, 4, 5, 11], + [ 4, 5, 6, 11, -1, -1], + [ 7, 8, 9, 11, -1, -1]]) + """ + # TODO(Hui Zhang): using comment code, + #_sos = paddle.to_tensor( + # [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place) + #_eos = paddle.to_tensor( + # [eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place) + #ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys + #ys_in = [paddle.cat([_sos, y], dim=0) for y in ys] + #ys_out = [paddle.cat([y, _eos], dim=0) for y in ys] + #return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id) + B = ys_pad.shape[0] + _sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos + _eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos + ys_in = paddle.cat([_sos, ys_pad], dim=1) + mask_pad = (ys_in == ignore_id) + ys_in = ys_in.masked_fill(mask_pad, eos) + + ys_out = paddle.cat([ys_pad, _eos], dim=1) + ys_out = ys_out.masked_fill(mask_pad, eos) + mask_eos = (ys_out == ignore_id) + ys_out = ys_out.masked_fill(mask_eos, eos) + ys_out = ys_out.masked_fill(mask_pad, ignore_id) + return ys_in, ys_out + + +def th_accuracy(pad_outputs: paddle.Tensor, + pad_targets: paddle.Tensor, + ignore_label: int) -> float: + """Calculate accuracy. + Args: + pad_outputs (Tensor): Prediction tensors (B * Lmax, D). + pad_targets (LongTensor): Target label tensors (B, Lmax, D). + ignore_label (int): Ignore label id. + Returns: + float: Accuracy value (0.0 - 1.0). + """ + pad_pred = pad_outputs.view(pad_targets.shape[0], pad_targets.shape[1], + pad_outputs.shape[1]).argmax(2) + mask = pad_targets != ignore_label + #TODO(Hui Zhang): sum not support bool type + # numerator = paddle.sum( + # pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) + numerator = ( + pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) + numerator = paddle.sum(numerator.type_as(pad_targets)) + #TODO(Hui Zhang): sum not support bool type + # denominator = paddle.sum(mask) + denominator = paddle.sum(mask.type_as(pad_targets)) + return float(numerator) / float(denominator) diff --git a/setup.py b/setup.py index 679549b4..b94a4cb2 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,7 @@ base = [ "pypinyin", "pypinyin-dict", "python-dateutil", "pyworld", "resampy==0.2.2", "sacrebleu", "scipy", "sentencepiece~=0.1.96", "soundfile~=0.10", "textgrid", "timer", "tqdm", "typeguard", "visualdl", "webrtcvad", - "yacs~=0.1.8", "prettytable", "zhon", 'colorlog', 'pathos == 0.2.8' + "yacs~=0.1.8", "prettytable", "zhon", 'colorlog', 'pathos == 0.2.8', 'webdataset' ] server = [