parent
614a004c37
commit
a107b75bac
@ -0,0 +1,41 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def delta(feat, window):
|
||||
assert window > 0
|
||||
delta_feat = np.zeros_like(feat)
|
||||
for i in range(1, window + 1):
|
||||
delta_feat[:-i] += i * feat[i:]
|
||||
delta_feat[i:] += -i * feat[:-i]
|
||||
delta_feat[-i:] += i * feat[-1]
|
||||
delta_feat[:i] += -i * feat[0]
|
||||
delta_feat /= 2 * sum(i ** 2 for i in range(1, window + 1))
|
||||
return delta_feat
|
||||
|
||||
|
||||
def add_deltas(x, window=2, order=2):
|
||||
"""
|
||||
Args:
|
||||
x (np.ndarray): speech feat, (T, D).
|
||||
|
||||
Return:
|
||||
np.ndarray: (T, (1+order)*D)
|
||||
"""
|
||||
feats = [x]
|
||||
for _ in range(order):
|
||||
feats.append(delta(feats[-1], window))
|
||||
return np.concatenate(feats, axis=1)
|
||||
|
||||
|
||||
class AddDeltas():
|
||||
def __init__(self, window=2, order=2):
|
||||
self.window = window
|
||||
self.order = order
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(window={window}, order={order}".format(
|
||||
name=self.__class__.__name__, window=self.window, order=self.order
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
return add_deltas(x, window=self.window, order=self.order)
|
@ -0,0 +1,45 @@
|
||||
import numpy
|
||||
|
||||
|
||||
class ChannelSelector():
|
||||
"""Select 1ch from multi-channel signal"""
|
||||
|
||||
def __init__(self, train_channel="random", eval_channel=0, axis=1):
|
||||
self.train_channel = train_channel
|
||||
self.eval_channel = eval_channel
|
||||
self.axis = axis
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"{name}(train_channel={train_channel}, "
|
||||
"eval_channel={eval_channel}, axis={axis})".format(
|
||||
name=self.__class__.__name__,
|
||||
train_channel=self.train_channel,
|
||||
eval_channel=self.eval_channel,
|
||||
axis=self.axis,
|
||||
)
|
||||
)
|
||||
|
||||
def __call__(self, x, train=True):
|
||||
# Assuming x: [Time, Channel] by default
|
||||
|
||||
if x.ndim <= self.axis:
|
||||
# If the dimension is insufficient, then unsqueeze
|
||||
# (e.g [Time] -> [Time, 1])
|
||||
ind = tuple(
|
||||
slice(None) if i < x.ndim else None for i in range(self.axis + 1)
|
||||
)
|
||||
x = x[ind]
|
||||
|
||||
if train:
|
||||
channel = self.train_channel
|
||||
else:
|
||||
channel = self.eval_channel
|
||||
|
||||
if channel == "random":
|
||||
ch = numpy.random.randint(0, x.shape[self.axis])
|
||||
else:
|
||||
ch = channel
|
||||
|
||||
ind = tuple(slice(None) if i != self.axis else ch for i in range(x.ndim))
|
||||
return x[ind]
|
@ -0,0 +1,71 @@
|
||||
import inspect
|
||||
|
||||
from deepspeech.transform.transform_interface import TransformInterface
|
||||
from deepspeech.utils.check_kwargs import check_kwargs
|
||||
|
||||
|
||||
class FuncTrans(TransformInterface):
|
||||
"""Functional Transformation
|
||||
|
||||
WARNING:
|
||||
Builtin or C/C++ functions may not work properly
|
||||
because this class heavily depends on the `inspect` module.
|
||||
|
||||
Usage:
|
||||
|
||||
>>> def foo_bar(x, a=1, b=2):
|
||||
... '''Foo bar
|
||||
... :param x: input
|
||||
... :param int a: default 1
|
||||
... :param int b: default 2
|
||||
... '''
|
||||
... return x + a - b
|
||||
|
||||
|
||||
>>> class FooBar(FuncTrans):
|
||||
... _func = foo_bar
|
||||
... __doc__ = foo_bar.__doc__
|
||||
"""
|
||||
|
||||
_func = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
check_kwargs(self.func, kwargs)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.func(x, **self.kwargs)
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
fname = cls._func.__name__.replace("_", "-")
|
||||
group = parser.add_argument_group(fname + " transformation setting")
|
||||
for k, v in cls.default_params().items():
|
||||
# TODO(karita): get help and choices from docstring?
|
||||
attr = k.replace("_", "-")
|
||||
group.add_argument(f"--{fname}-{attr}", default=v, type=type(v))
|
||||
return parser
|
||||
|
||||
@property
|
||||
def func(self):
|
||||
return type(self)._func
|
||||
|
||||
@classmethod
|
||||
def default_params(cls):
|
||||
try:
|
||||
d = dict(inspect.signature(cls._func).parameters)
|
||||
except ValueError:
|
||||
d = dict()
|
||||
return {
|
||||
k: v.default for k, v in d.items() if v.default != inspect.Parameter.empty
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
params = self.default_params()
|
||||
params.update(**self.kwargs)
|
||||
ret = self.__class__.__name__ + "("
|
||||
if len(params) == 0:
|
||||
return ret + ")"
|
||||
for k, v in params.items():
|
||||
ret += "{}={}, ".format(k, v)
|
||||
return ret[:-2] + ")"
|
@ -0,0 +1,343 @@
|
||||
import librosa
|
||||
import numpy
|
||||
import scipy
|
||||
import soundfile
|
||||
|
||||
from deepspeech.io.reader import SoundHDF5File
|
||||
|
||||
class SpeedPerturbation():
|
||||
"""SpeedPerturbation
|
||||
|
||||
The speed perturbation in kaldi uses sox-speed instead of sox-tempo,
|
||||
and sox-speed just to resample the input,
|
||||
i.e pitch and tempo are changed both.
|
||||
|
||||
"Why use speed option instead of tempo -s in SoX for speed perturbation"
|
||||
https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8
|
||||
|
||||
Warning:
|
||||
This function is very slow because of resampling.
|
||||
I recommmend to apply speed-perturb outside the training using sox.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lower=0.9,
|
||||
upper=1.1,
|
||||
utt2ratio=None,
|
||||
keep_length=True,
|
||||
res_type="kaiser_best",
|
||||
seed=None,
|
||||
):
|
||||
self.res_type = res_type
|
||||
self.keep_length = keep_length
|
||||
self.state = numpy.random.RandomState(seed)
|
||||
|
||||
if utt2ratio is not None:
|
||||
self.utt2ratio = {}
|
||||
# Use the scheduled ratio for each utterances
|
||||
self.utt2ratio_file = utt2ratio
|
||||
self.lower = None
|
||||
self.upper = None
|
||||
self.accept_uttid = True
|
||||
|
||||
with open(utt2ratio, "r") as f:
|
||||
for line in f:
|
||||
utt, ratio = line.rstrip().split(None, 1)
|
||||
ratio = float(ratio)
|
||||
self.utt2ratio[utt] = ratio
|
||||
else:
|
||||
self.utt2ratio = None
|
||||
# The ratio is given on runtime randomly
|
||||
self.lower = lower
|
||||
self.upper = upper
|
||||
|
||||
def __repr__(self):
|
||||
if self.utt2ratio is None:
|
||||
return "{}(lower={}, upper={}, " "keep_length={}, res_type={})".format(
|
||||
self.__class__.__name__,
|
||||
self.lower,
|
||||
self.upper,
|
||||
self.keep_length,
|
||||
self.res_type,
|
||||
)
|
||||
else:
|
||||
return "{}({}, res_type={})".format(
|
||||
self.__class__.__name__, self.utt2ratio_file, self.res_type
|
||||
)
|
||||
|
||||
def __call__(self, x, uttid=None, train=True):
|
||||
if not train:
|
||||
return x
|
||||
|
||||
x = x.astype(numpy.float32)
|
||||
if self.accept_uttid:
|
||||
ratio = self.utt2ratio[uttid]
|
||||
else:
|
||||
ratio = self.state.uniform(self.lower, self.upper)
|
||||
|
||||
# Note1: resample requires the sampling-rate of input and output,
|
||||
# but actually only the ratio is used.
|
||||
y = librosa.resample(x, ratio, 1, res_type=self.res_type)
|
||||
|
||||
if self.keep_length:
|
||||
diff = abs(len(x) - len(y))
|
||||
if len(y) > len(x):
|
||||
# Truncate noise
|
||||
y = y[diff // 2 : -((diff + 1) // 2)]
|
||||
elif len(y) < len(x):
|
||||
# Assume the time-axis is the first: (Time, Channel)
|
||||
pad_width = [(diff // 2, (diff + 1) // 2)] + [
|
||||
(0, 0) for _ in range(y.ndim - 1)
|
||||
]
|
||||
y = numpy.pad(
|
||||
y, pad_width=pad_width, constant_values=0, mode="constant"
|
||||
)
|
||||
return y
|
||||
|
||||
|
||||
class BandpassPerturbation():
|
||||
"""BandpassPerturbation
|
||||
|
||||
Randomly dropout along the frequency axis.
|
||||
|
||||
The original idea comes from the following:
|
||||
"randomly-selected frequency band was cut off under the constraint of
|
||||
leaving at least 1,000 Hz band within the range of less than 4,000Hz."
|
||||
(The Hitachi/JHU CHiME-5 system: Advances in speech recognition for
|
||||
everyday home environments using multiple microphone arrays;
|
||||
http://spandh.dcs.shef.ac.uk/chime_workshop/papers/CHiME_2018_paper_kanda.pdf)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, lower=0.0, upper=0.75, seed=None, axes=(-1,)):
|
||||
self.lower = lower
|
||||
self.upper = upper
|
||||
self.state = numpy.random.RandomState(seed)
|
||||
# x_stft: (Time, Channel, Freq)
|
||||
self.axes = axes
|
||||
|
||||
def __repr__(self):
|
||||
return "{}(lower={}, upper={})".format(
|
||||
self.__class__.__name__, self.lower, self.upper
|
||||
)
|
||||
|
||||
def __call__(self, x_stft, uttid=None, train=True):
|
||||
if not train:
|
||||
return x_stft
|
||||
|
||||
if x_stft.ndim == 1:
|
||||
raise RuntimeError(
|
||||
"Input in time-freq domain: " "(Time, Channel, Freq) or (Time, Freq)"
|
||||
)
|
||||
|
||||
ratio = self.state.uniform(self.lower, self.upper)
|
||||
axes = [i if i >= 0 else x_stft.ndim - i for i in self.axes]
|
||||
shape = [s if i in axes else 1 for i, s in enumerate(x_stft.shape)]
|
||||
|
||||
mask = self.state.randn(*shape) > ratio
|
||||
x_stft *= mask
|
||||
return x_stft
|
||||
|
||||
|
||||
class VolumePerturbation():
|
||||
def __init__(self, lower=-1.6, upper=1.6, utt2ratio=None, dbunit=True, seed=None):
|
||||
self.dbunit = dbunit
|
||||
self.utt2ratio_file = utt2ratio
|
||||
self.lower = lower
|
||||
self.upper = upper
|
||||
self.state = numpy.random.RandomState(seed)
|
||||
|
||||
if utt2ratio is not None:
|
||||
# Use the scheduled ratio for each utterances
|
||||
self.utt2ratio = {}
|
||||
self.lower = None
|
||||
self.upper = None
|
||||
self.accept_uttid = True
|
||||
|
||||
with open(utt2ratio, "r") as f:
|
||||
for line in f:
|
||||
utt, ratio = line.rstrip().split(None, 1)
|
||||
ratio = float(ratio)
|
||||
self.utt2ratio[utt] = ratio
|
||||
else:
|
||||
# The ratio is given on runtime randomly
|
||||
self.utt2ratio = None
|
||||
|
||||
def __repr__(self):
|
||||
if self.utt2ratio is None:
|
||||
return "{}(lower={}, upper={}, dbunit={})".format(
|
||||
self.__class__.__name__, self.lower, self.upper, self.dbunit
|
||||
)
|
||||
else:
|
||||
return '{}("{}", dbunit={})'.format(
|
||||
self.__class__.__name__, self.utt2ratio_file, self.dbunit
|
||||
)
|
||||
|
||||
def __call__(self, x, uttid=None, train=True):
|
||||
if not train:
|
||||
return x
|
||||
|
||||
x = x.astype(numpy.float32)
|
||||
|
||||
if self.accept_uttid:
|
||||
ratio = self.utt2ratio[uttid]
|
||||
else:
|
||||
ratio = self.state.uniform(self.lower, self.upper)
|
||||
if self.dbunit:
|
||||
ratio = 10 ** (ratio / 20)
|
||||
return x * ratio
|
||||
|
||||
|
||||
class NoiseInjection():
|
||||
"""Add isotropic noise"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
utt2noise=None,
|
||||
lower=-20,
|
||||
upper=-5,
|
||||
utt2ratio=None,
|
||||
filetype="list",
|
||||
dbunit=True,
|
||||
seed=None,
|
||||
):
|
||||
self.utt2noise_file = utt2noise
|
||||
self.utt2ratio_file = utt2ratio
|
||||
self.filetype = filetype
|
||||
self.dbunit = dbunit
|
||||
self.lower = lower
|
||||
self.upper = upper
|
||||
self.state = numpy.random.RandomState(seed)
|
||||
|
||||
if utt2ratio is not None:
|
||||
# Use the scheduled ratio for each utterances
|
||||
self.utt2ratio = {}
|
||||
with open(utt2noise, "r") as f:
|
||||
for line in f:
|
||||
utt, snr = line.rstrip().split(None, 1)
|
||||
snr = float(snr)
|
||||
self.utt2ratio[utt] = snr
|
||||
else:
|
||||
# The ratio is given on runtime randomly
|
||||
self.utt2ratio = None
|
||||
|
||||
if utt2noise is not None:
|
||||
self.utt2noise = {}
|
||||
if filetype == "list":
|
||||
with open(utt2noise, "r") as f:
|
||||
for line in f:
|
||||
utt, filename = line.rstrip().split(None, 1)
|
||||
signal, rate = soundfile.read(filename, dtype="int16")
|
||||
# Load all files in memory
|
||||
self.utt2noise[utt] = (signal, rate)
|
||||
|
||||
elif filetype == "sound.hdf5":
|
||||
self.utt2noise = SoundHDF5File(utt2noise, "r")
|
||||
else:
|
||||
raise ValueError(filetype)
|
||||
else:
|
||||
self.utt2noise = None
|
||||
|
||||
if utt2noise is not None and utt2ratio is not None:
|
||||
if set(self.utt2ratio) != set(self.utt2noise):
|
||||
raise RuntimeError(
|
||||
"The uttids mismatch between {} and {}".format(utt2ratio, utt2noise)
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
if self.utt2ratio is None:
|
||||
return "{}(lower={}, upper={}, dbunit={})".format(
|
||||
self.__class__.__name__, self.lower, self.upper, self.dbunit
|
||||
)
|
||||
else:
|
||||
return '{}("{}", dbunit={})'.format(
|
||||
self.__class__.__name__, self.utt2ratio_file, self.dbunit
|
||||
)
|
||||
|
||||
def __call__(self, x, uttid=None, train=True):
|
||||
if not train:
|
||||
return x
|
||||
x = x.astype(numpy.float32)
|
||||
|
||||
# 1. Get ratio of noise to signal in sound pressure level
|
||||
if uttid is not None and self.utt2ratio is not None:
|
||||
ratio = self.utt2ratio[uttid]
|
||||
else:
|
||||
ratio = self.state.uniform(self.lower, self.upper)
|
||||
|
||||
if self.dbunit:
|
||||
ratio = 10 ** (ratio / 20)
|
||||
scale = ratio * numpy.sqrt((x ** 2).mean())
|
||||
|
||||
# 2. Get noise
|
||||
if self.utt2noise is not None:
|
||||
# Get noise from the external source
|
||||
if uttid is not None:
|
||||
noise, rate = self.utt2noise[uttid]
|
||||
else:
|
||||
# Randomly select the noise source
|
||||
noise = self.state.choice(list(self.utt2noise.values()))
|
||||
# Normalize the level
|
||||
noise /= numpy.sqrt((noise ** 2).mean())
|
||||
|
||||
# Adjust the noise length
|
||||
diff = abs(len(x) - len(noise))
|
||||
offset = self.state.randint(0, diff)
|
||||
if len(noise) > len(x):
|
||||
# Truncate noise
|
||||
noise = noise[offset : -(diff - offset)]
|
||||
else:
|
||||
noise = numpy.pad(noise, pad_width=[offset, diff - offset], mode="wrap")
|
||||
|
||||
else:
|
||||
# Generate white noise
|
||||
noise = self.state.normal(0, 1, x.shape)
|
||||
|
||||
# 3. Add noise to signal
|
||||
return x + noise * scale
|
||||
|
||||
|
||||
class RIRConvolve():
|
||||
def __init__(self, utt2rir, filetype="list"):
|
||||
self.utt2rir_file = utt2rir
|
||||
self.filetype = filetype
|
||||
|
||||
self.utt2rir = {}
|
||||
if filetype == "list":
|
||||
with open(utt2rir, "r") as f:
|
||||
for line in f:
|
||||
utt, filename = line.rstrip().split(None, 1)
|
||||
signal, rate = soundfile.read(filename, dtype="int16")
|
||||
self.utt2rir[utt] = (signal, rate)
|
||||
|
||||
elif filetype == "sound.hdf5":
|
||||
self.utt2rir = SoundHDF5File(utt2rir, "r")
|
||||
else:
|
||||
raise NotImplementedError(filetype)
|
||||
|
||||
def __repr__(self):
|
||||
return '{}("{}")'.format(self.__class__.__name__, self.utt2rir_file)
|
||||
|
||||
def __call__(self, x, uttid=None, train=True):
|
||||
if not train:
|
||||
return x
|
||||
|
||||
x = x.astype(numpy.float32)
|
||||
|
||||
if x.ndim != 1:
|
||||
# Must be single channel
|
||||
raise RuntimeError(
|
||||
"Input x must be one dimensional array, but got {}".format(x.shape)
|
||||
)
|
||||
|
||||
rir, rate = self.utt2rir[uttid]
|
||||
if rir.ndim == 2:
|
||||
# FIXME(kamo): Use chainer.convolution_1d?
|
||||
# return [Time, Channel]
|
||||
return numpy.stack(
|
||||
[scipy.convolve(x, r, mode="same") for r in rir], axis=-1
|
||||
)
|
||||
else:
|
||||
return scipy.convolve(x, rir, mode="same")
|
@ -0,0 +1,202 @@
|
||||
"""Spec Augment module for preprocessing i.e., data augmentation"""
|
||||
|
||||
import random
|
||||
|
||||
import numpy
|
||||
from PIL import Image
|
||||
from PIL.Image import BICUBIC
|
||||
|
||||
from deepspeech.transform.functional import FuncTrans
|
||||
|
||||
|
||||
def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"):
|
||||
"""time warp for spec augment
|
||||
|
||||
move random center frame by the random width ~ uniform(-window, window)
|
||||
:param numpy.ndarray x: spectrogram (time, freq)
|
||||
:param int max_time_warp: maximum time frames to warp
|
||||
:param bool inplace: overwrite x with the result
|
||||
:param str mode: "PIL" (default, fast, not differentiable) or "sparse_image_warp"
|
||||
(slow, differentiable)
|
||||
:returns numpy.ndarray: time warped spectrogram (time, freq)
|
||||
"""
|
||||
window = max_time_warp
|
||||
if mode == "PIL":
|
||||
t = x.shape[0]
|
||||
if t - window <= window:
|
||||
return x
|
||||
# NOTE: randrange(a, b) emits a, a + 1, ..., b - 1
|
||||
center = random.randrange(window, t - window)
|
||||
warped = random.randrange(center - window, center + window) + 1 # 1 ... t - 1
|
||||
|
||||
left = Image.fromarray(x[:center]).resize((x.shape[1], warped), BICUBIC)
|
||||
right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped), BICUBIC)
|
||||
if inplace:
|
||||
x[:warped] = left
|
||||
x[warped:] = right
|
||||
return x
|
||||
return numpy.concatenate((left, right), 0)
|
||||
elif mode == "sparse_image_warp":
|
||||
import paddle
|
||||
|
||||
from espnet.utils import spec_augment
|
||||
|
||||
# TODO(karita): make this differentiable again
|
||||
return spec_augment.time_warp(paddle.to_tensor(x), window).numpy()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"unknown resize mode: "
|
||||
+ mode
|
||||
+ ", choose one from (PIL, sparse_image_warp)."
|
||||
)
|
||||
|
||||
|
||||
class TimeWarp(FuncTrans):
|
||||
_func = time_warp
|
||||
__doc__ = time_warp.__doc__
|
||||
|
||||
def __call__(self, x, train):
|
||||
if not train:
|
||||
return x
|
||||
return super().__call__(x)
|
||||
|
||||
|
||||
def freq_mask(x, F=30, n_mask=2, replace_with_zero=True, inplace=False):
|
||||
"""freq mask for spec agument
|
||||
|
||||
:param numpy.ndarray x: (time, freq)
|
||||
:param int n_mask: the number of masks
|
||||
:param bool inplace: overwrite
|
||||
:param bool replace_with_zero: pad zero on mask if true else use mean
|
||||
"""
|
||||
if inplace:
|
||||
cloned = x
|
||||
else:
|
||||
cloned = x.copy()
|
||||
|
||||
num_mel_channels = cloned.shape[1]
|
||||
fs = numpy.random.randint(0, F, size=(n_mask, 2))
|
||||
|
||||
for f, mask_end in fs:
|
||||
f_zero = random.randrange(0, num_mel_channels - f)
|
||||
mask_end += f_zero
|
||||
|
||||
# avoids randrange error if values are equal and range is empty
|
||||
if f_zero == f_zero + f:
|
||||
continue
|
||||
|
||||
if replace_with_zero:
|
||||
cloned[:, f_zero:mask_end] = 0
|
||||
else:
|
||||
cloned[:, f_zero:mask_end] = cloned.mean()
|
||||
return cloned
|
||||
|
||||
|
||||
class FreqMask(FuncTrans):
|
||||
_func = freq_mask
|
||||
__doc__ = freq_mask.__doc__
|
||||
|
||||
def __call__(self, x, train):
|
||||
if not train:
|
||||
return x
|
||||
return super().__call__(x)
|
||||
|
||||
|
||||
def time_mask(spec, T=40, n_mask=2, replace_with_zero=True, inplace=False):
|
||||
"""freq mask for spec agument
|
||||
|
||||
:param numpy.ndarray spec: (time, freq)
|
||||
:param int n_mask: the number of masks
|
||||
:param bool inplace: overwrite
|
||||
:param bool replace_with_zero: pad zero on mask if true else use mean
|
||||
"""
|
||||
if inplace:
|
||||
cloned = spec
|
||||
else:
|
||||
cloned = spec.copy()
|
||||
len_spectro = cloned.shape[0]
|
||||
ts = numpy.random.randint(0, T, size=(n_mask, 2))
|
||||
for t, mask_end in ts:
|
||||
# avoid randint range error
|
||||
if len_spectro - t <= 0:
|
||||
continue
|
||||
t_zero = random.randrange(0, len_spectro - t)
|
||||
|
||||
# avoids randrange error if values are equal and range is empty
|
||||
if t_zero == t_zero + t:
|
||||
continue
|
||||
|
||||
mask_end += t_zero
|
||||
if replace_with_zero:
|
||||
cloned[t_zero:mask_end] = 0
|
||||
else:
|
||||
cloned[t_zero:mask_end] = cloned.mean()
|
||||
return cloned
|
||||
|
||||
|
||||
class TimeMask(FuncTrans):
|
||||
_func = time_mask
|
||||
__doc__ = time_mask.__doc__
|
||||
|
||||
def __call__(self, x, train):
|
||||
if not train:
|
||||
return x
|
||||
return super().__call__(x)
|
||||
|
||||
|
||||
def spec_augment(
|
||||
x,
|
||||
resize_mode="PIL",
|
||||
max_time_warp=80,
|
||||
max_freq_width=27,
|
||||
n_freq_mask=2,
|
||||
max_time_width=100,
|
||||
n_time_mask=2,
|
||||
inplace=True,
|
||||
replace_with_zero=True,
|
||||
):
|
||||
"""spec agument
|
||||
|
||||
apply random time warping and time/freq masking
|
||||
default setting is based on LD (Librispeech double) in Table 2
|
||||
https://arxiv.org/pdf/1904.08779.pdf
|
||||
|
||||
:param numpy.ndarray x: (time, freq)
|
||||
:param str resize_mode: "PIL" (fast, nondifferentiable) or "sparse_image_warp"
|
||||
(slow, differentiable)
|
||||
:param int max_time_warp: maximum frames to warp the center frame in spectrogram (W)
|
||||
:param int freq_mask_width: maximum width of the random freq mask (F)
|
||||
:param int n_freq_mask: the number of the random freq mask (m_F)
|
||||
:param int time_mask_width: maximum width of the random time mask (T)
|
||||
:param int n_time_mask: the number of the random time mask (m_T)
|
||||
:param bool inplace: overwrite intermediate array
|
||||
:param bool replace_with_zero: pad zero on mask if true else use mean
|
||||
"""
|
||||
assert isinstance(x, numpy.ndarray)
|
||||
assert x.ndim == 2
|
||||
x = time_warp(x, max_time_warp, inplace=inplace, mode=resize_mode)
|
||||
x = freq_mask(
|
||||
x,
|
||||
max_freq_width,
|
||||
n_freq_mask,
|
||||
inplace=inplace,
|
||||
replace_with_zero=replace_with_zero,
|
||||
)
|
||||
x = time_mask(
|
||||
x,
|
||||
max_time_width,
|
||||
n_time_mask,
|
||||
inplace=inplace,
|
||||
replace_with_zero=replace_with_zero,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
class SpecAugment(FuncTrans):
|
||||
_func = spec_augment
|
||||
__doc__ = spec_augment.__doc__
|
||||
|
||||
def __call__(self, x, train):
|
||||
if not train:
|
||||
return x
|
||||
return super().__call__(x)
|
@ -0,0 +1,307 @@
|
||||
import librosa
|
||||
import numpy as np
|
||||
|
||||
|
||||
def stft(
|
||||
x, n_fft, n_shift, win_length=None, window="hann", center=True, pad_mode="reflect"
|
||||
):
|
||||
# x: [Time, Channel]
|
||||
if x.ndim == 1:
|
||||
single_channel = True
|
||||
# x: [Time] -> [Time, Channel]
|
||||
x = x[:, None]
|
||||
else:
|
||||
single_channel = False
|
||||
x = x.astype(np.float32)
|
||||
|
||||
# FIXME(kamo): librosa.stft can't use multi-channel?
|
||||
# x: [Time, Channel, Freq]
|
||||
x = np.stack(
|
||||
[
|
||||
librosa.stft(
|
||||
x[:, ch],
|
||||
n_fft=n_fft,
|
||||
hop_length=n_shift,
|
||||
win_length=win_length,
|
||||
window=window,
|
||||
center=center,
|
||||
pad_mode=pad_mode,
|
||||
).T
|
||||
for ch in range(x.shape[1])
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
if single_channel:
|
||||
# x: [Time, Channel, Freq] -> [Time, Freq]
|
||||
x = x[:, 0]
|
||||
return x
|
||||
|
||||
|
||||
def istft(x, n_shift, win_length=None, window="hann", center=True):
|
||||
# x: [Time, Channel, Freq]
|
||||
if x.ndim == 2:
|
||||
single_channel = True
|
||||
# x: [Time, Freq] -> [Time, Channel, Freq]
|
||||
x = x[:, None, :]
|
||||
else:
|
||||
single_channel = False
|
||||
|
||||
# x: [Time, Channel]
|
||||
x = np.stack(
|
||||
[
|
||||
librosa.istft(
|
||||
x[:, ch].T, # [Time, Freq] -> [Freq, Time]
|
||||
hop_length=n_shift,
|
||||
win_length=win_length,
|
||||
window=window,
|
||||
center=center,
|
||||
)
|
||||
for ch in range(x.shape[1])
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
if single_channel:
|
||||
# x: [Time, Channel] -> [Time]
|
||||
x = x[:, 0]
|
||||
return x
|
||||
|
||||
|
||||
def stft2logmelspectrogram(x_stft, fs, n_mels, n_fft, fmin=None, fmax=None, eps=1e-10):
|
||||
# x_stft: (Time, Channel, Freq) or (Time, Freq)
|
||||
fmin = 0 if fmin is None else fmin
|
||||
fmax = fs / 2 if fmax is None else fmax
|
||||
|
||||
# spc: (Time, Channel, Freq) or (Time, Freq)
|
||||
spc = np.abs(x_stft)
|
||||
# mel_basis: (Mel_freq, Freq)
|
||||
mel_basis = librosa.filters.mel(fs, n_fft, n_mels, fmin, fmax)
|
||||
# lmspc: (Time, Channel, Mel_freq) or (Time, Mel_freq)
|
||||
lmspc = np.log10(np.maximum(eps, np.dot(spc, mel_basis.T)))
|
||||
|
||||
return lmspc
|
||||
|
||||
|
||||
def spectrogram(x, n_fft, n_shift, win_length=None, window="hann"):
|
||||
# x: (Time, Channel) -> spc: (Time, Channel, Freq)
|
||||
spc = np.abs(stft(x, n_fft, n_shift, win_length, window=window))
|
||||
return spc
|
||||
|
||||
|
||||
def logmelspectrogram(
|
||||
x,
|
||||
fs,
|
||||
n_mels,
|
||||
n_fft,
|
||||
n_shift,
|
||||
win_length=None,
|
||||
window="hann",
|
||||
fmin=None,
|
||||
fmax=None,
|
||||
eps=1e-10,
|
||||
pad_mode="reflect",
|
||||
):
|
||||
# stft: (Time, Channel, Freq) or (Time, Freq)
|
||||
x_stft = stft(
|
||||
x,
|
||||
n_fft=n_fft,
|
||||
n_shift=n_shift,
|
||||
win_length=win_length,
|
||||
window=window,
|
||||
pad_mode=pad_mode,
|
||||
)
|
||||
|
||||
return stft2logmelspectrogram(
|
||||
x_stft, fs=fs, n_mels=n_mels, n_fft=n_fft, fmin=fmin, fmax=fmax, eps=eps
|
||||
)
|
||||
|
||||
|
||||
class Spectrogram():
|
||||
def __init__(self, n_fft, n_shift, win_length=None, window="hann"):
|
||||
self.n_fft = n_fft
|
||||
self.n_shift = n_shift
|
||||
self.win_length = win_length
|
||||
self.window = window
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"{name}(n_fft={n_fft}, n_shift={n_shift}, "
|
||||
"win_length={win_length}, window={window})".format(
|
||||
name=self.__class__.__name__,
|
||||
n_fft=self.n_fft,
|
||||
n_shift=self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
)
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
return spectrogram(
|
||||
x,
|
||||
n_fft=self.n_fft,
|
||||
n_shift=self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
)
|
||||
|
||||
|
||||
class LogMelSpectrogram():
|
||||
def __init__(
|
||||
self,
|
||||
fs,
|
||||
n_mels,
|
||||
n_fft,
|
||||
n_shift,
|
||||
win_length=None,
|
||||
window="hann",
|
||||
fmin=None,
|
||||
fmax=None,
|
||||
eps=1e-10,
|
||||
):
|
||||
self.fs = fs
|
||||
self.n_mels = n_mels
|
||||
self.n_fft = n_fft
|
||||
self.n_shift = n_shift
|
||||
self.win_length = win_length
|
||||
self.window = window
|
||||
self.fmin = fmin
|
||||
self.fmax = fmax
|
||||
self.eps = eps
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
|
||||
"n_shift={n_shift}, win_length={win_length}, window={window}, "
|
||||
"fmin={fmin}, fmax={fmax}, eps={eps}))".format(
|
||||
name=self.__class__.__name__,
|
||||
fs=self.fs,
|
||||
n_mels=self.n_mels,
|
||||
n_fft=self.n_fft,
|
||||
n_shift=self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
fmin=self.fmin,
|
||||
fmax=self.fmax,
|
||||
eps=self.eps,
|
||||
)
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
return logmelspectrogram(
|
||||
x,
|
||||
fs=self.fs,
|
||||
n_mels=self.n_mels,
|
||||
n_fft=self.n_fft,
|
||||
n_shift=self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
)
|
||||
|
||||
|
||||
class Stft2LogMelSpectrogram():
|
||||
def __init__(self, fs, n_mels, n_fft, fmin=None, fmax=None, eps=1e-10):
|
||||
self.fs = fs
|
||||
self.n_mels = n_mels
|
||||
self.n_fft = n_fft
|
||||
self.fmin = fmin
|
||||
self.fmax = fmax
|
||||
self.eps = eps
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
|
||||
"fmin={fmin}, fmax={fmax}, eps={eps}))".format(
|
||||
name=self.__class__.__name__,
|
||||
fs=self.fs,
|
||||
n_mels=self.n_mels,
|
||||
n_fft=self.n_fft,
|
||||
fmin=self.fmin,
|
||||
fmax=self.fmax,
|
||||
eps=self.eps,
|
||||
)
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
return stft2logmelspectrogram(
|
||||
x,
|
||||
fs=self.fs,
|
||||
n_mels=self.n_mels,
|
||||
n_fft=self.n_fft,
|
||||
fmin=self.fmin,
|
||||
fmax=self.fmax,
|
||||
)
|
||||
|
||||
|
||||
class Stft():
|
||||
def __init__(
|
||||
self,
|
||||
n_fft,
|
||||
n_shift,
|
||||
win_length=None,
|
||||
window="hann",
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
):
|
||||
self.n_fft = n_fft
|
||||
self.n_shift = n_shift
|
||||
self.win_length = win_length
|
||||
self.window = window
|
||||
self.center = center
|
||||
self.pad_mode = pad_mode
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"{name}(n_fft={n_fft}, n_shift={n_shift}, "
|
||||
"win_length={win_length}, window={window},"
|
||||
"center={center}, pad_mode={pad_mode})".format(
|
||||
name=self.__class__.__name__,
|
||||
n_fft=self.n_fft,
|
||||
n_shift=self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
center=self.center,
|
||||
pad_mode=self.pad_mode,
|
||||
)
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
return stft(
|
||||
x,
|
||||
self.n_fft,
|
||||
self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
center=self.center,
|
||||
pad_mode=self.pad_mode,
|
||||
)
|
||||
|
||||
|
||||
class IStft():
|
||||
def __init__(self, n_shift, win_length=None, window="hann", center=True):
|
||||
self.n_shift = n_shift
|
||||
self.win_length = win_length
|
||||
self.window = window
|
||||
self.center = center
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"{name}(n_shift={n_shift}, "
|
||||
"win_length={win_length}, window={window},"
|
||||
"center={center})".format(
|
||||
name=self.__class__.__name__,
|
||||
n_shift=self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
center=self.center,
|
||||
)
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
return istft(
|
||||
x,
|
||||
self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
center=self.center,
|
||||
)
|
@ -0,0 +1,20 @@
|
||||
# TODO(karita): add this to all the transform impl.
|
||||
class TransformInterface:
|
||||
"""Transform Interface"""
|
||||
|
||||
def __call__(self, x):
|
||||
raise NotImplementedError("__call__ method is not implemented")
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
return parser
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + "()"
|
||||
|
||||
|
||||
class Identity(TransformInterface):
|
||||
"""Identity Function"""
|
||||
|
||||
def __call__(self, x):
|
||||
return x
|
@ -0,0 +1,149 @@
|
||||
"""Transformation module."""
|
||||
from collections.abc import Sequence
|
||||
from collections import OrderedDict
|
||||
import copy
|
||||
from inspect import signature
|
||||
import io
|
||||
import logging
|
||||
|
||||
import yaml
|
||||
|
||||
from deepspeech.utils.dynamic_import import dynamic_import
|
||||
|
||||
|
||||
# TODO(karita): inherit TransformInterface
|
||||
# TODO(karita): register cmd arguments in asr_train.py
|
||||
import_alias = dict(
|
||||
identity="deepspeech.transform.transform_interface:Identity",
|
||||
time_warp="deepspeech.transform.spec_augment:TimeWarp",
|
||||
time_mask="deepspeech.transform.spec_augment:TimeMask",
|
||||
freq_mask="deepspeech.transform.spec_augment:FreqMask",
|
||||
spec_augment="deepspeech.transform.spec_augment:SpecAugment",
|
||||
speed_perturbation="deepspeech.transform.perturb:SpeedPerturbation",
|
||||
volume_perturbation="deepspeech.transform.perturb:VolumePerturbation",
|
||||
noise_injection="deepspeech.transform.perturb:NoiseInjection",
|
||||
bandpass_perturbation="deepspeech.transform.perturb:BandpassPerturbation",
|
||||
rir_convolve="deepspeech.transform.perturb:RIRConvolve",
|
||||
delta="deepspeech.transform.add_deltas:AddDeltas",
|
||||
cmvn="deepspeech.transform.cmvn:CMVN",
|
||||
utterance_cmvn="deepspeech.transform.cmvn:UtteranceCMVN",
|
||||
fbank="deepspeech.transform.spectrogram:LogMelSpectrogram",
|
||||
spectrogram="deepspeech.transform.spectrogram:Spectrogram",
|
||||
stft="deepspeech.transform.spectrogram:Stft",
|
||||
istft="deepspeech.transform.spectrogram:IStft",
|
||||
stft2fbank="deepspeech.transform.spectrogram:Stft2LogMelSpectrogram",
|
||||
wpe="deepspeech.transform.wpe:WPE",
|
||||
channel_selector="deepspeech.transform.channel_selector:ChannelSelector",
|
||||
)
|
||||
|
||||
|
||||
class Transformation():
|
||||
"""Apply some functions to the mini-batch
|
||||
|
||||
Examples:
|
||||
>>> kwargs = {"process": [{"type": "fbank",
|
||||
... "n_mels": 80,
|
||||
... "fs": 16000},
|
||||
... {"type": "cmvn",
|
||||
... "stats": "data/train/cmvn.ark",
|
||||
... "norm_vars": True},
|
||||
... {"type": "delta", "window": 2, "order": 2}]}
|
||||
>>> transform = Transformation(kwargs)
|
||||
>>> bs = 10
|
||||
>>> xs = [np.random.randn(100, 80).astype(np.float32)
|
||||
... for _ in range(bs)]
|
||||
>>> xs = transform(xs)
|
||||
"""
|
||||
|
||||
def __init__(self, conffile=None):
|
||||
if conffile is not None:
|
||||
if isinstance(conffile, dict):
|
||||
self.conf = copy.deepcopy(conffile)
|
||||
else:
|
||||
with io.open(conffile, encoding="utf-8") as f:
|
||||
self.conf = yaml.safe_load(f)
|
||||
assert isinstance(self.conf, dict), type(self.conf)
|
||||
else:
|
||||
self.conf = {"mode": "sequential", "process": []}
|
||||
|
||||
self.functions = OrderedDict()
|
||||
if self.conf.get("mode", "sequential") == "sequential":
|
||||
for idx, process in enumerate(self.conf["process"]):
|
||||
assert isinstance(process, dict), type(process)
|
||||
opts = dict(process)
|
||||
process_type = opts.pop("type")
|
||||
class_obj = dynamic_import(process_type, import_alias)
|
||||
# TODO(karita): assert issubclass(class_obj, TransformInterface)
|
||||
try:
|
||||
self.functions[idx] = class_obj(**opts)
|
||||
except TypeError:
|
||||
try:
|
||||
signa = signature(class_obj)
|
||||
except ValueError:
|
||||
# Some function, e.g. built-in function, are failed
|
||||
pass
|
||||
else:
|
||||
logging.error(
|
||||
"Expected signature: {}({})".format(
|
||||
class_obj.__name__, signa
|
||||
)
|
||||
)
|
||||
raise
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Not supporting mode={}".format(self.conf["mode"])
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
rep = "\n" + "\n".join(
|
||||
" {}: {}".format(k, v) for k, v in self.functions.items()
|
||||
)
|
||||
return "{}({})".format(self.__class__.__name__, rep)
|
||||
|
||||
def __call__(self, xs, uttid_list=None, **kwargs):
|
||||
"""Return new mini-batch
|
||||
|
||||
:param Union[Sequence[np.ndarray], np.ndarray] xs:
|
||||
:param Union[Sequence[str], str] uttid_list:
|
||||
:return: batch:
|
||||
:rtype: List[np.ndarray]
|
||||
"""
|
||||
if not isinstance(xs, Sequence):
|
||||
is_batch = False
|
||||
xs = [xs]
|
||||
else:
|
||||
is_batch = True
|
||||
|
||||
if isinstance(uttid_list, str):
|
||||
uttid_list = [uttid_list for _ in range(len(xs))]
|
||||
|
||||
if self.conf.get("mode", "sequential") == "sequential":
|
||||
for idx in range(len(self.conf["process"])):
|
||||
func = self.functions[idx]
|
||||
# TODO(karita): use TrainingTrans and UttTrans to check __call__ args
|
||||
# Derive only the args which the func has
|
||||
try:
|
||||
param = signature(func).parameters
|
||||
except ValueError:
|
||||
# Some function, e.g. built-in function, are failed
|
||||
param = {}
|
||||
_kwargs = {k: v for k, v in kwargs.items() if k in param}
|
||||
try:
|
||||
if uttid_list is not None and "uttid" in param:
|
||||
xs = [func(x, u, **_kwargs) for x, u in zip(xs, uttid_list)]
|
||||
else:
|
||||
xs = [func(x, **_kwargs) for x in xs]
|
||||
except Exception:
|
||||
logging.fatal(
|
||||
"Catch a exception from {}th func: {}".format(idx, func)
|
||||
)
|
||||
raise
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Not supporting mode={}".format(self.conf["mode"])
|
||||
)
|
||||
|
||||
if is_batch:
|
||||
return xs
|
||||
else:
|
||||
return xs[0]
|
@ -0,0 +1,45 @@
|
||||
from nara_wpe.wpe import wpe
|
||||
|
||||
|
||||
class WPE(object):
|
||||
def __init__(
|
||||
self, taps=10, delay=3, iterations=3, psd_context=0, statistics_mode="full"
|
||||
):
|
||||
self.taps = taps
|
||||
self.delay = delay
|
||||
self.iterations = iterations
|
||||
self.psd_context = psd_context
|
||||
self.statistics_mode = statistics_mode
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"{name}(taps={taps}, delay={delay}"
|
||||
"iterations={iterations}, psd_context={psd_context}, "
|
||||
"statistics_mode={statistics_mode})".format(
|
||||
name=self.__class__.__name__,
|
||||
taps=self.taps,
|
||||
delay=self.delay,
|
||||
iterations=self.iterations,
|
||||
psd_context=self.psd_context,
|
||||
statistics_mode=self.statistics_mode,
|
||||
)
|
||||
)
|
||||
|
||||
def __call__(self, xs):
|
||||
"""Return enhanced
|
||||
|
||||
:param np.ndarray xs: (Time, Channel, Frequency)
|
||||
:return: enhanced_xs
|
||||
:rtype: np.ndarray
|
||||
|
||||
"""
|
||||
# nara_wpe.wpe: (F, C, T)
|
||||
xs = wpe(
|
||||
xs.transpose((2, 1, 0)),
|
||||
taps=self.taps,
|
||||
delay=self.delay,
|
||||
iterations=self.iterations,
|
||||
psd_context=self.psd_context,
|
||||
statistics_mode=self.statistics_mode,
|
||||
)
|
||||
return xs.transpose(2, 1, 0)
|
@ -0,0 +1,20 @@
|
||||
import inspect
|
||||
|
||||
|
||||
def check_kwargs(func, kwargs, name=None):
|
||||
"""check kwargs are valid for func
|
||||
|
||||
If kwargs are invalid, raise TypeError as same as python default
|
||||
:param function func: function to be validated
|
||||
:param dict kwargs: keyword arguments for func
|
||||
:param str name: name used in TypeError (default is func name)
|
||||
"""
|
||||
try:
|
||||
params = inspect.signature(func).parameters
|
||||
except ValueError:
|
||||
return
|
||||
if name is None:
|
||||
name = func.__name__
|
||||
for k in kwargs.keys():
|
||||
if k not in params:
|
||||
raise TypeError(f"{name}() got an unexpected keyword argument '{k}'")
|
@ -1,122 +0,0 @@
|
||||
# https://yaml.org/type/float.html
|
||||
data:
|
||||
train_manifest: data/manifest.train
|
||||
dev_manifest: data/manifest.dev
|
||||
test_manifest: data/manifest.test
|
||||
min_input_len: 0.5
|
||||
max_input_len: 20.0
|
||||
min_output_len: 0.0
|
||||
max_output_len: 400.0
|
||||
min_output_input_ratio: 0.05
|
||||
max_output_input_ratio: 10.0
|
||||
|
||||
collator:
|
||||
vocab_filepath: data/vocab.txt
|
||||
unit_type: 'spm'
|
||||
spm_model_prefix: 'data/bpe_unigram_5000'
|
||||
mean_std_filepath: ""
|
||||
augmentation_config: conf/augmentation.json
|
||||
batch_size: 16
|
||||
raw_wav: True # use raw_wav or kaldi feature
|
||||
spectrum_type: fbank #linear, mfcc, fbank
|
||||
feat_dim: 80
|
||||
delta_delta: False
|
||||
dither: 1.0
|
||||
target_sample_rate: 16000
|
||||
max_freq: None
|
||||
n_fft: None
|
||||
stride_ms: 10.0
|
||||
window_ms: 25.0
|
||||
use_dB_normalization: True
|
||||
target_dB: -20
|
||||
random_seed: 0
|
||||
keep_transcription_text: False
|
||||
sortagrad: True
|
||||
shuffle_method: batch_shuffle
|
||||
num_workers: 2
|
||||
|
||||
|
||||
# network architecture
|
||||
model:
|
||||
cmvn_file: "data/mean_std.json"
|
||||
cmvn_file_type: "json"
|
||||
# encoder related
|
||||
encoder: conformer
|
||||
encoder_conf:
|
||||
output_size: 256 # dimension of attention
|
||||
attention_heads: 4
|
||||
linear_units: 2048 # the number of units of position-wise feed forward
|
||||
num_blocks: 12 # the number of encoder blocks
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
attention_dropout_rate: 0.0
|
||||
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
|
||||
normalize_before: True
|
||||
use_cnn_module: True
|
||||
cnn_module_kernel: 15
|
||||
activation_type: 'swish'
|
||||
pos_enc_layer_type: 'rel_pos'
|
||||
selfattention_layer_type: 'rel_selfattn'
|
||||
causal: True
|
||||
use_dynamic_chunk: true
|
||||
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
|
||||
use_dynamic_left_chunk: false
|
||||
|
||||
# decoder related
|
||||
decoder: transformer
|
||||
decoder_conf:
|
||||
attention_heads: 4
|
||||
linear_units: 2048
|
||||
num_blocks: 6
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
self_attention_dropout_rate: 0.0
|
||||
src_attention_dropout_rate: 0.0
|
||||
|
||||
# hybrid CTC/attention
|
||||
model_conf:
|
||||
ctc_weight: 0.3
|
||||
ctc_dropoutrate: 0.0
|
||||
ctc_grad_norm_type: null
|
||||
lsm_weight: 0.1 # label smoothing option
|
||||
length_normalized_loss: false
|
||||
|
||||
|
||||
training:
|
||||
n_epoch: 240
|
||||
accum_grad: 8
|
||||
global_grad_clip: 5.0
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.001
|
||||
weight_decay: 1e-06
|
||||
scheduler: warmuplr # pytorch v1.1.0+ required
|
||||
scheduler_conf:
|
||||
warmup_steps: 25000
|
||||
lr_decay: 1.0
|
||||
log_interval: 100
|
||||
checkpoint:
|
||||
kbest_n: 50
|
||||
latest_n: 5
|
||||
|
||||
|
||||
decoding:
|
||||
batch_size: 128
|
||||
error_rate_type: wer
|
||||
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
|
||||
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
|
||||
alpha: 2.5
|
||||
beta: 0.3
|
||||
beam_size: 10
|
||||
cutoff_prob: 1.0
|
||||
cutoff_top_n: 0
|
||||
num_proc_bsearch: 8
|
||||
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
|
||||
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
|
||||
# <0: for decoding, use full chunk.
|
||||
# >0: for decoding, use fixed chunk size as set.
|
||||
# 0: used for training, it's prohibited here.
|
||||
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
|
||||
simulate_streaming: true # simulate streaming inference. Defaults to False.
|
||||
|
||||
|
@ -1,115 +0,0 @@
|
||||
# https://yaml.org/type/float.html
|
||||
data:
|
||||
train_manifest: data/manifest.train
|
||||
dev_manifest: data/manifest.dev
|
||||
test_manifest: data/manifest.test
|
||||
min_input_len: 0.5 # second
|
||||
max_input_len: 20.0 # second
|
||||
min_output_len: 0.0 # tokens
|
||||
max_output_len: 400.0 # tokens
|
||||
min_output_input_ratio: 0.05
|
||||
max_output_input_ratio: 10.0
|
||||
|
||||
collator:
|
||||
vocab_filepath: data/vocab.txt
|
||||
unit_type: 'spm'
|
||||
spm_model_prefix: 'data/bpe_unigram_5000'
|
||||
mean_std_filepath: ""
|
||||
augmentation_config: conf/augmentation.json
|
||||
batch_size: 64
|
||||
raw_wav: True # use raw_wav or kaldi feature
|
||||
spectrum_type: fbank #linear, mfcc, fbank
|
||||
feat_dim: 80
|
||||
delta_delta: False
|
||||
dither: 1.0
|
||||
target_sample_rate: 16000
|
||||
max_freq: None
|
||||
n_fft: None
|
||||
stride_ms: 10.0
|
||||
window_ms: 25.0
|
||||
use_dB_normalization: True
|
||||
target_dB: -20
|
||||
random_seed: 0
|
||||
keep_transcription_text: False
|
||||
sortagrad: True
|
||||
shuffle_method: batch_shuffle
|
||||
num_workers: 2
|
||||
|
||||
|
||||
# network architecture
|
||||
model:
|
||||
cmvn_file: "data/mean_std.json"
|
||||
cmvn_file_type: "json"
|
||||
# encoder related
|
||||
encoder: transformer
|
||||
encoder_conf:
|
||||
output_size: 256 # dimension of attention
|
||||
attention_heads: 4
|
||||
linear_units: 2048 # the number of units of position-wise feed forward
|
||||
num_blocks: 12 # the number of encoder blocks
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
attention_dropout_rate: 0.0
|
||||
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
|
||||
normalize_before: true
|
||||
use_dynamic_chunk: true
|
||||
use_dynamic_left_chunk: false
|
||||
|
||||
# decoder related
|
||||
decoder: transformer
|
||||
decoder_conf:
|
||||
attention_heads: 4
|
||||
linear_units: 2048
|
||||
num_blocks: 6
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
self_attention_dropout_rate: 0.0
|
||||
src_attention_dropout_rate: 0.0
|
||||
|
||||
# hybrid CTC/attention
|
||||
model_conf:
|
||||
ctc_weight: 0.3
|
||||
ctc_dropoutrate: 0.0
|
||||
ctc_grad_norm_type: null
|
||||
lsm_weight: 0.1 # label smoothing option
|
||||
length_normalized_loss: false
|
||||
|
||||
|
||||
training:
|
||||
n_epoch: 120
|
||||
accum_grad: 1
|
||||
global_grad_clip: 5.0
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.001
|
||||
weight_decay: 1e-06
|
||||
scheduler: warmuplr # pytorch v1.1.0+ required
|
||||
scheduler_conf:
|
||||
warmup_steps: 25000
|
||||
lr_decay: 1.0
|
||||
log_interval: 100
|
||||
checkpoint:
|
||||
kbest_n: 50
|
||||
latest_n: 5
|
||||
|
||||
|
||||
decoding:
|
||||
batch_size: 64
|
||||
error_rate_type: wer
|
||||
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
|
||||
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
|
||||
alpha: 2.5
|
||||
beta: 0.3
|
||||
beam_size: 10
|
||||
cutoff_prob: 1.0
|
||||
cutoff_top_n: 0
|
||||
num_proc_bsearch: 8
|
||||
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
|
||||
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
|
||||
# <0: for decoding, use full chunk.
|
||||
# >0: for decoding, use fixed chunk size as set.
|
||||
# 0: used for training, it's prohibited here.
|
||||
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
|
||||
simulate_streaming: true # simulate streaming inference. Defaults to False.
|
||||
|
||||
|
@ -1,118 +0,0 @@
|
||||
# https://yaml.org/type/float.html
|
||||
data:
|
||||
train_manifest: data/manifest.train
|
||||
dev_manifest: data/manifest.dev
|
||||
test_manifest: data/manifest.test-clean
|
||||
min_input_len: 0.5 # seconds
|
||||
max_input_len: 20.0 # seconds
|
||||
min_output_len: 0.0 # tokens
|
||||
max_output_len: 400.0 # tokens
|
||||
min_output_input_ratio: 0.05
|
||||
max_output_input_ratio: 10.0
|
||||
|
||||
collator:
|
||||
vocab_filepath: data/vocab.txt
|
||||
unit_type: 'spm'
|
||||
spm_model_prefix: 'data/bpe_unigram_5000'
|
||||
mean_std_filepath: ""
|
||||
augmentation_config: conf/augmentation.json
|
||||
batch_size: 16
|
||||
raw_wav: True # use raw_wav or kaldi feature
|
||||
spectrum_type: fbank #linear, mfcc, fbank
|
||||
feat_dim: 80
|
||||
delta_delta: False
|
||||
dither: 1.0
|
||||
target_sample_rate: 16000
|
||||
max_freq: None
|
||||
n_fft: None
|
||||
stride_ms: 10.0
|
||||
window_ms: 25.0
|
||||
use_dB_normalization: True
|
||||
target_dB: -20
|
||||
random_seed: 0
|
||||
keep_transcription_text: False
|
||||
sortagrad: True
|
||||
shuffle_method: batch_shuffle
|
||||
num_workers: 2
|
||||
|
||||
|
||||
# network architecture
|
||||
model:
|
||||
cmvn_file: "data/mean_std.json"
|
||||
cmvn_file_type: "json"
|
||||
# encoder related
|
||||
encoder: conformer
|
||||
encoder_conf:
|
||||
output_size: 256 # dimension of attention
|
||||
attention_heads: 4
|
||||
linear_units: 2048 # the number of units of position-wise feed forward
|
||||
num_blocks: 12 # the number of encoder blocks
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
attention_dropout_rate: 0.0
|
||||
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
|
||||
normalize_before: True
|
||||
use_cnn_module: True
|
||||
cnn_module_kernel: 15
|
||||
activation_type: 'swish'
|
||||
pos_enc_layer_type: 'rel_pos'
|
||||
selfattention_layer_type: 'rel_selfattn'
|
||||
|
||||
# decoder related
|
||||
decoder: transformer
|
||||
decoder_conf:
|
||||
attention_heads: 4
|
||||
linear_units: 2048
|
||||
num_blocks: 6
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
self_attention_dropout_rate: 0.0
|
||||
src_attention_dropout_rate: 0.0
|
||||
|
||||
# hybrid CTC/attention
|
||||
model_conf:
|
||||
ctc_weight: 0.3
|
||||
ctc_dropoutrate: 0.0
|
||||
ctc_grad_norm_type: null
|
||||
lsm_weight: 0.1 # label smoothing option
|
||||
length_normalized_loss: false
|
||||
|
||||
|
||||
training:
|
||||
n_epoch: 120
|
||||
accum_grad: 8
|
||||
global_grad_clip: 3.0
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.004
|
||||
weight_decay: 1e-06
|
||||
scheduler: warmuplr # pytorch v1.1.0+ required
|
||||
scheduler_conf:
|
||||
warmup_steps: 25000
|
||||
lr_decay: 1.0
|
||||
log_interval: 100
|
||||
checkpoint:
|
||||
kbest_n: 50
|
||||
latest_n: 5
|
||||
|
||||
|
||||
decoding:
|
||||
batch_size: 64
|
||||
error_rate_type: wer
|
||||
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
|
||||
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
|
||||
alpha: 2.5
|
||||
beta: 0.3
|
||||
beam_size: 10
|
||||
cutoff_prob: 1.0
|
||||
cutoff_top_n: 0
|
||||
num_proc_bsearch: 8
|
||||
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
|
||||
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
|
||||
# <0: for decoding, use full chunk.
|
||||
# >0: for decoding, use fixed chunk size as set.
|
||||
# 0: used for training, it's prohibited here.
|
||||
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
|
||||
simulate_streaming: False # simulate streaming inference. Defaults to False.
|
||||
|
||||
|
@ -0,0 +1,84 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from deepspeech.transform.transformation import Transformation
|
||||
from deepspeech.utils.cli_readers import file_reader_helper
|
||||
from deepspeech.utils.cli_utils import get_commandline_args
|
||||
from deepspeech.utils.cli_utils import is_scipy_wav_style
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="convert feature to its shape",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option")
|
||||
parser.add_argument(
|
||||
"--filetype",
|
||||
type=str,
|
||||
default="mat",
|
||||
choices=["mat", "hdf5", "sound.hdf5", "sound"],
|
||||
help="Specify the file format for the rspecifier. "
|
||||
'"mat" is the matrix format in kaldi',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preprocess-conf",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The configuration file for the pre-processing",
|
||||
)
|
||||
parser.add_argument(
|
||||
"rspecifier", type=str, help="Read specifier for feats. e.g. ark:some.ark"
|
||||
)
|
||||
parser.add_argument(
|
||||
"out",
|
||||
nargs="?",
|
||||
type=argparse.FileType("w"),
|
||||
default=sys.stdout,
|
||||
help="The output filename. " "If omitted, then output to sys.stdout",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# logging info
|
||||
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||
if args.verbose > 0:
|
||||
logging.basicConfig(level=logging.INFO, format=logfmt)
|
||||
else:
|
||||
logging.basicConfig(level=logging.WARN, format=logfmt)
|
||||
logging.info(get_commandline_args())
|
||||
|
||||
if args.preprocess_conf is not None:
|
||||
preprocessing = Transformation(args.preprocess_conf)
|
||||
logging.info("Apply preprocessing: {}".format(preprocessing))
|
||||
else:
|
||||
preprocessing = None
|
||||
|
||||
# There are no necessary for matrix without preprocessing,
|
||||
# so change to file_reader_helper to return shape.
|
||||
# This make sense only with filetype="hdf5".
|
||||
for utt, mat in file_reader_helper(
|
||||
args.rspecifier, args.filetype, return_shape=preprocessing is None
|
||||
):
|
||||
if preprocessing is not None:
|
||||
if is_scipy_wav_style(mat):
|
||||
# If data is sound file, then got as Tuple[int, ndarray]
|
||||
rate, mat = mat
|
||||
mat = preprocessing(mat, uttid_list=utt)
|
||||
shape_str = ",".join(map(str, mat.shape))
|
||||
else:
|
||||
if len(mat) == 2 and isinstance(mat[1], tuple):
|
||||
# If data is sound file, Tuple[int, Tuple[int, ...]]
|
||||
rate, mat = mat
|
||||
shape_str = ",".join(map(str, mat))
|
||||
args.out.write("{} {}\n".format(utt, shape_str))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in new issue