You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/paddlespeech/audio/streamdata/filters.py

1008 lines
29 KiB

# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
# Modified from https://github.com/webdataset/webdataset
# Modified from wenet(https://github.com/wenet-e2e/wenet)
"""A collection of iterators for data transformations.
These functions are plain iterator functions. You can find curried versions
in webdataset.filters, and you can find IterableDataset wrappers in
webdataset.processing.
"""
import io
import itertools
import os
import random
import re
import sys
import time
from fnmatch import fnmatch
from functools import reduce
import paddle
from paddleaudio import backends
from paddleaudio.compliance import kaldi
from . import autodecode
from . import utils
from ..transform.cmvn import GlobalCMVN
from ..transform.spec_augment import freq_mask
from ..transform.spec_augment import time_mask
from ..transform.spec_augment import time_warp
from ..utils.tensor_utils import pad_sequence
from .utils import PipelineStage
class FilterFunction(object):
"""Helper class for currying pipeline stages.
We use this roundabout construct becauce it can be pickled.
"""
def __init__(self, f, *args, **kw):
"""Create a curried function."""
self.f = f
self.args = args
self.kw = kw
def __call__(self, data):
"""Call the curried function with the given argument."""
return self.f(data, *self.args, **self.kw)
def __str__(self):
"""Compute a string representation."""
return f"<{self.f.__name__} {self.args} {self.kw}>"
def __repr__(self):
"""Compute a string representation."""
return f"<{self.f.__name__} {self.args} {self.kw}>"
class RestCurried(object):
"""Helper class for currying pipeline stages.
We use this roundabout construct because it can be pickled.
"""
def __init__(self, f):
"""Store the function for future currying."""
self.f = f
def __call__(self, *args, **kw):
"""Curry with the given arguments."""
return FilterFunction(self.f, *args, **kw)
def pipelinefilter(f):
"""Turn the decorated function into one that is partially applied for
all arguments other than the first."""
result = RestCurried(f)
return result
def reraise_exception(exn):
"""Reraises the given exception; used as a handler.
:param exn: exception
"""
raise exn
def identity(x):
"""Return the argument."""
return x
def compose2(f, g):
"""Compose two functions, g(f(x))."""
return lambda x: g(f(x))
def compose(*args):
"""Compose a sequence of functions (left-to-right)."""
return reduce(compose2, args)
def pipeline(source, *args):
"""Write an input pipeline; first argument is source, rest are filters."""
if len(args) == 0:
return source
return compose(*args)(source)
def getfirst(a, keys, default=None, missing_is_error=True):
"""Get the first matching key from a dictionary.
Keys can be specified as a list, or as a string of keys separated by ';'.
"""
if isinstance(keys, str):
assert " " not in keys
keys = keys.split(";")
for k in keys:
if k in a:
return a[k]
if missing_is_error:
raise ValueError(f"didn't find {keys} in {list(a.keys())}")
return default
def parse_field_spec(fields):
"""Parse a specification for a list of fields to be extracted.
Keys are separated by spaces in the spec. Each key can itself
be composed of key alternatives separated by ';'.
"""
if isinstance(fields, str):
fields = fields.split()
return [field.split(";") for field in fields]
def transform_with(sample, transformers):
"""Transform a list of values using a list of functions.
sample: list of values
transformers: list of functions
If there are fewer transformers than inputs, or if a transformer
function is None, then the identity function is used for the
corresponding sample fields.
"""
if transformers is None or len(transformers) == 0:
return sample
result = list(sample)
assert len(transformers) <= len(sample)
for i in range(len(transformers)): # skipcq: PYL-C0200
f = transformers[i]
if f is not None:
result[i] = f(sample[i])
return result
###
# Iterators
###
def _info(data, fmt=None, n=3, every=-1, width=50, stream=sys.stderr, name=""):
"""Print information about the samples that are passing through.
:param data: source iterator
:param fmt: format statement (using sample dict as keyword)
:param n: when to stop
:param every: how often to print
:param width: maximum width
:param stream: output stream
:param name: identifier printed before any output
"""
for i, sample in enumerate(data):
if i < n or (every > 0 and (i + 1) % every == 0):
if fmt is None:
print("---", name, file=stream)
for k, v in sample.items():
print(k, repr(v)[:width], file=stream)
else:
print(fmt.format(**sample), file=stream)
yield sample
info = pipelinefilter(_info)
def pick(buf, rng):
k = rng.randint(0, len(buf) - 1)
sample = buf[k]
buf[k] = buf[-1]
buf.pop()
return sample
def _shuffle(data, bufsize=1000, initial=100, rng=None, handler=None):
"""Shuffle the data in the stream.
This uses a buffer of size `bufsize`. Shuffling at
startup is less random; this is traded off against
yielding samples quickly.
data: iterator
bufsize: buffer size for shuffling
returns: iterator
rng: either random module or random.Random instance
"""
if rng is None:
rng = random.Random(int((os.getpid() + time.time()) * 1e9))
initial = min(initial, bufsize)
buf = []
for sample in data:
buf.append(sample)
if len(buf) < bufsize:
try:
buf.append(next(data)) # skipcq: PYL-R1708
except StopIteration:
pass
if len(buf) >= initial:
yield pick(buf, rng)
while len(buf) > 0:
yield pick(buf, rng)
shuffle = pipelinefilter(_shuffle)
class detshuffle(PipelineStage):
def __init__(self, bufsize=1000, initial=100, seed=0, epoch=-1):
self.bufsize = bufsize
self.initial = initial
self.seed = seed
self.epoch = epoch
def run(self, src):
self.epoch += 1
rng = random.Random()
rng.seed((self.seed, self.epoch))
return _shuffle(src, self.bufsize, self.initial, rng)
def _select(data, predicate):
"""Select samples based on a predicate.
:param data: source iterator
:param predicate: predicate (function)
"""
for sample in data:
if predicate(sample):
yield sample
select = pipelinefilter(_select)
def _log_keys(data, logfile=None):
import fcntl
if logfile is None or logfile == "":
for sample in data:
yield sample
else:
with open(logfile, "a") as stream:
for i, sample in enumerate(data):
buf = f"{i}\t{sample.get('__worker__')}\t{sample.get('__rank__')}\t{sample.get('__key__')}\n"
try:
fcntl.flock(stream.fileno(), fcntl.LOCK_EX)
stream.write(buf)
finally:
fcntl.flock(stream.fileno(), fcntl.LOCK_UN)
yield sample
log_keys = pipelinefilter(_log_keys)
def _minedecode(x):
if isinstance(x, str):
return autodecode.imagehandler(x)
else:
return x
def _decode(data, *args, handler=reraise_exception, **kw):
"""Decode data based on the decoding functions given as arguments."""
decoder = _minedecode
handlers = [decoder(x) for x in args]
f = autodecode.Decoder(handlers, **kw)
for sample in data:
assert isinstance(sample, dict), sample
try:
decoded = f(sample)
except Exception as exn: # skipcq: PYL-W0703
if handler(exn):
continue
else:
break
yield decoded
decode = pipelinefilter(_decode)
def _map(data, f, handler=reraise_exception):
"""Map samples."""
for sample in data:
try:
result = f(sample)
except Exception as exn:
if handler(exn):
continue
else:
break
if result is None:
continue
if isinstance(sample, dict) and isinstance(result, dict):
result["__key__"] = sample.get("__key__")
yield result
map = pipelinefilter(_map)
def _rename(data, handler=reraise_exception, keep=True, **kw):
"""Rename samples based on keyword arguments."""
for sample in data:
try:
if not keep:
yield {
k: getfirst(sample, v, missing_is_error=True)
for k, v in kw.items()
}
else:
def listify(v):
return v.split(";") if isinstance(v, str) else v
to_be_replaced = {x for v in kw.values() for x in listify(v)}
result = {
k: v
for k, v in sample.items() if k not in to_be_replaced
}
result.update({
k: getfirst(sample, v, missing_is_error=True)
for k, v in kw.items()
})
yield result
except Exception as exn:
if handler(exn):
continue
else:
break
rename = pipelinefilter(_rename)
def _associate(data, associator, **kw):
"""Associate additional data with samples."""
for sample in data:
if callable(associator):
extra = associator(sample["__key__"])
else:
extra = associator.get(sample["__key__"], {})
sample.update(extra) # destructive
yield sample
associate = pipelinefilter(_associate)
def _map_dict(data, handler=reraise_exception, **kw):
"""Map the entries in a dict sample with individual functions."""
assert len(list(kw.keys())) > 0
for key, f in kw.items():
assert callable(f), (key, f)
for sample in data:
assert isinstance(sample, dict)
try:
for k, f in kw.items():
sample[k] = f(sample[k])
except Exception as exn:
if handler(exn):
continue
else:
break
yield sample
map_dict = pipelinefilter(_map_dict)
def _to_tuple(data,
*args,
handler=reraise_exception,
missing_is_error=True,
none_is_error=None):
"""Convert dict samples to tuples."""
if none_is_error is None:
none_is_error = missing_is_error
if len(args) == 1 and isinstance(args[0], str) and " " in args[0]:
args = args[0].split()
for sample in data:
try:
result = tuple([
getfirst(sample, f, missing_is_error=missing_is_error)
for f in args
])
if none_is_error and any(x is None for x in result):
raise ValueError(f"to_tuple {args} got {sample.keys()}")
yield result
except Exception as exn:
if handler(exn):
continue
else:
break
to_tuple = pipelinefilter(_to_tuple)
def _map_tuple(data, *args, handler=reraise_exception):
"""Map the entries of a tuple with individual functions."""
args = [f if f is not None else utils.identity for f in args]
for f in args:
assert callable(f), f
for sample in data:
assert isinstance(sample, (list, tuple))
sample = list(sample)
n = min(len(args), len(sample))
try:
for i in range(n):
sample[i] = args[i](sample[i])
except Exception as exn:
if handler(exn):
continue
else:
break
yield tuple(sample)
map_tuple = pipelinefilter(_map_tuple)
def _unlisted(data):
"""Turn batched data back into unbatched data."""
for batch in data:
assert isinstance(batch, list), sample
for sample in batch:
yield sample
unlisted = pipelinefilter(_unlisted)
def _unbatched(data):
"""Turn batched data back into unbatched data."""
for sample in data:
assert isinstance(sample, (tuple, list)), sample
assert len(sample) > 0
for i in range(len(sample[0])):
yield tuple(x[i] for x in sample)
unbatched = pipelinefilter(_unbatched)
def _rsample(data, p=0.5):
"""Randomly subsample a stream of data."""
assert p >= 0.0 and p <= 1.0
for sample in data:
if random.uniform(0.0, 1.0) < p:
yield sample
rsample = pipelinefilter(_rsample)
slice = pipelinefilter(itertools.islice)
def _extract_keys(source,
*patterns,
duplicate_is_error=True,
ignore_missing=False):
for sample in source:
result = []
for pattern in patterns:
pattern = pattern.split(";") if isinstance(pattern,
str) else pattern
matches = [
x for x in sample.keys()
if any(fnmatch("." + x, p) for p in pattern)
]
if len(matches) == 0:
if ignore_missing:
continue
else:
raise ValueError(
f"Cannot find {pattern} in sample keys {sample.keys()}.")
if len(matches) > 1 and duplicate_is_error:
raise ValueError(
f"Multiple sample keys {sample.keys()} match {pattern}.")
value = sample[matches[0]]
result.append(value)
yield tuple(result)
extract_keys = pipelinefilter(_extract_keys)
def _rename_keys(source,
*args,
keep_unselected=False,
must_match=True,
duplicate_is_error=True,
**kw):
renamings = [(pattern, output) for output, pattern in args]
renamings += [(pattern, output) for output, pattern in kw.items()]
for sample in source:
new_sample = {}
matched = {k: False for k, _ in renamings}
for path, value in sample.items():
fname = re.sub(r".*/", "", path)
new_name = None
for pattern, name in renamings[::-1]:
if fnmatch(fname.lower(), pattern):
matched[pattern] = True
new_name = name
break
if new_name is None:
if keep_unselected:
new_sample[path] = value
continue
if new_name in new_sample:
if duplicate_is_error:
raise ValueError(
f"Duplicate value in sample {sample.keys()} after rename."
)
continue
new_sample[new_name] = value
if must_match and not all(matched.values()):
raise ValueError(
f"Not all patterns ({matched}) matched sample keys ({sample.keys()})."
)
yield new_sample
rename_keys = pipelinefilter(_rename_keys)
def decode_bin(stream):
return stream.read()
def decode_text(stream):
binary = stream.read()
return binary.decode("utf-8")
def decode_pickle(stream):
return pickle.load(stream)
default_decoders = [
("*.bin", decode_bin),
("*.txt", decode_text),
("*.pyd", decode_pickle),
]
def find_decoder(decoders, path):
fname = re.sub(r".*/", "", path)
if fname.startswith("__"):
return lambda x: x
for pattern, fun in decoders[::-1]:
if fnmatch(fname.lower(), pattern) or fnmatch("." + fname.lower(),
pattern):
return fun
return None
def _xdecode(
source,
*args,
must_decode=True,
defaults=default_decoders,
**kw, ):
decoders = list(defaults) + list(args)
decoders += [("*." + k, v) for k, v in kw.items()]
for sample in source:
new_sample = {}
for path, data in sample.items():
if path.startswith("__"):
new_sample[path] = data
continue
decoder = find_decoder(decoders, path)
if decoder is False:
value = data
elif decoder is None:
if must_decode:
raise ValueError(f"No decoder found for {path}.")
value = data
else:
if isinstance(data, bytes):
data = io.BytesIO(data)
value = decoder(data)
new_sample[path] = value
yield new_sample
xdecode = pipelinefilter(_xdecode)
def _audio_data_filter(source,
frame_shift=10,
max_length=10240,
min_length=10,
token_max_length=200,
token_min_length=1,
min_output_input_ratio=0.0005,
max_output_input_ratio=1):
""" Filter sample according to feature and label length
Inplace operation.
Args::
source: Iterable[{fname, wav, label, sample_rate}]
frame_shift: length of frame shift (ms)
max_length: drop utterance which is greater than max_length(10ms)
min_length: drop utterance which is less than min_length(10ms)
token_max_length: drop utterance which is greater than
token_max_length, especially when use char unit for
english modeling
token_min_length: drop utterance which is
less than token_max_length
min_output_input_ratio: minimal ration of
token_length / feats_length(10ms)
max_output_input_ratio: maximum ration of
token_length / feats_length(10ms)
Returns:
Iterable[{fname, wav, label, sample_rate}]
"""
for sample in source:
assert 'sample_rate' in sample
assert 'wav' in sample
assert 'label' in sample
# sample['wav'] is paddle.Tensor, we have 100 frames every second (default)
num_frames = sample['wav'].shape[1] / sample['sample_rate'] * (
1000 / frame_shift)
if num_frames < min_length:
continue
if num_frames > max_length:
continue
if len(sample['label']) < token_min_length:
continue
if len(sample['label']) > token_max_length:
continue
if num_frames != 0:
if len(sample['label']) / num_frames < min_output_input_ratio:
continue
if len(sample['label']) / num_frames > max_output_input_ratio:
continue
yield sample
audio_data_filter = pipelinefilter(_audio_data_filter)
def _audio_tokenize(source,
symbol_table,
bpe_model=None,
non_lang_syms=None,
split_with_space=False):
""" Decode text to chars or BPE
Inplace operation
Args:
source: Iterable[{fname, wav, txt, sample_rate}]
Returns:
Iterable[{fname, wav, txt, tokens, label, sample_rate}]
"""
if non_lang_syms is not None:
non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})")
else:
non_lang_syms = {}
non_lang_syms_pattern = None
if bpe_model is not None:
import sentencepiece as spm
sp = spm.SentencePieceProcessor()
sp.load(bpe_model)
else:
sp = None
for sample in source:
assert 'txt' in sample
txt = sample['txt'].strip()
if non_lang_syms_pattern is not None:
parts = non_lang_syms_pattern.split(txt.upper())
parts = [w for w in parts if len(w.strip()) > 0]
else:
parts = [txt]
label = []
tokens = []
for part in parts:
if part in non_lang_syms:
tokens.append(part)
else:
if bpe_model is not None:
tokens.extend(__tokenize_by_bpe_model(sp, part))
else:
if split_with_space:
part = part.split(" ")
for ch in part:
if ch == ' ':
ch = "<space>"
tokens.append(ch)
for ch in tokens:
if ch in symbol_table:
label.append(symbol_table[ch])
elif '<unk>' in symbol_table:
label.append(symbol_table['<unk>'])
sample['tokens'] = tokens
sample['label'] = label
yield sample
audio_tokenize = pipelinefilter(_audio_tokenize)
def _audio_resample(source, resample_rate=16000):
""" Resample data.
Inplace operation.
Args:
data: Iterable[{fname, wav, label, sample_rate}]
resample_rate: target resample rate
Returns:
Iterable[{fname, wav, label, sample_rate}]
"""
for sample in source:
assert 'sample_rate' in sample
assert 'wav' in sample
sample_rate = sample['sample_rate']
waveform = sample['wav']
if sample_rate != resample_rate:
sample['sample_rate'] = resample_rate
sample['wav'] = paddle.to_tensor(
backends.soundfile_backend.resample(
waveform.numpy(),
src_sr=sample_rate,
target_sr=resample_rate))
yield sample
audio_resample = pipelinefilter(_audio_resample)
def _audio_compute_fbank(source,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
dither=0.0):
""" Extract fbank
Args:
source: Iterable[{fname, wav, label, sample_rate}]
num_mel_bins: number of mel filter bank
frame_length: length of one frame (ms)
frame_shift: length of frame shift (ms)
dither: value of dither
Returns:
Iterable[{fname, feat, label}]
"""
for sample in source:
assert 'sample_rate' in sample
assert 'wav' in sample
assert 'fname' in sample
assert 'label' in sample
sample_rate = sample['sample_rate']
waveform = sample['wav']
waveform = waveform * (1 << 15)
# Only keep fname, feat, label
mat = kaldi.fbank(
waveform,
n_mels=num_mel_bins,
frame_length=frame_length,
frame_shift=frame_shift,
dither=dither,
energy_floor=0.0,
sr=sample_rate)
yield dict(fname=sample['fname'], label=sample['label'], feat=mat)
audio_compute_fbank = pipelinefilter(_audio_compute_fbank)
def _audio_spec_aug(
source,
max_w=5,
w_inplace=True,
w_mode="PIL",
max_f=30,
num_f_mask=2,
f_inplace=True,
f_replace_with_zero=False,
max_t=40,
num_t_mask=2,
t_inplace=True,
t_replace_with_zero=False, ):
""" Do spec augmentation
Inplace operation
Args:
source: Iterable[{fname, feat, label}]
max_w: max width of time warp
w_inplace: whether to inplace the original data while time warping
w_mode: time warp mode
max_f: max width of freq mask
num_f_mask: number of freq mask to apply
f_inplace: whether to inplace the original data while frequency masking
f_replace_with_zero: use zero to mask
max_t: max width of time mask
num_t_mask: number of time mask to apply
t_inplace: whether to inplace the original data while time masking
t_replace_with_zero: use zero to mask
Returns
Iterable[{fname, feat, label}]
"""
for sample in source:
x = sample['feat']
x = x.numpy()
x = time_warp(x, max_time_warp=max_w, inplace=w_inplace, mode=w_mode)
x = freq_mask(
x,
F=max_f,
n_mask=num_f_mask,
inplace=f_inplace,
replace_with_zero=f_replace_with_zero)
x = time_mask(
x,
T=max_t,
n_mask=num_t_mask,
inplace=t_inplace,
replace_with_zero=t_replace_with_zero)
sample['feat'] = paddle.to_tensor(x, dtype=paddle.float32)
yield sample
audio_spec_aug = pipelinefilter(_audio_spec_aug)
def _sort(source, sort_size=500):
""" Sort the data by feature length.
Sort is used after shuffle and before batch, so we can group
utts with similar lengths into a batch, and `sort_size` should
be less than `shuffle_size`
Args:
source: Iterable[{fname, feat, label}]
sort_size: buffer size for sort
Returns:
Iterable[{fname, feat, label}]
"""
buf = []
for sample in source:
buf.append(sample)
if len(buf) >= sort_size:
buf.sort(key=lambda x: x['feat'].shape[0])
for x in buf:
yield x
buf = []
# The sample left over
buf.sort(key=lambda x: x['feat'].shape[0])
for x in buf:
yield x
sort = pipelinefilter(_sort)
def _batched(source, batch_size=16):
""" Static batch the data by `batch_size`
Args:
data: Iterable[{fname, feat, label}]
batch_size: batch size
Returns:
Iterable[List[{fname, feat, label}]]
"""
buf = []
for sample in source:
buf.append(sample)
if len(buf) >= batch_size:
yield buf
buf = []
if len(buf) > 0:
yield buf
batched = pipelinefilter(_batched)
def dynamic_batched(source, max_frames_in_batch=12000):
""" Dynamic batch the data until the total frames in batch
reach `max_frames_in_batch`
Args:
source: Iterable[{fname, feat, label}]
max_frames_in_batch: max_frames in one batch
Returns:
Iterable[List[{fname, feat, label}]]
"""
buf = []
longest_frames = 0
for sample in source:
assert 'feat' in sample
assert isinstance(sample['feat'], paddle.Tensor)
new_sample_frames = sample['feat'].size(0)
longest_frames = max(longest_frames, new_sample_frames)
frames_after_padding = longest_frames * (len(buf) + 1)
if frames_after_padding > max_frames_in_batch:
yield buf
buf = [sample]
longest_frames = new_sample_frames
else:
buf.append(sample)
if len(buf) > 0:
yield buf
def _audio_padding(source):
""" Padding the data into training data
Args:
source: Iterable[List[{fname, feat, label}]]
Returns:
Iterable[Tuple(fname, feats, labels, feats lengths, label lengths)]
"""
for sample in source:
assert isinstance(sample, list)
feats_length = paddle.to_tensor(
[x['feat'].shape[0] for x in sample], dtype="int64")
order = paddle.argsort(feats_length, descending=True)
feats_lengths = paddle.to_tensor(
[sample[i]['feat'].shape[0] for i in order], dtype="int64")
sorted_feats = [sample[i]['feat'] for i in order]
sorted_keys = [sample[i]['fname'] for i in order]
sorted_labels = [
paddle.to_tensor(sample[i]['label'], dtype="int32") for i in order
]
label_lengths = paddle.to_tensor(
[x.shape[0] for x in sorted_labels], dtype="int64")
padded_feats = pad_sequence(
sorted_feats, batch_first=True, padding_value=0)
padding_labels = pad_sequence(
sorted_labels, batch_first=True, padding_value=-1)
yield (sorted_keys, padded_feats, feats_lengths, padding_labels,
label_lengths)
audio_padding = pipelinefilter(_audio_padding)
def _audio_cmvn(source, cmvn_file):
global_cmvn = GlobalCMVN(cmvn_file)
for batch in source:
sorted_keys, padded_feats, feats_lengths, padding_labels, label_lengths = batch
padded_feats = padded_feats.numpy()
padded_feats = global_cmvn(padded_feats)
padded_feats = paddle.to_tensor(padded_feats, dtype=paddle.float32)
yield (sorted_keys, padded_feats, feats_lengths, padding_labels,
label_lengths)
audio_cmvn = pipelinefilter(_audio_cmvn)
def _placeholder(source):
for data in source:
yield data
placeholder = pipelinefilter(_placeholder)