From 8f5e61090b569f9bf77f53a16668a533ae925f99 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Wed, 22 Jun 2022 05:12:31 +0000 Subject: [PATCH 1/7] new feature: Add webdataset in audio --- paddlespeech/audio/stream_data/__init__.py | 68 ++ paddlespeech/audio/stream_data/cache.py | 190 ++++ paddlespeech/audio/stream_data/compat.py | 170 ++++ paddlespeech/audio/stream_data/filters.py | 912 ++++++++++++++++++ .../audio/stream_data/paddle_utils.py | 33 + paddlespeech/audio/stream_data/pipeline.py | 127 +++ paddlespeech/audio/stream_data/shardlists.py | 257 +++++ .../audio/stream_data/tariterators.py | 283 ++++++ paddlespeech/audio/stream_data/utils.py | 128 +++ paddlespeech/audio/transform/__init__.py | 13 + paddlespeech/audio/transform/add_deltas.py | 54 ++ .../audio/transform/channel_selector.py | 57 ++ paddlespeech/audio/transform/cmvn.py | 201 ++++ paddlespeech/audio/transform/functional.py | 86 ++ paddlespeech/audio/transform/perturb.py | 561 +++++++++++ paddlespeech/audio/transform/spec_augment.py | 214 ++++ paddlespeech/audio/transform/spectrogram.py | 475 +++++++++ .../audio/transform/transform_interface.py | 35 + .../audio/transform/transformation.py | 158 +++ paddlespeech/audio/transform/wpe.py | 58 ++ paddlespeech/audio/utils/check_kwargs.py | 35 + paddlespeech/audio/utils/dynamic_import.py | 38 + paddlespeech/audio/utils/tensor_utils.py | 195 ++++ setup.py | 2 +- 24 files changed, 4349 insertions(+), 1 deletion(-) create mode 100644 paddlespeech/audio/stream_data/__init__.py create mode 100644 paddlespeech/audio/stream_data/cache.py create mode 100644 paddlespeech/audio/stream_data/compat.py create mode 100644 paddlespeech/audio/stream_data/filters.py create mode 100644 paddlespeech/audio/stream_data/paddle_utils.py create mode 100644 paddlespeech/audio/stream_data/pipeline.py create mode 100644 paddlespeech/audio/stream_data/shardlists.py create mode 100644 paddlespeech/audio/stream_data/tariterators.py create mode 100644 paddlespeech/audio/stream_data/utils.py create mode 100644 paddlespeech/audio/transform/__init__.py create mode 100644 paddlespeech/audio/transform/add_deltas.py create mode 100644 paddlespeech/audio/transform/channel_selector.py create mode 100644 paddlespeech/audio/transform/cmvn.py create mode 100644 paddlespeech/audio/transform/functional.py create mode 100644 paddlespeech/audio/transform/perturb.py create mode 100644 paddlespeech/audio/transform/spec_augment.py create mode 100644 paddlespeech/audio/transform/spectrogram.py create mode 100644 paddlespeech/audio/transform/transform_interface.py create mode 100644 paddlespeech/audio/transform/transformation.py create mode 100644 paddlespeech/audio/transform/wpe.py create mode 100644 paddlespeech/audio/utils/check_kwargs.py create mode 100644 paddlespeech/audio/utils/dynamic_import.py create mode 100644 paddlespeech/audio/utils/tensor_utils.py diff --git a/paddlespeech/audio/stream_data/__init__.py b/paddlespeech/audio/stream_data/__init__.py new file mode 100644 index 00000000..fdb3458c --- /dev/null +++ b/paddlespeech/audio/stream_data/__init__.py @@ -0,0 +1,68 @@ +# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved. +# See the LICENSE file for licensing terms (BSD-style). +# Modified from https://github.com/webdataset/webdataset +# +# flake8: noqa + +from .cache import ( + cached_tarfile_samples, + cached_tarfile_to_samples, + lru_cleanup, + pipe_cleaner, +) +from .compat import WebDataset, WebLoader, FluidWrapper +from webdataset.extradatasets import MockDataset, with_epoch, with_length +from .filters import ( + associate, + batched, + decode, + detshuffle, + extract_keys, + getfirst, + info, + map, + map_dict, + map_tuple, + pipelinefilter, + rename, + rename_keys, + rsample, + select, + shuffle, + slice, + to_tuple, + transform_with, + unbatched, + xdecode, + data_filter, + tokenize, + resample, + compute_fbank, + spec_aug, + sort, + padding, + cmvn +) +from webdataset.handlers import ( + ignore_and_continue, + ignore_and_stop, + reraise_exception, + warn_and_continue, + warn_and_stop, +) +from .pipeline import DataPipeline +from .shardlists import ( + MultiShardSample, + ResampledShards, + SimpleShardList, + non_empty, + resampled, + shardspec, + single_node_only, + split_by_node, + split_by_worker, +) +from .tariterators import tarfile_samples, tarfile_to_samples +from .utils import PipelineStage, repeatedly +from webdataset.writer import ShardWriter, TarWriter, numpy_dumps +from webdataset.mix import RandomMix, RoundRobin diff --git a/paddlespeech/audio/stream_data/cache.py b/paddlespeech/audio/stream_data/cache.py new file mode 100644 index 00000000..724f6911 --- /dev/null +++ b/paddlespeech/audio/stream_data/cache.py @@ -0,0 +1,190 @@ +# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# See the LICENSE file for licensing terms (BSD-style). +# Modified from https://github.com/webdataset/webdataset +import itertools, os, random, re, sys +from urllib.parse import urlparse + +from . import filters +from webdataset import gopen +from webdataset.handlers import reraise_exception +from .tariterators import tar_file_and_group_expander + +default_cache_dir = os.environ.get("WDS_CACHE", "./_cache") +default_cache_size = float(os.environ.get("WDS_CACHE_SIZE", "1e18")) + + +def lru_cleanup(cache_dir, cache_size, keyfn=os.path.getctime, verbose=False): + """Performs cleanup of the file cache in cache_dir using an LRU strategy, + keeping the total size of all remaining files below cache_size.""" + if not os.path.exists(cache_dir): + return + total_size = 0 + for dirpath, dirnames, filenames in os.walk(cache_dir): + for filename in filenames: + total_size += os.path.getsize(os.path.join(dirpath, filename)) + if total_size <= cache_size: + return + # sort files by last access time + files = [] + for dirpath, dirnames, filenames in os.walk(cache_dir): + for filename in filenames: + files.append(os.path.join(dirpath, filename)) + files.sort(key=keyfn, reverse=True) + # delete files until we're under the cache size + while len(files) > 0 and total_size > cache_size: + fname = files.pop() + total_size -= os.path.getsize(fname) + if verbose: + print("# deleting %s" % fname, file=sys.stderr) + os.remove(fname) + + +def download(url, dest, chunk_size=1024 ** 2, verbose=False): + """Download a file from `url` to `dest`.""" + temp = dest + f".temp{os.getpid()}" + with gopen.gopen(url) as stream: + with open(temp, "wb") as f: + while True: + data = stream.read(chunk_size) + if not data: + break + f.write(data) + os.rename(temp, dest) + + +def pipe_cleaner(spec): + """Guess the actual URL from a "pipe:" specification.""" + if spec.startswith("pipe:"): + spec = spec[5:] + words = spec.split(" ") + for word in words: + if re.match(r"^(https?|gs|ais|s3)", word): + return word + return spec + + +def get_file_cached( + spec, + cache_size=-1, + cache_dir=None, + url_to_name=pipe_cleaner, + verbose=False, +): + if cache_size == -1: + cache_size = default_cache_size + if cache_dir is None: + cache_dir = default_cache_dir + url = url_to_name(spec) + parsed = urlparse(url) + dirname, filename = os.path.split(parsed.path) + dirname = dirname.lstrip("/") + dirname = re.sub(r"[:/|;]", "_", dirname) + destdir = os.path.join(cache_dir, dirname) + os.makedirs(destdir, exist_ok=True) + dest = os.path.join(cache_dir, dirname, filename) + if not os.path.exists(dest): + if verbose: + print("# downloading %s to %s" % (url, dest), file=sys.stderr) + lru_cleanup(cache_dir, cache_size, verbose=verbose) + download(spec, dest, verbose=verbose) + return dest + + +def get_filetype(fname): + with os.popen("file '%s'" % fname) as f: + ftype = f.read() + return ftype + + +def check_tar_format(fname): + """Check whether a file is a tar archive.""" + ftype = get_filetype(fname) + return "tar archive" in ftype or "gzip compressed" in ftype + + +verbose_cache = int(os.environ.get("WDS_VERBOSE_CACHE", "0")) + + +def cached_url_opener( + data, + handler=reraise_exception, + cache_size=-1, + cache_dir=None, + url_to_name=pipe_cleaner, + validator=check_tar_format, + verbose=False, + always=False, +): + """Given a stream of url names (packaged in `dict(url=url)`), yield opened streams.""" + verbose = verbose or verbose_cache + for sample in data: + assert isinstance(sample, dict), sample + assert "url" in sample + url = sample["url"] + attempts = 5 + try: + if not always and os.path.exists(url): + dest = url + else: + dest = get_file_cached( + url, + cache_size=cache_size, + cache_dir=cache_dir, + url_to_name=url_to_name, + verbose=verbose, + ) + if verbose: + print("# opening %s" % dest, file=sys.stderr) + assert os.path.exists(dest) + if not validator(dest): + ftype = get_filetype(dest) + with open(dest, "rb") as f: + data = f.read(200) + os.remove(dest) + raise ValueError( + "%s (%s) is not a tar archive, but a %s, contains %s" + % (dest, url, ftype, repr(data)) + ) + try: + stream = open(dest, "rb") + sample.update(stream=stream) + yield sample + except FileNotFoundError as exn: + # dealing with race conditions in lru_cleanup + attempts -= 1 + if attempts > 0: + time.sleep(random.random() * 10) + continue + raise exn + except Exception as exn: + exn.args = exn.args + (url,) + if handler(exn): + continue + else: + break + + +def cached_tarfile_samples( + src, + handler=reraise_exception, + cache_size=-1, + cache_dir=None, + verbose=False, + url_to_name=pipe_cleaner, + always=False, +): + streams = cached_url_opener( + src, + handler=handler, + cache_size=cache_size, + cache_dir=cache_dir, + verbose=verbose, + url_to_name=url_to_name, + always=always, + ) + samples = tar_file_and_group_expander(streams, handler=handler) + return samples + + +cached_tarfile_to_samples = filters.pipelinefilter(cached_tarfile_samples) diff --git a/paddlespeech/audio/stream_data/compat.py b/paddlespeech/audio/stream_data/compat.py new file mode 100644 index 00000000..ee564431 --- /dev/null +++ b/paddlespeech/audio/stream_data/compat.py @@ -0,0 +1,170 @@ +# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# See the LICENSE file for licensing terms (BSD-style). +# Modified from https://github.com/webdataset/webdataset +from dataclasses import dataclass +from itertools import islice +from typing import List + +import braceexpand, yaml + +from webdataset import autodecode +from . import cache, filters, shardlists, tariterators +from .filters import reraise_exception +from .pipeline import DataPipeline +from .paddle_utils import DataLoader, IterableDataset + + +class FluidInterface: + def batched(self, batchsize): + return self.compose(filters.batched(batchsize)) + + def dynamic_batched(self, max_frames_in_batch): + return self.compose(filter.dynamic_batched(max_frames_in_batch)) + + def unbatched(self): + return self.compose(filters.unbatched()) + + def listed(self, batchsize, partial=True): + return self.compose(filters.batched(), batchsize=batchsize, collation_fn=None) + + def unlisted(self): + return self.compose(filters.unlisted()) + + def log_keys(self, logfile=None): + return self.compose(filters.log_keys(logfile)) + + def shuffle(self, size, **kw): + if size < 1: + return self + else: + return self.compose(filters.shuffle(size, **kw)) + + def map(self, f, handler=reraise_exception): + return self.compose(filters.map(f, handler=handler)) + + def decode(self, *args, pre=None, post=None, only=None, partial=False, handler=reraise_exception): + handlers = [autodecode.ImageHandler(x) if isinstance(x, str) else x for x in args] + decoder = autodecode.Decoder(handlers, pre=pre, post=post, only=only, partial=partial) + return self.map(decoder, handler=handler) + + def map_dict(self, handler=reraise_exception, **kw): + return self.compose(filters.map_dict(handler=handler, **kw)) + + def select(self, predicate, **kw): + return self.compose(filters.select(predicate, **kw)) + + def to_tuple(self, *args, handler=reraise_exception): + return self.compose(filters.to_tuple(*args, handler=handler)) + + def map_tuple(self, *args, handler=reraise_exception): + return self.compose(filters.map_tuple(*args, handler=handler)) + + def slice(self, *args): + return self.compose(filters.slice(*args)) + + def rename(self, **kw): + return self.compose(filters.rename(**kw)) + + def rsample(self, p=0.5): + return self.compose(filters.rsample(p)) + + def rename_keys(self, *args, **kw): + return self.compose(filters.rename_keys(*args, **kw)) + + def extract_keys(self, *args, **kw): + return self.compose(filters.extract_keys(*args, **kw)) + + def xdecode(self, *args, **kw): + return self.compose(filters.xdecode(*args, **kw)) + + def data_filter(self, *args, **kw): + return self.compose(filters.data_filter(*args, **kw)) + + def tokenize(self, *args, **kw): + return self.compose(filters.tokenize(*args, **kw)) + + def resample(self, *args, **kw): + return self.compose(filters.resample(*args, **kw)) + + def compute_fbank(self, *args, **kw): + return self.compose(filters.compute_fbank(*args, **kw)) + + def spec_aug(self, *args, **kw): + return self.compose(filters.spec_aug(*args, **kw)) + + def sort(self, size=500): + return self.compose(filters.sort(size)) + + def padding(self): + return self.compose(filters.padding()) + + def cmvn(self, cmvn_file): + return self.compose(filters.cmvn(cmvn_file)) + +class WebDataset(DataPipeline, FluidInterface): + """Small fluid-interface wrapper for DataPipeline.""" + + def __init__( + self, + urls, + handler=reraise_exception, + resampled=False, + repeat=False, + shardshuffle=None, + cache_size=0, + cache_dir=None, + detshuffle=False, + nodesplitter=shardlists.single_node_only, + verbose=False, + ): + super().__init__() + if isinstance(urls, IterableDataset): + assert not resampled + self.append(urls) + elif isinstance(urls, str) and (urls.endswith(".yaml") or urls.endswith(".yml")): + with (open(urls)) as stream: + spec = yaml.safe_load(stream) + assert "datasets" in spec + self.append(shardlists.MultiShardSample(spec)) + elif isinstance(urls, dict): + assert "datasets" in urls + self.append(shardlists.MultiShardSample(urls)) + elif resampled: + self.append(shardlists.ResampledShards(urls)) + else: + self.append(shardlists.SimpleShardList(urls)) + self.append(nodesplitter) + self.append(shardlists.split_by_worker) + if shardshuffle is True: + shardshuffle = 100 + if shardshuffle is not None: + if detshuffle: + self.append(filters.detshuffle(shardshuffle)) + else: + self.append(filters.shuffle(shardshuffle)) + if cache_size == 0: + self.append(tariterators.tarfile_to_samples(handler=handler)) + else: + assert cache_size == -1 or cache_size > 0 + self.append( + cache.cached_tarfile_to_samples( + handler=handler, + verbose=verbose, + cache_size=cache_size, + cache_dir=cache_dir, + ) + ) + + +class FluidWrapper(DataPipeline, FluidInterface): + """Small fluid-interface wrapper for DataPipeline.""" + + def __init__(self, initial): + super().__init__() + self.append(initial) + + +class WebLoader(DataPipeline, FluidInterface): + def __init__(self, *args, **kw): + super().__init__(DataLoader(*args, **kw)) diff --git a/paddlespeech/audio/stream_data/filters.py b/paddlespeech/audio/stream_data/filters.py new file mode 100644 index 00000000..3112c954 --- /dev/null +++ b/paddlespeech/audio/stream_data/filters.py @@ -0,0 +1,912 @@ +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). +# + +# Modified from https://github.com/webdataset/webdataset +# Modified from wenet(https://github.com/wenet-e2e/wenet) +"""A collection of iterators for data transformations. + +These functions are plain iterator functions. You can find curried versions +in webdataset.filters, and you can find IterableDataset wrappers in +webdataset.processing. +""" + +import io +from fnmatch import fnmatch +import re +import itertools, os, random, sys, time +from functools import reduce, wraps + +import numpy as np + +from webdataset import autodecode +from . import utils +from .paddle_utils import PaddleTensor +from .utils import PipelineStage + +from .. import backends +from ..compliance import kaldi +import paddle +from ..transform.cmvn import GlobalCMVN +from ..utils.tensor_utils import pad_sequence +from ..transform.spec_augment import time_warp +from ..transform.spec_augment import time_mask +from ..transform.spec_augment import freq_mask + +class FilterFunction(object): + """Helper class for currying pipeline stages. + + We use this roundabout construct becauce it can be pickled. + """ + + def __init__(self, f, *args, **kw): + """Create a curried function.""" + self.f = f + self.args = args + self.kw = kw + + def __call__(self, data): + """Call the curried function with the given argument.""" + return self.f(data, *self.args, **self.kw) + + def __str__(self): + """Compute a string representation.""" + return f"<{self.f.__name__} {self.args} {self.kw}>" + + def __repr__(self): + """Compute a string representation.""" + return f"<{self.f.__name__} {self.args} {self.kw}>" + + +class RestCurried(object): + """Helper class for currying pipeline stages. + + We use this roundabout construct because it can be pickled. + """ + + def __init__(self, f): + """Store the function for future currying.""" + self.f = f + + def __call__(self, *args, **kw): + """Curry with the given arguments.""" + return FilterFunction(self.f, *args, **kw) + + +def pipelinefilter(f): + """Turn the decorated function into one that is partially applied for + all arguments other than the first.""" + result = RestCurried(f) + return result + + +def reraise_exception(exn): + """Reraises the given exception; used as a handler. + + :param exn: exception + """ + raise exn + + +def identity(x): + """Return the argument.""" + return x + + +def compose2(f, g): + """Compose two functions, g(f(x)).""" + return lambda x: g(f(x)) + + +def compose(*args): + """Compose a sequence of functions (left-to-right).""" + return reduce(compose2, args) + + +def pipeline(source, *args): + """Write an input pipeline; first argument is source, rest are filters.""" + if len(args) == 0: + return source + return compose(*args)(source) + + +def getfirst(a, keys, default=None, missing_is_error=True): + """Get the first matching key from a dictionary. + + Keys can be specified as a list, or as a string of keys separated by ';'. + """ + if isinstance(keys, str): + assert " " not in keys + keys = keys.split(";") + for k in keys: + if k in a: + return a[k] + if missing_is_error: + raise ValueError(f"didn't find {keys} in {list(a.keys())}") + return default + + +def parse_field_spec(fields): + """Parse a specification for a list of fields to be extracted. + + Keys are separated by spaces in the spec. Each key can itself + be composed of key alternatives separated by ';'. + """ + if isinstance(fields, str): + fields = fields.split() + return [field.split(";") for field in fields] + + +def transform_with(sample, transformers): + """Transform a list of values using a list of functions. + + sample: list of values + transformers: list of functions + + If there are fewer transformers than inputs, or if a transformer + function is None, then the identity function is used for the + corresponding sample fields. + """ + if transformers is None or len(transformers) == 0: + return sample + result = list(sample) + assert len(transformers) <= len(sample) + for i in range(len(transformers)): # skipcq: PYL-C0200 + f = transformers[i] + if f is not None: + result[i] = f(sample[i]) + return result + +### +# Iterators +### + +def _info(data, fmt=None, n=3, every=-1, width=50, stream=sys.stderr, name=""): + """Print information about the samples that are passing through. + + :param data: source iterator + :param fmt: format statement (using sample dict as keyword) + :param n: when to stop + :param every: how often to print + :param width: maximum width + :param stream: output stream + :param name: identifier printed before any output + """ + for i, sample in enumerate(data): + if i < n or (every > 0 and (i + 1) % every == 0): + if fmt is None: + print("---", name, file=stream) + for k, v in sample.items(): + print(k, repr(v)[:width], file=stream) + else: + print(fmt.format(**sample), file=stream) + yield sample + + +info = pipelinefilter(_info) + + +def pick(buf, rng): + k = rng.randint(0, len(buf) - 1) + sample = buf[k] + buf[k] = buf[-1] + buf.pop() + return sample + + +def _shuffle(data, bufsize=1000, initial=100, rng=None, handler=None): + """Shuffle the data in the stream. + + This uses a buffer of size `bufsize`. Shuffling at + startup is less random; this is traded off against + yielding samples quickly. + + data: iterator + bufsize: buffer size for shuffling + returns: iterator + rng: either random module or random.Random instance + + """ + if rng is None: + rng = random.Random(int((os.getpid() + time.time()) * 1e9)) + initial = min(initial, bufsize) + buf = [] + for sample in data: + buf.append(sample) + if len(buf) < bufsize: + try: + buf.append(next(data)) # skipcq: PYL-R1708 + except StopIteration: + pass + if len(buf) >= initial: + yield pick(buf, rng) + while len(buf) > 0: + yield pick(buf, rng) + + +shuffle = pipelinefilter(_shuffle) + + +class detshuffle(PipelineStage): + def __init__(self, bufsize=1000, initial=100, seed=0, epoch=-1): + self.bufsize = bufsize + self.initial = initial + self.seed = seed + self.epoch = epoch + + def run(self, src): + self.epoch += 1 + rng = random.Random() + rng.seed((self.seed, self.epoch)) + return _shuffle(src, self.bufsize, self.initial, rng) + + +def _select(data, predicate): + """Select samples based on a predicate. + + :param data: source iterator + :param predicate: predicate (function) + """ + for sample in data: + if predicate(sample): + yield sample + + +select = pipelinefilter(_select) + + +def _log_keys(data, logfile=None): + import fcntl + + if logfile is None or logfile == "": + for sample in data: + yield sample + else: + with open(logfile, "a") as stream: + for i, sample in enumerate(data): + buf = f"{i}\t{sample.get('__worker__')}\t{sample.get('__rank__')}\t{sample.get('__key__')}\n" + try: + fcntl.flock(stream.fileno(), fcntl.LOCK_EX) + stream.write(buf) + finally: + fcntl.flock(stream.fileno(), fcntl.LOCK_UN) + yield sample + + +log_keys = pipelinefilter(_log_keys) + + +def _decode(data, *args, handler=reraise_exception, **kw): + """Decode data based on the decoding functions given as arguments.""" + + decoder = lambda x: autodecode.imagehandler(x) if isinstance(x, str) else x + handlers = [decoder(x) for x in args] + f = autodecode.Decoder(handlers, **kw) + + for sample in data: + assert isinstance(sample, dict), sample + try: + decoded = f(sample) + except Exception as exn: # skipcq: PYL-W0703 + if handler(exn): + continue + else: + break + yield decoded + + +decode = pipelinefilter(_decode) + + +def _map(data, f, handler=reraise_exception): + """Map samples.""" + for sample in data: + try: + result = f(sample) + except Exception as exn: + if handler(exn): + continue + else: + break + if result is None: + continue + if isinstance(sample, dict) and isinstance(result, dict): + result["__key__"] = sample.get("__key__") + yield result + + +map = pipelinefilter(_map) + + +def _rename(data, handler=reraise_exception, keep=True, **kw): + """Rename samples based on keyword arguments.""" + for sample in data: + try: + if not keep: + yield {k: getfirst(sample, v, missing_is_error=True) for k, v in kw.items()} + else: + + def listify(v): + return v.split(";") if isinstance(v, str) else v + + to_be_replaced = {x for v in kw.values() for x in listify(v)} + result = {k: v for k, v in sample.items() if k not in to_be_replaced} + result.update({k: getfirst(sample, v, missing_is_error=True) for k, v in kw.items()}) + yield result + except Exception as exn: + if handler(exn): + continue + else: + break + + +rename = pipelinefilter(_rename) + + +def _associate(data, associator, **kw): + """Associate additional data with samples.""" + for sample in data: + if callable(associator): + extra = associator(sample["__key__"]) + else: + extra = associator.get(sample["__key__"], {}) + sample.update(extra) # destructive + yield sample + + +associate = pipelinefilter(_associate) + + +def _map_dict(data, handler=reraise_exception, **kw): + """Map the entries in a dict sample with individual functions.""" + assert len(list(kw.keys())) > 0 + for key, f in kw.items(): + assert callable(f), (key, f) + + for sample in data: + assert isinstance(sample, dict) + try: + for k, f in kw.items(): + sample[k] = f(sample[k]) + except Exception as exn: + if handler(exn): + continue + else: + break + yield sample + + +map_dict = pipelinefilter(_map_dict) + + +def _to_tuple(data, *args, handler=reraise_exception, missing_is_error=True, none_is_error=None): + """Convert dict samples to tuples.""" + if none_is_error is None: + none_is_error = missing_is_error + if len(args) == 1 and isinstance(args[0], str) and " " in args[0]: + args = args[0].split() + + for sample in data: + try: + result = tuple([getfirst(sample, f, missing_is_error=missing_is_error) for f in args]) + if none_is_error and any(x is None for x in result): + raise ValueError(f"to_tuple {args} got {sample.keys()}") + yield result + except Exception as exn: + if handler(exn): + continue + else: + break + + +to_tuple = pipelinefilter(_to_tuple) + + +def _map_tuple(data, *args, handler=reraise_exception): + """Map the entries of a tuple with individual functions.""" + args = [f if f is not None else utils.identity for f in args] + for f in args: + assert callable(f), f + for sample in data: + assert isinstance(sample, (list, tuple)) + sample = list(sample) + n = min(len(args), len(sample)) + try: + for i in range(n): + sample[i] = args[i](sample[i]) + except Exception as exn: + if handler(exn): + continue + else: + break + yield tuple(sample) + + +map_tuple = pipelinefilter(_map_tuple) + + +def _unlisted(data): + """Turn batched data back into unbatched data.""" + for batch in data: + assert isinstance(batch, list), sample + for sample in batch: + yield sample + + +unlisted = pipelinefilter(_unlisted) + + +def _unbatched(data): + """Turn batched data back into unbatched data.""" + for sample in data: + assert isinstance(sample, (tuple, list)), sample + assert len(sample) > 0 + for i in range(len(sample[0])): + yield tuple(x[i] for x in sample) + + +unbatched = pipelinefilter(_unbatched) + + +def _rsample(data, p=0.5): + """Randomly subsample a stream of data.""" + assert p >= 0.0 and p <= 1.0 + for sample in data: + if random.uniform(0.0, 1.0) < p: + yield sample + + +rsample = pipelinefilter(_rsample) + +slice = pipelinefilter(itertools.islice) + + +def _extract_keys(source, *patterns, duplicate_is_error=True, ignore_missing=False): + for sample in source: + result = [] + for pattern in patterns: + pattern = pattern.split(";") if isinstance(pattern, str) else pattern + matches = [x for x in sample.keys() if any(fnmatch("." + x, p) for p in pattern)] + if len(matches) == 0: + if ignore_missing: + continue + else: + raise ValueError(f"Cannot find {pattern} in sample keys {sample.keys()}.") + if len(matches) > 1 and duplicate_is_error: + raise ValueError(f"Multiple sample keys {sample.keys()} match {pattern}.") + value = sample[matches[0]] + result.append(value) + yield tuple(result) + + +extract_keys = pipelinefilter(_extract_keys) + + +def _rename_keys(source, *args, keep_unselected=False, must_match=True, duplicate_is_error=True, **kw): + renamings = [(pattern, output) for output, pattern in args] + renamings += [(pattern, output) for output, pattern in kw.items()] + for sample in source: + new_sample = {} + matched = {k: False for k, _ in renamings} + for path, value in sample.items(): + fname = re.sub(r".*/", "", path) + new_name = None + for pattern, name in renamings[::-1]: + if fnmatch(fname.lower(), pattern): + matched[pattern] = True + new_name = name + break + if new_name is None: + if keep_unselected: + new_sample[path] = value + continue + if new_name in new_sample: + if duplicate_is_error: + raise ValueError(f"Duplicate value in sample {sample.keys()} after rename.") + continue + new_sample[new_name] = value + if must_match and not all(matched.values()): + raise ValueError(f"Not all patterns ({matched}) matched sample keys ({sample.keys()}).") + + yield new_sample + + +rename_keys = pipelinefilter(_rename_keys) + + +def decode_bin(stream): + return stream.read() + + +def decode_text(stream): + binary = stream.read() + return binary.decode("utf-8") + + +def decode_pickle(stream): + return pickle.load(stream) + + +default_decoders = [ + ("*.bin", decode_bin), + ("*.txt", decode_text), + ("*.pyd", decode_pickle), +] + + +def find_decoder(decoders, path): + fname = re.sub(r".*/", "", path) + if fname.startswith("__"): + return lambda x: x + for pattern, fun in decoders[::-1]: + if fnmatch(fname.lower(), pattern) or fnmatch("." + fname.lower(), pattern): + return fun + return None + + +def _xdecode( + source, + *args, + must_decode=True, + defaults=default_decoders, + **kw, +): + decoders = list(defaults) + list(args) + decoders += [("*." + k, v) for k, v in kw.items()] + for sample in source: + new_sample = {} + for path, data in sample.items(): + if path.startswith("__"): + new_sample[path] = data + continue + decoder = find_decoder(decoders, path) + if decoder is False: + value = data + elif decoder is None: + if must_decode: + raise ValueError(f"No decoder found for {path}.") + value = data + else: + if isinstance(data, bytes): + data = io.BytesIO(data) + value = decoder(data) + new_sample[path] = value + yield new_sample + +xdecode = pipelinefilter(_xdecode) + + + +def _data_filter(source, + frame_shift=10, + max_length=10240, + min_length=10, + token_max_length=200, + token_min_length=1, + min_output_input_ratio=0.0005, + max_output_input_ratio=1): + """ Filter sample according to feature and label length + Inplace operation. + + Args:: + source: Iterable[{fname, wav, label, sample_rate}] + frame_shift: length of frame shift (ms) + max_length: drop utterance which is greater than max_length(10ms) + min_length: drop utterance which is less than min_length(10ms) + token_max_length: drop utterance which is greater than + token_max_length, especially when use char unit for + english modeling + token_min_length: drop utterance which is + less than token_max_length + min_output_input_ratio: minimal ration of + token_length / feats_length(10ms) + max_output_input_ratio: maximum ration of + token_length / feats_length(10ms) + + Returns: + Iterable[{fname, wav, label, sample_rate}] + """ + for sample in source: + assert 'sample_rate' in sample + assert 'wav' in sample + assert 'label' in sample + # sample['wav'] is paddle.Tensor, we have 100 frames every second (default) + num_frames = sample['wav'].shape[1] / sample['sample_rate'] * (1000 / frame_shift) + if num_frames < min_length: + continue + if num_frames > max_length: + continue + if len(sample['label']) < token_min_length: + continue + if len(sample['label']) > token_max_length: + continue + if num_frames != 0: + if len(sample['label']) / num_frames < min_output_input_ratio: + continue + if len(sample['label']) / num_frames > max_output_input_ratio: + continue + yield sample + +data_filter = pipelinefilter(_data_filter) + +def _tokenize(source, + symbol_table, + bpe_model=None, + non_lang_syms=None, + split_with_space=False): + """ Decode text to chars or BPE + Inplace operation + + Args: + source: Iterable[{fname, wav, txt, sample_rate}] + + Returns: + Iterable[{fname, wav, txt, tokens, label, sample_rate}] + """ + if non_lang_syms is not None: + non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})") + else: + non_lang_syms = {} + non_lang_syms_pattern = None + + if bpe_model is not None: + import sentencepiece as spm + sp = spm.SentencePieceProcessor() + sp.load(bpe_model) + else: + sp = None + + for sample in source: + assert 'txt' in sample + txt = sample['txt'].strip() + if non_lang_syms_pattern is not None: + parts = non_lang_syms_pattern.split(txt.upper()) + parts = [w for w in parts if len(w.strip()) > 0] + else: + parts = [txt] + + label = [] + tokens = [] + for part in parts: + if part in non_lang_syms: + tokens.append(part) + else: + if bpe_model is not None: + tokens.extend(__tokenize_by_bpe_model(sp, part)) + else: + if split_with_space: + part = part.split(" ") + for ch in part: + if ch == ' ': + ch = "" + tokens.append(ch) + + for ch in tokens: + if ch in symbol_table: + label.append(symbol_table[ch]) + elif '' in symbol_table: + label.append(symbol_table['']) + + sample['tokens'] = tokens + sample['label'] = label + yield sample + +tokenize = pipelinefilter(_tokenize) + +def _resample(source, resample_rate=16000): + """ Resample data. + Inplace operation. + + Args: + data: Iterable[{fname, wav, label, sample_rate}] + resample_rate: target resample rate + + Returns: + Iterable[{fname, wav, label, sample_rate}] + """ + for sample in source: + assert 'sample_rate' in sample + assert 'wav' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + if sample_rate != resample_rate: + sample['sample_rate'] = resample_rate + sample['wav'] = paddle.to_tensor(backends.soundfile_backend.resample( + waveform.numpy(), src_sr = sample_rate, target_sr = resample_rate + )) + yield sample + +resample = pipelinefilter(_resample) + +def _compute_fbank(source, + num_mel_bins=80, + frame_length=25, + frame_shift=10, + dither=0.0): + """ Extract fbank + + Args: + source: Iterable[{fname, wav, label, sample_rate}] + num_mel_bins: number of mel filter bank + frame_length: length of one frame (ms) + frame_shift: length of frame shift (ms) + dither: value of dither + + Returns: + Iterable[{fname, feat, label}] + """ + for sample in source: + assert 'sample_rate' in sample + assert 'wav' in sample + assert 'fname' in sample + assert 'label' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + waveform = waveform * (1 << 15) + # Only keep fname, feat, label + mat = kaldi.fbank(waveform, + n_mels=num_mel_bins, + frame_length=frame_length, + frame_shift=frame_shift, + dither=dither, + energy_floor=0.0, + sr=sample_rate) + yield dict(fname=sample['fname'], label=sample['label'], feat=mat) + + +compute_fbank = pipelinefilter(_compute_fbank) + +def _spec_aug(source, num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80): + """ Do spec augmentation + Inplace operation + + Args: + source: Iterable[{fname, feat, label}] + num_t_mask: number of time mask to apply + num_f_mask: number of freq mask to apply + max_t: max width of time mask + max_f: max width of freq mask + max_w: max width of time warp + + Returns + Iterable[{fname, feat, label}] + """ + for sample in source: + x = sample['feat'] + x = x.numpy() + x = time_warp(x, max_time_warp=max_w, inplace = True, mode= "PIL") + x = freq_mask(x, F = max_f, n_mask = num_f_mask, inplace = True, replace_with_zero = False) + x = time_mask(x, T = max_t, n_mask = num_t_mask, inplace = True, replace_with_zero = False) + sample['feat'] = paddle.to_tensor(x, dtype=paddle.float32) + yield sample + +spec_aug = pipelinefilter(_spec_aug) + + +def _sort(source, sort_size=500): + """ Sort the data by feature length. + Sort is used after shuffle and before batch, so we can group + utts with similar lengths into a batch, and `sort_size` should + be less than `shuffle_size` + + Args: + source: Iterable[{fname, feat, label}] + sort_size: buffer size for sort + + Returns: + Iterable[{fname, feat, label}] + """ + + buf = [] + for sample in source: + buf.append(sample) + if len(buf) >= sort_size: + buf.sort(key=lambda x: x['feat'].shape[0]) + for x in buf: + yield x + buf = [] + # The sample left over + buf.sort(key=lambda x: x['feat'].shape[0]) + for x in buf: + yield x + +sort = pipelinefilter(_sort) + +def _batched(source, batch_size=16): + """ Static batch the data by `batch_size` + + Args: + data: Iterable[{fname, feat, label}] + batch_size: batch size + + Returns: + Iterable[List[{fname, feat, label}]] + """ + buf = [] + for sample in source: + buf.append(sample) + if len(buf) >= batch_size: + yield buf + buf = [] + if len(buf) > 0: + yield buf + +batched = pipelinefilter(_batched) + +def dynamic_batched(source, max_frames_in_batch=12000): + """ Dynamic batch the data until the total frames in batch + reach `max_frames_in_batch` + + Args: + source: Iterable[{fname, feat, label}] + max_frames_in_batch: max_frames in one batch + + Returns: + Iterable[List[{fname, feat, label}]] + """ + buf = [] + longest_frames = 0 + for sample in source: + assert 'feat' in sample + assert isinstance(sample['feat'], paddle.Tensor) + new_sample_frames = sample['feat'].size(0) + longest_frames = max(longest_frames, new_sample_frames) + frames_after_padding = longest_frames * (len(buf) + 1) + if frames_after_padding > max_frames_in_batch: + yield buf + buf = [sample] + longest_frames = new_sample_frames + else: + buf.append(sample) + if len(buf) > 0: + yield buf + + +def _padding(source): + """ Padding the data into training data + + Args: + source: Iterable[List[{fname, feat, label}]] + + Returns: + Iterable[Tuple(fname, feats, labels, feats lengths, label lengths)] + """ + for sample in source: + assert isinstance(sample, list) + feats_length = paddle.to_tensor([x['feat'].shape[0] for x in sample], + dtype="int64") + order = paddle.argsort(feats_length, descending=True) + feats_lengths = paddle.to_tensor( + [sample[i]['feat'].shape[0] for i in order], dtype="int64") + sorted_feats = [sample[i]['feat'] for i in order] + sorted_keys = [sample[i]['fname'] for i in order] + sorted_labels = [ + paddle.to_tensor(sample[i]['label'], dtype="int32") for i in order + ] + label_lengths = paddle.to_tensor([x.shape[0] for x in sorted_labels], + dtype="int64") + padded_feats = pad_sequence(sorted_feats, + batch_first=True, + padding_value=0) + padding_labels = pad_sequence(sorted_labels, + batch_first=True, + padding_value=-1) + + yield (sorted_keys, padded_feats, feats_lengths, padding_labels, + label_lengths) + +padding = pipelinefilter(_padding) + +def _cmvn(source, cmvn_file): + global_cmvn = GlobalCMVN(cmvn_file) + for batch in source: + sorted_keys, padded_feats, feats_lengths, padding_labels, label_lengths = batch + padded_feats = padded_feats.numpy() + padded_feats = global_cmvn(padded_feats) + padded_feats = paddle.to_tensor(padded_feats, dtype=paddle.float32) + yield (sorted_keys, padded_feats, feats_lengths, padding_labels, + label_lengths) + +cmvn = pipelinefilter(_cmvn) diff --git a/paddlespeech/audio/stream_data/paddle_utils.py b/paddlespeech/audio/stream_data/paddle_utils.py new file mode 100644 index 00000000..02bc4c84 --- /dev/null +++ b/paddlespeech/audio/stream_data/paddle_utils.py @@ -0,0 +1,33 @@ +# +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). +# Modified from https://github.com/webdataset/webdataset +# + +"""Mock implementations of paddle interfaces when paddle is not available.""" + + +try: + from paddle.io import DataLoader, IterableDataset +except ModuleNotFoundError: + + class IterableDataset: + """Empty implementation of IterableDataset when paddle is not available.""" + + pass + + class DataLoader: + """Empty implementation of DataLoader when paddle is not available.""" + + pass + +try: + from paddle import Tensor as PaddleTensor +except ModuleNotFoundError: + + class TorchTensor: + """Empty implementation of PaddleTensor when paddle is not available.""" + + pass diff --git a/paddlespeech/audio/stream_data/pipeline.py b/paddlespeech/audio/stream_data/pipeline.py new file mode 100644 index 00000000..b672773b --- /dev/null +++ b/paddlespeech/audio/stream_data/pipeline.py @@ -0,0 +1,127 @@ +# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# See the LICENSE file for licensing terms (BSD-style). +# Modified from https://github.com/webdataset/webdataset +#%% +import copy, os, random, sys, time +from dataclasses import dataclass +from itertools import islice +from typing import List + +import braceexpand, yaml + +from webdataset import autodecode, extradatasets as eds, filters, shardlists, tariterators +from webdataset.handlers import reraise_exception +from .paddle_utils import DataLoader, IterableDataset +from .utils import PipelineStage + + +def add_length_method(obj): + def length(self): + return self.size + + Combined = type( + obj.__class__.__name__ + "_Length", + (obj.__class__, IterableDataset), + {"__len__": length}, + ) + obj.__class__ = Combined + return obj + + +class DataPipeline(IterableDataset, PipelineStage): + """A pipeline starting with an IterableDataset and a series of filters.""" + + def __init__(self, *args, **kwargs): + super().__init__() + self.pipeline = [] + self.length = -1 + self.repetitions = 1 + self.nsamples = -1 + for arg in args: + if arg is None: + continue + if isinstance(arg, list): + self.pipeline.extend(arg) + else: + self.pipeline.append(arg) + + def invoke(self, f, *args, **kwargs): + """Apply a pipeline stage, possibly to the output of a previous stage.""" + if isinstance(f, PipelineStage): + return f.run(*args, **kwargs) + if isinstance(f, (IterableDataset, DataLoader)) and len(args) == 0: + return iter(f) + if isinstance(f, list): + return iter(f) + if callable(f): + result = f(*args, **kwargs) + return result + raise ValueError(f"{f}: not a valid pipeline stage") + + def iterator1(self): + """Create an iterator through one epoch in the pipeline.""" + source = self.invoke(self.pipeline[0]) + for step in self.pipeline[1:]: + source = self.invoke(step, source) + return source + + def iterator(self): + """Create an iterator through the entire dataset, using the given number of repetitions.""" + for i in range(self.repetitions): + for sample in self.iterator1(): + yield sample + + def __iter__(self): + """Create an iterator through the pipeline, repeating and slicing as requested.""" + if self.repetitions != 1: + if self.nsamples > 0: + return islice(self.iterator(), self.nsamples) + else: + return self.iterator() + else: + return self.iterator() + + def stage(self, i): + """Return pipeline stage i.""" + return self.pipeline[i] + + def append(self, f): + """Append a pipeline stage (modifies the object).""" + self.pipeline.append(f) + + def compose(self, *args): + """Append a pipeline stage to a copy of the pipeline and returns the copy.""" + result = copy.copy(self) + for arg in args: + result.append(arg) + return result + + def with_length(self, n): + """Add a __len__ method returning the desired value. + + This does not change the actual number of samples in an epoch. + PyTorch IterableDataset should not have a __len__ method. + This is provided only as a workaround for some broken training environments + that require a __len__ method. + """ + self.size = n + return add_length_method(self) + + def with_epoch(self, nsamples=-1, nbatches=-1): + """Change the epoch to return the given number of samples/batches. + + The two arguments mean the same thing.""" + self.repetitions = sys.maxsize + self.nsamples = max(nsamples, nbatches) + return self + + def repeat(self, nepochs=-1, nbatches=-1): + """Repeat iterating through the dataset for the given #epochs up to the given #samples.""" + if nepochs > 0: + self.repetitions = nepochs + self.nsamples = nbatches + else: + self.repetitions = sys.maxsize + self.nsamples = nbatches + return self diff --git a/paddlespeech/audio/stream_data/shardlists.py b/paddlespeech/audio/stream_data/shardlists.py new file mode 100644 index 00000000..503bfe57 --- /dev/null +++ b/paddlespeech/audio/stream_data/shardlists.py @@ -0,0 +1,257 @@ +# +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). +# + +# Modified from https://github.com/webdataset/webdataset + +"""Train PyTorch models directly from POSIX tar archive. + +Code works locally or over HTTP connections. +""" + +import os, random, sys, time +from dataclasses import dataclass, field +from itertools import islice +from typing import List + +import braceexpand, yaml + +from . import utils +from .filters import pipelinefilter +from .paddle_utils import IterableDataset + + +def expand_urls(urls): + if isinstance(urls, str): + urllist = urls.split("::") + result = [] + for url in urllist: + result.extend(braceexpand.braceexpand(url)) + return result + else: + return list(urls) + + +class SimpleShardList(IterableDataset): + """An iterable dataset yielding a list of urls.""" + + def __init__(self, urls, seed=None): + """Iterate through the list of shards. + + :param urls: a list of URLs as a Python list or brace notation string + """ + super().__init__() + urls = expand_urls(urls) + self.urls = urls + assert isinstance(self.urls[0], str) + self.seed = seed + + def __len__(self): + return len(self.urls) + + def __iter__(self): + """Return an iterator over the shards.""" + urls = self.urls.copy() + if self.seed is not None: + random.Random(self.seed).shuffle(urls) + for url in urls: + yield dict(url=url) + + +def split_by_node(src, group=None): + rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group) + if world_size > 1: + for s in islice(src, rank, None, world_size): + yield s + else: + for s in src: + yield s + + +def single_node_only(src, group=None): + rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group) + if world_size > 1: + raise ValueError("input pipeline needs to be reconfigured for multinode training") + for s in src: + yield s + + +def split_by_worker(src): + rank, world_size, worker, num_workers = utils.paddle_worker_info() + if num_workers > 1: + for s in islice(src, worker, None, num_workers): + yield s + else: + for s in src: + yield s + + +def resampled_(src, n=sys.maxsize): + import random + + seed = time.time() + try: + seed = open("/dev/random", "rb").read(20) + except Exception as exn: + print(repr(exn)[:50], file=sys.stderr) + rng = random.Random(seed) + print("# resampled loading", file=sys.stderr) + items = list(src) + print(f"# resampled got {len(items)} samples, yielding {n}", file=sys.stderr) + for i in range(n): + yield rng.choice(items) + + +resampled = pipelinefilter(resampled_) + + +def non_empty(src): + count = 0 + for s in src: + yield s + count += 1 + if count == 0: + raise ValueError("pipeline stage received no data at all and this was declared as an error") + + +@dataclass +class MSSource: + """Class representing a data source.""" + + name: str = "" + perepoch: int = -1 + resample: bool = False + urls: List[str] = field(default_factory=list) + + +default_rng = random.Random() + + +def expand(s): + return os.path.expanduser(os.path.expandvars(s)) + + +class MultiShardSample(IterableDataset): + def __init__(self, fname): + """Construct a shardlist from multiple sources using a YAML spec.""" + self.epoch = -1 +class MultiShardSample(IterableDataset): + def __init__(self, fname): + """Construct a shardlist from multiple sources using a YAML spec.""" + self.epoch = -1 + self.parse_spec(fname) + + def parse_spec(self, fname): + self.rng = default_rng # capture default_rng if we fork + if isinstance(fname, dict): + spec = fname + fname = "{dict}" + else: + with open(fname) as stream: + spec = yaml.safe_load(stream) + assert set(spec.keys()).issubset(set("prefix datasets buckets".split())), list(spec.keys()) + prefix = expand(spec.get("prefix", "")) + self.sources = [] + for ds in spec["datasets"]: + assert set(ds.keys()).issubset(set("buckets name shards resample choose".split())), list( + ds.keys() + ) + buckets = ds.get("buckets", spec.get("buckets", [])) + if isinstance(buckets, str): + buckets = [buckets] + buckets = [expand(s) for s in buckets] + if buckets == []: + buckets = [""] + assert len(buckets) == 1, f"{buckets}: FIXME support for multiple buckets unimplemented" + bucket = buckets[0] + name = ds.get("name", "@" + bucket) + urls = ds["shards"] + if isinstance(urls, str): + urls = [urls] + # urls = [u for url in urls for u in braceexpand.braceexpand(url)] + urls = [ + prefix + os.path.join(bucket, u) for url in urls for u in braceexpand.braceexpand(expand(url)) + ] + resample = ds.get("resample", -1) + nsample = ds.get("choose", -1) + if nsample > len(urls): + raise ValueError(f"perepoch {nsample} must be no greater than the number of shards") + if (nsample > 0) and (resample > 0): + raise ValueError("specify only one of perepoch or choose") + entry = MSSource(name=name, urls=urls, perepoch=nsample, resample=resample) + self.sources.append(entry) + print(f"# {name} {len(urls)} {nsample}", file=sys.stderr) + + def set_epoch(self, seed): + """Set the current epoch (for consistent shard selection among nodes).""" + self.rng = random.Random(seed) + + def get_shards_for_epoch(self): + result = [] + for source in self.sources: + if source.resample > 0: + # sample with replacement + l = self.rng.choices(source.urls, k=source.resample) + elif source.perepoch > 0: + # sample without replacement + l = list(source.urls) + self.rng.shuffle(l) + l = l[: source.perepoch] + else: + l = list(source.urls) + result += l + self.rng.shuffle(result) + return result + + def __iter__(self): + shards = self.get_shards_for_epoch() + for shard in shards: + yield dict(url=shard) + + +def shardspec(spec): + if spec.endswith(".yaml"): + return MultiShardSample(spec) + else: + return SimpleShardList(spec) + + +class ResampledShards(IterableDataset): + """An iterable dataset yielding a list of urls.""" + + def __init__( + self, + urls, + nshards=sys.maxsize, + worker_seed=None, + deterministic=False, + ): + """Sample shards from the shard list with replacement. + + :param urls: a list of URLs as a Python list or brace notation string + """ + super().__init__() + urls = expand_urls(urls) + self.urls = urls + assert isinstance(self.urls[0], str) + self.nshards = nshards + self.worker_seed = utils.paddle_worker_seed if worker_seed is None else worker_seed + self.deterministic = deterministic + self.epoch = -1 + + def __iter__(self): + """Return an iterator over the shards.""" + self.epoch += 1 + if self.deterministic: + seed = utils.make_seed(self.worker_seed(), self.epoch) + else: + seed = utils.make_seed(self.worker_seed(), self.epoch, os.getpid(), time.time_ns(), os.urandom(4)) + if os.environ.get("WDS_SHOW_SEED", "0") == "1": + print(f"# ResampledShards seed {seed}") + self.rng = random.Random(seed) + for _ in range(self.nshards): + index = self.rng.randint(0, len(self.urls) - 1) + yield dict(url=self.urls[index]) diff --git a/paddlespeech/audio/stream_data/tariterators.py b/paddlespeech/audio/stream_data/tariterators.py new file mode 100644 index 00000000..d9469797 --- /dev/null +++ b/paddlespeech/audio/stream_data/tariterators.py @@ -0,0 +1,283 @@ +# +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). + +# Modified from https://github.com/webdataset/webdataset +# Modified from wenet(https://github.com/wenet-e2e/wenet) + +"""Low level iteration functions for tar archives.""" + +import random, re, tarfile + +import braceexpand + +from . import filters +from webdataset import gopen +from webdataset.handlers import reraise_exception + +trace = False +meta_prefix = "__" +meta_suffix = "__" + +from ... import audio as paddleaudio +import paddle +import numpy as np + +AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma']) + +def base_plus_ext(path): + """Split off all file extensions. + + Returns base, allext. + + :param path: path with extensions + :param returns: path with all extensions removed + + """ + match = re.match(r"^((?:.*/|)[^.]+)[.]([^/]*)$", path) + if not match: + return None, None + return match.group(1), match.group(2) + + +def valid_sample(sample): + """Check whether a sample is valid. + + :param sample: sample to be checked + """ + return ( + sample is not None + and isinstance(sample, dict) + and len(list(sample.keys())) > 0 + and not sample.get("__bad__", False) + ) + + +# FIXME: UNUSED +def shardlist(urls, *, shuffle=False): + """Given a list of URLs, yields that list, possibly shuffled.""" + if isinstance(urls, str): + urls = braceexpand.braceexpand(urls) + else: + urls = list(urls) + if shuffle: + random.shuffle(urls) + for url in urls: + yield dict(url=url) + + +def url_opener(data, handler=reraise_exception, **kw): + """Given a stream of url names (packaged in `dict(url=url)`), yield opened streams.""" + for sample in data: + assert isinstance(sample, dict), sample + assert "url" in sample + url = sample["url"] + try: + stream = gopen.gopen(url, **kw) + sample.update(stream=stream) + yield sample + except Exception as exn: + exn.args = exn.args + (url,) + if handler(exn): + continue + else: + break + + +def tar_file_iterator( + fileobj, skip_meta=r"__[^/]*__($|/)", handler=reraise_exception +): + """Iterate over tar file, yielding filename, content pairs for the given tar stream. + + :param fileobj: byte stream suitable for tarfile + :param skip_meta: regexp for keys that are skipped entirely (Default value = r"__[^/]*__($|/)") + + """ + stream = tarfile.open(fileobj=fileobj, mode="r:*") + for tarinfo in stream: + fname = tarinfo.name + try: + if not tarinfo.isreg(): + continue + if fname is None: + continue + if ( + "/" not in fname + and fname.startswith(meta_prefix) + and fname.endswith(meta_suffix) + ): + # skipping metadata for now + continue + if skip_meta is not None and re.match(skip_meta, fname): + continue + + name = tarinfo.name + pos = name.rfind('.') + assert pos > 0 + prefix, postfix = name[:pos], name[pos + 1:] + if postfix == 'wav': + waveform, sample_rate = paddleaudio.load(stream.extractfile(tarinfo), normal=False) + result = dict(fname=prefix, wav=waveform, sample_rate = sample_rate) + else: + txt = stream.extractfile(tarinfo).read().decode('utf8').strip() + result = dict(fname=prefix, txt=txt) + #result = dict(fname=fname, data=data) + yield result + stream.members = [] + except Exception as exn: + if hasattr(exn, "args") and len(exn.args) > 0: + exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:] + if handler(exn): + continue + else: + break + del stream + +def tar_file_and_group_iterator( + fileobj, skip_meta=r"__[^/]*__($|/)", handler=reraise_exception +): + """ Expand a stream of open tar files into a stream of tar file contents. + And groups the file with same prefix + + Args: + data: Iterable[{src, stream}] + + Returns: + Iterable[{key, wav, txt, sample_rate}] + """ + stream = tarfile.open(fileobj=fileobj, mode="r:*") + prev_prefix = None + example = {} + valid = True + for tarinfo in stream: + name = tarinfo.name + pos = name.rfind('.') + assert pos > 0 + prefix, postfix = name[:pos], name[pos + 1:] + if prev_prefix is not None and prefix != prev_prefix: + example['fname'] = prev_prefix + if valid: + yield example + example = {} + valid = True + with stream.extractfile(tarinfo) as file_obj: + try: + if postfix == 'txt': + example['txt'] = file_obj.read().decode('utf8').strip() + elif postfix in AUDIO_FORMAT_SETS: + waveform, sample_rate = paddleaudio.load(file_obj, normal=False) + waveform = paddle.to_tensor(np.expand_dims(np.array(waveform),0), dtype=paddle.float32) + + example['wav'] = waveform + example['sample_rate'] = sample_rate + else: + example[postfix] = file_obj.read() + except Exception as exn: + if hasattr(exn, "args") and len(exn.args) > 0: + exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:] + if handler(exn): + continue + else: + break + valid = False + # logging.warning('error to parse {}'.format(name)) + prev_prefix = prefix + if prev_prefix is not None: + example['fname'] = prev_prefix + yield example + stream.close() + +def tar_file_expander(data, handler=reraise_exception): + """Expand a stream of open tar files into a stream of tar file contents. + + This returns an iterator over (filename, file_contents). + """ + for source in data: + url = source["url"] + try: + assert isinstance(source, dict) + assert "stream" in source + for sample in tar_file_iterator(source["stream"]): + assert ( + isinstance(sample, dict) and "data" in sample and "fname" in sample + ) + sample["__url__"] = url + yield sample + except Exception as exn: + exn.args = exn.args + (source.get("stream"), source.get("url")) + if handler(exn): + continue + else: + break + + + + +def tar_file_and_group_expander(data, handler=reraise_exception): + """Expand a stream of open tar files into a stream of tar file contents. + + This returns an iterator over (filename, file_contents). + """ + for source in data: + url = source["url"] + try: + assert isinstance(source, dict) + assert "stream" in source + for sample in tar_file_and_group_iterator(source["stream"]): + assert ( + isinstance(sample, dict) and "wav" in sample and "txt" in sample and "fname" in sample + ) + sample["__url__"] = url + yield sample + except Exception as exn: + exn.args = exn.args + (source.get("stream"), source.get("url")) + if handler(exn): + continue + else: + break + + +def group_by_keys(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): + """Return function over iterator that groups key, value pairs into samples. + + :param keys: function that splits the key into key and extension (base_plus_ext) + :param lcase: convert suffixes to lower case (Default value = True) + """ + current_sample = None + for filesample in data: + assert isinstance(filesample, dict) + fname, value = filesample["fname"], filesample["data"] + prefix, suffix = keys(fname) + if trace: + print( + prefix, + suffix, + current_sample.keys() if isinstance(current_sample, dict) else None, + ) + if prefix is None: + continue + if lcase: + suffix = suffix.lower() + if current_sample is None or prefix != current_sample["__key__"]: + if valid_sample(current_sample): + yield current_sample + current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) + if suffix in current_sample: + raise ValueError( + f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}" + ) + if suffixes is None or suffix in suffixes: + current_sample[suffix] = value + if valid_sample(current_sample): + yield current_sample + + +def tarfile_samples(src, handler=reraise_exception): + streams = url_opener(src, handler=handler) + samples = tar_file_and_group_expander(streams, handler=handler) + return samples + + +tarfile_to_samples = filters.pipelinefilter(tarfile_samples) diff --git a/paddlespeech/audio/stream_data/utils.py b/paddlespeech/audio/stream_data/utils.py new file mode 100644 index 00000000..83a42bad --- /dev/null +++ b/paddlespeech/audio/stream_data/utils.py @@ -0,0 +1,128 @@ +# +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). +# + +# Modified from https://github.com/webdataset/webdataset + +"""Miscellaneous utility functions.""" + +import importlib +import itertools as itt +import os +import re +import sys +from typing import Any, Callable, Iterator, Optional, Union + + +def make_seed(*args): + seed = 0 + for arg in args: + seed = (seed * 31 + hash(arg)) & 0x7FFFFFFF + return seed + + +class PipelineStage: + def invoke(self, *args, **kw): + raise NotImplementedError + + +def identity(x: Any) -> Any: + """Return the argument as is.""" + return x + + +def safe_eval(s: str, expr: str = "{}"): + """Evaluate the given expression more safely.""" + if re.sub("[^A-Za-z0-9_]", "", s) != s: + raise ValueError(f"safe_eval: illegal characters in: '{s}'") + return eval(expr.format(s)) + + +def lookup_sym(sym: str, modules: list): + """Look up a symbol in a list of modules.""" + for mname in modules: + module = importlib.import_module(mname, package="webdataset") + result = getattr(module, sym, None) + if result is not None: + return result + return None + + +def repeatedly0( + loader: Iterator, nepochs: int = sys.maxsize, nbatches: int = sys.maxsize +): + """Repeatedly returns batches from a DataLoader.""" + for epoch in range(nepochs): + for sample in itt.islice(loader, nbatches): + yield sample + + +def guess_batchsize(batch: Union[tuple, list]): + """Guess the batch size by looking at the length of the first element in a tuple.""" + return len(batch[0]) + + +def repeatedly( + source: Iterator, + nepochs: int = None, + nbatches: int = None, + nsamples: int = None, + batchsize: Callable[..., int] = guess_batchsize, +): + """Repeatedly yield samples from an iterator.""" + epoch = 0 + batch = 0 + total = 0 + while True: + for sample in source: + yield sample + batch += 1 + if nbatches is not None and batch >= nbatches: + return + if nsamples is not None: + total += guess_batchsize(sample) + if total >= nsamples: + return + epoch += 1 + if nepochs is not None and epoch >= nepochs: + return + +def paddle_worker_info(group=None): + """Return node and worker info for PyTorch and some distributed environments.""" + rank = 0 + world_size = 1 + worker = 0 + num_workers = 1 + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + else: + try: + import paddle.distributed + group = group or paddle.distributed.get_group() + rank = paddle.distributed.get_rank() + world_size = paddle.distributed.get_world_size() + except ModuleNotFoundError: + pass + if "WORKER" in os.environ and "NUM_WORKERS" in os.environ: + worker = int(os.environ["WORKER"]) + num_workers = int(os.environ["NUM_WORKERS"]) + else: + try: + import paddle.io.get_worker_info + worker_info = paddle.io.get_worker_info() + if worker_info is not None: + worker = worker_info.id + num_workers = worker_info.num_workers + except ModuleNotFoundError: + pass + + return rank, world_size, worker, num_workers + +def paddle_worker_seed(group=None): + """Compute a distinct, deterministic RNG seed for each worker and node.""" + rank, world_size, worker, num_workers = paddle_worker_info(group=group) + return rank * 1000 + worker diff --git a/paddlespeech/audio/transform/__init__.py b/paddlespeech/audio/transform/__init__.py new file mode 100644 index 00000000..185a92b8 --- /dev/null +++ b/paddlespeech/audio/transform/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/audio/transform/add_deltas.py b/paddlespeech/audio/transform/add_deltas.py new file mode 100644 index 00000000..1387fe9d --- /dev/null +++ b/paddlespeech/audio/transform/add_deltas.py @@ -0,0 +1,54 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +import numpy as np + + +def delta(feat, window): + assert window > 0 + delta_feat = np.zeros_like(feat) + for i in range(1, window + 1): + delta_feat[:-i] += i * feat[i:] + delta_feat[i:] += -i * feat[:-i] + delta_feat[-i:] += i * feat[-1] + delta_feat[:i] += -i * feat[0] + delta_feat /= 2 * sum(i**2 for i in range(1, window + 1)) + return delta_feat + + +def add_deltas(x, window=2, order=2): + """ + Args: + x (np.ndarray): speech feat, (T, D). + + Return: + np.ndarray: (T, (1+order)*D) + """ + feats = [x] + for _ in range(order): + feats.append(delta(feats[-1], window)) + return np.concatenate(feats, axis=1) + + +class AddDeltas(): + def __init__(self, window=2, order=2): + self.window = window + self.order = order + + def __repr__(self): + return "{name}(window={window}, order={order}".format( + name=self.__class__.__name__, window=self.window, order=self.order) + + def __call__(self, x): + return add_deltas(x, window=self.window, order=self.order) diff --git a/paddlespeech/audio/transform/channel_selector.py b/paddlespeech/audio/transform/channel_selector.py new file mode 100644 index 00000000..b078dcf8 --- /dev/null +++ b/paddlespeech/audio/transform/channel_selector.py @@ -0,0 +1,57 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +import numpy + + +class ChannelSelector(): + """Select 1ch from multi-channel signal""" + + def __init__(self, train_channel="random", eval_channel=0, axis=1): + self.train_channel = train_channel + self.eval_channel = eval_channel + self.axis = axis + + def __repr__(self): + return ("{name}(train_channel={train_channel}, " + "eval_channel={eval_channel}, axis={axis})".format( + name=self.__class__.__name__, + train_channel=self.train_channel, + eval_channel=self.eval_channel, + axis=self.axis, )) + + def __call__(self, x, train=True): + # Assuming x: [Time, Channel] by default + + if x.ndim <= self.axis: + # If the dimension is insufficient, then unsqueeze + # (e.g [Time] -> [Time, 1]) + ind = tuple( + slice(None) if i < x.ndim else None + for i in range(self.axis + 1)) + x = x[ind] + + if train: + channel = self.train_channel + else: + channel = self.eval_channel + + if channel == "random": + ch = numpy.random.randint(0, x.shape[self.axis]) + else: + ch = channel + + ind = tuple( + slice(None) if i != self.axis else ch for i in range(x.ndim)) + return x[ind] diff --git a/paddlespeech/audio/transform/cmvn.py b/paddlespeech/audio/transform/cmvn.py new file mode 100644 index 00000000..2db0070b --- /dev/null +++ b/paddlespeech/audio/transform/cmvn.py @@ -0,0 +1,201 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +import io +import json + +import h5py +import kaldiio +import numpy as np + + +class CMVN(): + "Apply Global/Spk CMVN/iverserCMVN." + + def __init__( + self, + stats, + norm_means=True, + norm_vars=False, + filetype="mat", + utt2spk=None, + spk2utt=None, + reverse=False, + std_floor=1.0e-20, ): + self.stats_file = stats + self.norm_means = norm_means + self.norm_vars = norm_vars + self.reverse = reverse + + if isinstance(stats, dict): + stats_dict = dict(stats) + else: + # Use for global CMVN + if filetype == "mat": + stats_dict = {None: kaldiio.load_mat(stats)} + # Use for global CMVN + elif filetype == "npy": + stats_dict = {None: np.load(stats)} + # Use for speaker CMVN + elif filetype == "ark": + self.accept_uttid = True + stats_dict = dict(kaldiio.load_ark(stats)) + # Use for speaker CMVN + elif filetype == "hdf5": + self.accept_uttid = True + stats_dict = h5py.File(stats) + else: + raise ValueError("Not supporting filetype={}".format(filetype)) + + if utt2spk is not None: + self.utt2spk = {} + with io.open(utt2spk, "r", encoding="utf-8") as f: + for line in f: + utt, spk = line.rstrip().split(None, 1) + self.utt2spk[utt] = spk + elif spk2utt is not None: + self.utt2spk = {} + with io.open(spk2utt, "r", encoding="utf-8") as f: + for line in f: + spk, utts = line.rstrip().split(None, 1) + for utt in utts.split(): + self.utt2spk[utt] = spk + else: + self.utt2spk = None + + # Kaldi makes a matrix for CMVN which has a shape of (2, feat_dim + 1), + # and the first vector contains the sum of feats and the second is + # the sum of squares. The last value of the first, i.e. stats[0,-1], + # is the number of samples for this statistics. + self.bias = {} + self.scale = {} + for spk, stats in stats_dict.items(): + assert len(stats) == 2, stats.shape + + count = stats[0, -1] + + # If the feature has two or more dimensions + if not (np.isscalar(count) or isinstance(count, (int, float))): + # The first is only used + count = count.flatten()[0] + + mean = stats[0, :-1] / count + # V(x) = E(x^2) - (E(x))^2 + var = stats[1, :-1] / count - mean * mean + std = np.maximum(np.sqrt(var), std_floor) + self.bias[spk] = -mean + self.scale[spk] = 1 / std + + def __repr__(self): + return ("{name}(stats_file={stats_file}, " + "norm_means={norm_means}, norm_vars={norm_vars}, " + "reverse={reverse})".format( + name=self.__class__.__name__, + stats_file=self.stats_file, + norm_means=self.norm_means, + norm_vars=self.norm_vars, + reverse=self.reverse, )) + + def __call__(self, x, uttid=None): + if self.utt2spk is not None: + spk = self.utt2spk[uttid] + else: + spk = uttid + + if not self.reverse: + # apply cmvn + if self.norm_means: + x = np.add(x, self.bias[spk]) + if self.norm_vars: + x = np.multiply(x, self.scale[spk]) + + else: + # apply reverse cmvn + if self.norm_vars: + x = np.divide(x, self.scale[spk]) + if self.norm_means: + x = np.subtract(x, self.bias[spk]) + + return x + + +class UtteranceCMVN(): + "Apply Utterance CMVN" + + def __init__(self, norm_means=True, norm_vars=False, std_floor=1.0e-20): + self.norm_means = norm_means + self.norm_vars = norm_vars + self.std_floor = std_floor + + def __repr__(self): + return "{name}(norm_means={norm_means}, norm_vars={norm_vars})".format( + name=self.__class__.__name__, + norm_means=self.norm_means, + norm_vars=self.norm_vars, ) + + def __call__(self, x, uttid=None): + # x: [Time, Dim] + square_sums = (x**2).sum(axis=0) + mean = x.mean(axis=0) + + if self.norm_means: + x = np.subtract(x, mean) + + if self.norm_vars: + var = square_sums / x.shape[0] - mean**2 + std = np.maximum(np.sqrt(var), self.std_floor) + x = np.divide(x, std) + + return x + + +class GlobalCMVN(): + "Apply Global CMVN" + + def __init__(self, + cmvn_path, + norm_means=True, + norm_vars=True, + std_floor=1.0e-20): + # cmvn_path: Option[str, dict] + cmvn = cmvn_path + self.cmvn = cmvn + self.norm_means = norm_means + self.norm_vars = norm_vars + self.std_floor = std_floor + if isinstance(cmvn, dict): + cmvn_stats = cmvn + else: + with open(cmvn) as f: + cmvn_stats = json.load(f) + self.count = cmvn_stats['frame_num'] + self.mean = np.array(cmvn_stats['mean_stat']) / self.count + self.square_sums = np.array(cmvn_stats['var_stat']) + self.var = self.square_sums / self.count - self.mean**2 + self.std = np.maximum(np.sqrt(self.var), self.std_floor) + + def __repr__(self): + return f"""{self.__class__.__name__}( + cmvn_path={self.cmvn}, + norm_means={self.norm_means}, + norm_vars={self.norm_vars},)""" + + def __call__(self, x, uttid=None): + # x: [Time, Dim] + if self.norm_means: + x = np.subtract(x, self.mean) + + if self.norm_vars: + x = np.divide(x, self.std) + return x diff --git a/paddlespeech/audio/transform/functional.py b/paddlespeech/audio/transform/functional.py new file mode 100644 index 00000000..271819ad --- /dev/null +++ b/paddlespeech/audio/transform/functional.py @@ -0,0 +1,86 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +import inspect + +from paddlespeech.audio.transform.transform_interface import TransformInterface +from paddlespeech.audio.utils.check_kwargs import check_kwargs + + +class FuncTrans(TransformInterface): + """Functional Transformation + + WARNING: + Builtin or C/C++ functions may not work properly + because this class heavily depends on the `inspect` module. + + Usage: + + >>> def foo_bar(x, a=1, b=2): + ... '''Foo bar + ... :param x: input + ... :param int a: default 1 + ... :param int b: default 2 + ... ''' + ... return x + a - b + + + >>> class FooBar(FuncTrans): + ... _func = foo_bar + ... __doc__ = foo_bar.__doc__ + """ + + _func = None + + def __init__(self, **kwargs): + self.kwargs = kwargs + check_kwargs(self.func, kwargs) + + def __call__(self, x): + return self.func(x, **self.kwargs) + + @classmethod + def add_arguments(cls, parser): + fname = cls._func.__name__.replace("_", "-") + group = parser.add_argument_group(fname + " transformation setting") + for k, v in cls.default_params().items(): + # TODO(karita): get help and choices from docstring? + attr = k.replace("_", "-") + group.add_argument(f"--{fname}-{attr}", default=v, type=type(v)) + return parser + + @property + def func(self): + return type(self)._func + + @classmethod + def default_params(cls): + try: + d = dict(inspect.signature(cls._func).parameters) + except ValueError: + d = dict() + return { + k: v.default + for k, v in d.items() if v.default != inspect.Parameter.empty + } + + def __repr__(self): + params = self.default_params() + params.update(**self.kwargs) + ret = self.__class__.__name__ + "(" + if len(params) == 0: + return ret + ")" + for k, v in params.items(): + ret += "{}={}, ".format(k, v) + return ret[:-2] + ")" diff --git a/paddlespeech/audio/transform/perturb.py b/paddlespeech/audio/transform/perturb.py new file mode 100644 index 00000000..8044dc36 --- /dev/null +++ b/paddlespeech/audio/transform/perturb.py @@ -0,0 +1,561 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +import librosa +import numpy +import scipy +import soundfile + +import io +import os +import h5py +import numpy as np + +class SoundHDF5File(): + """Collecting sound files to a HDF5 file + + >>> f = SoundHDF5File('a.flac.h5', mode='a') + >>> array = np.random.randint(0, 100, 100, dtype=np.int16) + >>> f['id'] = (array, 16000) + >>> array, rate = f['id'] + + + :param: str filepath: + :param: str mode: + :param: str format: The type used when saving wav. flac, nist, htk, etc. + :param: str dtype: + + """ + + def __init__(self, + filepath, + mode="r+", + format=None, + dtype="int16", + **kwargs): + self.filepath = filepath + self.mode = mode + self.dtype = dtype + + self.file = h5py.File(filepath, mode, **kwargs) + if format is None: + # filepath = a.flac.h5 -> format = flac + second_ext = os.path.splitext(os.path.splitext(filepath)[0])[1] + format = second_ext[1:] + if format.upper() not in soundfile.available_formats(): + # If not found, flac is selected + format = "flac" + + # This format affects only saving + self.format = format + + def __repr__(self): + return ''.format( + self.filepath, self.mode, self.format, self.dtype) + + def create_dataset(self, name, shape=None, data=None, **kwds): + f = io.BytesIO() + array, rate = data + soundfile.write(f, array, rate, format=self.format) + self.file.create_dataset( + name, shape=shape, data=np.void(f.getvalue()), **kwds) + + def __setitem__(self, name, data): + self.create_dataset(name, data=data) + + def __getitem__(self, key): + data = self.file[key][()] + f = io.BytesIO(data.tobytes()) + array, rate = soundfile.read(f, dtype=self.dtype) + return array, rate + + def keys(self): + return self.file.keys() + + def values(self): + for k in self.file: + yield self[k] + + def items(self): + for k in self.file: + yield k, self[k] + + def __iter__(self): + return iter(self.file) + + def __contains__(self, item): + return item in self.file + + def __len__(self, item): + return len(self.file) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.file.close() + + def close(self): + self.file.close() + +class SpeedPerturbation(): + """SpeedPerturbation + + The speed perturbation in kaldi uses sox-speed instead of sox-tempo, + and sox-speed just to resample the input, + i.e pitch and tempo are changed both. + + "Why use speed option instead of tempo -s in SoX for speed perturbation" + https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8 + + Warning: + This function is very slow because of resampling. + I recommmend to apply speed-perturb outside the training using sox. + + """ + + def __init__( + self, + lower=0.9, + upper=1.1, + utt2ratio=None, + keep_length=True, + res_type="kaiser_best", + seed=None, ): + self.res_type = res_type + self.keep_length = keep_length + self.state = numpy.random.RandomState(seed) + + if utt2ratio is not None: + self.utt2ratio = {} + # Use the scheduled ratio for each utterances + self.utt2ratio_file = utt2ratio + self.lower = None + self.upper = None + self.accept_uttid = True + + with open(utt2ratio, "r") as f: + for line in f: + utt, ratio = line.rstrip().split(None, 1) + ratio = float(ratio) + self.utt2ratio[utt] = ratio + else: + self.utt2ratio = None + # The ratio is given on runtime randomly + self.lower = lower + self.upper = upper + + def __repr__(self): + if self.utt2ratio is None: + return "{}(lower={}, upper={}, " "keep_length={}, res_type={})".format( + self.__class__.__name__, + self.lower, + self.upper, + self.keep_length, + self.res_type, ) + else: + return "{}({}, res_type={})".format( + self.__class__.__name__, self.utt2ratio_file, self.res_type) + + def __call__(self, x, uttid=None, train=True): + if not train: + return x + x = x.astype(numpy.float32) + if self.accept_uttid: + ratio = self.utt2ratio[uttid] + else: + ratio = self.state.uniform(self.lower, self.upper) + + # Note1: resample requires the sampling-rate of input and output, + # but actually only the ratio is used. + y = librosa.resample( + x, orig_sr=ratio, target_sr=1, res_type=self.res_type) + + if self.keep_length: + diff = abs(len(x) - len(y)) + if len(y) > len(x): + # Truncate noise + y = y[diff // 2:-((diff + 1) // 2)] + elif len(y) < len(x): + # Assume the time-axis is the first: (Time, Channel) + pad_width = [(diff // 2, (diff + 1) // 2)] + [ + (0, 0) for _ in range(y.ndim - 1) + ] + y = numpy.pad( + y, pad_width=pad_width, constant_values=0, mode="constant") + return y + + +class SpeedPerturbationSox(): + """SpeedPerturbationSox + + The speed perturbation in kaldi uses sox-speed instead of sox-tempo, + and sox-speed just to resample the input, + i.e pitch and tempo are changed both. + + To speed up or slow down the sound of a file, + use speed to modify the pitch and the duration of the file. + This raises the speed and reduces the time. + The default factor is 1.0 which makes no change to the audio. + 2.0 doubles speed, thus time length is cut by a half and pitch is one interval higher. + + "Why use speed option instead of tempo -s in SoX for speed perturbation" + https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8 + + tempo option: + sox -t wav input.wav -t wav output.tempo0.9.wav tempo -s 0.9 + + speed option: + sox -t wav input.wav -t wav output.speed0.9.wav speed 0.9 + + If we use speed option like above, the pitch of audio also will be changed, + but the tempo option does not change the pitch. + """ + + def __init__( + self, + lower=0.9, + upper=1.1, + utt2ratio=None, + keep_length=True, + sr=16000, + seed=None, ): + self.sr = sr + self.keep_length = keep_length + self.state = numpy.random.RandomState(seed) + + try: + import soxbindings as sox + except ImportError: + try: + from paddlespeech.s2t.utils import dynamic_pip_install + package = "sox" + dynamic_pip_install.install(package) + package = "soxbindings" + if sys.platform != "win32": + dynamic_pip_install.install(package) + import soxbindings as sox + except Exception: + raise RuntimeError( + "Can not install soxbindings on your system.") + self.sox = sox + + if utt2ratio is not None: + self.utt2ratio = {} + # Use the scheduled ratio for each utterances + self.utt2ratio_file = utt2ratio + self.lower = None + self.upper = None + self.accept_uttid = True + + with open(utt2ratio, "r") as f: + for line in f: + utt, ratio = line.rstrip().split(None, 1) + ratio = float(ratio) + self.utt2ratio[utt] = ratio + else: + self.utt2ratio = None + # The ratio is given on runtime randomly + self.lower = lower + self.upper = upper + + def __repr__(self): + if self.utt2ratio is None: + return f"""{self.__class__.__name__}( + lower={self.lower}, + upper={self.upper}, + keep_length={self.keep_length}, + sample_rate={self.sr})""" + + else: + return f"""{self.__class__.__name__}( + utt2ratio={self.utt2ratio_file}, + sample_rate={self.sr})""" + + def __call__(self, x, uttid=None, train=True): + if not train: + return x + + x = x.astype(numpy.float32) + if self.accept_uttid: + ratio = self.utt2ratio[uttid] + else: + ratio = self.state.uniform(self.lower, self.upper) + + tfm = self.sox.Transformer() + tfm.set_globals(multithread=False) + tfm.speed(ratio) + y = tfm.build_array(input_array=x, sample_rate_in=self.sr) + + if self.keep_length: + diff = abs(len(x) - len(y)) + if len(y) > len(x): + # Truncate noise + y = y[diff // 2:-((diff + 1) // 2)] + elif len(y) < len(x): + # Assume the time-axis is the first: (Time, Channel) + pad_width = [(diff // 2, (diff + 1) // 2)] + [ + (0, 0) for _ in range(y.ndim - 1) + ] + y = numpy.pad( + y, pad_width=pad_width, constant_values=0, mode="constant") + + if y.ndim == 2 and x.ndim == 1: + # (T, C) -> (T) + y = y.sequence(1) + return y + + +class BandpassPerturbation(): + """BandpassPerturbation + + Randomly dropout along the frequency axis. + + The original idea comes from the following: + "randomly-selected frequency band was cut off under the constraint of + leaving at least 1,000 Hz band within the range of less than 4,000Hz." + (The Hitachi/JHU CHiME-5 system: Advances in speech recognition for + everyday home environments using multiple microphone arrays; + http://spandh.dcs.shef.ac.uk/chime_workshop/papers/CHiME_2018_paper_kanda.pdf) + + """ + + def __init__(self, lower=0.0, upper=0.75, seed=None, axes=(-1, )): + self.lower = lower + self.upper = upper + self.state = numpy.random.RandomState(seed) + # x_stft: (Time, Channel, Freq) + self.axes = axes + + def __repr__(self): + return "{}(lower={}, upper={})".format(self.__class__.__name__, + self.lower, self.upper) + + def __call__(self, x_stft, uttid=None, train=True): + if not train: + return x_stft + + if x_stft.ndim == 1: + raise RuntimeError("Input in time-freq domain: " + "(Time, Channel, Freq) or (Time, Freq)") + + ratio = self.state.uniform(self.lower, self.upper) + axes = [i if i >= 0 else x_stft.ndim - i for i in self.axes] + shape = [s if i in axes else 1 for i, s in enumerate(x_stft.shape)] + + mask = self.state.randn(*shape) > ratio + x_stft *= mask + return x_stft + + +class VolumePerturbation(): + def __init__(self, + lower=-1.6, + upper=1.6, + utt2ratio=None, + dbunit=True, + seed=None): + self.dbunit = dbunit + self.utt2ratio_file = utt2ratio + self.lower = lower + self.upper = upper + self.state = numpy.random.RandomState(seed) + + if utt2ratio is not None: + # Use the scheduled ratio for each utterances + self.utt2ratio = {} + self.lower = None + self.upper = None + self.accept_uttid = True + + with open(utt2ratio, "r") as f: + for line in f: + utt, ratio = line.rstrip().split(None, 1) + ratio = float(ratio) + self.utt2ratio[utt] = ratio + else: + # The ratio is given on runtime randomly + self.utt2ratio = None + + def __repr__(self): + if self.utt2ratio is None: + return "{}(lower={}, upper={}, dbunit={})".format( + self.__class__.__name__, self.lower, self.upper, self.dbunit) + else: + return '{}("{}", dbunit={})'.format( + self.__class__.__name__, self.utt2ratio_file, self.dbunit) + + def __call__(self, x, uttid=None, train=True): + if not train: + return x + + x = x.astype(numpy.float32) + + if self.accept_uttid: + ratio = self.utt2ratio[uttid] + else: + ratio = self.state.uniform(self.lower, self.upper) + if self.dbunit: + ratio = 10**(ratio / 20) + return x * ratio + + +class NoiseInjection(): + """Add isotropic noise""" + + def __init__( + self, + utt2noise=None, + lower=-20, + upper=-5, + utt2ratio=None, + filetype="list", + dbunit=True, + seed=None, ): + self.utt2noise_file = utt2noise + self.utt2ratio_file = utt2ratio + self.filetype = filetype + self.dbunit = dbunit + self.lower = lower + self.upper = upper + self.state = numpy.random.RandomState(seed) + + if utt2ratio is not None: + # Use the scheduled ratio for each utterances + self.utt2ratio = {} + with open(utt2noise, "r") as f: + for line in f: + utt, snr = line.rstrip().split(None, 1) + snr = float(snr) + self.utt2ratio[utt] = snr + else: + # The ratio is given on runtime randomly + self.utt2ratio = None + + if utt2noise is not None: + self.utt2noise = {} + if filetype == "list": + with open(utt2noise, "r") as f: + for line in f: + utt, filename = line.rstrip().split(None, 1) + signal, rate = soundfile.read(filename, dtype="int16") + # Load all files in memory + self.utt2noise[utt] = (signal, rate) + + elif filetype == "sound.hdf5": + self.utt2noise = SoundHDF5File(utt2noise, "r") + else: + raise ValueError(filetype) + else: + self.utt2noise = None + + if utt2noise is not None and utt2ratio is not None: + if set(self.utt2ratio) != set(self.utt2noise): + raise RuntimeError("The uttids mismatch between {} and {}". + format(utt2ratio, utt2noise)) + + def __repr__(self): + if self.utt2ratio is None: + return "{}(lower={}, upper={}, dbunit={})".format( + self.__class__.__name__, self.lower, self.upper, self.dbunit) + else: + return '{}("{}", dbunit={})'.format( + self.__class__.__name__, self.utt2ratio_file, self.dbunit) + + def __call__(self, x, uttid=None, train=True): + if not train: + return x + x = x.astype(numpy.float32) + + # 1. Get ratio of noise to signal in sound pressure level + if uttid is not None and self.utt2ratio is not None: + ratio = self.utt2ratio[uttid] + else: + ratio = self.state.uniform(self.lower, self.upper) + + if self.dbunit: + ratio = 10**(ratio / 20) + scale = ratio * numpy.sqrt((x**2).mean()) + + # 2. Get noise + if self.utt2noise is not None: + # Get noise from the external source + if uttid is not None: + noise, rate = self.utt2noise[uttid] + else: + # Randomly select the noise source + noise = self.state.choice(list(self.utt2noise.values())) + # Normalize the level + noise /= numpy.sqrt((noise**2).mean()) + + # Adjust the noise length + diff = abs(len(x) - len(noise)) + offset = self.state.randint(0, diff) + if len(noise) > len(x): + # Truncate noise + noise = noise[offset:-(diff - offset)] + else: + noise = numpy.pad( + noise, pad_width=[offset, diff - offset], mode="wrap") + + else: + # Generate white noise + noise = self.state.normal(0, 1, x.shape) + + # 3. Add noise to signal + return x + noise * scale + + +class RIRConvolve(): + def __init__(self, utt2rir, filetype="list"): + self.utt2rir_file = utt2rir + self.filetype = filetype + + self.utt2rir = {} + if filetype == "list": + with open(utt2rir, "r") as f: + for line in f: + utt, filename = line.rstrip().split(None, 1) + signal, rate = soundfile.read(filename, dtype="int16") + self.utt2rir[utt] = (signal, rate) + + elif filetype == "sound.hdf5": + self.utt2rir = SoundHDF5File(utt2rir, "r") + else: + raise NotImplementedError(filetype) + + def __repr__(self): + return '{}("{}")'.format(self.__class__.__name__, self.utt2rir_file) + + def __call__(self, x, uttid=None, train=True): + if not train: + return x + + x = x.astype(numpy.float32) + + if x.ndim != 1: + # Must be single channel + raise RuntimeError( + "Input x must be one dimensional array, but got {}".format( + x.shape)) + + rir, rate = self.utt2rir[uttid] + if rir.ndim == 2: + # FIXME(kamo): Use chainer.convolution_1d? + # return [Time, Channel] + return numpy.stack( + [scipy.convolve(x, r, mode="same") for r in rir], axis=-1) + else: + return scipy.convolve(x, rir, mode="same") + diff --git a/paddlespeech/audio/transform/spec_augment.py b/paddlespeech/audio/transform/spec_augment.py new file mode 100644 index 00000000..c8f0a855 --- /dev/null +++ b/paddlespeech/audio/transform/spec_augment.py @@ -0,0 +1,214 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +"""Spec Augment module for preprocessing i.e., data augmentation""" +import random + +import numpy +from PIL import Image +from PIL.Image import BICUBIC + +from .functional import FuncTrans + + +def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"): + """time warp for spec augment + + move random center frame by the random width ~ uniform(-window, window) + :param numpy.ndarray x: spectrogram (time, freq) + :param int max_time_warp: maximum time frames to warp + :param bool inplace: overwrite x with the result + :param str mode: "PIL" (default, fast, not differentiable) or "sparse_image_warp" + (slow, differentiable) + :returns numpy.ndarray: time warped spectrogram (time, freq) + """ + window = max_time_warp + if window == 0: + return x + + if mode == "PIL": + t = x.shape[0] + if t - window <= window: + return x + # NOTE: randrange(a, b) emits a, a + 1, ..., b - 1 + center = random.randrange(window, t - window) + warped = random.randrange(center - window, center + + window) + 1 # 1 ... t - 1 + + left = Image.fromarray(x[:center]).resize((x.shape[1], warped), BICUBIC) + right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped), + BICUBIC) + if inplace: + x[:warped] = left + x[warped:] = right + return x + return numpy.concatenate((left, right), 0) + elif mode == "sparse_image_warp": + import paddle + + from espnet.utils import spec_augment + + # TODO(karita): make this differentiable again + return spec_augment.time_warp(paddle.to_tensor(x), window).numpy() + else: + raise NotImplementedError("unknown resize mode: " + mode + + ", choose one from (PIL, sparse_image_warp).") + + +class TimeWarp(FuncTrans): + _func = time_warp + __doc__ = time_warp.__doc__ + + def __call__(self, x, train): + if not train: + return x + return super().__call__(x) + + +def freq_mask(x, F=30, n_mask=2, replace_with_zero=True, inplace=False): + """freq mask for spec agument + + :param numpy.ndarray x: (time, freq) + :param int n_mask: the number of masks + :param bool inplace: overwrite + :param bool replace_with_zero: pad zero on mask if true else use mean + """ + if inplace: + cloned = x + else: + cloned = x.copy() + + num_mel_channels = cloned.shape[1] + fs = numpy.random.randint(0, F, size=(n_mask, 2)) + + for f, mask_end in fs: + f_zero = random.randrange(0, num_mel_channels - f) + mask_end += f_zero + + # avoids randrange error if values are equal and range is empty + if f_zero == f_zero + f: + continue + + if replace_with_zero: + cloned[:, f_zero:mask_end] = 0 + else: + cloned[:, f_zero:mask_end] = cloned.mean() + return cloned + + +class FreqMask(FuncTrans): + _func = freq_mask + __doc__ = freq_mask.__doc__ + + def __call__(self, x, train): + if not train: + return x + return super().__call__(x) + + +def time_mask(spec, T=40, n_mask=2, replace_with_zero=True, inplace=False): + """freq mask for spec agument + + :param numpy.ndarray spec: (time, freq) + :param int n_mask: the number of masks + :param bool inplace: overwrite + :param bool replace_with_zero: pad zero on mask if true else use mean + """ + if inplace: + cloned = spec + else: + cloned = spec.copy() + len_spectro = cloned.shape[0] + ts = numpy.random.randint(0, T, size=(n_mask, 2)) + for t, mask_end in ts: + # avoid randint range error + if len_spectro - t <= 0: + continue + t_zero = random.randrange(0, len_spectro - t) + + # avoids randrange error if values are equal and range is empty + if t_zero == t_zero + t: + continue + + mask_end += t_zero + if replace_with_zero: + cloned[t_zero:mask_end] = 0 + else: + cloned[t_zero:mask_end] = cloned.mean() + return cloned + + +class TimeMask(FuncTrans): + _func = time_mask + __doc__ = time_mask.__doc__ + + def __call__(self, x, train): + if not train: + return x + return super().__call__(x) + + +def spec_augment( + x, + resize_mode="PIL", + max_time_warp=80, + max_freq_width=27, + n_freq_mask=2, + max_time_width=100, + n_time_mask=2, + inplace=True, + replace_with_zero=True, ): + """spec agument + + apply random time warping and time/freq masking + default setting is based on LD (Librispeech double) in Table 2 + https://arxiv.org/pdf/1904.08779.pdf + + :param numpy.ndarray x: (time, freq) + :param str resize_mode: "PIL" (fast, nondifferentiable) or "sparse_image_warp" + (slow, differentiable) + :param int max_time_warp: maximum frames to warp the center frame in spectrogram (W) + :param int freq_mask_width: maximum width of the random freq mask (F) + :param int n_freq_mask: the number of the random freq mask (m_F) + :param int time_mask_width: maximum width of the random time mask (T) + :param int n_time_mask: the number of the random time mask (m_T) + :param bool inplace: overwrite intermediate array + :param bool replace_with_zero: pad zero on mask if true else use mean + """ + assert isinstance(x, numpy.ndarray) + assert x.ndim == 2 + x = time_warp(x, max_time_warp, inplace=inplace, mode=resize_mode) + x = freq_mask( + x, + max_freq_width, + n_freq_mask, + inplace=inplace, + replace_with_zero=replace_with_zero, ) + x = time_mask( + x, + max_time_width, + n_time_mask, + inplace=inplace, + replace_with_zero=replace_with_zero, ) + return x + + +class SpecAugment(FuncTrans): + _func = spec_augment + __doc__ = spec_augment.__doc__ + + def __call__(self, x, train): + if not train: + return x + return super().__call__(x) diff --git a/paddlespeech/audio/transform/spectrogram.py b/paddlespeech/audio/transform/spectrogram.py new file mode 100644 index 00000000..864f3f99 --- /dev/null +++ b/paddlespeech/audio/transform/spectrogram.py @@ -0,0 +1,475 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +import librosa +import numpy as np +import paddle +from python_speech_features import logfbank + +from ..compliance import kaldi + + +def stft(x, + n_fft, + n_shift, + win_length=None, + window="hann", + center=True, + pad_mode="reflect"): + # x: [Time, Channel] + if x.ndim == 1: + single_channel = True + # x: [Time] -> [Time, Channel] + x = x[:, None] + else: + single_channel = False + x = x.astype(np.float32) + + # FIXME(kamo): librosa.stft can't use multi-channel? + # x: [Time, Channel, Freq] + x = np.stack( + [ + librosa.stft( + y=x[:, ch], + n_fft=n_fft, + hop_length=n_shift, + win_length=win_length, + window=window, + center=center, + pad_mode=pad_mode, ).T for ch in range(x.shape[1]) + ], + axis=1, ) + + if single_channel: + # x: [Time, Channel, Freq] -> [Time, Freq] + x = x[:, 0] + return x + + +def istft(x, n_shift, win_length=None, window="hann", center=True): + # x: [Time, Channel, Freq] + if x.ndim == 2: + single_channel = True + # x: [Time, Freq] -> [Time, Channel, Freq] + x = x[:, None, :] + else: + single_channel = False + + # x: [Time, Channel] + x = np.stack( + [ + librosa.istft( + stft_matrix=x[:, ch].T, # [Time, Freq] -> [Freq, Time] + hop_length=n_shift, + win_length=win_length, + window=window, + center=center, ) for ch in range(x.shape[1]) + ], + axis=1, ) + + if single_channel: + # x: [Time, Channel] -> [Time] + x = x[:, 0] + return x + + +def stft2logmelspectrogram(x_stft, + fs, + n_mels, + n_fft, + fmin=None, + fmax=None, + eps=1e-10): + # x_stft: (Time, Channel, Freq) or (Time, Freq) + fmin = 0 if fmin is None else fmin + fmax = fs / 2 if fmax is None else fmax + + # spc: (Time, Channel, Freq) or (Time, Freq) + spc = np.abs(x_stft) + # mel_basis: (Mel_freq, Freq) + mel_basis = librosa.filters.mel( + sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) + # lmspc: (Time, Channel, Mel_freq) or (Time, Mel_freq) + lmspc = np.log10(np.maximum(eps, np.dot(spc, mel_basis.T))) + + return lmspc + + +def spectrogram(x, n_fft, n_shift, win_length=None, window="hann"): + # x: (Time, Channel) -> spc: (Time, Channel, Freq) + spc = np.abs(stft(x, n_fft, n_shift, win_length, window=window)) + return spc + + +def logmelspectrogram( + x, + fs, + n_mels, + n_fft, + n_shift, + win_length=None, + window="hann", + fmin=None, + fmax=None, + eps=1e-10, + pad_mode="reflect", ): + # stft: (Time, Channel, Freq) or (Time, Freq) + x_stft = stft( + x, + n_fft=n_fft, + n_shift=n_shift, + win_length=win_length, + window=window, + pad_mode=pad_mode, ) + + return stft2logmelspectrogram( + x_stft, + fs=fs, + n_mels=n_mels, + n_fft=n_fft, + fmin=fmin, + fmax=fmax, + eps=eps) + + +class Spectrogram(): + def __init__(self, n_fft, n_shift, win_length=None, window="hann"): + self.n_fft = n_fft + self.n_shift = n_shift + self.win_length = win_length + self.window = window + + def __repr__(self): + return ("{name}(n_fft={n_fft}, n_shift={n_shift}, " + "win_length={win_length}, window={window})".format( + name=self.__class__.__name__, + n_fft=self.n_fft, + n_shift=self.n_shift, + win_length=self.win_length, + window=self.window, )) + + def __call__(self, x): + return spectrogram( + x, + n_fft=self.n_fft, + n_shift=self.n_shift, + win_length=self.win_length, + window=self.window, ) + + +class LogMelSpectrogram(): + def __init__( + self, + fs, + n_mels, + n_fft, + n_shift, + win_length=None, + window="hann", + fmin=None, + fmax=None, + eps=1e-10, ): + self.fs = fs + self.n_mels = n_mels + self.n_fft = n_fft + self.n_shift = n_shift + self.win_length = win_length + self.window = window + self.fmin = fmin + self.fmax = fmax + self.eps = eps + + def __repr__(self): + return ("{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, " + "n_shift={n_shift}, win_length={win_length}, window={window}, " + "fmin={fmin}, fmax={fmax}, eps={eps}))".format( + name=self.__class__.__name__, + fs=self.fs, + n_mels=self.n_mels, + n_fft=self.n_fft, + n_shift=self.n_shift, + win_length=self.win_length, + window=self.window, + fmin=self.fmin, + fmax=self.fmax, + eps=self.eps, )) + + def __call__(self, x): + return logmelspectrogram( + x, + fs=self.fs, + n_mels=self.n_mels, + n_fft=self.n_fft, + n_shift=self.n_shift, + win_length=self.win_length, + window=self.window, ) + + +class Stft2LogMelSpectrogram(): + def __init__(self, fs, n_mels, n_fft, fmin=None, fmax=None, eps=1e-10): + self.fs = fs + self.n_mels = n_mels + self.n_fft = n_fft + self.fmin = fmin + self.fmax = fmax + self.eps = eps + + def __repr__(self): + return ("{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, " + "fmin={fmin}, fmax={fmax}, eps={eps}))".format( + name=self.__class__.__name__, + fs=self.fs, + n_mels=self.n_mels, + n_fft=self.n_fft, + fmin=self.fmin, + fmax=self.fmax, + eps=self.eps, )) + + def __call__(self, x): + return stft2logmelspectrogram( + x, + fs=self.fs, + n_mels=self.n_mels, + n_fft=self.n_fft, + fmin=self.fmin, + fmax=self.fmax, ) + + +class Stft(): + def __init__( + self, + n_fft, + n_shift, + win_length=None, + window="hann", + center=True, + pad_mode="reflect", ): + self.n_fft = n_fft + self.n_shift = n_shift + self.win_length = win_length + self.window = window + self.center = center + self.pad_mode = pad_mode + + def __repr__(self): + return ("{name}(n_fft={n_fft}, n_shift={n_shift}, " + "win_length={win_length}, window={window}," + "center={center}, pad_mode={pad_mode})".format( + name=self.__class__.__name__, + n_fft=self.n_fft, + n_shift=self.n_shift, + win_length=self.win_length, + window=self.window, + center=self.center, + pad_mode=self.pad_mode, )) + + def __call__(self, x): + return stft( + x, + self.n_fft, + self.n_shift, + win_length=self.win_length, + window=self.window, + center=self.center, + pad_mode=self.pad_mode, ) + + +class IStft(): + def __init__(self, n_shift, win_length=None, window="hann", center=True): + self.n_shift = n_shift + self.win_length = win_length + self.window = window + self.center = center + + def __repr__(self): + return ("{name}(n_shift={n_shift}, " + "win_length={win_length}, window={window}," + "center={center})".format( + name=self.__class__.__name__, + n_shift=self.n_shift, + win_length=self.win_length, + window=self.window, + center=self.center, )) + + def __call__(self, x): + return istft( + x, + self.n_shift, + win_length=self.win_length, + window=self.window, + center=self.center, ) + + +class LogMelSpectrogramKaldi(): + def __init__( + self, + fs=16000, + n_mels=80, + n_shift=160, # unit:sample, 10ms + win_length=400, # unit:sample, 25ms + energy_floor=0.0, + dither=0.1): + """ + The Kaldi implementation of LogMelSpectrogram + Args: + fs (int): sample rate of the audio + n_mels (int): number of mel filter banks + n_shift (int): number of points in a frame shift + win_length (int): number of points in a frame windows + energy_floor (float): Floor on energy in Spectrogram computation (absolute) + dither (float): Dithering constant + + Returns: + LogMelSpectrogramKaldi + """ + + self.fs = fs + self.n_mels = n_mels + num_point_ms = fs / 1000 + self.n_frame_length = win_length / num_point_ms + self.n_frame_shift = n_shift / num_point_ms + self.energy_floor = energy_floor + self.dither = dither + + def __repr__(self): + return ( + "{name}(fs={fs}, n_mels={n_mels}, " + "n_frame_shift={n_frame_shift}, n_frame_length={n_frame_length}, " + "dither={dither}))".format( + name=self.__class__.__name__, + fs=self.fs, + n_mels=self.n_mels, + n_frame_shift=self.n_frame_shift, + n_frame_length=self.n_frame_length, + dither=self.dither, )) + + def __call__(self, x, train): + """ + Args: + x (np.ndarray): shape (Ti,) + train (bool): True, train mode. + + Raises: + ValueError: not support (Ti, C) + + Returns: + np.ndarray: (T, D) + """ + dither = self.dither if train else 0.0 + if x.ndim != 1: + raise ValueError("Not support x: [Time, Channel]") + waveform = paddle.to_tensor(np.expand_dims(x, 0), dtype=paddle.float32) + mat = kaldi.fbank( + waveform, + n_mels=self.n_mels, + frame_length=self.n_frame_length, + frame_shift=self.n_frame_shift, + dither=dither, + energy_floor=self.energy_floor, + sr=self.fs) + mat = np.squeeze(mat.numpy()) + return mat + + +class LogMelSpectrogramKaldi_decay(): + def __init__( + self, + fs=16000, + n_mels=80, + n_fft=512, # fft point + n_shift=160, # unit:sample, 10ms + win_length=400, # unit:sample, 25ms + window="povey", + fmin=20, + fmax=None, + eps=1e-10, + dither=1.0): + self.fs = fs + self.n_mels = n_mels + self.n_fft = n_fft + if n_shift > win_length: + raise ValueError("Stride size must not be greater than " + "window size.") + self.n_shift = n_shift / fs # unit: ms + self.win_length = win_length / fs # unit: ms + + self.window = window + self.fmin = fmin + if fmax is None: + fmax_ = fmax if fmax else self.fs / 2 + elif fmax > int(self.fs / 2): + raise ValueError("fmax must not be greater than half of " + "sample rate.") + self.fmax = fmax_ + + self.eps = eps + self.remove_dc_offset = True + self.preemph = 0.97 + self.dither = dither # only work in train mode + + def __repr__(self): + return ( + "{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, " + "n_shift={n_shift}, win_length={win_length}, preemph={preemph}, window={window}, " + "fmin={fmin}, fmax={fmax}, eps={eps}, dither={dither}))".format( + name=self.__class__.__name__, + fs=self.fs, + n_mels=self.n_mels, + n_fft=self.n_fft, + n_shift=self.n_shift, + preemph=self.preemph, + win_length=self.win_length, + window=self.window, + fmin=self.fmin, + fmax=self.fmax, + eps=self.eps, + dither=self.dither, )) + + def __call__(self, x, train): + """ + + Args: + x (np.ndarray): shape (Ti,) + train (bool): True, train mode. + + Raises: + ValueError: not support (Ti, C) + + Returns: + np.ndarray: (T, D) + """ + dither = self.dither if train else 0.0 + if x.ndim != 1: + raise ValueError("Not support x: [Time, Channel]") + + if x.dtype in np.sctypes['float']: + # PCM32 -> PCM16 + bits = np.iinfo(np.int16).bits + x = x * 2**(bits - 1) + + # logfbank need PCM16 input + y = logfbank( + signal=x, + samplerate=self.fs, + winlen=self.win_length, # unit ms + winstep=self.n_shift, # unit ms + nfilt=self.n_mels, + nfft=self.n_fft, + lowfreq=self.fmin, + highfreq=self.fmax, + dither=dither, + remove_dc_offset=self.remove_dc_offset, + preemph=self.preemph, + wintype=self.window) + return y diff --git a/paddlespeech/audio/transform/transform_interface.py b/paddlespeech/audio/transform/transform_interface.py new file mode 100644 index 00000000..8bc62420 --- /dev/null +++ b/paddlespeech/audio/transform/transform_interface.py @@ -0,0 +1,35 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) + + +class TransformInterface: + """Transform Interface""" + + def __call__(self, x): + raise NotImplementedError("__call__ method is not implemented") + + @classmethod + def add_arguments(cls, parser): + return parser + + def __repr__(self): + return self.__class__.__name__ + "()" + + +class Identity(TransformInterface): + """Identity Function""" + + def __call__(self, x): + return x diff --git a/paddlespeech/audio/transform/transformation.py b/paddlespeech/audio/transform/transformation.py new file mode 100644 index 00000000..d24d6437 --- /dev/null +++ b/paddlespeech/audio/transform/transformation.py @@ -0,0 +1,158 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +"""Transformation module.""" +import copy +import io +import logging +from collections import OrderedDict +from collections.abc import Sequence +from inspect import signature + +import yaml + +from ..utils.dynamic_import import dynamic_import + +import_alias = dict( + identity="paddlespeech.audio.transform.transform_interface:Identity", + time_warp="paddlespeech.audio.transform.spec_augment:TimeWarp", + time_mask="paddlespeech.audio.transform.spec_augment:TimeMask", + freq_mask="paddlespeech.audio.transform.spec_augment:FreqMask", + spec_augment="paddlespeech.audio.transform.spec_augment:SpecAugment", + speed_perturbation="paddlespeech.audio.transform.perturb:SpeedPerturbation", + speed_perturbation_sox="paddlespeech.audio.transform.perturb:SpeedPerturbationSox", + volume_perturbation="paddlespeech.audio.transform.perturb:VolumePerturbation", + noise_injection="paddlespeech.audio.transform.perturb:NoiseInjection", + bandpass_perturbation="paddlespeech.audio.transform.perturb:BandpassPerturbation", + rir_convolve="paddlespeech.audio.transform.perturb:RIRConvolve", + delta="paddlespeech.audio.transform.add_deltas:AddDeltas", + cmvn="paddlespeech.audio.transform.cmvn:CMVN", + utterance_cmvn="paddlespeech.audio.transform.cmvn:UtteranceCMVN", + fbank="paddlespeech.audio.transform.spectrogram:LogMelSpectrogram", + spectrogram="paddlespeech.audio.transform.spectrogram:Spectrogram", + stft="paddlespeech.audio.transform.spectrogram:Stft", + istft="paddlespeech.audio.transform.spectrogram:IStft", + stft2fbank="paddlespeech.audio.transform.spectrogram:Stft2LogMelSpectrogram", + wpe="paddlespeech.audio.transform.wpe:WPE", + channel_selector="paddlespeech.audio.transform.channel_selector:ChannelSelector", + fbank_kaldi="paddlespeech.audio.transform.spectrogram:LogMelSpectrogramKaldi", + cmvn_json="paddlespeech.audio.transform.cmvn:GlobalCMVN") + + +class Transformation(): + """Apply some functions to the mini-batch + + Examples: + >>> kwargs = {"process": [{"type": "fbank", + ... "n_mels": 80, + ... "fs": 16000}, + ... {"type": "cmvn", + ... "stats": "data/train/cmvn.ark", + ... "norm_vars": True}, + ... {"type": "delta", "window": 2, "order": 2}]} + >>> transform = Transformation(kwargs) + >>> bs = 10 + >>> xs = [np.random.randn(100, 80).astype(np.float32) + ... for _ in range(bs)] + >>> xs = transform(xs) + """ + + def __init__(self, conffile=None): + if conffile is not None: + if isinstance(conffile, dict): + self.conf = copy.deepcopy(conffile) + else: + with io.open(conffile, encoding="utf-8") as f: + self.conf = yaml.safe_load(f) + assert isinstance(self.conf, dict), type(self.conf) + else: + self.conf = {"mode": "sequential", "process": []} + + self.functions = OrderedDict() + if self.conf.get("mode", "sequential") == "sequential": + for idx, process in enumerate(self.conf["process"]): + assert isinstance(process, dict), type(process) + opts = dict(process) + process_type = opts.pop("type") + class_obj = dynamic_import(process_type, import_alias) + # TODO(karita): assert issubclass(class_obj, TransformInterface) + try: + self.functions[idx] = class_obj(**opts) + except TypeError: + try: + signa = signature(class_obj) + except ValueError: + # Some function, e.g. built-in function, are failed + pass + else: + logging.error("Expected signature: {}({})".format( + class_obj.__name__, signa)) + raise + else: + raise NotImplementedError( + "Not supporting mode={}".format(self.conf["mode"])) + + def __repr__(self): + rep = "\n" + "\n".join(" {}: {}".format(k, v) + for k, v in self.functions.items()) + return "{}({})".format(self.__class__.__name__, rep) + + def __call__(self, xs, uttid_list=None, **kwargs): + """Return new mini-batch + + :param Union[Sequence[np.ndarray], np.ndarray] xs: + :param Union[Sequence[str], str] uttid_list: + :return: batch: + :rtype: List[np.ndarray] + """ + if not isinstance(xs, Sequence): + is_batch = False + xs = [xs] + else: + is_batch = True + + if isinstance(uttid_list, str): + uttid_list = [uttid_list for _ in range(len(xs))] + + if self.conf.get("mode", "sequential") == "sequential": + for idx in range(len(self.conf["process"])): + func = self.functions[idx] + # TODO(karita): use TrainingTrans and UttTrans to check __call__ args + # Derive only the args which the func has + try: + param = signature(func).parameters + except ValueError: + # Some function, e.g. built-in function, are failed + param = {} + _kwargs = {k: v for k, v in kwargs.items() if k in param} + try: + if uttid_list is not None and "uttid" in param: + xs = [ + func(x, u, **_kwargs) + for x, u in zip(xs, uttid_list) + ] + else: + xs = [func(x, **_kwargs) for x in xs] + except Exception: + logging.fatal("Catch a exception from {}th func: {}".format( + idx, func)) + raise + else: + raise NotImplementedError( + "Not supporting mode={}".format(self.conf["mode"])) + + if is_batch: + return xs + else: + return xs[0] diff --git a/paddlespeech/audio/transform/wpe.py b/paddlespeech/audio/transform/wpe.py new file mode 100644 index 00000000..777379d0 --- /dev/null +++ b/paddlespeech/audio/transform/wpe.py @@ -0,0 +1,58 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +from nara_wpe.wpe import wpe + + +class WPE(object): + def __init__(self, + taps=10, + delay=3, + iterations=3, + psd_context=0, + statistics_mode="full"): + self.taps = taps + self.delay = delay + self.iterations = iterations + self.psd_context = psd_context + self.statistics_mode = statistics_mode + + def __repr__(self): + return ("{name}(taps={taps}, delay={delay}" + "iterations={iterations}, psd_context={psd_context}, " + "statistics_mode={statistics_mode})".format( + name=self.__class__.__name__, + taps=self.taps, + delay=self.delay, + iterations=self.iterations, + psd_context=self.psd_context, + statistics_mode=self.statistics_mode, )) + + def __call__(self, xs): + """Return enhanced + + :param np.ndarray xs: (Time, Channel, Frequency) + :return: enhanced_xs + :rtype: np.ndarray + + """ + # nara_wpe.wpe: (F, C, T) + xs = wpe( + xs.transpose((2, 1, 0)), + taps=self.taps, + delay=self.delay, + iterations=self.iterations, + psd_context=self.psd_context, + statistics_mode=self.statistics_mode, ) + return xs.transpose(2, 1, 0) diff --git a/paddlespeech/audio/utils/check_kwargs.py b/paddlespeech/audio/utils/check_kwargs.py new file mode 100644 index 00000000..0aa839ac --- /dev/null +++ b/paddlespeech/audio/utils/check_kwargs.py @@ -0,0 +1,35 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +import inspect + + +def check_kwargs(func, kwargs, name=None): + """check kwargs are valid for func + + If kwargs are invalid, raise TypeError as same as python default + :param function func: function to be validated + :param dict kwargs: keyword arguments for func + :param str name: name used in TypeError (default is func name) + """ + try: + params = inspect.signature(func).parameters + except ValueError: + return + if name is None: + name = func.__name__ + for k in kwargs.keys(): + if k not in params: + raise TypeError( + f"{name}() got an unexpected keyword argument '{k}'") diff --git a/paddlespeech/audio/utils/dynamic_import.py b/paddlespeech/audio/utils/dynamic_import.py new file mode 100644 index 00000000..99f93356 --- /dev/null +++ b/paddlespeech/audio/utils/dynamic_import.py @@ -0,0 +1,38 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +import importlib + +__all__ = ["dynamic_import"] + + +def dynamic_import(import_path, alias=dict()): + """dynamic import module and class + + :param str import_path: syntax 'module_name:class_name' + e.g., 'paddlespeech.s2t.models.u2:U2Model' + :param dict alias: shortcut for registered class + :return: imported class + """ + if import_path not in alias and ":" not in import_path: + raise ValueError( + "import_path should be one of {} or " + 'include ":", e.g. "paddlespeech.s2t.models.u2:U2Model" : ' + "{}".format(set(alias), import_path)) + if ":" not in import_path: + import_path = alias[import_path] + + module_name, objname = import_path.split(":") + m = importlib.import_module(module_name) + return getattr(m, objname) diff --git a/paddlespeech/audio/utils/tensor_utils.py b/paddlespeech/audio/utils/tensor_utils.py new file mode 100644 index 00000000..bae473ec --- /dev/null +++ b/paddlespeech/audio/utils/tensor_utils.py @@ -0,0 +1,195 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unility functions for Transformer.""" +from typing import List +from typing import Tuple + +import paddle + +from .log import Logger + +__all__ = ["pad_sequence", "add_sos_eos", "th_accuracy", "has_tensor"] + +logger = Logger(__name__) + + +def has_tensor(val): + if isinstance(val, (list, tuple)): + for item in val: + if has_tensor(item): + return True + elif isinstance(val, dict): + for k, v in val.items(): + print(k) + if has_tensor(v): + return True + else: + return paddle.is_tensor(val) + + +def pad_sequence(sequences: List[paddle.Tensor], + batch_first: bool=False, + padding_value: float=0.0) -> paddle.Tensor: + r"""Pad a list of variable length Tensors with ``padding_value`` + + ``pad_sequence`` stacks a list of Tensors along a new dimension, + and pads them to equal length. For example, if the input is list of + sequences with size ``L x *`` and if batch_first is False, and ``T x B x *`` + otherwise. + + `B` is batch size. It is equal to the number of elements in ``sequences``. + `T` is length of the longest sequence. + `L` is length of the sequence. + `*` is any number of trailing dimensions, including none. + + Example: + >>> from paddle.nn.utils.rnn import pad_sequence + >>> a = paddle.ones(25, 300) + >>> b = paddle.ones(22, 300) + >>> c = paddle.ones(15, 300) + >>> pad_sequence([a, b, c]).shape + paddle.Tensor([25, 3, 300]) + + Note: + This function returns a Tensor of size ``T x B x *`` or ``B x T x *`` + where `T` is the length of the longest sequence. This function assumes + trailing dimensions and type of all the Tensors in sequences are same. + + Args: + sequences (list[Tensor]): list of variable length sequences. + batch_first (bool, optional): output will be in ``B x T x *`` if True, or in + ``T x B x *`` otherwise + padding_value (float, optional): value for padded elements. Default: 0. + + Returns: + Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``. + Tensor of size ``B x T x *`` otherwise + """ + + # assuming trailing dimensions and type of all the Tensors + # in sequences are same and fetching those from sequences[0] + max_size = paddle.shape(sequences[0]) + # (TODO Hui Zhang): slice not supprot `end==start` + # trailing_dims = max_size[1:] + trailing_dims = tuple( + max_size[1:].numpy().tolist()) if sequences[0].ndim >= 2 else () + max_len = max([s.shape[0] for s in sequences]) + if batch_first: + out_dims = (len(sequences), max_len) + trailing_dims + else: + out_dims = (max_len, len(sequences)) + trailing_dims + out_tensor = paddle.full(out_dims, padding_value, sequences[0].dtype) + for i, tensor in enumerate(sequences): + length = tensor.shape[0] + # use index notation to prevent duplicate references to the tensor + logger.info( + f"length {length}, out_tensor {out_tensor.shape}, tensor {tensor.shape}" + ) + if batch_first: + # TODO (Hui Zhang): set_value op not supprot `end==start` + # TODO (Hui Zhang): set_value op not support int16 + # TODO (Hui Zhang): set_varbase 2 rank not support [0,0,...] + # out_tensor[i, :length, ...] = tensor + if length != 0: + out_tensor[i, :length] = tensor + else: + out_tensor[i, length] = tensor + else: + # TODO (Hui Zhang): set_value op not supprot `end==start` + # out_tensor[:length, i, ...] = tensor + if length != 0: + out_tensor[:length, i] = tensor + else: + out_tensor[length, i] = tensor + + return out_tensor + + +def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int, + ignore_id: int) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Add and labels. + Args: + ys_pad (paddle.Tensor): batch of padded target sequences (B, Lmax) + sos (int): index of + eos (int): index of + ignore_id (int): index of padding + Returns: + ys_in (paddle.Tensor) : (B, Lmax + 1) + ys_out (paddle.Tensor) : (B, Lmax + 1) + Examples: + >>> sos_id = 10 + >>> eos_id = 11 + >>> ignore_id = -1 + >>> ys_pad + tensor([[ 1, 2, 3, 4, 5], + [ 4, 5, 6, -1, -1], + [ 7, 8, 9, -1, -1]], dtype=paddle.int32) + >>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id) + >>> ys_in + tensor([[10, 1, 2, 3, 4, 5], + [10, 4, 5, 6, 11, 11], + [10, 7, 8, 9, 11, 11]]) + >>> ys_out + tensor([[ 1, 2, 3, 4, 5, 11], + [ 4, 5, 6, 11, -1, -1], + [ 7, 8, 9, 11, -1, -1]]) + """ + # TODO(Hui Zhang): using comment code, + #_sos = paddle.to_tensor( + # [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place) + #_eos = paddle.to_tensor( + # [eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place) + #ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys + #ys_in = [paddle.cat([_sos, y], dim=0) for y in ys] + #ys_out = [paddle.cat([y, _eos], dim=0) for y in ys] + #return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id) + B = ys_pad.shape[0] + _sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos + _eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos + ys_in = paddle.cat([_sos, ys_pad], dim=1) + mask_pad = (ys_in == ignore_id) + ys_in = ys_in.masked_fill(mask_pad, eos) + + ys_out = paddle.cat([ys_pad, _eos], dim=1) + ys_out = ys_out.masked_fill(mask_pad, eos) + mask_eos = (ys_out == ignore_id) + ys_out = ys_out.masked_fill(mask_eos, eos) + ys_out = ys_out.masked_fill(mask_pad, ignore_id) + return ys_in, ys_out + + +def th_accuracy(pad_outputs: paddle.Tensor, + pad_targets: paddle.Tensor, + ignore_label: int) -> float: + """Calculate accuracy. + Args: + pad_outputs (Tensor): Prediction tensors (B * Lmax, D). + pad_targets (LongTensor): Target label tensors (B, Lmax, D). + ignore_label (int): Ignore label id. + Returns: + float: Accuracy value (0.0 - 1.0). + """ + pad_pred = pad_outputs.view(pad_targets.shape[0], pad_targets.shape[1], + pad_outputs.shape[1]).argmax(2) + mask = pad_targets != ignore_label + #TODO(Hui Zhang): sum not support bool type + # numerator = paddle.sum( + # pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) + numerator = ( + pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) + numerator = paddle.sum(numerator.type_as(pad_targets)) + #TODO(Hui Zhang): sum not support bool type + # denominator = paddle.sum(mask) + denominator = paddle.sum(mask.type_as(pad_targets)) + return float(numerator) / float(denominator) diff --git a/setup.py b/setup.py index 679549b4..b94a4cb2 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,7 @@ base = [ "pypinyin", "pypinyin-dict", "python-dateutil", "pyworld", "resampy==0.2.2", "sacrebleu", "scipy", "sentencepiece~=0.1.96", "soundfile~=0.10", "textgrid", "timer", "tqdm", "typeguard", "visualdl", "webrtcvad", - "yacs~=0.1.8", "prettytable", "zhon", 'colorlog', 'pathos == 0.2.8' + "yacs~=0.1.8", "prettytable", "zhon", 'colorlog', 'pathos == 0.2.8', 'webdataset' ] server = [ From c7a7b113c856a455f92b70b256f782d25133bc8d Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Fri, 24 Jun 2022 05:01:44 +0000 Subject: [PATCH 2/7] support multi-gpu training with webdataset --- examples/wenetspeech/asr1/conf/conformer.yaml | 35 +- paddlespeech/audio/stream_data/__init__.py | 3 +- paddlespeech/audio/stream_data/filters.py | 39 +- paddlespeech/audio/stream_data/pipeline.py | 6 + paddlespeech/audio/stream_data/shardlists.py | 2 + paddlespeech/audio/utils/log.py | 3 +- paddlespeech/audio/utils/tensor_utils.py | 3 - paddlespeech/s2t/exps/u2/model.py | 269 ++++++---- paddlespeech/s2t/io/dataloader.py | 87 ++++ paddlespeech/s2t/io/reader.py | 2 +- paddlespeech/s2t/transform/__init__.py | 13 - paddlespeech/s2t/transform/add_deltas.py | 54 -- .../s2t/transform/channel_selector.py | 57 --- paddlespeech/s2t/transform/cmvn.py | 201 -------- paddlespeech/s2t/transform/functional.py | 86 ---- paddlespeech/s2t/transform/perturb.py | 471 ----------------- paddlespeech/s2t/transform/spec_augment.py | 214 -------- paddlespeech/s2t/transform/spectrogram.py | 475 ------------------ .../s2t/transform/transform_interface.py | 35 -- paddlespeech/s2t/transform/transformation.py | 158 ------ paddlespeech/s2t/transform/wpe.py | 58 --- 21 files changed, 341 insertions(+), 1930 deletions(-) delete mode 100644 paddlespeech/s2t/transform/__init__.py delete mode 100644 paddlespeech/s2t/transform/add_deltas.py delete mode 100644 paddlespeech/s2t/transform/channel_selector.py delete mode 100644 paddlespeech/s2t/transform/cmvn.py delete mode 100644 paddlespeech/s2t/transform/functional.py delete mode 100644 paddlespeech/s2t/transform/perturb.py delete mode 100644 paddlespeech/s2t/transform/spec_augment.py delete mode 100644 paddlespeech/s2t/transform/spectrogram.py delete mode 100644 paddlespeech/s2t/transform/transform_interface.py delete mode 100644 paddlespeech/s2t/transform/transformation.py delete mode 100644 paddlespeech/s2t/transform/wpe.py diff --git a/examples/wenetspeech/asr1/conf/conformer.yaml b/examples/wenetspeech/asr1/conf/conformer.yaml index 6c2bbca4..dd4ff0e2 100644 --- a/examples/wenetspeech/asr1/conf/conformer.yaml +++ b/examples/wenetspeech/asr1/conf/conformer.yaml @@ -50,26 +50,41 @@ test_manifest: data/manifest.test ########################################### # Dataloader # ########################################### -vocab_filepath: data/lang_char/vocab.txt +use_stream_data: True unit_type: 'char' +vocab_filepath: data/lang_char/vocab.txt +cmvn_file: data/mean_std.json preprocess_config: conf/preprocess.yaml spm_model_prefix: '' feat_dim: 80 stride_ms: 10.0 window_ms: 25.0 +dither: 0.1 sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs batch_size: 64 +minlen_in: 10 maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced +minlen_out: 0 maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced -minibatches: 0 # for debug -batch_count: auto -batch_bins: 0 -batch_frames_in: 0 -batch_frames_out: 0 -batch_frames_inout: 0 -num_workers: 0 -subsampling_factor: 1 +resample_rate: 16000 +shuffle_size: 10000 +sort_size: 500 +num_workers: 4 +prefetch_factor: 100 +dist_sampler: True num_encs: 1 +augment_conf: + max_w: 80 + w_inplace: True + w_mode: "PIL" + max_f: 30 + num_f_mask: 2 + f_inplace: True + f_replace_with_zero: False + max_t: 40 + num_t_mask: 2 + t_inplace: True + t_replace_with_zero: False ########################################### @@ -78,7 +93,7 @@ num_encs: 1 n_epoch: 240 accum_grad: 16 global_grad_clip: 5.0 -log_interval: 100 +log_interval: 1 checkpoint: kbest_n: 50 latest_n: 5 diff --git a/paddlespeech/audio/stream_data/__init__.py b/paddlespeech/audio/stream_data/__init__.py index fdb3458c..e9706d4e 100644 --- a/paddlespeech/audio/stream_data/__init__.py +++ b/paddlespeech/audio/stream_data/__init__.py @@ -41,7 +41,8 @@ from .filters import ( spec_aug, sort, padding, - cmvn + cmvn, + placeholder, ) from webdataset.handlers import ( ignore_and_continue, diff --git a/paddlespeech/audio/stream_data/filters.py b/paddlespeech/audio/stream_data/filters.py index 3112c954..db3e037a 100644 --- a/paddlespeech/audio/stream_data/filters.py +++ b/paddlespeech/audio/stream_data/filters.py @@ -758,27 +758,44 @@ def _compute_fbank(source, compute_fbank = pipelinefilter(_compute_fbank) -def _spec_aug(source, num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80): +def _spec_aug(source, + max_w=5, + w_inplace=True, + w_mode="PIL", + max_f=30, + num_f_mask=2, + f_inplace=True, + f_replace_with_zero=False, + max_t=40, + num_t_mask=2, + t_inplace=True, + t_replace_with_zero=False,): """ Do spec augmentation Inplace operation Args: source: Iterable[{fname, feat, label}] - num_t_mask: number of time mask to apply + max_w: max width of time warp + w_inplace: whether to inplace the original data while time warping + w_mode: time warp mode + max_f: max width of freq mask num_f_mask: number of freq mask to apply + f_inplace: whether to inplace the original data while frequency masking + f_replace_with_zero: use zero to mask max_t: max width of time mask - max_f: max width of freq mask - max_w: max width of time warp - + num_t_mask: number of time mask to apply + t_inplace: whether to inplace the original data while time masking + t_replace_with_zero: use zero to mask + Returns Iterable[{fname, feat, label}] """ for sample in source: x = sample['feat'] x = x.numpy() - x = time_warp(x, max_time_warp=max_w, inplace = True, mode= "PIL") - x = freq_mask(x, F = max_f, n_mask = num_f_mask, inplace = True, replace_with_zero = False) - x = time_mask(x, T = max_t, n_mask = num_t_mask, inplace = True, replace_with_zero = False) + x = time_warp(x, max_time_warp=max_w, inplace = w_inplace, mode= w_mode) + x = freq_mask(x, F = max_f, n_mask = num_f_mask, inplace = f_inplace, replace_with_zero = f_replace_with_zero) + x = time_mask(x, T = max_t, n_mask = num_t_mask, inplace = t_inplace, replace_with_zero = t_replace_with_zero) sample['feat'] = paddle.to_tensor(x, dtype=paddle.float32) yield sample @@ -910,3 +927,9 @@ def _cmvn(source, cmvn_file): label_lengths) cmvn = pipelinefilter(_cmvn) + +def _placeholder(source): + for data in source: + yield data + +placeholder = pipelinefilter(_placeholder) \ No newline at end of file diff --git a/paddlespeech/audio/stream_data/pipeline.py b/paddlespeech/audio/stream_data/pipeline.py index b672773b..e738083f 100644 --- a/paddlespeech/audio/stream_data/pipeline.py +++ b/paddlespeech/audio/stream_data/pipeline.py @@ -89,6 +89,12 @@ class DataPipeline(IterableDataset, PipelineStage): def append(self, f): """Append a pipeline stage (modifies the object).""" self.pipeline.append(f) + return self + + def append_list(self, *args): + for arg in args: + self.pipeline.append(arg) + return self def compose(self, *args): """Append a pipeline stage to a copy of the pipeline and returns the copy.""" diff --git a/paddlespeech/audio/stream_data/shardlists.py b/paddlespeech/audio/stream_data/shardlists.py index 503bfe57..3d1801cc 100644 --- a/paddlespeech/audio/stream_data/shardlists.py +++ b/paddlespeech/audio/stream_data/shardlists.py @@ -24,6 +24,8 @@ from .filters import pipelinefilter from .paddle_utils import IterableDataset +from ..utils.log import Logger +logger = Logger(__name__) def expand_urls(urls): if isinstance(urls, str): urllist = urls.split("::") diff --git a/paddlespeech/audio/utils/log.py b/paddlespeech/audio/utils/log.py index 5656b286..0a25bbd5 100644 --- a/paddlespeech/audio/utils/log.py +++ b/paddlespeech/audio/utils/log.py @@ -65,6 +65,7 @@ class Logger(object): def __init__(self, name: str=None): name = 'PaddleAudio' if not name else name + self.name = name self.logger = logging.getLogger(name) for key, conf in log_config.items(): @@ -101,7 +102,7 @@ class Logger(object): if not self.is_enable: return - self.logger.log(log_level, msg) + self.logger.log(log_level, self.name + " | " + msg) @contextlib.contextmanager def use_terminator(self, terminator: str): diff --git a/paddlespeech/audio/utils/tensor_utils.py b/paddlespeech/audio/utils/tensor_utils.py index bae473ec..16f60810 100644 --- a/paddlespeech/audio/utils/tensor_utils.py +++ b/paddlespeech/audio/utils/tensor_utils.py @@ -93,9 +93,6 @@ def pad_sequence(sequences: List[paddle.Tensor], for i, tensor in enumerate(sequences): length = tensor.shape[0] # use index notation to prevent duplicate references to the tensor - logger.info( - f"length {length}, out_tensor {out_tensor.shape}, tensor {tensor.shape}" - ) if batch_first: # TODO (Hui Zhang): set_value op not supprot `end==start` # TODO (Hui Zhang): set_value op not support int16 diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index efcc9629..d6c68f96 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -26,6 +26,7 @@ from paddle import distributed as dist from paddlespeech.s2t.frontend.featurizer import TextFeaturizer from paddlespeech.s2t.io.dataloader import BatchDataLoader +from paddlespeech.s2t.io.dataloader import StreamDataLoader from paddlespeech.s2t.models.u2 import U2Model from paddlespeech.s2t.training.optimizer import OptimizerFactory from paddlespeech.s2t.training.reporter import ObsScope @@ -106,7 +107,8 @@ class U2Trainer(Trainer): @paddle.no_grad() def valid(self): self.model.eval() - logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") + if not self.use_streamdata: + logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") valid_losses = defaultdict(list) num_seen_utts = 1 total_loss = 0.0 @@ -132,7 +134,7 @@ class U2Trainer(Trainer): msg = f"Valid: Rank: {dist.get_rank()}, " msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) - msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) + #msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in valid_dump.items()) logger.info(msg) @@ -152,7 +154,8 @@ class U2Trainer(Trainer): self.before_train() - logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") + if not self.use_streamdata: + logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.n_epoch: with Timer("Epoch-Train Time Cost: {}"): self.model.train() @@ -170,7 +173,8 @@ class U2Trainer(Trainer): self.train_batch(batch_index, batch, msg) self.after_train_batch() report('iter', batch_index + 1) - report('total', len(self.train_loader)) + if not self.use_streamdata: + report('total', len(self.train_loader)) report('reader_cost', dataload_time) observation['batch_cost'] = observation[ 'reader_cost'] + observation['step_cost'] @@ -218,92 +222,188 @@ class U2Trainer(Trainer): def setup_dataloader(self): config = self.config.clone() - + self.use_streamdata = config.get("use_stream_data", False) if self.train: # train/valid dataset, return token ids - self.train_loader = BatchDataLoader( - json_file=config.train_manifest, - train_mode=True, - sortagrad=config.sortagrad, - batch_size=config.batch_size, - maxlen_in=config.maxlen_in, - maxlen_out=config.maxlen_out, - minibatches=config.minibatches, - mini_batch_size=self.args.ngpu, - batch_count=config.batch_count, - batch_bins=config.batch_bins, - batch_frames_in=config.batch_frames_in, - batch_frames_out=config.batch_frames_out, - batch_frames_inout=config.batch_frames_inout, - preprocess_conf=config.preprocess_config, - n_iter_processes=config.num_workers, - subsampling_factor=1, - num_encs=1, - dist_sampler=config.get('dist_sampler', False), - shortest_first=False) - - self.valid_loader = BatchDataLoader( - json_file=config.dev_manifest, - train_mode=False, - sortagrad=False, - batch_size=config.batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=self.args.ngpu, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=config.preprocess_config, - n_iter_processes=config.num_workers, - subsampling_factor=1, - num_encs=1, - dist_sampler=config.get('dist_sampler', False), - shortest_first=False) + if self.use_streamdata: + self.train_loader = StreamDataLoader( + manifest_file=config.train_manifest, + train_mode=True, + unit_type=config.unit_type, + batch_size=config.batch_size, + num_mel_bins=config.feat_dim, + frame_length=config.window_ms, + frame_shift=config.stride_ms, + dither=config.dither, + minlen_in=config.minlen_in, + maxlen_in=config.maxlen_in, + minlen_out=config.minlen_out, + maxlen_out=config.maxlen_out, + resample_rate=config.resample_rate, + augment_conf=config.augment_conf, # dict + shuffle_size=config.shuffle_size, + sort_size=config.sort_size, + n_iter_processes=config.num_workers, + prefetch_factor=config.prefetch_factor, + dist_sampler=config.get('dist_sampler', False), + cmvn_file=config.cmvn_file, + vocab_filepath=config.vocab_filepath, + ) + self.valid_loader = StreamDataLoader( + manifest_file=config.dev_manifest, + train_mode=False, + unit_type=config.unit_type, + batch_size=config.batch_size, + num_mel_bins=config.feat_dim, + frame_length=config.window_ms, + frame_shift=config.stride_ms, + dither=config.dither, + minlen_in=config.minlen_in, + maxlen_in=config.maxlen_in, + minlen_out=config.minlen_out, + maxlen_out=config.maxlen_out, + resample_rate=config.resample_rate, + augment_conf=config.augment_conf, # dict + shuffle_size=config.shuffle_size, + sort_size=config.sort_size, + n_iter_processes=config.num_workers, + prefetch_factor=config.prefetch_factor, + dist_sampler=config.get('dist_sampler', False), + cmvn_file=config.cmvn_file, + vocab_filepath=config.vocab_filepath, + ) + else: + self.train_loader = BatchDataLoader( + json_file=config.train_manifest, + train_mode=True, + sortagrad=config.sortagrad, + batch_size=config.batch_size, + maxlen_in=config.maxlen_in, + maxlen_out=config.maxlen_out, + minibatches=config.minibatches, + mini_batch_size=self.args.ngpu, + batch_count=config.batch_count, + batch_bins=config.batch_bins, + batch_frames_in=config.batch_frames_in, + batch_frames_out=config.batch_frames_out, + batch_frames_inout=config.batch_frames_inout, + preprocess_conf=config.preprocess_config, + n_iter_processes=config.num_workers, + subsampling_factor=1, + num_encs=1, + dist_sampler=config.get('dist_sampler', False), + shortest_first=False) + + self.valid_loader = BatchDataLoader( + json_file=config.dev_manifest, + train_mode=False, + sortagrad=False, + batch_size=config.batch_size, + maxlen_in=float('inf'), + maxlen_out=float('inf'), + minibatches=0, + mini_batch_size=self.args.ngpu, + batch_count='auto', + batch_bins=0, + batch_frames_in=0, + batch_frames_out=0, + batch_frames_inout=0, + preprocess_conf=config.preprocess_config, + n_iter_processes=config.num_workers, + subsampling_factor=1, + num_encs=1, + dist_sampler=config.get('dist_sampler', False), + shortest_first=False) logger.info("Setup train/valid Dataloader!") else: decode_batch_size = config.get('decode', dict()).get( 'decode_batch_size', 1) # test dataset, return raw text - self.test_loader = BatchDataLoader( - json_file=config.test_manifest, - train_mode=False, - sortagrad=False, - batch_size=decode_batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=1, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=config.preprocess_config, - n_iter_processes=1, - subsampling_factor=1, - num_encs=1) - - self.align_loader = BatchDataLoader( - json_file=config.test_manifest, - train_mode=False, - sortagrad=False, - batch_size=decode_batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=1, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=config.preprocess_config, - n_iter_processes=1, - subsampling_factor=1, - num_encs=1) + if self.use_streamdata: + self.test_loader = StreamDataLoader( + manifest_file=config.test_manifest, + train_mode=False, + unit_type=config.unit_type, + batch_size=config.batch_size, + num_mel_bins=config.feat_dim, + frame_length=config.window_ms, + frame_shift=config.stride_ms, + dither=0.0, + minlen_in=0.0, + maxlen_in=float('inf'), + minlen_out=0, + maxlen_out=float('inf'), + resample_rate=config.resample_rate, + augment_conf=config.augment_conf, # dict + shuffle_size=config.shuffle_size, + sort_size=config.sort_size, + n_iter_processes=config.num_workers, + prefetch_factor=config.prefetch_factor, + dist_sampler=config.get('dist_sampler', False), + cmvn_file=config.cmvn_file, + vocab_filepath=config.vocab_filepath, + ) + self.align_loader = StreamDataLoader( + manifest_file=config.test_manifest, + train_mode=False, + unit_type=config.unit_type, + batch_size=config.batch_size, + num_mel_bins=config.feat_dim, + frame_length=config.window_ms, + frame_shift=config.stride_ms, + dither=0.0, + minlen_in=0.0, + maxlen_in=float('inf'), + minlen_out=0, + maxlen_out=float('inf'), + resample_rate=config.resample_rate, + augment_conf=config.augment_conf, # dict + shuffle_size=config.shuffle_size, + sort_size=config.sort_size, + n_iter_processes=config.num_workers, + prefetch_factor=config.prefetch_factor, + dist_sampler=config.get('dist_sampler', False), + cmvn_file=config.cmvn_file, + vocab_filepath=config.vocab_filepath, + ) + else: + self.test_loader = BatchDataLoader( + json_file=config.test_manifest, + train_mode=False, + sortagrad=False, + batch_size=decode_batch_size, + maxlen_in=float('inf'), + maxlen_out=float('inf'), + minibatches=0, + mini_batch_size=1, + batch_count='auto', + batch_bins=0, + batch_frames_in=0, + batch_frames_out=0, + batch_frames_inout=0, + preprocess_conf=config.preprocess_config, + n_iter_processes=1, + subsampling_factor=1, + num_encs=1) + + self.align_loader = BatchDataLoader( + json_file=config.test_manifest, + train_mode=False, + sortagrad=False, + batch_size=decode_batch_size, + maxlen_in=float('inf'), + maxlen_out=float('inf'), + minibatches=0, + mini_batch_size=1, + batch_count='auto', + batch_bins=0, + batch_frames_in=0, + batch_frames_out=0, + batch_frames_inout=0, + preprocess_conf=config.preprocess_config, + n_iter_processes=1, + subsampling_factor=1, + num_encs=1) logger.info("Setup test/align Dataloader!") def setup_model(self): @@ -452,7 +552,8 @@ class U2Tester(U2Trainer): def test(self): assert self.args.result_file self.model.eval() - logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") + if not self.use_streamdata: + logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") stride_ms = self.config.stride_ms error_rate_type = None diff --git a/paddlespeech/s2t/io/dataloader.py b/paddlespeech/s2t/io/dataloader.py index 55aa13ff..c27969f0 100644 --- a/paddlespeech/s2t/io/dataloader.py +++ b/paddlespeech/s2t/io/dataloader.py @@ -28,6 +28,9 @@ from paddlespeech.s2t.io.dataset import TransformDataset from paddlespeech.s2t.io.reader import LoadInputsAndTargets from paddlespeech.s2t.utils.log import Log +import paddlespeech.audio.stream_data as stream_data +from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer + __all__ = ["BatchDataLoader"] logger = Log(__name__).getlog() @@ -56,6 +59,90 @@ def batch_collate(x): """ return x[0] +class StreamDataLoader(): + def __init__(self, + manifest_file: str, + train_mode: bool, + unit_type: str='char', + batch_size: int=0, + num_mel_bins=80, + frame_length=25, + frame_shift=10, + dither=0.0, + minlen_in: float=0.0, + maxlen_in: float=float('inf'), + minlen_out: float=0.0, + maxlen_out: float=float('inf'), + resample_rate: int=16000, + augment_conf: dict=None, + shuffle_size: int=10000, + sort_size: int=1000, + n_iter_processes: int=1, + prefetch_factor: int=2, + dist_sampler: bool=False, + cmvn_file="data/mean_std.json", + vocab_filepath='data/lang_char/vocab.txt'): + self.manifest_file = manifest_file + self.train_model = train_mode + self.batch_size = batch_size + self.prefetch_factor = prefetch_factor + self.dist_sampler = dist_sampler + self.n_iter_processes = n_iter_processes + + text_featurizer = TextFeaturizer(unit_type, vocab_filepath) + symbol_table = text_featurizer.vocab_dict + self.feat_dim = num_mel_bins + self.vocab_size = text_featurizer.vocab_size + + # The list of shard + shardlist = [] + with open(manifest_file, "r") as f: + for line in f.readlines(): + shardlist.append(line.strip()) + + if self.dist_sampler: + base_dataset = stream_data.DataPipeline( + stream_data.SimpleShardList(shardlist), + stream_data.split_by_node, + stream_data.split_by_worker, + stream_data.tarfile_to_samples(stream_data.reraise_exception) + ) + else: + base_dataset = stream_data.DataPipeline( + stream_data.SimpleShardList(shardlist), + stream_data.split_by_worker, + stream_data.tarfile_to_samples(stream_data.reraise_exception) + ) + + self.dataset = base_dataset.append_list( + stream_data.tokenize(symbol_table), + stream_data.data_filter(frame_shift=frame_shift, max_length=maxlen_in, min_length=minlen_in, token_max_length=maxlen_out, token_min_length=minlen_in), + stream_data.resample(resample_rate=resample_rate), + stream_data.compute_fbank(num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither), + stream_data.spec_aug(**augment_conf) if train_mode else stream_data.placeholder(), # num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80) + stream_data.shuffle(shuffle_size), + stream_data.sort(sort_size=sort_size), + stream_data.batched(batch_size), + stream_data.padding(), + stream_data.cmvn(cmvn_file) + ) + self.loader = stream_data.WebLoader( + self.dataset, + num_workers=self.n_iter_processes, + prefetch_factor = self.prefetch_factor, + batch_size=None + ) + + def __iter__(self): + return self.loader.__iter__() + + def __call__(self): + return self.__iter__() + + def __len__(self): + logger.info("Stream dataloader does not support calculate the length of the dataset") + return -1 + class BatchDataLoader(): def __init__(self, diff --git a/paddlespeech/s2t/io/reader.py b/paddlespeech/s2t/io/reader.py index 4e136bdc..5e018bef 100644 --- a/paddlespeech/s2t/io/reader.py +++ b/paddlespeech/s2t/io/reader.py @@ -19,7 +19,7 @@ import numpy as np import soundfile from .utility import feat_type -from paddlespeech.s2t.transform.transformation import Transformation +from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.s2t.utils.log import Log # from paddlespeech.s2t.frontend.augmentor.augmentation import AugmentationPipeline as Transformation diff --git a/paddlespeech/s2t/transform/__init__.py b/paddlespeech/s2t/transform/__init__.py deleted file mode 100644 index 185a92b8..00000000 --- a/paddlespeech/s2t/transform/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/paddlespeech/s2t/transform/add_deltas.py b/paddlespeech/s2t/transform/add_deltas.py deleted file mode 100644 index 1387fe9d..00000000 --- a/paddlespeech/s2t/transform/add_deltas.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Modified from espnet(https://github.com/espnet/espnet) -import numpy as np - - -def delta(feat, window): - assert window > 0 - delta_feat = np.zeros_like(feat) - for i in range(1, window + 1): - delta_feat[:-i] += i * feat[i:] - delta_feat[i:] += -i * feat[:-i] - delta_feat[-i:] += i * feat[-1] - delta_feat[:i] += -i * feat[0] - delta_feat /= 2 * sum(i**2 for i in range(1, window + 1)) - return delta_feat - - -def add_deltas(x, window=2, order=2): - """ - Args: - x (np.ndarray): speech feat, (T, D). - - Return: - np.ndarray: (T, (1+order)*D) - """ - feats = [x] - for _ in range(order): - feats.append(delta(feats[-1], window)) - return np.concatenate(feats, axis=1) - - -class AddDeltas(): - def __init__(self, window=2, order=2): - self.window = window - self.order = order - - def __repr__(self): - return "{name}(window={window}, order={order}".format( - name=self.__class__.__name__, window=self.window, order=self.order) - - def __call__(self, x): - return add_deltas(x, window=self.window, order=self.order) diff --git a/paddlespeech/s2t/transform/channel_selector.py b/paddlespeech/s2t/transform/channel_selector.py deleted file mode 100644 index b078dcf8..00000000 --- a/paddlespeech/s2t/transform/channel_selector.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Modified from espnet(https://github.com/espnet/espnet) -import numpy - - -class ChannelSelector(): - """Select 1ch from multi-channel signal""" - - def __init__(self, train_channel="random", eval_channel=0, axis=1): - self.train_channel = train_channel - self.eval_channel = eval_channel - self.axis = axis - - def __repr__(self): - return ("{name}(train_channel={train_channel}, " - "eval_channel={eval_channel}, axis={axis})".format( - name=self.__class__.__name__, - train_channel=self.train_channel, - eval_channel=self.eval_channel, - axis=self.axis, )) - - def __call__(self, x, train=True): - # Assuming x: [Time, Channel] by default - - if x.ndim <= self.axis: - # If the dimension is insufficient, then unsqueeze - # (e.g [Time] -> [Time, 1]) - ind = tuple( - slice(None) if i < x.ndim else None - for i in range(self.axis + 1)) - x = x[ind] - - if train: - channel = self.train_channel - else: - channel = self.eval_channel - - if channel == "random": - ch = numpy.random.randint(0, x.shape[self.axis]) - else: - ch = channel - - ind = tuple( - slice(None) if i != self.axis else ch for i in range(x.ndim)) - return x[ind] diff --git a/paddlespeech/s2t/transform/cmvn.py b/paddlespeech/s2t/transform/cmvn.py deleted file mode 100644 index 2db0070b..00000000 --- a/paddlespeech/s2t/transform/cmvn.py +++ /dev/null @@ -1,201 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Modified from espnet(https://github.com/espnet/espnet) -import io -import json - -import h5py -import kaldiio -import numpy as np - - -class CMVN(): - "Apply Global/Spk CMVN/iverserCMVN." - - def __init__( - self, - stats, - norm_means=True, - norm_vars=False, - filetype="mat", - utt2spk=None, - spk2utt=None, - reverse=False, - std_floor=1.0e-20, ): - self.stats_file = stats - self.norm_means = norm_means - self.norm_vars = norm_vars - self.reverse = reverse - - if isinstance(stats, dict): - stats_dict = dict(stats) - else: - # Use for global CMVN - if filetype == "mat": - stats_dict = {None: kaldiio.load_mat(stats)} - # Use for global CMVN - elif filetype == "npy": - stats_dict = {None: np.load(stats)} - # Use for speaker CMVN - elif filetype == "ark": - self.accept_uttid = True - stats_dict = dict(kaldiio.load_ark(stats)) - # Use for speaker CMVN - elif filetype == "hdf5": - self.accept_uttid = True - stats_dict = h5py.File(stats) - else: - raise ValueError("Not supporting filetype={}".format(filetype)) - - if utt2spk is not None: - self.utt2spk = {} - with io.open(utt2spk, "r", encoding="utf-8") as f: - for line in f: - utt, spk = line.rstrip().split(None, 1) - self.utt2spk[utt] = spk - elif spk2utt is not None: - self.utt2spk = {} - with io.open(spk2utt, "r", encoding="utf-8") as f: - for line in f: - spk, utts = line.rstrip().split(None, 1) - for utt in utts.split(): - self.utt2spk[utt] = spk - else: - self.utt2spk = None - - # Kaldi makes a matrix for CMVN which has a shape of (2, feat_dim + 1), - # and the first vector contains the sum of feats and the second is - # the sum of squares. The last value of the first, i.e. stats[0,-1], - # is the number of samples for this statistics. - self.bias = {} - self.scale = {} - for spk, stats in stats_dict.items(): - assert len(stats) == 2, stats.shape - - count = stats[0, -1] - - # If the feature has two or more dimensions - if not (np.isscalar(count) or isinstance(count, (int, float))): - # The first is only used - count = count.flatten()[0] - - mean = stats[0, :-1] / count - # V(x) = E(x^2) - (E(x))^2 - var = stats[1, :-1] / count - mean * mean - std = np.maximum(np.sqrt(var), std_floor) - self.bias[spk] = -mean - self.scale[spk] = 1 / std - - def __repr__(self): - return ("{name}(stats_file={stats_file}, " - "norm_means={norm_means}, norm_vars={norm_vars}, " - "reverse={reverse})".format( - name=self.__class__.__name__, - stats_file=self.stats_file, - norm_means=self.norm_means, - norm_vars=self.norm_vars, - reverse=self.reverse, )) - - def __call__(self, x, uttid=None): - if self.utt2spk is not None: - spk = self.utt2spk[uttid] - else: - spk = uttid - - if not self.reverse: - # apply cmvn - if self.norm_means: - x = np.add(x, self.bias[spk]) - if self.norm_vars: - x = np.multiply(x, self.scale[spk]) - - else: - # apply reverse cmvn - if self.norm_vars: - x = np.divide(x, self.scale[spk]) - if self.norm_means: - x = np.subtract(x, self.bias[spk]) - - return x - - -class UtteranceCMVN(): - "Apply Utterance CMVN" - - def __init__(self, norm_means=True, norm_vars=False, std_floor=1.0e-20): - self.norm_means = norm_means - self.norm_vars = norm_vars - self.std_floor = std_floor - - def __repr__(self): - return "{name}(norm_means={norm_means}, norm_vars={norm_vars})".format( - name=self.__class__.__name__, - norm_means=self.norm_means, - norm_vars=self.norm_vars, ) - - def __call__(self, x, uttid=None): - # x: [Time, Dim] - square_sums = (x**2).sum(axis=0) - mean = x.mean(axis=0) - - if self.norm_means: - x = np.subtract(x, mean) - - if self.norm_vars: - var = square_sums / x.shape[0] - mean**2 - std = np.maximum(np.sqrt(var), self.std_floor) - x = np.divide(x, std) - - return x - - -class GlobalCMVN(): - "Apply Global CMVN" - - def __init__(self, - cmvn_path, - norm_means=True, - norm_vars=True, - std_floor=1.0e-20): - # cmvn_path: Option[str, dict] - cmvn = cmvn_path - self.cmvn = cmvn - self.norm_means = norm_means - self.norm_vars = norm_vars - self.std_floor = std_floor - if isinstance(cmvn, dict): - cmvn_stats = cmvn - else: - with open(cmvn) as f: - cmvn_stats = json.load(f) - self.count = cmvn_stats['frame_num'] - self.mean = np.array(cmvn_stats['mean_stat']) / self.count - self.square_sums = np.array(cmvn_stats['var_stat']) - self.var = self.square_sums / self.count - self.mean**2 - self.std = np.maximum(np.sqrt(self.var), self.std_floor) - - def __repr__(self): - return f"""{self.__class__.__name__}( - cmvn_path={self.cmvn}, - norm_means={self.norm_means}, - norm_vars={self.norm_vars},)""" - - def __call__(self, x, uttid=None): - # x: [Time, Dim] - if self.norm_means: - x = np.subtract(x, self.mean) - - if self.norm_vars: - x = np.divide(x, self.std) - return x diff --git a/paddlespeech/s2t/transform/functional.py b/paddlespeech/s2t/transform/functional.py deleted file mode 100644 index ccb50081..00000000 --- a/paddlespeech/s2t/transform/functional.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Modified from espnet(https://github.com/espnet/espnet) -import inspect - -from paddlespeech.s2t.transform.transform_interface import TransformInterface -from paddlespeech.s2t.utils.check_kwargs import check_kwargs - - -class FuncTrans(TransformInterface): - """Functional Transformation - - WARNING: - Builtin or C/C++ functions may not work properly - because this class heavily depends on the `inspect` module. - - Usage: - - >>> def foo_bar(x, a=1, b=2): - ... '''Foo bar - ... :param x: input - ... :param int a: default 1 - ... :param int b: default 2 - ... ''' - ... return x + a - b - - - >>> class FooBar(FuncTrans): - ... _func = foo_bar - ... __doc__ = foo_bar.__doc__ - """ - - _func = None - - def __init__(self, **kwargs): - self.kwargs = kwargs - check_kwargs(self.func, kwargs) - - def __call__(self, x): - return self.func(x, **self.kwargs) - - @classmethod - def add_arguments(cls, parser): - fname = cls._func.__name__.replace("_", "-") - group = parser.add_argument_group(fname + " transformation setting") - for k, v in cls.default_params().items(): - # TODO(karita): get help and choices from docstring? - attr = k.replace("_", "-") - group.add_argument(f"--{fname}-{attr}", default=v, type=type(v)) - return parser - - @property - def func(self): - return type(self)._func - - @classmethod - def default_params(cls): - try: - d = dict(inspect.signature(cls._func).parameters) - except ValueError: - d = dict() - return { - k: v.default - for k, v in d.items() if v.default != inspect.Parameter.empty - } - - def __repr__(self): - params = self.default_params() - params.update(**self.kwargs) - ret = self.__class__.__name__ + "(" - if len(params) == 0: - return ret + ")" - for k, v in params.items(): - ret += "{}={}, ".format(k, v) - return ret[:-2] + ")" diff --git a/paddlespeech/s2t/transform/perturb.py b/paddlespeech/s2t/transform/perturb.py deleted file mode 100644 index b18caefb..00000000 --- a/paddlespeech/s2t/transform/perturb.py +++ /dev/null @@ -1,471 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Modified from espnet(https://github.com/espnet/espnet) -import librosa -import numpy -import scipy -import soundfile - -from paddlespeech.s2t.io.reader import SoundHDF5File - - -class SpeedPerturbation(): - """SpeedPerturbation - - The speed perturbation in kaldi uses sox-speed instead of sox-tempo, - and sox-speed just to resample the input, - i.e pitch and tempo are changed both. - - "Why use speed option instead of tempo -s in SoX for speed perturbation" - https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8 - - Warning: - This function is very slow because of resampling. - I recommmend to apply speed-perturb outside the training using sox. - - """ - - def __init__( - self, - lower=0.9, - upper=1.1, - utt2ratio=None, - keep_length=True, - res_type="kaiser_best", - seed=None, ): - self.res_type = res_type - self.keep_length = keep_length - self.state = numpy.random.RandomState(seed) - - if utt2ratio is not None: - self.utt2ratio = {} - # Use the scheduled ratio for each utterances - self.utt2ratio_file = utt2ratio - self.lower = None - self.upper = None - self.accept_uttid = True - - with open(utt2ratio, "r") as f: - for line in f: - utt, ratio = line.rstrip().split(None, 1) - ratio = float(ratio) - self.utt2ratio[utt] = ratio - else: - self.utt2ratio = None - # The ratio is given on runtime randomly - self.lower = lower - self.upper = upper - - def __repr__(self): - if self.utt2ratio is None: - return "{}(lower={}, upper={}, " "keep_length={}, res_type={})".format( - self.__class__.__name__, - self.lower, - self.upper, - self.keep_length, - self.res_type, ) - else: - return "{}({}, res_type={})".format( - self.__class__.__name__, self.utt2ratio_file, self.res_type) - - def __call__(self, x, uttid=None, train=True): - if not train: - return x - x = x.astype(numpy.float32) - if self.accept_uttid: - ratio = self.utt2ratio[uttid] - else: - ratio = self.state.uniform(self.lower, self.upper) - - # Note1: resample requires the sampling-rate of input and output, - # but actually only the ratio is used. - y = librosa.resample( - x, orig_sr=ratio, target_sr=1, res_type=self.res_type) - - if self.keep_length: - diff = abs(len(x) - len(y)) - if len(y) > len(x): - # Truncate noise - y = y[diff // 2:-((diff + 1) // 2)] - elif len(y) < len(x): - # Assume the time-axis is the first: (Time, Channel) - pad_width = [(diff // 2, (diff + 1) // 2)] + [ - (0, 0) for _ in range(y.ndim - 1) - ] - y = numpy.pad( - y, pad_width=pad_width, constant_values=0, mode="constant") - return y - - -class SpeedPerturbationSox(): - """SpeedPerturbationSox - - The speed perturbation in kaldi uses sox-speed instead of sox-tempo, - and sox-speed just to resample the input, - i.e pitch and tempo are changed both. - - To speed up or slow down the sound of a file, - use speed to modify the pitch and the duration of the file. - This raises the speed and reduces the time. - The default factor is 1.0 which makes no change to the audio. - 2.0 doubles speed, thus time length is cut by a half and pitch is one interval higher. - - "Why use speed option instead of tempo -s in SoX for speed perturbation" - https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8 - - tempo option: - sox -t wav input.wav -t wav output.tempo0.9.wav tempo -s 0.9 - - speed option: - sox -t wav input.wav -t wav output.speed0.9.wav speed 0.9 - - If we use speed option like above, the pitch of audio also will be changed, - but the tempo option does not change the pitch. - """ - - def __init__( - self, - lower=0.9, - upper=1.1, - utt2ratio=None, - keep_length=True, - sr=16000, - seed=None, ): - self.sr = sr - self.keep_length = keep_length - self.state = numpy.random.RandomState(seed) - - try: - import soxbindings as sox - except ImportError: - try: - from paddlespeech.s2t.utils import dynamic_pip_install - package = "sox" - dynamic_pip_install.install(package) - package = "soxbindings" - if sys.platform != "win32": - dynamic_pip_install.install(package) - import soxbindings as sox - except Exception: - raise RuntimeError( - "Can not install soxbindings on your system.") - self.sox = sox - - if utt2ratio is not None: - self.utt2ratio = {} - # Use the scheduled ratio for each utterances - self.utt2ratio_file = utt2ratio - self.lower = None - self.upper = None - self.accept_uttid = True - - with open(utt2ratio, "r") as f: - for line in f: - utt, ratio = line.rstrip().split(None, 1) - ratio = float(ratio) - self.utt2ratio[utt] = ratio - else: - self.utt2ratio = None - # The ratio is given on runtime randomly - self.lower = lower - self.upper = upper - - def __repr__(self): - if self.utt2ratio is None: - return f"""{self.__class__.__name__}( - lower={self.lower}, - upper={self.upper}, - keep_length={self.keep_length}, - sample_rate={self.sr})""" - - else: - return f"""{self.__class__.__name__}( - utt2ratio={self.utt2ratio_file}, - sample_rate={self.sr})""" - - def __call__(self, x, uttid=None, train=True): - if not train: - return x - - x = x.astype(numpy.float32) - if self.accept_uttid: - ratio = self.utt2ratio[uttid] - else: - ratio = self.state.uniform(self.lower, self.upper) - - tfm = self.sox.Transformer() - tfm.set_globals(multithread=False) - tfm.speed(ratio) - y = tfm.build_array(input_array=x, sample_rate_in=self.sr) - - if self.keep_length: - diff = abs(len(x) - len(y)) - if len(y) > len(x): - # Truncate noise - y = y[diff // 2:-((diff + 1) // 2)] - elif len(y) < len(x): - # Assume the time-axis is the first: (Time, Channel) - pad_width = [(diff // 2, (diff + 1) // 2)] + [ - (0, 0) for _ in range(y.ndim - 1) - ] - y = numpy.pad( - y, pad_width=pad_width, constant_values=0, mode="constant") - - if y.ndim == 2 and x.ndim == 1: - # (T, C) -> (T) - y = y.sequence(1) - return y - - -class BandpassPerturbation(): - """BandpassPerturbation - - Randomly dropout along the frequency axis. - - The original idea comes from the following: - "randomly-selected frequency band was cut off under the constraint of - leaving at least 1,000 Hz band within the range of less than 4,000Hz." - (The Hitachi/JHU CHiME-5 system: Advances in speech recognition for - everyday home environments using multiple microphone arrays; - http://spandh.dcs.shef.ac.uk/chime_workshop/papers/CHiME_2018_paper_kanda.pdf) - - """ - - def __init__(self, lower=0.0, upper=0.75, seed=None, axes=(-1, )): - self.lower = lower - self.upper = upper - self.state = numpy.random.RandomState(seed) - # x_stft: (Time, Channel, Freq) - self.axes = axes - - def __repr__(self): - return "{}(lower={}, upper={})".format(self.__class__.__name__, - self.lower, self.upper) - - def __call__(self, x_stft, uttid=None, train=True): - if not train: - return x_stft - - if x_stft.ndim == 1: - raise RuntimeError("Input in time-freq domain: " - "(Time, Channel, Freq) or (Time, Freq)") - - ratio = self.state.uniform(self.lower, self.upper) - axes = [i if i >= 0 else x_stft.ndim - i for i in self.axes] - shape = [s if i in axes else 1 for i, s in enumerate(x_stft.shape)] - - mask = self.state.randn(*shape) > ratio - x_stft *= mask - return x_stft - - -class VolumePerturbation(): - def __init__(self, - lower=-1.6, - upper=1.6, - utt2ratio=None, - dbunit=True, - seed=None): - self.dbunit = dbunit - self.utt2ratio_file = utt2ratio - self.lower = lower - self.upper = upper - self.state = numpy.random.RandomState(seed) - - if utt2ratio is not None: - # Use the scheduled ratio for each utterances - self.utt2ratio = {} - self.lower = None - self.upper = None - self.accept_uttid = True - - with open(utt2ratio, "r") as f: - for line in f: - utt, ratio = line.rstrip().split(None, 1) - ratio = float(ratio) - self.utt2ratio[utt] = ratio - else: - # The ratio is given on runtime randomly - self.utt2ratio = None - - def __repr__(self): - if self.utt2ratio is None: - return "{}(lower={}, upper={}, dbunit={})".format( - self.__class__.__name__, self.lower, self.upper, self.dbunit) - else: - return '{}("{}", dbunit={})'.format( - self.__class__.__name__, self.utt2ratio_file, self.dbunit) - - def __call__(self, x, uttid=None, train=True): - if not train: - return x - - x = x.astype(numpy.float32) - - if self.accept_uttid: - ratio = self.utt2ratio[uttid] - else: - ratio = self.state.uniform(self.lower, self.upper) - if self.dbunit: - ratio = 10**(ratio / 20) - return x * ratio - - -class NoiseInjection(): - """Add isotropic noise""" - - def __init__( - self, - utt2noise=None, - lower=-20, - upper=-5, - utt2ratio=None, - filetype="list", - dbunit=True, - seed=None, ): - self.utt2noise_file = utt2noise - self.utt2ratio_file = utt2ratio - self.filetype = filetype - self.dbunit = dbunit - self.lower = lower - self.upper = upper - self.state = numpy.random.RandomState(seed) - - if utt2ratio is not None: - # Use the scheduled ratio for each utterances - self.utt2ratio = {} - with open(utt2noise, "r") as f: - for line in f: - utt, snr = line.rstrip().split(None, 1) - snr = float(snr) - self.utt2ratio[utt] = snr - else: - # The ratio is given on runtime randomly - self.utt2ratio = None - - if utt2noise is not None: - self.utt2noise = {} - if filetype == "list": - with open(utt2noise, "r") as f: - for line in f: - utt, filename = line.rstrip().split(None, 1) - signal, rate = soundfile.read(filename, dtype="int16") - # Load all files in memory - self.utt2noise[utt] = (signal, rate) - - elif filetype == "sound.hdf5": - self.utt2noise = SoundHDF5File(utt2noise, "r") - else: - raise ValueError(filetype) - else: - self.utt2noise = None - - if utt2noise is not None and utt2ratio is not None: - if set(self.utt2ratio) != set(self.utt2noise): - raise RuntimeError("The uttids mismatch between {} and {}". - format(utt2ratio, utt2noise)) - - def __repr__(self): - if self.utt2ratio is None: - return "{}(lower={}, upper={}, dbunit={})".format( - self.__class__.__name__, self.lower, self.upper, self.dbunit) - else: - return '{}("{}", dbunit={})'.format( - self.__class__.__name__, self.utt2ratio_file, self.dbunit) - - def __call__(self, x, uttid=None, train=True): - if not train: - return x - x = x.astype(numpy.float32) - - # 1. Get ratio of noise to signal in sound pressure level - if uttid is not None and self.utt2ratio is not None: - ratio = self.utt2ratio[uttid] - else: - ratio = self.state.uniform(self.lower, self.upper) - - if self.dbunit: - ratio = 10**(ratio / 20) - scale = ratio * numpy.sqrt((x**2).mean()) - - # 2. Get noise - if self.utt2noise is not None: - # Get noise from the external source - if uttid is not None: - noise, rate = self.utt2noise[uttid] - else: - # Randomly select the noise source - noise = self.state.choice(list(self.utt2noise.values())) - # Normalize the level - noise /= numpy.sqrt((noise**2).mean()) - - # Adjust the noise length - diff = abs(len(x) - len(noise)) - offset = self.state.randint(0, diff) - if len(noise) > len(x): - # Truncate noise - noise = noise[offset:-(diff - offset)] - else: - noise = numpy.pad( - noise, pad_width=[offset, diff - offset], mode="wrap") - - else: - # Generate white noise - noise = self.state.normal(0, 1, x.shape) - - # 3. Add noise to signal - return x + noise * scale - - -class RIRConvolve(): - def __init__(self, utt2rir, filetype="list"): - self.utt2rir_file = utt2rir - self.filetype = filetype - - self.utt2rir = {} - if filetype == "list": - with open(utt2rir, "r") as f: - for line in f: - utt, filename = line.rstrip().split(None, 1) - signal, rate = soundfile.read(filename, dtype="int16") - self.utt2rir[utt] = (signal, rate) - - elif filetype == "sound.hdf5": - self.utt2rir = SoundHDF5File(utt2rir, "r") - else: - raise NotImplementedError(filetype) - - def __repr__(self): - return '{}("{}")'.format(self.__class__.__name__, self.utt2rir_file) - - def __call__(self, x, uttid=None, train=True): - if not train: - return x - - x = x.astype(numpy.float32) - - if x.ndim != 1: - # Must be single channel - raise RuntimeError( - "Input x must be one dimensional array, but got {}".format( - x.shape)) - - rir, rate = self.utt2rir[uttid] - if rir.ndim == 2: - # FIXME(kamo): Use chainer.convolution_1d? - # return [Time, Channel] - return numpy.stack( - [scipy.convolve(x, r, mode="same") for r in rir], axis=-1) - else: - return scipy.convolve(x, rir, mode="same") diff --git a/paddlespeech/s2t/transform/spec_augment.py b/paddlespeech/s2t/transform/spec_augment.py deleted file mode 100644 index 5ce95085..00000000 --- a/paddlespeech/s2t/transform/spec_augment.py +++ /dev/null @@ -1,214 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Modified from espnet(https://github.com/espnet/espnet) -"""Spec Augment module for preprocessing i.e., data augmentation""" -import random - -import numpy -from PIL import Image -from PIL.Image import BICUBIC - -from paddlespeech.s2t.transform.functional import FuncTrans - - -def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"): - """time warp for spec augment - - move random center frame by the random width ~ uniform(-window, window) - :param numpy.ndarray x: spectrogram (time, freq) - :param int max_time_warp: maximum time frames to warp - :param bool inplace: overwrite x with the result - :param str mode: "PIL" (default, fast, not differentiable) or "sparse_image_warp" - (slow, differentiable) - :returns numpy.ndarray: time warped spectrogram (time, freq) - """ - window = max_time_warp - if window == 0: - return x - - if mode == "PIL": - t = x.shape[0] - if t - window <= window: - return x - # NOTE: randrange(a, b) emits a, a + 1, ..., b - 1 - center = random.randrange(window, t - window) - warped = random.randrange(center - window, center + - window) + 1 # 1 ... t - 1 - - left = Image.fromarray(x[:center]).resize((x.shape[1], warped), BICUBIC) - right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped), - BICUBIC) - if inplace: - x[:warped] = left - x[warped:] = right - return x - return numpy.concatenate((left, right), 0) - elif mode == "sparse_image_warp": - import paddle - - from espnet.utils import spec_augment - - # TODO(karita): make this differentiable again - return spec_augment.time_warp(paddle.to_tensor(x), window).numpy() - else: - raise NotImplementedError("unknown resize mode: " + mode + - ", choose one from (PIL, sparse_image_warp).") - - -class TimeWarp(FuncTrans): - _func = time_warp - __doc__ = time_warp.__doc__ - - def __call__(self, x, train): - if not train: - return x - return super().__call__(x) - - -def freq_mask(x, F=30, n_mask=2, replace_with_zero=True, inplace=False): - """freq mask for spec agument - - :param numpy.ndarray x: (time, freq) - :param int n_mask: the number of masks - :param bool inplace: overwrite - :param bool replace_with_zero: pad zero on mask if true else use mean - """ - if inplace: - cloned = x - else: - cloned = x.copy() - - num_mel_channels = cloned.shape[1] - fs = numpy.random.randint(0, F, size=(n_mask, 2)) - - for f, mask_end in fs: - f_zero = random.randrange(0, num_mel_channels - f) - mask_end += f_zero - - # avoids randrange error if values are equal and range is empty - if f_zero == f_zero + f: - continue - - if replace_with_zero: - cloned[:, f_zero:mask_end] = 0 - else: - cloned[:, f_zero:mask_end] = cloned.mean() - return cloned - - -class FreqMask(FuncTrans): - _func = freq_mask - __doc__ = freq_mask.__doc__ - - def __call__(self, x, train): - if not train: - return x - return super().__call__(x) - - -def time_mask(spec, T=40, n_mask=2, replace_with_zero=True, inplace=False): - """freq mask for spec agument - - :param numpy.ndarray spec: (time, freq) - :param int n_mask: the number of masks - :param bool inplace: overwrite - :param bool replace_with_zero: pad zero on mask if true else use mean - """ - if inplace: - cloned = spec - else: - cloned = spec.copy() - len_spectro = cloned.shape[0] - ts = numpy.random.randint(0, T, size=(n_mask, 2)) - for t, mask_end in ts: - # avoid randint range error - if len_spectro - t <= 0: - continue - t_zero = random.randrange(0, len_spectro - t) - - # avoids randrange error if values are equal and range is empty - if t_zero == t_zero + t: - continue - - mask_end += t_zero - if replace_with_zero: - cloned[t_zero:mask_end] = 0 - else: - cloned[t_zero:mask_end] = cloned.mean() - return cloned - - -class TimeMask(FuncTrans): - _func = time_mask - __doc__ = time_mask.__doc__ - - def __call__(self, x, train): - if not train: - return x - return super().__call__(x) - - -def spec_augment( - x, - resize_mode="PIL", - max_time_warp=80, - max_freq_width=27, - n_freq_mask=2, - max_time_width=100, - n_time_mask=2, - inplace=True, - replace_with_zero=True, ): - """spec agument - - apply random time warping and time/freq masking - default setting is based on LD (Librispeech double) in Table 2 - https://arxiv.org/pdf/1904.08779.pdf - - :param numpy.ndarray x: (time, freq) - :param str resize_mode: "PIL" (fast, nondifferentiable) or "sparse_image_warp" - (slow, differentiable) - :param int max_time_warp: maximum frames to warp the center frame in spectrogram (W) - :param int freq_mask_width: maximum width of the random freq mask (F) - :param int n_freq_mask: the number of the random freq mask (m_F) - :param int time_mask_width: maximum width of the random time mask (T) - :param int n_time_mask: the number of the random time mask (m_T) - :param bool inplace: overwrite intermediate array - :param bool replace_with_zero: pad zero on mask if true else use mean - """ - assert isinstance(x, numpy.ndarray) - assert x.ndim == 2 - x = time_warp(x, max_time_warp, inplace=inplace, mode=resize_mode) - x = freq_mask( - x, - max_freq_width, - n_freq_mask, - inplace=inplace, - replace_with_zero=replace_with_zero, ) - x = time_mask( - x, - max_time_width, - n_time_mask, - inplace=inplace, - replace_with_zero=replace_with_zero, ) - return x - - -class SpecAugment(FuncTrans): - _func = spec_augment - __doc__ = spec_augment.__doc__ - - def __call__(self, x, train): - if not train: - return x - return super().__call__(x) diff --git a/paddlespeech/s2t/transform/spectrogram.py b/paddlespeech/s2t/transform/spectrogram.py deleted file mode 100644 index 19f0237b..00000000 --- a/paddlespeech/s2t/transform/spectrogram.py +++ /dev/null @@ -1,475 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Modified from espnet(https://github.com/espnet/espnet) -import librosa -import numpy as np -import paddle -from python_speech_features import logfbank - -import paddlespeech.audio.compliance.kaldi as kaldi - - -def stft(x, - n_fft, - n_shift, - win_length=None, - window="hann", - center=True, - pad_mode="reflect"): - # x: [Time, Channel] - if x.ndim == 1: - single_channel = True - # x: [Time] -> [Time, Channel] - x = x[:, None] - else: - single_channel = False - x = x.astype(np.float32) - - # FIXME(kamo): librosa.stft can't use multi-channel? - # x: [Time, Channel, Freq] - x = np.stack( - [ - librosa.stft( - y=x[:, ch], - n_fft=n_fft, - hop_length=n_shift, - win_length=win_length, - window=window, - center=center, - pad_mode=pad_mode, ).T for ch in range(x.shape[1]) - ], - axis=1, ) - - if single_channel: - # x: [Time, Channel, Freq] -> [Time, Freq] - x = x[:, 0] - return x - - -def istft(x, n_shift, win_length=None, window="hann", center=True): - # x: [Time, Channel, Freq] - if x.ndim == 2: - single_channel = True - # x: [Time, Freq] -> [Time, Channel, Freq] - x = x[:, None, :] - else: - single_channel = False - - # x: [Time, Channel] - x = np.stack( - [ - librosa.istft( - stft_matrix=x[:, ch].T, # [Time, Freq] -> [Freq, Time] - hop_length=n_shift, - win_length=win_length, - window=window, - center=center, ) for ch in range(x.shape[1]) - ], - axis=1, ) - - if single_channel: - # x: [Time, Channel] -> [Time] - x = x[:, 0] - return x - - -def stft2logmelspectrogram(x_stft, - fs, - n_mels, - n_fft, - fmin=None, - fmax=None, - eps=1e-10): - # x_stft: (Time, Channel, Freq) or (Time, Freq) - fmin = 0 if fmin is None else fmin - fmax = fs / 2 if fmax is None else fmax - - # spc: (Time, Channel, Freq) or (Time, Freq) - spc = np.abs(x_stft) - # mel_basis: (Mel_freq, Freq) - mel_basis = librosa.filters.mel( - sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) - # lmspc: (Time, Channel, Mel_freq) or (Time, Mel_freq) - lmspc = np.log10(np.maximum(eps, np.dot(spc, mel_basis.T))) - - return lmspc - - -def spectrogram(x, n_fft, n_shift, win_length=None, window="hann"): - # x: (Time, Channel) -> spc: (Time, Channel, Freq) - spc = np.abs(stft(x, n_fft, n_shift, win_length, window=window)) - return spc - - -def logmelspectrogram( - x, - fs, - n_mels, - n_fft, - n_shift, - win_length=None, - window="hann", - fmin=None, - fmax=None, - eps=1e-10, - pad_mode="reflect", ): - # stft: (Time, Channel, Freq) or (Time, Freq) - x_stft = stft( - x, - n_fft=n_fft, - n_shift=n_shift, - win_length=win_length, - window=window, - pad_mode=pad_mode, ) - - return stft2logmelspectrogram( - x_stft, - fs=fs, - n_mels=n_mels, - n_fft=n_fft, - fmin=fmin, - fmax=fmax, - eps=eps) - - -class Spectrogram(): - def __init__(self, n_fft, n_shift, win_length=None, window="hann"): - self.n_fft = n_fft - self.n_shift = n_shift - self.win_length = win_length - self.window = window - - def __repr__(self): - return ("{name}(n_fft={n_fft}, n_shift={n_shift}, " - "win_length={win_length}, window={window})".format( - name=self.__class__.__name__, - n_fft=self.n_fft, - n_shift=self.n_shift, - win_length=self.win_length, - window=self.window, )) - - def __call__(self, x): - return spectrogram( - x, - n_fft=self.n_fft, - n_shift=self.n_shift, - win_length=self.win_length, - window=self.window, ) - - -class LogMelSpectrogram(): - def __init__( - self, - fs, - n_mels, - n_fft, - n_shift, - win_length=None, - window="hann", - fmin=None, - fmax=None, - eps=1e-10, ): - self.fs = fs - self.n_mels = n_mels - self.n_fft = n_fft - self.n_shift = n_shift - self.win_length = win_length - self.window = window - self.fmin = fmin - self.fmax = fmax - self.eps = eps - - def __repr__(self): - return ("{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, " - "n_shift={n_shift}, win_length={win_length}, window={window}, " - "fmin={fmin}, fmax={fmax}, eps={eps}))".format( - name=self.__class__.__name__, - fs=self.fs, - n_mels=self.n_mels, - n_fft=self.n_fft, - n_shift=self.n_shift, - win_length=self.win_length, - window=self.window, - fmin=self.fmin, - fmax=self.fmax, - eps=self.eps, )) - - def __call__(self, x): - return logmelspectrogram( - x, - fs=self.fs, - n_mels=self.n_mels, - n_fft=self.n_fft, - n_shift=self.n_shift, - win_length=self.win_length, - window=self.window, ) - - -class Stft2LogMelSpectrogram(): - def __init__(self, fs, n_mels, n_fft, fmin=None, fmax=None, eps=1e-10): - self.fs = fs - self.n_mels = n_mels - self.n_fft = n_fft - self.fmin = fmin - self.fmax = fmax - self.eps = eps - - def __repr__(self): - return ("{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, " - "fmin={fmin}, fmax={fmax}, eps={eps}))".format( - name=self.__class__.__name__, - fs=self.fs, - n_mels=self.n_mels, - n_fft=self.n_fft, - fmin=self.fmin, - fmax=self.fmax, - eps=self.eps, )) - - def __call__(self, x): - return stft2logmelspectrogram( - x, - fs=self.fs, - n_mels=self.n_mels, - n_fft=self.n_fft, - fmin=self.fmin, - fmax=self.fmax, ) - - -class Stft(): - def __init__( - self, - n_fft, - n_shift, - win_length=None, - window="hann", - center=True, - pad_mode="reflect", ): - self.n_fft = n_fft - self.n_shift = n_shift - self.win_length = win_length - self.window = window - self.center = center - self.pad_mode = pad_mode - - def __repr__(self): - return ("{name}(n_fft={n_fft}, n_shift={n_shift}, " - "win_length={win_length}, window={window}," - "center={center}, pad_mode={pad_mode})".format( - name=self.__class__.__name__, - n_fft=self.n_fft, - n_shift=self.n_shift, - win_length=self.win_length, - window=self.window, - center=self.center, - pad_mode=self.pad_mode, )) - - def __call__(self, x): - return stft( - x, - self.n_fft, - self.n_shift, - win_length=self.win_length, - window=self.window, - center=self.center, - pad_mode=self.pad_mode, ) - - -class IStft(): - def __init__(self, n_shift, win_length=None, window="hann", center=True): - self.n_shift = n_shift - self.win_length = win_length - self.window = window - self.center = center - - def __repr__(self): - return ("{name}(n_shift={n_shift}, " - "win_length={win_length}, window={window}," - "center={center})".format( - name=self.__class__.__name__, - n_shift=self.n_shift, - win_length=self.win_length, - window=self.window, - center=self.center, )) - - def __call__(self, x): - return istft( - x, - self.n_shift, - win_length=self.win_length, - window=self.window, - center=self.center, ) - - -class LogMelSpectrogramKaldi(): - def __init__( - self, - fs=16000, - n_mels=80, - n_shift=160, # unit:sample, 10ms - win_length=400, # unit:sample, 25ms - energy_floor=0.0, - dither=0.1): - """ - The Kaldi implementation of LogMelSpectrogram - Args: - fs (int): sample rate of the audio - n_mels (int): number of mel filter banks - n_shift (int): number of points in a frame shift - win_length (int): number of points in a frame windows - energy_floor (float): Floor on energy in Spectrogram computation (absolute) - dither (float): Dithering constant - - Returns: - LogMelSpectrogramKaldi - """ - - self.fs = fs - self.n_mels = n_mels - num_point_ms = fs / 1000 - self.n_frame_length = win_length / num_point_ms - self.n_frame_shift = n_shift / num_point_ms - self.energy_floor = energy_floor - self.dither = dither - - def __repr__(self): - return ( - "{name}(fs={fs}, n_mels={n_mels}, " - "n_frame_shift={n_frame_shift}, n_frame_length={n_frame_length}, " - "dither={dither}))".format( - name=self.__class__.__name__, - fs=self.fs, - n_mels=self.n_mels, - n_frame_shift=self.n_frame_shift, - n_frame_length=self.n_frame_length, - dither=self.dither, )) - - def __call__(self, x, train): - """ - Args: - x (np.ndarray): shape (Ti,) - train (bool): True, train mode. - - Raises: - ValueError: not support (Ti, C) - - Returns: - np.ndarray: (T, D) - """ - dither = self.dither if train else 0.0 - if x.ndim != 1: - raise ValueError("Not support x: [Time, Channel]") - waveform = paddle.to_tensor(np.expand_dims(x, 0), dtype=paddle.float32) - mat = kaldi.fbank( - waveform, - n_mels=self.n_mels, - frame_length=self.n_frame_length, - frame_shift=self.n_frame_shift, - dither=dither, - energy_floor=self.energy_floor, - sr=self.fs) - mat = np.squeeze(mat.numpy()) - return mat - - -class LogMelSpectrogramKaldi_decay(): - def __init__( - self, - fs=16000, - n_mels=80, - n_fft=512, # fft point - n_shift=160, # unit:sample, 10ms - win_length=400, # unit:sample, 25ms - window="povey", - fmin=20, - fmax=None, - eps=1e-10, - dither=1.0): - self.fs = fs - self.n_mels = n_mels - self.n_fft = n_fft - if n_shift > win_length: - raise ValueError("Stride size must not be greater than " - "window size.") - self.n_shift = n_shift / fs # unit: ms - self.win_length = win_length / fs # unit: ms - - self.window = window - self.fmin = fmin - if fmax is None: - fmax_ = fmax if fmax else self.fs / 2 - elif fmax > int(self.fs / 2): - raise ValueError("fmax must not be greater than half of " - "sample rate.") - self.fmax = fmax_ - - self.eps = eps - self.remove_dc_offset = True - self.preemph = 0.97 - self.dither = dither # only work in train mode - - def __repr__(self): - return ( - "{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, " - "n_shift={n_shift}, win_length={win_length}, preemph={preemph}, window={window}, " - "fmin={fmin}, fmax={fmax}, eps={eps}, dither={dither}))".format( - name=self.__class__.__name__, - fs=self.fs, - n_mels=self.n_mels, - n_fft=self.n_fft, - n_shift=self.n_shift, - preemph=self.preemph, - win_length=self.win_length, - window=self.window, - fmin=self.fmin, - fmax=self.fmax, - eps=self.eps, - dither=self.dither, )) - - def __call__(self, x, train): - """ - - Args: - x (np.ndarray): shape (Ti,) - train (bool): True, train mode. - - Raises: - ValueError: not support (Ti, C) - - Returns: - np.ndarray: (T, D) - """ - dither = self.dither if train else 0.0 - if x.ndim != 1: - raise ValueError("Not support x: [Time, Channel]") - - if x.dtype in np.sctypes['float']: - # PCM32 -> PCM16 - bits = np.iinfo(np.int16).bits - x = x * 2**(bits - 1) - - # logfbank need PCM16 input - y = logfbank( - signal=x, - samplerate=self.fs, - winlen=self.win_length, # unit ms - winstep=self.n_shift, # unit ms - nfilt=self.n_mels, - nfft=self.n_fft, - lowfreq=self.fmin, - highfreq=self.fmax, - dither=dither, - remove_dc_offset=self.remove_dc_offset, - preemph=self.preemph, - wintype=self.window) - return y diff --git a/paddlespeech/s2t/transform/transform_interface.py b/paddlespeech/s2t/transform/transform_interface.py deleted file mode 100644 index 8bc62420..00000000 --- a/paddlespeech/s2t/transform/transform_interface.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Modified from espnet(https://github.com/espnet/espnet) - - -class TransformInterface: - """Transform Interface""" - - def __call__(self, x): - raise NotImplementedError("__call__ method is not implemented") - - @classmethod - def add_arguments(cls, parser): - return parser - - def __repr__(self): - return self.__class__.__name__ + "()" - - -class Identity(TransformInterface): - """Identity Function""" - - def __call__(self, x): - return x diff --git a/paddlespeech/s2t/transform/transformation.py b/paddlespeech/s2t/transform/transformation.py deleted file mode 100644 index 3b433cb0..00000000 --- a/paddlespeech/s2t/transform/transformation.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Modified from espnet(https://github.com/espnet/espnet) -"""Transformation module.""" -import copy -import io -import logging -from collections import OrderedDict -from collections.abc import Sequence -from inspect import signature - -import yaml - -from paddlespeech.s2t.utils.dynamic_import import dynamic_import - -import_alias = dict( - identity="paddlespeech.s2t.transform.transform_interface:Identity", - time_warp="paddlespeech.s2t.transform.spec_augment:TimeWarp", - time_mask="paddlespeech.s2t.transform.spec_augment:TimeMask", - freq_mask="paddlespeech.s2t.transform.spec_augment:FreqMask", - spec_augment="paddlespeech.s2t.transform.spec_augment:SpecAugment", - speed_perturbation="paddlespeech.s2t.transform.perturb:SpeedPerturbation", - speed_perturbation_sox="paddlespeech.s2t.transform.perturb:SpeedPerturbationSox", - volume_perturbation="paddlespeech.s2t.transform.perturb:VolumePerturbation", - noise_injection="paddlespeech.s2t.transform.perturb:NoiseInjection", - bandpass_perturbation="paddlespeech.s2t.transform.perturb:BandpassPerturbation", - rir_convolve="paddlespeech.s2t.transform.perturb:RIRConvolve", - delta="paddlespeech.s2t.transform.add_deltas:AddDeltas", - cmvn="paddlespeech.s2t.transform.cmvn:CMVN", - utterance_cmvn="paddlespeech.s2t.transform.cmvn:UtteranceCMVN", - fbank="paddlespeech.s2t.transform.spectrogram:LogMelSpectrogram", - spectrogram="paddlespeech.s2t.transform.spectrogram:Spectrogram", - stft="paddlespeech.s2t.transform.spectrogram:Stft", - istft="paddlespeech.s2t.transform.spectrogram:IStft", - stft2fbank="paddlespeech.s2t.transform.spectrogram:Stft2LogMelSpectrogram", - wpe="paddlespeech.s2t.transform.wpe:WPE", - channel_selector="paddlespeech.s2t.transform.channel_selector:ChannelSelector", - fbank_kaldi="paddlespeech.s2t.transform.spectrogram:LogMelSpectrogramKaldi", - cmvn_json="paddlespeech.s2t.transform.cmvn:GlobalCMVN") - - -class Transformation(): - """Apply some functions to the mini-batch - - Examples: - >>> kwargs = {"process": [{"type": "fbank", - ... "n_mels": 80, - ... "fs": 16000}, - ... {"type": "cmvn", - ... "stats": "data/train/cmvn.ark", - ... "norm_vars": True}, - ... {"type": "delta", "window": 2, "order": 2}]} - >>> transform = Transformation(kwargs) - >>> bs = 10 - >>> xs = [np.random.randn(100, 80).astype(np.float32) - ... for _ in range(bs)] - >>> xs = transform(xs) - """ - - def __init__(self, conffile=None): - if conffile is not None: - if isinstance(conffile, dict): - self.conf = copy.deepcopy(conffile) - else: - with io.open(conffile, encoding="utf-8") as f: - self.conf = yaml.safe_load(f) - assert isinstance(self.conf, dict), type(self.conf) - else: - self.conf = {"mode": "sequential", "process": []} - - self.functions = OrderedDict() - if self.conf.get("mode", "sequential") == "sequential": - for idx, process in enumerate(self.conf["process"]): - assert isinstance(process, dict), type(process) - opts = dict(process) - process_type = opts.pop("type") - class_obj = dynamic_import(process_type, import_alias) - # TODO(karita): assert issubclass(class_obj, TransformInterface) - try: - self.functions[idx] = class_obj(**opts) - except TypeError: - try: - signa = signature(class_obj) - except ValueError: - # Some function, e.g. built-in function, are failed - pass - else: - logging.error("Expected signature: {}({})".format( - class_obj.__name__, signa)) - raise - else: - raise NotImplementedError( - "Not supporting mode={}".format(self.conf["mode"])) - - def __repr__(self): - rep = "\n" + "\n".join(" {}: {}".format(k, v) - for k, v in self.functions.items()) - return "{}({})".format(self.__class__.__name__, rep) - - def __call__(self, xs, uttid_list=None, **kwargs): - """Return new mini-batch - - :param Union[Sequence[np.ndarray], np.ndarray] xs: - :param Union[Sequence[str], str] uttid_list: - :return: batch: - :rtype: List[np.ndarray] - """ - if not isinstance(xs, Sequence): - is_batch = False - xs = [xs] - else: - is_batch = True - - if isinstance(uttid_list, str): - uttid_list = [uttid_list for _ in range(len(xs))] - - if self.conf.get("mode", "sequential") == "sequential": - for idx in range(len(self.conf["process"])): - func = self.functions[idx] - # TODO(karita): use TrainingTrans and UttTrans to check __call__ args - # Derive only the args which the func has - try: - param = signature(func).parameters - except ValueError: - # Some function, e.g. built-in function, are failed - param = {} - _kwargs = {k: v for k, v in kwargs.items() if k in param} - try: - if uttid_list is not None and "uttid" in param: - xs = [ - func(x, u, **_kwargs) - for x, u in zip(xs, uttid_list) - ] - else: - xs = [func(x, **_kwargs) for x in xs] - except Exception: - logging.fatal("Catch a exception from {}th func: {}".format( - idx, func)) - raise - else: - raise NotImplementedError( - "Not supporting mode={}".format(self.conf["mode"])) - - if is_batch: - return xs - else: - return xs[0] diff --git a/paddlespeech/s2t/transform/wpe.py b/paddlespeech/s2t/transform/wpe.py deleted file mode 100644 index 777379d0..00000000 --- a/paddlespeech/s2t/transform/wpe.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Modified from espnet(https://github.com/espnet/espnet) -from nara_wpe.wpe import wpe - - -class WPE(object): - def __init__(self, - taps=10, - delay=3, - iterations=3, - psd_context=0, - statistics_mode="full"): - self.taps = taps - self.delay = delay - self.iterations = iterations - self.psd_context = psd_context - self.statistics_mode = statistics_mode - - def __repr__(self): - return ("{name}(taps={taps}, delay={delay}" - "iterations={iterations}, psd_context={psd_context}, " - "statistics_mode={statistics_mode})".format( - name=self.__class__.__name__, - taps=self.taps, - delay=self.delay, - iterations=self.iterations, - psd_context=self.psd_context, - statistics_mode=self.statistics_mode, )) - - def __call__(self, xs): - """Return enhanced - - :param np.ndarray xs: (Time, Channel, Frequency) - :return: enhanced_xs - :rtype: np.ndarray - - """ - # nara_wpe.wpe: (F, C, T) - xs = wpe( - xs.transpose((2, 1, 0)), - taps=self.taps, - delay=self.delay, - iterations=self.iterations, - psd_context=self.psd_context, - statistics_mode=self.statistics_mode, ) - return xs.transpose(2, 1, 0) From 0c7abc1f1753d4af1b104a3aae30b5f55661b8c1 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Mon, 27 Jun 2022 11:06:30 +0000 Subject: [PATCH 3/7] add training scripts --- examples/wenetspeech/asr1/conf/conformer.yaml | 28 +- examples/wenetspeech/asr1/local/data.sh | 125 ++--- .../asr1/local/wenetspeech_data_prep.sh | 4 +- .../{stream_data => streamdata}/__init__.py | 6 +- paddlespeech/audio/streamdata/autodecode.py | 445 +++++++++++++++++ .../{stream_data => streamdata}/cache.py | 4 +- .../{stream_data => streamdata}/compat.py | 2 +- .../audio/streamdata/extradatasets.py | 141 ++++++ .../{stream_data => streamdata}/filters.py | 4 +- paddlespeech/audio/streamdata/gopen.py | 340 +++++++++++++ paddlespeech/audio/streamdata/handlers.py | 47 ++ paddlespeech/audio/streamdata/mix.py | 85 ++++ .../paddle_utils.py | 0 .../{stream_data => streamdata}/pipeline.py | 3 +- .../{stream_data => streamdata}/shardlists.py | 0 .../tariterators.py | 4 +- .../{stream_data => streamdata}/utils.py | 0 paddlespeech/audio/streamdata/writer.py | 450 ++++++++++++++++++ paddlespeech/s2t/io/dataloader.py | 61 ++- setup.py | 3 +- 20 files changed, 1620 insertions(+), 132 deletions(-) rename paddlespeech/audio/{stream_data => streamdata}/__init__.py (87%) create mode 100644 paddlespeech/audio/streamdata/autodecode.py rename paddlespeech/audio/{stream_data => streamdata}/cache.py (98%) rename paddlespeech/audio/{stream_data => streamdata}/compat.py (99%) create mode 100644 paddlespeech/audio/streamdata/extradatasets.py rename paddlespeech/audio/{stream_data => streamdata}/filters.py (99%) create mode 100644 paddlespeech/audio/streamdata/gopen.py create mode 100644 paddlespeech/audio/streamdata/handlers.py create mode 100644 paddlespeech/audio/streamdata/mix.py rename paddlespeech/audio/{stream_data => streamdata}/paddle_utils.py (100%) rename paddlespeech/audio/{stream_data => streamdata}/pipeline.py (96%) rename paddlespeech/audio/{stream_data => streamdata}/shardlists.py (100%) rename paddlespeech/audio/{stream_data => streamdata}/tariterators.py (99%) rename paddlespeech/audio/{stream_data => streamdata}/utils.py (100%) create mode 100644 paddlespeech/audio/streamdata/writer.py diff --git a/examples/wenetspeech/asr1/conf/conformer.yaml b/examples/wenetspeech/asr1/conf/conformer.yaml index dd4ff0e2..f46d4bd9 100644 --- a/examples/wenetspeech/asr1/conf/conformer.yaml +++ b/examples/wenetspeech/asr1/conf/conformer.yaml @@ -1,7 +1,6 @@ ############################################ # Network Architecture # ############################################ -cmvn_file: cmvn_file_type: "json" # encoder related encoder: conformer @@ -43,9 +42,9 @@ model_conf: ########################################### # Data # ########################################### -train_manifest: data/manifest.train -dev_manifest: data/manifest.dev -test_manifest: data/manifest.test +train_manifest: data/train_l/data.list +dev_manifest: data/dev/data.list +test_manifest: data/test_meeting/data.list ########################################### # Dataloader # @@ -54,23 +53,22 @@ use_stream_data: True unit_type: 'char' vocab_filepath: data/lang_char/vocab.txt cmvn_file: data/mean_std.json -preprocess_config: conf/preprocess.yaml spm_model_prefix: '' feat_dim: 80 stride_ms: 10.0 window_ms: 25.0 dither: 0.1 sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs -batch_size: 64 +batch_size: 32 minlen_in: 10 -maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced +maxlen_in: 1200 # if input length(number of frames) > maxlen-in, data is automatically removed minlen_out: 0 -maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced +maxlen_out: 150 # if output length(number of tokens) > maxlen-out, data is automatically removed resample_rate: 16000 -shuffle_size: 10000 -sort_size: 500 -num_workers: 4 -prefetch_factor: 100 +shuffle_size: 1500 +sort_size: 1000 +num_workers: 0 +prefetch_factor: 10 dist_sampler: True num_encs: 1 augment_conf: @@ -90,10 +88,10 @@ augment_conf: ########################################### # Training # ########################################### -n_epoch: 240 -accum_grad: 16 +n_epoch: 30 +accum_grad: 32 global_grad_clip: 5.0 -log_interval: 1 +log_interval: 100 checkpoint: kbest_n: 50 latest_n: 5 diff --git a/examples/wenetspeech/asr1/local/data.sh b/examples/wenetspeech/asr1/local/data.sh index d216dd84..b3472a8f 100755 --- a/examples/wenetspeech/asr1/local/data.sh +++ b/examples/wenetspeech/asr1/local/data.sh @@ -2,6 +2,8 @@ # Copyright 2021 Mobvoi Inc(Author: Di Wu, Binbin Zhang) # NPU, ASLP Group (Author: Qijie Shao) +# +# Modified from wenet(https://github.com/wenet-e2e/wenet) stage=-1 stop_stage=100 @@ -30,7 +32,7 @@ mkdir -p data TARGET_DIR=${MAIN_ROOT}/dataset mkdir -p ${TARGET_DIR} -if [ ${stage} -le -2 ] && [ ${stop_stage} -ge -2 ]; then +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then # download data echo "Please follow https://github.com/wenet-e2e/WenetSpeech to download the data." exit 0; @@ -44,86 +46,57 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then data || exit 1; fi -if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then - # generate manifests - python3 ${TARGET_DIR}/aishell/aishell.py \ - --manifest_prefix="data/manifest" \ - --target_dir="${TARGET_DIR}/aishell" - - if [ $? -ne 0 ]; then - echo "Prepare Aishell failed. Terminated." - exit 1 - fi - - for dataset in train dev test; do - mv data/manifest.${dataset} data/manifest.${dataset}.raw - done -fi - -if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then - # compute mean and stddev for normalizer - if $cmvn; then - full_size=`cat data/${train_set}/wav.scp | wc -l` - sampling_size=$((full_size / cmvn_sampling_divisor)) - shuf -n $sampling_size data/$train_set/wav.scp \ - > data/$train_set/wav.scp.sampled - num_workers=$(nproc) - - python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ - --manifest_path="data/manifest.train.raw" \ - --spectrum_type="fbank" \ - --feat_dim=80 \ - --delta_delta=false \ - --stride_ms=10 \ - --window_ms=25 \ - --sample_rate=16000 \ - --use_dB_normalization=False \ - --num_samples=-1 \ - --num_workers=${num_workers} \ - --output_path="data/mean_std.json" - - if [ $? -ne 0 ]; then - echo "Compute mean and stddev failed. Terminated." - exit 1 - fi - fi -fi - -dict=data/dict/lang_char.txt +dict=data/lang_char/vocab.txt if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then - # download data, generate manifests - # build vocabulary - python3 ${MAIN_ROOT}/utils/build_vocab.py \ - --unit_type="char" \ - --count_threshold=0 \ - --vocab_path="data/lang_char/vocab.txt" \ - --manifest_paths "data/manifest.train.raw" - - if [ $? -ne 0 ]; then - echo "Build vocabulary failed. Terminated." - exit 1 - fi + echo "Make a dictionary" + echo "dictionary: ${dict}" + mkdir -p $(dirname $dict) + echo "" > ${dict} # 0 will be used for "blank" in CTC + echo "" >> ${dict} # must be 1 + echo "▁" >> ${dict} # ▁ is for space + utils/text2token.py -s 1 -n 1 --space "▁" data/${train_set}/text \ + | cut -f 2- -d" " | tr " " "\n" \ + | sort | uniq | grep -a -v -e '^\s*$' \ + | grep -v "▁" \ + | awk '{print $0}' >> ${dict} \ + || exit 1; + echo "" >> $dict fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then - # format manifest with tokenids, vocab size - for dataset in train dev test; do - { - python3 ${MAIN_ROOT}/utils/format_data.py \ - --cmvn_path "data/mean_std.json" \ - --unit_type "char" \ - --vocab_path="data/vocab.txt" \ - --manifest_path="data/manifest.${dataset}.raw" \ - --output_path="data/manifest.${dataset}" + echo "Compute cmvn" + # Here we use all the training data, you can sample some some data to save time + # BUG!!! We should use the segmented data for CMVN + if $cmvn; then + full_size=`cat data/${train_set}/wav.scp | wc -l` + sampling_size=$((full_size / cmvn_sampling_divisor)) + shuf -n $sampling_size data/$train_set/wav.scp \ + > data/$train_set/wav.scp.sampled + python3 utils/compute_cmvn_stats.py \ + --num_workers 16 \ + --train_config $train_config \ + --in_scp data/$train_set/wav.scp.sampled \ + --out_cmvn data/$train_set/mean_std.json \ + || exit 1; + fi +fi - if [ $? -ne 0 ]; then - echo "Formt mnaifest failed. Terminated." - exit 1 - fi - } & - done - wait +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + echo "Making shards, please wait..." + RED='\033[0;31m' + NOCOLOR='\033[0m' + echo -e "It requires ${RED}1.2T ${NOCOLOR}space for $shards_dir, please make sure you have enough space" + echo -e "It takes about ${RED}12 ${NOCOLOR}hours with 32 threads" + for x in $dev_set $test_sets ${train_set}; do + dst=$shards_dir/$x + mkdir -p $dst + utils/make_filted_shard_list.py --resample 16000 --num_utts_per_shard 1000 \ + --do_filter --num_node 1 --num_gpus_per_node 8 \ + --num_threads 32 --segments data/$x/segments \ + data/$x/wav.scp data/$x/text \ + $(realpath $dst) data/$x/data.list + done fi -echo "Aishell data preparation done." +echo "Wenetspeech data preparation done." exit 0 diff --git a/examples/wenetspeech/asr1/local/wenetspeech_data_prep.sh b/examples/wenetspeech/asr1/local/wenetspeech_data_prep.sh index 85853053..baa2b32d 100755 --- a/examples/wenetspeech/asr1/local/wenetspeech_data_prep.sh +++ b/examples/wenetspeech/asr1/local/wenetspeech_data_prep.sh @@ -24,7 +24,7 @@ stage=1 prefix= train_subset=L -. ./tools/parse_options.sh || exit 1; +. ./utils/parse_options.sh || exit 1; filter_by_id () { idlist=$1 @@ -132,4 +132,4 @@ if [ $stage -le 2 ]; then done fi -echo "$0: Done" \ No newline at end of file +echo "$0: Done" diff --git a/paddlespeech/audio/stream_data/__init__.py b/paddlespeech/audio/streamdata/__init__.py similarity index 87% rename from paddlespeech/audio/stream_data/__init__.py rename to paddlespeech/audio/streamdata/__init__.py index e9706d4e..d84fbb52 100644 --- a/paddlespeech/audio/stream_data/__init__.py +++ b/paddlespeech/audio/streamdata/__init__.py @@ -11,7 +11,7 @@ from .cache import ( pipe_cleaner, ) from .compat import WebDataset, WebLoader, FluidWrapper -from webdataset.extradatasets import MockDataset, with_epoch, with_length +from .extradatasets import MockDataset, with_epoch, with_length from .filters import ( associate, batched, @@ -65,5 +65,5 @@ from .shardlists import ( ) from .tariterators import tarfile_samples, tarfile_to_samples from .utils import PipelineStage, repeatedly -from webdataset.writer import ShardWriter, TarWriter, numpy_dumps -from webdataset.mix import RandomMix, RoundRobin +from .writer import ShardWriter, TarWriter, numpy_dumps +from .mix import RandomMix, RoundRobin diff --git a/paddlespeech/audio/streamdata/autodecode.py b/paddlespeech/audio/streamdata/autodecode.py new file mode 100644 index 00000000..8c74b685 --- /dev/null +++ b/paddlespeech/audio/streamdata/autodecode.py @@ -0,0 +1,445 @@ +# +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). +# Modified from https://github.com/webdataset/webdataset +# + +"""Automatically decode webdataset samples.""" + +import io, json, os, pickle, re, tempfile +from functools import partial + +import numpy as np + +"""Extensions passed on to the image decoder.""" +image_extensions = "jpg jpeg png ppm pgm pbm pnm".split() + + +################################################################ +# handle basic datatypes +################################################################ + + +def paddle_loads(data): + """Load data using paddle.loads, importing paddle only if needed. + + :param data: data to be decoded + """ + import io + + import paddle + + stream = io.BytesIO(data) + return paddle.load(stream) + + +def tenbin_loads(data): + from . import tenbin + + return tenbin.decode_buffer(data) + + +def msgpack_loads(data): + import msgpack + + return msgpack.unpackb(data) + + +def npy_loads(data): + import numpy.lib.format + + stream = io.BytesIO(data) + return numpy.lib.format.read_array(stream) + + +def cbor_loads(data): + import cbor + + return cbor.loads(data) + + +decoders = { + "txt": lambda data: data.decode("utf-8"), + "text": lambda data: data.decode("utf-8"), + "transcript": lambda data: data.decode("utf-8"), + "cls": lambda data: int(data), + "cls2": lambda data: int(data), + "index": lambda data: int(data), + "inx": lambda data: int(data), + "id": lambda data: int(data), + "json": lambda data: json.loads(data), + "jsn": lambda data: json.loads(data), + "pyd": lambda data: pickle.loads(data), + "pickle": lambda data: pickle.loads(data), + "pdparams": lambda data: paddle_loads(data), + "ten": tenbin_loads, + "tb": tenbin_loads, + "mp": msgpack_loads, + "msg": msgpack_loads, + "npy": npy_loads, + "npz": lambda data: np.load(io.BytesIO(data)), + "cbor": cbor_loads, +} + + +def basichandlers(key, data): + """Handle basic file decoding. + + This function is usually part of the post= decoders. + This handles the following forms of decoding: + + - txt -> unicode string + - cls cls2 class count index inx id -> int + - json jsn -> JSON decoding + - pyd pickle -> pickle decoding + - pdparams -> paddle.loads + - ten tenbin -> fast tensor loading + - mp messagepack msg -> messagepack decoding + - npy -> Python NPY decoding + + :param key: file name extension + :param data: binary data to be decoded + """ + extension = re.sub(r".*[.]", "", key) + + if extension in decoders: + return decoders[extension](data) + + return None + + +################################################################ +# Generic extension handler. +################################################################ + + +def call_extension_handler(key, data, f, extensions): + """Call the function f with the given data if the key matches the extensions. + + :param key: actual key found in the sample + :param data: binary data + :param f: decoder function + :param extensions: list of matching extensions + """ + extension = key.lower().split(".") + for target in extensions: + target = target.split(".") + if len(target) > len(extension): + continue + if extension[-len(target) :] == target: + return f(data) + return None + + +def handle_extension(extensions, f): + """Return a decoder function for the list of extensions. + + Extensions can be a space separated list of extensions. + Extensions can contain dots, in which case the corresponding number + of extension components must be present in the key given to f. + Comparisons are case insensitive. + + Examples: + handle_extension("jpg jpeg", my_decode_jpg) # invoked for any file.jpg + handle_extension("seg.jpg", special_case_jpg) # invoked only for file.seg.jpg + """ + extensions = extensions.lower().split() + return partial(call_extension_handler, f=f, extensions=extensions) + + +################################################################ +# handle images +################################################################ + +imagespecs = { + "l8": ("numpy", "uint8", "l"), + "rgb8": ("numpy", "uint8", "rgb"), + "rgba8": ("numpy", "uint8", "rgba"), + "l": ("numpy", "float", "l"), + "rgb": ("numpy", "float", "rgb"), + "rgba": ("numpy", "float", "rgba"), + "paddlel8": ("paddle", "uint8", "l"), + "paddlergb8": ("paddle", "uint8", "rgb"), + "paddlergba8": ("paddle", "uint8", "rgba"), + "paddlel": ("paddle", "float", "l"), + "paddlergb": ("paddle", "float", "rgb"), + "paddle": ("paddle", "float", "rgb"), + "paddlergba": ("paddle", "float", "rgba"), + "pill": ("pil", None, "l"), + "pil": ("pil", None, "rgb"), + "pilrgb": ("pil", None, "rgb"), + "pilrgba": ("pil", None, "rgba"), +} + + +class ImageHandler: + """Decode image data using the given `imagespec`. + + The `imagespec` specifies whether the image is decoded + to numpy/paddle/pi, decoded to uint8/float, and decoded + to l/rgb/rgba: + + - l8: numpy uint8 l + - rgb8: numpy uint8 rgb + - rgba8: numpy uint8 rgba + - l: numpy float l + - rgb: numpy float rgb + - rgba: numpy float rgba + - paddlel8: paddle uint8 l + - paddlergb8: paddle uint8 rgb + - paddlergba8: paddle uint8 rgba + - paddlel: paddle float l + - paddlergb: paddle float rgb + - paddle: paddle float rgb + - paddlergba: paddle float rgba + - pill: pil None l + - pil: pil None rgb + - pilrgb: pil None rgb + - pilrgba: pil None rgba + + """ + + def __init__(self, imagespec, extensions=image_extensions): + """Create an image handler. + + :param imagespec: short string indicating the type of decoding + :param extensions: list of extensions the image handler is invoked for + """ + if imagespec not in list(imagespecs.keys()): + raise ValueError("Unknown imagespec: %s" % imagespec) + self.imagespec = imagespec.lower() + self.extensions = extensions + + def __call__(self, key, data): + """Perform image decoding. + + :param key: file name extension + :param data: binary data + """ + import PIL.Image + + extension = re.sub(r".*[.]", "", key) + if extension.lower() not in self.extensions: + return None + imagespec = self.imagespec + atype, etype, mode = imagespecs[imagespec] + with io.BytesIO(data) as stream: + img = PIL.Image.open(stream) + img.load() + img = img.convert(mode.upper()) + if atype == "pil": + return img + elif atype == "numpy": + result = np.asarray(img) + if result.dtype != np.uint8: + raise ValueError("ImageHandler: numpy image must be uint8") + if etype == "uint8": + return result + else: + return result.astype("f") / 255.0 + elif atype == "paddle": + import paddle + + result = np.asarray(img) + if result.dtype != np.uint8: + raise ValueError("ImageHandler: paddle image must be uint8") + if etype == "uint8": + result = np.array(result.transpose(2, 0, 1)) + return paddle.tensor(result) + else: + result = np.array(result.transpose(2, 0, 1)) + return paddle.tensor(result) / 255.0 + return None + + +def imagehandler(imagespec, extensions=image_extensions): + """Create an image handler. + + This is just a lower case alias for ImageHander. + + :param imagespec: textual image spec + :param extensions: list of extensions the handler should be applied for + """ + return ImageHandler(imagespec, extensions) + + +################################################################ +# torch video +################################################################ + +''' +def torch_video(key, data): + """Decode video using the torchvideo library. + + :param key: file name extension + :param data: data to be decoded + """ + extension = re.sub(r".*[.]", "", key) + if extension not in "mp4 ogv mjpeg avi mov h264 mpg webm wmv".split(): + return None + + import torchvision.io + + with tempfile.TemporaryDirectory() as dirname: + fname = os.path.join(dirname, f"file.{extension}") + with open(fname, "wb") as stream: + stream.write(data) + return torchvision.io.read_video(fname, pts_unit="sec") +''' + + +################################################################ +# paddleaudio +################################################################ + + +def paddle_audio(key, data): + """Decode audio using the paddleaudio library. + + :param key: file name extension + :param data: data to be decoded + """ + extension = re.sub(r".*[.]", "", key) + if extension not in ["flac", "mp3", "sox", "wav", "m4a", "ogg", "wma"]: + return None + + import paddleaudio + + with tempfile.TemporaryDirectory() as dirname: + fname = os.path.join(dirname, f"file.{extension}") + with open(fname, "wb") as stream: + stream.write(data) + return paddleaudio.load(fname) + + +################################################################ +# special class for continuing decoding +################################################################ + + +class Continue: + """Special class for continuing decoding. + + This is mostly used for decompression, as in: + + def decompressor(key, data): + if key.endswith(".gz"): + return Continue(key[:-3], decompress(data)) + return None + """ + + def __init__(self, key, data): + """__init__. + + :param key: + :param data: + """ + self.key, self.data = key, data + + +def gzfilter(key, data): + """Decode .gz files. + + This decodes compressed files and the continues decoding. + + :param key: file name extension + :param data: binary data + """ + import gzip + + if not key.endswith(".gz"): + return None + decompressed = gzip.open(io.BytesIO(data)).read() + return Continue(key[:-3], decompressed) + + +################################################################ +# decode entire training amples +################################################################ + + +default_pre_handlers = [gzfilter] +default_post_handlers = [basichandlers] + + +class Decoder: + """Decode samples using a list of handlers. + + For each key/data item, this iterates through the list of + handlers until some handler returns something other than None. + """ + + def __init__(self, handlers, pre=None, post=None, only=None, partial=False): + """Create a Decoder. + + :param handlers: main list of handlers + :param pre: handlers called before the main list (.gz handler by default) + :param post: handlers called after the main list (default handlers by default) + :param only: a list of extensions; when give, only ignores files with those extensions + :param partial: allow partial decoding (i.e., don't decode fields that aren't of type bytes) + """ + if isinstance(only, str): + only = only.split() + self.only = only if only is None else set(only) + if pre is None: + pre = default_pre_handlers + if post is None: + post = default_post_handlers + assert all(callable(h) for h in handlers), f"one of {handlers} not callable" + assert all(callable(h) for h in pre), f"one of {pre} not callable" + assert all(callable(h) for h in post), f"one of {post} not callable" + self.handlers = pre + handlers + post + self.partial = partial + + def decode1(self, key, data): + """Decode a single field of a sample. + + :param key: file name extension + :param data: binary data + """ + key = "." + key + for f in self.handlers: + result = f(key, data) + if isinstance(result, Continue): + key, data = result.key, result.data + continue + if result is not None: + return result + return data + + def decode(self, sample): + """Decode an entire sample. + + :param sample: the sample, a dictionary of key value pairs + """ + result = {} + assert isinstance(sample, dict), sample + for k, v in list(sample.items()): + if k[0] == "_": + if isinstance(v, bytes): + v = v.decode("utf-8") + result[k] = v + continue + if self.only is not None and k not in self.only: + result[k] = v + continue + assert v is not None + if self.partial: + if isinstance(v, bytes): + result[k] = self.decode1(k, v) + else: + result[k] = v + else: + assert isinstance(v, bytes) + result[k] = self.decode1(k, v) + return result + + def __call__(self, sample): + """Decode an entire sample. + + :param sample: the sample + """ + assert isinstance(sample, dict), (len(sample), sample) + return self.decode(sample) diff --git a/paddlespeech/audio/stream_data/cache.py b/paddlespeech/audio/streamdata/cache.py similarity index 98% rename from paddlespeech/audio/stream_data/cache.py rename to paddlespeech/audio/streamdata/cache.py index 724f6911..e7bbffa1 100644 --- a/paddlespeech/audio/stream_data/cache.py +++ b/paddlespeech/audio/streamdata/cache.py @@ -6,8 +6,8 @@ import itertools, os, random, re, sys from urllib.parse import urlparse from . import filters -from webdataset import gopen -from webdataset.handlers import reraise_exception +from . import gopen +from .handlers import reraise_exception from .tariterators import tar_file_and_group_expander default_cache_dir = os.environ.get("WDS_CACHE", "./_cache") diff --git a/paddlespeech/audio/stream_data/compat.py b/paddlespeech/audio/streamdata/compat.py similarity index 99% rename from paddlespeech/audio/stream_data/compat.py rename to paddlespeech/audio/streamdata/compat.py index ee564431..11308d03 100644 --- a/paddlespeech/audio/stream_data/compat.py +++ b/paddlespeech/audio/streamdata/compat.py @@ -8,7 +8,7 @@ from typing import List import braceexpand, yaml -from webdataset import autodecode +from . import autodecode from . import cache, filters, shardlists, tariterators from .filters import reraise_exception from .pipeline import DataPipeline diff --git a/paddlespeech/audio/streamdata/extradatasets.py b/paddlespeech/audio/streamdata/extradatasets.py new file mode 100644 index 00000000..e6d61772 --- /dev/null +++ b/paddlespeech/audio/streamdata/extradatasets.py @@ -0,0 +1,141 @@ +# +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). +# Modified from https://github.com/webdataset/webdataset +# + + +"""Train PyTorch models directly from POSIX tar archive. + +Code works locally or over HTTP connections. +""" + +import itertools as itt +import os +import random +import sys + +import braceexpand + +from . import utils +from .paddle_utils import IterableDataset +from .utils import PipelineStage + + +class MockDataset(IterableDataset): + """MockDataset. + + A mock dataset for performance testing and unit testing. + """ + + def __init__(self, sample, length): + """Create a mock dataset instance. + + :param sample: the sample to be returned repeatedly + :param length: the length of the mock dataset + """ + self.sample = sample + self.length = length + + def __iter__(self): + """Return an iterator over this mock dataset.""" + for i in range(self.length): + yield self.sample + + +class repeatedly(IterableDataset, PipelineStage): + """Repeatedly yield samples from a dataset.""" + + def __init__(self, source, nepochs=None, nbatches=None, length=None): + """Create an instance of Repeatedly. + + :param nepochs: repeat for a maximum of nepochs + :param nbatches: repeat for a maximum of nbatches + """ + self.source = source + self.length = length + self.nbatches = nbatches + + def invoke(self, source): + """Return an iterator that iterates repeatedly over a source.""" + return utils.repeatedly( + source, + nepochs=self.nepochs, + nbatches=self.nbatches, + ) + + +class with_epoch(IterableDataset): + """Change the actual and nominal length of an IterableDataset. + + This will continuously iterate through the original dataset, but + impose new epoch boundaries at the given length/nominal. + This exists mainly as a workaround for the odd logic in DataLoader. + It is also useful for choosing smaller nominal epoch sizes with + very large datasets. + + """ + + def __init__(self, dataset, length): + """Chop the dataset to the given length. + + :param dataset: IterableDataset + :param length: declared length of the dataset + :param nominal: nominal length of dataset (if different from declared) + """ + super().__init__() + self.length = length + self.source = None + + def __getstate__(self): + """Return the pickled state of the dataset. + + This resets the dataset iterator, since that can't be pickled. + """ + result = dict(self.__dict__) + result["source"] = None + return result + + def invoke(self, dataset): + """Return an iterator over the dataset. + + This iterator returns as many samples as given by the `length` + parameter. + """ + if self.source is None: + self.source = iter(dataset) + for i in range(self.length): + try: + sample = next(self.source) + except StopIteration: + self.source = iter(dataset) + try: + sample = next(self.source) + except StopIteration: + return + yield sample + self.source = None + + +class with_length(IterableDataset, PipelineStage): + """Repeatedly yield samples from a dataset.""" + + def __init__(self, dataset, length): + """Create an instance of Repeatedly. + + :param dataset: source dataset + :param length: stated length + """ + super().__init__() + self.dataset = dataset + self.length = length + + def invoke(self, dataset): + """Return an iterator that iterates repeatedly over a source.""" + return iter(dataset) + + def __len__(self): + """Return the user specified length.""" + return self.length diff --git a/paddlespeech/audio/stream_data/filters.py b/paddlespeech/audio/streamdata/filters.py similarity index 99% rename from paddlespeech/audio/stream_data/filters.py rename to paddlespeech/audio/streamdata/filters.py index db3e037a..0ade66f9 100644 --- a/paddlespeech/audio/stream_data/filters.py +++ b/paddlespeech/audio/streamdata/filters.py @@ -21,7 +21,7 @@ from functools import reduce, wraps import numpy as np -from webdataset import autodecode +from . import autodecode from . import utils from .paddle_utils import PaddleTensor from .utils import PipelineStage @@ -932,4 +932,4 @@ def _placeholder(source): for data in source: yield data -placeholder = pipelinefilter(_placeholder) \ No newline at end of file +placeholder = pipelinefilter(_placeholder) diff --git a/paddlespeech/audio/streamdata/gopen.py b/paddlespeech/audio/streamdata/gopen.py new file mode 100644 index 00000000..457d048a --- /dev/null +++ b/paddlespeech/audio/streamdata/gopen.py @@ -0,0 +1,340 @@ +# +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). +# + + +"""Open URLs by calling subcommands.""" + +import os, sys, re +from subprocess import PIPE, Popen +from urllib.parse import urlparse + +# global used for printing additional node information during verbose output +info = {} + + +class Pipe: + """Wrapper class for subprocess.Pipe. + + This class looks like a stream from the outside, but it checks + subprocess status and handles timeouts with exceptions. + This way, clients of the class do not need to know that they are + dealing with subprocesses. + + :param *args: passed to `subprocess.Pipe` + :param **kw: passed to `subprocess.Pipe` + :param timeout: timeout for closing/waiting + :param ignore_errors: don't raise exceptions on subprocess errors + :param ignore_status: list of status codes to ignore + """ + + def __init__( + self, + *args, + mode=None, + timeout=7200.0, + ignore_errors=False, + ignore_status=[], + **kw, + ): + """Create an IO Pipe.""" + self.ignore_errors = ignore_errors + self.ignore_status = [0] + ignore_status + self.timeout = timeout + self.args = (args, kw) + if mode[0] == "r": + self.proc = Popen(*args, stdout=PIPE, **kw) + self.stream = self.proc.stdout + if self.stream is None: + raise ValueError(f"{args}: couldn't open") + elif mode[0] == "w": + self.proc = Popen(*args, stdin=PIPE, **kw) + self.stream = self.proc.stdin + if self.stream is None: + raise ValueError(f"{args}: couldn't open") + self.status = None + + def __str__(self): + return f"" + + def check_status(self): + """Poll the process and handle any errors.""" + status = self.proc.poll() + if status is not None: + self.wait_for_child() + + def wait_for_child(self): + """Check the status variable and raise an exception if necessary.""" + verbose = int(os.environ.get("GOPEN_VERBOSE", 0)) + if self.status is not None and verbose: + # print(f"(waiting again [{self.status} {os.getpid()}:{self.proc.pid}])", file=sys.stderr) + return + self.status = self.proc.wait() + if verbose: + print( + f"pipe exit [{self.status} {os.getpid()}:{self.proc.pid}] {self.args} {info}", + file=sys.stderr, + ) + if self.status not in self.ignore_status and not self.ignore_errors: + raise Exception(f"{self.args}: exit {self.status} (read) {info}") + + def read(self, *args, **kw): + """Wrap stream.read and checks status.""" + result = self.stream.read(*args, **kw) + self.check_status() + return result + + def write(self, *args, **kw): + """Wrap stream.write and checks status.""" + result = self.stream.write(*args, **kw) + self.check_status() + return result + + def readLine(self, *args, **kw): + """Wrap stream.readLine and checks status.""" + result = self.stream.readLine(*args, **kw) + self.status = self.proc.poll() + self.check_status() + return result + + def close(self): + """Wrap stream.close, wait for the subprocess, and handle errors.""" + self.stream.close() + self.status = self.proc.wait(self.timeout) + self.wait_for_child() + + def __enter__(self): + """Context handler.""" + return self + + def __exit__(self, etype, value, traceback): + """Context handler.""" + self.close() + + +def set_options( + obj, timeout=None, ignore_errors=None, ignore_status=None, handler=None +): + """Set options for Pipes. + + This function can be called on any stream. It will set pipe options only + when its argument is a pipe. + + :param obj: any kind of stream + :param timeout: desired timeout + :param ignore_errors: desired ignore_errors setting + :param ignore_status: desired ignore_status setting + :param handler: desired error handler + """ + if not isinstance(obj, Pipe): + return False + if timeout is not None: + obj.timeout = timeout + if ignore_errors is not None: + obj.ignore_errors = ignore_errors + if ignore_status is not None: + obj.ignore_status = ignore_status + if handler is not None: + obj.handler = handler + return True + + +def gopen_file(url, mode="rb", bufsize=8192): + """Open a file. + + This works for local files, files over HTTP, and pipe: files. + + :param url: URL to be opened + :param mode: mode to open it with + :param bufsize: requested buffer size + """ + return open(url, mode) + + +def gopen_pipe(url, mode="rb", bufsize=8192): + """Use gopen to open a pipe. + + :param url: a pipe: URL + :param mode: desired mode + :param bufsize: desired buffer size + """ + assert url.startswith("pipe:") + cmd = url[5:] + if mode[0] == "r": + return Pipe( + cmd, + mode=mode, + shell=True, + bufsize=bufsize, + ignore_status=[141], + ) # skipcq: BAN-B604 + elif mode[0] == "w": + return Pipe( + cmd, + mode=mode, + shell=True, + bufsize=bufsize, + ignore_status=[141], + ) # skipcq: BAN-B604 + else: + raise ValueError(f"{mode}: unknown mode") + + +def gopen_curl(url, mode="rb", bufsize=8192): + """Open a URL with `curl`. + + :param url: url (usually, http:// etc.) + :param mode: file mode + :param bufsize: buffer size + """ + if mode[0] == "r": + cmd = f"curl -s -L '{url}'" + return Pipe( + cmd, + mode=mode, + shell=True, + bufsize=bufsize, + ignore_status=[141, 23], + ) # skipcq: BAN-B604 + elif mode[0] == "w": + cmd = f"curl -s -L -T - '{url}'" + return Pipe( + cmd, + mode=mode, + shell=True, + bufsize=bufsize, + ignore_status=[141, 26], + ) # skipcq: BAN-B604 + else: + raise ValueError(f"{mode}: unknown mode") + + +def gopen_htgs(url, mode="rb", bufsize=8192): + """Open a URL with `curl`. + + :param url: url (usually, http:// etc.) + :param mode: file mode + :param bufsize: buffer size + """ + if mode[0] == "r": + url = re.sub(r"(?i)^htgs://", "gs://", url) + cmd = f"curl -s -L '{url}'" + return Pipe( + cmd, + mode=mode, + shell=True, + bufsize=bufsize, + ignore_status=[141, 23], + ) # skipcq: BAN-B604 + elif mode[0] == "w": + raise ValueError(f"{mode}: cannot write") + else: + raise ValueError(f"{mode}: unknown mode") + + + +def gopen_gsutil(url, mode="rb", bufsize=8192): + """Open a URL with `curl`. + + :param url: url (usually, http:// etc.) + :param mode: file mode + :param bufsize: buffer size + """ + if mode[0] == "r": + cmd = f"gsutil cat '{url}'" + return Pipe( + cmd, + mode=mode, + shell=True, + bufsize=bufsize, + ignore_status=[141, 23], + ) # skipcq: BAN-B604 + elif mode[0] == "w": + cmd = f"gsutil cp - '{url}'" + return Pipe( + cmd, + mode=mode, + shell=True, + bufsize=bufsize, + ignore_status=[141, 26], + ) # skipcq: BAN-B604 + else: + raise ValueError(f"{mode}: unknown mode") + + + +def gopen_error(url, *args, **kw): + """Raise a value error. + + :param url: url + :param args: other arguments + :param kw: other keywords + """ + raise ValueError(f"{url}: no gopen handler defined") + + +"""A dispatch table mapping URL schemes to handlers.""" +gopen_schemes = dict( + __default__=gopen_error, + pipe=gopen_pipe, + http=gopen_curl, + https=gopen_curl, + sftp=gopen_curl, + ftps=gopen_curl, + scp=gopen_curl, + gs=gopen_gsutil, + htgs=gopen_htgs, +) + + +def gopen(url, mode="rb", bufsize=8192, **kw): + """Open the URL. + + This uses the `gopen_schemes` dispatch table to dispatch based + on scheme. + + Support for the following schemes is built-in: pipe, file, + http, https, sftp, ftps, scp. + + When no scheme is given the url is treated as a file. + + You can use the OPEN_VERBOSE argument to get info about + files being opened. + + :param url: the source URL + :param mode: the mode ("rb", "r") + :param bufsize: the buffer size + """ + global fallback_gopen + verbose = int(os.environ.get("GOPEN_VERBOSE", 0)) + if verbose: + print("GOPEN", url, info, file=sys.stderr) + assert mode in ["rb", "wb"], mode + if url == "-": + if mode == "rb": + return sys.stdin.buffer + elif mode == "wb": + return sys.stdout.buffer + else: + raise ValueError(f"unknown mode {mode}") + pr = urlparse(url) + if pr.scheme == "": + bufsize = int(os.environ.get("GOPEN_BUFFER", -1)) + return open(url, mode, buffering=bufsize) + if pr.scheme == "file": + bufsize = int(os.environ.get("GOPEN_BUFFER", -1)) + return open(pr.path, mode, buffering=bufsize) + handler = gopen_schemes["__default__"] + handler = gopen_schemes.get(pr.scheme, handler) + return handler(url, mode, bufsize, **kw) + + +def reader(url, **kw): + """Open url with gopen and mode "rb". + + :param url: source URL + :param kw: other keywords forwarded to gopen + """ + return gopen(url, "rb", **kw) diff --git a/paddlespeech/audio/streamdata/handlers.py b/paddlespeech/audio/streamdata/handlers.py new file mode 100644 index 00000000..7f3d28b6 --- /dev/null +++ b/paddlespeech/audio/streamdata/handlers.py @@ -0,0 +1,47 @@ +# +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). +# + +"""Pluggable exception handlers. + +These are functions that take an exception as an argument and then return... + +- the exception (in order to re-raise it) +- True (in order to continue and ignore the exception) +- False (in order to ignore the exception and stop processing) + +They are used as handler= arguments in much of the library. +""" + +import time, warnings + + +def reraise_exception(exn): + """Call in an exception handler to re-raise the exception.""" + raise exn + + +def ignore_and_continue(exn): + """Call in an exception handler to ignore any exception and continue.""" + return True + + +def warn_and_continue(exn): + """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" + warnings.warn(repr(exn)) + time.sleep(0.5) + return True + + +def ignore_and_stop(exn): + """Call in an exception handler to ignore any exception and stop further processing.""" + return False + + +def warn_and_stop(exn): + """Call in an exception handler to ignore any exception and stop further processing.""" + warnings.warn(repr(exn)) + time.sleep(0.5) + return False diff --git a/paddlespeech/audio/streamdata/mix.py b/paddlespeech/audio/streamdata/mix.py new file mode 100644 index 00000000..7d790f00 --- /dev/null +++ b/paddlespeech/audio/streamdata/mix.py @@ -0,0 +1,85 @@ +# +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). +# Modified from https://github.com/webdataset/webdataset +# + +"""Classes for mixing samples from multiple sources.""" + +import itertools, os, random, time, sys +from functools import reduce, wraps + +import numpy as np + +from . import autodecode, utils +from .paddle_utils import PaddleTensor, IterableDataset +from .utils import PipelineStage + + +def round_robin_shortest(*sources): + i = 0 + while True: + try: + sample = next(sources[i % len(sources)]) + yield sample + except StopIteration: + break + i += 1 + + +def round_robin_longest(*sources): + i = 0 + while len(sources) > 0: + try: + sample = next(sources[i]) + i += 1 + yield sample + except StopIteration: + del sources[i] + + +class RoundRobin(IterableDataset): + def __init__(self, datasets, longest=False): + self.datasets = datasets + self.longest = longest + + def __iter__(self): + """Return an iterator over the sources.""" + sources = [iter(d) for d in self.datasets] + if self.longest: + return round_robin_longest(*sources) + else: + return round_robin_shortest(*sources) + + +def random_samples(sources, probs=None, longest=False): + if probs is None: + probs = [1] * len(sources) + else: + probs = list(probs) + while len(sources) > 0: + cum = (np.array(probs) / np.sum(probs)).cumsum() + r = random.random() + i = np.searchsorted(cum, r) + try: + yield next(sources[i]) + except StopIteration: + if longest: + del sources[i] + del probs[i] + else: + break + + +class RandomMix(IterableDataset): + def __init__(self, datasets, probs=None, longest=False): + self.datasets = datasets + self.probs = probs + self.longest = longest + + def __iter__(self): + """Return an iterator over the sources.""" + sources = [iter(d) for d in self.datasets] + return random_samples(sources, self.probs, longest=self.longest) diff --git a/paddlespeech/audio/stream_data/paddle_utils.py b/paddlespeech/audio/streamdata/paddle_utils.py similarity index 100% rename from paddlespeech/audio/stream_data/paddle_utils.py rename to paddlespeech/audio/streamdata/paddle_utils.py diff --git a/paddlespeech/audio/stream_data/pipeline.py b/paddlespeech/audio/streamdata/pipeline.py similarity index 96% rename from paddlespeech/audio/stream_data/pipeline.py rename to paddlespeech/audio/streamdata/pipeline.py index e738083f..7339a762 100644 --- a/paddlespeech/audio/stream_data/pipeline.py +++ b/paddlespeech/audio/streamdata/pipeline.py @@ -10,8 +10,7 @@ from typing import List import braceexpand, yaml -from webdataset import autodecode, extradatasets as eds, filters, shardlists, tariterators -from webdataset.handlers import reraise_exception +from .handlers import reraise_exception from .paddle_utils import DataLoader, IterableDataset from .utils import PipelineStage diff --git a/paddlespeech/audio/stream_data/shardlists.py b/paddlespeech/audio/streamdata/shardlists.py similarity index 100% rename from paddlespeech/audio/stream_data/shardlists.py rename to paddlespeech/audio/streamdata/shardlists.py diff --git a/paddlespeech/audio/stream_data/tariterators.py b/paddlespeech/audio/streamdata/tariterators.py similarity index 99% rename from paddlespeech/audio/stream_data/tariterators.py rename to paddlespeech/audio/streamdata/tariterators.py index d9469797..2c1daae1 100644 --- a/paddlespeech/audio/stream_data/tariterators.py +++ b/paddlespeech/audio/streamdata/tariterators.py @@ -14,8 +14,8 @@ import random, re, tarfile import braceexpand from . import filters -from webdataset import gopen -from webdataset.handlers import reraise_exception +from . import gopen +from .handlers import reraise_exception trace = False meta_prefix = "__" diff --git a/paddlespeech/audio/stream_data/utils.py b/paddlespeech/audio/streamdata/utils.py similarity index 100% rename from paddlespeech/audio/stream_data/utils.py rename to paddlespeech/audio/streamdata/utils.py diff --git a/paddlespeech/audio/streamdata/writer.py b/paddlespeech/audio/streamdata/writer.py new file mode 100644 index 00000000..7d4f7703 --- /dev/null +++ b/paddlespeech/audio/streamdata/writer.py @@ -0,0 +1,450 @@ +# +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). +# Modified from https://github.com/webdataset/webdataset +# + +"""Classes and functions for writing tar files and WebDataset files.""" + +import io, json, pickle, re, tarfile, time +from typing import Any, Callable, Optional, Union + +import numpy as np + +from . import gopen + + +def imageencoder(image: Any, format: str = "PNG"): # skipcq: PYL-W0622 + """Compress an image using PIL and return it as a string. + + Can handle float or uint8 images. + + :param image: ndarray representing an image + :param format: compression format (PNG, JPEG, PPM) + + """ + import PIL + + assert isinstance(image, (PIL.Image.Image, np.ndarray)), type(image) + + if isinstance(image, np.ndarray): + if image.dtype in [np.dtype("f"), np.dtype("d")]: + if not (np.amin(image) > -0.001 and np.amax(image) < 1.001): + raise ValueError( + f"image values out of range {np.amin(image)} {np.amax(image)}" + ) + image = np.clip(image, 0.0, 1.0) + image = np.array(image * 255.0, "uint8") + assert image.ndim in [2, 3] + if image.ndim == 3: + assert image.shape[2] in [1, 3] + image = PIL.Image.fromarray(image) + if format.upper() == "JPG": + format = "JPEG" + elif format.upper() in ["IMG", "IMAGE"]: + format = "PPM" + if format == "JPEG": + opts = dict(quality=100) + else: + opts = {} + with io.BytesIO() as result: + image.save(result, format=format, **opts) + return result.getvalue() + + +def bytestr(data: Any): + """Convert data into a bytestring. + + Uses str and ASCII encoding for data that isn't already in string format. + + :param data: data + """ + if isinstance(data, bytes): + return data + if isinstance(data, str): + return data.encode("ascii") + return str(data).encode("ascii") + +def paddle_dumps(data: Any): + """Dump data into a bytestring using paddle.dumps. + + This delays importing paddle until needed. + + :param data: data to be dumped + """ + import io + + import paddle + + stream = io.BytesIO() + paddle.save(data, stream) + return stream.getvalue() + +def numpy_dumps(data: np.ndarray): + """Dump data into a bytestring using numpy npy format. + + :param data: data to be dumped + """ + import io + + import numpy.lib.format + + stream = io.BytesIO() + numpy.lib.format.write_array(stream, data) + return stream.getvalue() + + +def numpy_npz_dumps(data: np.ndarray): + """Dump data into a bytestring using numpy npz format. + + :param data: data to be dumped + """ + import io + + stream = io.BytesIO() + np.savez_compressed(stream, **data) + return stream.getvalue() + + +def tenbin_dumps(x): + from . import tenbin + + if isinstance(x, list): + return memoryview(tenbin.encode_buffer(x)) + else: + return memoryview(tenbin.encode_buffer([x])) + + +def cbor_dumps(x): + import cbor + + return cbor.dumps(x) + + +def mp_dumps(x): + import msgpack + + return msgpack.packb(x) + + +def add_handlers(d, keys, value): + if isinstance(keys, str): + keys = keys.split() + for k in keys: + d[k] = value + + +def make_handlers(): + """Create a list of handlers for encoding data.""" + handlers = {} + add_handlers( + handlers, "cls cls2 class count index inx id", lambda x: str(x).encode("ascii") + ) + add_handlers(handlers, "txt text transcript", lambda x: x.encode("utf-8")) + add_handlers(handlers, "html htm", lambda x: x.encode("utf-8")) + add_handlers(handlers, "pyd pickle", pickle.dumps) + add_handlers(handlers, "pdparams", paddle_dumps) + add_handlers(handlers, "npy", numpy_dumps) + add_handlers(handlers, "npz", numpy_npz_dumps) + add_handlers(handlers, "ten tenbin tb", tenbin_dumps) + add_handlers(handlers, "json jsn", lambda x: json.dumps(x).encode("utf-8")) + add_handlers(handlers, "mp msgpack msg", mp_dumps) + add_handlers(handlers, "cbor", cbor_dumps) + add_handlers(handlers, "jpg jpeg img image", lambda data: imageencoder(data, "jpg")) + add_handlers(handlers, "png", lambda data: imageencoder(data, "png")) + add_handlers(handlers, "pbm", lambda data: imageencoder(data, "pbm")) + add_handlers(handlers, "pgm", lambda data: imageencoder(data, "pgm")) + add_handlers(handlers, "ppm", lambda data: imageencoder(data, "ppm")) + return handlers + + +default_handlers = make_handlers() + + +def encode_based_on_extension1(data: Any, tname: str, handlers: dict): + """Encode data based on its extension and a dict of handlers. + + :param data: data + :param tname: file extension + :param handlers: handlers + """ + if tname[0] == "_": + if not isinstance(data, str): + raise ValueError("the values of metadata must be of string type") + return data + extension = re.sub(r".*\.", "", tname).lower() + if isinstance(data, bytes): + return data + if isinstance(data, str): + return data.encode("utf-8") + handler = handlers.get(extension) + if handler is None: + raise ValueError(f"no handler found for {extension}") + return handler(data) + + +def encode_based_on_extension(sample: dict, handlers: dict): + """Encode an entire sample with a collection of handlers. + + :param sample: data sample (a dict) + :param handlers: handlers for encoding + """ + return { + k: encode_based_on_extension1(v, k, handlers) for k, v in list(sample.items()) + } + + +def make_encoder(spec: Union[bool, str, dict, Callable]): + """Make an encoder function from a specification. + + :param spec: specification + """ + if spec is False or spec is None: + + def encoder(x): + """Do not encode at all.""" + return x + + elif callable(spec): + encoder = spec + elif isinstance(spec, dict): + + def f(sample): + """Encode based on extension.""" + return encode_based_on_extension(sample, spec) + + encoder = f + + elif spec is True: + handlers = default_handlers + + def g(sample): + """Encode based on extension.""" + return encode_based_on_extension(sample, handlers) + + encoder = g + + else: + raise ValueError(f"{spec}: unknown decoder spec") + if not callable(encoder): + raise ValueError(f"{spec} did not yield a callable encoder") + return encoder + + +class TarWriter: + """A class for writing dictionaries to tar files. + + :param fileobj: fileobj: file name for tar file (.tgz/.tar) or open file descriptor + :param encoder: sample encoding (Default value = True) + :param compress: (Default value = None) + + `True` will use an encoder that behaves similar to the automatic + decoder for `Dataset`. `False` disables encoding and expects byte strings + (except for metadata, which must be strings). The `encoder` argument can + also be a `callable`, or a dictionary mapping extensions to encoders. + + The following code will add two file to the tar archive: `a/b.png` and + `a/b.output.png`. + + ```Python + tarwriter = TarWriter(stream) + image = imread("b.jpg") + image2 = imread("b.out.jpg") + sample = {"__key__": "a/b", "png": image, "output.png": image2} + tarwriter.write(sample) + ``` + """ + + def __init__( + self, + fileobj, + user: str = "bigdata", + group: str = "bigdata", + mode: int = 0o0444, + compress: Optional[bool] = None, + encoder: Union[None, bool, Callable] = True, + keep_meta: bool = False, + ): + """Create a tar writer. + + :param fileobj: stream to write data to + :param user: user for tar files + :param group: group for tar files + :param mode: mode for tar files + :param compress: desired compression + :param encoder: encoder function + :param keep_meta: keep metadata (entries starting with "_") + """ + if isinstance(fileobj, str): + if compress is False: + tarmode = "w|" + elif compress is True: + tarmode = "w|gz" + else: + tarmode = "w|gz" if fileobj.endswith("gz") else "w|" + fileobj = gopen.gopen(fileobj, "wb") + self.own_fileobj = fileobj + else: + tarmode = "w|gz" if compress is True else "w|" + self.own_fileobj = None + self.encoder = make_encoder(encoder) + self.keep_meta = keep_meta + self.stream = fileobj + self.tarstream = tarfile.open(fileobj=fileobj, mode=tarmode) + + self.user = user + self.group = group + self.mode = mode + self.compress = compress + + def __enter__(self): + """Enter context.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Exit context.""" + self.close() + + def close(self): + """Close the tar file.""" + self.tarstream.close() + if self.own_fileobj is not None: + self.own_fileobj.close() + self.own_fileobj = None + + def write(self, obj): + """Write a dictionary to the tar file. + + :param obj: dictionary of objects to be stored + :returns: size of the entry + + """ + total = 0 + obj = self.encoder(obj) + if "__key__" not in obj: + raise ValueError("object must contain a __key__") + for k, v in list(obj.items()): + if k[0] == "_": + continue + if not isinstance(v, (bytes, bytearray, memoryview)): + raise ValueError( + f"{k} doesn't map to a bytes after encoding ({type(v)})" + ) + key = obj["__key__"] + for k in sorted(obj.keys()): + if k == "__key__": + continue + if not self.keep_meta and k[0] == "_": + continue + v = obj[k] + if isinstance(v, str): + v = v.encode("utf-8") + now = time.time() + ti = tarfile.TarInfo(key + "." + k) + ti.size = len(v) + ti.mtime = now + ti.mode = self.mode + ti.uname = self.user + ti.gname = self.group + if not isinstance(v, (bytes, bytearray, memoryview)): + raise ValueError(f"converter didn't yield bytes: {k}, {type(v)}") + stream = io.BytesIO(v) + self.tarstream.addfile(ti, stream) + total += ti.size + return total + + +class ShardWriter: + """Like TarWriter but splits into multiple shards.""" + + def __init__( + self, + pattern: str, + maxcount: int = 100000, + maxsize: float = 3e9, + post: Optional[Callable] = None, + start_shard: int = 0, + **kw, + ): + """Create a ShardWriter. + + :param pattern: output file pattern + :param maxcount: maximum number of records per shard (Default value = 100000) + :param maxsize: maximum size of each shard (Default value = 3e9) + :param kw: other options passed to TarWriter + """ + self.verbose = 1 + self.kw = kw + self.maxcount = maxcount + self.maxsize = maxsize + self.post = post + + self.tarstream = None + self.shard = start_shard + self.pattern = pattern + self.total = 0 + self.count = 0 + self.size = 0 + self.fname = None + self.next_stream() + + def next_stream(self): + """Close the current stream and move to the next.""" + self.finish() + self.fname = self.pattern % self.shard + if self.verbose: + print( + "# writing", + self.fname, + self.count, + "%.1f GB" % (self.size / 1e9), + self.total, + ) + self.shard += 1 + stream = open(self.fname, "wb") + self.tarstream = TarWriter(stream, **self.kw) + self.count = 0 + self.size = 0 + + def write(self, obj): + """Write a sample. + + :param obj: sample to be written + """ + if ( + self.tarstream is None + or self.count >= self.maxcount + or self.size >= self.maxsize + ): + self.next_stream() + size = self.tarstream.write(obj) + self.count += 1 + self.total += 1 + self.size += size + + def finish(self): + """Finish all writing (use close instead).""" + if self.tarstream is not None: + self.tarstream.close() + assert self.fname is not None + if callable(self.post): + self.post(self.fname) + self.tarstream = None + + def close(self): + """Close the stream.""" + self.finish() + del self.tarstream + del self.shard + del self.count + del self.size + + def __enter__(self): + """Enter context.""" + return self + + def __exit__(self, *args, **kw): + """Exit context.""" + self.close() diff --git a/paddlespeech/s2t/io/dataloader.py b/paddlespeech/s2t/io/dataloader.py index c27969f0..2f3803fa 100644 --- a/paddlespeech/s2t/io/dataloader.py +++ b/paddlespeech/s2t/io/dataloader.py @@ -18,6 +18,7 @@ from typing import Text import jsonlines import numpy as np +import paddle from paddle.io import BatchSampler from paddle.io import DataLoader from paddle.io import DistributedBatchSampler @@ -28,7 +29,7 @@ from paddlespeech.s2t.io.dataset import TransformDataset from paddlespeech.s2t.io.reader import LoadInputsAndTargets from paddlespeech.s2t.utils.log import Log -import paddlespeech.audio.stream_data as stream_data +import paddlespeech.audio.streamdata as streamdata from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer __all__ = ["BatchDataLoader"] @@ -101,38 +102,46 @@ class StreamDataLoader(): shardlist.append(line.strip()) if self.dist_sampler: - base_dataset = stream_data.DataPipeline( - stream_data.SimpleShardList(shardlist), - stream_data.split_by_node, - stream_data.split_by_worker, - stream_data.tarfile_to_samples(stream_data.reraise_exception) + base_dataset = streamdata.DataPipeline( + streamdata.SimpleShardList(shardlist), + streamdata.split_by_node, + streamdata.split_by_worker, + streamdata.tarfile_to_samples(streamdata.reraise_exception) ) else: - base_dataset = stream_data.DataPipeline( - stream_data.SimpleShardList(shardlist), - stream_data.split_by_worker, - stream_data.tarfile_to_samples(stream_data.reraise_exception) + base_dataset = streamdata.DataPipeline( + streamdata.SimpleShardList(shardlist), + streamdata.split_by_worker, + streamdata.tarfile_to_samples(streamdata.reraise_exception) ) self.dataset = base_dataset.append_list( - stream_data.tokenize(symbol_table), - stream_data.data_filter(frame_shift=frame_shift, max_length=maxlen_in, min_length=minlen_in, token_max_length=maxlen_out, token_min_length=minlen_in), - stream_data.resample(resample_rate=resample_rate), - stream_data.compute_fbank(num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither), - stream_data.spec_aug(**augment_conf) if train_mode else stream_data.placeholder(), # num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80) - stream_data.shuffle(shuffle_size), - stream_data.sort(sort_size=sort_size), - stream_data.batched(batch_size), - stream_data.padding(), - stream_data.cmvn(cmvn_file) - ) - self.loader = stream_data.WebLoader( - self.dataset, - num_workers=self.n_iter_processes, - prefetch_factor = self.prefetch_factor, - batch_size=None + streamdata.tokenize(symbol_table), + streamdata.data_filter(frame_shift=frame_shift, max_length=maxlen_in, min_length=minlen_in, token_max_length=maxlen_out, token_min_length=minlen_in), + streamdata.resample(resample_rate=resample_rate), + streamdata.compute_fbank(num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither), + streamdata.spec_aug(**augment_conf) if train_mode else streamdata.placeholder(), # num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80) + streamdata.shuffle(shuffle_size), + streamdata.sort(sort_size=sort_size), + streamdata.batched(batch_size), + streamdata.padding(), + streamdata.cmvn(cmvn_file) ) + if paddle.__version__ >= '2.3.2': + self.loader = streamdata.WebLoader( + self.dataset, + num_workers=self.n_iter_processes, + prefetch_factor = self.prefetch_factor, + batch_size=None + ) + else: + self.loader = streamdata.WebLoader( + self.dataset, + num_workers=self.n_iter_processes, + batch_size=None + ) + def __iter__(self): return self.loader.__iter__() diff --git a/setup.py b/setup.py index b94a4cb2..035d0b2d 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,8 @@ base = [ "pypinyin", "pypinyin-dict", "python-dateutil", "pyworld", "resampy==0.2.2", "sacrebleu", "scipy", "sentencepiece~=0.1.96", "soundfile~=0.10", "textgrid", "timer", "tqdm", "typeguard", "visualdl", "webrtcvad", - "yacs~=0.1.8", "prettytable", "zhon", 'colorlog', 'pathos == 0.2.8', 'webdataset' + "yacs~=0.1.8", "prettytable", "zhon", "colorlog", "pathos == 0.2.8", + "braceexpand", "pyyaml" ] server = [ From aa12b9ab523eb7ca1e178dda663ecdcc62c6dd3b Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Tue, 28 Jun 2022 02:49:52 +0000 Subject: [PATCH 4/7] replace s2t.transform with audio.transform --- paddlespeech/cli/asr/infer.py | 2 +- paddlespeech/s2t/exps/u2/bin/test_wav.py | 2 +- paddlespeech/server/engine/asr/online/onnx/asr_engine.py | 2 +- .../server/engine/asr/online/paddleinference/asr_engine.py | 2 +- paddlespeech/server/engine/asr/online/python/asr_engine.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 24839a89..df7c5835 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -33,8 +33,8 @@ from ..log import logger from ..utils import CLI_TIMER from ..utils import stats_wrapper from ..utils import timer_register +from paddlespeech.s2t.audio.transformation import Transformation from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer -from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.utils.utility import UpdateConfig __all__ = ['ASRExecutor'] diff --git a/paddlespeech/s2t/exps/u2/bin/test_wav.py b/paddlespeech/s2t/exps/u2/bin/test_wav.py index 86c3db89..887ec7a6 100644 --- a/paddlespeech/s2t/exps/u2/bin/test_wav.py +++ b/paddlespeech/s2t/exps/u2/bin/test_wav.py @@ -20,10 +20,10 @@ import paddle import soundfile from yacs.config import CfgNode +from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.models.u2 import U2Model from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.utility import UpdateConfig logger = Log(__name__).getlog() diff --git a/paddlespeech/server/engine/asr/online/onnx/asr_engine.py b/paddlespeech/server/engine/asr/online/onnx/asr_engine.py index 06793164..cb743ea2 100644 --- a/paddlespeech/server/engine/asr/online/onnx/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/onnx/asr_engine.py @@ -26,7 +26,7 @@ from paddlespeech.cli.log import logger from paddlespeech.resource import CommonTaskResource from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.modules.ctc import CTCDecoder -from paddlespeech.s2t.transform.transformation import Transformation +from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils import onnx_infer diff --git a/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py b/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py index efb726aa..bcd0fa7f 100644 --- a/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py @@ -24,9 +24,9 @@ from yacs.config import CfgNode from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.log import logger from paddlespeech.resource import CommonTaskResource +from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.modules.ctc import CTCDecoder -from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.paddle_predictor import init_predictor diff --git a/paddlespeech/server/engine/asr/online/python/asr_engine.py b/paddlespeech/server/engine/asr/online/python/asr_engine.py index daa9fc50..2ffbba99 100644 --- a/paddlespeech/server/engine/asr/online/python/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/python/asr_engine.py @@ -24,9 +24,9 @@ from yacs.config import CfgNode from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.log import logger from paddlespeech.resource import CommonTaskResource +from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.modules.ctc import CTCDecoder -from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.utils.tensor_utils import add_sos_eos from paddlespeech.s2t.utils.tensor_utils import pad_sequence from paddlespeech.s2t.utils.utility import UpdateConfig From 81934d7191230b62db8241d03f0eef22c2351d93 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Tue, 28 Jun 2022 09:20:00 +0000 Subject: [PATCH 5/7] fix run.sh --- examples/wenetspeech/asr1/local/train.sh | 69 ++++++++++++++++++++++++ examples/wenetspeech/asr1/run.sh | 3 +- 2 files changed, 71 insertions(+), 1 deletion(-) create mode 100755 examples/wenetspeech/asr1/local/train.sh diff --git a/examples/wenetspeech/asr1/local/train.sh b/examples/wenetspeech/asr1/local/train.sh new file mode 100755 index 00000000..df84ee62 --- /dev/null +++ b/examples/wenetspeech/asr1/local/train.sh @@ -0,0 +1,69 @@ +#!/bin/bash + +profiler_options= +benchmark_batch_size=0 +benchmark_max_step=0 + +# seed may break model convergence +seed=0 + +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +if [ ${seed} != 0 ]; then + export FLAGS_cudnn_deterministic=True + echo "using seed $seed & FLAGS_cudnn_deterministic=True ..." +fi + +if [ $# -lt 2 ] && [ $# -gt 3 ];then + echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)" + exit -1 +fi + +config_path=$1 +ckpt_name=$2 +ips=$3 + +if [ ! $ips ];then + ips_config= +else + ips_config="--ips="${ips} +fi +echo ${ips_config} + +mkdir -p exp + +if [ ${ngpu} == 0 ]; then +python3 -u ${BIN_DIR}/train.py \ +--ngpu ${ngpu} \ +--seed ${seed} \ +--config ${config_path} \ +--output exp/${ckpt_name} \ +--profiler-options "${profiler_options}" \ +--benchmark-batch-size ${benchmark_batch_size} \ +--benchmark-max-step ${benchmark_max_step} +else +#NCCL_SOCKET_IFNAME=eth0 +python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \ +--ngpu ${ngpu} \ +--seed ${seed} \ +--config ${config_path} \ +--output exp/${ckpt_name} \ +--profiler-options "${profiler_options}" \ +--benchmark-batch-size ${benchmark_batch_size} \ +--benchmark-max-step ${benchmark_max_step} +fi + + +if [ ${seed} != 0 ]; then + unset FLAGS_cudnn_deterministic +fi + +if [ $? -ne 0 ]; then + echo "Failed in training!" + exit 1 +fi + +exit 0 diff --git a/examples/wenetspeech/asr1/run.sh b/examples/wenetspeech/asr1/run.sh index 9995bc63..f2f29246 100644 --- a/examples/wenetspeech/asr1/run.sh +++ b/examples/wenetspeech/asr1/run.sh @@ -7,6 +7,7 @@ gpus=0,1,2,3,4,5,6,7 stage=0 stop_stage=100 conf_path=conf/conformer.yaml +ips= #xxx.xxx.xxx, xxx.xxx.xxx.xxx decode_conf_path=conf/tuning/decode.yaml average_checkpoint=true avg_num=10 @@ -26,7 +27,7 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # train model, all `ckpt` under `exp` dir - CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${ips} fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then From 429221dc0379eb0435f5e3e6194d7191ab571831 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Wed, 29 Jun 2022 03:30:04 +0000 Subject: [PATCH 6/7] adopt multi machine traiing --- examples/wenetspeech/asr1/conf/conformer.yaml | 2 +- examples/wenetspeech/asr1/local/train.sh | 3 +-- paddlespeech/audio/streamdata/shardlists.py | 2 ++ paddlespeech/audio/streamdata/utils.py | 10 +++++++--- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/examples/wenetspeech/asr1/conf/conformer.yaml b/examples/wenetspeech/asr1/conf/conformer.yaml index f46d4bd9..013c3e0c 100644 --- a/examples/wenetspeech/asr1/conf/conformer.yaml +++ b/examples/wenetspeech/asr1/conf/conformer.yaml @@ -67,7 +67,7 @@ maxlen_out: 150 # if output length(number of tokens) > maxlen-out, data is auto resample_rate: 16000 shuffle_size: 1500 sort_size: 1000 -num_workers: 0 +num_workers: 8 prefetch_factor: 10 dist_sampler: True num_encs: 1 diff --git a/examples/wenetspeech/asr1/local/train.sh b/examples/wenetspeech/asr1/local/train.sh index df84ee62..01af00b6 100755 --- a/examples/wenetspeech/asr1/local/train.sh +++ b/examples/wenetspeech/asr1/local/train.sh @@ -45,8 +45,7 @@ python3 -u ${BIN_DIR}/train.py \ --benchmark-batch-size ${benchmark_batch_size} \ --benchmark-max-step ${benchmark_max_step} else -#NCCL_SOCKET_IFNAME=eth0 -python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \ +NCCL_SOCKET_IFNAME=eth0 python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \ --ngpu ${ngpu} \ --seed ${seed} \ --config ${config_path} \ diff --git a/paddlespeech/audio/streamdata/shardlists.py b/paddlespeech/audio/streamdata/shardlists.py index 3d1801cc..cfaf9a64 100644 --- a/paddlespeech/audio/streamdata/shardlists.py +++ b/paddlespeech/audio/streamdata/shardlists.py @@ -65,6 +65,7 @@ class SimpleShardList(IterableDataset): def split_by_node(src, group=None): rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group) + logger.info(f"world_size:{world_size}, rank:{rank}") if world_size > 1: for s in islice(src, rank, None, world_size): yield s @@ -83,6 +84,7 @@ def single_node_only(src, group=None): def split_by_worker(src): rank, world_size, worker, num_workers = utils.paddle_worker_info() + logger.info(f"num_workers:{num_workers}, worker:{worker}") if num_workers > 1: for s in islice(src, worker, None, num_workers): yield s diff --git a/paddlespeech/audio/streamdata/utils.py b/paddlespeech/audio/streamdata/utils.py index 83a42bad..c7294f2b 100644 --- a/paddlespeech/audio/streamdata/utils.py +++ b/paddlespeech/audio/streamdata/utils.py @@ -16,6 +16,9 @@ import re import sys from typing import Any, Callable, Iterator, Optional, Union +from ..utils.log import Logger + +logger = Logger(__name__) def make_seed(*args): seed = 0 @@ -112,13 +115,14 @@ def paddle_worker_info(group=None): num_workers = int(os.environ["NUM_WORKERS"]) else: try: - import paddle.io.get_worker_info + from paddle.io import get_worker_info worker_info = paddle.io.get_worker_info() if worker_info is not None: worker = worker_info.id num_workers = worker_info.num_workers - except ModuleNotFoundError: - pass + except ModuleNotFoundError as E: + logger.info(f"not found {E}") + exit(-1) return rank, world_size, worker, num_workers From 92d1d08b9a9c8fbe96457024d60b60240fa3bc79 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Tue, 5 Jul 2022 09:23:01 +0000 Subject: [PATCH 7/7] fix scripts --- examples/wenetspeech/asr1/conf/conformer.yaml | 20 +- examples/wenetspeech/asr1/local/data.sh | 4 +- paddlespeech/audio/streamdata/__init__.py | 17 +- paddlespeech/audio/streamdata/autodecode.py | 8 +- paddlespeech/audio/streamdata/compat.py | 24 +- paddlespeech/audio/streamdata/filters.py | 28 +- paddlespeech/audio/streamdata/tariterators.py | 6 +- paddlespeech/audio/text/text_featurizer.py | 235 +++++++++++ paddlespeech/audio/text/utility.py | 393 ++++++++++++++++++ paddlespeech/s2t/exps/deepspeech2/model.py | 2 +- paddlespeech/s2t/exps/u2/model.py | 185 +-------- paddlespeech/s2t/exps/u2_kaldi/model.py | 116 ++---- paddlespeech/s2t/exps/u2_st/model.py | 93 +---- paddlespeech/s2t/io/dataloader.py | 177 +++++++- paddlespeech/s2t/models/u2/u2.py | 6 +- paddlespeech/s2t/models/u2_st/u2_st.py | 4 +- 16 files changed, 901 insertions(+), 417 deletions(-) create mode 100644 paddlespeech/audio/text/text_featurizer.py create mode 100644 paddlespeech/audio/text/utility.py diff --git a/examples/wenetspeech/asr1/conf/conformer.yaml b/examples/wenetspeech/asr1/conf/conformer.yaml index 013c3e0c..d1ac20b9 100644 --- a/examples/wenetspeech/asr1/conf/conformer.yaml +++ b/examples/wenetspeech/asr1/conf/conformer.yaml @@ -52,6 +52,7 @@ test_manifest: data/test_meeting/data.list use_stream_data: True unit_type: 'char' vocab_filepath: data/lang_char/vocab.txt +preprocess_config: conf/preprocess.yaml cmvn_file: data/mean_std.json spm_model_prefix: '' feat_dim: 80 @@ -65,30 +66,17 @@ maxlen_in: 1200 # if input length(number of frames) > maxlen-in, data is automa minlen_out: 0 maxlen_out: 150 # if output length(number of tokens) > maxlen-out, data is automatically removed resample_rate: 16000 -shuffle_size: 1500 -sort_size: 1000 +shuffle_size: 1500 # read number of 'shuffle_size' data as a chunk, shuffle the data in the chunk +sort_size: 1000 # read number of 'sort_size' data as a chunk, sort the data in the chunk num_workers: 8 prefetch_factor: 10 dist_sampler: True num_encs: 1 -augment_conf: - max_w: 80 - w_inplace: True - w_mode: "PIL" - max_f: 30 - num_f_mask: 2 - f_inplace: True - f_replace_with_zero: False - max_t: 40 - num_t_mask: 2 - t_inplace: True - t_replace_with_zero: False - ########################################### # Training # ########################################### -n_epoch: 30 +n_epoch: 32 accum_grad: 32 global_grad_clip: 5.0 log_interval: 100 diff --git a/examples/wenetspeech/asr1/local/data.sh b/examples/wenetspeech/asr1/local/data.sh index b3472a8f..62579ba3 100755 --- a/examples/wenetspeech/asr1/local/data.sh +++ b/examples/wenetspeech/asr1/local/data.sh @@ -90,8 +90,8 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then for x in $dev_set $test_sets ${train_set}; do dst=$shards_dir/$x mkdir -p $dst - utils/make_filted_shard_list.py --resample 16000 --num_utts_per_shard 1000 \ - --do_filter --num_node 1 --num_gpus_per_node 8 \ + utils/make_filted_shard_list.py --num_node 1 --num_gpus_per_node 8 --num_utts_per_shard 1000 \ + --do_filter --resample 16000 \ --num_threads 32 --segments data/$x/segments \ data/$x/wav.scp data/$x/text \ $(realpath $dst) data/$x/data.list diff --git a/paddlespeech/audio/streamdata/__init__.py b/paddlespeech/audio/streamdata/__init__.py index 1acd898a..753fcc11 100644 --- a/paddlespeech/audio/streamdata/__init__.py +++ b/paddlespeech/audio/streamdata/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # See the LICENSE file for licensing terms (BSD-style). # Modified from https://github.com/webdataset/webdataset # @@ -26,7 +27,7 @@ from .filters import ( pipelinefilter, rename, rename_keys, - rsample, + audio_resample, select, shuffle, slice, @@ -34,14 +35,14 @@ from .filters import ( transform_with, unbatched, xdecode, - data_filter, - tokenize, - resample, - compute_fbank, - spec_aug, + audio_data_filter, + audio_tokenize, + audio_resample, + audio_compute_fbank, + audio_spec_aug, sort, - padding, - cmvn, + audio_padding, + audio_cmvn, placeholder, ) from .handlers import ( diff --git a/paddlespeech/audio/streamdata/autodecode.py b/paddlespeech/audio/streamdata/autodecode.py index 8c74b685..ca0e2ea2 100644 --- a/paddlespeech/audio/streamdata/autodecode.py +++ b/paddlespeech/audio/streamdata/autodecode.py @@ -291,12 +291,12 @@ def torch_video(key, data): ################################################################ -# paddleaudio +# paddlespeech.audio ################################################################ def paddle_audio(key, data): - """Decode audio using the paddleaudio library. + """Decode audio using the paddlespeech.audio library. :param key: file name extension :param data: data to be decoded @@ -305,13 +305,13 @@ def paddle_audio(key, data): if extension not in ["flac", "mp3", "sox", "wav", "m4a", "ogg", "wma"]: return None - import paddleaudio + import paddlespeech.audio with tempfile.TemporaryDirectory() as dirname: fname = os.path.join(dirname, f"file.{extension}") with open(fname, "wb") as stream: stream.write(data) - return paddleaudio.load(fname) + return paddlespeech.audio.load(fname) ################################################################ diff --git a/paddlespeech/audio/streamdata/compat.py b/paddlespeech/audio/streamdata/compat.py index 11308d03..deda5338 100644 --- a/paddlespeech/audio/streamdata/compat.py +++ b/paddlespeech/audio/streamdata/compat.py @@ -78,29 +78,29 @@ class FluidInterface: def xdecode(self, *args, **kw): return self.compose(filters.xdecode(*args, **kw)) - def data_filter(self, *args, **kw): - return self.compose(filters.data_filter(*args, **kw)) + def audio_data_filter(self, *args, **kw): + return self.compose(filters.audio_data_filter(*args, **kw)) - def tokenize(self, *args, **kw): - return self.compose(filters.tokenize(*args, **kw)) + def audio_tokenize(self, *args, **kw): + return self.compose(filters.audio_tokenize(*args, **kw)) def resample(self, *args, **kw): return self.compose(filters.resample(*args, **kw)) - def compute_fbank(self, *args, **kw): - return self.compose(filters.compute_fbank(*args, **kw)) + def audio_compute_fbank(self, *args, **kw): + return self.compose(filters.audio_compute_fbank(*args, **kw)) - def spec_aug(self, *args, **kw): - return self.compose(filters.spec_aug(*args, **kw)) + def audio_spec_aug(self, *args, **kw): + return self.compose(filters.audio_spec_aug(*args, **kw)) def sort(self, size=500): return self.compose(filters.sort(size)) - def padding(self): - return self.compose(filters.padding()) + def audio_padding(self): + return self.compose(filters.audio_padding()) - def cmvn(self, cmvn_file): - return self.compose(filters.cmvn(cmvn_file)) + def audio_cmvn(self, cmvn_file): + return self.compose(filters.audio_cmvn(cmvn_file)) class WebDataset(DataPipeline, FluidInterface): """Small fluid-interface wrapper for DataPipeline.""" diff --git a/paddlespeech/audio/streamdata/filters.py b/paddlespeech/audio/streamdata/filters.py index 0ade66f9..82b9c6ba 100644 --- a/paddlespeech/audio/streamdata/filters.py +++ b/paddlespeech/audio/streamdata/filters.py @@ -579,7 +579,7 @@ xdecode = pipelinefilter(_xdecode) -def _data_filter(source, +def _audio_data_filter(source, frame_shift=10, max_length=10240, min_length=10, @@ -629,9 +629,9 @@ def _data_filter(source, continue yield sample -data_filter = pipelinefilter(_data_filter) +audio_data_filter = pipelinefilter(_audio_data_filter) -def _tokenize(source, +def _audio_tokenize(source, symbol_table, bpe_model=None, non_lang_syms=None, @@ -693,9 +693,9 @@ def _tokenize(source, sample['label'] = label yield sample -tokenize = pipelinefilter(_tokenize) +audio_tokenize = pipelinefilter(_audio_tokenize) -def _resample(source, resample_rate=16000): +def _audio_resample(source, resample_rate=16000): """ Resample data. Inplace operation. @@ -718,9 +718,9 @@ def _resample(source, resample_rate=16000): )) yield sample -resample = pipelinefilter(_resample) +audio_resample = pipelinefilter(_audio_resample) -def _compute_fbank(source, +def _audio_compute_fbank(source, num_mel_bins=80, frame_length=25, frame_shift=10, @@ -756,9 +756,9 @@ def _compute_fbank(source, yield dict(fname=sample['fname'], label=sample['label'], feat=mat) -compute_fbank = pipelinefilter(_compute_fbank) +audio_compute_fbank = pipelinefilter(_audio_compute_fbank) -def _spec_aug(source, +def _audio_spec_aug(source, max_w=5, w_inplace=True, w_mode="PIL", @@ -799,7 +799,7 @@ def _spec_aug(source, sample['feat'] = paddle.to_tensor(x, dtype=paddle.float32) yield sample -spec_aug = pipelinefilter(_spec_aug) +audio_spec_aug = pipelinefilter(_audio_spec_aug) def _sort(source, sort_size=500): @@ -881,7 +881,7 @@ def dynamic_batched(source, max_frames_in_batch=12000): yield buf -def _padding(source): +def _audio_padding(source): """ Padding the data into training data Args: @@ -914,9 +914,9 @@ def _padding(source): yield (sorted_keys, padded_feats, feats_lengths, padding_labels, label_lengths) -padding = pipelinefilter(_padding) +audio_padding = pipelinefilter(_audio_padding) -def _cmvn(source, cmvn_file): +def _audio_cmvn(source, cmvn_file): global_cmvn = GlobalCMVN(cmvn_file) for batch in source: sorted_keys, padded_feats, feats_lengths, padding_labels, label_lengths = batch @@ -926,7 +926,7 @@ def _cmvn(source, cmvn_file): yield (sorted_keys, padded_feats, feats_lengths, padding_labels, label_lengths) -cmvn = pipelinefilter(_cmvn) +audio_cmvn = pipelinefilter(_audio_cmvn) def _placeholder(source): for data in source: diff --git a/paddlespeech/audio/streamdata/tariterators.py b/paddlespeech/audio/streamdata/tariterators.py index 2c1daae1..b1616918 100644 --- a/paddlespeech/audio/streamdata/tariterators.py +++ b/paddlespeech/audio/streamdata/tariterators.py @@ -21,7 +21,7 @@ trace = False meta_prefix = "__" meta_suffix = "__" -from ... import audio as paddleaudio +import paddlespeech import paddle import numpy as np @@ -118,7 +118,7 @@ def tar_file_iterator( assert pos > 0 prefix, postfix = name[:pos], name[pos + 1:] if postfix == 'wav': - waveform, sample_rate = paddleaudio.load(stream.extractfile(tarinfo), normal=False) + waveform, sample_rate = paddlespeech.audio.load(stream.extractfile(tarinfo), normal=False) result = dict(fname=prefix, wav=waveform, sample_rate = sample_rate) else: txt = stream.extractfile(tarinfo).read().decode('utf8').strip() @@ -167,7 +167,7 @@ def tar_file_and_group_iterator( if postfix == 'txt': example['txt'] = file_obj.read().decode('utf8').strip() elif postfix in AUDIO_FORMAT_SETS: - waveform, sample_rate = paddleaudio.load(file_obj, normal=False) + waveform, sample_rate = paddlespeech.audio.load(file_obj, normal=False) waveform = paddle.to_tensor(np.expand_dims(np.array(waveform),0), dtype=paddle.float32) example['wav'] = waveform diff --git a/paddlespeech/audio/text/text_featurizer.py b/paddlespeech/audio/text/text_featurizer.py new file mode 100644 index 00000000..91c4d75c --- /dev/null +++ b/paddlespeech/audio/text/text_featurizer.py @@ -0,0 +1,235 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains the text featurizer class.""" +from pprint import pformat +from typing import Union + +import sentencepiece as spm + +from .utility import BLANK +from .utility import EOS +from .utility import load_dict +from .utility import MASKCTC +from .utility import SOS +from .utility import SPACE +from .utility import UNK +from ..utils.log import Logger + +logger = Logger(__name__) + +__all__ = ["TextFeaturizer"] + + +class TextFeaturizer(): + def __init__(self, unit_type, vocab, spm_model_prefix=None, maskctc=False): + """Text featurizer, for processing or extracting features from text. + + Currently, it supports char/word/sentence-piece level tokenizing and conversion into + a list of token indices. Note that the token indexing order follows the + given vocabulary file. + + Args: + unit_type (str): unit type, e.g. char, word, spm + vocab Option[str, list]: Filepath to load vocabulary for token indices conversion, or vocab list. + spm_model_prefix (str, optional): spm model prefix. Defaults to None. + """ + assert unit_type in ('char', 'spm', 'word') + self.unit_type = unit_type + self.unk = UNK + self.maskctc = maskctc + + if vocab: + self.vocab_dict, self._id2token, self.vocab_list, self.unk_id, self.eos_id, self.blank_id = self._load_vocabulary_from_file( + vocab, maskctc) + self.vocab_size = len(self.vocab_list) + else: + logger.warning("TextFeaturizer: not have vocab file or vocab list.") + + if unit_type == 'spm': + spm_model = spm_model_prefix + '.model' + self.sp = spm.SentencePieceProcessor() + self.sp.Load(spm_model) + + def tokenize(self, text, replace_space=True): + if self.unit_type == 'char': + tokens = self.char_tokenize(text, replace_space) + elif self.unit_type == 'word': + tokens = self.word_tokenize(text) + else: # spm + tokens = self.spm_tokenize(text) + return tokens + + def detokenize(self, tokens): + if self.unit_type == 'char': + text = self.char_detokenize(tokens) + elif self.unit_type == 'word': + text = self.word_detokenize(tokens) + else: # spm + text = self.spm_detokenize(tokens) + return text + + def featurize(self, text): + """Convert text string to a list of token indices. + + Args: + text (str): Text to process. + + Returns: + List[int]: List of token indices. + """ + tokens = self.tokenize(text) + ids = [] + for token in tokens: + if token not in self.vocab_dict: + logger.debug(f"Text Token: {token} -> {self.unk}") + token = self.unk + ids.append(self.vocab_dict[token]) + return ids + + def defeaturize(self, idxs): + """Convert a list of token indices to text string, + ignore index after eos_id. + + Args: + idxs (List[int]): List of token indices. + + Returns: + str: Text. + """ + tokens = [] + for idx in idxs: + if idx == self.eos_id: + break + tokens.append(self._id2token[idx]) + text = self.detokenize(tokens) + return text + + def char_tokenize(self, text, replace_space=True): + """Character tokenizer. + + Args: + text (str): text string. + replace_space (bool): False only used by build_vocab.py. + + Returns: + List[str]: tokens. + """ + text = text.strip() + if replace_space: + text_list = [SPACE if item == " " else item for item in list(text)] + else: + text_list = list(text) + return text_list + + def char_detokenize(self, tokens): + """Character detokenizer. + + Args: + tokens (List[str]): tokens. + + Returns: + str: text string. + """ + tokens = [t.replace(SPACE, " ") for t in tokens] + return "".join(tokens) + + def word_tokenize(self, text): + """Word tokenizer, separate by .""" + return text.strip().split() + + def word_detokenize(self, tokens): + """Word detokenizer, separate by .""" + return " ".join(tokens) + + def spm_tokenize(self, text): + """spm tokenize. + + Args: + text (str): text string. + + Returns: + List[str]: sentence pieces str code + """ + stats = {"num_empty": 0, "num_filtered": 0} + + def valid(line): + return True + + def encode(l): + return self.sp.EncodeAsPieces(l) + + def encode_line(line): + line = line.strip() + if len(line) > 0: + line = encode(line) + if valid(line): + return line + else: + stats["num_filtered"] += 1 + else: + stats["num_empty"] += 1 + return None + + enc_line = encode_line(text) + return enc_line + + def spm_detokenize(self, tokens, input_format='piece'): + """spm detokenize. + + Args: + ids (List[str]): tokens. + + Returns: + str: text + """ + if input_format == "piece": + + def decode(l): + return "".join(self.sp.DecodePieces(l)) + elif input_format == "id": + + def decode(l): + return "".join(self.sp.DecodeIds(l)) + + return decode(tokens) + + def _load_vocabulary_from_file(self, vocab: Union[str, list], + maskctc: bool): + """Load vocabulary from file.""" + if isinstance(vocab, list): + vocab_list = vocab + else: + vocab_list = load_dict(vocab, maskctc) + assert vocab_list is not None + logger.debug(f"Vocab: {pformat(vocab_list)}") + + id2token = dict( + [(idx, token) for (idx, token) in enumerate(vocab_list)]) + token2id = dict( + [(token, idx) for (idx, token) in enumerate(vocab_list)]) + + blank_id = vocab_list.index(BLANK) if BLANK in vocab_list else -1 + maskctc_id = vocab_list.index(MASKCTC) if MASKCTC in vocab_list else -1 + unk_id = vocab_list.index(UNK) if UNK in vocab_list else -1 + eos_id = vocab_list.index(EOS) if EOS in vocab_list else -1 + sos_id = vocab_list.index(SOS) if SOS in vocab_list else -1 + space_id = vocab_list.index(SPACE) if SPACE in vocab_list else -1 + + logger.info(f"BLANK id: {blank_id}") + logger.info(f"UNK id: {unk_id}") + logger.info(f"EOS id: {eos_id}") + logger.info(f"SOS id: {sos_id}") + logger.info(f"SPACE id: {space_id}") + logger.info(f"MASKCTC id: {maskctc_id}") + return token2id, id2token, vocab_list, unk_id, eos_id, blank_id diff --git a/paddlespeech/audio/text/utility.py b/paddlespeech/audio/text/utility.py new file mode 100644 index 00000000..d35785db --- /dev/null +++ b/paddlespeech/audio/text/utility.py @@ -0,0 +1,393 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains data helper functions.""" +import json +import math +import tarfile +from collections import namedtuple +from typing import List +from typing import Optional +from typing import Text + +import jsonlines +import numpy as np + +from paddlespeech.s2t.utils.log import Log + +logger = Log(__name__).getlog() + +__all__ = [ + "load_dict", "load_cmvn", "read_manifest", "rms_to_db", "rms_to_dbfs", + "max_dbfs", "mean_dbfs", "gain_db_to_ratio", "normalize_audio", "SOS", + "EOS", "UNK", "BLANK", "MASKCTC", "SPACE", "convert_samples_to_float32", + "convert_samples_from_float32" +] + +IGNORE_ID = -1 +# `sos` and `eos` using same token +SOS = "" +EOS = SOS +UNK = "" +BLANK = "" +MASKCTC = "" +SPACE = "" + + +def load_dict(dict_path: Optional[Text], maskctc=False) -> Optional[List[Text]]: + if dict_path is None: + return None + + with open(dict_path, "r") as f: + dictionary = f.readlines() + # first token is `` + # multi line: ` 0\n` + # one line: `` + # space is relpace with + char_list = [entry[:-1].split(" ")[0] for entry in dictionary] + if BLANK not in char_list: + char_list.insert(0, BLANK) + if EOS not in char_list: + char_list.append(EOS) + # for non-autoregressive maskctc model + if maskctc and MASKCTC not in char_list: + char_list.append(MASKCTC) + return char_list + + +def read_manifest( + manifest_path, + max_input_len=float('inf'), + min_input_len=0.0, + max_output_len=float('inf'), + min_output_len=0.0, + max_output_input_ratio=float('inf'), + min_output_input_ratio=0.0, ): + """Load and parse manifest file. + + Args: + manifest_path ([type]): Manifest file to load and parse. + max_input_len ([type], optional): maximum output seq length, + in seconds for raw wav, in frame numbers for feature data. + Defaults to float('inf'). + min_input_len (float, optional): minimum input seq length, + in seconds for raw wav, in frame numbers for feature data. + Defaults to 0.0. + max_output_len (float, optional): maximum input seq length, + in modeling units. Defaults to 500.0. + min_output_len (float, optional): minimum input seq length, + in modeling units. Defaults to 0.0. + max_output_input_ratio (float, optional): + maximum output seq length/output seq length ratio. Defaults to 10.0. + min_output_input_ratio (float, optional): + minimum output seq length/output seq length ratio. Defaults to 0.05. + + Raises: + IOError: If failed to parse the manifest. + + Returns: + List[dict]: Manifest parsing results. + """ + manifest = [] + with jsonlines.open(manifest_path, 'r') as reader: + for json_data in reader: + feat_len = json_data["input"][0]["shape"][ + 0] if "input" in json_data and "shape" in json_data["input"][ + 0] else 1.0 + token_len = json_data["output"][0]["shape"][ + 0] if "output" in json_data and "shape" in json_data["output"][ + 0] else 1.0 + conditions = [ + feat_len >= min_input_len, + feat_len <= max_input_len, + token_len >= min_output_len, + token_len <= max_output_len, + token_len / feat_len >= min_output_input_ratio, + token_len / feat_len <= max_output_input_ratio, + ] + if all(conditions): + manifest.append(json_data) + return manifest + + +# Tar File read +TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object']) + + +def parse_tar(file): + """Parse a tar file to get a tarfile object + and a map containing tarinfoes + """ + result = {} + f = tarfile.open(file) + for tarinfo in f.getmembers(): + result[tarinfo.name] = tarinfo + return f, result + + +def subfile_from_tar(file, local_data=None): + """Get subfile object from tar. + + tar:tarpath#filename + + It will return a subfile object from tar file + and cached tar file info for next reading request. + """ + tarpath, filename = file.split(':', 1)[1].split('#', 1) + + if local_data is None: + local_data = TarLocalData(tar2info={}, tar2object={}) + + assert isinstance(local_data, TarLocalData) + + if 'tar2info' not in local_data.__dict__: + local_data.tar2info = {} + if 'tar2object' not in local_data.__dict__: + local_data.tar2object = {} + + if tarpath not in local_data.tar2info: + fobj, infos = parse_tar(tarpath) + local_data.tar2info[tarpath] = infos + local_data.tar2object[tarpath] = fobj + else: + fobj = local_data.tar2object[tarpath] + infos = local_data.tar2info[tarpath] + return fobj.extractfile(infos[filename]) + + +def rms_to_db(rms: float): + """Root Mean Square to dB. + + Args: + rms ([float]): root mean square + + Returns: + float: dB + """ + return 20.0 * math.log10(max(1e-16, rms)) + + +def rms_to_dbfs(rms: float): + """Root Mean Square to dBFS. + https://fireattack.wordpress.com/2017/02/06/replaygain-loudness-normalization-and-applications/ + Audio is mix of sine wave, so 1 amp sine wave's Full scale is 0.7071, equal to -3.0103dB. + + dB = dBFS + 3.0103 + dBFS = db - 3.0103 + e.g. 0 dB = -3.0103 dBFS + + Args: + rms ([float]): root mean square + + Returns: + float: dBFS + """ + return rms_to_db(rms) - 3.0103 + + +def max_dbfs(sample_data: np.ndarray): + """Peak dBFS based on the maximum energy sample. + + Args: + sample_data ([np.ndarray]): float array, [-1, 1]. + + Returns: + float: dBFS + """ + # Peak dBFS based on the maximum energy sample. Will prevent overdrive if used for normalization. + return rms_to_dbfs(max(abs(np.min(sample_data)), abs(np.max(sample_data)))) + + +def mean_dbfs(sample_data): + """Peak dBFS based on the RMS energy. + + Args: + sample_data ([np.ndarray]): float array, [-1, 1]. + + Returns: + float: dBFS + """ + return rms_to_dbfs( + math.sqrt(np.mean(np.square(sample_data, dtype=np.float64)))) + + +def gain_db_to_ratio(gain_db: float): + """dB to ratio + + Args: + gain_db (float): gain in dB + + Returns: + float: scale in amp + """ + return math.pow(10.0, gain_db / 20.0) + + +def normalize_audio(sample_data: np.ndarray, dbfs: float=-3.0103): + """Nomalize audio to dBFS. + + Args: + sample_data (np.ndarray): input wave samples, [-1, 1]. + dbfs (float, optional): target dBFS. Defaults to -3.0103. + + Returns: + np.ndarray: normalized wave + """ + return np.maximum( + np.minimum(sample_data * gain_db_to_ratio(dbfs - max_dbfs(sample_data)), + 1.0), -1.0) + + +def _load_json_cmvn(json_cmvn_file): + """ Load the json format cmvn stats file and calculate cmvn + + Args: + json_cmvn_file: cmvn stats file in json format + + Returns: + a numpy array of [means, vars] + """ + with open(json_cmvn_file) as f: + cmvn_stats = json.load(f) + + means = cmvn_stats['mean_stat'] + variance = cmvn_stats['var_stat'] + count = cmvn_stats['frame_num'] + for i in range(len(means)): + means[i] /= count + variance[i] = variance[i] / count - means[i] * means[i] + if variance[i] < 1.0e-20: + variance[i] = 1.0e-20 + variance[i] = 1.0 / math.sqrt(variance[i]) + cmvn = np.array([means, variance]) + return cmvn + + +def _load_kaldi_cmvn(kaldi_cmvn_file): + """ Load the kaldi format cmvn stats file and calculate cmvn + + Args: + kaldi_cmvn_file: kaldi text style global cmvn file, which + is generated by: + compute-cmvn-stats --binary=false scp:feats.scp global_cmvn + + Returns: + a numpy array of [means, vars] + """ + means = [] + variance = [] + with open(kaldi_cmvn_file, 'r') as fid: + # kaldi binary file start with '\0B' + if fid.read(2) == '\0B': + logger.error('kaldi cmvn binary file is not supported, please ' + 'recompute it by: compute-cmvn-stats --binary=false ' + ' scp:feats.scp global_cmvn') + sys.exit(1) + fid.seek(0) + arr = fid.read().split() + assert (arr[0] == '[') + assert (arr[-2] == '0') + assert (arr[-1] == ']') + feat_dim = int((len(arr) - 2 - 2) / 2) + for i in range(1, feat_dim + 1): + means.append(float(arr[i])) + count = float(arr[feat_dim + 1]) + for i in range(feat_dim + 2, 2 * feat_dim + 2): + variance.append(float(arr[i])) + + for i in range(len(means)): + means[i] /= count + variance[i] = variance[i] / count - means[i] * means[i] + if variance[i] < 1.0e-20: + variance[i] = 1.0e-20 + variance[i] = 1.0 / math.sqrt(variance[i]) + cmvn = np.array([means, variance]) + return cmvn + + +def load_cmvn(cmvn_file: str, filetype: str): + """load cmvn from file. + + Args: + cmvn_file (str): cmvn path. + filetype (str): file type, optional[npz, json, kaldi]. + + Raises: + ValueError: file type not support. + + Returns: + Tuple[np.ndarray, np.ndarray]: mean, istd + """ + assert filetype in ['npz', 'json', 'kaldi'], filetype + filetype = filetype.lower() + if filetype == "json": + cmvn = _load_json_cmvn(cmvn_file) + elif filetype == "kaldi": + cmvn = _load_kaldi_cmvn(cmvn_file) + elif filetype == "npz": + eps = 1e-14 + npzfile = np.load(cmvn_file) + mean = np.squeeze(npzfile["mean"]) + std = np.squeeze(npzfile["std"]) + istd = 1 / (std + eps) + cmvn = [mean, istd] + else: + raise ValueError(f"cmvn file type no support: {filetype}") + return cmvn[0], cmvn[1] + + +def convert_samples_to_float32(samples): + """Convert sample type to float32. + + Audio sample type is usually integer or float-point. + Integers will be scaled to [-1, 1] in float32. + + PCM16 -> PCM32 + """ + float32_samples = samples.astype('float32') + if samples.dtype in np.sctypes['int']: + bits = np.iinfo(samples.dtype).bits + float32_samples *= (1. / 2**(bits - 1)) + elif samples.dtype in np.sctypes['float']: + pass + else: + raise TypeError("Unsupported sample type: %s." % samples.dtype) + return float32_samples + + +def convert_samples_from_float32(samples, dtype): + """Convert sample type from float32 to dtype. + + Audio sample type is usually integer or float-point. For integer + type, float32 will be rescaled from [-1, 1] to the maximum range + supported by the integer type. + + PCM32 -> PCM16 + """ + dtype = np.dtype(dtype) + output_samples = samples.copy() + if dtype in np.sctypes['int']: + bits = np.iinfo(dtype).bits + output_samples *= (2**(bits - 1) / 1.) + min_val = np.iinfo(dtype).min + max_val = np.iinfo(dtype).max + output_samples[output_samples > max_val] = max_val + output_samples[output_samples < min_val] = min_val + elif samples.dtype in np.sctypes['float']: + min_val = np.finfo(dtype).min + max_val = np.finfo(dtype).max + output_samples[output_samples > max_val] = max_val + output_samples[output_samples < min_val] = min_val + else: + raise TypeError("Unsupported sample type: %s." % samples.dtype) + return output_samples.astype(dtype) diff --git a/paddlespeech/s2t/exps/deepspeech2/model.py b/paddlespeech/s2t/exps/deepspeech2/model.py index 511997a7..7ab8cf85 100644 --- a/paddlespeech/s2t/exps/deepspeech2/model.py +++ b/paddlespeech/s2t/exps/deepspeech2/model.py @@ -23,7 +23,7 @@ import paddle from paddle import distributed as dist from paddle import inference -from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer +from paddlespeech.audio.text.text_featurizer import TextFeaturizer from paddlespeech.s2t.io.dataloader import BatchDataLoader from paddlespeech.s2t.models.ds2 import DeepSpeech2InferModel from paddlespeech.s2t.models.ds2 import DeepSpeech2Model diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index d6c68f96..cdad3b8f 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -27,6 +27,7 @@ from paddle import distributed as dist from paddlespeech.s2t.frontend.featurizer import TextFeaturizer from paddlespeech.s2t.io.dataloader import BatchDataLoader from paddlespeech.s2t.io.dataloader import StreamDataLoader +from paddlespeech.s2t.io.dataloader import DataLoaderFactory from paddlespeech.s2t.models.u2 import U2Model from paddlespeech.s2t.training.optimizer import OptimizerFactory from paddlespeech.s2t.training.reporter import ObsScope @@ -134,7 +135,8 @@ class U2Trainer(Trainer): msg = f"Valid: Rank: {dist.get_rank()}, " msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) - #msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) + if not self.use_streamdata: + msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in valid_dump.items()) logger.info(msg) @@ -195,7 +197,6 @@ class U2Trainer(Trainer): except Exception as e: logger.error(e) raise e - with Timer("Eval Time Cost: {}"): total_loss, num_seen_utts = self.valid() if dist.get_world_size() > 1: @@ -224,186 +225,14 @@ class U2Trainer(Trainer): config = self.config.clone() self.use_streamdata = config.get("use_stream_data", False) if self.train: - # train/valid dataset, return token ids - if self.use_streamdata: - self.train_loader = StreamDataLoader( - manifest_file=config.train_manifest, - train_mode=True, - unit_type=config.unit_type, - batch_size=config.batch_size, - num_mel_bins=config.feat_dim, - frame_length=config.window_ms, - frame_shift=config.stride_ms, - dither=config.dither, - minlen_in=config.minlen_in, - maxlen_in=config.maxlen_in, - minlen_out=config.minlen_out, - maxlen_out=config.maxlen_out, - resample_rate=config.resample_rate, - augment_conf=config.augment_conf, # dict - shuffle_size=config.shuffle_size, - sort_size=config.sort_size, - n_iter_processes=config.num_workers, - prefetch_factor=config.prefetch_factor, - dist_sampler=config.get('dist_sampler', False), - cmvn_file=config.cmvn_file, - vocab_filepath=config.vocab_filepath, - ) - self.valid_loader = StreamDataLoader( - manifest_file=config.dev_manifest, - train_mode=False, - unit_type=config.unit_type, - batch_size=config.batch_size, - num_mel_bins=config.feat_dim, - frame_length=config.window_ms, - frame_shift=config.stride_ms, - dither=config.dither, - minlen_in=config.minlen_in, - maxlen_in=config.maxlen_in, - minlen_out=config.minlen_out, - maxlen_out=config.maxlen_out, - resample_rate=config.resample_rate, - augment_conf=config.augment_conf, # dict - shuffle_size=config.shuffle_size, - sort_size=config.sort_size, - n_iter_processes=config.num_workers, - prefetch_factor=config.prefetch_factor, - dist_sampler=config.get('dist_sampler', False), - cmvn_file=config.cmvn_file, - vocab_filepath=config.vocab_filepath, - ) - else: - self.train_loader = BatchDataLoader( - json_file=config.train_manifest, - train_mode=True, - sortagrad=config.sortagrad, - batch_size=config.batch_size, - maxlen_in=config.maxlen_in, - maxlen_out=config.maxlen_out, - minibatches=config.minibatches, - mini_batch_size=self.args.ngpu, - batch_count=config.batch_count, - batch_bins=config.batch_bins, - batch_frames_in=config.batch_frames_in, - batch_frames_out=config.batch_frames_out, - batch_frames_inout=config.batch_frames_inout, - preprocess_conf=config.preprocess_config, - n_iter_processes=config.num_workers, - subsampling_factor=1, - num_encs=1, - dist_sampler=config.get('dist_sampler', False), - shortest_first=False) - - self.valid_loader = BatchDataLoader( - json_file=config.dev_manifest, - train_mode=False, - sortagrad=False, - batch_size=config.batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=self.args.ngpu, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=config.preprocess_config, - n_iter_processes=config.num_workers, - subsampling_factor=1, - num_encs=1, - dist_sampler=config.get('dist_sampler', False), - shortest_first=False) + self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args) + self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args) logger.info("Setup train/valid Dataloader!") else: decode_batch_size = config.get('decode', dict()).get( 'decode_batch_size', 1) - # test dataset, return raw text - if self.use_streamdata: - self.test_loader = StreamDataLoader( - manifest_file=config.test_manifest, - train_mode=False, - unit_type=config.unit_type, - batch_size=config.batch_size, - num_mel_bins=config.feat_dim, - frame_length=config.window_ms, - frame_shift=config.stride_ms, - dither=0.0, - minlen_in=0.0, - maxlen_in=float('inf'), - minlen_out=0, - maxlen_out=float('inf'), - resample_rate=config.resample_rate, - augment_conf=config.augment_conf, # dict - shuffle_size=config.shuffle_size, - sort_size=config.sort_size, - n_iter_processes=config.num_workers, - prefetch_factor=config.prefetch_factor, - dist_sampler=config.get('dist_sampler', False), - cmvn_file=config.cmvn_file, - vocab_filepath=config.vocab_filepath, - ) - self.align_loader = StreamDataLoader( - manifest_file=config.test_manifest, - train_mode=False, - unit_type=config.unit_type, - batch_size=config.batch_size, - num_mel_bins=config.feat_dim, - frame_length=config.window_ms, - frame_shift=config.stride_ms, - dither=0.0, - minlen_in=0.0, - maxlen_in=float('inf'), - minlen_out=0, - maxlen_out=float('inf'), - resample_rate=config.resample_rate, - augment_conf=config.augment_conf, # dict - shuffle_size=config.shuffle_size, - sort_size=config.sort_size, - n_iter_processes=config.num_workers, - prefetch_factor=config.prefetch_factor, - dist_sampler=config.get('dist_sampler', False), - cmvn_file=config.cmvn_file, - vocab_filepath=config.vocab_filepath, - ) - else: - self.test_loader = BatchDataLoader( - json_file=config.test_manifest, - train_mode=False, - sortagrad=False, - batch_size=decode_batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=1, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=config.preprocess_config, - n_iter_processes=1, - subsampling_factor=1, - num_encs=1) - - self.align_loader = BatchDataLoader( - json_file=config.test_manifest, - train_mode=False, - sortagrad=False, - batch_size=decode_batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=1, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=config.preprocess_config, - n_iter_processes=1, - subsampling_factor=1, - num_encs=1) + self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args) + self.align_loader = DataLoaderFactory.get_dataloader('align', config, self.args) logger.info("Setup test/align Dataloader!") def setup_model(self): diff --git a/paddlespeech/s2t/exps/u2_kaldi/model.py b/paddlespeech/s2t/exps/u2_kaldi/model.py index bc995977..cb015c11 100644 --- a/paddlespeech/s2t/exps/u2_kaldi/model.py +++ b/paddlespeech/s2t/exps/u2_kaldi/model.py @@ -25,7 +25,7 @@ from paddle import distributed as dist from paddlespeech.s2t.frontend.featurizer import TextFeaturizer from paddlespeech.s2t.frontend.utility import load_dict -from paddlespeech.s2t.io.dataloader import BatchDataLoader +from paddlespeech.s2t.io.dataloader import DataLoaderFactory from paddlespeech.s2t.models.u2 import U2Model from paddlespeech.s2t.training.optimizer import OptimizerFactory from paddlespeech.s2t.training.scheduler import LRSchedulerFactory @@ -104,7 +104,8 @@ class U2Trainer(Trainer): @paddle.no_grad() def valid(self): self.model.eval() - logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") + if not self.use_streamdata: + logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") valid_losses = defaultdict(list) num_seen_utts = 1 total_loss = 0.0 @@ -131,7 +132,8 @@ class U2Trainer(Trainer): msg = f"Valid: Rank: {dist.get_rank()}, " msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) - msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) + if not self.use_streamdata: + msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in valid_dump.items()) logger.info(msg) @@ -150,8 +152,8 @@ class U2Trainer(Trainer): # paddle.jit.save(script_model, script_model_path) self.before_train() - - logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") + if not self.use_streamdata: + logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.n_epoch: with Timer("Epoch-Train Time Cost: {}"): self.model.train() @@ -162,7 +164,8 @@ class U2Trainer(Trainer): msg = "Train: Rank: {}, ".format(dist.get_rank()) msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) - msg += "batch : {}/{}, ".format(batch_index + 1, + if not self.use_streamdata: + msg += "batch : {}/{}, ".format(batch_index + 1, len(self.train_loader)) msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) msg += "data time: {:>.3f}s, ".format(dataload_time) @@ -198,87 +201,23 @@ class U2Trainer(Trainer): self.new_epoch() def setup_dataloader(self): - config = self.config.clone() - # train/valid dataset, return token ids - self.train_loader = BatchDataLoader( - json_file=config.train_manifest, - train_mode=True, - sortagrad=False, - batch_size=config.batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=self.args.ngpu, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=config.preprocess_config, - n_iter_processes=config.num_workers, - subsampling_factor=1, - num_encs=1) - - self.valid_loader = BatchDataLoader( - json_file=config.dev_manifest, - train_mode=False, - sortagrad=False, - batch_size=config.batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=self.args.ngpu, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=None, - n_iter_processes=config.num_workers, - subsampling_factor=1, - num_encs=1) - - decode_batch_size = config.get('decode', dict()).get( - 'decode_batch_size', 1) - # test dataset, return raw text - self.test_loader = BatchDataLoader( - json_file=config.test_manifest, - train_mode=False, - sortagrad=False, - batch_size=decode_batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=1, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=None, - n_iter_processes=1, - subsampling_factor=1, - num_encs=1) - - self.align_loader = BatchDataLoader( - json_file=config.test_manifest, - train_mode=False, - sortagrad=False, - batch_size=decode_batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=1, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=None, - n_iter_processes=1, - subsampling_factor=1, - num_encs=1) - logger.info("Setup train/valid/test/align Dataloader!") + self.use_streamdata = config.get("use_stream_data", False) + if self.train: + config = self.config.clone() + self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args) + config = self.config.clone() + config['preprocess_config'] = None + self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args) + logger.info("Setup train/valid Dataloader!") + else: + config = self.config.clone() + config['preprocess_config'] = None + self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args) + config = self.config.clone() + config['preprocess_config'] = None + self.align_loader = DataLoaderFactory.get_dataloader('align', config, self.args) + logger.info("Setup test/align Dataloader!") + def setup_model(self): config = self.config @@ -406,7 +345,8 @@ class U2Tester(U2Trainer): def test(self): assert self.args.result_file self.model.eval() - logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") + if not self.use_streamdata: + logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") stride_ms = self.config.stride_ms error_rate_type = None diff --git a/paddlespeech/s2t/exps/u2_st/model.py b/paddlespeech/s2t/exps/u2_st/model.py index 6a32eda7..60382543 100644 --- a/paddlespeech/s2t/exps/u2_st/model.py +++ b/paddlespeech/s2t/exps/u2_st/model.py @@ -25,7 +25,7 @@ import paddle from paddle import distributed as dist from paddlespeech.s2t.frontend.featurizer import TextFeaturizer -from paddlespeech.s2t.io.dataloader import BatchDataLoader +from paddlespeech.s2t.io.dataloader import DataLoaderFactory from paddlespeech.s2t.models.u2_st import U2STModel from paddlespeech.s2t.training.optimizer import OptimizerFactory from paddlespeech.s2t.training.reporter import ObsScope @@ -120,7 +120,8 @@ class U2STTrainer(Trainer): @paddle.no_grad() def valid(self): self.model.eval() - logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") + if not self.use_streamdata: + logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") valid_losses = defaultdict(list) num_seen_utts = 1 total_loss = 0.0 @@ -153,7 +154,8 @@ class U2STTrainer(Trainer): msg = f"Valid: Rank: {dist.get_rank()}, " msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) - msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) + if not self.use_streamdata: + msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in valid_dump.items()) logger.info(msg) @@ -172,8 +174,8 @@ class U2STTrainer(Trainer): # paddle.jit.save(script_model, script_model_path) self.before_train() - - logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") + if not self.use_streamdata: + logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.n_epoch: with Timer("Epoch-Train Time Cost: {}"): self.model.train() @@ -191,7 +193,8 @@ class U2STTrainer(Trainer): self.train_batch(batch_index, batch, msg) self.after_train_batch() report('iter', batch_index + 1) - report('total', len(self.train_loader)) + if not self.use_streamdata: + report('total', len(self.train_loader)) report('reader_cost', dataload_time) observation['batch_cost'] = observation[ 'reader_cost'] + observation['step_cost'] @@ -241,79 +244,18 @@ class U2STTrainer(Trainer): load_transcript = True if config.model_conf.asr_weight > 0 else False + config = self.config.clone() + config['load_transcript'] = load_transcript + self.use_streamdata = config.get("use_stream_data", False) if self.train: - # train/valid dataset, return token ids - self.train_loader = BatchDataLoader( - json_file=config.train_manifest, - train_mode=True, - sortagrad=False, - batch_size=config.batch_size, - maxlen_in=config.maxlen_in, - maxlen_out=config.maxlen_out, - minibatches=0, - mini_batch_size=1, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=config. - preprocess_config, # aug will be off when train_mode=False - n_iter_processes=config.num_workers, - subsampling_factor=1, - load_aux_output=load_transcript, - num_encs=1, - dist_sampler=True) - - self.valid_loader = BatchDataLoader( - json_file=config.dev_manifest, - train_mode=False, - sortagrad=False, - batch_size=config.batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=1, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=config. - preprocess_config, # aug will be off when train_mode=False - n_iter_processes=config.num_workers, - subsampling_factor=1, - load_aux_output=load_transcript, - num_encs=1, - dist_sampler=False) + self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args) + self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args) logger.info("Setup train/valid Dataloader!") else: - # test dataset, return raw text - decode_batch_size = config.get('decode', dict()).get( - 'decode_batch_size', 1) - self.test_loader = BatchDataLoader( - json_file=config.test_manifest, - train_mode=False, - sortagrad=False, - batch_size=decode_batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=1, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=config. - preprocess_config, # aug will be off when train_mode=False - n_iter_processes=config.num_workers, - subsampling_factor=1, - num_encs=1, - dist_sampler=False) - + self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args) logger.info("Setup test Dataloader!") + def setup_model(self): config = self.config model_conf = config @@ -468,7 +410,8 @@ class U2STTester(U2STTrainer): def test(self): assert self.args.result_file self.model.eval() - logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") + if not self.use_streamdata: + logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") decode_cfg = self.config.decode bleu_func = bleu_score.char_bleu if decode_cfg.error_rate_type == 'char-bleu' else bleu_score.bleu diff --git a/paddlespeech/s2t/io/dataloader.py b/paddlespeech/s2t/io/dataloader.py index cb466ecb..83183024 100644 --- a/paddlespeech/s2t/io/dataloader.py +++ b/paddlespeech/s2t/io/dataloader.py @@ -30,9 +30,10 @@ from paddlespeech.s2t.io.reader import LoadInputsAndTargets from paddlespeech.s2t.utils.log import Log import paddlespeech.audio.streamdata as streamdata -from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer +from paddlespeech.audio.text.text_featurizer import TextFeaturizer +from yacs.config import CfgNode -__all__ = ["BatchDataLoader"] +__all__ = ["BatchDataLoader", "StreamDataLoader"] logger = Log(__name__).getlog() @@ -60,12 +61,36 @@ def batch_collate(x): """ return x[0] +def read_preprocess_cfg(preprocess_conf_file): + augment_conf = dict() + preprocess_cfg = CfgNode(new_allowed=True) + preprocess_cfg.merge_from_file(preprocess_conf_file) + for idx, process in enumerate(preprocess_cfg["process"]): + opts = dict(process) + process_type = opts.pop("type") + if process_type == 'time_warp': + augment_conf['max_w'] = process['max_time_warp'] + augment_conf['w_inplace'] = process['inplace'] + augment_conf['w_mode'] = process['mode'] + if process_type == 'freq_mask': + augment_conf['max_f'] = process['F'] + augment_conf['num_f_mask'] = process['n_mask'] + augment_conf['f_inplace'] = process['inplace'] + augment_conf['f_replace_with_zero'] = process['replace_with_zero'] + if process_type == 'time_mask': + augment_conf['max_t'] = process['T'] + augment_conf['num_t_mask'] = process['n_mask'] + augment_conf['t_inplace'] = process['inplace'] + augment_conf['t_replace_with_zero'] = process['replace_with_zero'] + return augment_conf + class StreamDataLoader(): def __init__(self, manifest_file: str, train_mode: bool, unit_type: str='char', batch_size: int=0, + preprocess_conf=None, num_mel_bins=80, frame_length=25, frame_shift=10, @@ -75,7 +100,6 @@ class StreamDataLoader(): minlen_out: float=0.0, maxlen_out: float=float('inf'), resample_rate: int=16000, - augment_conf: dict=None, shuffle_size: int=10000, sort_size: int=1000, n_iter_processes: int=1, @@ -95,12 +119,27 @@ class StreamDataLoader(): self.feat_dim = num_mel_bins self.vocab_size = text_featurizer.vocab_size + augment_conf = read_preprocess_cfg(preprocess_conf) + # The list of shard shardlist = [] with open(manifest_file, "r") as f: for line in f.readlines(): shardlist.append(line.strip()) - + world_size = 1 + try: + world_size = paddle.distributed.get_world_size() + except Exception as e: + logger.warninig(e) + logger.warninig("can not get world_size using paddle.distributed.get_world_size(), use world_size=1") + assert(len(shardlist) >= world_size, "the length of shard list should >= number of gpus/xpus/...") + + update_n_iter_processes = int(max(min(len(shardlist)/world_size - 1, self.n_iter_processes), 0)) + logger.info(f"update_n_iter_processes {update_n_iter_processes}") + if update_n_iter_processes != self.n_iter_processes: + self.n_iter_processes = update_n_iter_processes + logger.info(f"change nun_workers to {self.n_iter_processes}") + if self.dist_sampler: base_dataset = streamdata.DataPipeline( streamdata.SimpleShardList(shardlist), @@ -116,16 +155,16 @@ class StreamDataLoader(): ) self.dataset = base_dataset.append_list( - streamdata.tokenize(symbol_table), - streamdata.data_filter(frame_shift=frame_shift, max_length=maxlen_in, min_length=minlen_in, token_max_length=maxlen_out, token_min_length=minlen_in), - streamdata.resample(resample_rate=resample_rate), - streamdata.compute_fbank(num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither), - streamdata.spec_aug(**augment_conf) if train_mode else streamdata.placeholder(), # num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80) + streamdata.audio_tokenize(symbol_table), + streamdata.audio_data_filter(frame_shift=frame_shift, max_length=maxlen_in, min_length=minlen_in, token_max_length=maxlen_out, token_min_length=minlen_out), + streamdata.audio_resample(resample_rate=resample_rate), + streamdata.audio_compute_fbank(num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither), + streamdata.audio_spec_aug(**augment_conf) if train_mode else streamdata.placeholder(), # num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80) streamdata.shuffle(shuffle_size), streamdata.sort(sort_size=sort_size), streamdata.batched(batch_size), - streamdata.padding(), - streamdata.cmvn(cmvn_file) + streamdata.audio_padding(), + streamdata.audio_cmvn(cmvn_file) ) if paddle.__version__ >= '2.3.2': @@ -295,3 +334,119 @@ class BatchDataLoader(): echo += f"shortest_first: {self.shortest_first}, " echo += f"file: {self.json_file}" return echo + + +class DataLoaderFactory(): + @staticmethod + def get_dataloader(mode: str, config, args): + config = config.clone() + use_streamdata = config.get("use_stream_data", False) + if use_streamdata: + if mode == 'train': + config['manifest'] = config.train_manifest + config['train_mode'] = True + elif mode == 'valid': + config['manifest'] = config.dev_manifest + config['train_mode'] = False + elif model == 'test' or mode == 'align': + config['manifest'] = config.test_manifest + config['train_mode'] = False + config['dither'] = 0.0 + config['minlen_in'] = 0.0 + config['maxlen_in'] = float('inf') + config['minlen_out'] = 0 + config['maxlen_out'] = float('inf') + config['dist_sampler'] = False + else: + raise KeyError("not valid mode type!!, please input one of 'train, valid, test, align'") + return StreamDataLoader( + manifest_file=config.manifest, + train_mode=config.train_mode, + unit_type=config.unit_type, + preprocess_conf=config.preprocess_config, + batch_size=config.batch_size, + num_mel_bins=config.feat_dim, + frame_length=config.window_ms, + frame_shift=config.stride_ms, + dither=config.dither, + minlen_in=config.minlen_in, + maxlen_in=config.maxlen_in, + minlen_out=config.minlen_out, + maxlen_out=config.maxlen_out, + resample_rate=config.resample_rate, + shuffle_size=config.shuffle_size, + sort_size=config.sort_size, + n_iter_processes=config.num_workers, + prefetch_factor=config.prefetch_factor, + dist_sampler=config.dist_sampler, + cmvn_file=config.cmvn_file, + vocab_filepath=config.vocab_filepath, + ) + else: + if mode == 'train': + config['manifest'] = config.train_manifest + config['train_mode'] = True + config['mini_batch_size'] = args.ngpu + config['subsampling_factor'] = 1 + config['num_encs'] = 1 + elif mode == 'valid': + config['manifest'] = config.dev_manifest + config['train_mode'] = False + config['sortagrad'] = False + config['maxlen_in'] = float('inf') + config['maxlen_out'] = float('inf') + config['minibatches'] = 0 + config['mini_batch_size'] = args.ngpu + config['batch_count'] = 'auto' + config['batch_bins'] = 0 + config['batch_frames_in'] = 0 + config['batch_frames_out'] = 0 + config['batch_frames_inout'] = 0 + config['subsampling_factor'] = 1 + config['num_encs'] = 1 + config['shortest_first'] = False + elif mode == 'test' or mode == 'align': + config['manifest'] = config.test_manifest + config['train_mode'] = False + config['sortagrad'] = False + config['batch_size'] = config.get('decode', dict()).get( + 'decode_batch_size', 1) + config['maxlen_in'] = float('inf') + config['maxlen_out'] = float('inf') + config['minibatches'] = 0 + config['mini_batch_size'] = 1 + config['batch_count'] = 'auto' + config['batch_bins'] = 0 + config['batch_frames_in'] = 0 + config['batch_frames_out'] = 0 + config['batch_frames_inout'] = 0 + config['num_workers'] = 1 + config['subsampling_factor'] = 1 + config['num_encs'] = 1 + config['dist_sampler'] = False + config['shortest_first'] = False + else: + raise KeyError("not valid mode type!!, please input one of 'train, valid, test, align'") + + return BatchDataLoader( + json_file=config.manifest, + train_mode=config.train_mode, + sortagrad=config.sortagrad, + batch_size=config.batch_size, + maxlen_in=config.maxlen_in, + maxlen_out=config.maxlen_out, + minibatches=config.minibatches, + mini_batch_size=config.mini_batch_size, + batch_count=config.batch_count, + batch_bins=config.batch_bins, + batch_frames_in=config.batch_frames_in, + batch_frames_out=config.batch_frames_out, + batch_frames_inout=config.batch_frames_inout, + preprocess_conf=config.preprocess_config, + n_iter_processes=config.num_workers, + subsampling_factor=config.subsampling_factor, + load_aux_output=config.get('load_transcript', None), + num_encs=config.num_encs, + dist_sampler=config.dist_sampler, + shortest_first=config.shortest_first) + diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index b4b61666..e3d0edb7 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -48,9 +48,9 @@ from paddlespeech.s2t.utils import checkpoint from paddlespeech.s2t.utils import layer_tools from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank from paddlespeech.s2t.utils.log import Log -from paddlespeech.s2t.utils.tensor_utils import add_sos_eos -from paddlespeech.s2t.utils.tensor_utils import pad_sequence -from paddlespeech.s2t.utils.tensor_utils import th_accuracy +from paddlespeech.audio.utils.tensor_utils import add_sos_eos +from paddlespeech.audio.utils.tensor_utils import pad_sequence +from paddlespeech.audio.utils.tensor_utils import th_accuracy from paddlespeech.s2t.utils.utility import log_add from paddlespeech.s2t.utils.utility import UpdateConfig diff --git a/paddlespeech/s2t/models/u2_st/u2_st.py b/paddlespeech/s2t/models/u2_st/u2_st.py index 6447753c..00ded912 100644 --- a/paddlespeech/s2t/models/u2_st/u2_st.py +++ b/paddlespeech/s2t/models/u2_st/u2_st.py @@ -38,8 +38,8 @@ from paddlespeech.s2t.modules.mask import subsequent_mask from paddlespeech.s2t.utils import checkpoint from paddlespeech.s2t.utils import layer_tools from paddlespeech.s2t.utils.log import Log -from paddlespeech.s2t.utils.tensor_utils import add_sos_eos -from paddlespeech.s2t.utils.tensor_utils import th_accuracy +from paddlespeech.audio.utils.tensor_utils import add_sos_eos +from paddlespeech.audio.utils.tensor_utils import th_accuracy from paddlespeech.s2t.utils.utility import UpdateConfig __all__ = ["U2STModel", "U2STInferModel"]