diff --git a/deepspeech/transform/add_deltas.py b/deepspeech/transform/add_deltas.py new file mode 100644 index 00000000..68f44d41 --- /dev/null +++ b/deepspeech/transform/add_deltas.py @@ -0,0 +1,41 @@ +import numpy as np + + +def delta(feat, window): + assert window > 0 + delta_feat = np.zeros_like(feat) + for i in range(1, window + 1): + delta_feat[:-i] += i * feat[i:] + delta_feat[i:] += -i * feat[:-i] + delta_feat[-i:] += i * feat[-1] + delta_feat[:i] += -i * feat[0] + delta_feat /= 2 * sum(i ** 2 for i in range(1, window + 1)) + return delta_feat + + +def add_deltas(x, window=2, order=2): + """ + Args: + x (np.ndarray): speech feat, (T, D). + + Return: + np.ndarray: (T, (1+order)*D) + """ + feats = [x] + for _ in range(order): + feats.append(delta(feats[-1], window)) + return np.concatenate(feats, axis=1) + + +class AddDeltas(): + def __init__(self, window=2, order=2): + self.window = window + self.order = order + + def __repr__(self): + return "{name}(window={window}, order={order}".format( + name=self.__class__.__name__, window=self.window, order=self.order + ) + + def __call__(self, x): + return add_deltas(x, window=self.window, order=self.order) diff --git a/deepspeech/transform/channel_selector.py b/deepspeech/transform/channel_selector.py new file mode 100644 index 00000000..1ac9e350 --- /dev/null +++ b/deepspeech/transform/channel_selector.py @@ -0,0 +1,45 @@ +import numpy + + +class ChannelSelector(): + """Select 1ch from multi-channel signal""" + + def __init__(self, train_channel="random", eval_channel=0, axis=1): + self.train_channel = train_channel + self.eval_channel = eval_channel + self.axis = axis + + def __repr__(self): + return ( + "{name}(train_channel={train_channel}, " + "eval_channel={eval_channel}, axis={axis})".format( + name=self.__class__.__name__, + train_channel=self.train_channel, + eval_channel=self.eval_channel, + axis=self.axis, + ) + ) + + def __call__(self, x, train=True): + # Assuming x: [Time, Channel] by default + + if x.ndim <= self.axis: + # If the dimension is insufficient, then unsqueeze + # (e.g [Time] -> [Time, 1]) + ind = tuple( + slice(None) if i < x.ndim else None for i in range(self.axis + 1) + ) + x = x[ind] + + if train: + channel = self.train_channel + else: + channel = self.eval_channel + + if channel == "random": + ch = numpy.random.randint(0, x.shape[self.axis]) + else: + ch = channel + + ind = tuple(slice(None) if i != self.axis else ch for i in range(x.ndim)) + return x[ind] diff --git a/deepspeech/transform/functional.py b/deepspeech/transform/functional.py new file mode 100644 index 00000000..5eec6cc1 --- /dev/null +++ b/deepspeech/transform/functional.py @@ -0,0 +1,71 @@ +import inspect + +from deepspeech.transform.transform_interface import TransformInterface +from deepspeech.utils.check_kwargs import check_kwargs + + +class FuncTrans(TransformInterface): + """Functional Transformation + + WARNING: + Builtin or C/C++ functions may not work properly + because this class heavily depends on the `inspect` module. + + Usage: + + >>> def foo_bar(x, a=1, b=2): + ... '''Foo bar + ... :param x: input + ... :param int a: default 1 + ... :param int b: default 2 + ... ''' + ... return x + a - b + + + >>> class FooBar(FuncTrans): + ... _func = foo_bar + ... __doc__ = foo_bar.__doc__ + """ + + _func = None + + def __init__(self, **kwargs): + self.kwargs = kwargs + check_kwargs(self.func, kwargs) + + def __call__(self, x): + return self.func(x, **self.kwargs) + + @classmethod + def add_arguments(cls, parser): + fname = cls._func.__name__.replace("_", "-") + group = parser.add_argument_group(fname + " transformation setting") + for k, v in cls.default_params().items(): + # TODO(karita): get help and choices from docstring? + attr = k.replace("_", "-") + group.add_argument(f"--{fname}-{attr}", default=v, type=type(v)) + return parser + + @property + def func(self): + return type(self)._func + + @classmethod + def default_params(cls): + try: + d = dict(inspect.signature(cls._func).parameters) + except ValueError: + d = dict() + return { + k: v.default for k, v in d.items() if v.default != inspect.Parameter.empty + } + + def __repr__(self): + params = self.default_params() + params.update(**self.kwargs) + ret = self.__class__.__name__ + "(" + if len(params) == 0: + return ret + ")" + for k, v in params.items(): + ret += "{}={}, ".format(k, v) + return ret[:-2] + ")" diff --git a/deepspeech/transform/perturb.py b/deepspeech/transform/perturb.py new file mode 100644 index 00000000..05766cad --- /dev/null +++ b/deepspeech/transform/perturb.py @@ -0,0 +1,343 @@ +import librosa +import numpy +import scipy +import soundfile + +from deepspeech.io.reader import SoundHDF5File + +class SpeedPerturbation(): + """SpeedPerturbation + + The speed perturbation in kaldi uses sox-speed instead of sox-tempo, + and sox-speed just to resample the input, + i.e pitch and tempo are changed both. + + "Why use speed option instead of tempo -s in SoX for speed perturbation" + https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8 + + Warning: + This function is very slow because of resampling. + I recommmend to apply speed-perturb outside the training using sox. + + """ + + def __init__( + self, + lower=0.9, + upper=1.1, + utt2ratio=None, + keep_length=True, + res_type="kaiser_best", + seed=None, + ): + self.res_type = res_type + self.keep_length = keep_length + self.state = numpy.random.RandomState(seed) + + if utt2ratio is not None: + self.utt2ratio = {} + # Use the scheduled ratio for each utterances + self.utt2ratio_file = utt2ratio + self.lower = None + self.upper = None + self.accept_uttid = True + + with open(utt2ratio, "r") as f: + for line in f: + utt, ratio = line.rstrip().split(None, 1) + ratio = float(ratio) + self.utt2ratio[utt] = ratio + else: + self.utt2ratio = None + # The ratio is given on runtime randomly + self.lower = lower + self.upper = upper + + def __repr__(self): + if self.utt2ratio is None: + return "{}(lower={}, upper={}, " "keep_length={}, res_type={})".format( + self.__class__.__name__, + self.lower, + self.upper, + self.keep_length, + self.res_type, + ) + else: + return "{}({}, res_type={})".format( + self.__class__.__name__, self.utt2ratio_file, self.res_type + ) + + def __call__(self, x, uttid=None, train=True): + if not train: + return x + + x = x.astype(numpy.float32) + if self.accept_uttid: + ratio = self.utt2ratio[uttid] + else: + ratio = self.state.uniform(self.lower, self.upper) + + # Note1: resample requires the sampling-rate of input and output, + # but actually only the ratio is used. + y = librosa.resample(x, ratio, 1, res_type=self.res_type) + + if self.keep_length: + diff = abs(len(x) - len(y)) + if len(y) > len(x): + # Truncate noise + y = y[diff // 2 : -((diff + 1) // 2)] + elif len(y) < len(x): + # Assume the time-axis is the first: (Time, Channel) + pad_width = [(diff // 2, (diff + 1) // 2)] + [ + (0, 0) for _ in range(y.ndim - 1) + ] + y = numpy.pad( + y, pad_width=pad_width, constant_values=0, mode="constant" + ) + return y + + +class BandpassPerturbation(): + """BandpassPerturbation + + Randomly dropout along the frequency axis. + + The original idea comes from the following: + "randomly-selected frequency band was cut off under the constraint of + leaving at least 1,000 Hz band within the range of less than 4,000Hz." + (The Hitachi/JHU CHiME-5 system: Advances in speech recognition for + everyday home environments using multiple microphone arrays; + http://spandh.dcs.shef.ac.uk/chime_workshop/papers/CHiME_2018_paper_kanda.pdf) + + """ + + def __init__(self, lower=0.0, upper=0.75, seed=None, axes=(-1,)): + self.lower = lower + self.upper = upper + self.state = numpy.random.RandomState(seed) + # x_stft: (Time, Channel, Freq) + self.axes = axes + + def __repr__(self): + return "{}(lower={}, upper={})".format( + self.__class__.__name__, self.lower, self.upper + ) + + def __call__(self, x_stft, uttid=None, train=True): + if not train: + return x_stft + + if x_stft.ndim == 1: + raise RuntimeError( + "Input in time-freq domain: " "(Time, Channel, Freq) or (Time, Freq)" + ) + + ratio = self.state.uniform(self.lower, self.upper) + axes = [i if i >= 0 else x_stft.ndim - i for i in self.axes] + shape = [s if i in axes else 1 for i, s in enumerate(x_stft.shape)] + + mask = self.state.randn(*shape) > ratio + x_stft *= mask + return x_stft + + +class VolumePerturbation(): + def __init__(self, lower=-1.6, upper=1.6, utt2ratio=None, dbunit=True, seed=None): + self.dbunit = dbunit + self.utt2ratio_file = utt2ratio + self.lower = lower + self.upper = upper + self.state = numpy.random.RandomState(seed) + + if utt2ratio is not None: + # Use the scheduled ratio for each utterances + self.utt2ratio = {} + self.lower = None + self.upper = None + self.accept_uttid = True + + with open(utt2ratio, "r") as f: + for line in f: + utt, ratio = line.rstrip().split(None, 1) + ratio = float(ratio) + self.utt2ratio[utt] = ratio + else: + # The ratio is given on runtime randomly + self.utt2ratio = None + + def __repr__(self): + if self.utt2ratio is None: + return "{}(lower={}, upper={}, dbunit={})".format( + self.__class__.__name__, self.lower, self.upper, self.dbunit + ) + else: + return '{}("{}", dbunit={})'.format( + self.__class__.__name__, self.utt2ratio_file, self.dbunit + ) + + def __call__(self, x, uttid=None, train=True): + if not train: + return x + + x = x.astype(numpy.float32) + + if self.accept_uttid: + ratio = self.utt2ratio[uttid] + else: + ratio = self.state.uniform(self.lower, self.upper) + if self.dbunit: + ratio = 10 ** (ratio / 20) + return x * ratio + + +class NoiseInjection(): + """Add isotropic noise""" + + def __init__( + self, + utt2noise=None, + lower=-20, + upper=-5, + utt2ratio=None, + filetype="list", + dbunit=True, + seed=None, + ): + self.utt2noise_file = utt2noise + self.utt2ratio_file = utt2ratio + self.filetype = filetype + self.dbunit = dbunit + self.lower = lower + self.upper = upper + self.state = numpy.random.RandomState(seed) + + if utt2ratio is not None: + # Use the scheduled ratio for each utterances + self.utt2ratio = {} + with open(utt2noise, "r") as f: + for line in f: + utt, snr = line.rstrip().split(None, 1) + snr = float(snr) + self.utt2ratio[utt] = snr + else: + # The ratio is given on runtime randomly + self.utt2ratio = None + + if utt2noise is not None: + self.utt2noise = {} + if filetype == "list": + with open(utt2noise, "r") as f: + for line in f: + utt, filename = line.rstrip().split(None, 1) + signal, rate = soundfile.read(filename, dtype="int16") + # Load all files in memory + self.utt2noise[utt] = (signal, rate) + + elif filetype == "sound.hdf5": + self.utt2noise = SoundHDF5File(utt2noise, "r") + else: + raise ValueError(filetype) + else: + self.utt2noise = None + + if utt2noise is not None and utt2ratio is not None: + if set(self.utt2ratio) != set(self.utt2noise): + raise RuntimeError( + "The uttids mismatch between {} and {}".format(utt2ratio, utt2noise) + ) + + def __repr__(self): + if self.utt2ratio is None: + return "{}(lower={}, upper={}, dbunit={})".format( + self.__class__.__name__, self.lower, self.upper, self.dbunit + ) + else: + return '{}("{}", dbunit={})'.format( + self.__class__.__name__, self.utt2ratio_file, self.dbunit + ) + + def __call__(self, x, uttid=None, train=True): + if not train: + return x + x = x.astype(numpy.float32) + + # 1. Get ratio of noise to signal in sound pressure level + if uttid is not None and self.utt2ratio is not None: + ratio = self.utt2ratio[uttid] + else: + ratio = self.state.uniform(self.lower, self.upper) + + if self.dbunit: + ratio = 10 ** (ratio / 20) + scale = ratio * numpy.sqrt((x ** 2).mean()) + + # 2. Get noise + if self.utt2noise is not None: + # Get noise from the external source + if uttid is not None: + noise, rate = self.utt2noise[uttid] + else: + # Randomly select the noise source + noise = self.state.choice(list(self.utt2noise.values())) + # Normalize the level + noise /= numpy.sqrt((noise ** 2).mean()) + + # Adjust the noise length + diff = abs(len(x) - len(noise)) + offset = self.state.randint(0, diff) + if len(noise) > len(x): + # Truncate noise + noise = noise[offset : -(diff - offset)] + else: + noise = numpy.pad(noise, pad_width=[offset, diff - offset], mode="wrap") + + else: + # Generate white noise + noise = self.state.normal(0, 1, x.shape) + + # 3. Add noise to signal + return x + noise * scale + + +class RIRConvolve(): + def __init__(self, utt2rir, filetype="list"): + self.utt2rir_file = utt2rir + self.filetype = filetype + + self.utt2rir = {} + if filetype == "list": + with open(utt2rir, "r") as f: + for line in f: + utt, filename = line.rstrip().split(None, 1) + signal, rate = soundfile.read(filename, dtype="int16") + self.utt2rir[utt] = (signal, rate) + + elif filetype == "sound.hdf5": + self.utt2rir = SoundHDF5File(utt2rir, "r") + else: + raise NotImplementedError(filetype) + + def __repr__(self): + return '{}("{}")'.format(self.__class__.__name__, self.utt2rir_file) + + def __call__(self, x, uttid=None, train=True): + if not train: + return x + + x = x.astype(numpy.float32) + + if x.ndim != 1: + # Must be single channel + raise RuntimeError( + "Input x must be one dimensional array, but got {}".format(x.shape) + ) + + rir, rate = self.utt2rir[uttid] + if rir.ndim == 2: + # FIXME(kamo): Use chainer.convolution_1d? + # return [Time, Channel] + return numpy.stack( + [scipy.convolve(x, r, mode="same") for r in rir], axis=-1 + ) + else: + return scipy.convolve(x, rir, mode="same") diff --git a/deepspeech/transform/spec_augment.py b/deepspeech/transform/spec_augment.py new file mode 100644 index 00000000..feb712df --- /dev/null +++ b/deepspeech/transform/spec_augment.py @@ -0,0 +1,202 @@ +"""Spec Augment module for preprocessing i.e., data augmentation""" + +import random + +import numpy +from PIL import Image +from PIL.Image import BICUBIC + +from deepspeech.transform.functional import FuncTrans + + +def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"): + """time warp for spec augment + + move random center frame by the random width ~ uniform(-window, window) + :param numpy.ndarray x: spectrogram (time, freq) + :param int max_time_warp: maximum time frames to warp + :param bool inplace: overwrite x with the result + :param str mode: "PIL" (default, fast, not differentiable) or "sparse_image_warp" + (slow, differentiable) + :returns numpy.ndarray: time warped spectrogram (time, freq) + """ + window = max_time_warp + if mode == "PIL": + t = x.shape[0] + if t - window <= window: + return x + # NOTE: randrange(a, b) emits a, a + 1, ..., b - 1 + center = random.randrange(window, t - window) + warped = random.randrange(center - window, center + window) + 1 # 1 ... t - 1 + + left = Image.fromarray(x[:center]).resize((x.shape[1], warped), BICUBIC) + right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped), BICUBIC) + if inplace: + x[:warped] = left + x[warped:] = right + return x + return numpy.concatenate((left, right), 0) + elif mode == "sparse_image_warp": + import paddle + + from espnet.utils import spec_augment + + # TODO(karita): make this differentiable again + return spec_augment.time_warp(paddle.to_tensor(x), window).numpy() + else: + raise NotImplementedError( + "unknown resize mode: " + + mode + + ", choose one from (PIL, sparse_image_warp)." + ) + + +class TimeWarp(FuncTrans): + _func = time_warp + __doc__ = time_warp.__doc__ + + def __call__(self, x, train): + if not train: + return x + return super().__call__(x) + + +def freq_mask(x, F=30, n_mask=2, replace_with_zero=True, inplace=False): + """freq mask for spec agument + + :param numpy.ndarray x: (time, freq) + :param int n_mask: the number of masks + :param bool inplace: overwrite + :param bool replace_with_zero: pad zero on mask if true else use mean + """ + if inplace: + cloned = x + else: + cloned = x.copy() + + num_mel_channels = cloned.shape[1] + fs = numpy.random.randint(0, F, size=(n_mask, 2)) + + for f, mask_end in fs: + f_zero = random.randrange(0, num_mel_channels - f) + mask_end += f_zero + + # avoids randrange error if values are equal and range is empty + if f_zero == f_zero + f: + continue + + if replace_with_zero: + cloned[:, f_zero:mask_end] = 0 + else: + cloned[:, f_zero:mask_end] = cloned.mean() + return cloned + + +class FreqMask(FuncTrans): + _func = freq_mask + __doc__ = freq_mask.__doc__ + + def __call__(self, x, train): + if not train: + return x + return super().__call__(x) + + +def time_mask(spec, T=40, n_mask=2, replace_with_zero=True, inplace=False): + """freq mask for spec agument + + :param numpy.ndarray spec: (time, freq) + :param int n_mask: the number of masks + :param bool inplace: overwrite + :param bool replace_with_zero: pad zero on mask if true else use mean + """ + if inplace: + cloned = spec + else: + cloned = spec.copy() + len_spectro = cloned.shape[0] + ts = numpy.random.randint(0, T, size=(n_mask, 2)) + for t, mask_end in ts: + # avoid randint range error + if len_spectro - t <= 0: + continue + t_zero = random.randrange(0, len_spectro - t) + + # avoids randrange error if values are equal and range is empty + if t_zero == t_zero + t: + continue + + mask_end += t_zero + if replace_with_zero: + cloned[t_zero:mask_end] = 0 + else: + cloned[t_zero:mask_end] = cloned.mean() + return cloned + + +class TimeMask(FuncTrans): + _func = time_mask + __doc__ = time_mask.__doc__ + + def __call__(self, x, train): + if not train: + return x + return super().__call__(x) + + +def spec_augment( + x, + resize_mode="PIL", + max_time_warp=80, + max_freq_width=27, + n_freq_mask=2, + max_time_width=100, + n_time_mask=2, + inplace=True, + replace_with_zero=True, +): + """spec agument + + apply random time warping and time/freq masking + default setting is based on LD (Librispeech double) in Table 2 + https://arxiv.org/pdf/1904.08779.pdf + + :param numpy.ndarray x: (time, freq) + :param str resize_mode: "PIL" (fast, nondifferentiable) or "sparse_image_warp" + (slow, differentiable) + :param int max_time_warp: maximum frames to warp the center frame in spectrogram (W) + :param int freq_mask_width: maximum width of the random freq mask (F) + :param int n_freq_mask: the number of the random freq mask (m_F) + :param int time_mask_width: maximum width of the random time mask (T) + :param int n_time_mask: the number of the random time mask (m_T) + :param bool inplace: overwrite intermediate array + :param bool replace_with_zero: pad zero on mask if true else use mean + """ + assert isinstance(x, numpy.ndarray) + assert x.ndim == 2 + x = time_warp(x, max_time_warp, inplace=inplace, mode=resize_mode) + x = freq_mask( + x, + max_freq_width, + n_freq_mask, + inplace=inplace, + replace_with_zero=replace_with_zero, + ) + x = time_mask( + x, + max_time_width, + n_time_mask, + inplace=inplace, + replace_with_zero=replace_with_zero, + ) + return x + + +class SpecAugment(FuncTrans): + _func = spec_augment + __doc__ = spec_augment.__doc__ + + def __call__(self, x, train): + if not train: + return x + return super().__call__(x) diff --git a/deepspeech/transform/spectrogram.py b/deepspeech/transform/spectrogram.py new file mode 100644 index 00000000..68d47627 --- /dev/null +++ b/deepspeech/transform/spectrogram.py @@ -0,0 +1,307 @@ +import librosa +import numpy as np + + +def stft( + x, n_fft, n_shift, win_length=None, window="hann", center=True, pad_mode="reflect" +): + # x: [Time, Channel] + if x.ndim == 1: + single_channel = True + # x: [Time] -> [Time, Channel] + x = x[:, None] + else: + single_channel = False + x = x.astype(np.float32) + + # FIXME(kamo): librosa.stft can't use multi-channel? + # x: [Time, Channel, Freq] + x = np.stack( + [ + librosa.stft( + x[:, ch], + n_fft=n_fft, + hop_length=n_shift, + win_length=win_length, + window=window, + center=center, + pad_mode=pad_mode, + ).T + for ch in range(x.shape[1]) + ], + axis=1, + ) + + if single_channel: + # x: [Time, Channel, Freq] -> [Time, Freq] + x = x[:, 0] + return x + + +def istft(x, n_shift, win_length=None, window="hann", center=True): + # x: [Time, Channel, Freq] + if x.ndim == 2: + single_channel = True + # x: [Time, Freq] -> [Time, Channel, Freq] + x = x[:, None, :] + else: + single_channel = False + + # x: [Time, Channel] + x = np.stack( + [ + librosa.istft( + x[:, ch].T, # [Time, Freq] -> [Freq, Time] + hop_length=n_shift, + win_length=win_length, + window=window, + center=center, + ) + for ch in range(x.shape[1]) + ], + axis=1, + ) + + if single_channel: + # x: [Time, Channel] -> [Time] + x = x[:, 0] + return x + + +def stft2logmelspectrogram(x_stft, fs, n_mels, n_fft, fmin=None, fmax=None, eps=1e-10): + # x_stft: (Time, Channel, Freq) or (Time, Freq) + fmin = 0 if fmin is None else fmin + fmax = fs / 2 if fmax is None else fmax + + # spc: (Time, Channel, Freq) or (Time, Freq) + spc = np.abs(x_stft) + # mel_basis: (Mel_freq, Freq) + mel_basis = librosa.filters.mel(fs, n_fft, n_mels, fmin, fmax) + # lmspc: (Time, Channel, Mel_freq) or (Time, Mel_freq) + lmspc = np.log10(np.maximum(eps, np.dot(spc, mel_basis.T))) + + return lmspc + + +def spectrogram(x, n_fft, n_shift, win_length=None, window="hann"): + # x: (Time, Channel) -> spc: (Time, Channel, Freq) + spc = np.abs(stft(x, n_fft, n_shift, win_length, window=window)) + return spc + + +def logmelspectrogram( + x, + fs, + n_mels, + n_fft, + n_shift, + win_length=None, + window="hann", + fmin=None, + fmax=None, + eps=1e-10, + pad_mode="reflect", +): + # stft: (Time, Channel, Freq) or (Time, Freq) + x_stft = stft( + x, + n_fft=n_fft, + n_shift=n_shift, + win_length=win_length, + window=window, + pad_mode=pad_mode, + ) + + return stft2logmelspectrogram( + x_stft, fs=fs, n_mels=n_mels, n_fft=n_fft, fmin=fmin, fmax=fmax, eps=eps + ) + + +class Spectrogram(): + def __init__(self, n_fft, n_shift, win_length=None, window="hann"): + self.n_fft = n_fft + self.n_shift = n_shift + self.win_length = win_length + self.window = window + + def __repr__(self): + return ( + "{name}(n_fft={n_fft}, n_shift={n_shift}, " + "win_length={win_length}, window={window})".format( + name=self.__class__.__name__, + n_fft=self.n_fft, + n_shift=self.n_shift, + win_length=self.win_length, + window=self.window, + ) + ) + + def __call__(self, x): + return spectrogram( + x, + n_fft=self.n_fft, + n_shift=self.n_shift, + win_length=self.win_length, + window=self.window, + ) + + +class LogMelSpectrogram(): + def __init__( + self, + fs, + n_mels, + n_fft, + n_shift, + win_length=None, + window="hann", + fmin=None, + fmax=None, + eps=1e-10, + ): + self.fs = fs + self.n_mels = n_mels + self.n_fft = n_fft + self.n_shift = n_shift + self.win_length = win_length + self.window = window + self.fmin = fmin + self.fmax = fmax + self.eps = eps + + def __repr__(self): + return ( + "{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, " + "n_shift={n_shift}, win_length={win_length}, window={window}, " + "fmin={fmin}, fmax={fmax}, eps={eps}))".format( + name=self.__class__.__name__, + fs=self.fs, + n_mels=self.n_mels, + n_fft=self.n_fft, + n_shift=self.n_shift, + win_length=self.win_length, + window=self.window, + fmin=self.fmin, + fmax=self.fmax, + eps=self.eps, + ) + ) + + def __call__(self, x): + return logmelspectrogram( + x, + fs=self.fs, + n_mels=self.n_mels, + n_fft=self.n_fft, + n_shift=self.n_shift, + win_length=self.win_length, + window=self.window, + ) + + +class Stft2LogMelSpectrogram(): + def __init__(self, fs, n_mels, n_fft, fmin=None, fmax=None, eps=1e-10): + self.fs = fs + self.n_mels = n_mels + self.n_fft = n_fft + self.fmin = fmin + self.fmax = fmax + self.eps = eps + + def __repr__(self): + return ( + "{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, " + "fmin={fmin}, fmax={fmax}, eps={eps}))".format( + name=self.__class__.__name__, + fs=self.fs, + n_mels=self.n_mels, + n_fft=self.n_fft, + fmin=self.fmin, + fmax=self.fmax, + eps=self.eps, + ) + ) + + def __call__(self, x): + return stft2logmelspectrogram( + x, + fs=self.fs, + n_mels=self.n_mels, + n_fft=self.n_fft, + fmin=self.fmin, + fmax=self.fmax, + ) + + +class Stft(): + def __init__( + self, + n_fft, + n_shift, + win_length=None, + window="hann", + center=True, + pad_mode="reflect", + ): + self.n_fft = n_fft + self.n_shift = n_shift + self.win_length = win_length + self.window = window + self.center = center + self.pad_mode = pad_mode + + def __repr__(self): + return ( + "{name}(n_fft={n_fft}, n_shift={n_shift}, " + "win_length={win_length}, window={window}," + "center={center}, pad_mode={pad_mode})".format( + name=self.__class__.__name__, + n_fft=self.n_fft, + n_shift=self.n_shift, + win_length=self.win_length, + window=self.window, + center=self.center, + pad_mode=self.pad_mode, + ) + ) + + def __call__(self, x): + return stft( + x, + self.n_fft, + self.n_shift, + win_length=self.win_length, + window=self.window, + center=self.center, + pad_mode=self.pad_mode, + ) + + +class IStft(): + def __init__(self, n_shift, win_length=None, window="hann", center=True): + self.n_shift = n_shift + self.win_length = win_length + self.window = window + self.center = center + + def __repr__(self): + return ( + "{name}(n_shift={n_shift}, " + "win_length={win_length}, window={window}," + "center={center})".format( + name=self.__class__.__name__, + n_shift=self.n_shift, + win_length=self.win_length, + window=self.window, + center=self.center, + ) + ) + + def __call__(self, x): + return istft( + x, + self.n_shift, + win_length=self.win_length, + window=self.window, + center=self.center, + ) diff --git a/deepspeech/transform/transform_interface.py b/deepspeech/transform/transform_interface.py new file mode 100644 index 00000000..8a6aba45 --- /dev/null +++ b/deepspeech/transform/transform_interface.py @@ -0,0 +1,20 @@ +# TODO(karita): add this to all the transform impl. +class TransformInterface: + """Transform Interface""" + + def __call__(self, x): + raise NotImplementedError("__call__ method is not implemented") + + @classmethod + def add_arguments(cls, parser): + return parser + + def __repr__(self): + return self.__class__.__name__ + "()" + + +class Identity(TransformInterface): + """Identity Function""" + + def __call__(self, x): + return x diff --git a/deepspeech/transform/transformation.py b/deepspeech/transform/transformation.py new file mode 100644 index 00000000..0f8c39bb --- /dev/null +++ b/deepspeech/transform/transformation.py @@ -0,0 +1,149 @@ +"""Transformation module.""" +from collections.abc import Sequence +from collections import OrderedDict +import copy +from inspect import signature +import io +import logging + +import yaml + +from deepspeech.utils.dynamic_import import dynamic_import + + +# TODO(karita): inherit TransformInterface +# TODO(karita): register cmd arguments in asr_train.py +import_alias = dict( + identity="deepspeech.transform.transform_interface:Identity", + time_warp="deepspeech.transform.spec_augment:TimeWarp", + time_mask="deepspeech.transform.spec_augment:TimeMask", + freq_mask="deepspeech.transform.spec_augment:FreqMask", + spec_augment="deepspeech.transform.spec_augment:SpecAugment", + speed_perturbation="deepspeech.transform.perturb:SpeedPerturbation", + volume_perturbation="deepspeech.transform.perturb:VolumePerturbation", + noise_injection="deepspeech.transform.perturb:NoiseInjection", + bandpass_perturbation="deepspeech.transform.perturb:BandpassPerturbation", + rir_convolve="deepspeech.transform.perturb:RIRConvolve", + delta="deepspeech.transform.add_deltas:AddDeltas", + cmvn="deepspeech.transform.cmvn:CMVN", + utterance_cmvn="deepspeech.transform.cmvn:UtteranceCMVN", + fbank="deepspeech.transform.spectrogram:LogMelSpectrogram", + spectrogram="deepspeech.transform.spectrogram:Spectrogram", + stft="deepspeech.transform.spectrogram:Stft", + istft="deepspeech.transform.spectrogram:IStft", + stft2fbank="deepspeech.transform.spectrogram:Stft2LogMelSpectrogram", + wpe="deepspeech.transform.wpe:WPE", + channel_selector="deepspeech.transform.channel_selector:ChannelSelector", +) + + +class Transformation(): + """Apply some functions to the mini-batch + + Examples: + >>> kwargs = {"process": [{"type": "fbank", + ... "n_mels": 80, + ... "fs": 16000}, + ... {"type": "cmvn", + ... "stats": "data/train/cmvn.ark", + ... "norm_vars": True}, + ... {"type": "delta", "window": 2, "order": 2}]} + >>> transform = Transformation(kwargs) + >>> bs = 10 + >>> xs = [np.random.randn(100, 80).astype(np.float32) + ... for _ in range(bs)] + >>> xs = transform(xs) + """ + + def __init__(self, conffile=None): + if conffile is not None: + if isinstance(conffile, dict): + self.conf = copy.deepcopy(conffile) + else: + with io.open(conffile, encoding="utf-8") as f: + self.conf = yaml.safe_load(f) + assert isinstance(self.conf, dict), type(self.conf) + else: + self.conf = {"mode": "sequential", "process": []} + + self.functions = OrderedDict() + if self.conf.get("mode", "sequential") == "sequential": + for idx, process in enumerate(self.conf["process"]): + assert isinstance(process, dict), type(process) + opts = dict(process) + process_type = opts.pop("type") + class_obj = dynamic_import(process_type, import_alias) + # TODO(karita): assert issubclass(class_obj, TransformInterface) + try: + self.functions[idx] = class_obj(**opts) + except TypeError: + try: + signa = signature(class_obj) + except ValueError: + # Some function, e.g. built-in function, are failed + pass + else: + logging.error( + "Expected signature: {}({})".format( + class_obj.__name__, signa + ) + ) + raise + else: + raise NotImplementedError( + "Not supporting mode={}".format(self.conf["mode"]) + ) + + def __repr__(self): + rep = "\n" + "\n".join( + " {}: {}".format(k, v) for k, v in self.functions.items() + ) + return "{}({})".format(self.__class__.__name__, rep) + + def __call__(self, xs, uttid_list=None, **kwargs): + """Return new mini-batch + + :param Union[Sequence[np.ndarray], np.ndarray] xs: + :param Union[Sequence[str], str] uttid_list: + :return: batch: + :rtype: List[np.ndarray] + """ + if not isinstance(xs, Sequence): + is_batch = False + xs = [xs] + else: + is_batch = True + + if isinstance(uttid_list, str): + uttid_list = [uttid_list for _ in range(len(xs))] + + if self.conf.get("mode", "sequential") == "sequential": + for idx in range(len(self.conf["process"])): + func = self.functions[idx] + # TODO(karita): use TrainingTrans and UttTrans to check __call__ args + # Derive only the args which the func has + try: + param = signature(func).parameters + except ValueError: + # Some function, e.g. built-in function, are failed + param = {} + _kwargs = {k: v for k, v in kwargs.items() if k in param} + try: + if uttid_list is not None and "uttid" in param: + xs = [func(x, u, **_kwargs) for x, u in zip(xs, uttid_list)] + else: + xs = [func(x, **_kwargs) for x in xs] + except Exception: + logging.fatal( + "Catch a exception from {}th func: {}".format(idx, func) + ) + raise + else: + raise NotImplementedError( + "Not supporting mode={}".format(self.conf["mode"]) + ) + + if is_batch: + return xs + else: + return xs[0] diff --git a/deepspeech/transform/wpe.py b/deepspeech/transform/wpe.py new file mode 100644 index 00000000..8aed97e6 --- /dev/null +++ b/deepspeech/transform/wpe.py @@ -0,0 +1,45 @@ +from nara_wpe.wpe import wpe + + +class WPE(object): + def __init__( + self, taps=10, delay=3, iterations=3, psd_context=0, statistics_mode="full" + ): + self.taps = taps + self.delay = delay + self.iterations = iterations + self.psd_context = psd_context + self.statistics_mode = statistics_mode + + def __repr__(self): + return ( + "{name}(taps={taps}, delay={delay}" + "iterations={iterations}, psd_context={psd_context}, " + "statistics_mode={statistics_mode})".format( + name=self.__class__.__name__, + taps=self.taps, + delay=self.delay, + iterations=self.iterations, + psd_context=self.psd_context, + statistics_mode=self.statistics_mode, + ) + ) + + def __call__(self, xs): + """Return enhanced + + :param np.ndarray xs: (Time, Channel, Frequency) + :return: enhanced_xs + :rtype: np.ndarray + + """ + # nara_wpe.wpe: (F, C, T) + xs = wpe( + xs.transpose((2, 1, 0)), + taps=self.taps, + delay=self.delay, + iterations=self.iterations, + psd_context=self.psd_context, + statistics_mode=self.statistics_mode, + ) + return xs.transpose(2, 1, 0) diff --git a/deepspeech/utils/check_kwargs.py b/deepspeech/utils/check_kwargs.py new file mode 100644 index 00000000..593bfa24 --- /dev/null +++ b/deepspeech/utils/check_kwargs.py @@ -0,0 +1,20 @@ +import inspect + + +def check_kwargs(func, kwargs, name=None): + """check kwargs are valid for func + + If kwargs are invalid, raise TypeError as same as python default + :param function func: function to be validated + :param dict kwargs: keyword arguments for func + :param str name: name used in TypeError (default is func name) + """ + try: + params = inspect.signature(func).parameters + except ValueError: + return + if name is None: + name = func.__name__ + for k in kwargs.keys(): + if k not in params: + raise TypeError(f"{name}() got an unexpected keyword argument '{k}'") diff --git a/deepspeech/utils/spec_augment.py b/deepspeech/utils/spec_augment.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/librispeech/s2/conf/chunk_conformer.yaml b/examples/librispeech/s2/conf/chunk_conformer.yaml deleted file mode 100644 index afd2b051..00000000 --- a/examples/librispeech/s2/conf/chunk_conformer.yaml +++ /dev/null @@ -1,122 +0,0 @@ -# https://yaml.org/type/float.html -data: - train_manifest: data/manifest.train - dev_manifest: data/manifest.dev - test_manifest: data/manifest.test - min_input_len: 0.5 - max_input_len: 20.0 - min_output_len: 0.0 - max_output_len: 400.0 - min_output_input_ratio: 0.05 - max_output_input_ratio: 10.0 - -collator: - vocab_filepath: data/vocab.txt - unit_type: 'spm' - spm_model_prefix: 'data/bpe_unigram_5000' - mean_std_filepath: "" - augmentation_config: conf/augmentation.json - batch_size: 16 - raw_wav: True # use raw_wav or kaldi feature - spectrum_type: fbank #linear, mfcc, fbank - feat_dim: 80 - delta_delta: False - dither: 1.0 - target_sample_rate: 16000 - max_freq: None - n_fft: None - stride_ms: 10.0 - window_ms: 25.0 - use_dB_normalization: True - target_dB: -20 - random_seed: 0 - keep_transcription_text: False - sortagrad: True - shuffle_method: batch_shuffle - num_workers: 2 - - -# network architecture -model: - cmvn_file: "data/mean_std.json" - cmvn_file_type: "json" - # encoder related - encoder: conformer - encoder_conf: - output_size: 256 # dimension of attention - attention_heads: 4 - linear_units: 2048 # the number of units of position-wise feed forward - num_blocks: 12 # the number of encoder blocks - dropout_rate: 0.1 - positional_dropout_rate: 0.1 - attention_dropout_rate: 0.0 - input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 - normalize_before: True - use_cnn_module: True - cnn_module_kernel: 15 - activation_type: 'swish' - pos_enc_layer_type: 'rel_pos' - selfattention_layer_type: 'rel_selfattn' - causal: True - use_dynamic_chunk: true - cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster - use_dynamic_left_chunk: false - - # decoder related - decoder: transformer - decoder_conf: - attention_heads: 4 - linear_units: 2048 - num_blocks: 6 - dropout_rate: 0.1 - positional_dropout_rate: 0.1 - self_attention_dropout_rate: 0.0 - src_attention_dropout_rate: 0.0 - - # hybrid CTC/attention - model_conf: - ctc_weight: 0.3 - ctc_dropoutrate: 0.0 - ctc_grad_norm_type: null - lsm_weight: 0.1 # label smoothing option - length_normalized_loss: false - - -training: - n_epoch: 240 - accum_grad: 8 - global_grad_clip: 5.0 - optim: adam - optim_conf: - lr: 0.001 - weight_decay: 1e-06 - scheduler: warmuplr # pytorch v1.1.0+ required - scheduler_conf: - warmup_steps: 25000 - lr_decay: 1.0 - log_interval: 100 - checkpoint: - kbest_n: 50 - latest_n: 5 - - -decoding: - batch_size: 128 - error_rate_type: wer - decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' - lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm - alpha: 2.5 - beta: 0.3 - beam_size: 10 - cutoff_prob: 1.0 - cutoff_top_n: 0 - num_proc_bsearch: 8 - ctc_weight: 0.5 # ctc weight for attention rescoring decode mode. - decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. - # <0: for decoding, use full chunk. - # >0: for decoding, use fixed chunk size as set. - # 0: used for training, it's prohibited here. - num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. - simulate_streaming: true # simulate streaming inference. Defaults to False. - - diff --git a/examples/librispeech/s2/conf/chunk_transformer.yaml b/examples/librispeech/s2/conf/chunk_transformer.yaml deleted file mode 100644 index 721bb7d9..00000000 --- a/examples/librispeech/s2/conf/chunk_transformer.yaml +++ /dev/null @@ -1,115 +0,0 @@ -# https://yaml.org/type/float.html -data: - train_manifest: data/manifest.train - dev_manifest: data/manifest.dev - test_manifest: data/manifest.test - min_input_len: 0.5 # second - max_input_len: 20.0 # second - min_output_len: 0.0 # tokens - max_output_len: 400.0 # tokens - min_output_input_ratio: 0.05 - max_output_input_ratio: 10.0 - -collator: - vocab_filepath: data/vocab.txt - unit_type: 'spm' - spm_model_prefix: 'data/bpe_unigram_5000' - mean_std_filepath: "" - augmentation_config: conf/augmentation.json - batch_size: 64 - raw_wav: True # use raw_wav or kaldi feature - spectrum_type: fbank #linear, mfcc, fbank - feat_dim: 80 - delta_delta: False - dither: 1.0 - target_sample_rate: 16000 - max_freq: None - n_fft: None - stride_ms: 10.0 - window_ms: 25.0 - use_dB_normalization: True - target_dB: -20 - random_seed: 0 - keep_transcription_text: False - sortagrad: True - shuffle_method: batch_shuffle - num_workers: 2 - - -# network architecture -model: - cmvn_file: "data/mean_std.json" - cmvn_file_type: "json" - # encoder related - encoder: transformer - encoder_conf: - output_size: 256 # dimension of attention - attention_heads: 4 - linear_units: 2048 # the number of units of position-wise feed forward - num_blocks: 12 # the number of encoder blocks - dropout_rate: 0.1 - positional_dropout_rate: 0.1 - attention_dropout_rate: 0.0 - input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 - normalize_before: true - use_dynamic_chunk: true - use_dynamic_left_chunk: false - - # decoder related - decoder: transformer - decoder_conf: - attention_heads: 4 - linear_units: 2048 - num_blocks: 6 - dropout_rate: 0.1 - positional_dropout_rate: 0.1 - self_attention_dropout_rate: 0.0 - src_attention_dropout_rate: 0.0 - - # hybrid CTC/attention - model_conf: - ctc_weight: 0.3 - ctc_dropoutrate: 0.0 - ctc_grad_norm_type: null - lsm_weight: 0.1 # label smoothing option - length_normalized_loss: false - - -training: - n_epoch: 120 - accum_grad: 1 - global_grad_clip: 5.0 - optim: adam - optim_conf: - lr: 0.001 - weight_decay: 1e-06 - scheduler: warmuplr # pytorch v1.1.0+ required - scheduler_conf: - warmup_steps: 25000 - lr_decay: 1.0 - log_interval: 100 - checkpoint: - kbest_n: 50 - latest_n: 5 - - -decoding: - batch_size: 64 - error_rate_type: wer - decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' - lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm - alpha: 2.5 - beta: 0.3 - beam_size: 10 - cutoff_prob: 1.0 - cutoff_top_n: 0 - num_proc_bsearch: 8 - ctc_weight: 0.5 # ctc weight for attention rescoring decode mode. - decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. - # <0: for decoding, use full chunk. - # >0: for decoding, use fixed chunk size as set. - # 0: used for training, it's prohibited here. - num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. - simulate_streaming: true # simulate streaming inference. Defaults to False. - - diff --git a/examples/librispeech/s2/conf/conformer.yaml b/examples/librispeech/s2/conf/conformer.yaml deleted file mode 100644 index ef87753c..00000000 --- a/examples/librispeech/s2/conf/conformer.yaml +++ /dev/null @@ -1,118 +0,0 @@ -# https://yaml.org/type/float.html -data: - train_manifest: data/manifest.train - dev_manifest: data/manifest.dev - test_manifest: data/manifest.test-clean - min_input_len: 0.5 # seconds - max_input_len: 20.0 # seconds - min_output_len: 0.0 # tokens - max_output_len: 400.0 # tokens - min_output_input_ratio: 0.05 - max_output_input_ratio: 10.0 - -collator: - vocab_filepath: data/vocab.txt - unit_type: 'spm' - spm_model_prefix: 'data/bpe_unigram_5000' - mean_std_filepath: "" - augmentation_config: conf/augmentation.json - batch_size: 16 - raw_wav: True # use raw_wav or kaldi feature - spectrum_type: fbank #linear, mfcc, fbank - feat_dim: 80 - delta_delta: False - dither: 1.0 - target_sample_rate: 16000 - max_freq: None - n_fft: None - stride_ms: 10.0 - window_ms: 25.0 - use_dB_normalization: True - target_dB: -20 - random_seed: 0 - keep_transcription_text: False - sortagrad: True - shuffle_method: batch_shuffle - num_workers: 2 - - -# network architecture -model: - cmvn_file: "data/mean_std.json" - cmvn_file_type: "json" - # encoder related - encoder: conformer - encoder_conf: - output_size: 256 # dimension of attention - attention_heads: 4 - linear_units: 2048 # the number of units of position-wise feed forward - num_blocks: 12 # the number of encoder blocks - dropout_rate: 0.1 - positional_dropout_rate: 0.1 - attention_dropout_rate: 0.0 - input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 - normalize_before: True - use_cnn_module: True - cnn_module_kernel: 15 - activation_type: 'swish' - pos_enc_layer_type: 'rel_pos' - selfattention_layer_type: 'rel_selfattn' - - # decoder related - decoder: transformer - decoder_conf: - attention_heads: 4 - linear_units: 2048 - num_blocks: 6 - dropout_rate: 0.1 - positional_dropout_rate: 0.1 - self_attention_dropout_rate: 0.0 - src_attention_dropout_rate: 0.0 - - # hybrid CTC/attention - model_conf: - ctc_weight: 0.3 - ctc_dropoutrate: 0.0 - ctc_grad_norm_type: null - lsm_weight: 0.1 # label smoothing option - length_normalized_loss: false - - -training: - n_epoch: 120 - accum_grad: 8 - global_grad_clip: 3.0 - optim: adam - optim_conf: - lr: 0.004 - weight_decay: 1e-06 - scheduler: warmuplr # pytorch v1.1.0+ required - scheduler_conf: - warmup_steps: 25000 - lr_decay: 1.0 - log_interval: 100 - checkpoint: - kbest_n: 50 - latest_n: 5 - - -decoding: - batch_size: 64 - error_rate_type: wer - decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' - lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm - alpha: 2.5 - beta: 0.3 - beam_size: 10 - cutoff_prob: 1.0 - cutoff_top_n: 0 - num_proc_bsearch: 8 - ctc_weight: 0.5 # ctc weight for attention rescoring decode mode. - decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. - # <0: for decoding, use full chunk. - # >0: for decoding, use fixed chunk size as set. - # 0: used for training, it's prohibited here. - num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. - simulate_streaming: False # simulate streaming inference. Defaults to False. - - diff --git a/examples/librispeech/s2/conf/transformer.yaml b/examples/librispeech/s2/conf/transformer.yaml index c9eed4f9..b2babca7 100644 --- a/examples/librispeech/s2/conf/transformer.yaml +++ b/examples/librispeech/s2/conf/transformer.yaml @@ -5,9 +5,9 @@ data: test_manifest: data/manifest.test-clean collator: - vocab_filepath: data/bpe_unigram_5000_units.txt + vocab_filepath: data/lang_char/train_960_unigram5000_units.txt unit_type: spm - spm_model_prefix: data/bpe_unigram_5000 + spm_model_prefix: data/lang_char/train_960_unigram5000 feat_dim: 83 stride_ms: 10.0 window_ms: 25.0 diff --git a/examples/librispeech/s2/local/data.sh b/examples/librispeech/s2/local/data.sh index 4d8c6fdd..b232f35a 100755 --- a/examples/librispeech/s2/local/data.sh +++ b/examples/librispeech/s2/local/data.sh @@ -20,7 +20,6 @@ datadir=${MAIN_ROOT}/examples/dataset/ # bpemode (unigram or bpe) nbpe=5000 bpemode=unigram -bpeprefix="data/bpe_${bpemode}_${nbpe}" source ${MAIN_ROOT}/utils/parse_options.sh @@ -153,26 +152,15 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then fi -if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then - # format manifest with tokenids, vocab size - for set in train dev test dev-clean dev-other test-clean test-other; do - { - python3 ${MAIN_ROOT}/utils/format_data.py \ - --feat_type "raw" \ - --cmvn_path "data/mean_std.json" \ - --unit_type "spm" \ - --spm_model_prefix ${bpeprefix} \ - --vocab_path="data/vocab.txt" \ - --manifest_path="data/manifest.${set}.raw" \ - --output_path="data/manifest.${set}" - - if [ $? -ne 0 ]; then - echo "Formt mnaifest failed. Terminated." - exit 1 - fi - }& +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # make json labels + python3 local/espnet_json_to_manifest.py --json-file ${feat_sp_dir}/data_${bpemode}${nbpe}.json --manifest-file data/manifest.train + python3 local/espnet_json_to_manifest.py --json-file ${feat_dt_dir}/data_${bpemode}${nbpe}.json --manifest-file data/manifest.dev + + for rtask in ${recog_set}; do + feat_recog_dir=${dumpdir}/${rtask}/delta${do_delta} + python3 local/espnet_json_to_manifest.py --json-file ${feat_recog_dir}/data_${bpemode}${nbpe}.json --manifest-file data/manifest.${rtask//_/-} done - wait fi echo "LibriSpeech Data preparation done." diff --git a/examples/librispeech/s2/local/recog.sh b/examples/librispeech/s2/local/recog.sh index df3846c0..62c1479e 100755 --- a/examples/librispeech/s2/local/recog.sh +++ b/examples/librispeech/s2/local/recog.sh @@ -15,18 +15,19 @@ lang_model=rnnlm.model.best lmexpdir=exp/train_rnnlm_pytorch_lm_transformer_cosine_batchsize32_lr1e-4_layer16_unigram5000_ngpu4/ lmtag='nolm' +train_set=train_960 recog_set="test-clean test-other dev-clean dev-other" recog_set="test-clean" # bpemode (unigram or bpe) nbpe=5000 bpemode=unigram -bpeprefix="data/bpe_${bpemode}_${nbpe}" +bpeprefix=data/lang_char/${train_set}_${bpemode}${nbpe} bpemodel=${bpeprefix}.model # bin params config_path=conf/transformer.yaml -dict=data/bpe_unigram_5000_units.txt +dict=data/lang_char/${train_set}_${bpemode}${nbpe}_units.txt ckpt_prefix= source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; diff --git a/examples/librispeech/s2/local/test.sh b/examples/librispeech/s2/local/test.sh index 5f662d29..23670f74 100755 --- a/examples/librispeech/s2/local/test.sh +++ b/examples/librispeech/s2/local/test.sh @@ -8,17 +8,18 @@ nj=32 lmtag='nolm' +train_set=train_960 recog_set="test-clean test-other dev-clean dev-other" recog_set="test-clean" # bpemode (unigram or bpe) nbpe=5000 bpemode=unigram -bpeprefix="data/bpe_${bpemode}_${nbpe}" +bpeprefix=data/lang_char/${train_set}_${bpemode}${nbpe} bpemodel=${bpeprefix}.model config_path=conf/transformer.yaml -dict=data/bpe_unigram_5000_units.txt +dict=data/lang_char/${train_set}_${bpemode}${nbpe}_units.txt ckpt_prefix= source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; diff --git a/requirements.txt b/requirements.txt index d654ef3d..4878dbe3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -42,3 +42,4 @@ visualdl==2.2.0 webrtcvad yacs yq +nara_wpe \ No newline at end of file diff --git a/setup.py b/setup.py index be17e0a4..a2e4c031 100644 --- a/setup.py +++ b/setup.py @@ -65,13 +65,6 @@ def _remove(files: str): def _post_install(install_lib_dir): - # apt - check_call("apt-get update -y") - check_call("apt-get install -y " + 'vim tig tree sox pkg-config ' + - 'libsndfile1 libflac-dev libogg-dev ' + - 'libvorbis-dev libboost-dev swig python3-dev ') - print("apt install.") - # tools/make tool_dir = HERE / "tools" _remove(tool_dir.glob("*.done")) diff --git a/setup.sh b/setup.sh index d3dd8207..aefdab98 100644 --- a/setup.sh +++ b/setup.sh @@ -10,7 +10,7 @@ fi if [ -e /etc/lsb-release ];then ${SUDO} apt-get update -y - ${SUDO} apt-get install -y jq vim tig tree sox pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev + ${SUDO} apt-get install -y bc jq vim tig tree sox pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev if [ $? != 0 ]; then error_msg "Please using Ubuntu or install pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev by user." exit -1 diff --git a/tools/Makefile b/tools/Makefile index 39a619c5..87107a53 100644 --- a/tools/Makefile +++ b/tools/Makefile @@ -24,7 +24,7 @@ clean: apt.done: apt update -y - apt install -y flac jq vim tig tree pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev + apt install -y bc flac jq vim tig tree pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev echo "check_certificate = off" >> ~/.wgetrc touch apt.done diff --git a/utils/avg_model.py b/utils/avg_model.py index 7c05ec78..6ee16408 100755 --- a/utils/avg_model.py +++ b/utils/avg_model.py @@ -47,8 +47,10 @@ def main(args): beat_val_scores = sorted_val_scores[:args.num, 1] selected_epochs = sorted_val_scores[:args.num, 0].astype(np.int64) + avg_val_score = np.mean(beat_val_scores) print("selected val scores = " + str(beat_val_scores)) print("selected epochs = " + str(selected_epochs)) + print("averaged val score = " + str(avg_val_score)) path_list = [ args.ckpt_dir + '/{}.pdparams'.format(int(epoch)) @@ -80,7 +82,7 @@ def main(args): data = json.dumps({ "mode": 'val_best' if args.val_best else 'latest', "avg_ckpt": args.dst_model, - "val_loss_mean": np.mean(beat_val_scores), + "val_loss_mean": avg_val_score, "ckpts": path_list, "epochs": selected_epochs.tolist(), "val_losses": beat_val_scores.tolist(), diff --git a/utils/feat-to-shape.py b/utils/feat-to-shape.py new file mode 100755 index 00000000..911bf5cf --- /dev/null +++ b/utils/feat-to-shape.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +import argparse +import logging +import sys + +from deepspeech.transform.transformation import Transformation +from deepspeech.utils.cli_readers import file_reader_helper +from deepspeech.utils.cli_utils import get_commandline_args +from deepspeech.utils.cli_utils import is_scipy_wav_style + + +def get_parser(): + parser = argparse.ArgumentParser( + description="convert feature to its shape", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option") + parser.add_argument( + "--filetype", + type=str, + default="mat", + choices=["mat", "hdf5", "sound.hdf5", "sound"], + help="Specify the file format for the rspecifier. " + '"mat" is the matrix format in kaldi', + ) + parser.add_argument( + "--preprocess-conf", + type=str, + default=None, + help="The configuration file for the pre-processing", + ) + parser.add_argument( + "rspecifier", type=str, help="Read specifier for feats. e.g. ark:some.ark" + ) + parser.add_argument( + "out", + nargs="?", + type=argparse.FileType("w"), + default=sys.stdout, + help="The output filename. " "If omitted, then output to sys.stdout", + ) + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + # logging info + logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + if args.verbose > 0: + logging.basicConfig(level=logging.INFO, format=logfmt) + else: + logging.basicConfig(level=logging.WARN, format=logfmt) + logging.info(get_commandline_args()) + + if args.preprocess_conf is not None: + preprocessing = Transformation(args.preprocess_conf) + logging.info("Apply preprocessing: {}".format(preprocessing)) + else: + preprocessing = None + + # There are no necessary for matrix without preprocessing, + # so change to file_reader_helper to return shape. + # This make sense only with filetype="hdf5". + for utt, mat in file_reader_helper( + args.rspecifier, args.filetype, return_shape=preprocessing is None + ): + if preprocessing is not None: + if is_scipy_wav_style(mat): + # If data is sound file, then got as Tuple[int, ndarray] + rate, mat = mat + mat = preprocessing(mat, uttid_list=utt) + shape_str = ",".join(map(str, mat.shape)) + else: + if len(mat) == 2 and isinstance(mat[1], tuple): + # If data is sound file, Tuple[int, Tuple[int, ...]] + rate, mat = mat + shape_str = ",".join(map(str, mat)) + args.out.write("{} {}\n".format(utt, shape_str)) + + +if __name__ == "__main__": + main()