parent
614a004c37
commit
a107b75bac
@ -0,0 +1,41 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def delta(feat, window):
|
||||||
|
assert window > 0
|
||||||
|
delta_feat = np.zeros_like(feat)
|
||||||
|
for i in range(1, window + 1):
|
||||||
|
delta_feat[:-i] += i * feat[i:]
|
||||||
|
delta_feat[i:] += -i * feat[:-i]
|
||||||
|
delta_feat[-i:] += i * feat[-1]
|
||||||
|
delta_feat[:i] += -i * feat[0]
|
||||||
|
delta_feat /= 2 * sum(i ** 2 for i in range(1, window + 1))
|
||||||
|
return delta_feat
|
||||||
|
|
||||||
|
|
||||||
|
def add_deltas(x, window=2, order=2):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (np.ndarray): speech feat, (T, D).
|
||||||
|
|
||||||
|
Return:
|
||||||
|
np.ndarray: (T, (1+order)*D)
|
||||||
|
"""
|
||||||
|
feats = [x]
|
||||||
|
for _ in range(order):
|
||||||
|
feats.append(delta(feats[-1], window))
|
||||||
|
return np.concatenate(feats, axis=1)
|
||||||
|
|
||||||
|
|
||||||
|
class AddDeltas():
|
||||||
|
def __init__(self, window=2, order=2):
|
||||||
|
self.window = window
|
||||||
|
self.order = order
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{name}(window={window}, order={order}".format(
|
||||||
|
name=self.__class__.__name__, window=self.window, order=self.order
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return add_deltas(x, window=self.window, order=self.order)
|
@ -0,0 +1,45 @@
|
|||||||
|
import numpy
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelSelector():
|
||||||
|
"""Select 1ch from multi-channel signal"""
|
||||||
|
|
||||||
|
def __init__(self, train_channel="random", eval_channel=0, axis=1):
|
||||||
|
self.train_channel = train_channel
|
||||||
|
self.eval_channel = eval_channel
|
||||||
|
self.axis = axis
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (
|
||||||
|
"{name}(train_channel={train_channel}, "
|
||||||
|
"eval_channel={eval_channel}, axis={axis})".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
train_channel=self.train_channel,
|
||||||
|
eval_channel=self.eval_channel,
|
||||||
|
axis=self.axis,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x, train=True):
|
||||||
|
# Assuming x: [Time, Channel] by default
|
||||||
|
|
||||||
|
if x.ndim <= self.axis:
|
||||||
|
# If the dimension is insufficient, then unsqueeze
|
||||||
|
# (e.g [Time] -> [Time, 1])
|
||||||
|
ind = tuple(
|
||||||
|
slice(None) if i < x.ndim else None for i in range(self.axis + 1)
|
||||||
|
)
|
||||||
|
x = x[ind]
|
||||||
|
|
||||||
|
if train:
|
||||||
|
channel = self.train_channel
|
||||||
|
else:
|
||||||
|
channel = self.eval_channel
|
||||||
|
|
||||||
|
if channel == "random":
|
||||||
|
ch = numpy.random.randint(0, x.shape[self.axis])
|
||||||
|
else:
|
||||||
|
ch = channel
|
||||||
|
|
||||||
|
ind = tuple(slice(None) if i != self.axis else ch for i in range(x.ndim))
|
||||||
|
return x[ind]
|
@ -0,0 +1,71 @@
|
|||||||
|
import inspect
|
||||||
|
|
||||||
|
from deepspeech.transform.transform_interface import TransformInterface
|
||||||
|
from deepspeech.utils.check_kwargs import check_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
class FuncTrans(TransformInterface):
|
||||||
|
"""Functional Transformation
|
||||||
|
|
||||||
|
WARNING:
|
||||||
|
Builtin or C/C++ functions may not work properly
|
||||||
|
because this class heavily depends on the `inspect` module.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
>>> def foo_bar(x, a=1, b=2):
|
||||||
|
... '''Foo bar
|
||||||
|
... :param x: input
|
||||||
|
... :param int a: default 1
|
||||||
|
... :param int b: default 2
|
||||||
|
... '''
|
||||||
|
... return x + a - b
|
||||||
|
|
||||||
|
|
||||||
|
>>> class FooBar(FuncTrans):
|
||||||
|
... _func = foo_bar
|
||||||
|
... __doc__ = foo_bar.__doc__
|
||||||
|
"""
|
||||||
|
|
||||||
|
_func = None
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.kwargs = kwargs
|
||||||
|
check_kwargs(self.func, kwargs)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return self.func(x, **self.kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_arguments(cls, parser):
|
||||||
|
fname = cls._func.__name__.replace("_", "-")
|
||||||
|
group = parser.add_argument_group(fname + " transformation setting")
|
||||||
|
for k, v in cls.default_params().items():
|
||||||
|
# TODO(karita): get help and choices from docstring?
|
||||||
|
attr = k.replace("_", "-")
|
||||||
|
group.add_argument(f"--{fname}-{attr}", default=v, type=type(v))
|
||||||
|
return parser
|
||||||
|
|
||||||
|
@property
|
||||||
|
def func(self):
|
||||||
|
return type(self)._func
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def default_params(cls):
|
||||||
|
try:
|
||||||
|
d = dict(inspect.signature(cls._func).parameters)
|
||||||
|
except ValueError:
|
||||||
|
d = dict()
|
||||||
|
return {
|
||||||
|
k: v.default for k, v in d.items() if v.default != inspect.Parameter.empty
|
||||||
|
}
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
params = self.default_params()
|
||||||
|
params.update(**self.kwargs)
|
||||||
|
ret = self.__class__.__name__ + "("
|
||||||
|
if len(params) == 0:
|
||||||
|
return ret + ")"
|
||||||
|
for k, v in params.items():
|
||||||
|
ret += "{}={}, ".format(k, v)
|
||||||
|
return ret[:-2] + ")"
|
@ -0,0 +1,343 @@
|
|||||||
|
import librosa
|
||||||
|
import numpy
|
||||||
|
import scipy
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
from deepspeech.io.reader import SoundHDF5File
|
||||||
|
|
||||||
|
class SpeedPerturbation():
|
||||||
|
"""SpeedPerturbation
|
||||||
|
|
||||||
|
The speed perturbation in kaldi uses sox-speed instead of sox-tempo,
|
||||||
|
and sox-speed just to resample the input,
|
||||||
|
i.e pitch and tempo are changed both.
|
||||||
|
|
||||||
|
"Why use speed option instead of tempo -s in SoX for speed perturbation"
|
||||||
|
https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
This function is very slow because of resampling.
|
||||||
|
I recommmend to apply speed-perturb outside the training using sox.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
lower=0.9,
|
||||||
|
upper=1.1,
|
||||||
|
utt2ratio=None,
|
||||||
|
keep_length=True,
|
||||||
|
res_type="kaiser_best",
|
||||||
|
seed=None,
|
||||||
|
):
|
||||||
|
self.res_type = res_type
|
||||||
|
self.keep_length = keep_length
|
||||||
|
self.state = numpy.random.RandomState(seed)
|
||||||
|
|
||||||
|
if utt2ratio is not None:
|
||||||
|
self.utt2ratio = {}
|
||||||
|
# Use the scheduled ratio for each utterances
|
||||||
|
self.utt2ratio_file = utt2ratio
|
||||||
|
self.lower = None
|
||||||
|
self.upper = None
|
||||||
|
self.accept_uttid = True
|
||||||
|
|
||||||
|
with open(utt2ratio, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
utt, ratio = line.rstrip().split(None, 1)
|
||||||
|
ratio = float(ratio)
|
||||||
|
self.utt2ratio[utt] = ratio
|
||||||
|
else:
|
||||||
|
self.utt2ratio = None
|
||||||
|
# The ratio is given on runtime randomly
|
||||||
|
self.lower = lower
|
||||||
|
self.upper = upper
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
if self.utt2ratio is None:
|
||||||
|
return "{}(lower={}, upper={}, " "keep_length={}, res_type={})".format(
|
||||||
|
self.__class__.__name__,
|
||||||
|
self.lower,
|
||||||
|
self.upper,
|
||||||
|
self.keep_length,
|
||||||
|
self.res_type,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return "{}({}, res_type={})".format(
|
||||||
|
self.__class__.__name__, self.utt2ratio_file, self.res_type
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x, uttid=None, train=True):
|
||||||
|
if not train:
|
||||||
|
return x
|
||||||
|
|
||||||
|
x = x.astype(numpy.float32)
|
||||||
|
if self.accept_uttid:
|
||||||
|
ratio = self.utt2ratio[uttid]
|
||||||
|
else:
|
||||||
|
ratio = self.state.uniform(self.lower, self.upper)
|
||||||
|
|
||||||
|
# Note1: resample requires the sampling-rate of input and output,
|
||||||
|
# but actually only the ratio is used.
|
||||||
|
y = librosa.resample(x, ratio, 1, res_type=self.res_type)
|
||||||
|
|
||||||
|
if self.keep_length:
|
||||||
|
diff = abs(len(x) - len(y))
|
||||||
|
if len(y) > len(x):
|
||||||
|
# Truncate noise
|
||||||
|
y = y[diff // 2 : -((diff + 1) // 2)]
|
||||||
|
elif len(y) < len(x):
|
||||||
|
# Assume the time-axis is the first: (Time, Channel)
|
||||||
|
pad_width = [(diff // 2, (diff + 1) // 2)] + [
|
||||||
|
(0, 0) for _ in range(y.ndim - 1)
|
||||||
|
]
|
||||||
|
y = numpy.pad(
|
||||||
|
y, pad_width=pad_width, constant_values=0, mode="constant"
|
||||||
|
)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class BandpassPerturbation():
|
||||||
|
"""BandpassPerturbation
|
||||||
|
|
||||||
|
Randomly dropout along the frequency axis.
|
||||||
|
|
||||||
|
The original idea comes from the following:
|
||||||
|
"randomly-selected frequency band was cut off under the constraint of
|
||||||
|
leaving at least 1,000 Hz band within the range of less than 4,000Hz."
|
||||||
|
(The Hitachi/JHU CHiME-5 system: Advances in speech recognition for
|
||||||
|
everyday home environments using multiple microphone arrays;
|
||||||
|
http://spandh.dcs.shef.ac.uk/chime_workshop/papers/CHiME_2018_paper_kanda.pdf)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, lower=0.0, upper=0.75, seed=None, axes=(-1,)):
|
||||||
|
self.lower = lower
|
||||||
|
self.upper = upper
|
||||||
|
self.state = numpy.random.RandomState(seed)
|
||||||
|
# x_stft: (Time, Channel, Freq)
|
||||||
|
self.axes = axes
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{}(lower={}, upper={})".format(
|
||||||
|
self.__class__.__name__, self.lower, self.upper
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x_stft, uttid=None, train=True):
|
||||||
|
if not train:
|
||||||
|
return x_stft
|
||||||
|
|
||||||
|
if x_stft.ndim == 1:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Input in time-freq domain: " "(Time, Channel, Freq) or (Time, Freq)"
|
||||||
|
)
|
||||||
|
|
||||||
|
ratio = self.state.uniform(self.lower, self.upper)
|
||||||
|
axes = [i if i >= 0 else x_stft.ndim - i for i in self.axes]
|
||||||
|
shape = [s if i in axes else 1 for i, s in enumerate(x_stft.shape)]
|
||||||
|
|
||||||
|
mask = self.state.randn(*shape) > ratio
|
||||||
|
x_stft *= mask
|
||||||
|
return x_stft
|
||||||
|
|
||||||
|
|
||||||
|
class VolumePerturbation():
|
||||||
|
def __init__(self, lower=-1.6, upper=1.6, utt2ratio=None, dbunit=True, seed=None):
|
||||||
|
self.dbunit = dbunit
|
||||||
|
self.utt2ratio_file = utt2ratio
|
||||||
|
self.lower = lower
|
||||||
|
self.upper = upper
|
||||||
|
self.state = numpy.random.RandomState(seed)
|
||||||
|
|
||||||
|
if utt2ratio is not None:
|
||||||
|
# Use the scheduled ratio for each utterances
|
||||||
|
self.utt2ratio = {}
|
||||||
|
self.lower = None
|
||||||
|
self.upper = None
|
||||||
|
self.accept_uttid = True
|
||||||
|
|
||||||
|
with open(utt2ratio, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
utt, ratio = line.rstrip().split(None, 1)
|
||||||
|
ratio = float(ratio)
|
||||||
|
self.utt2ratio[utt] = ratio
|
||||||
|
else:
|
||||||
|
# The ratio is given on runtime randomly
|
||||||
|
self.utt2ratio = None
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
if self.utt2ratio is None:
|
||||||
|
return "{}(lower={}, upper={}, dbunit={})".format(
|
||||||
|
self.__class__.__name__, self.lower, self.upper, self.dbunit
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return '{}("{}", dbunit={})'.format(
|
||||||
|
self.__class__.__name__, self.utt2ratio_file, self.dbunit
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x, uttid=None, train=True):
|
||||||
|
if not train:
|
||||||
|
return x
|
||||||
|
|
||||||
|
x = x.astype(numpy.float32)
|
||||||
|
|
||||||
|
if self.accept_uttid:
|
||||||
|
ratio = self.utt2ratio[uttid]
|
||||||
|
else:
|
||||||
|
ratio = self.state.uniform(self.lower, self.upper)
|
||||||
|
if self.dbunit:
|
||||||
|
ratio = 10 ** (ratio / 20)
|
||||||
|
return x * ratio
|
||||||
|
|
||||||
|
|
||||||
|
class NoiseInjection():
|
||||||
|
"""Add isotropic noise"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
utt2noise=None,
|
||||||
|
lower=-20,
|
||||||
|
upper=-5,
|
||||||
|
utt2ratio=None,
|
||||||
|
filetype="list",
|
||||||
|
dbunit=True,
|
||||||
|
seed=None,
|
||||||
|
):
|
||||||
|
self.utt2noise_file = utt2noise
|
||||||
|
self.utt2ratio_file = utt2ratio
|
||||||
|
self.filetype = filetype
|
||||||
|
self.dbunit = dbunit
|
||||||
|
self.lower = lower
|
||||||
|
self.upper = upper
|
||||||
|
self.state = numpy.random.RandomState(seed)
|
||||||
|
|
||||||
|
if utt2ratio is not None:
|
||||||
|
# Use the scheduled ratio for each utterances
|
||||||
|
self.utt2ratio = {}
|
||||||
|
with open(utt2noise, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
utt, snr = line.rstrip().split(None, 1)
|
||||||
|
snr = float(snr)
|
||||||
|
self.utt2ratio[utt] = snr
|
||||||
|
else:
|
||||||
|
# The ratio is given on runtime randomly
|
||||||
|
self.utt2ratio = None
|
||||||
|
|
||||||
|
if utt2noise is not None:
|
||||||
|
self.utt2noise = {}
|
||||||
|
if filetype == "list":
|
||||||
|
with open(utt2noise, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
utt, filename = line.rstrip().split(None, 1)
|
||||||
|
signal, rate = soundfile.read(filename, dtype="int16")
|
||||||
|
# Load all files in memory
|
||||||
|
self.utt2noise[utt] = (signal, rate)
|
||||||
|
|
||||||
|
elif filetype == "sound.hdf5":
|
||||||
|
self.utt2noise = SoundHDF5File(utt2noise, "r")
|
||||||
|
else:
|
||||||
|
raise ValueError(filetype)
|
||||||
|
else:
|
||||||
|
self.utt2noise = None
|
||||||
|
|
||||||
|
if utt2noise is not None and utt2ratio is not None:
|
||||||
|
if set(self.utt2ratio) != set(self.utt2noise):
|
||||||
|
raise RuntimeError(
|
||||||
|
"The uttids mismatch between {} and {}".format(utt2ratio, utt2noise)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
if self.utt2ratio is None:
|
||||||
|
return "{}(lower={}, upper={}, dbunit={})".format(
|
||||||
|
self.__class__.__name__, self.lower, self.upper, self.dbunit
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return '{}("{}", dbunit={})'.format(
|
||||||
|
self.__class__.__name__, self.utt2ratio_file, self.dbunit
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x, uttid=None, train=True):
|
||||||
|
if not train:
|
||||||
|
return x
|
||||||
|
x = x.astype(numpy.float32)
|
||||||
|
|
||||||
|
# 1. Get ratio of noise to signal in sound pressure level
|
||||||
|
if uttid is not None and self.utt2ratio is not None:
|
||||||
|
ratio = self.utt2ratio[uttid]
|
||||||
|
else:
|
||||||
|
ratio = self.state.uniform(self.lower, self.upper)
|
||||||
|
|
||||||
|
if self.dbunit:
|
||||||
|
ratio = 10 ** (ratio / 20)
|
||||||
|
scale = ratio * numpy.sqrt((x ** 2).mean())
|
||||||
|
|
||||||
|
# 2. Get noise
|
||||||
|
if self.utt2noise is not None:
|
||||||
|
# Get noise from the external source
|
||||||
|
if uttid is not None:
|
||||||
|
noise, rate = self.utt2noise[uttid]
|
||||||
|
else:
|
||||||
|
# Randomly select the noise source
|
||||||
|
noise = self.state.choice(list(self.utt2noise.values()))
|
||||||
|
# Normalize the level
|
||||||
|
noise /= numpy.sqrt((noise ** 2).mean())
|
||||||
|
|
||||||
|
# Adjust the noise length
|
||||||
|
diff = abs(len(x) - len(noise))
|
||||||
|
offset = self.state.randint(0, diff)
|
||||||
|
if len(noise) > len(x):
|
||||||
|
# Truncate noise
|
||||||
|
noise = noise[offset : -(diff - offset)]
|
||||||
|
else:
|
||||||
|
noise = numpy.pad(noise, pad_width=[offset, diff - offset], mode="wrap")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Generate white noise
|
||||||
|
noise = self.state.normal(0, 1, x.shape)
|
||||||
|
|
||||||
|
# 3. Add noise to signal
|
||||||
|
return x + noise * scale
|
||||||
|
|
||||||
|
|
||||||
|
class RIRConvolve():
|
||||||
|
def __init__(self, utt2rir, filetype="list"):
|
||||||
|
self.utt2rir_file = utt2rir
|
||||||
|
self.filetype = filetype
|
||||||
|
|
||||||
|
self.utt2rir = {}
|
||||||
|
if filetype == "list":
|
||||||
|
with open(utt2rir, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
utt, filename = line.rstrip().split(None, 1)
|
||||||
|
signal, rate = soundfile.read(filename, dtype="int16")
|
||||||
|
self.utt2rir[utt] = (signal, rate)
|
||||||
|
|
||||||
|
elif filetype == "sound.hdf5":
|
||||||
|
self.utt2rir = SoundHDF5File(utt2rir, "r")
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(filetype)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return '{}("{}")'.format(self.__class__.__name__, self.utt2rir_file)
|
||||||
|
|
||||||
|
def __call__(self, x, uttid=None, train=True):
|
||||||
|
if not train:
|
||||||
|
return x
|
||||||
|
|
||||||
|
x = x.astype(numpy.float32)
|
||||||
|
|
||||||
|
if x.ndim != 1:
|
||||||
|
# Must be single channel
|
||||||
|
raise RuntimeError(
|
||||||
|
"Input x must be one dimensional array, but got {}".format(x.shape)
|
||||||
|
)
|
||||||
|
|
||||||
|
rir, rate = self.utt2rir[uttid]
|
||||||
|
if rir.ndim == 2:
|
||||||
|
# FIXME(kamo): Use chainer.convolution_1d?
|
||||||
|
# return [Time, Channel]
|
||||||
|
return numpy.stack(
|
||||||
|
[scipy.convolve(x, r, mode="same") for r in rir], axis=-1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return scipy.convolve(x, rir, mode="same")
|
@ -0,0 +1,202 @@
|
|||||||
|
"""Spec Augment module for preprocessing i.e., data augmentation"""
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
from PIL import Image
|
||||||
|
from PIL.Image import BICUBIC
|
||||||
|
|
||||||
|
from deepspeech.transform.functional import FuncTrans
|
||||||
|
|
||||||
|
|
||||||
|
def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"):
|
||||||
|
"""time warp for spec augment
|
||||||
|
|
||||||
|
move random center frame by the random width ~ uniform(-window, window)
|
||||||
|
:param numpy.ndarray x: spectrogram (time, freq)
|
||||||
|
:param int max_time_warp: maximum time frames to warp
|
||||||
|
:param bool inplace: overwrite x with the result
|
||||||
|
:param str mode: "PIL" (default, fast, not differentiable) or "sparse_image_warp"
|
||||||
|
(slow, differentiable)
|
||||||
|
:returns numpy.ndarray: time warped spectrogram (time, freq)
|
||||||
|
"""
|
||||||
|
window = max_time_warp
|
||||||
|
if mode == "PIL":
|
||||||
|
t = x.shape[0]
|
||||||
|
if t - window <= window:
|
||||||
|
return x
|
||||||
|
# NOTE: randrange(a, b) emits a, a + 1, ..., b - 1
|
||||||
|
center = random.randrange(window, t - window)
|
||||||
|
warped = random.randrange(center - window, center + window) + 1 # 1 ... t - 1
|
||||||
|
|
||||||
|
left = Image.fromarray(x[:center]).resize((x.shape[1], warped), BICUBIC)
|
||||||
|
right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped), BICUBIC)
|
||||||
|
if inplace:
|
||||||
|
x[:warped] = left
|
||||||
|
x[warped:] = right
|
||||||
|
return x
|
||||||
|
return numpy.concatenate((left, right), 0)
|
||||||
|
elif mode == "sparse_image_warp":
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from espnet.utils import spec_augment
|
||||||
|
|
||||||
|
# TODO(karita): make this differentiable again
|
||||||
|
return spec_augment.time_warp(paddle.to_tensor(x), window).numpy()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"unknown resize mode: "
|
||||||
|
+ mode
|
||||||
|
+ ", choose one from (PIL, sparse_image_warp)."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TimeWarp(FuncTrans):
|
||||||
|
_func = time_warp
|
||||||
|
__doc__ = time_warp.__doc__
|
||||||
|
|
||||||
|
def __call__(self, x, train):
|
||||||
|
if not train:
|
||||||
|
return x
|
||||||
|
return super().__call__(x)
|
||||||
|
|
||||||
|
|
||||||
|
def freq_mask(x, F=30, n_mask=2, replace_with_zero=True, inplace=False):
|
||||||
|
"""freq mask for spec agument
|
||||||
|
|
||||||
|
:param numpy.ndarray x: (time, freq)
|
||||||
|
:param int n_mask: the number of masks
|
||||||
|
:param bool inplace: overwrite
|
||||||
|
:param bool replace_with_zero: pad zero on mask if true else use mean
|
||||||
|
"""
|
||||||
|
if inplace:
|
||||||
|
cloned = x
|
||||||
|
else:
|
||||||
|
cloned = x.copy()
|
||||||
|
|
||||||
|
num_mel_channels = cloned.shape[1]
|
||||||
|
fs = numpy.random.randint(0, F, size=(n_mask, 2))
|
||||||
|
|
||||||
|
for f, mask_end in fs:
|
||||||
|
f_zero = random.randrange(0, num_mel_channels - f)
|
||||||
|
mask_end += f_zero
|
||||||
|
|
||||||
|
# avoids randrange error if values are equal and range is empty
|
||||||
|
if f_zero == f_zero + f:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if replace_with_zero:
|
||||||
|
cloned[:, f_zero:mask_end] = 0
|
||||||
|
else:
|
||||||
|
cloned[:, f_zero:mask_end] = cloned.mean()
|
||||||
|
return cloned
|
||||||
|
|
||||||
|
|
||||||
|
class FreqMask(FuncTrans):
|
||||||
|
_func = freq_mask
|
||||||
|
__doc__ = freq_mask.__doc__
|
||||||
|
|
||||||
|
def __call__(self, x, train):
|
||||||
|
if not train:
|
||||||
|
return x
|
||||||
|
return super().__call__(x)
|
||||||
|
|
||||||
|
|
||||||
|
def time_mask(spec, T=40, n_mask=2, replace_with_zero=True, inplace=False):
|
||||||
|
"""freq mask for spec agument
|
||||||
|
|
||||||
|
:param numpy.ndarray spec: (time, freq)
|
||||||
|
:param int n_mask: the number of masks
|
||||||
|
:param bool inplace: overwrite
|
||||||
|
:param bool replace_with_zero: pad zero on mask if true else use mean
|
||||||
|
"""
|
||||||
|
if inplace:
|
||||||
|
cloned = spec
|
||||||
|
else:
|
||||||
|
cloned = spec.copy()
|
||||||
|
len_spectro = cloned.shape[0]
|
||||||
|
ts = numpy.random.randint(0, T, size=(n_mask, 2))
|
||||||
|
for t, mask_end in ts:
|
||||||
|
# avoid randint range error
|
||||||
|
if len_spectro - t <= 0:
|
||||||
|
continue
|
||||||
|
t_zero = random.randrange(0, len_spectro - t)
|
||||||
|
|
||||||
|
# avoids randrange error if values are equal and range is empty
|
||||||
|
if t_zero == t_zero + t:
|
||||||
|
continue
|
||||||
|
|
||||||
|
mask_end += t_zero
|
||||||
|
if replace_with_zero:
|
||||||
|
cloned[t_zero:mask_end] = 0
|
||||||
|
else:
|
||||||
|
cloned[t_zero:mask_end] = cloned.mean()
|
||||||
|
return cloned
|
||||||
|
|
||||||
|
|
||||||
|
class TimeMask(FuncTrans):
|
||||||
|
_func = time_mask
|
||||||
|
__doc__ = time_mask.__doc__
|
||||||
|
|
||||||
|
def __call__(self, x, train):
|
||||||
|
if not train:
|
||||||
|
return x
|
||||||
|
return super().__call__(x)
|
||||||
|
|
||||||
|
|
||||||
|
def spec_augment(
|
||||||
|
x,
|
||||||
|
resize_mode="PIL",
|
||||||
|
max_time_warp=80,
|
||||||
|
max_freq_width=27,
|
||||||
|
n_freq_mask=2,
|
||||||
|
max_time_width=100,
|
||||||
|
n_time_mask=2,
|
||||||
|
inplace=True,
|
||||||
|
replace_with_zero=True,
|
||||||
|
):
|
||||||
|
"""spec agument
|
||||||
|
|
||||||
|
apply random time warping and time/freq masking
|
||||||
|
default setting is based on LD (Librispeech double) in Table 2
|
||||||
|
https://arxiv.org/pdf/1904.08779.pdf
|
||||||
|
|
||||||
|
:param numpy.ndarray x: (time, freq)
|
||||||
|
:param str resize_mode: "PIL" (fast, nondifferentiable) or "sparse_image_warp"
|
||||||
|
(slow, differentiable)
|
||||||
|
:param int max_time_warp: maximum frames to warp the center frame in spectrogram (W)
|
||||||
|
:param int freq_mask_width: maximum width of the random freq mask (F)
|
||||||
|
:param int n_freq_mask: the number of the random freq mask (m_F)
|
||||||
|
:param int time_mask_width: maximum width of the random time mask (T)
|
||||||
|
:param int n_time_mask: the number of the random time mask (m_T)
|
||||||
|
:param bool inplace: overwrite intermediate array
|
||||||
|
:param bool replace_with_zero: pad zero on mask if true else use mean
|
||||||
|
"""
|
||||||
|
assert isinstance(x, numpy.ndarray)
|
||||||
|
assert x.ndim == 2
|
||||||
|
x = time_warp(x, max_time_warp, inplace=inplace, mode=resize_mode)
|
||||||
|
x = freq_mask(
|
||||||
|
x,
|
||||||
|
max_freq_width,
|
||||||
|
n_freq_mask,
|
||||||
|
inplace=inplace,
|
||||||
|
replace_with_zero=replace_with_zero,
|
||||||
|
)
|
||||||
|
x = time_mask(
|
||||||
|
x,
|
||||||
|
max_time_width,
|
||||||
|
n_time_mask,
|
||||||
|
inplace=inplace,
|
||||||
|
replace_with_zero=replace_with_zero,
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SpecAugment(FuncTrans):
|
||||||
|
_func = spec_augment
|
||||||
|
__doc__ = spec_augment.__doc__
|
||||||
|
|
||||||
|
def __call__(self, x, train):
|
||||||
|
if not train:
|
||||||
|
return x
|
||||||
|
return super().__call__(x)
|
@ -0,0 +1,307 @@
|
|||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def stft(
|
||||||
|
x, n_fft, n_shift, win_length=None, window="hann", center=True, pad_mode="reflect"
|
||||||
|
):
|
||||||
|
# x: [Time, Channel]
|
||||||
|
if x.ndim == 1:
|
||||||
|
single_channel = True
|
||||||
|
# x: [Time] -> [Time, Channel]
|
||||||
|
x = x[:, None]
|
||||||
|
else:
|
||||||
|
single_channel = False
|
||||||
|
x = x.astype(np.float32)
|
||||||
|
|
||||||
|
# FIXME(kamo): librosa.stft can't use multi-channel?
|
||||||
|
# x: [Time, Channel, Freq]
|
||||||
|
x = np.stack(
|
||||||
|
[
|
||||||
|
librosa.stft(
|
||||||
|
x[:, ch],
|
||||||
|
n_fft=n_fft,
|
||||||
|
hop_length=n_shift,
|
||||||
|
win_length=win_length,
|
||||||
|
window=window,
|
||||||
|
center=center,
|
||||||
|
pad_mode=pad_mode,
|
||||||
|
).T
|
||||||
|
for ch in range(x.shape[1])
|
||||||
|
],
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
if single_channel:
|
||||||
|
# x: [Time, Channel, Freq] -> [Time, Freq]
|
||||||
|
x = x[:, 0]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def istft(x, n_shift, win_length=None, window="hann", center=True):
|
||||||
|
# x: [Time, Channel, Freq]
|
||||||
|
if x.ndim == 2:
|
||||||
|
single_channel = True
|
||||||
|
# x: [Time, Freq] -> [Time, Channel, Freq]
|
||||||
|
x = x[:, None, :]
|
||||||
|
else:
|
||||||
|
single_channel = False
|
||||||
|
|
||||||
|
# x: [Time, Channel]
|
||||||
|
x = np.stack(
|
||||||
|
[
|
||||||
|
librosa.istft(
|
||||||
|
x[:, ch].T, # [Time, Freq] -> [Freq, Time]
|
||||||
|
hop_length=n_shift,
|
||||||
|
win_length=win_length,
|
||||||
|
window=window,
|
||||||
|
center=center,
|
||||||
|
)
|
||||||
|
for ch in range(x.shape[1])
|
||||||
|
],
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
if single_channel:
|
||||||
|
# x: [Time, Channel] -> [Time]
|
||||||
|
x = x[:, 0]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def stft2logmelspectrogram(x_stft, fs, n_mels, n_fft, fmin=None, fmax=None, eps=1e-10):
|
||||||
|
# x_stft: (Time, Channel, Freq) or (Time, Freq)
|
||||||
|
fmin = 0 if fmin is None else fmin
|
||||||
|
fmax = fs / 2 if fmax is None else fmax
|
||||||
|
|
||||||
|
# spc: (Time, Channel, Freq) or (Time, Freq)
|
||||||
|
spc = np.abs(x_stft)
|
||||||
|
# mel_basis: (Mel_freq, Freq)
|
||||||
|
mel_basis = librosa.filters.mel(fs, n_fft, n_mels, fmin, fmax)
|
||||||
|
# lmspc: (Time, Channel, Mel_freq) or (Time, Mel_freq)
|
||||||
|
lmspc = np.log10(np.maximum(eps, np.dot(spc, mel_basis.T)))
|
||||||
|
|
||||||
|
return lmspc
|
||||||
|
|
||||||
|
|
||||||
|
def spectrogram(x, n_fft, n_shift, win_length=None, window="hann"):
|
||||||
|
# x: (Time, Channel) -> spc: (Time, Channel, Freq)
|
||||||
|
spc = np.abs(stft(x, n_fft, n_shift, win_length, window=window))
|
||||||
|
return spc
|
||||||
|
|
||||||
|
|
||||||
|
def logmelspectrogram(
|
||||||
|
x,
|
||||||
|
fs,
|
||||||
|
n_mels,
|
||||||
|
n_fft,
|
||||||
|
n_shift,
|
||||||
|
win_length=None,
|
||||||
|
window="hann",
|
||||||
|
fmin=None,
|
||||||
|
fmax=None,
|
||||||
|
eps=1e-10,
|
||||||
|
pad_mode="reflect",
|
||||||
|
):
|
||||||
|
# stft: (Time, Channel, Freq) or (Time, Freq)
|
||||||
|
x_stft = stft(
|
||||||
|
x,
|
||||||
|
n_fft=n_fft,
|
||||||
|
n_shift=n_shift,
|
||||||
|
win_length=win_length,
|
||||||
|
window=window,
|
||||||
|
pad_mode=pad_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
return stft2logmelspectrogram(
|
||||||
|
x_stft, fs=fs, n_mels=n_mels, n_fft=n_fft, fmin=fmin, fmax=fmax, eps=eps
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Spectrogram():
|
||||||
|
def __init__(self, n_fft, n_shift, win_length=None, window="hann"):
|
||||||
|
self.n_fft = n_fft
|
||||||
|
self.n_shift = n_shift
|
||||||
|
self.win_length = win_length
|
||||||
|
self.window = window
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (
|
||||||
|
"{name}(n_fft={n_fft}, n_shift={n_shift}, "
|
||||||
|
"win_length={win_length}, window={window})".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
n_shift=self.n_shift,
|
||||||
|
win_length=self.win_length,
|
||||||
|
window=self.window,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return spectrogram(
|
||||||
|
x,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
n_shift=self.n_shift,
|
||||||
|
win_length=self.win_length,
|
||||||
|
window=self.window,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LogMelSpectrogram():
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fs,
|
||||||
|
n_mels,
|
||||||
|
n_fft,
|
||||||
|
n_shift,
|
||||||
|
win_length=None,
|
||||||
|
window="hann",
|
||||||
|
fmin=None,
|
||||||
|
fmax=None,
|
||||||
|
eps=1e-10,
|
||||||
|
):
|
||||||
|
self.fs = fs
|
||||||
|
self.n_mels = n_mels
|
||||||
|
self.n_fft = n_fft
|
||||||
|
self.n_shift = n_shift
|
||||||
|
self.win_length = win_length
|
||||||
|
self.window = window
|
||||||
|
self.fmin = fmin
|
||||||
|
self.fmax = fmax
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (
|
||||||
|
"{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
|
||||||
|
"n_shift={n_shift}, win_length={win_length}, window={window}, "
|
||||||
|
"fmin={fmin}, fmax={fmax}, eps={eps}))".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
fs=self.fs,
|
||||||
|
n_mels=self.n_mels,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
n_shift=self.n_shift,
|
||||||
|
win_length=self.win_length,
|
||||||
|
window=self.window,
|
||||||
|
fmin=self.fmin,
|
||||||
|
fmax=self.fmax,
|
||||||
|
eps=self.eps,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return logmelspectrogram(
|
||||||
|
x,
|
||||||
|
fs=self.fs,
|
||||||
|
n_mels=self.n_mels,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
n_shift=self.n_shift,
|
||||||
|
win_length=self.win_length,
|
||||||
|
window=self.window,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Stft2LogMelSpectrogram():
|
||||||
|
def __init__(self, fs, n_mels, n_fft, fmin=None, fmax=None, eps=1e-10):
|
||||||
|
self.fs = fs
|
||||||
|
self.n_mels = n_mels
|
||||||
|
self.n_fft = n_fft
|
||||||
|
self.fmin = fmin
|
||||||
|
self.fmax = fmax
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (
|
||||||
|
"{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
|
||||||
|
"fmin={fmin}, fmax={fmax}, eps={eps}))".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
fs=self.fs,
|
||||||
|
n_mels=self.n_mels,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
fmin=self.fmin,
|
||||||
|
fmax=self.fmax,
|
||||||
|
eps=self.eps,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return stft2logmelspectrogram(
|
||||||
|
x,
|
||||||
|
fs=self.fs,
|
||||||
|
n_mels=self.n_mels,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
fmin=self.fmin,
|
||||||
|
fmax=self.fmax,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Stft():
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_fft,
|
||||||
|
n_shift,
|
||||||
|
win_length=None,
|
||||||
|
window="hann",
|
||||||
|
center=True,
|
||||||
|
pad_mode="reflect",
|
||||||
|
):
|
||||||
|
self.n_fft = n_fft
|
||||||
|
self.n_shift = n_shift
|
||||||
|
self.win_length = win_length
|
||||||
|
self.window = window
|
||||||
|
self.center = center
|
||||||
|
self.pad_mode = pad_mode
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (
|
||||||
|
"{name}(n_fft={n_fft}, n_shift={n_shift}, "
|
||||||
|
"win_length={win_length}, window={window},"
|
||||||
|
"center={center}, pad_mode={pad_mode})".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
n_shift=self.n_shift,
|
||||||
|
win_length=self.win_length,
|
||||||
|
window=self.window,
|
||||||
|
center=self.center,
|
||||||
|
pad_mode=self.pad_mode,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return stft(
|
||||||
|
x,
|
||||||
|
self.n_fft,
|
||||||
|
self.n_shift,
|
||||||
|
win_length=self.win_length,
|
||||||
|
window=self.window,
|
||||||
|
center=self.center,
|
||||||
|
pad_mode=self.pad_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class IStft():
|
||||||
|
def __init__(self, n_shift, win_length=None, window="hann", center=True):
|
||||||
|
self.n_shift = n_shift
|
||||||
|
self.win_length = win_length
|
||||||
|
self.window = window
|
||||||
|
self.center = center
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (
|
||||||
|
"{name}(n_shift={n_shift}, "
|
||||||
|
"win_length={win_length}, window={window},"
|
||||||
|
"center={center})".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
n_shift=self.n_shift,
|
||||||
|
win_length=self.win_length,
|
||||||
|
window=self.window,
|
||||||
|
center=self.center,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return istft(
|
||||||
|
x,
|
||||||
|
self.n_shift,
|
||||||
|
win_length=self.win_length,
|
||||||
|
window=self.window,
|
||||||
|
center=self.center,
|
||||||
|
)
|
@ -0,0 +1,20 @@
|
|||||||
|
# TODO(karita): add this to all the transform impl.
|
||||||
|
class TransformInterface:
|
||||||
|
"""Transform Interface"""
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
raise NotImplementedError("__call__ method is not implemented")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_arguments(cls, parser):
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__class__.__name__ + "()"
|
||||||
|
|
||||||
|
|
||||||
|
class Identity(TransformInterface):
|
||||||
|
"""Identity Function"""
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return x
|
@ -0,0 +1,149 @@
|
|||||||
|
"""Transformation module."""
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from collections import OrderedDict
|
||||||
|
import copy
|
||||||
|
from inspect import signature
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from deepspeech.utils.dynamic_import import dynamic_import
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(karita): inherit TransformInterface
|
||||||
|
# TODO(karita): register cmd arguments in asr_train.py
|
||||||
|
import_alias = dict(
|
||||||
|
identity="deepspeech.transform.transform_interface:Identity",
|
||||||
|
time_warp="deepspeech.transform.spec_augment:TimeWarp",
|
||||||
|
time_mask="deepspeech.transform.spec_augment:TimeMask",
|
||||||
|
freq_mask="deepspeech.transform.spec_augment:FreqMask",
|
||||||
|
spec_augment="deepspeech.transform.spec_augment:SpecAugment",
|
||||||
|
speed_perturbation="deepspeech.transform.perturb:SpeedPerturbation",
|
||||||
|
volume_perturbation="deepspeech.transform.perturb:VolumePerturbation",
|
||||||
|
noise_injection="deepspeech.transform.perturb:NoiseInjection",
|
||||||
|
bandpass_perturbation="deepspeech.transform.perturb:BandpassPerturbation",
|
||||||
|
rir_convolve="deepspeech.transform.perturb:RIRConvolve",
|
||||||
|
delta="deepspeech.transform.add_deltas:AddDeltas",
|
||||||
|
cmvn="deepspeech.transform.cmvn:CMVN",
|
||||||
|
utterance_cmvn="deepspeech.transform.cmvn:UtteranceCMVN",
|
||||||
|
fbank="deepspeech.transform.spectrogram:LogMelSpectrogram",
|
||||||
|
spectrogram="deepspeech.transform.spectrogram:Spectrogram",
|
||||||
|
stft="deepspeech.transform.spectrogram:Stft",
|
||||||
|
istft="deepspeech.transform.spectrogram:IStft",
|
||||||
|
stft2fbank="deepspeech.transform.spectrogram:Stft2LogMelSpectrogram",
|
||||||
|
wpe="deepspeech.transform.wpe:WPE",
|
||||||
|
channel_selector="deepspeech.transform.channel_selector:ChannelSelector",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Transformation():
|
||||||
|
"""Apply some functions to the mini-batch
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> kwargs = {"process": [{"type": "fbank",
|
||||||
|
... "n_mels": 80,
|
||||||
|
... "fs": 16000},
|
||||||
|
... {"type": "cmvn",
|
||||||
|
... "stats": "data/train/cmvn.ark",
|
||||||
|
... "norm_vars": True},
|
||||||
|
... {"type": "delta", "window": 2, "order": 2}]}
|
||||||
|
>>> transform = Transformation(kwargs)
|
||||||
|
>>> bs = 10
|
||||||
|
>>> xs = [np.random.randn(100, 80).astype(np.float32)
|
||||||
|
... for _ in range(bs)]
|
||||||
|
>>> xs = transform(xs)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, conffile=None):
|
||||||
|
if conffile is not None:
|
||||||
|
if isinstance(conffile, dict):
|
||||||
|
self.conf = copy.deepcopy(conffile)
|
||||||
|
else:
|
||||||
|
with io.open(conffile, encoding="utf-8") as f:
|
||||||
|
self.conf = yaml.safe_load(f)
|
||||||
|
assert isinstance(self.conf, dict), type(self.conf)
|
||||||
|
else:
|
||||||
|
self.conf = {"mode": "sequential", "process": []}
|
||||||
|
|
||||||
|
self.functions = OrderedDict()
|
||||||
|
if self.conf.get("mode", "sequential") == "sequential":
|
||||||
|
for idx, process in enumerate(self.conf["process"]):
|
||||||
|
assert isinstance(process, dict), type(process)
|
||||||
|
opts = dict(process)
|
||||||
|
process_type = opts.pop("type")
|
||||||
|
class_obj = dynamic_import(process_type, import_alias)
|
||||||
|
# TODO(karita): assert issubclass(class_obj, TransformInterface)
|
||||||
|
try:
|
||||||
|
self.functions[idx] = class_obj(**opts)
|
||||||
|
except TypeError:
|
||||||
|
try:
|
||||||
|
signa = signature(class_obj)
|
||||||
|
except ValueError:
|
||||||
|
# Some function, e.g. built-in function, are failed
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
logging.error(
|
||||||
|
"Expected signature: {}({})".format(
|
||||||
|
class_obj.__name__, signa
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Not supporting mode={}".format(self.conf["mode"])
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
rep = "\n" + "\n".join(
|
||||||
|
" {}: {}".format(k, v) for k, v in self.functions.items()
|
||||||
|
)
|
||||||
|
return "{}({})".format(self.__class__.__name__, rep)
|
||||||
|
|
||||||
|
def __call__(self, xs, uttid_list=None, **kwargs):
|
||||||
|
"""Return new mini-batch
|
||||||
|
|
||||||
|
:param Union[Sequence[np.ndarray], np.ndarray] xs:
|
||||||
|
:param Union[Sequence[str], str] uttid_list:
|
||||||
|
:return: batch:
|
||||||
|
:rtype: List[np.ndarray]
|
||||||
|
"""
|
||||||
|
if not isinstance(xs, Sequence):
|
||||||
|
is_batch = False
|
||||||
|
xs = [xs]
|
||||||
|
else:
|
||||||
|
is_batch = True
|
||||||
|
|
||||||
|
if isinstance(uttid_list, str):
|
||||||
|
uttid_list = [uttid_list for _ in range(len(xs))]
|
||||||
|
|
||||||
|
if self.conf.get("mode", "sequential") == "sequential":
|
||||||
|
for idx in range(len(self.conf["process"])):
|
||||||
|
func = self.functions[idx]
|
||||||
|
# TODO(karita): use TrainingTrans and UttTrans to check __call__ args
|
||||||
|
# Derive only the args which the func has
|
||||||
|
try:
|
||||||
|
param = signature(func).parameters
|
||||||
|
except ValueError:
|
||||||
|
# Some function, e.g. built-in function, are failed
|
||||||
|
param = {}
|
||||||
|
_kwargs = {k: v for k, v in kwargs.items() if k in param}
|
||||||
|
try:
|
||||||
|
if uttid_list is not None and "uttid" in param:
|
||||||
|
xs = [func(x, u, **_kwargs) for x, u in zip(xs, uttid_list)]
|
||||||
|
else:
|
||||||
|
xs = [func(x, **_kwargs) for x in xs]
|
||||||
|
except Exception:
|
||||||
|
logging.fatal(
|
||||||
|
"Catch a exception from {}th func: {}".format(idx, func)
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Not supporting mode={}".format(self.conf["mode"])
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_batch:
|
||||||
|
return xs
|
||||||
|
else:
|
||||||
|
return xs[0]
|
@ -0,0 +1,45 @@
|
|||||||
|
from nara_wpe.wpe import wpe
|
||||||
|
|
||||||
|
|
||||||
|
class WPE(object):
|
||||||
|
def __init__(
|
||||||
|
self, taps=10, delay=3, iterations=3, psd_context=0, statistics_mode="full"
|
||||||
|
):
|
||||||
|
self.taps = taps
|
||||||
|
self.delay = delay
|
||||||
|
self.iterations = iterations
|
||||||
|
self.psd_context = psd_context
|
||||||
|
self.statistics_mode = statistics_mode
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (
|
||||||
|
"{name}(taps={taps}, delay={delay}"
|
||||||
|
"iterations={iterations}, psd_context={psd_context}, "
|
||||||
|
"statistics_mode={statistics_mode})".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
taps=self.taps,
|
||||||
|
delay=self.delay,
|
||||||
|
iterations=self.iterations,
|
||||||
|
psd_context=self.psd_context,
|
||||||
|
statistics_mode=self.statistics_mode,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, xs):
|
||||||
|
"""Return enhanced
|
||||||
|
|
||||||
|
:param np.ndarray xs: (Time, Channel, Frequency)
|
||||||
|
:return: enhanced_xs
|
||||||
|
:rtype: np.ndarray
|
||||||
|
|
||||||
|
"""
|
||||||
|
# nara_wpe.wpe: (F, C, T)
|
||||||
|
xs = wpe(
|
||||||
|
xs.transpose((2, 1, 0)),
|
||||||
|
taps=self.taps,
|
||||||
|
delay=self.delay,
|
||||||
|
iterations=self.iterations,
|
||||||
|
psd_context=self.psd_context,
|
||||||
|
statistics_mode=self.statistics_mode,
|
||||||
|
)
|
||||||
|
return xs.transpose(2, 1, 0)
|
@ -0,0 +1,20 @@
|
|||||||
|
import inspect
|
||||||
|
|
||||||
|
|
||||||
|
def check_kwargs(func, kwargs, name=None):
|
||||||
|
"""check kwargs are valid for func
|
||||||
|
|
||||||
|
If kwargs are invalid, raise TypeError as same as python default
|
||||||
|
:param function func: function to be validated
|
||||||
|
:param dict kwargs: keyword arguments for func
|
||||||
|
:param str name: name used in TypeError (default is func name)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
params = inspect.signature(func).parameters
|
||||||
|
except ValueError:
|
||||||
|
return
|
||||||
|
if name is None:
|
||||||
|
name = func.__name__
|
||||||
|
for k in kwargs.keys():
|
||||||
|
if k not in params:
|
||||||
|
raise TypeError(f"{name}() got an unexpected keyword argument '{k}'")
|
@ -1,122 +0,0 @@
|
|||||||
# https://yaml.org/type/float.html
|
|
||||||
data:
|
|
||||||
train_manifest: data/manifest.train
|
|
||||||
dev_manifest: data/manifest.dev
|
|
||||||
test_manifest: data/manifest.test
|
|
||||||
min_input_len: 0.5
|
|
||||||
max_input_len: 20.0
|
|
||||||
min_output_len: 0.0
|
|
||||||
max_output_len: 400.0
|
|
||||||
min_output_input_ratio: 0.05
|
|
||||||
max_output_input_ratio: 10.0
|
|
||||||
|
|
||||||
collator:
|
|
||||||
vocab_filepath: data/vocab.txt
|
|
||||||
unit_type: 'spm'
|
|
||||||
spm_model_prefix: 'data/bpe_unigram_5000'
|
|
||||||
mean_std_filepath: ""
|
|
||||||
augmentation_config: conf/augmentation.json
|
|
||||||
batch_size: 16
|
|
||||||
raw_wav: True # use raw_wav or kaldi feature
|
|
||||||
spectrum_type: fbank #linear, mfcc, fbank
|
|
||||||
feat_dim: 80
|
|
||||||
delta_delta: False
|
|
||||||
dither: 1.0
|
|
||||||
target_sample_rate: 16000
|
|
||||||
max_freq: None
|
|
||||||
n_fft: None
|
|
||||||
stride_ms: 10.0
|
|
||||||
window_ms: 25.0
|
|
||||||
use_dB_normalization: True
|
|
||||||
target_dB: -20
|
|
||||||
random_seed: 0
|
|
||||||
keep_transcription_text: False
|
|
||||||
sortagrad: True
|
|
||||||
shuffle_method: batch_shuffle
|
|
||||||
num_workers: 2
|
|
||||||
|
|
||||||
|
|
||||||
# network architecture
|
|
||||||
model:
|
|
||||||
cmvn_file: "data/mean_std.json"
|
|
||||||
cmvn_file_type: "json"
|
|
||||||
# encoder related
|
|
||||||
encoder: conformer
|
|
||||||
encoder_conf:
|
|
||||||
output_size: 256 # dimension of attention
|
|
||||||
attention_heads: 4
|
|
||||||
linear_units: 2048 # the number of units of position-wise feed forward
|
|
||||||
num_blocks: 12 # the number of encoder blocks
|
|
||||||
dropout_rate: 0.1
|
|
||||||
positional_dropout_rate: 0.1
|
|
||||||
attention_dropout_rate: 0.0
|
|
||||||
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
|
|
||||||
normalize_before: True
|
|
||||||
use_cnn_module: True
|
|
||||||
cnn_module_kernel: 15
|
|
||||||
activation_type: 'swish'
|
|
||||||
pos_enc_layer_type: 'rel_pos'
|
|
||||||
selfattention_layer_type: 'rel_selfattn'
|
|
||||||
causal: True
|
|
||||||
use_dynamic_chunk: true
|
|
||||||
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
|
|
||||||
use_dynamic_left_chunk: false
|
|
||||||
|
|
||||||
# decoder related
|
|
||||||
decoder: transformer
|
|
||||||
decoder_conf:
|
|
||||||
attention_heads: 4
|
|
||||||
linear_units: 2048
|
|
||||||
num_blocks: 6
|
|
||||||
dropout_rate: 0.1
|
|
||||||
positional_dropout_rate: 0.1
|
|
||||||
self_attention_dropout_rate: 0.0
|
|
||||||
src_attention_dropout_rate: 0.0
|
|
||||||
|
|
||||||
# hybrid CTC/attention
|
|
||||||
model_conf:
|
|
||||||
ctc_weight: 0.3
|
|
||||||
ctc_dropoutrate: 0.0
|
|
||||||
ctc_grad_norm_type: null
|
|
||||||
lsm_weight: 0.1 # label smoothing option
|
|
||||||
length_normalized_loss: false
|
|
||||||
|
|
||||||
|
|
||||||
training:
|
|
||||||
n_epoch: 240
|
|
||||||
accum_grad: 8
|
|
||||||
global_grad_clip: 5.0
|
|
||||||
optim: adam
|
|
||||||
optim_conf:
|
|
||||||
lr: 0.001
|
|
||||||
weight_decay: 1e-06
|
|
||||||
scheduler: warmuplr # pytorch v1.1.0+ required
|
|
||||||
scheduler_conf:
|
|
||||||
warmup_steps: 25000
|
|
||||||
lr_decay: 1.0
|
|
||||||
log_interval: 100
|
|
||||||
checkpoint:
|
|
||||||
kbest_n: 50
|
|
||||||
latest_n: 5
|
|
||||||
|
|
||||||
|
|
||||||
decoding:
|
|
||||||
batch_size: 128
|
|
||||||
error_rate_type: wer
|
|
||||||
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
|
|
||||||
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
|
|
||||||
alpha: 2.5
|
|
||||||
beta: 0.3
|
|
||||||
beam_size: 10
|
|
||||||
cutoff_prob: 1.0
|
|
||||||
cutoff_top_n: 0
|
|
||||||
num_proc_bsearch: 8
|
|
||||||
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
|
|
||||||
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
|
|
||||||
# <0: for decoding, use full chunk.
|
|
||||||
# >0: for decoding, use fixed chunk size as set.
|
|
||||||
# 0: used for training, it's prohibited here.
|
|
||||||
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
|
|
||||||
simulate_streaming: true # simulate streaming inference. Defaults to False.
|
|
||||||
|
|
||||||
|
|
@ -1,115 +0,0 @@
|
|||||||
# https://yaml.org/type/float.html
|
|
||||||
data:
|
|
||||||
train_manifest: data/manifest.train
|
|
||||||
dev_manifest: data/manifest.dev
|
|
||||||
test_manifest: data/manifest.test
|
|
||||||
min_input_len: 0.5 # second
|
|
||||||
max_input_len: 20.0 # second
|
|
||||||
min_output_len: 0.0 # tokens
|
|
||||||
max_output_len: 400.0 # tokens
|
|
||||||
min_output_input_ratio: 0.05
|
|
||||||
max_output_input_ratio: 10.0
|
|
||||||
|
|
||||||
collator:
|
|
||||||
vocab_filepath: data/vocab.txt
|
|
||||||
unit_type: 'spm'
|
|
||||||
spm_model_prefix: 'data/bpe_unigram_5000'
|
|
||||||
mean_std_filepath: ""
|
|
||||||
augmentation_config: conf/augmentation.json
|
|
||||||
batch_size: 64
|
|
||||||
raw_wav: True # use raw_wav or kaldi feature
|
|
||||||
spectrum_type: fbank #linear, mfcc, fbank
|
|
||||||
feat_dim: 80
|
|
||||||
delta_delta: False
|
|
||||||
dither: 1.0
|
|
||||||
target_sample_rate: 16000
|
|
||||||
max_freq: None
|
|
||||||
n_fft: None
|
|
||||||
stride_ms: 10.0
|
|
||||||
window_ms: 25.0
|
|
||||||
use_dB_normalization: True
|
|
||||||
target_dB: -20
|
|
||||||
random_seed: 0
|
|
||||||
keep_transcription_text: False
|
|
||||||
sortagrad: True
|
|
||||||
shuffle_method: batch_shuffle
|
|
||||||
num_workers: 2
|
|
||||||
|
|
||||||
|
|
||||||
# network architecture
|
|
||||||
model:
|
|
||||||
cmvn_file: "data/mean_std.json"
|
|
||||||
cmvn_file_type: "json"
|
|
||||||
# encoder related
|
|
||||||
encoder: transformer
|
|
||||||
encoder_conf:
|
|
||||||
output_size: 256 # dimension of attention
|
|
||||||
attention_heads: 4
|
|
||||||
linear_units: 2048 # the number of units of position-wise feed forward
|
|
||||||
num_blocks: 12 # the number of encoder blocks
|
|
||||||
dropout_rate: 0.1
|
|
||||||
positional_dropout_rate: 0.1
|
|
||||||
attention_dropout_rate: 0.0
|
|
||||||
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
|
|
||||||
normalize_before: true
|
|
||||||
use_dynamic_chunk: true
|
|
||||||
use_dynamic_left_chunk: false
|
|
||||||
|
|
||||||
# decoder related
|
|
||||||
decoder: transformer
|
|
||||||
decoder_conf:
|
|
||||||
attention_heads: 4
|
|
||||||
linear_units: 2048
|
|
||||||
num_blocks: 6
|
|
||||||
dropout_rate: 0.1
|
|
||||||
positional_dropout_rate: 0.1
|
|
||||||
self_attention_dropout_rate: 0.0
|
|
||||||
src_attention_dropout_rate: 0.0
|
|
||||||
|
|
||||||
# hybrid CTC/attention
|
|
||||||
model_conf:
|
|
||||||
ctc_weight: 0.3
|
|
||||||
ctc_dropoutrate: 0.0
|
|
||||||
ctc_grad_norm_type: null
|
|
||||||
lsm_weight: 0.1 # label smoothing option
|
|
||||||
length_normalized_loss: false
|
|
||||||
|
|
||||||
|
|
||||||
training:
|
|
||||||
n_epoch: 120
|
|
||||||
accum_grad: 1
|
|
||||||
global_grad_clip: 5.0
|
|
||||||
optim: adam
|
|
||||||
optim_conf:
|
|
||||||
lr: 0.001
|
|
||||||
weight_decay: 1e-06
|
|
||||||
scheduler: warmuplr # pytorch v1.1.0+ required
|
|
||||||
scheduler_conf:
|
|
||||||
warmup_steps: 25000
|
|
||||||
lr_decay: 1.0
|
|
||||||
log_interval: 100
|
|
||||||
checkpoint:
|
|
||||||
kbest_n: 50
|
|
||||||
latest_n: 5
|
|
||||||
|
|
||||||
|
|
||||||
decoding:
|
|
||||||
batch_size: 64
|
|
||||||
error_rate_type: wer
|
|
||||||
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
|
|
||||||
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
|
|
||||||
alpha: 2.5
|
|
||||||
beta: 0.3
|
|
||||||
beam_size: 10
|
|
||||||
cutoff_prob: 1.0
|
|
||||||
cutoff_top_n: 0
|
|
||||||
num_proc_bsearch: 8
|
|
||||||
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
|
|
||||||
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
|
|
||||||
# <0: for decoding, use full chunk.
|
|
||||||
# >0: for decoding, use fixed chunk size as set.
|
|
||||||
# 0: used for training, it's prohibited here.
|
|
||||||
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
|
|
||||||
simulate_streaming: true # simulate streaming inference. Defaults to False.
|
|
||||||
|
|
||||||
|
|
@ -1,118 +0,0 @@
|
|||||||
# https://yaml.org/type/float.html
|
|
||||||
data:
|
|
||||||
train_manifest: data/manifest.train
|
|
||||||
dev_manifest: data/manifest.dev
|
|
||||||
test_manifest: data/manifest.test-clean
|
|
||||||
min_input_len: 0.5 # seconds
|
|
||||||
max_input_len: 20.0 # seconds
|
|
||||||
min_output_len: 0.0 # tokens
|
|
||||||
max_output_len: 400.0 # tokens
|
|
||||||
min_output_input_ratio: 0.05
|
|
||||||
max_output_input_ratio: 10.0
|
|
||||||
|
|
||||||
collator:
|
|
||||||
vocab_filepath: data/vocab.txt
|
|
||||||
unit_type: 'spm'
|
|
||||||
spm_model_prefix: 'data/bpe_unigram_5000'
|
|
||||||
mean_std_filepath: ""
|
|
||||||
augmentation_config: conf/augmentation.json
|
|
||||||
batch_size: 16
|
|
||||||
raw_wav: True # use raw_wav or kaldi feature
|
|
||||||
spectrum_type: fbank #linear, mfcc, fbank
|
|
||||||
feat_dim: 80
|
|
||||||
delta_delta: False
|
|
||||||
dither: 1.0
|
|
||||||
target_sample_rate: 16000
|
|
||||||
max_freq: None
|
|
||||||
n_fft: None
|
|
||||||
stride_ms: 10.0
|
|
||||||
window_ms: 25.0
|
|
||||||
use_dB_normalization: True
|
|
||||||
target_dB: -20
|
|
||||||
random_seed: 0
|
|
||||||
keep_transcription_text: False
|
|
||||||
sortagrad: True
|
|
||||||
shuffle_method: batch_shuffle
|
|
||||||
num_workers: 2
|
|
||||||
|
|
||||||
|
|
||||||
# network architecture
|
|
||||||
model:
|
|
||||||
cmvn_file: "data/mean_std.json"
|
|
||||||
cmvn_file_type: "json"
|
|
||||||
# encoder related
|
|
||||||
encoder: conformer
|
|
||||||
encoder_conf:
|
|
||||||
output_size: 256 # dimension of attention
|
|
||||||
attention_heads: 4
|
|
||||||
linear_units: 2048 # the number of units of position-wise feed forward
|
|
||||||
num_blocks: 12 # the number of encoder blocks
|
|
||||||
dropout_rate: 0.1
|
|
||||||
positional_dropout_rate: 0.1
|
|
||||||
attention_dropout_rate: 0.0
|
|
||||||
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
|
|
||||||
normalize_before: True
|
|
||||||
use_cnn_module: True
|
|
||||||
cnn_module_kernel: 15
|
|
||||||
activation_type: 'swish'
|
|
||||||
pos_enc_layer_type: 'rel_pos'
|
|
||||||
selfattention_layer_type: 'rel_selfattn'
|
|
||||||
|
|
||||||
# decoder related
|
|
||||||
decoder: transformer
|
|
||||||
decoder_conf:
|
|
||||||
attention_heads: 4
|
|
||||||
linear_units: 2048
|
|
||||||
num_blocks: 6
|
|
||||||
dropout_rate: 0.1
|
|
||||||
positional_dropout_rate: 0.1
|
|
||||||
self_attention_dropout_rate: 0.0
|
|
||||||
src_attention_dropout_rate: 0.0
|
|
||||||
|
|
||||||
# hybrid CTC/attention
|
|
||||||
model_conf:
|
|
||||||
ctc_weight: 0.3
|
|
||||||
ctc_dropoutrate: 0.0
|
|
||||||
ctc_grad_norm_type: null
|
|
||||||
lsm_weight: 0.1 # label smoothing option
|
|
||||||
length_normalized_loss: false
|
|
||||||
|
|
||||||
|
|
||||||
training:
|
|
||||||
n_epoch: 120
|
|
||||||
accum_grad: 8
|
|
||||||
global_grad_clip: 3.0
|
|
||||||
optim: adam
|
|
||||||
optim_conf:
|
|
||||||
lr: 0.004
|
|
||||||
weight_decay: 1e-06
|
|
||||||
scheduler: warmuplr # pytorch v1.1.0+ required
|
|
||||||
scheduler_conf:
|
|
||||||
warmup_steps: 25000
|
|
||||||
lr_decay: 1.0
|
|
||||||
log_interval: 100
|
|
||||||
checkpoint:
|
|
||||||
kbest_n: 50
|
|
||||||
latest_n: 5
|
|
||||||
|
|
||||||
|
|
||||||
decoding:
|
|
||||||
batch_size: 64
|
|
||||||
error_rate_type: wer
|
|
||||||
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
|
|
||||||
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
|
|
||||||
alpha: 2.5
|
|
||||||
beta: 0.3
|
|
||||||
beam_size: 10
|
|
||||||
cutoff_prob: 1.0
|
|
||||||
cutoff_top_n: 0
|
|
||||||
num_proc_bsearch: 8
|
|
||||||
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
|
|
||||||
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
|
|
||||||
# <0: for decoding, use full chunk.
|
|
||||||
# >0: for decoding, use fixed chunk size as set.
|
|
||||||
# 0: used for training, it's prohibited here.
|
|
||||||
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
|
|
||||||
simulate_streaming: False # simulate streaming inference. Defaults to False.
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,84 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from deepspeech.transform.transformation import Transformation
|
||||||
|
from deepspeech.utils.cli_readers import file_reader_helper
|
||||||
|
from deepspeech.utils.cli_utils import get_commandline_args
|
||||||
|
from deepspeech.utils.cli_utils import is_scipy_wav_style
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="convert feature to its shape",
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
|
)
|
||||||
|
parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option")
|
||||||
|
parser.add_argument(
|
||||||
|
"--filetype",
|
||||||
|
type=str,
|
||||||
|
default="mat",
|
||||||
|
choices=["mat", "hdf5", "sound.hdf5", "sound"],
|
||||||
|
help="Specify the file format for the rspecifier. "
|
||||||
|
'"mat" is the matrix format in kaldi',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--preprocess-conf",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The configuration file for the pre-processing",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"rspecifier", type=str, help="Read specifier for feats. e.g. ark:some.ark"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"out",
|
||||||
|
nargs="?",
|
||||||
|
type=argparse.FileType("w"),
|
||||||
|
default=sys.stdout,
|
||||||
|
help="The output filename. " "If omitted, then output to sys.stdout",
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# logging info
|
||||||
|
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||||
|
if args.verbose > 0:
|
||||||
|
logging.basicConfig(level=logging.INFO, format=logfmt)
|
||||||
|
else:
|
||||||
|
logging.basicConfig(level=logging.WARN, format=logfmt)
|
||||||
|
logging.info(get_commandline_args())
|
||||||
|
|
||||||
|
if args.preprocess_conf is not None:
|
||||||
|
preprocessing = Transformation(args.preprocess_conf)
|
||||||
|
logging.info("Apply preprocessing: {}".format(preprocessing))
|
||||||
|
else:
|
||||||
|
preprocessing = None
|
||||||
|
|
||||||
|
# There are no necessary for matrix without preprocessing,
|
||||||
|
# so change to file_reader_helper to return shape.
|
||||||
|
# This make sense only with filetype="hdf5".
|
||||||
|
for utt, mat in file_reader_helper(
|
||||||
|
args.rspecifier, args.filetype, return_shape=preprocessing is None
|
||||||
|
):
|
||||||
|
if preprocessing is not None:
|
||||||
|
if is_scipy_wav_style(mat):
|
||||||
|
# If data is sound file, then got as Tuple[int, ndarray]
|
||||||
|
rate, mat = mat
|
||||||
|
mat = preprocessing(mat, uttid_list=utt)
|
||||||
|
shape_str = ",".join(map(str, mat.shape))
|
||||||
|
else:
|
||||||
|
if len(mat) == 2 and isinstance(mat[1], tuple):
|
||||||
|
# If data is sound file, Tuple[int, Tuple[int, ...]]
|
||||||
|
rate, mat = mat
|
||||||
|
shape_str = ",".join(map(str, mat))
|
||||||
|
args.out.write("{} {}\n".format(utt, shape_str))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Reference in new issue