parent
e04cd18846
commit
8f5e61090b
@ -0,0 +1,68 @@
|
||||
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
|
||||
# See the LICENSE file for licensing terms (BSD-style).
|
||||
# Modified from https://github.com/webdataset/webdataset
|
||||
#
|
||||
# flake8: noqa
|
||||
|
||||
from .cache import (
|
||||
cached_tarfile_samples,
|
||||
cached_tarfile_to_samples,
|
||||
lru_cleanup,
|
||||
pipe_cleaner,
|
||||
)
|
||||
from .compat import WebDataset, WebLoader, FluidWrapper
|
||||
from webdataset.extradatasets import MockDataset, with_epoch, with_length
|
||||
from .filters import (
|
||||
associate,
|
||||
batched,
|
||||
decode,
|
||||
detshuffle,
|
||||
extract_keys,
|
||||
getfirst,
|
||||
info,
|
||||
map,
|
||||
map_dict,
|
||||
map_tuple,
|
||||
pipelinefilter,
|
||||
rename,
|
||||
rename_keys,
|
||||
rsample,
|
||||
select,
|
||||
shuffle,
|
||||
slice,
|
||||
to_tuple,
|
||||
transform_with,
|
||||
unbatched,
|
||||
xdecode,
|
||||
data_filter,
|
||||
tokenize,
|
||||
resample,
|
||||
compute_fbank,
|
||||
spec_aug,
|
||||
sort,
|
||||
padding,
|
||||
cmvn
|
||||
)
|
||||
from webdataset.handlers import (
|
||||
ignore_and_continue,
|
||||
ignore_and_stop,
|
||||
reraise_exception,
|
||||
warn_and_continue,
|
||||
warn_and_stop,
|
||||
)
|
||||
from .pipeline import DataPipeline
|
||||
from .shardlists import (
|
||||
MultiShardSample,
|
||||
ResampledShards,
|
||||
SimpleShardList,
|
||||
non_empty,
|
||||
resampled,
|
||||
shardspec,
|
||||
single_node_only,
|
||||
split_by_node,
|
||||
split_by_worker,
|
||||
)
|
||||
from .tariterators import tarfile_samples, tarfile_to_samples
|
||||
from .utils import PipelineStage, repeatedly
|
||||
from webdataset.writer import ShardWriter, TarWriter, numpy_dumps
|
||||
from webdataset.mix import RandomMix, RoundRobin
|
@ -0,0 +1,190 @@
|
||||
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
# See the LICENSE file for licensing terms (BSD-style).
|
||||
# Modified from https://github.com/webdataset/webdataset
|
||||
import itertools, os, random, re, sys
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from . import filters
|
||||
from webdataset import gopen
|
||||
from webdataset.handlers import reraise_exception
|
||||
from .tariterators import tar_file_and_group_expander
|
||||
|
||||
default_cache_dir = os.environ.get("WDS_CACHE", "./_cache")
|
||||
default_cache_size = float(os.environ.get("WDS_CACHE_SIZE", "1e18"))
|
||||
|
||||
|
||||
def lru_cleanup(cache_dir, cache_size, keyfn=os.path.getctime, verbose=False):
|
||||
"""Performs cleanup of the file cache in cache_dir using an LRU strategy,
|
||||
keeping the total size of all remaining files below cache_size."""
|
||||
if not os.path.exists(cache_dir):
|
||||
return
|
||||
total_size = 0
|
||||
for dirpath, dirnames, filenames in os.walk(cache_dir):
|
||||
for filename in filenames:
|
||||
total_size += os.path.getsize(os.path.join(dirpath, filename))
|
||||
if total_size <= cache_size:
|
||||
return
|
||||
# sort files by last access time
|
||||
files = []
|
||||
for dirpath, dirnames, filenames in os.walk(cache_dir):
|
||||
for filename in filenames:
|
||||
files.append(os.path.join(dirpath, filename))
|
||||
files.sort(key=keyfn, reverse=True)
|
||||
# delete files until we're under the cache size
|
||||
while len(files) > 0 and total_size > cache_size:
|
||||
fname = files.pop()
|
||||
total_size -= os.path.getsize(fname)
|
||||
if verbose:
|
||||
print("# deleting %s" % fname, file=sys.stderr)
|
||||
os.remove(fname)
|
||||
|
||||
|
||||
def download(url, dest, chunk_size=1024 ** 2, verbose=False):
|
||||
"""Download a file from `url` to `dest`."""
|
||||
temp = dest + f".temp{os.getpid()}"
|
||||
with gopen.gopen(url) as stream:
|
||||
with open(temp, "wb") as f:
|
||||
while True:
|
||||
data = stream.read(chunk_size)
|
||||
if not data:
|
||||
break
|
||||
f.write(data)
|
||||
os.rename(temp, dest)
|
||||
|
||||
|
||||
def pipe_cleaner(spec):
|
||||
"""Guess the actual URL from a "pipe:" specification."""
|
||||
if spec.startswith("pipe:"):
|
||||
spec = spec[5:]
|
||||
words = spec.split(" ")
|
||||
for word in words:
|
||||
if re.match(r"^(https?|gs|ais|s3)", word):
|
||||
return word
|
||||
return spec
|
||||
|
||||
|
||||
def get_file_cached(
|
||||
spec,
|
||||
cache_size=-1,
|
||||
cache_dir=None,
|
||||
url_to_name=pipe_cleaner,
|
||||
verbose=False,
|
||||
):
|
||||
if cache_size == -1:
|
||||
cache_size = default_cache_size
|
||||
if cache_dir is None:
|
||||
cache_dir = default_cache_dir
|
||||
url = url_to_name(spec)
|
||||
parsed = urlparse(url)
|
||||
dirname, filename = os.path.split(parsed.path)
|
||||
dirname = dirname.lstrip("/")
|
||||
dirname = re.sub(r"[:/|;]", "_", dirname)
|
||||
destdir = os.path.join(cache_dir, dirname)
|
||||
os.makedirs(destdir, exist_ok=True)
|
||||
dest = os.path.join(cache_dir, dirname, filename)
|
||||
if not os.path.exists(dest):
|
||||
if verbose:
|
||||
print("# downloading %s to %s" % (url, dest), file=sys.stderr)
|
||||
lru_cleanup(cache_dir, cache_size, verbose=verbose)
|
||||
download(spec, dest, verbose=verbose)
|
||||
return dest
|
||||
|
||||
|
||||
def get_filetype(fname):
|
||||
with os.popen("file '%s'" % fname) as f:
|
||||
ftype = f.read()
|
||||
return ftype
|
||||
|
||||
|
||||
def check_tar_format(fname):
|
||||
"""Check whether a file is a tar archive."""
|
||||
ftype = get_filetype(fname)
|
||||
return "tar archive" in ftype or "gzip compressed" in ftype
|
||||
|
||||
|
||||
verbose_cache = int(os.environ.get("WDS_VERBOSE_CACHE", "0"))
|
||||
|
||||
|
||||
def cached_url_opener(
|
||||
data,
|
||||
handler=reraise_exception,
|
||||
cache_size=-1,
|
||||
cache_dir=None,
|
||||
url_to_name=pipe_cleaner,
|
||||
validator=check_tar_format,
|
||||
verbose=False,
|
||||
always=False,
|
||||
):
|
||||
"""Given a stream of url names (packaged in `dict(url=url)`), yield opened streams."""
|
||||
verbose = verbose or verbose_cache
|
||||
for sample in data:
|
||||
assert isinstance(sample, dict), sample
|
||||
assert "url" in sample
|
||||
url = sample["url"]
|
||||
attempts = 5
|
||||
try:
|
||||
if not always and os.path.exists(url):
|
||||
dest = url
|
||||
else:
|
||||
dest = get_file_cached(
|
||||
url,
|
||||
cache_size=cache_size,
|
||||
cache_dir=cache_dir,
|
||||
url_to_name=url_to_name,
|
||||
verbose=verbose,
|
||||
)
|
||||
if verbose:
|
||||
print("# opening %s" % dest, file=sys.stderr)
|
||||
assert os.path.exists(dest)
|
||||
if not validator(dest):
|
||||
ftype = get_filetype(dest)
|
||||
with open(dest, "rb") as f:
|
||||
data = f.read(200)
|
||||
os.remove(dest)
|
||||
raise ValueError(
|
||||
"%s (%s) is not a tar archive, but a %s, contains %s"
|
||||
% (dest, url, ftype, repr(data))
|
||||
)
|
||||
try:
|
||||
stream = open(dest, "rb")
|
||||
sample.update(stream=stream)
|
||||
yield sample
|
||||
except FileNotFoundError as exn:
|
||||
# dealing with race conditions in lru_cleanup
|
||||
attempts -= 1
|
||||
if attempts > 0:
|
||||
time.sleep(random.random() * 10)
|
||||
continue
|
||||
raise exn
|
||||
except Exception as exn:
|
||||
exn.args = exn.args + (url,)
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
def cached_tarfile_samples(
|
||||
src,
|
||||
handler=reraise_exception,
|
||||
cache_size=-1,
|
||||
cache_dir=None,
|
||||
verbose=False,
|
||||
url_to_name=pipe_cleaner,
|
||||
always=False,
|
||||
):
|
||||
streams = cached_url_opener(
|
||||
src,
|
||||
handler=handler,
|
||||
cache_size=cache_size,
|
||||
cache_dir=cache_dir,
|
||||
verbose=verbose,
|
||||
url_to_name=url_to_name,
|
||||
always=always,
|
||||
)
|
||||
samples = tar_file_and_group_expander(streams, handler=handler)
|
||||
return samples
|
||||
|
||||
|
||||
cached_tarfile_to_samples = filters.pipelinefilter(cached_tarfile_samples)
|
@ -0,0 +1,170 @@
|
||||
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
# See the LICENSE file for licensing terms (BSD-style).
|
||||
# Modified from https://github.com/webdataset/webdataset
|
||||
from dataclasses import dataclass
|
||||
from itertools import islice
|
||||
from typing import List
|
||||
|
||||
import braceexpand, yaml
|
||||
|
||||
from webdataset import autodecode
|
||||
from . import cache, filters, shardlists, tariterators
|
||||
from .filters import reraise_exception
|
||||
from .pipeline import DataPipeline
|
||||
from .paddle_utils import DataLoader, IterableDataset
|
||||
|
||||
|
||||
class FluidInterface:
|
||||
def batched(self, batchsize):
|
||||
return self.compose(filters.batched(batchsize))
|
||||
|
||||
def dynamic_batched(self, max_frames_in_batch):
|
||||
return self.compose(filter.dynamic_batched(max_frames_in_batch))
|
||||
|
||||
def unbatched(self):
|
||||
return self.compose(filters.unbatched())
|
||||
|
||||
def listed(self, batchsize, partial=True):
|
||||
return self.compose(filters.batched(), batchsize=batchsize, collation_fn=None)
|
||||
|
||||
def unlisted(self):
|
||||
return self.compose(filters.unlisted())
|
||||
|
||||
def log_keys(self, logfile=None):
|
||||
return self.compose(filters.log_keys(logfile))
|
||||
|
||||
def shuffle(self, size, **kw):
|
||||
if size < 1:
|
||||
return self
|
||||
else:
|
||||
return self.compose(filters.shuffle(size, **kw))
|
||||
|
||||
def map(self, f, handler=reraise_exception):
|
||||
return self.compose(filters.map(f, handler=handler))
|
||||
|
||||
def decode(self, *args, pre=None, post=None, only=None, partial=False, handler=reraise_exception):
|
||||
handlers = [autodecode.ImageHandler(x) if isinstance(x, str) else x for x in args]
|
||||
decoder = autodecode.Decoder(handlers, pre=pre, post=post, only=only, partial=partial)
|
||||
return self.map(decoder, handler=handler)
|
||||
|
||||
def map_dict(self, handler=reraise_exception, **kw):
|
||||
return self.compose(filters.map_dict(handler=handler, **kw))
|
||||
|
||||
def select(self, predicate, **kw):
|
||||
return self.compose(filters.select(predicate, **kw))
|
||||
|
||||
def to_tuple(self, *args, handler=reraise_exception):
|
||||
return self.compose(filters.to_tuple(*args, handler=handler))
|
||||
|
||||
def map_tuple(self, *args, handler=reraise_exception):
|
||||
return self.compose(filters.map_tuple(*args, handler=handler))
|
||||
|
||||
def slice(self, *args):
|
||||
return self.compose(filters.slice(*args))
|
||||
|
||||
def rename(self, **kw):
|
||||
return self.compose(filters.rename(**kw))
|
||||
|
||||
def rsample(self, p=0.5):
|
||||
return self.compose(filters.rsample(p))
|
||||
|
||||
def rename_keys(self, *args, **kw):
|
||||
return self.compose(filters.rename_keys(*args, **kw))
|
||||
|
||||
def extract_keys(self, *args, **kw):
|
||||
return self.compose(filters.extract_keys(*args, **kw))
|
||||
|
||||
def xdecode(self, *args, **kw):
|
||||
return self.compose(filters.xdecode(*args, **kw))
|
||||
|
||||
def data_filter(self, *args, **kw):
|
||||
return self.compose(filters.data_filter(*args, **kw))
|
||||
|
||||
def tokenize(self, *args, **kw):
|
||||
return self.compose(filters.tokenize(*args, **kw))
|
||||
|
||||
def resample(self, *args, **kw):
|
||||
return self.compose(filters.resample(*args, **kw))
|
||||
|
||||
def compute_fbank(self, *args, **kw):
|
||||
return self.compose(filters.compute_fbank(*args, **kw))
|
||||
|
||||
def spec_aug(self, *args, **kw):
|
||||
return self.compose(filters.spec_aug(*args, **kw))
|
||||
|
||||
def sort(self, size=500):
|
||||
return self.compose(filters.sort(size))
|
||||
|
||||
def padding(self):
|
||||
return self.compose(filters.padding())
|
||||
|
||||
def cmvn(self, cmvn_file):
|
||||
return self.compose(filters.cmvn(cmvn_file))
|
||||
|
||||
class WebDataset(DataPipeline, FluidInterface):
|
||||
"""Small fluid-interface wrapper for DataPipeline."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
urls,
|
||||
handler=reraise_exception,
|
||||
resampled=False,
|
||||
repeat=False,
|
||||
shardshuffle=None,
|
||||
cache_size=0,
|
||||
cache_dir=None,
|
||||
detshuffle=False,
|
||||
nodesplitter=shardlists.single_node_only,
|
||||
verbose=False,
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(urls, IterableDataset):
|
||||
assert not resampled
|
||||
self.append(urls)
|
||||
elif isinstance(urls, str) and (urls.endswith(".yaml") or urls.endswith(".yml")):
|
||||
with (open(urls)) as stream:
|
||||
spec = yaml.safe_load(stream)
|
||||
assert "datasets" in spec
|
||||
self.append(shardlists.MultiShardSample(spec))
|
||||
elif isinstance(urls, dict):
|
||||
assert "datasets" in urls
|
||||
self.append(shardlists.MultiShardSample(urls))
|
||||
elif resampled:
|
||||
self.append(shardlists.ResampledShards(urls))
|
||||
else:
|
||||
self.append(shardlists.SimpleShardList(urls))
|
||||
self.append(nodesplitter)
|
||||
self.append(shardlists.split_by_worker)
|
||||
if shardshuffle is True:
|
||||
shardshuffle = 100
|
||||
if shardshuffle is not None:
|
||||
if detshuffle:
|
||||
self.append(filters.detshuffle(shardshuffle))
|
||||
else:
|
||||
self.append(filters.shuffle(shardshuffle))
|
||||
if cache_size == 0:
|
||||
self.append(tariterators.tarfile_to_samples(handler=handler))
|
||||
else:
|
||||
assert cache_size == -1 or cache_size > 0
|
||||
self.append(
|
||||
cache.cached_tarfile_to_samples(
|
||||
handler=handler,
|
||||
verbose=verbose,
|
||||
cache_size=cache_size,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class FluidWrapper(DataPipeline, FluidInterface):
|
||||
"""Small fluid-interface wrapper for DataPipeline."""
|
||||
|
||||
def __init__(self, initial):
|
||||
super().__init__()
|
||||
self.append(initial)
|
||||
|
||||
|
||||
class WebLoader(DataPipeline, FluidInterface):
|
||||
def __init__(self, *args, **kw):
|
||||
super().__init__(DataLoader(*args, **kw))
|
@ -0,0 +1,912 @@
|
||||
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
# This file is part of the WebDataset library.
|
||||
# See the LICENSE file for licensing terms (BSD-style).
|
||||
#
|
||||
|
||||
# Modified from https://github.com/webdataset/webdataset
|
||||
# Modified from wenet(https://github.com/wenet-e2e/wenet)
|
||||
"""A collection of iterators for data transformations.
|
||||
|
||||
These functions are plain iterator functions. You can find curried versions
|
||||
in webdataset.filters, and you can find IterableDataset wrappers in
|
||||
webdataset.processing.
|
||||
"""
|
||||
|
||||
import io
|
||||
from fnmatch import fnmatch
|
||||
import re
|
||||
import itertools, os, random, sys, time
|
||||
from functools import reduce, wraps
|
||||
|
||||
import numpy as np
|
||||
|
||||
from webdataset import autodecode
|
||||
from . import utils
|
||||
from .paddle_utils import PaddleTensor
|
||||
from .utils import PipelineStage
|
||||
|
||||
from .. import backends
|
||||
from ..compliance import kaldi
|
||||
import paddle
|
||||
from ..transform.cmvn import GlobalCMVN
|
||||
from ..utils.tensor_utils import pad_sequence
|
||||
from ..transform.spec_augment import time_warp
|
||||
from ..transform.spec_augment import time_mask
|
||||
from ..transform.spec_augment import freq_mask
|
||||
|
||||
class FilterFunction(object):
|
||||
"""Helper class for currying pipeline stages.
|
||||
|
||||
We use this roundabout construct becauce it can be pickled.
|
||||
"""
|
||||
|
||||
def __init__(self, f, *args, **kw):
|
||||
"""Create a curried function."""
|
||||
self.f = f
|
||||
self.args = args
|
||||
self.kw = kw
|
||||
|
||||
def __call__(self, data):
|
||||
"""Call the curried function with the given argument."""
|
||||
return self.f(data, *self.args, **self.kw)
|
||||
|
||||
def __str__(self):
|
||||
"""Compute a string representation."""
|
||||
return f"<{self.f.__name__} {self.args} {self.kw}>"
|
||||
|
||||
def __repr__(self):
|
||||
"""Compute a string representation."""
|
||||
return f"<{self.f.__name__} {self.args} {self.kw}>"
|
||||
|
||||
|
||||
class RestCurried(object):
|
||||
"""Helper class for currying pipeline stages.
|
||||
|
||||
We use this roundabout construct because it can be pickled.
|
||||
"""
|
||||
|
||||
def __init__(self, f):
|
||||
"""Store the function for future currying."""
|
||||
self.f = f
|
||||
|
||||
def __call__(self, *args, **kw):
|
||||
"""Curry with the given arguments."""
|
||||
return FilterFunction(self.f, *args, **kw)
|
||||
|
||||
|
||||
def pipelinefilter(f):
|
||||
"""Turn the decorated function into one that is partially applied for
|
||||
all arguments other than the first."""
|
||||
result = RestCurried(f)
|
||||
return result
|
||||
|
||||
|
||||
def reraise_exception(exn):
|
||||
"""Reraises the given exception; used as a handler.
|
||||
|
||||
:param exn: exception
|
||||
"""
|
||||
raise exn
|
||||
|
||||
|
||||
def identity(x):
|
||||
"""Return the argument."""
|
||||
return x
|
||||
|
||||
|
||||
def compose2(f, g):
|
||||
"""Compose two functions, g(f(x))."""
|
||||
return lambda x: g(f(x))
|
||||
|
||||
|
||||
def compose(*args):
|
||||
"""Compose a sequence of functions (left-to-right)."""
|
||||
return reduce(compose2, args)
|
||||
|
||||
|
||||
def pipeline(source, *args):
|
||||
"""Write an input pipeline; first argument is source, rest are filters."""
|
||||
if len(args) == 0:
|
||||
return source
|
||||
return compose(*args)(source)
|
||||
|
||||
|
||||
def getfirst(a, keys, default=None, missing_is_error=True):
|
||||
"""Get the first matching key from a dictionary.
|
||||
|
||||
Keys can be specified as a list, or as a string of keys separated by ';'.
|
||||
"""
|
||||
if isinstance(keys, str):
|
||||
assert " " not in keys
|
||||
keys = keys.split(";")
|
||||
for k in keys:
|
||||
if k in a:
|
||||
return a[k]
|
||||
if missing_is_error:
|
||||
raise ValueError(f"didn't find {keys} in {list(a.keys())}")
|
||||
return default
|
||||
|
||||
|
||||
def parse_field_spec(fields):
|
||||
"""Parse a specification for a list of fields to be extracted.
|
||||
|
||||
Keys are separated by spaces in the spec. Each key can itself
|
||||
be composed of key alternatives separated by ';'.
|
||||
"""
|
||||
if isinstance(fields, str):
|
||||
fields = fields.split()
|
||||
return [field.split(";") for field in fields]
|
||||
|
||||
|
||||
def transform_with(sample, transformers):
|
||||
"""Transform a list of values using a list of functions.
|
||||
|
||||
sample: list of values
|
||||
transformers: list of functions
|
||||
|
||||
If there are fewer transformers than inputs, or if a transformer
|
||||
function is None, then the identity function is used for the
|
||||
corresponding sample fields.
|
||||
"""
|
||||
if transformers is None or len(transformers) == 0:
|
||||
return sample
|
||||
result = list(sample)
|
||||
assert len(transformers) <= len(sample)
|
||||
for i in range(len(transformers)): # skipcq: PYL-C0200
|
||||
f = transformers[i]
|
||||
if f is not None:
|
||||
result[i] = f(sample[i])
|
||||
return result
|
||||
|
||||
###
|
||||
# Iterators
|
||||
###
|
||||
|
||||
def _info(data, fmt=None, n=3, every=-1, width=50, stream=sys.stderr, name=""):
|
||||
"""Print information about the samples that are passing through.
|
||||
|
||||
:param data: source iterator
|
||||
:param fmt: format statement (using sample dict as keyword)
|
||||
:param n: when to stop
|
||||
:param every: how often to print
|
||||
:param width: maximum width
|
||||
:param stream: output stream
|
||||
:param name: identifier printed before any output
|
||||
"""
|
||||
for i, sample in enumerate(data):
|
||||
if i < n or (every > 0 and (i + 1) % every == 0):
|
||||
if fmt is None:
|
||||
print("---", name, file=stream)
|
||||
for k, v in sample.items():
|
||||
print(k, repr(v)[:width], file=stream)
|
||||
else:
|
||||
print(fmt.format(**sample), file=stream)
|
||||
yield sample
|
||||
|
||||
|
||||
info = pipelinefilter(_info)
|
||||
|
||||
|
||||
def pick(buf, rng):
|
||||
k = rng.randint(0, len(buf) - 1)
|
||||
sample = buf[k]
|
||||
buf[k] = buf[-1]
|
||||
buf.pop()
|
||||
return sample
|
||||
|
||||
|
||||
def _shuffle(data, bufsize=1000, initial=100, rng=None, handler=None):
|
||||
"""Shuffle the data in the stream.
|
||||
|
||||
This uses a buffer of size `bufsize`. Shuffling at
|
||||
startup is less random; this is traded off against
|
||||
yielding samples quickly.
|
||||
|
||||
data: iterator
|
||||
bufsize: buffer size for shuffling
|
||||
returns: iterator
|
||||
rng: either random module or random.Random instance
|
||||
|
||||
"""
|
||||
if rng is None:
|
||||
rng = random.Random(int((os.getpid() + time.time()) * 1e9))
|
||||
initial = min(initial, bufsize)
|
||||
buf = []
|
||||
for sample in data:
|
||||
buf.append(sample)
|
||||
if len(buf) < bufsize:
|
||||
try:
|
||||
buf.append(next(data)) # skipcq: PYL-R1708
|
||||
except StopIteration:
|
||||
pass
|
||||
if len(buf) >= initial:
|
||||
yield pick(buf, rng)
|
||||
while len(buf) > 0:
|
||||
yield pick(buf, rng)
|
||||
|
||||
|
||||
shuffle = pipelinefilter(_shuffle)
|
||||
|
||||
|
||||
class detshuffle(PipelineStage):
|
||||
def __init__(self, bufsize=1000, initial=100, seed=0, epoch=-1):
|
||||
self.bufsize = bufsize
|
||||
self.initial = initial
|
||||
self.seed = seed
|
||||
self.epoch = epoch
|
||||
|
||||
def run(self, src):
|
||||
self.epoch += 1
|
||||
rng = random.Random()
|
||||
rng.seed((self.seed, self.epoch))
|
||||
return _shuffle(src, self.bufsize, self.initial, rng)
|
||||
|
||||
|
||||
def _select(data, predicate):
|
||||
"""Select samples based on a predicate.
|
||||
|
||||
:param data: source iterator
|
||||
:param predicate: predicate (function)
|
||||
"""
|
||||
for sample in data:
|
||||
if predicate(sample):
|
||||
yield sample
|
||||
|
||||
|
||||
select = pipelinefilter(_select)
|
||||
|
||||
|
||||
def _log_keys(data, logfile=None):
|
||||
import fcntl
|
||||
|
||||
if logfile is None or logfile == "":
|
||||
for sample in data:
|
||||
yield sample
|
||||
else:
|
||||
with open(logfile, "a") as stream:
|
||||
for i, sample in enumerate(data):
|
||||
buf = f"{i}\t{sample.get('__worker__')}\t{sample.get('__rank__')}\t{sample.get('__key__')}\n"
|
||||
try:
|
||||
fcntl.flock(stream.fileno(), fcntl.LOCK_EX)
|
||||
stream.write(buf)
|
||||
finally:
|
||||
fcntl.flock(stream.fileno(), fcntl.LOCK_UN)
|
||||
yield sample
|
||||
|
||||
|
||||
log_keys = pipelinefilter(_log_keys)
|
||||
|
||||
|
||||
def _decode(data, *args, handler=reraise_exception, **kw):
|
||||
"""Decode data based on the decoding functions given as arguments."""
|
||||
|
||||
decoder = lambda x: autodecode.imagehandler(x) if isinstance(x, str) else x
|
||||
handlers = [decoder(x) for x in args]
|
||||
f = autodecode.Decoder(handlers, **kw)
|
||||
|
||||
for sample in data:
|
||||
assert isinstance(sample, dict), sample
|
||||
try:
|
||||
decoded = f(sample)
|
||||
except Exception as exn: # skipcq: PYL-W0703
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
yield decoded
|
||||
|
||||
|
||||
decode = pipelinefilter(_decode)
|
||||
|
||||
|
||||
def _map(data, f, handler=reraise_exception):
|
||||
"""Map samples."""
|
||||
for sample in data:
|
||||
try:
|
||||
result = f(sample)
|
||||
except Exception as exn:
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
if result is None:
|
||||
continue
|
||||
if isinstance(sample, dict) and isinstance(result, dict):
|
||||
result["__key__"] = sample.get("__key__")
|
||||
yield result
|
||||
|
||||
|
||||
map = pipelinefilter(_map)
|
||||
|
||||
|
||||
def _rename(data, handler=reraise_exception, keep=True, **kw):
|
||||
"""Rename samples based on keyword arguments."""
|
||||
for sample in data:
|
||||
try:
|
||||
if not keep:
|
||||
yield {k: getfirst(sample, v, missing_is_error=True) for k, v in kw.items()}
|
||||
else:
|
||||
|
||||
def listify(v):
|
||||
return v.split(";") if isinstance(v, str) else v
|
||||
|
||||
to_be_replaced = {x for v in kw.values() for x in listify(v)}
|
||||
result = {k: v for k, v in sample.items() if k not in to_be_replaced}
|
||||
result.update({k: getfirst(sample, v, missing_is_error=True) for k, v in kw.items()})
|
||||
yield result
|
||||
except Exception as exn:
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
rename = pipelinefilter(_rename)
|
||||
|
||||
|
||||
def _associate(data, associator, **kw):
|
||||
"""Associate additional data with samples."""
|
||||
for sample in data:
|
||||
if callable(associator):
|
||||
extra = associator(sample["__key__"])
|
||||
else:
|
||||
extra = associator.get(sample["__key__"], {})
|
||||
sample.update(extra) # destructive
|
||||
yield sample
|
||||
|
||||
|
||||
associate = pipelinefilter(_associate)
|
||||
|
||||
|
||||
def _map_dict(data, handler=reraise_exception, **kw):
|
||||
"""Map the entries in a dict sample with individual functions."""
|
||||
assert len(list(kw.keys())) > 0
|
||||
for key, f in kw.items():
|
||||
assert callable(f), (key, f)
|
||||
|
||||
for sample in data:
|
||||
assert isinstance(sample, dict)
|
||||
try:
|
||||
for k, f in kw.items():
|
||||
sample[k] = f(sample[k])
|
||||
except Exception as exn:
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
yield sample
|
||||
|
||||
|
||||
map_dict = pipelinefilter(_map_dict)
|
||||
|
||||
|
||||
def _to_tuple(data, *args, handler=reraise_exception, missing_is_error=True, none_is_error=None):
|
||||
"""Convert dict samples to tuples."""
|
||||
if none_is_error is None:
|
||||
none_is_error = missing_is_error
|
||||
if len(args) == 1 and isinstance(args[0], str) and " " in args[0]:
|
||||
args = args[0].split()
|
||||
|
||||
for sample in data:
|
||||
try:
|
||||
result = tuple([getfirst(sample, f, missing_is_error=missing_is_error) for f in args])
|
||||
if none_is_error and any(x is None for x in result):
|
||||
raise ValueError(f"to_tuple {args} got {sample.keys()}")
|
||||
yield result
|
||||
except Exception as exn:
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
to_tuple = pipelinefilter(_to_tuple)
|
||||
|
||||
|
||||
def _map_tuple(data, *args, handler=reraise_exception):
|
||||
"""Map the entries of a tuple with individual functions."""
|
||||
args = [f if f is not None else utils.identity for f in args]
|
||||
for f in args:
|
||||
assert callable(f), f
|
||||
for sample in data:
|
||||
assert isinstance(sample, (list, tuple))
|
||||
sample = list(sample)
|
||||
n = min(len(args), len(sample))
|
||||
try:
|
||||
for i in range(n):
|
||||
sample[i] = args[i](sample[i])
|
||||
except Exception as exn:
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
yield tuple(sample)
|
||||
|
||||
|
||||
map_tuple = pipelinefilter(_map_tuple)
|
||||
|
||||
|
||||
def _unlisted(data):
|
||||
"""Turn batched data back into unbatched data."""
|
||||
for batch in data:
|
||||
assert isinstance(batch, list), sample
|
||||
for sample in batch:
|
||||
yield sample
|
||||
|
||||
|
||||
unlisted = pipelinefilter(_unlisted)
|
||||
|
||||
|
||||
def _unbatched(data):
|
||||
"""Turn batched data back into unbatched data."""
|
||||
for sample in data:
|
||||
assert isinstance(sample, (tuple, list)), sample
|
||||
assert len(sample) > 0
|
||||
for i in range(len(sample[0])):
|
||||
yield tuple(x[i] for x in sample)
|
||||
|
||||
|
||||
unbatched = pipelinefilter(_unbatched)
|
||||
|
||||
|
||||
def _rsample(data, p=0.5):
|
||||
"""Randomly subsample a stream of data."""
|
||||
assert p >= 0.0 and p <= 1.0
|
||||
for sample in data:
|
||||
if random.uniform(0.0, 1.0) < p:
|
||||
yield sample
|
||||
|
||||
|
||||
rsample = pipelinefilter(_rsample)
|
||||
|
||||
slice = pipelinefilter(itertools.islice)
|
||||
|
||||
|
||||
def _extract_keys(source, *patterns, duplicate_is_error=True, ignore_missing=False):
|
||||
for sample in source:
|
||||
result = []
|
||||
for pattern in patterns:
|
||||
pattern = pattern.split(";") if isinstance(pattern, str) else pattern
|
||||
matches = [x for x in sample.keys() if any(fnmatch("." + x, p) for p in pattern)]
|
||||
if len(matches) == 0:
|
||||
if ignore_missing:
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Cannot find {pattern} in sample keys {sample.keys()}.")
|
||||
if len(matches) > 1 and duplicate_is_error:
|
||||
raise ValueError(f"Multiple sample keys {sample.keys()} match {pattern}.")
|
||||
value = sample[matches[0]]
|
||||
result.append(value)
|
||||
yield tuple(result)
|
||||
|
||||
|
||||
extract_keys = pipelinefilter(_extract_keys)
|
||||
|
||||
|
||||
def _rename_keys(source, *args, keep_unselected=False, must_match=True, duplicate_is_error=True, **kw):
|
||||
renamings = [(pattern, output) for output, pattern in args]
|
||||
renamings += [(pattern, output) for output, pattern in kw.items()]
|
||||
for sample in source:
|
||||
new_sample = {}
|
||||
matched = {k: False for k, _ in renamings}
|
||||
for path, value in sample.items():
|
||||
fname = re.sub(r".*/", "", path)
|
||||
new_name = None
|
||||
for pattern, name in renamings[::-1]:
|
||||
if fnmatch(fname.lower(), pattern):
|
||||
matched[pattern] = True
|
||||
new_name = name
|
||||
break
|
||||
if new_name is None:
|
||||
if keep_unselected:
|
||||
new_sample[path] = value
|
||||
continue
|
||||
if new_name in new_sample:
|
||||
if duplicate_is_error:
|
||||
raise ValueError(f"Duplicate value in sample {sample.keys()} after rename.")
|
||||
continue
|
||||
new_sample[new_name] = value
|
||||
if must_match and not all(matched.values()):
|
||||
raise ValueError(f"Not all patterns ({matched}) matched sample keys ({sample.keys()}).")
|
||||
|
||||
yield new_sample
|
||||
|
||||
|
||||
rename_keys = pipelinefilter(_rename_keys)
|
||||
|
||||
|
||||
def decode_bin(stream):
|
||||
return stream.read()
|
||||
|
||||
|
||||
def decode_text(stream):
|
||||
binary = stream.read()
|
||||
return binary.decode("utf-8")
|
||||
|
||||
|
||||
def decode_pickle(stream):
|
||||
return pickle.load(stream)
|
||||
|
||||
|
||||
default_decoders = [
|
||||
("*.bin", decode_bin),
|
||||
("*.txt", decode_text),
|
||||
("*.pyd", decode_pickle),
|
||||
]
|
||||
|
||||
|
||||
def find_decoder(decoders, path):
|
||||
fname = re.sub(r".*/", "", path)
|
||||
if fname.startswith("__"):
|
||||
return lambda x: x
|
||||
for pattern, fun in decoders[::-1]:
|
||||
if fnmatch(fname.lower(), pattern) or fnmatch("." + fname.lower(), pattern):
|
||||
return fun
|
||||
return None
|
||||
|
||||
|
||||
def _xdecode(
|
||||
source,
|
||||
*args,
|
||||
must_decode=True,
|
||||
defaults=default_decoders,
|
||||
**kw,
|
||||
):
|
||||
decoders = list(defaults) + list(args)
|
||||
decoders += [("*." + k, v) for k, v in kw.items()]
|
||||
for sample in source:
|
||||
new_sample = {}
|
||||
for path, data in sample.items():
|
||||
if path.startswith("__"):
|
||||
new_sample[path] = data
|
||||
continue
|
||||
decoder = find_decoder(decoders, path)
|
||||
if decoder is False:
|
||||
value = data
|
||||
elif decoder is None:
|
||||
if must_decode:
|
||||
raise ValueError(f"No decoder found for {path}.")
|
||||
value = data
|
||||
else:
|
||||
if isinstance(data, bytes):
|
||||
data = io.BytesIO(data)
|
||||
value = decoder(data)
|
||||
new_sample[path] = value
|
||||
yield new_sample
|
||||
|
||||
xdecode = pipelinefilter(_xdecode)
|
||||
|
||||
|
||||
|
||||
def _data_filter(source,
|
||||
frame_shift=10,
|
||||
max_length=10240,
|
||||
min_length=10,
|
||||
token_max_length=200,
|
||||
token_min_length=1,
|
||||
min_output_input_ratio=0.0005,
|
||||
max_output_input_ratio=1):
|
||||
""" Filter sample according to feature and label length
|
||||
Inplace operation.
|
||||
|
||||
Args::
|
||||
source: Iterable[{fname, wav, label, sample_rate}]
|
||||
frame_shift: length of frame shift (ms)
|
||||
max_length: drop utterance which is greater than max_length(10ms)
|
||||
min_length: drop utterance which is less than min_length(10ms)
|
||||
token_max_length: drop utterance which is greater than
|
||||
token_max_length, especially when use char unit for
|
||||
english modeling
|
||||
token_min_length: drop utterance which is
|
||||
less than token_max_length
|
||||
min_output_input_ratio: minimal ration of
|
||||
token_length / feats_length(10ms)
|
||||
max_output_input_ratio: maximum ration of
|
||||
token_length / feats_length(10ms)
|
||||
|
||||
Returns:
|
||||
Iterable[{fname, wav, label, sample_rate}]
|
||||
"""
|
||||
for sample in source:
|
||||
assert 'sample_rate' in sample
|
||||
assert 'wav' in sample
|
||||
assert 'label' in sample
|
||||
# sample['wav'] is paddle.Tensor, we have 100 frames every second (default)
|
||||
num_frames = sample['wav'].shape[1] / sample['sample_rate'] * (1000 / frame_shift)
|
||||
if num_frames < min_length:
|
||||
continue
|
||||
if num_frames > max_length:
|
||||
continue
|
||||
if len(sample['label']) < token_min_length:
|
||||
continue
|
||||
if len(sample['label']) > token_max_length:
|
||||
continue
|
||||
if num_frames != 0:
|
||||
if len(sample['label']) / num_frames < min_output_input_ratio:
|
||||
continue
|
||||
if len(sample['label']) / num_frames > max_output_input_ratio:
|
||||
continue
|
||||
yield sample
|
||||
|
||||
data_filter = pipelinefilter(_data_filter)
|
||||
|
||||
def _tokenize(source,
|
||||
symbol_table,
|
||||
bpe_model=None,
|
||||
non_lang_syms=None,
|
||||
split_with_space=False):
|
||||
""" Decode text to chars or BPE
|
||||
Inplace operation
|
||||
|
||||
Args:
|
||||
source: Iterable[{fname, wav, txt, sample_rate}]
|
||||
|
||||
Returns:
|
||||
Iterable[{fname, wav, txt, tokens, label, sample_rate}]
|
||||
"""
|
||||
if non_lang_syms is not None:
|
||||
non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})")
|
||||
else:
|
||||
non_lang_syms = {}
|
||||
non_lang_syms_pattern = None
|
||||
|
||||
if bpe_model is not None:
|
||||
import sentencepiece as spm
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(bpe_model)
|
||||
else:
|
||||
sp = None
|
||||
|
||||
for sample in source:
|
||||
assert 'txt' in sample
|
||||
txt = sample['txt'].strip()
|
||||
if non_lang_syms_pattern is not None:
|
||||
parts = non_lang_syms_pattern.split(txt.upper())
|
||||
parts = [w for w in parts if len(w.strip()) > 0]
|
||||
else:
|
||||
parts = [txt]
|
||||
|
||||
label = []
|
||||
tokens = []
|
||||
for part in parts:
|
||||
if part in non_lang_syms:
|
||||
tokens.append(part)
|
||||
else:
|
||||
if bpe_model is not None:
|
||||
tokens.extend(__tokenize_by_bpe_model(sp, part))
|
||||
else:
|
||||
if split_with_space:
|
||||
part = part.split(" ")
|
||||
for ch in part:
|
||||
if ch == ' ':
|
||||
ch = "<space>"
|
||||
tokens.append(ch)
|
||||
|
||||
for ch in tokens:
|
||||
if ch in symbol_table:
|
||||
label.append(symbol_table[ch])
|
||||
elif '<unk>' in symbol_table:
|
||||
label.append(symbol_table['<unk>'])
|
||||
|
||||
sample['tokens'] = tokens
|
||||
sample['label'] = label
|
||||
yield sample
|
||||
|
||||
tokenize = pipelinefilter(_tokenize)
|
||||
|
||||
def _resample(source, resample_rate=16000):
|
||||
""" Resample data.
|
||||
Inplace operation.
|
||||
|
||||
Args:
|
||||
data: Iterable[{fname, wav, label, sample_rate}]
|
||||
resample_rate: target resample rate
|
||||
|
||||
Returns:
|
||||
Iterable[{fname, wav, label, sample_rate}]
|
||||
"""
|
||||
for sample in source:
|
||||
assert 'sample_rate' in sample
|
||||
assert 'wav' in sample
|
||||
sample_rate = sample['sample_rate']
|
||||
waveform = sample['wav']
|
||||
if sample_rate != resample_rate:
|
||||
sample['sample_rate'] = resample_rate
|
||||
sample['wav'] = paddle.to_tensor(backends.soundfile_backend.resample(
|
||||
waveform.numpy(), src_sr = sample_rate, target_sr = resample_rate
|
||||
))
|
||||
yield sample
|
||||
|
||||
resample = pipelinefilter(_resample)
|
||||
|
||||
def _compute_fbank(source,
|
||||
num_mel_bins=80,
|
||||
frame_length=25,
|
||||
frame_shift=10,
|
||||
dither=0.0):
|
||||
""" Extract fbank
|
||||
|
||||
Args:
|
||||
source: Iterable[{fname, wav, label, sample_rate}]
|
||||
num_mel_bins: number of mel filter bank
|
||||
frame_length: length of one frame (ms)
|
||||
frame_shift: length of frame shift (ms)
|
||||
dither: value of dither
|
||||
|
||||
Returns:
|
||||
Iterable[{fname, feat, label}]
|
||||
"""
|
||||
for sample in source:
|
||||
assert 'sample_rate' in sample
|
||||
assert 'wav' in sample
|
||||
assert 'fname' in sample
|
||||
assert 'label' in sample
|
||||
sample_rate = sample['sample_rate']
|
||||
waveform = sample['wav']
|
||||
waveform = waveform * (1 << 15)
|
||||
# Only keep fname, feat, label
|
||||
mat = kaldi.fbank(waveform,
|
||||
n_mels=num_mel_bins,
|
||||
frame_length=frame_length,
|
||||
frame_shift=frame_shift,
|
||||
dither=dither,
|
||||
energy_floor=0.0,
|
||||
sr=sample_rate)
|
||||
yield dict(fname=sample['fname'], label=sample['label'], feat=mat)
|
||||
|
||||
|
||||
compute_fbank = pipelinefilter(_compute_fbank)
|
||||
|
||||
def _spec_aug(source, num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80):
|
||||
""" Do spec augmentation
|
||||
Inplace operation
|
||||
|
||||
Args:
|
||||
source: Iterable[{fname, feat, label}]
|
||||
num_t_mask: number of time mask to apply
|
||||
num_f_mask: number of freq mask to apply
|
||||
max_t: max width of time mask
|
||||
max_f: max width of freq mask
|
||||
max_w: max width of time warp
|
||||
|
||||
Returns
|
||||
Iterable[{fname, feat, label}]
|
||||
"""
|
||||
for sample in source:
|
||||
x = sample['feat']
|
||||
x = x.numpy()
|
||||
x = time_warp(x, max_time_warp=max_w, inplace = True, mode= "PIL")
|
||||
x = freq_mask(x, F = max_f, n_mask = num_f_mask, inplace = True, replace_with_zero = False)
|
||||
x = time_mask(x, T = max_t, n_mask = num_t_mask, inplace = True, replace_with_zero = False)
|
||||
sample['feat'] = paddle.to_tensor(x, dtype=paddle.float32)
|
||||
yield sample
|
||||
|
||||
spec_aug = pipelinefilter(_spec_aug)
|
||||
|
||||
|
||||
def _sort(source, sort_size=500):
|
||||
""" Sort the data by feature length.
|
||||
Sort is used after shuffle and before batch, so we can group
|
||||
utts with similar lengths into a batch, and `sort_size` should
|
||||
be less than `shuffle_size`
|
||||
|
||||
Args:
|
||||
source: Iterable[{fname, feat, label}]
|
||||
sort_size: buffer size for sort
|
||||
|
||||
Returns:
|
||||
Iterable[{fname, feat, label}]
|
||||
"""
|
||||
|
||||
buf = []
|
||||
for sample in source:
|
||||
buf.append(sample)
|
||||
if len(buf) >= sort_size:
|
||||
buf.sort(key=lambda x: x['feat'].shape[0])
|
||||
for x in buf:
|
||||
yield x
|
||||
buf = []
|
||||
# The sample left over
|
||||
buf.sort(key=lambda x: x['feat'].shape[0])
|
||||
for x in buf:
|
||||
yield x
|
||||
|
||||
sort = pipelinefilter(_sort)
|
||||
|
||||
def _batched(source, batch_size=16):
|
||||
""" Static batch the data by `batch_size`
|
||||
|
||||
Args:
|
||||
data: Iterable[{fname, feat, label}]
|
||||
batch_size: batch size
|
||||
|
||||
Returns:
|
||||
Iterable[List[{fname, feat, label}]]
|
||||
"""
|
||||
buf = []
|
||||
for sample in source:
|
||||
buf.append(sample)
|
||||
if len(buf) >= batch_size:
|
||||
yield buf
|
||||
buf = []
|
||||
if len(buf) > 0:
|
||||
yield buf
|
||||
|
||||
batched = pipelinefilter(_batched)
|
||||
|
||||
def dynamic_batched(source, max_frames_in_batch=12000):
|
||||
""" Dynamic batch the data until the total frames in batch
|
||||
reach `max_frames_in_batch`
|
||||
|
||||
Args:
|
||||
source: Iterable[{fname, feat, label}]
|
||||
max_frames_in_batch: max_frames in one batch
|
||||
|
||||
Returns:
|
||||
Iterable[List[{fname, feat, label}]]
|
||||
"""
|
||||
buf = []
|
||||
longest_frames = 0
|
||||
for sample in source:
|
||||
assert 'feat' in sample
|
||||
assert isinstance(sample['feat'], paddle.Tensor)
|
||||
new_sample_frames = sample['feat'].size(0)
|
||||
longest_frames = max(longest_frames, new_sample_frames)
|
||||
frames_after_padding = longest_frames * (len(buf) + 1)
|
||||
if frames_after_padding > max_frames_in_batch:
|
||||
yield buf
|
||||
buf = [sample]
|
||||
longest_frames = new_sample_frames
|
||||
else:
|
||||
buf.append(sample)
|
||||
if len(buf) > 0:
|
||||
yield buf
|
||||
|
||||
|
||||
def _padding(source):
|
||||
""" Padding the data into training data
|
||||
|
||||
Args:
|
||||
source: Iterable[List[{fname, feat, label}]]
|
||||
|
||||
Returns:
|
||||
Iterable[Tuple(fname, feats, labels, feats lengths, label lengths)]
|
||||
"""
|
||||
for sample in source:
|
||||
assert isinstance(sample, list)
|
||||
feats_length = paddle.to_tensor([x['feat'].shape[0] for x in sample],
|
||||
dtype="int64")
|
||||
order = paddle.argsort(feats_length, descending=True)
|
||||
feats_lengths = paddle.to_tensor(
|
||||
[sample[i]['feat'].shape[0] for i in order], dtype="int64")
|
||||
sorted_feats = [sample[i]['feat'] for i in order]
|
||||
sorted_keys = [sample[i]['fname'] for i in order]
|
||||
sorted_labels = [
|
||||
paddle.to_tensor(sample[i]['label'], dtype="int32") for i in order
|
||||
]
|
||||
label_lengths = paddle.to_tensor([x.shape[0] for x in sorted_labels],
|
||||
dtype="int64")
|
||||
padded_feats = pad_sequence(sorted_feats,
|
||||
batch_first=True,
|
||||
padding_value=0)
|
||||
padding_labels = pad_sequence(sorted_labels,
|
||||
batch_first=True,
|
||||
padding_value=-1)
|
||||
|
||||
yield (sorted_keys, padded_feats, feats_lengths, padding_labels,
|
||||
label_lengths)
|
||||
|
||||
padding = pipelinefilter(_padding)
|
||||
|
||||
def _cmvn(source, cmvn_file):
|
||||
global_cmvn = GlobalCMVN(cmvn_file)
|
||||
for batch in source:
|
||||
sorted_keys, padded_feats, feats_lengths, padding_labels, label_lengths = batch
|
||||
padded_feats = padded_feats.numpy()
|
||||
padded_feats = global_cmvn(padded_feats)
|
||||
padded_feats = paddle.to_tensor(padded_feats, dtype=paddle.float32)
|
||||
yield (sorted_keys, padded_feats, feats_lengths, padding_labels,
|
||||
label_lengths)
|
||||
|
||||
cmvn = pipelinefilter(_cmvn)
|
@ -0,0 +1,33 @@
|
||||
#
|
||||
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
# This file is part of the WebDataset library.
|
||||
# See the LICENSE file for licensing terms (BSD-style).
|
||||
# Modified from https://github.com/webdataset/webdataset
|
||||
#
|
||||
|
||||
"""Mock implementations of paddle interfaces when paddle is not available."""
|
||||
|
||||
|
||||
try:
|
||||
from paddle.io import DataLoader, IterableDataset
|
||||
except ModuleNotFoundError:
|
||||
|
||||
class IterableDataset:
|
||||
"""Empty implementation of IterableDataset when paddle is not available."""
|
||||
|
||||
pass
|
||||
|
||||
class DataLoader:
|
||||
"""Empty implementation of DataLoader when paddle is not available."""
|
||||
|
||||
pass
|
||||
|
||||
try:
|
||||
from paddle import Tensor as PaddleTensor
|
||||
except ModuleNotFoundError:
|
||||
|
||||
class TorchTensor:
|
||||
"""Empty implementation of PaddleTensor when paddle is not available."""
|
||||
|
||||
pass
|
@ -0,0 +1,127 @@
|
||||
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
# See the LICENSE file for licensing terms (BSD-style).
|
||||
# Modified from https://github.com/webdataset/webdataset
|
||||
#%%
|
||||
import copy, os, random, sys, time
|
||||
from dataclasses import dataclass
|
||||
from itertools import islice
|
||||
from typing import List
|
||||
|
||||
import braceexpand, yaml
|
||||
|
||||
from webdataset import autodecode, extradatasets as eds, filters, shardlists, tariterators
|
||||
from webdataset.handlers import reraise_exception
|
||||
from .paddle_utils import DataLoader, IterableDataset
|
||||
from .utils import PipelineStage
|
||||
|
||||
|
||||
def add_length_method(obj):
|
||||
def length(self):
|
||||
return self.size
|
||||
|
||||
Combined = type(
|
||||
obj.__class__.__name__ + "_Length",
|
||||
(obj.__class__, IterableDataset),
|
||||
{"__len__": length},
|
||||
)
|
||||
obj.__class__ = Combined
|
||||
return obj
|
||||
|
||||
|
||||
class DataPipeline(IterableDataset, PipelineStage):
|
||||
"""A pipeline starting with an IterableDataset and a series of filters."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.pipeline = []
|
||||
self.length = -1
|
||||
self.repetitions = 1
|
||||
self.nsamples = -1
|
||||
for arg in args:
|
||||
if arg is None:
|
||||
continue
|
||||
if isinstance(arg, list):
|
||||
self.pipeline.extend(arg)
|
||||
else:
|
||||
self.pipeline.append(arg)
|
||||
|
||||
def invoke(self, f, *args, **kwargs):
|
||||
"""Apply a pipeline stage, possibly to the output of a previous stage."""
|
||||
if isinstance(f, PipelineStage):
|
||||
return f.run(*args, **kwargs)
|
||||
if isinstance(f, (IterableDataset, DataLoader)) and len(args) == 0:
|
||||
return iter(f)
|
||||
if isinstance(f, list):
|
||||
return iter(f)
|
||||
if callable(f):
|
||||
result = f(*args, **kwargs)
|
||||
return result
|
||||
raise ValueError(f"{f}: not a valid pipeline stage")
|
||||
|
||||
def iterator1(self):
|
||||
"""Create an iterator through one epoch in the pipeline."""
|
||||
source = self.invoke(self.pipeline[0])
|
||||
for step in self.pipeline[1:]:
|
||||
source = self.invoke(step, source)
|
||||
return source
|
||||
|
||||
def iterator(self):
|
||||
"""Create an iterator through the entire dataset, using the given number of repetitions."""
|
||||
for i in range(self.repetitions):
|
||||
for sample in self.iterator1():
|
||||
yield sample
|
||||
|
||||
def __iter__(self):
|
||||
"""Create an iterator through the pipeline, repeating and slicing as requested."""
|
||||
if self.repetitions != 1:
|
||||
if self.nsamples > 0:
|
||||
return islice(self.iterator(), self.nsamples)
|
||||
else:
|
||||
return self.iterator()
|
||||
else:
|
||||
return self.iterator()
|
||||
|
||||
def stage(self, i):
|
||||
"""Return pipeline stage i."""
|
||||
return self.pipeline[i]
|
||||
|
||||
def append(self, f):
|
||||
"""Append a pipeline stage (modifies the object)."""
|
||||
self.pipeline.append(f)
|
||||
|
||||
def compose(self, *args):
|
||||
"""Append a pipeline stage to a copy of the pipeline and returns the copy."""
|
||||
result = copy.copy(self)
|
||||
for arg in args:
|
||||
result.append(arg)
|
||||
return result
|
||||
|
||||
def with_length(self, n):
|
||||
"""Add a __len__ method returning the desired value.
|
||||
|
||||
This does not change the actual number of samples in an epoch.
|
||||
PyTorch IterableDataset should not have a __len__ method.
|
||||
This is provided only as a workaround for some broken training environments
|
||||
that require a __len__ method.
|
||||
"""
|
||||
self.size = n
|
||||
return add_length_method(self)
|
||||
|
||||
def with_epoch(self, nsamples=-1, nbatches=-1):
|
||||
"""Change the epoch to return the given number of samples/batches.
|
||||
|
||||
The two arguments mean the same thing."""
|
||||
self.repetitions = sys.maxsize
|
||||
self.nsamples = max(nsamples, nbatches)
|
||||
return self
|
||||
|
||||
def repeat(self, nepochs=-1, nbatches=-1):
|
||||
"""Repeat iterating through the dataset for the given #epochs up to the given #samples."""
|
||||
if nepochs > 0:
|
||||
self.repetitions = nepochs
|
||||
self.nsamples = nbatches
|
||||
else:
|
||||
self.repetitions = sys.maxsize
|
||||
self.nsamples = nbatches
|
||||
return self
|
@ -0,0 +1,257 @@
|
||||
#
|
||||
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
# This file is part of the WebDataset library.
|
||||
# See the LICENSE file for licensing terms (BSD-style).
|
||||
#
|
||||
|
||||
# Modified from https://github.com/webdataset/webdataset
|
||||
|
||||
"""Train PyTorch models directly from POSIX tar archive.
|
||||
|
||||
Code works locally or over HTTP connections.
|
||||
"""
|
||||
|
||||
import os, random, sys, time
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import islice
|
||||
from typing import List
|
||||
|
||||
import braceexpand, yaml
|
||||
|
||||
from . import utils
|
||||
from .filters import pipelinefilter
|
||||
from .paddle_utils import IterableDataset
|
||||
|
||||
|
||||
def expand_urls(urls):
|
||||
if isinstance(urls, str):
|
||||
urllist = urls.split("::")
|
||||
result = []
|
||||
for url in urllist:
|
||||
result.extend(braceexpand.braceexpand(url))
|
||||
return result
|
||||
else:
|
||||
return list(urls)
|
||||
|
||||
|
||||
class SimpleShardList(IterableDataset):
|
||||
"""An iterable dataset yielding a list of urls."""
|
||||
|
||||
def __init__(self, urls, seed=None):
|
||||
"""Iterate through the list of shards.
|
||||
|
||||
:param urls: a list of URLs as a Python list or brace notation string
|
||||
"""
|
||||
super().__init__()
|
||||
urls = expand_urls(urls)
|
||||
self.urls = urls
|
||||
assert isinstance(self.urls[0], str)
|
||||
self.seed = seed
|
||||
|
||||
def __len__(self):
|
||||
return len(self.urls)
|
||||
|
||||
def __iter__(self):
|
||||
"""Return an iterator over the shards."""
|
||||
urls = self.urls.copy()
|
||||
if self.seed is not None:
|
||||
random.Random(self.seed).shuffle(urls)
|
||||
for url in urls:
|
||||
yield dict(url=url)
|
||||
|
||||
|
||||
def split_by_node(src, group=None):
|
||||
rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group)
|
||||
if world_size > 1:
|
||||
for s in islice(src, rank, None, world_size):
|
||||
yield s
|
||||
else:
|
||||
for s in src:
|
||||
yield s
|
||||
|
||||
|
||||
def single_node_only(src, group=None):
|
||||
rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group)
|
||||
if world_size > 1:
|
||||
raise ValueError("input pipeline needs to be reconfigured for multinode training")
|
||||
for s in src:
|
||||
yield s
|
||||
|
||||
|
||||
def split_by_worker(src):
|
||||
rank, world_size, worker, num_workers = utils.paddle_worker_info()
|
||||
if num_workers > 1:
|
||||
for s in islice(src, worker, None, num_workers):
|
||||
yield s
|
||||
else:
|
||||
for s in src:
|
||||
yield s
|
||||
|
||||
|
||||
def resampled_(src, n=sys.maxsize):
|
||||
import random
|
||||
|
||||
seed = time.time()
|
||||
try:
|
||||
seed = open("/dev/random", "rb").read(20)
|
||||
except Exception as exn:
|
||||
print(repr(exn)[:50], file=sys.stderr)
|
||||
rng = random.Random(seed)
|
||||
print("# resampled loading", file=sys.stderr)
|
||||
items = list(src)
|
||||
print(f"# resampled got {len(items)} samples, yielding {n}", file=sys.stderr)
|
||||
for i in range(n):
|
||||
yield rng.choice(items)
|
||||
|
||||
|
||||
resampled = pipelinefilter(resampled_)
|
||||
|
||||
|
||||
def non_empty(src):
|
||||
count = 0
|
||||
for s in src:
|
||||
yield s
|
||||
count += 1
|
||||
if count == 0:
|
||||
raise ValueError("pipeline stage received no data at all and this was declared as an error")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MSSource:
|
||||
"""Class representing a data source."""
|
||||
|
||||
name: str = ""
|
||||
perepoch: int = -1
|
||||
resample: bool = False
|
||||
urls: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
default_rng = random.Random()
|
||||
|
||||
|
||||
def expand(s):
|
||||
return os.path.expanduser(os.path.expandvars(s))
|
||||
|
||||
|
||||
class MultiShardSample(IterableDataset):
|
||||
def __init__(self, fname):
|
||||
"""Construct a shardlist from multiple sources using a YAML spec."""
|
||||
self.epoch = -1
|
||||
class MultiShardSample(IterableDataset):
|
||||
def __init__(self, fname):
|
||||
"""Construct a shardlist from multiple sources using a YAML spec."""
|
||||
self.epoch = -1
|
||||
self.parse_spec(fname)
|
||||
|
||||
def parse_spec(self, fname):
|
||||
self.rng = default_rng # capture default_rng if we fork
|
||||
if isinstance(fname, dict):
|
||||
spec = fname
|
||||
fname = "{dict}"
|
||||
else:
|
||||
with open(fname) as stream:
|
||||
spec = yaml.safe_load(stream)
|
||||
assert set(spec.keys()).issubset(set("prefix datasets buckets".split())), list(spec.keys())
|
||||
prefix = expand(spec.get("prefix", ""))
|
||||
self.sources = []
|
||||
for ds in spec["datasets"]:
|
||||
assert set(ds.keys()).issubset(set("buckets name shards resample choose".split())), list(
|
||||
ds.keys()
|
||||
)
|
||||
buckets = ds.get("buckets", spec.get("buckets", []))
|
||||
if isinstance(buckets, str):
|
||||
buckets = [buckets]
|
||||
buckets = [expand(s) for s in buckets]
|
||||
if buckets == []:
|
||||
buckets = [""]
|
||||
assert len(buckets) == 1, f"{buckets}: FIXME support for multiple buckets unimplemented"
|
||||
bucket = buckets[0]
|
||||
name = ds.get("name", "@" + bucket)
|
||||
urls = ds["shards"]
|
||||
if isinstance(urls, str):
|
||||
urls = [urls]
|
||||
# urls = [u for url in urls for u in braceexpand.braceexpand(url)]
|
||||
urls = [
|
||||
prefix + os.path.join(bucket, u) for url in urls for u in braceexpand.braceexpand(expand(url))
|
||||
]
|
||||
resample = ds.get("resample", -1)
|
||||
nsample = ds.get("choose", -1)
|
||||
if nsample > len(urls):
|
||||
raise ValueError(f"perepoch {nsample} must be no greater than the number of shards")
|
||||
if (nsample > 0) and (resample > 0):
|
||||
raise ValueError("specify only one of perepoch or choose")
|
||||
entry = MSSource(name=name, urls=urls, perepoch=nsample, resample=resample)
|
||||
self.sources.append(entry)
|
||||
print(f"# {name} {len(urls)} {nsample}", file=sys.stderr)
|
||||
|
||||
def set_epoch(self, seed):
|
||||
"""Set the current epoch (for consistent shard selection among nodes)."""
|
||||
self.rng = random.Random(seed)
|
||||
|
||||
def get_shards_for_epoch(self):
|
||||
result = []
|
||||
for source in self.sources:
|
||||
if source.resample > 0:
|
||||
# sample with replacement
|
||||
l = self.rng.choices(source.urls, k=source.resample)
|
||||
elif source.perepoch > 0:
|
||||
# sample without replacement
|
||||
l = list(source.urls)
|
||||
self.rng.shuffle(l)
|
||||
l = l[: source.perepoch]
|
||||
else:
|
||||
l = list(source.urls)
|
||||
result += l
|
||||
self.rng.shuffle(result)
|
||||
return result
|
||||
|
||||
def __iter__(self):
|
||||
shards = self.get_shards_for_epoch()
|
||||
for shard in shards:
|
||||
yield dict(url=shard)
|
||||
|
||||
|
||||
def shardspec(spec):
|
||||
if spec.endswith(".yaml"):
|
||||
return MultiShardSample(spec)
|
||||
else:
|
||||
return SimpleShardList(spec)
|
||||
|
||||
|
||||
class ResampledShards(IterableDataset):
|
||||
"""An iterable dataset yielding a list of urls."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
urls,
|
||||
nshards=sys.maxsize,
|
||||
worker_seed=None,
|
||||
deterministic=False,
|
||||
):
|
||||
"""Sample shards from the shard list with replacement.
|
||||
|
||||
:param urls: a list of URLs as a Python list or brace notation string
|
||||
"""
|
||||
super().__init__()
|
||||
urls = expand_urls(urls)
|
||||
self.urls = urls
|
||||
assert isinstance(self.urls[0], str)
|
||||
self.nshards = nshards
|
||||
self.worker_seed = utils.paddle_worker_seed if worker_seed is None else worker_seed
|
||||
self.deterministic = deterministic
|
||||
self.epoch = -1
|
||||
|
||||
def __iter__(self):
|
||||
"""Return an iterator over the shards."""
|
||||
self.epoch += 1
|
||||
if self.deterministic:
|
||||
seed = utils.make_seed(self.worker_seed(), self.epoch)
|
||||
else:
|
||||
seed = utils.make_seed(self.worker_seed(), self.epoch, os.getpid(), time.time_ns(), os.urandom(4))
|
||||
if os.environ.get("WDS_SHOW_SEED", "0") == "1":
|
||||
print(f"# ResampledShards seed {seed}")
|
||||
self.rng = random.Random(seed)
|
||||
for _ in range(self.nshards):
|
||||
index = self.rng.randint(0, len(self.urls) - 1)
|
||||
yield dict(url=self.urls[index])
|
@ -0,0 +1,283 @@
|
||||
#
|
||||
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
# This file is part of the WebDataset library.
|
||||
# See the LICENSE file for licensing terms (BSD-style).
|
||||
|
||||
# Modified from https://github.com/webdataset/webdataset
|
||||
# Modified from wenet(https://github.com/wenet-e2e/wenet)
|
||||
|
||||
"""Low level iteration functions for tar archives."""
|
||||
|
||||
import random, re, tarfile
|
||||
|
||||
import braceexpand
|
||||
|
||||
from . import filters
|
||||
from webdataset import gopen
|
||||
from webdataset.handlers import reraise_exception
|
||||
|
||||
trace = False
|
||||
meta_prefix = "__"
|
||||
meta_suffix = "__"
|
||||
|
||||
from ... import audio as paddleaudio
|
||||
import paddle
|
||||
import numpy as np
|
||||
|
||||
AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
|
||||
|
||||
def base_plus_ext(path):
|
||||
"""Split off all file extensions.
|
||||
|
||||
Returns base, allext.
|
||||
|
||||
:param path: path with extensions
|
||||
:param returns: path with all extensions removed
|
||||
|
||||
"""
|
||||
match = re.match(r"^((?:.*/|)[^.]+)[.]([^/]*)$", path)
|
||||
if not match:
|
||||
return None, None
|
||||
return match.group(1), match.group(2)
|
||||
|
||||
|
||||
def valid_sample(sample):
|
||||
"""Check whether a sample is valid.
|
||||
|
||||
:param sample: sample to be checked
|
||||
"""
|
||||
return (
|
||||
sample is not None
|
||||
and isinstance(sample, dict)
|
||||
and len(list(sample.keys())) > 0
|
||||
and not sample.get("__bad__", False)
|
||||
)
|
||||
|
||||
|
||||
# FIXME: UNUSED
|
||||
def shardlist(urls, *, shuffle=False):
|
||||
"""Given a list of URLs, yields that list, possibly shuffled."""
|
||||
if isinstance(urls, str):
|
||||
urls = braceexpand.braceexpand(urls)
|
||||
else:
|
||||
urls = list(urls)
|
||||
if shuffle:
|
||||
random.shuffle(urls)
|
||||
for url in urls:
|
||||
yield dict(url=url)
|
||||
|
||||
|
||||
def url_opener(data, handler=reraise_exception, **kw):
|
||||
"""Given a stream of url names (packaged in `dict(url=url)`), yield opened streams."""
|
||||
for sample in data:
|
||||
assert isinstance(sample, dict), sample
|
||||
assert "url" in sample
|
||||
url = sample["url"]
|
||||
try:
|
||||
stream = gopen.gopen(url, **kw)
|
||||
sample.update(stream=stream)
|
||||
yield sample
|
||||
except Exception as exn:
|
||||
exn.args = exn.args + (url,)
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
def tar_file_iterator(
|
||||
fileobj, skip_meta=r"__[^/]*__($|/)", handler=reraise_exception
|
||||
):
|
||||
"""Iterate over tar file, yielding filename, content pairs for the given tar stream.
|
||||
|
||||
:param fileobj: byte stream suitable for tarfile
|
||||
:param skip_meta: regexp for keys that are skipped entirely (Default value = r"__[^/]*__($|/)")
|
||||
|
||||
"""
|
||||
stream = tarfile.open(fileobj=fileobj, mode="r:*")
|
||||
for tarinfo in stream:
|
||||
fname = tarinfo.name
|
||||
try:
|
||||
if not tarinfo.isreg():
|
||||
continue
|
||||
if fname is None:
|
||||
continue
|
||||
if (
|
||||
"/" not in fname
|
||||
and fname.startswith(meta_prefix)
|
||||
and fname.endswith(meta_suffix)
|
||||
):
|
||||
# skipping metadata for now
|
||||
continue
|
||||
if skip_meta is not None and re.match(skip_meta, fname):
|
||||
continue
|
||||
|
||||
name = tarinfo.name
|
||||
pos = name.rfind('.')
|
||||
assert pos > 0
|
||||
prefix, postfix = name[:pos], name[pos + 1:]
|
||||
if postfix == 'wav':
|
||||
waveform, sample_rate = paddleaudio.load(stream.extractfile(tarinfo), normal=False)
|
||||
result = dict(fname=prefix, wav=waveform, sample_rate = sample_rate)
|
||||
else:
|
||||
txt = stream.extractfile(tarinfo).read().decode('utf8').strip()
|
||||
result = dict(fname=prefix, txt=txt)
|
||||
#result = dict(fname=fname, data=data)
|
||||
yield result
|
||||
stream.members = []
|
||||
except Exception as exn:
|
||||
if hasattr(exn, "args") and len(exn.args) > 0:
|
||||
exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:]
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
del stream
|
||||
|
||||
def tar_file_and_group_iterator(
|
||||
fileobj, skip_meta=r"__[^/]*__($|/)", handler=reraise_exception
|
||||
):
|
||||
""" Expand a stream of open tar files into a stream of tar file contents.
|
||||
And groups the file with same prefix
|
||||
|
||||
Args:
|
||||
data: Iterable[{src, stream}]
|
||||
|
||||
Returns:
|
||||
Iterable[{key, wav, txt, sample_rate}]
|
||||
"""
|
||||
stream = tarfile.open(fileobj=fileobj, mode="r:*")
|
||||
prev_prefix = None
|
||||
example = {}
|
||||
valid = True
|
||||
for tarinfo in stream:
|
||||
name = tarinfo.name
|
||||
pos = name.rfind('.')
|
||||
assert pos > 0
|
||||
prefix, postfix = name[:pos], name[pos + 1:]
|
||||
if prev_prefix is not None and prefix != prev_prefix:
|
||||
example['fname'] = prev_prefix
|
||||
if valid:
|
||||
yield example
|
||||
example = {}
|
||||
valid = True
|
||||
with stream.extractfile(tarinfo) as file_obj:
|
||||
try:
|
||||
if postfix == 'txt':
|
||||
example['txt'] = file_obj.read().decode('utf8').strip()
|
||||
elif postfix in AUDIO_FORMAT_SETS:
|
||||
waveform, sample_rate = paddleaudio.load(file_obj, normal=False)
|
||||
waveform = paddle.to_tensor(np.expand_dims(np.array(waveform),0), dtype=paddle.float32)
|
||||
|
||||
example['wav'] = waveform
|
||||
example['sample_rate'] = sample_rate
|
||||
else:
|
||||
example[postfix] = file_obj.read()
|
||||
except Exception as exn:
|
||||
if hasattr(exn, "args") and len(exn.args) > 0:
|
||||
exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:]
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
valid = False
|
||||
# logging.warning('error to parse {}'.format(name))
|
||||
prev_prefix = prefix
|
||||
if prev_prefix is not None:
|
||||
example['fname'] = prev_prefix
|
||||
yield example
|
||||
stream.close()
|
||||
|
||||
def tar_file_expander(data, handler=reraise_exception):
|
||||
"""Expand a stream of open tar files into a stream of tar file contents.
|
||||
|
||||
This returns an iterator over (filename, file_contents).
|
||||
"""
|
||||
for source in data:
|
||||
url = source["url"]
|
||||
try:
|
||||
assert isinstance(source, dict)
|
||||
assert "stream" in source
|
||||
for sample in tar_file_iterator(source["stream"]):
|
||||
assert (
|
||||
isinstance(sample, dict) and "data" in sample and "fname" in sample
|
||||
)
|
||||
sample["__url__"] = url
|
||||
yield sample
|
||||
except Exception as exn:
|
||||
exn.args = exn.args + (source.get("stream"), source.get("url"))
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
|
||||
|
||||
def tar_file_and_group_expander(data, handler=reraise_exception):
|
||||
"""Expand a stream of open tar files into a stream of tar file contents.
|
||||
|
||||
This returns an iterator over (filename, file_contents).
|
||||
"""
|
||||
for source in data:
|
||||
url = source["url"]
|
||||
try:
|
||||
assert isinstance(source, dict)
|
||||
assert "stream" in source
|
||||
for sample in tar_file_and_group_iterator(source["stream"]):
|
||||
assert (
|
||||
isinstance(sample, dict) and "wav" in sample and "txt" in sample and "fname" in sample
|
||||
)
|
||||
sample["__url__"] = url
|
||||
yield sample
|
||||
except Exception as exn:
|
||||
exn.args = exn.args + (source.get("stream"), source.get("url"))
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
def group_by_keys(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
|
||||
"""Return function over iterator that groups key, value pairs into samples.
|
||||
|
||||
:param keys: function that splits the key into key and extension (base_plus_ext)
|
||||
:param lcase: convert suffixes to lower case (Default value = True)
|
||||
"""
|
||||
current_sample = None
|
||||
for filesample in data:
|
||||
assert isinstance(filesample, dict)
|
||||
fname, value = filesample["fname"], filesample["data"]
|
||||
prefix, suffix = keys(fname)
|
||||
if trace:
|
||||
print(
|
||||
prefix,
|
||||
suffix,
|
||||
current_sample.keys() if isinstance(current_sample, dict) else None,
|
||||
)
|
||||
if prefix is None:
|
||||
continue
|
||||
if lcase:
|
||||
suffix = suffix.lower()
|
||||
if current_sample is None or prefix != current_sample["__key__"]:
|
||||
if valid_sample(current_sample):
|
||||
yield current_sample
|
||||
current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
|
||||
if suffix in current_sample:
|
||||
raise ValueError(
|
||||
f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}"
|
||||
)
|
||||
if suffixes is None or suffix in suffixes:
|
||||
current_sample[suffix] = value
|
||||
if valid_sample(current_sample):
|
||||
yield current_sample
|
||||
|
||||
|
||||
def tarfile_samples(src, handler=reraise_exception):
|
||||
streams = url_opener(src, handler=handler)
|
||||
samples = tar_file_and_group_expander(streams, handler=handler)
|
||||
return samples
|
||||
|
||||
|
||||
tarfile_to_samples = filters.pipelinefilter(tarfile_samples)
|
@ -0,0 +1,128 @@
|
||||
#
|
||||
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
# This file is part of the WebDataset library.
|
||||
# See the LICENSE file for licensing terms (BSD-style).
|
||||
#
|
||||
|
||||
# Modified from https://github.com/webdataset/webdataset
|
||||
|
||||
"""Miscellaneous utility functions."""
|
||||
|
||||
import importlib
|
||||
import itertools as itt
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from typing import Any, Callable, Iterator, Optional, Union
|
||||
|
||||
|
||||
def make_seed(*args):
|
||||
seed = 0
|
||||
for arg in args:
|
||||
seed = (seed * 31 + hash(arg)) & 0x7FFFFFFF
|
||||
return seed
|
||||
|
||||
|
||||
class PipelineStage:
|
||||
def invoke(self, *args, **kw):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def identity(x: Any) -> Any:
|
||||
"""Return the argument as is."""
|
||||
return x
|
||||
|
||||
|
||||
def safe_eval(s: str, expr: str = "{}"):
|
||||
"""Evaluate the given expression more safely."""
|
||||
if re.sub("[^A-Za-z0-9_]", "", s) != s:
|
||||
raise ValueError(f"safe_eval: illegal characters in: '{s}'")
|
||||
return eval(expr.format(s))
|
||||
|
||||
|
||||
def lookup_sym(sym: str, modules: list):
|
||||
"""Look up a symbol in a list of modules."""
|
||||
for mname in modules:
|
||||
module = importlib.import_module(mname, package="webdataset")
|
||||
result = getattr(module, sym, None)
|
||||
if result is not None:
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
def repeatedly0(
|
||||
loader: Iterator, nepochs: int = sys.maxsize, nbatches: int = sys.maxsize
|
||||
):
|
||||
"""Repeatedly returns batches from a DataLoader."""
|
||||
for epoch in range(nepochs):
|
||||
for sample in itt.islice(loader, nbatches):
|
||||
yield sample
|
||||
|
||||
|
||||
def guess_batchsize(batch: Union[tuple, list]):
|
||||
"""Guess the batch size by looking at the length of the first element in a tuple."""
|
||||
return len(batch[0])
|
||||
|
||||
|
||||
def repeatedly(
|
||||
source: Iterator,
|
||||
nepochs: int = None,
|
||||
nbatches: int = None,
|
||||
nsamples: int = None,
|
||||
batchsize: Callable[..., int] = guess_batchsize,
|
||||
):
|
||||
"""Repeatedly yield samples from an iterator."""
|
||||
epoch = 0
|
||||
batch = 0
|
||||
total = 0
|
||||
while True:
|
||||
for sample in source:
|
||||
yield sample
|
||||
batch += 1
|
||||
if nbatches is not None and batch >= nbatches:
|
||||
return
|
||||
if nsamples is not None:
|
||||
total += guess_batchsize(sample)
|
||||
if total >= nsamples:
|
||||
return
|
||||
epoch += 1
|
||||
if nepochs is not None and epoch >= nepochs:
|
||||
return
|
||||
|
||||
def paddle_worker_info(group=None):
|
||||
"""Return node and worker info for PyTorch and some distributed environments."""
|
||||
rank = 0
|
||||
world_size = 1
|
||||
worker = 0
|
||||
num_workers = 1
|
||||
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||
rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
else:
|
||||
try:
|
||||
import paddle.distributed
|
||||
group = group or paddle.distributed.get_group()
|
||||
rank = paddle.distributed.get_rank()
|
||||
world_size = paddle.distributed.get_world_size()
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
if "WORKER" in os.environ and "NUM_WORKERS" in os.environ:
|
||||
worker = int(os.environ["WORKER"])
|
||||
num_workers = int(os.environ["NUM_WORKERS"])
|
||||
else:
|
||||
try:
|
||||
import paddle.io.get_worker_info
|
||||
worker_info = paddle.io.get_worker_info()
|
||||
if worker_info is not None:
|
||||
worker = worker_info.id
|
||||
num_workers = worker_info.num_workers
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
return rank, world_size, worker, num_workers
|
||||
|
||||
def paddle_worker_seed(group=None):
|
||||
"""Compute a distinct, deterministic RNG seed for each worker and node."""
|
||||
rank, world_size, worker, num_workers = paddle_worker_info(group=group)
|
||||
return rank * 1000 + worker
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,54 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from espnet(https://github.com/espnet/espnet)
|
||||
import numpy as np
|
||||
|
||||
|
||||
def delta(feat, window):
|
||||
assert window > 0
|
||||
delta_feat = np.zeros_like(feat)
|
||||
for i in range(1, window + 1):
|
||||
delta_feat[:-i] += i * feat[i:]
|
||||
delta_feat[i:] += -i * feat[:-i]
|
||||
delta_feat[-i:] += i * feat[-1]
|
||||
delta_feat[:i] += -i * feat[0]
|
||||
delta_feat /= 2 * sum(i**2 for i in range(1, window + 1))
|
||||
return delta_feat
|
||||
|
||||
|
||||
def add_deltas(x, window=2, order=2):
|
||||
"""
|
||||
Args:
|
||||
x (np.ndarray): speech feat, (T, D).
|
||||
|
||||
Return:
|
||||
np.ndarray: (T, (1+order)*D)
|
||||
"""
|
||||
feats = [x]
|
||||
for _ in range(order):
|
||||
feats.append(delta(feats[-1], window))
|
||||
return np.concatenate(feats, axis=1)
|
||||
|
||||
|
||||
class AddDeltas():
|
||||
def __init__(self, window=2, order=2):
|
||||
self.window = window
|
||||
self.order = order
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(window={window}, order={order}".format(
|
||||
name=self.__class__.__name__, window=self.window, order=self.order)
|
||||
|
||||
def __call__(self, x):
|
||||
return add_deltas(x, window=self.window, order=self.order)
|
@ -0,0 +1,57 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from espnet(https://github.com/espnet/espnet)
|
||||
import numpy
|
||||
|
||||
|
||||
class ChannelSelector():
|
||||
"""Select 1ch from multi-channel signal"""
|
||||
|
||||
def __init__(self, train_channel="random", eval_channel=0, axis=1):
|
||||
self.train_channel = train_channel
|
||||
self.eval_channel = eval_channel
|
||||
self.axis = axis
|
||||
|
||||
def __repr__(self):
|
||||
return ("{name}(train_channel={train_channel}, "
|
||||
"eval_channel={eval_channel}, axis={axis})".format(
|
||||
name=self.__class__.__name__,
|
||||
train_channel=self.train_channel,
|
||||
eval_channel=self.eval_channel,
|
||||
axis=self.axis, ))
|
||||
|
||||
def __call__(self, x, train=True):
|
||||
# Assuming x: [Time, Channel] by default
|
||||
|
||||
if x.ndim <= self.axis:
|
||||
# If the dimension is insufficient, then unsqueeze
|
||||
# (e.g [Time] -> [Time, 1])
|
||||
ind = tuple(
|
||||
slice(None) if i < x.ndim else None
|
||||
for i in range(self.axis + 1))
|
||||
x = x[ind]
|
||||
|
||||
if train:
|
||||
channel = self.train_channel
|
||||
else:
|
||||
channel = self.eval_channel
|
||||
|
||||
if channel == "random":
|
||||
ch = numpy.random.randint(0, x.shape[self.axis])
|
||||
else:
|
||||
ch = channel
|
||||
|
||||
ind = tuple(
|
||||
slice(None) if i != self.axis else ch for i in range(x.ndim))
|
||||
return x[ind]
|
@ -0,0 +1,201 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from espnet(https://github.com/espnet/espnet)
|
||||
import io
|
||||
import json
|
||||
|
||||
import h5py
|
||||
import kaldiio
|
||||
import numpy as np
|
||||
|
||||
|
||||
class CMVN():
|
||||
"Apply Global/Spk CMVN/iverserCMVN."
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stats,
|
||||
norm_means=True,
|
||||
norm_vars=False,
|
||||
filetype="mat",
|
||||
utt2spk=None,
|
||||
spk2utt=None,
|
||||
reverse=False,
|
||||
std_floor=1.0e-20, ):
|
||||
self.stats_file = stats
|
||||
self.norm_means = norm_means
|
||||
self.norm_vars = norm_vars
|
||||
self.reverse = reverse
|
||||
|
||||
if isinstance(stats, dict):
|
||||
stats_dict = dict(stats)
|
||||
else:
|
||||
# Use for global CMVN
|
||||
if filetype == "mat":
|
||||
stats_dict = {None: kaldiio.load_mat(stats)}
|
||||
# Use for global CMVN
|
||||
elif filetype == "npy":
|
||||
stats_dict = {None: np.load(stats)}
|
||||
# Use for speaker CMVN
|
||||
elif filetype == "ark":
|
||||
self.accept_uttid = True
|
||||
stats_dict = dict(kaldiio.load_ark(stats))
|
||||
# Use for speaker CMVN
|
||||
elif filetype == "hdf5":
|
||||
self.accept_uttid = True
|
||||
stats_dict = h5py.File(stats)
|
||||
else:
|
||||
raise ValueError("Not supporting filetype={}".format(filetype))
|
||||
|
||||
if utt2spk is not None:
|
||||
self.utt2spk = {}
|
||||
with io.open(utt2spk, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
utt, spk = line.rstrip().split(None, 1)
|
||||
self.utt2spk[utt] = spk
|
||||
elif spk2utt is not None:
|
||||
self.utt2spk = {}
|
||||
with io.open(spk2utt, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
spk, utts = line.rstrip().split(None, 1)
|
||||
for utt in utts.split():
|
||||
self.utt2spk[utt] = spk
|
||||
else:
|
||||
self.utt2spk = None
|
||||
|
||||
# Kaldi makes a matrix for CMVN which has a shape of (2, feat_dim + 1),
|
||||
# and the first vector contains the sum of feats and the second is
|
||||
# the sum of squares. The last value of the first, i.e. stats[0,-1],
|
||||
# is the number of samples for this statistics.
|
||||
self.bias = {}
|
||||
self.scale = {}
|
||||
for spk, stats in stats_dict.items():
|
||||
assert len(stats) == 2, stats.shape
|
||||
|
||||
count = stats[0, -1]
|
||||
|
||||
# If the feature has two or more dimensions
|
||||
if not (np.isscalar(count) or isinstance(count, (int, float))):
|
||||
# The first is only used
|
||||
count = count.flatten()[0]
|
||||
|
||||
mean = stats[0, :-1] / count
|
||||
# V(x) = E(x^2) - (E(x))^2
|
||||
var = stats[1, :-1] / count - mean * mean
|
||||
std = np.maximum(np.sqrt(var), std_floor)
|
||||
self.bias[spk] = -mean
|
||||
self.scale[spk] = 1 / std
|
||||
|
||||
def __repr__(self):
|
||||
return ("{name}(stats_file={stats_file}, "
|
||||
"norm_means={norm_means}, norm_vars={norm_vars}, "
|
||||
"reverse={reverse})".format(
|
||||
name=self.__class__.__name__,
|
||||
stats_file=self.stats_file,
|
||||
norm_means=self.norm_means,
|
||||
norm_vars=self.norm_vars,
|
||||
reverse=self.reverse, ))
|
||||
|
||||
def __call__(self, x, uttid=None):
|
||||
if self.utt2spk is not None:
|
||||
spk = self.utt2spk[uttid]
|
||||
else:
|
||||
spk = uttid
|
||||
|
||||
if not self.reverse:
|
||||
# apply cmvn
|
||||
if self.norm_means:
|
||||
x = np.add(x, self.bias[spk])
|
||||
if self.norm_vars:
|
||||
x = np.multiply(x, self.scale[spk])
|
||||
|
||||
else:
|
||||
# apply reverse cmvn
|
||||
if self.norm_vars:
|
||||
x = np.divide(x, self.scale[spk])
|
||||
if self.norm_means:
|
||||
x = np.subtract(x, self.bias[spk])
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class UtteranceCMVN():
|
||||
"Apply Utterance CMVN"
|
||||
|
||||
def __init__(self, norm_means=True, norm_vars=False, std_floor=1.0e-20):
|
||||
self.norm_means = norm_means
|
||||
self.norm_vars = norm_vars
|
||||
self.std_floor = std_floor
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(norm_means={norm_means}, norm_vars={norm_vars})".format(
|
||||
name=self.__class__.__name__,
|
||||
norm_means=self.norm_means,
|
||||
norm_vars=self.norm_vars, )
|
||||
|
||||
def __call__(self, x, uttid=None):
|
||||
# x: [Time, Dim]
|
||||
square_sums = (x**2).sum(axis=0)
|
||||
mean = x.mean(axis=0)
|
||||
|
||||
if self.norm_means:
|
||||
x = np.subtract(x, mean)
|
||||
|
||||
if self.norm_vars:
|
||||
var = square_sums / x.shape[0] - mean**2
|
||||
std = np.maximum(np.sqrt(var), self.std_floor)
|
||||
x = np.divide(x, std)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class GlobalCMVN():
|
||||
"Apply Global CMVN"
|
||||
|
||||
def __init__(self,
|
||||
cmvn_path,
|
||||
norm_means=True,
|
||||
norm_vars=True,
|
||||
std_floor=1.0e-20):
|
||||
# cmvn_path: Option[str, dict]
|
||||
cmvn = cmvn_path
|
||||
self.cmvn = cmvn
|
||||
self.norm_means = norm_means
|
||||
self.norm_vars = norm_vars
|
||||
self.std_floor = std_floor
|
||||
if isinstance(cmvn, dict):
|
||||
cmvn_stats = cmvn
|
||||
else:
|
||||
with open(cmvn) as f:
|
||||
cmvn_stats = json.load(f)
|
||||
self.count = cmvn_stats['frame_num']
|
||||
self.mean = np.array(cmvn_stats['mean_stat']) / self.count
|
||||
self.square_sums = np.array(cmvn_stats['var_stat'])
|
||||
self.var = self.square_sums / self.count - self.mean**2
|
||||
self.std = np.maximum(np.sqrt(self.var), self.std_floor)
|
||||
|
||||
def __repr__(self):
|
||||
return f"""{self.__class__.__name__}(
|
||||
cmvn_path={self.cmvn},
|
||||
norm_means={self.norm_means},
|
||||
norm_vars={self.norm_vars},)"""
|
||||
|
||||
def __call__(self, x, uttid=None):
|
||||
# x: [Time, Dim]
|
||||
if self.norm_means:
|
||||
x = np.subtract(x, self.mean)
|
||||
|
||||
if self.norm_vars:
|
||||
x = np.divide(x, self.std)
|
||||
return x
|
@ -0,0 +1,86 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from espnet(https://github.com/espnet/espnet)
|
||||
import inspect
|
||||
|
||||
from paddlespeech.audio.transform.transform_interface import TransformInterface
|
||||
from paddlespeech.audio.utils.check_kwargs import check_kwargs
|
||||
|
||||
|
||||
class FuncTrans(TransformInterface):
|
||||
"""Functional Transformation
|
||||
|
||||
WARNING:
|
||||
Builtin or C/C++ functions may not work properly
|
||||
because this class heavily depends on the `inspect` module.
|
||||
|
||||
Usage:
|
||||
|
||||
>>> def foo_bar(x, a=1, b=2):
|
||||
... '''Foo bar
|
||||
... :param x: input
|
||||
... :param int a: default 1
|
||||
... :param int b: default 2
|
||||
... '''
|
||||
... return x + a - b
|
||||
|
||||
|
||||
>>> class FooBar(FuncTrans):
|
||||
... _func = foo_bar
|
||||
... __doc__ = foo_bar.__doc__
|
||||
"""
|
||||
|
||||
_func = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
check_kwargs(self.func, kwargs)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.func(x, **self.kwargs)
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
fname = cls._func.__name__.replace("_", "-")
|
||||
group = parser.add_argument_group(fname + " transformation setting")
|
||||
for k, v in cls.default_params().items():
|
||||
# TODO(karita): get help and choices from docstring?
|
||||
attr = k.replace("_", "-")
|
||||
group.add_argument(f"--{fname}-{attr}", default=v, type=type(v))
|
||||
return parser
|
||||
|
||||
@property
|
||||
def func(self):
|
||||
return type(self)._func
|
||||
|
||||
@classmethod
|
||||
def default_params(cls):
|
||||
try:
|
||||
d = dict(inspect.signature(cls._func).parameters)
|
||||
except ValueError:
|
||||
d = dict()
|
||||
return {
|
||||
k: v.default
|
||||
for k, v in d.items() if v.default != inspect.Parameter.empty
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
params = self.default_params()
|
||||
params.update(**self.kwargs)
|
||||
ret = self.__class__.__name__ + "("
|
||||
if len(params) == 0:
|
||||
return ret + ")"
|
||||
for k, v in params.items():
|
||||
ret += "{}={}, ".format(k, v)
|
||||
return ret[:-2] + ")"
|
@ -0,0 +1,561 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from espnet(https://github.com/espnet/espnet)
|
||||
import librosa
|
||||
import numpy
|
||||
import scipy
|
||||
import soundfile
|
||||
|
||||
import io
|
||||
import os
|
||||
import h5py
|
||||
import numpy as np
|
||||
|
||||
class SoundHDF5File():
|
||||
"""Collecting sound files to a HDF5 file
|
||||
|
||||
>>> f = SoundHDF5File('a.flac.h5', mode='a')
|
||||
>>> array = np.random.randint(0, 100, 100, dtype=np.int16)
|
||||
>>> f['id'] = (array, 16000)
|
||||
>>> array, rate = f['id']
|
||||
|
||||
|
||||
:param: str filepath:
|
||||
:param: str mode:
|
||||
:param: str format: The type used when saving wav. flac, nist, htk, etc.
|
||||
:param: str dtype:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
filepath,
|
||||
mode="r+",
|
||||
format=None,
|
||||
dtype="int16",
|
||||
**kwargs):
|
||||
self.filepath = filepath
|
||||
self.mode = mode
|
||||
self.dtype = dtype
|
||||
|
||||
self.file = h5py.File(filepath, mode, **kwargs)
|
||||
if format is None:
|
||||
# filepath = a.flac.h5 -> format = flac
|
||||
second_ext = os.path.splitext(os.path.splitext(filepath)[0])[1]
|
||||
format = second_ext[1:]
|
||||
if format.upper() not in soundfile.available_formats():
|
||||
# If not found, flac is selected
|
||||
format = "flac"
|
||||
|
||||
# This format affects only saving
|
||||
self.format = format
|
||||
|
||||
def __repr__(self):
|
||||
return '<SoundHDF5 file "{}" (mode {}, format {}, type {})>'.format(
|
||||
self.filepath, self.mode, self.format, self.dtype)
|
||||
|
||||
def create_dataset(self, name, shape=None, data=None, **kwds):
|
||||
f = io.BytesIO()
|
||||
array, rate = data
|
||||
soundfile.write(f, array, rate, format=self.format)
|
||||
self.file.create_dataset(
|
||||
name, shape=shape, data=np.void(f.getvalue()), **kwds)
|
||||
|
||||
def __setitem__(self, name, data):
|
||||
self.create_dataset(name, data=data)
|
||||
|
||||
def __getitem__(self, key):
|
||||
data = self.file[key][()]
|
||||
f = io.BytesIO(data.tobytes())
|
||||
array, rate = soundfile.read(f, dtype=self.dtype)
|
||||
return array, rate
|
||||
|
||||
def keys(self):
|
||||
return self.file.keys()
|
||||
|
||||
def values(self):
|
||||
for k in self.file:
|
||||
yield self[k]
|
||||
|
||||
def items(self):
|
||||
for k in self.file:
|
||||
yield k, self[k]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.file)
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self.file
|
||||
|
||||
def __len__(self, item):
|
||||
return len(self.file)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.file.close()
|
||||
|
||||
def close(self):
|
||||
self.file.close()
|
||||
|
||||
class SpeedPerturbation():
|
||||
"""SpeedPerturbation
|
||||
|
||||
The speed perturbation in kaldi uses sox-speed instead of sox-tempo,
|
||||
and sox-speed just to resample the input,
|
||||
i.e pitch and tempo are changed both.
|
||||
|
||||
"Why use speed option instead of tempo -s in SoX for speed perturbation"
|
||||
https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8
|
||||
|
||||
Warning:
|
||||
This function is very slow because of resampling.
|
||||
I recommmend to apply speed-perturb outside the training using sox.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lower=0.9,
|
||||
upper=1.1,
|
||||
utt2ratio=None,
|
||||
keep_length=True,
|
||||
res_type="kaiser_best",
|
||||
seed=None, ):
|
||||
self.res_type = res_type
|
||||
self.keep_length = keep_length
|
||||
self.state = numpy.random.RandomState(seed)
|
||||
|
||||
if utt2ratio is not None:
|
||||
self.utt2ratio = {}
|
||||
# Use the scheduled ratio for each utterances
|
||||
self.utt2ratio_file = utt2ratio
|
||||
self.lower = None
|
||||
self.upper = None
|
||||
self.accept_uttid = True
|
||||
|
||||
with open(utt2ratio, "r") as f:
|
||||
for line in f:
|
||||
utt, ratio = line.rstrip().split(None, 1)
|
||||
ratio = float(ratio)
|
||||
self.utt2ratio[utt] = ratio
|
||||
else:
|
||||
self.utt2ratio = None
|
||||
# The ratio is given on runtime randomly
|
||||
self.lower = lower
|
||||
self.upper = upper
|
||||
|
||||
def __repr__(self):
|
||||
if self.utt2ratio is None:
|
||||
return "{}(lower={}, upper={}, " "keep_length={}, res_type={})".format(
|
||||
self.__class__.__name__,
|
||||
self.lower,
|
||||
self.upper,
|
||||
self.keep_length,
|
||||
self.res_type, )
|
||||
else:
|
||||
return "{}({}, res_type={})".format(
|
||||
self.__class__.__name__, self.utt2ratio_file, self.res_type)
|
||||
|
||||
def __call__(self, x, uttid=None, train=True):
|
||||
if not train:
|
||||
return x
|
||||
x = x.astype(numpy.float32)
|
||||
if self.accept_uttid:
|
||||
ratio = self.utt2ratio[uttid]
|
||||
else:
|
||||
ratio = self.state.uniform(self.lower, self.upper)
|
||||
|
||||
# Note1: resample requires the sampling-rate of input and output,
|
||||
# but actually only the ratio is used.
|
||||
y = librosa.resample(
|
||||
x, orig_sr=ratio, target_sr=1, res_type=self.res_type)
|
||||
|
||||
if self.keep_length:
|
||||
diff = abs(len(x) - len(y))
|
||||
if len(y) > len(x):
|
||||
# Truncate noise
|
||||
y = y[diff // 2:-((diff + 1) // 2)]
|
||||
elif len(y) < len(x):
|
||||
# Assume the time-axis is the first: (Time, Channel)
|
||||
pad_width = [(diff // 2, (diff + 1) // 2)] + [
|
||||
(0, 0) for _ in range(y.ndim - 1)
|
||||
]
|
||||
y = numpy.pad(
|
||||
y, pad_width=pad_width, constant_values=0, mode="constant")
|
||||
return y
|
||||
|
||||
|
||||
class SpeedPerturbationSox():
|
||||
"""SpeedPerturbationSox
|
||||
|
||||
The speed perturbation in kaldi uses sox-speed instead of sox-tempo,
|
||||
and sox-speed just to resample the input,
|
||||
i.e pitch and tempo are changed both.
|
||||
|
||||
To speed up or slow down the sound of a file,
|
||||
use speed to modify the pitch and the duration of the file.
|
||||
This raises the speed and reduces the time.
|
||||
The default factor is 1.0 which makes no change to the audio.
|
||||
2.0 doubles speed, thus time length is cut by a half and pitch is one interval higher.
|
||||
|
||||
"Why use speed option instead of tempo -s in SoX for speed perturbation"
|
||||
https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8
|
||||
|
||||
tempo option:
|
||||
sox -t wav input.wav -t wav output.tempo0.9.wav tempo -s 0.9
|
||||
|
||||
speed option:
|
||||
sox -t wav input.wav -t wav output.speed0.9.wav speed 0.9
|
||||
|
||||
If we use speed option like above, the pitch of audio also will be changed,
|
||||
but the tempo option does not change the pitch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lower=0.9,
|
||||
upper=1.1,
|
||||
utt2ratio=None,
|
||||
keep_length=True,
|
||||
sr=16000,
|
||||
seed=None, ):
|
||||
self.sr = sr
|
||||
self.keep_length = keep_length
|
||||
self.state = numpy.random.RandomState(seed)
|
||||
|
||||
try:
|
||||
import soxbindings as sox
|
||||
except ImportError:
|
||||
try:
|
||||
from paddlespeech.s2t.utils import dynamic_pip_install
|
||||
package = "sox"
|
||||
dynamic_pip_install.install(package)
|
||||
package = "soxbindings"
|
||||
if sys.platform != "win32":
|
||||
dynamic_pip_install.install(package)
|
||||
import soxbindings as sox
|
||||
except Exception:
|
||||
raise RuntimeError(
|
||||
"Can not install soxbindings on your system.")
|
||||
self.sox = sox
|
||||
|
||||
if utt2ratio is not None:
|
||||
self.utt2ratio = {}
|
||||
# Use the scheduled ratio for each utterances
|
||||
self.utt2ratio_file = utt2ratio
|
||||
self.lower = None
|
||||
self.upper = None
|
||||
self.accept_uttid = True
|
||||
|
||||
with open(utt2ratio, "r") as f:
|
||||
for line in f:
|
||||
utt, ratio = line.rstrip().split(None, 1)
|
||||
ratio = float(ratio)
|
||||
self.utt2ratio[utt] = ratio
|
||||
else:
|
||||
self.utt2ratio = None
|
||||
# The ratio is given on runtime randomly
|
||||
self.lower = lower
|
||||
self.upper = upper
|
||||
|
||||
def __repr__(self):
|
||||
if self.utt2ratio is None:
|
||||
return f"""{self.__class__.__name__}(
|
||||
lower={self.lower},
|
||||
upper={self.upper},
|
||||
keep_length={self.keep_length},
|
||||
sample_rate={self.sr})"""
|
||||
|
||||
else:
|
||||
return f"""{self.__class__.__name__}(
|
||||
utt2ratio={self.utt2ratio_file},
|
||||
sample_rate={self.sr})"""
|
||||
|
||||
def __call__(self, x, uttid=None, train=True):
|
||||
if not train:
|
||||
return x
|
||||
|
||||
x = x.astype(numpy.float32)
|
||||
if self.accept_uttid:
|
||||
ratio = self.utt2ratio[uttid]
|
||||
else:
|
||||
ratio = self.state.uniform(self.lower, self.upper)
|
||||
|
||||
tfm = self.sox.Transformer()
|
||||
tfm.set_globals(multithread=False)
|
||||
tfm.speed(ratio)
|
||||
y = tfm.build_array(input_array=x, sample_rate_in=self.sr)
|
||||
|
||||
if self.keep_length:
|
||||
diff = abs(len(x) - len(y))
|
||||
if len(y) > len(x):
|
||||
# Truncate noise
|
||||
y = y[diff // 2:-((diff + 1) // 2)]
|
||||
elif len(y) < len(x):
|
||||
# Assume the time-axis is the first: (Time, Channel)
|
||||
pad_width = [(diff // 2, (diff + 1) // 2)] + [
|
||||
(0, 0) for _ in range(y.ndim - 1)
|
||||
]
|
||||
y = numpy.pad(
|
||||
y, pad_width=pad_width, constant_values=0, mode="constant")
|
||||
|
||||
if y.ndim == 2 and x.ndim == 1:
|
||||
# (T, C) -> (T)
|
||||
y = y.sequence(1)
|
||||
return y
|
||||
|
||||
|
||||
class BandpassPerturbation():
|
||||
"""BandpassPerturbation
|
||||
|
||||
Randomly dropout along the frequency axis.
|
||||
|
||||
The original idea comes from the following:
|
||||
"randomly-selected frequency band was cut off under the constraint of
|
||||
leaving at least 1,000 Hz band within the range of less than 4,000Hz."
|
||||
(The Hitachi/JHU CHiME-5 system: Advances in speech recognition for
|
||||
everyday home environments using multiple microphone arrays;
|
||||
http://spandh.dcs.shef.ac.uk/chime_workshop/papers/CHiME_2018_paper_kanda.pdf)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, lower=0.0, upper=0.75, seed=None, axes=(-1, )):
|
||||
self.lower = lower
|
||||
self.upper = upper
|
||||
self.state = numpy.random.RandomState(seed)
|
||||
# x_stft: (Time, Channel, Freq)
|
||||
self.axes = axes
|
||||
|
||||
def __repr__(self):
|
||||
return "{}(lower={}, upper={})".format(self.__class__.__name__,
|
||||
self.lower, self.upper)
|
||||
|
||||
def __call__(self, x_stft, uttid=None, train=True):
|
||||
if not train:
|
||||
return x_stft
|
||||
|
||||
if x_stft.ndim == 1:
|
||||
raise RuntimeError("Input in time-freq domain: "
|
||||
"(Time, Channel, Freq) or (Time, Freq)")
|
||||
|
||||
ratio = self.state.uniform(self.lower, self.upper)
|
||||
axes = [i if i >= 0 else x_stft.ndim - i for i in self.axes]
|
||||
shape = [s if i in axes else 1 for i, s in enumerate(x_stft.shape)]
|
||||
|
||||
mask = self.state.randn(*shape) > ratio
|
||||
x_stft *= mask
|
||||
return x_stft
|
||||
|
||||
|
||||
class VolumePerturbation():
|
||||
def __init__(self,
|
||||
lower=-1.6,
|
||||
upper=1.6,
|
||||
utt2ratio=None,
|
||||
dbunit=True,
|
||||
seed=None):
|
||||
self.dbunit = dbunit
|
||||
self.utt2ratio_file = utt2ratio
|
||||
self.lower = lower
|
||||
self.upper = upper
|
||||
self.state = numpy.random.RandomState(seed)
|
||||
|
||||
if utt2ratio is not None:
|
||||
# Use the scheduled ratio for each utterances
|
||||
self.utt2ratio = {}
|
||||
self.lower = None
|
||||
self.upper = None
|
||||
self.accept_uttid = True
|
||||
|
||||
with open(utt2ratio, "r") as f:
|
||||
for line in f:
|
||||
utt, ratio = line.rstrip().split(None, 1)
|
||||
ratio = float(ratio)
|
||||
self.utt2ratio[utt] = ratio
|
||||
else:
|
||||
# The ratio is given on runtime randomly
|
||||
self.utt2ratio = None
|
||||
|
||||
def __repr__(self):
|
||||
if self.utt2ratio is None:
|
||||
return "{}(lower={}, upper={}, dbunit={})".format(
|
||||
self.__class__.__name__, self.lower, self.upper, self.dbunit)
|
||||
else:
|
||||
return '{}("{}", dbunit={})'.format(
|
||||
self.__class__.__name__, self.utt2ratio_file, self.dbunit)
|
||||
|
||||
def __call__(self, x, uttid=None, train=True):
|
||||
if not train:
|
||||
return x
|
||||
|
||||
x = x.astype(numpy.float32)
|
||||
|
||||
if self.accept_uttid:
|
||||
ratio = self.utt2ratio[uttid]
|
||||
else:
|
||||
ratio = self.state.uniform(self.lower, self.upper)
|
||||
if self.dbunit:
|
||||
ratio = 10**(ratio / 20)
|
||||
return x * ratio
|
||||
|
||||
|
||||
class NoiseInjection():
|
||||
"""Add isotropic noise"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
utt2noise=None,
|
||||
lower=-20,
|
||||
upper=-5,
|
||||
utt2ratio=None,
|
||||
filetype="list",
|
||||
dbunit=True,
|
||||
seed=None, ):
|
||||
self.utt2noise_file = utt2noise
|
||||
self.utt2ratio_file = utt2ratio
|
||||
self.filetype = filetype
|
||||
self.dbunit = dbunit
|
||||
self.lower = lower
|
||||
self.upper = upper
|
||||
self.state = numpy.random.RandomState(seed)
|
||||
|
||||
if utt2ratio is not None:
|
||||
# Use the scheduled ratio for each utterances
|
||||
self.utt2ratio = {}
|
||||
with open(utt2noise, "r") as f:
|
||||
for line in f:
|
||||
utt, snr = line.rstrip().split(None, 1)
|
||||
snr = float(snr)
|
||||
self.utt2ratio[utt] = snr
|
||||
else:
|
||||
# The ratio is given on runtime randomly
|
||||
self.utt2ratio = None
|
||||
|
||||
if utt2noise is not None:
|
||||
self.utt2noise = {}
|
||||
if filetype == "list":
|
||||
with open(utt2noise, "r") as f:
|
||||
for line in f:
|
||||
utt, filename = line.rstrip().split(None, 1)
|
||||
signal, rate = soundfile.read(filename, dtype="int16")
|
||||
# Load all files in memory
|
||||
self.utt2noise[utt] = (signal, rate)
|
||||
|
||||
elif filetype == "sound.hdf5":
|
||||
self.utt2noise = SoundHDF5File(utt2noise, "r")
|
||||
else:
|
||||
raise ValueError(filetype)
|
||||
else:
|
||||
self.utt2noise = None
|
||||
|
||||
if utt2noise is not None and utt2ratio is not None:
|
||||
if set(self.utt2ratio) != set(self.utt2noise):
|
||||
raise RuntimeError("The uttids mismatch between {} and {}".
|
||||
format(utt2ratio, utt2noise))
|
||||
|
||||
def __repr__(self):
|
||||
if self.utt2ratio is None:
|
||||
return "{}(lower={}, upper={}, dbunit={})".format(
|
||||
self.__class__.__name__, self.lower, self.upper, self.dbunit)
|
||||
else:
|
||||
return '{}("{}", dbunit={})'.format(
|
||||
self.__class__.__name__, self.utt2ratio_file, self.dbunit)
|
||||
|
||||
def __call__(self, x, uttid=None, train=True):
|
||||
if not train:
|
||||
return x
|
||||
x = x.astype(numpy.float32)
|
||||
|
||||
# 1. Get ratio of noise to signal in sound pressure level
|
||||
if uttid is not None and self.utt2ratio is not None:
|
||||
ratio = self.utt2ratio[uttid]
|
||||
else:
|
||||
ratio = self.state.uniform(self.lower, self.upper)
|
||||
|
||||
if self.dbunit:
|
||||
ratio = 10**(ratio / 20)
|
||||
scale = ratio * numpy.sqrt((x**2).mean())
|
||||
|
||||
# 2. Get noise
|
||||
if self.utt2noise is not None:
|
||||
# Get noise from the external source
|
||||
if uttid is not None:
|
||||
noise, rate = self.utt2noise[uttid]
|
||||
else:
|
||||
# Randomly select the noise source
|
||||
noise = self.state.choice(list(self.utt2noise.values()))
|
||||
# Normalize the level
|
||||
noise /= numpy.sqrt((noise**2).mean())
|
||||
|
||||
# Adjust the noise length
|
||||
diff = abs(len(x) - len(noise))
|
||||
offset = self.state.randint(0, diff)
|
||||
if len(noise) > len(x):
|
||||
# Truncate noise
|
||||
noise = noise[offset:-(diff - offset)]
|
||||
else:
|
||||
noise = numpy.pad(
|
||||
noise, pad_width=[offset, diff - offset], mode="wrap")
|
||||
|
||||
else:
|
||||
# Generate white noise
|
||||
noise = self.state.normal(0, 1, x.shape)
|
||||
|
||||
# 3. Add noise to signal
|
||||
return x + noise * scale
|
||||
|
||||
|
||||
class RIRConvolve():
|
||||
def __init__(self, utt2rir, filetype="list"):
|
||||
self.utt2rir_file = utt2rir
|
||||
self.filetype = filetype
|
||||
|
||||
self.utt2rir = {}
|
||||
if filetype == "list":
|
||||
with open(utt2rir, "r") as f:
|
||||
for line in f:
|
||||
utt, filename = line.rstrip().split(None, 1)
|
||||
signal, rate = soundfile.read(filename, dtype="int16")
|
||||
self.utt2rir[utt] = (signal, rate)
|
||||
|
||||
elif filetype == "sound.hdf5":
|
||||
self.utt2rir = SoundHDF5File(utt2rir, "r")
|
||||
else:
|
||||
raise NotImplementedError(filetype)
|
||||
|
||||
def __repr__(self):
|
||||
return '{}("{}")'.format(self.__class__.__name__, self.utt2rir_file)
|
||||
|
||||
def __call__(self, x, uttid=None, train=True):
|
||||
if not train:
|
||||
return x
|
||||
|
||||
x = x.astype(numpy.float32)
|
||||
|
||||
if x.ndim != 1:
|
||||
# Must be single channel
|
||||
raise RuntimeError(
|
||||
"Input x must be one dimensional array, but got {}".format(
|
||||
x.shape))
|
||||
|
||||
rir, rate = self.utt2rir[uttid]
|
||||
if rir.ndim == 2:
|
||||
# FIXME(kamo): Use chainer.convolution_1d?
|
||||
# return [Time, Channel]
|
||||
return numpy.stack(
|
||||
[scipy.convolve(x, r, mode="same") for r in rir], axis=-1)
|
||||
else:
|
||||
return scipy.convolve(x, rir, mode="same")
|
||||
|
@ -0,0 +1,214 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from espnet(https://github.com/espnet/espnet)
|
||||
"""Spec Augment module for preprocessing i.e., data augmentation"""
|
||||
import random
|
||||
|
||||
import numpy
|
||||
from PIL import Image
|
||||
from PIL.Image import BICUBIC
|
||||
|
||||
from .functional import FuncTrans
|
||||
|
||||
|
||||
def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"):
|
||||
"""time warp for spec augment
|
||||
|
||||
move random center frame by the random width ~ uniform(-window, window)
|
||||
:param numpy.ndarray x: spectrogram (time, freq)
|
||||
:param int max_time_warp: maximum time frames to warp
|
||||
:param bool inplace: overwrite x with the result
|
||||
:param str mode: "PIL" (default, fast, not differentiable) or "sparse_image_warp"
|
||||
(slow, differentiable)
|
||||
:returns numpy.ndarray: time warped spectrogram (time, freq)
|
||||
"""
|
||||
window = max_time_warp
|
||||
if window == 0:
|
||||
return x
|
||||
|
||||
if mode == "PIL":
|
||||
t = x.shape[0]
|
||||
if t - window <= window:
|
||||
return x
|
||||
# NOTE: randrange(a, b) emits a, a + 1, ..., b - 1
|
||||
center = random.randrange(window, t - window)
|
||||
warped = random.randrange(center - window, center +
|
||||
window) + 1 # 1 ... t - 1
|
||||
|
||||
left = Image.fromarray(x[:center]).resize((x.shape[1], warped), BICUBIC)
|
||||
right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped),
|
||||
BICUBIC)
|
||||
if inplace:
|
||||
x[:warped] = left
|
||||
x[warped:] = right
|
||||
return x
|
||||
return numpy.concatenate((left, right), 0)
|
||||
elif mode == "sparse_image_warp":
|
||||
import paddle
|
||||
|
||||
from espnet.utils import spec_augment
|
||||
|
||||
# TODO(karita): make this differentiable again
|
||||
return spec_augment.time_warp(paddle.to_tensor(x), window).numpy()
|
||||
else:
|
||||
raise NotImplementedError("unknown resize mode: " + mode +
|
||||
", choose one from (PIL, sparse_image_warp).")
|
||||
|
||||
|
||||
class TimeWarp(FuncTrans):
|
||||
_func = time_warp
|
||||
__doc__ = time_warp.__doc__
|
||||
|
||||
def __call__(self, x, train):
|
||||
if not train:
|
||||
return x
|
||||
return super().__call__(x)
|
||||
|
||||
|
||||
def freq_mask(x, F=30, n_mask=2, replace_with_zero=True, inplace=False):
|
||||
"""freq mask for spec agument
|
||||
|
||||
:param numpy.ndarray x: (time, freq)
|
||||
:param int n_mask: the number of masks
|
||||
:param bool inplace: overwrite
|
||||
:param bool replace_with_zero: pad zero on mask if true else use mean
|
||||
"""
|
||||
if inplace:
|
||||
cloned = x
|
||||
else:
|
||||
cloned = x.copy()
|
||||
|
||||
num_mel_channels = cloned.shape[1]
|
||||
fs = numpy.random.randint(0, F, size=(n_mask, 2))
|
||||
|
||||
for f, mask_end in fs:
|
||||
f_zero = random.randrange(0, num_mel_channels - f)
|
||||
mask_end += f_zero
|
||||
|
||||
# avoids randrange error if values are equal and range is empty
|
||||
if f_zero == f_zero + f:
|
||||
continue
|
||||
|
||||
if replace_with_zero:
|
||||
cloned[:, f_zero:mask_end] = 0
|
||||
else:
|
||||
cloned[:, f_zero:mask_end] = cloned.mean()
|
||||
return cloned
|
||||
|
||||
|
||||
class FreqMask(FuncTrans):
|
||||
_func = freq_mask
|
||||
__doc__ = freq_mask.__doc__
|
||||
|
||||
def __call__(self, x, train):
|
||||
if not train:
|
||||
return x
|
||||
return super().__call__(x)
|
||||
|
||||
|
||||
def time_mask(spec, T=40, n_mask=2, replace_with_zero=True, inplace=False):
|
||||
"""freq mask for spec agument
|
||||
|
||||
:param numpy.ndarray spec: (time, freq)
|
||||
:param int n_mask: the number of masks
|
||||
:param bool inplace: overwrite
|
||||
:param bool replace_with_zero: pad zero on mask if true else use mean
|
||||
"""
|
||||
if inplace:
|
||||
cloned = spec
|
||||
else:
|
||||
cloned = spec.copy()
|
||||
len_spectro = cloned.shape[0]
|
||||
ts = numpy.random.randint(0, T, size=(n_mask, 2))
|
||||
for t, mask_end in ts:
|
||||
# avoid randint range error
|
||||
if len_spectro - t <= 0:
|
||||
continue
|
||||
t_zero = random.randrange(0, len_spectro - t)
|
||||
|
||||
# avoids randrange error if values are equal and range is empty
|
||||
if t_zero == t_zero + t:
|
||||
continue
|
||||
|
||||
mask_end += t_zero
|
||||
if replace_with_zero:
|
||||
cloned[t_zero:mask_end] = 0
|
||||
else:
|
||||
cloned[t_zero:mask_end] = cloned.mean()
|
||||
return cloned
|
||||
|
||||
|
||||
class TimeMask(FuncTrans):
|
||||
_func = time_mask
|
||||
__doc__ = time_mask.__doc__
|
||||
|
||||
def __call__(self, x, train):
|
||||
if not train:
|
||||
return x
|
||||
return super().__call__(x)
|
||||
|
||||
|
||||
def spec_augment(
|
||||
x,
|
||||
resize_mode="PIL",
|
||||
max_time_warp=80,
|
||||
max_freq_width=27,
|
||||
n_freq_mask=2,
|
||||
max_time_width=100,
|
||||
n_time_mask=2,
|
||||
inplace=True,
|
||||
replace_with_zero=True, ):
|
||||
"""spec agument
|
||||
|
||||
apply random time warping and time/freq masking
|
||||
default setting is based on LD (Librispeech double) in Table 2
|
||||
https://arxiv.org/pdf/1904.08779.pdf
|
||||
|
||||
:param numpy.ndarray x: (time, freq)
|
||||
:param str resize_mode: "PIL" (fast, nondifferentiable) or "sparse_image_warp"
|
||||
(slow, differentiable)
|
||||
:param int max_time_warp: maximum frames to warp the center frame in spectrogram (W)
|
||||
:param int freq_mask_width: maximum width of the random freq mask (F)
|
||||
:param int n_freq_mask: the number of the random freq mask (m_F)
|
||||
:param int time_mask_width: maximum width of the random time mask (T)
|
||||
:param int n_time_mask: the number of the random time mask (m_T)
|
||||
:param bool inplace: overwrite intermediate array
|
||||
:param bool replace_with_zero: pad zero on mask if true else use mean
|
||||
"""
|
||||
assert isinstance(x, numpy.ndarray)
|
||||
assert x.ndim == 2
|
||||
x = time_warp(x, max_time_warp, inplace=inplace, mode=resize_mode)
|
||||
x = freq_mask(
|
||||
x,
|
||||
max_freq_width,
|
||||
n_freq_mask,
|
||||
inplace=inplace,
|
||||
replace_with_zero=replace_with_zero, )
|
||||
x = time_mask(
|
||||
x,
|
||||
max_time_width,
|
||||
n_time_mask,
|
||||
inplace=inplace,
|
||||
replace_with_zero=replace_with_zero, )
|
||||
return x
|
||||
|
||||
|
||||
class SpecAugment(FuncTrans):
|
||||
_func = spec_augment
|
||||
__doc__ = spec_augment.__doc__
|
||||
|
||||
def __call__(self, x, train):
|
||||
if not train:
|
||||
return x
|
||||
return super().__call__(x)
|
@ -0,0 +1,475 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from espnet(https://github.com/espnet/espnet)
|
||||
import librosa
|
||||
import numpy as np
|
||||
import paddle
|
||||
from python_speech_features import logfbank
|
||||
|
||||
from ..compliance import kaldi
|
||||
|
||||
|
||||
def stft(x,
|
||||
n_fft,
|
||||
n_shift,
|
||||
win_length=None,
|
||||
window="hann",
|
||||
center=True,
|
||||
pad_mode="reflect"):
|
||||
# x: [Time, Channel]
|
||||
if x.ndim == 1:
|
||||
single_channel = True
|
||||
# x: [Time] -> [Time, Channel]
|
||||
x = x[:, None]
|
||||
else:
|
||||
single_channel = False
|
||||
x = x.astype(np.float32)
|
||||
|
||||
# FIXME(kamo): librosa.stft can't use multi-channel?
|
||||
# x: [Time, Channel, Freq]
|
||||
x = np.stack(
|
||||
[
|
||||
librosa.stft(
|
||||
y=x[:, ch],
|
||||
n_fft=n_fft,
|
||||
hop_length=n_shift,
|
||||
win_length=win_length,
|
||||
window=window,
|
||||
center=center,
|
||||
pad_mode=pad_mode, ).T for ch in range(x.shape[1])
|
||||
],
|
||||
axis=1, )
|
||||
|
||||
if single_channel:
|
||||
# x: [Time, Channel, Freq] -> [Time, Freq]
|
||||
x = x[:, 0]
|
||||
return x
|
||||
|
||||
|
||||
def istft(x, n_shift, win_length=None, window="hann", center=True):
|
||||
# x: [Time, Channel, Freq]
|
||||
if x.ndim == 2:
|
||||
single_channel = True
|
||||
# x: [Time, Freq] -> [Time, Channel, Freq]
|
||||
x = x[:, None, :]
|
||||
else:
|
||||
single_channel = False
|
||||
|
||||
# x: [Time, Channel]
|
||||
x = np.stack(
|
||||
[
|
||||
librosa.istft(
|
||||
stft_matrix=x[:, ch].T, # [Time, Freq] -> [Freq, Time]
|
||||
hop_length=n_shift,
|
||||
win_length=win_length,
|
||||
window=window,
|
||||
center=center, ) for ch in range(x.shape[1])
|
||||
],
|
||||
axis=1, )
|
||||
|
||||
if single_channel:
|
||||
# x: [Time, Channel] -> [Time]
|
||||
x = x[:, 0]
|
||||
return x
|
||||
|
||||
|
||||
def stft2logmelspectrogram(x_stft,
|
||||
fs,
|
||||
n_mels,
|
||||
n_fft,
|
||||
fmin=None,
|
||||
fmax=None,
|
||||
eps=1e-10):
|
||||
# x_stft: (Time, Channel, Freq) or (Time, Freq)
|
||||
fmin = 0 if fmin is None else fmin
|
||||
fmax = fs / 2 if fmax is None else fmax
|
||||
|
||||
# spc: (Time, Channel, Freq) or (Time, Freq)
|
||||
spc = np.abs(x_stft)
|
||||
# mel_basis: (Mel_freq, Freq)
|
||||
mel_basis = librosa.filters.mel(
|
||||
sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
|
||||
# lmspc: (Time, Channel, Mel_freq) or (Time, Mel_freq)
|
||||
lmspc = np.log10(np.maximum(eps, np.dot(spc, mel_basis.T)))
|
||||
|
||||
return lmspc
|
||||
|
||||
|
||||
def spectrogram(x, n_fft, n_shift, win_length=None, window="hann"):
|
||||
# x: (Time, Channel) -> spc: (Time, Channel, Freq)
|
||||
spc = np.abs(stft(x, n_fft, n_shift, win_length, window=window))
|
||||
return spc
|
||||
|
||||
|
||||
def logmelspectrogram(
|
||||
x,
|
||||
fs,
|
||||
n_mels,
|
||||
n_fft,
|
||||
n_shift,
|
||||
win_length=None,
|
||||
window="hann",
|
||||
fmin=None,
|
||||
fmax=None,
|
||||
eps=1e-10,
|
||||
pad_mode="reflect", ):
|
||||
# stft: (Time, Channel, Freq) or (Time, Freq)
|
||||
x_stft = stft(
|
||||
x,
|
||||
n_fft=n_fft,
|
||||
n_shift=n_shift,
|
||||
win_length=win_length,
|
||||
window=window,
|
||||
pad_mode=pad_mode, )
|
||||
|
||||
return stft2logmelspectrogram(
|
||||
x_stft,
|
||||
fs=fs,
|
||||
n_mels=n_mels,
|
||||
n_fft=n_fft,
|
||||
fmin=fmin,
|
||||
fmax=fmax,
|
||||
eps=eps)
|
||||
|
||||
|
||||
class Spectrogram():
|
||||
def __init__(self, n_fft, n_shift, win_length=None, window="hann"):
|
||||
self.n_fft = n_fft
|
||||
self.n_shift = n_shift
|
||||
self.win_length = win_length
|
||||
self.window = window
|
||||
|
||||
def __repr__(self):
|
||||
return ("{name}(n_fft={n_fft}, n_shift={n_shift}, "
|
||||
"win_length={win_length}, window={window})".format(
|
||||
name=self.__class__.__name__,
|
||||
n_fft=self.n_fft,
|
||||
n_shift=self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window, ))
|
||||
|
||||
def __call__(self, x):
|
||||
return spectrogram(
|
||||
x,
|
||||
n_fft=self.n_fft,
|
||||
n_shift=self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window, )
|
||||
|
||||
|
||||
class LogMelSpectrogram():
|
||||
def __init__(
|
||||
self,
|
||||
fs,
|
||||
n_mels,
|
||||
n_fft,
|
||||
n_shift,
|
||||
win_length=None,
|
||||
window="hann",
|
||||
fmin=None,
|
||||
fmax=None,
|
||||
eps=1e-10, ):
|
||||
self.fs = fs
|
||||
self.n_mels = n_mels
|
||||
self.n_fft = n_fft
|
||||
self.n_shift = n_shift
|
||||
self.win_length = win_length
|
||||
self.window = window
|
||||
self.fmin = fmin
|
||||
self.fmax = fmax
|
||||
self.eps = eps
|
||||
|
||||
def __repr__(self):
|
||||
return ("{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
|
||||
"n_shift={n_shift}, win_length={win_length}, window={window}, "
|
||||
"fmin={fmin}, fmax={fmax}, eps={eps}))".format(
|
||||
name=self.__class__.__name__,
|
||||
fs=self.fs,
|
||||
n_mels=self.n_mels,
|
||||
n_fft=self.n_fft,
|
||||
n_shift=self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
fmin=self.fmin,
|
||||
fmax=self.fmax,
|
||||
eps=self.eps, ))
|
||||
|
||||
def __call__(self, x):
|
||||
return logmelspectrogram(
|
||||
x,
|
||||
fs=self.fs,
|
||||
n_mels=self.n_mels,
|
||||
n_fft=self.n_fft,
|
||||
n_shift=self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window, )
|
||||
|
||||
|
||||
class Stft2LogMelSpectrogram():
|
||||
def __init__(self, fs, n_mels, n_fft, fmin=None, fmax=None, eps=1e-10):
|
||||
self.fs = fs
|
||||
self.n_mels = n_mels
|
||||
self.n_fft = n_fft
|
||||
self.fmin = fmin
|
||||
self.fmax = fmax
|
||||
self.eps = eps
|
||||
|
||||
def __repr__(self):
|
||||
return ("{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
|
||||
"fmin={fmin}, fmax={fmax}, eps={eps}))".format(
|
||||
name=self.__class__.__name__,
|
||||
fs=self.fs,
|
||||
n_mels=self.n_mels,
|
||||
n_fft=self.n_fft,
|
||||
fmin=self.fmin,
|
||||
fmax=self.fmax,
|
||||
eps=self.eps, ))
|
||||
|
||||
def __call__(self, x):
|
||||
return stft2logmelspectrogram(
|
||||
x,
|
||||
fs=self.fs,
|
||||
n_mels=self.n_mels,
|
||||
n_fft=self.n_fft,
|
||||
fmin=self.fmin,
|
||||
fmax=self.fmax, )
|
||||
|
||||
|
||||
class Stft():
|
||||
def __init__(
|
||||
self,
|
||||
n_fft,
|
||||
n_shift,
|
||||
win_length=None,
|
||||
window="hann",
|
||||
center=True,
|
||||
pad_mode="reflect", ):
|
||||
self.n_fft = n_fft
|
||||
self.n_shift = n_shift
|
||||
self.win_length = win_length
|
||||
self.window = window
|
||||
self.center = center
|
||||
self.pad_mode = pad_mode
|
||||
|
||||
def __repr__(self):
|
||||
return ("{name}(n_fft={n_fft}, n_shift={n_shift}, "
|
||||
"win_length={win_length}, window={window},"
|
||||
"center={center}, pad_mode={pad_mode})".format(
|
||||
name=self.__class__.__name__,
|
||||
n_fft=self.n_fft,
|
||||
n_shift=self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
center=self.center,
|
||||
pad_mode=self.pad_mode, ))
|
||||
|
||||
def __call__(self, x):
|
||||
return stft(
|
||||
x,
|
||||
self.n_fft,
|
||||
self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
center=self.center,
|
||||
pad_mode=self.pad_mode, )
|
||||
|
||||
|
||||
class IStft():
|
||||
def __init__(self, n_shift, win_length=None, window="hann", center=True):
|
||||
self.n_shift = n_shift
|
||||
self.win_length = win_length
|
||||
self.window = window
|
||||
self.center = center
|
||||
|
||||
def __repr__(self):
|
||||
return ("{name}(n_shift={n_shift}, "
|
||||
"win_length={win_length}, window={window},"
|
||||
"center={center})".format(
|
||||
name=self.__class__.__name__,
|
||||
n_shift=self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
center=self.center, ))
|
||||
|
||||
def __call__(self, x):
|
||||
return istft(
|
||||
x,
|
||||
self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
center=self.center, )
|
||||
|
||||
|
||||
class LogMelSpectrogramKaldi():
|
||||
def __init__(
|
||||
self,
|
||||
fs=16000,
|
||||
n_mels=80,
|
||||
n_shift=160, # unit:sample, 10ms
|
||||
win_length=400, # unit:sample, 25ms
|
||||
energy_floor=0.0,
|
||||
dither=0.1):
|
||||
"""
|
||||
The Kaldi implementation of LogMelSpectrogram
|
||||
Args:
|
||||
fs (int): sample rate of the audio
|
||||
n_mels (int): number of mel filter banks
|
||||
n_shift (int): number of points in a frame shift
|
||||
win_length (int): number of points in a frame windows
|
||||
energy_floor (float): Floor on energy in Spectrogram computation (absolute)
|
||||
dither (float): Dithering constant
|
||||
|
||||
Returns:
|
||||
LogMelSpectrogramKaldi
|
||||
"""
|
||||
|
||||
self.fs = fs
|
||||
self.n_mels = n_mels
|
||||
num_point_ms = fs / 1000
|
||||
self.n_frame_length = win_length / num_point_ms
|
||||
self.n_frame_shift = n_shift / num_point_ms
|
||||
self.energy_floor = energy_floor
|
||||
self.dither = dither
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"{name}(fs={fs}, n_mels={n_mels}, "
|
||||
"n_frame_shift={n_frame_shift}, n_frame_length={n_frame_length}, "
|
||||
"dither={dither}))".format(
|
||||
name=self.__class__.__name__,
|
||||
fs=self.fs,
|
||||
n_mels=self.n_mels,
|
||||
n_frame_shift=self.n_frame_shift,
|
||||
n_frame_length=self.n_frame_length,
|
||||
dither=self.dither, ))
|
||||
|
||||
def __call__(self, x, train):
|
||||
"""
|
||||
Args:
|
||||
x (np.ndarray): shape (Ti,)
|
||||
train (bool): True, train mode.
|
||||
|
||||
Raises:
|
||||
ValueError: not support (Ti, C)
|
||||
|
||||
Returns:
|
||||
np.ndarray: (T, D)
|
||||
"""
|
||||
dither = self.dither if train else 0.0
|
||||
if x.ndim != 1:
|
||||
raise ValueError("Not support x: [Time, Channel]")
|
||||
waveform = paddle.to_tensor(np.expand_dims(x, 0), dtype=paddle.float32)
|
||||
mat = kaldi.fbank(
|
||||
waveform,
|
||||
n_mels=self.n_mels,
|
||||
frame_length=self.n_frame_length,
|
||||
frame_shift=self.n_frame_shift,
|
||||
dither=dither,
|
||||
energy_floor=self.energy_floor,
|
||||
sr=self.fs)
|
||||
mat = np.squeeze(mat.numpy())
|
||||
return mat
|
||||
|
||||
|
||||
class LogMelSpectrogramKaldi_decay():
|
||||
def __init__(
|
||||
self,
|
||||
fs=16000,
|
||||
n_mels=80,
|
||||
n_fft=512, # fft point
|
||||
n_shift=160, # unit:sample, 10ms
|
||||
win_length=400, # unit:sample, 25ms
|
||||
window="povey",
|
||||
fmin=20,
|
||||
fmax=None,
|
||||
eps=1e-10,
|
||||
dither=1.0):
|
||||
self.fs = fs
|
||||
self.n_mels = n_mels
|
||||
self.n_fft = n_fft
|
||||
if n_shift > win_length:
|
||||
raise ValueError("Stride size must not be greater than "
|
||||
"window size.")
|
||||
self.n_shift = n_shift / fs # unit: ms
|
||||
self.win_length = win_length / fs # unit: ms
|
||||
|
||||
self.window = window
|
||||
self.fmin = fmin
|
||||
if fmax is None:
|
||||
fmax_ = fmax if fmax else self.fs / 2
|
||||
elif fmax > int(self.fs / 2):
|
||||
raise ValueError("fmax must not be greater than half of "
|
||||
"sample rate.")
|
||||
self.fmax = fmax_
|
||||
|
||||
self.eps = eps
|
||||
self.remove_dc_offset = True
|
||||
self.preemph = 0.97
|
||||
self.dither = dither # only work in train mode
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
|
||||
"n_shift={n_shift}, win_length={win_length}, preemph={preemph}, window={window}, "
|
||||
"fmin={fmin}, fmax={fmax}, eps={eps}, dither={dither}))".format(
|
||||
name=self.__class__.__name__,
|
||||
fs=self.fs,
|
||||
n_mels=self.n_mels,
|
||||
n_fft=self.n_fft,
|
||||
n_shift=self.n_shift,
|
||||
preemph=self.preemph,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
fmin=self.fmin,
|
||||
fmax=self.fmax,
|
||||
eps=self.eps,
|
||||
dither=self.dither, ))
|
||||
|
||||
def __call__(self, x, train):
|
||||
"""
|
||||
|
||||
Args:
|
||||
x (np.ndarray): shape (Ti,)
|
||||
train (bool): True, train mode.
|
||||
|
||||
Raises:
|
||||
ValueError: not support (Ti, C)
|
||||
|
||||
Returns:
|
||||
np.ndarray: (T, D)
|
||||
"""
|
||||
dither = self.dither if train else 0.0
|
||||
if x.ndim != 1:
|
||||
raise ValueError("Not support x: [Time, Channel]")
|
||||
|
||||
if x.dtype in np.sctypes['float']:
|
||||
# PCM32 -> PCM16
|
||||
bits = np.iinfo(np.int16).bits
|
||||
x = x * 2**(bits - 1)
|
||||
|
||||
# logfbank need PCM16 input
|
||||
y = logfbank(
|
||||
signal=x,
|
||||
samplerate=self.fs,
|
||||
winlen=self.win_length, # unit ms
|
||||
winstep=self.n_shift, # unit ms
|
||||
nfilt=self.n_mels,
|
||||
nfft=self.n_fft,
|
||||
lowfreq=self.fmin,
|
||||
highfreq=self.fmax,
|
||||
dither=dither,
|
||||
remove_dc_offset=self.remove_dc_offset,
|
||||
preemph=self.preemph,
|
||||
wintype=self.window)
|
||||
return y
|
@ -0,0 +1,35 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from espnet(https://github.com/espnet/espnet)
|
||||
|
||||
|
||||
class TransformInterface:
|
||||
"""Transform Interface"""
|
||||
|
||||
def __call__(self, x):
|
||||
raise NotImplementedError("__call__ method is not implemented")
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
return parser
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + "()"
|
||||
|
||||
|
||||
class Identity(TransformInterface):
|
||||
"""Identity Function"""
|
||||
|
||||
def __call__(self, x):
|
||||
return x
|
@ -0,0 +1,158 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from espnet(https://github.com/espnet/espnet)
|
||||
"""Transformation module."""
|
||||
import copy
|
||||
import io
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Sequence
|
||||
from inspect import signature
|
||||
|
||||
import yaml
|
||||
|
||||
from ..utils.dynamic_import import dynamic_import
|
||||
|
||||
import_alias = dict(
|
||||
identity="paddlespeech.audio.transform.transform_interface:Identity",
|
||||
time_warp="paddlespeech.audio.transform.spec_augment:TimeWarp",
|
||||
time_mask="paddlespeech.audio.transform.spec_augment:TimeMask",
|
||||
freq_mask="paddlespeech.audio.transform.spec_augment:FreqMask",
|
||||
spec_augment="paddlespeech.audio.transform.spec_augment:SpecAugment",
|
||||
speed_perturbation="paddlespeech.audio.transform.perturb:SpeedPerturbation",
|
||||
speed_perturbation_sox="paddlespeech.audio.transform.perturb:SpeedPerturbationSox",
|
||||
volume_perturbation="paddlespeech.audio.transform.perturb:VolumePerturbation",
|
||||
noise_injection="paddlespeech.audio.transform.perturb:NoiseInjection",
|
||||
bandpass_perturbation="paddlespeech.audio.transform.perturb:BandpassPerturbation",
|
||||
rir_convolve="paddlespeech.audio.transform.perturb:RIRConvolve",
|
||||
delta="paddlespeech.audio.transform.add_deltas:AddDeltas",
|
||||
cmvn="paddlespeech.audio.transform.cmvn:CMVN",
|
||||
utterance_cmvn="paddlespeech.audio.transform.cmvn:UtteranceCMVN",
|
||||
fbank="paddlespeech.audio.transform.spectrogram:LogMelSpectrogram",
|
||||
spectrogram="paddlespeech.audio.transform.spectrogram:Spectrogram",
|
||||
stft="paddlespeech.audio.transform.spectrogram:Stft",
|
||||
istft="paddlespeech.audio.transform.spectrogram:IStft",
|
||||
stft2fbank="paddlespeech.audio.transform.spectrogram:Stft2LogMelSpectrogram",
|
||||
wpe="paddlespeech.audio.transform.wpe:WPE",
|
||||
channel_selector="paddlespeech.audio.transform.channel_selector:ChannelSelector",
|
||||
fbank_kaldi="paddlespeech.audio.transform.spectrogram:LogMelSpectrogramKaldi",
|
||||
cmvn_json="paddlespeech.audio.transform.cmvn:GlobalCMVN")
|
||||
|
||||
|
||||
class Transformation():
|
||||
"""Apply some functions to the mini-batch
|
||||
|
||||
Examples:
|
||||
>>> kwargs = {"process": [{"type": "fbank",
|
||||
... "n_mels": 80,
|
||||
... "fs": 16000},
|
||||
... {"type": "cmvn",
|
||||
... "stats": "data/train/cmvn.ark",
|
||||
... "norm_vars": True},
|
||||
... {"type": "delta", "window": 2, "order": 2}]}
|
||||
>>> transform = Transformation(kwargs)
|
||||
>>> bs = 10
|
||||
>>> xs = [np.random.randn(100, 80).astype(np.float32)
|
||||
... for _ in range(bs)]
|
||||
>>> xs = transform(xs)
|
||||
"""
|
||||
|
||||
def __init__(self, conffile=None):
|
||||
if conffile is not None:
|
||||
if isinstance(conffile, dict):
|
||||
self.conf = copy.deepcopy(conffile)
|
||||
else:
|
||||
with io.open(conffile, encoding="utf-8") as f:
|
||||
self.conf = yaml.safe_load(f)
|
||||
assert isinstance(self.conf, dict), type(self.conf)
|
||||
else:
|
||||
self.conf = {"mode": "sequential", "process": []}
|
||||
|
||||
self.functions = OrderedDict()
|
||||
if self.conf.get("mode", "sequential") == "sequential":
|
||||
for idx, process in enumerate(self.conf["process"]):
|
||||
assert isinstance(process, dict), type(process)
|
||||
opts = dict(process)
|
||||
process_type = opts.pop("type")
|
||||
class_obj = dynamic_import(process_type, import_alias)
|
||||
# TODO(karita): assert issubclass(class_obj, TransformInterface)
|
||||
try:
|
||||
self.functions[idx] = class_obj(**opts)
|
||||
except TypeError:
|
||||
try:
|
||||
signa = signature(class_obj)
|
||||
except ValueError:
|
||||
# Some function, e.g. built-in function, are failed
|
||||
pass
|
||||
else:
|
||||
logging.error("Expected signature: {}({})".format(
|
||||
class_obj.__name__, signa))
|
||||
raise
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Not supporting mode={}".format(self.conf["mode"]))
|
||||
|
||||
def __repr__(self):
|
||||
rep = "\n" + "\n".join(" {}: {}".format(k, v)
|
||||
for k, v in self.functions.items())
|
||||
return "{}({})".format(self.__class__.__name__, rep)
|
||||
|
||||
def __call__(self, xs, uttid_list=None, **kwargs):
|
||||
"""Return new mini-batch
|
||||
|
||||
:param Union[Sequence[np.ndarray], np.ndarray] xs:
|
||||
:param Union[Sequence[str], str] uttid_list:
|
||||
:return: batch:
|
||||
:rtype: List[np.ndarray]
|
||||
"""
|
||||
if not isinstance(xs, Sequence):
|
||||
is_batch = False
|
||||
xs = [xs]
|
||||
else:
|
||||
is_batch = True
|
||||
|
||||
if isinstance(uttid_list, str):
|
||||
uttid_list = [uttid_list for _ in range(len(xs))]
|
||||
|
||||
if self.conf.get("mode", "sequential") == "sequential":
|
||||
for idx in range(len(self.conf["process"])):
|
||||
func = self.functions[idx]
|
||||
# TODO(karita): use TrainingTrans and UttTrans to check __call__ args
|
||||
# Derive only the args which the func has
|
||||
try:
|
||||
param = signature(func).parameters
|
||||
except ValueError:
|
||||
# Some function, e.g. built-in function, are failed
|
||||
param = {}
|
||||
_kwargs = {k: v for k, v in kwargs.items() if k in param}
|
||||
try:
|
||||
if uttid_list is not None and "uttid" in param:
|
||||
xs = [
|
||||
func(x, u, **_kwargs)
|
||||
for x, u in zip(xs, uttid_list)
|
||||
]
|
||||
else:
|
||||
xs = [func(x, **_kwargs) for x in xs]
|
||||
except Exception:
|
||||
logging.fatal("Catch a exception from {}th func: {}".format(
|
||||
idx, func))
|
||||
raise
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Not supporting mode={}".format(self.conf["mode"]))
|
||||
|
||||
if is_batch:
|
||||
return xs
|
||||
else:
|
||||
return xs[0]
|
@ -0,0 +1,58 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from espnet(https://github.com/espnet/espnet)
|
||||
from nara_wpe.wpe import wpe
|
||||
|
||||
|
||||
class WPE(object):
|
||||
def __init__(self,
|
||||
taps=10,
|
||||
delay=3,
|
||||
iterations=3,
|
||||
psd_context=0,
|
||||
statistics_mode="full"):
|
||||
self.taps = taps
|
||||
self.delay = delay
|
||||
self.iterations = iterations
|
||||
self.psd_context = psd_context
|
||||
self.statistics_mode = statistics_mode
|
||||
|
||||
def __repr__(self):
|
||||
return ("{name}(taps={taps}, delay={delay}"
|
||||
"iterations={iterations}, psd_context={psd_context}, "
|
||||
"statistics_mode={statistics_mode})".format(
|
||||
name=self.__class__.__name__,
|
||||
taps=self.taps,
|
||||
delay=self.delay,
|
||||
iterations=self.iterations,
|
||||
psd_context=self.psd_context,
|
||||
statistics_mode=self.statistics_mode, ))
|
||||
|
||||
def __call__(self, xs):
|
||||
"""Return enhanced
|
||||
|
||||
:param np.ndarray xs: (Time, Channel, Frequency)
|
||||
:return: enhanced_xs
|
||||
:rtype: np.ndarray
|
||||
|
||||
"""
|
||||
# nara_wpe.wpe: (F, C, T)
|
||||
xs = wpe(
|
||||
xs.transpose((2, 1, 0)),
|
||||
taps=self.taps,
|
||||
delay=self.delay,
|
||||
iterations=self.iterations,
|
||||
psd_context=self.psd_context,
|
||||
statistics_mode=self.statistics_mode, )
|
||||
return xs.transpose(2, 1, 0)
|
@ -0,0 +1,35 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from espnet(https://github.com/espnet/espnet)
|
||||
import inspect
|
||||
|
||||
|
||||
def check_kwargs(func, kwargs, name=None):
|
||||
"""check kwargs are valid for func
|
||||
|
||||
If kwargs are invalid, raise TypeError as same as python default
|
||||
:param function func: function to be validated
|
||||
:param dict kwargs: keyword arguments for func
|
||||
:param str name: name used in TypeError (default is func name)
|
||||
"""
|
||||
try:
|
||||
params = inspect.signature(func).parameters
|
||||
except ValueError:
|
||||
return
|
||||
if name is None:
|
||||
name = func.__name__
|
||||
for k in kwargs.keys():
|
||||
if k not in params:
|
||||
raise TypeError(
|
||||
f"{name}() got an unexpected keyword argument '{k}'")
|
@ -0,0 +1,38 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from espnet(https://github.com/espnet/espnet)
|
||||
import importlib
|
||||
|
||||
__all__ = ["dynamic_import"]
|
||||
|
||||
|
||||
def dynamic_import(import_path, alias=dict()):
|
||||
"""dynamic import module and class
|
||||
|
||||
:param str import_path: syntax 'module_name:class_name'
|
||||
e.g., 'paddlespeech.s2t.models.u2:U2Model'
|
||||
:param dict alias: shortcut for registered class
|
||||
:return: imported class
|
||||
"""
|
||||
if import_path not in alias and ":" not in import_path:
|
||||
raise ValueError(
|
||||
"import_path should be one of {} or "
|
||||
'include ":", e.g. "paddlespeech.s2t.models.u2:U2Model" : '
|
||||
"{}".format(set(alias), import_path))
|
||||
if ":" not in import_path:
|
||||
import_path = alias[import_path]
|
||||
|
||||
module_name, objname = import_path.split(":")
|
||||
m = importlib.import_module(module_name)
|
||||
return getattr(m, objname)
|
@ -0,0 +1,195 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unility functions for Transformer."""
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import paddle
|
||||
|
||||
from .log import Logger
|
||||
|
||||
__all__ = ["pad_sequence", "add_sos_eos", "th_accuracy", "has_tensor"]
|
||||
|
||||
logger = Logger(__name__)
|
||||
|
||||
|
||||
def has_tensor(val):
|
||||
if isinstance(val, (list, tuple)):
|
||||
for item in val:
|
||||
if has_tensor(item):
|
||||
return True
|
||||
elif isinstance(val, dict):
|
||||
for k, v in val.items():
|
||||
print(k)
|
||||
if has_tensor(v):
|
||||
return True
|
||||
else:
|
||||
return paddle.is_tensor(val)
|
||||
|
||||
|
||||
def pad_sequence(sequences: List[paddle.Tensor],
|
||||
batch_first: bool=False,
|
||||
padding_value: float=0.0) -> paddle.Tensor:
|
||||
r"""Pad a list of variable length Tensors with ``padding_value``
|
||||
|
||||
``pad_sequence`` stacks a list of Tensors along a new dimension,
|
||||
and pads them to equal length. For example, if the input is list of
|
||||
sequences with size ``L x *`` and if batch_first is False, and ``T x B x *``
|
||||
otherwise.
|
||||
|
||||
`B` is batch size. It is equal to the number of elements in ``sequences``.
|
||||
`T` is length of the longest sequence.
|
||||
`L` is length of the sequence.
|
||||
`*` is any number of trailing dimensions, including none.
|
||||
|
||||
Example:
|
||||
>>> from paddle.nn.utils.rnn import pad_sequence
|
||||
>>> a = paddle.ones(25, 300)
|
||||
>>> b = paddle.ones(22, 300)
|
||||
>>> c = paddle.ones(15, 300)
|
||||
>>> pad_sequence([a, b, c]).shape
|
||||
paddle.Tensor([25, 3, 300])
|
||||
|
||||
Note:
|
||||
This function returns a Tensor of size ``T x B x *`` or ``B x T x *``
|
||||
where `T` is the length of the longest sequence. This function assumes
|
||||
trailing dimensions and type of all the Tensors in sequences are same.
|
||||
|
||||
Args:
|
||||
sequences (list[Tensor]): list of variable length sequences.
|
||||
batch_first (bool, optional): output will be in ``B x T x *`` if True, or in
|
||||
``T x B x *`` otherwise
|
||||
padding_value (float, optional): value for padded elements. Default: 0.
|
||||
|
||||
Returns:
|
||||
Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``.
|
||||
Tensor of size ``B x T x *`` otherwise
|
||||
"""
|
||||
|
||||
# assuming trailing dimensions and type of all the Tensors
|
||||
# in sequences are same and fetching those from sequences[0]
|
||||
max_size = paddle.shape(sequences[0])
|
||||
# (TODO Hui Zhang): slice not supprot `end==start`
|
||||
# trailing_dims = max_size[1:]
|
||||
trailing_dims = tuple(
|
||||
max_size[1:].numpy().tolist()) if sequences[0].ndim >= 2 else ()
|
||||
max_len = max([s.shape[0] for s in sequences])
|
||||
if batch_first:
|
||||
out_dims = (len(sequences), max_len) + trailing_dims
|
||||
else:
|
||||
out_dims = (max_len, len(sequences)) + trailing_dims
|
||||
out_tensor = paddle.full(out_dims, padding_value, sequences[0].dtype)
|
||||
for i, tensor in enumerate(sequences):
|
||||
length = tensor.shape[0]
|
||||
# use index notation to prevent duplicate references to the tensor
|
||||
logger.info(
|
||||
f"length {length}, out_tensor {out_tensor.shape}, tensor {tensor.shape}"
|
||||
)
|
||||
if batch_first:
|
||||
# TODO (Hui Zhang): set_value op not supprot `end==start`
|
||||
# TODO (Hui Zhang): set_value op not support int16
|
||||
# TODO (Hui Zhang): set_varbase 2 rank not support [0,0,...]
|
||||
# out_tensor[i, :length, ...] = tensor
|
||||
if length != 0:
|
||||
out_tensor[i, :length] = tensor
|
||||
else:
|
||||
out_tensor[i, length] = tensor
|
||||
else:
|
||||
# TODO (Hui Zhang): set_value op not supprot `end==start`
|
||||
# out_tensor[:length, i, ...] = tensor
|
||||
if length != 0:
|
||||
out_tensor[:length, i] = tensor
|
||||
else:
|
||||
out_tensor[length, i] = tensor
|
||||
|
||||
return out_tensor
|
||||
|
||||
|
||||
def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
|
||||
ignore_id: int) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""Add <sos> and <eos> labels.
|
||||
Args:
|
||||
ys_pad (paddle.Tensor): batch of padded target sequences (B, Lmax)
|
||||
sos (int): index of <sos>
|
||||
eos (int): index of <eeos>
|
||||
ignore_id (int): index of padding
|
||||
Returns:
|
||||
ys_in (paddle.Tensor) : (B, Lmax + 1)
|
||||
ys_out (paddle.Tensor) : (B, Lmax + 1)
|
||||
Examples:
|
||||
>>> sos_id = 10
|
||||
>>> eos_id = 11
|
||||
>>> ignore_id = -1
|
||||
>>> ys_pad
|
||||
tensor([[ 1, 2, 3, 4, 5],
|
||||
[ 4, 5, 6, -1, -1],
|
||||
[ 7, 8, 9, -1, -1]], dtype=paddle.int32)
|
||||
>>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id)
|
||||
>>> ys_in
|
||||
tensor([[10, 1, 2, 3, 4, 5],
|
||||
[10, 4, 5, 6, 11, 11],
|
||||
[10, 7, 8, 9, 11, 11]])
|
||||
>>> ys_out
|
||||
tensor([[ 1, 2, 3, 4, 5, 11],
|
||||
[ 4, 5, 6, 11, -1, -1],
|
||||
[ 7, 8, 9, 11, -1, -1]])
|
||||
"""
|
||||
# TODO(Hui Zhang): using comment code,
|
||||
#_sos = paddle.to_tensor(
|
||||
# [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
|
||||
#_eos = paddle.to_tensor(
|
||||
# [eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
|
||||
#ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
|
||||
#ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]
|
||||
#ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]
|
||||
#return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)
|
||||
B = ys_pad.shape[0]
|
||||
_sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos
|
||||
_eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos
|
||||
ys_in = paddle.cat([_sos, ys_pad], dim=1)
|
||||
mask_pad = (ys_in == ignore_id)
|
||||
ys_in = ys_in.masked_fill(mask_pad, eos)
|
||||
|
||||
ys_out = paddle.cat([ys_pad, _eos], dim=1)
|
||||
ys_out = ys_out.masked_fill(mask_pad, eos)
|
||||
mask_eos = (ys_out == ignore_id)
|
||||
ys_out = ys_out.masked_fill(mask_eos, eos)
|
||||
ys_out = ys_out.masked_fill(mask_pad, ignore_id)
|
||||
return ys_in, ys_out
|
||||
|
||||
|
||||
def th_accuracy(pad_outputs: paddle.Tensor,
|
||||
pad_targets: paddle.Tensor,
|
||||
ignore_label: int) -> float:
|
||||
"""Calculate accuracy.
|
||||
Args:
|
||||
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
|
||||
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
|
||||
ignore_label (int): Ignore label id.
|
||||
Returns:
|
||||
float: Accuracy value (0.0 - 1.0).
|
||||
"""
|
||||
pad_pred = pad_outputs.view(pad_targets.shape[0], pad_targets.shape[1],
|
||||
pad_outputs.shape[1]).argmax(2)
|
||||
mask = pad_targets != ignore_label
|
||||
#TODO(Hui Zhang): sum not support bool type
|
||||
# numerator = paddle.sum(
|
||||
# pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
|
||||
numerator = (
|
||||
pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
|
||||
numerator = paddle.sum(numerator.type_as(pad_targets))
|
||||
#TODO(Hui Zhang): sum not support bool type
|
||||
# denominator = paddle.sum(mask)
|
||||
denominator = paddle.sum(mask.type_as(pad_targets))
|
||||
return float(numerator) / float(denominator)
|
Loading…
Reference in new issue