format code

pull/933/head
Hui Zhang 3 years ago
parent e8bc9a2a08
commit b878027c9a

@ -24,11 +24,7 @@ from .utils import add_results_to_json
from deepspeech.exps import dynamic_import_tester from deepspeech.exps import dynamic_import_tester
from deepspeech.io.reader import LoadInputsAndTargets from deepspeech.io.reader import LoadInputsAndTargets
from deepspeech.models.asr_interface import ASRInterface from deepspeech.models.asr_interface import ASRInterface
from deepspeech.models.lm.transformer import TransformerLM
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
# from espnet.asr.asr_utils import get_model_conf
# from espnet.asr.asr_utils import torch_load
# from espnet.nets.lm_interface import dynamic_import_lm
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
@ -49,12 +45,14 @@ def load_trained_model(args):
model = exp.model model = exp.model
return model, char_list, exp, confs return model, char_list, exp, confs
def get_config(config_path): def get_config(config_path):
stream = open(config_path, mode='r', encoding="utf-8") stream = open(config_path, mode='r', encoding="utf-8")
config = yaml.load(stream, Loader=yaml.FullLoader) config = yaml.load(stream, Loader=yaml.FullLoader)
stream.close() stream.close()
return config return config
def load_trained_lm(args): def load_trained_lm(args):
lm_args = get_config(args.rnnlm_conf) lm_args = get_config(args.rnnlm_conf)
# NOTE: for a compatibility with less than 0.5.0 version models # NOTE: for a compatibility with less than 0.5.0 version models
@ -65,6 +63,7 @@ def load_trained_lm(args):
lm.set_state_dict(model_dict) lm.set_state_dict(model_dict)
return lm return lm
def recog_v2(args): def recog_v2(args):
"""Decode with custom models that implements ScorerInterface. """Decode with custom models that implements ScorerInterface.

@ -21,6 +21,7 @@ from distutils.util import strtobool
import configargparse import configargparse
import numpy as np import numpy as np
def get_parser(): def get_parser():
"""Get default arguments.""" """Get default arguments."""
parser = configargparse.ArgumentParser( parser = configargparse.ArgumentParser(

@ -30,8 +30,7 @@ logger = Log(__name__).getlog()
class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface): class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
def __init__( def __init__(self,
self,
n_vocab: int, n_vocab: int,
pos_enc: str=None, pos_enc: str=None,
embed_unit: int=128, embed_unit: int=128,

@ -1,3 +1,16 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np import numpy as np
@ -34,8 +47,7 @@ class AddDeltas():
def __repr__(self): def __repr__(self):
return "{name}(window={window}, order={order}".format( return "{name}(window={window}, order={order}".format(
name=self.__class__.__name__, window=self.window, order=self.order name=self.__class__.__name__, window=self.window, order=self.order)
)
def __call__(self, x): def __call__(self, x):
return add_deltas(x, window=self.window, order=self.order) return add_deltas(x, window=self.window, order=self.order)

@ -1,3 +1,16 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy import numpy
@ -10,15 +23,12 @@ class ChannelSelector():
self.axis = axis self.axis = axis
def __repr__(self): def __repr__(self):
return ( return ("{name}(train_channel={train_channel}, "
"{name}(train_channel={train_channel}, "
"eval_channel={eval_channel}, axis={axis})".format( "eval_channel={eval_channel}, axis={axis})".format(
name=self.__class__.__name__, name=self.__class__.__name__,
train_channel=self.train_channel, train_channel=self.train_channel,
eval_channel=self.eval_channel, eval_channel=self.eval_channel,
axis=self.axis, axis=self.axis, ))
)
)
def __call__(self, x, train=True): def __call__(self, x, train=True):
# Assuming x: [Time, Channel] by default # Assuming x: [Time, Channel] by default
@ -27,8 +37,8 @@ class ChannelSelector():
# If the dimension is insufficient, then unsqueeze # If the dimension is insufficient, then unsqueeze
# (e.g [Time] -> [Time, 1]) # (e.g [Time] -> [Time, 1])
ind = tuple( ind = tuple(
slice(None) if i < x.ndim else None for i in range(self.axis + 1) slice(None) if i < x.ndim else None
) for i in range(self.axis + 1))
x = x[ind] x = x[ind]
if train: if train:
@ -41,5 +51,6 @@ class ChannelSelector():
else: else:
ch = channel ch = channel
ind = tuple(slice(None) if i != self.axis else ch for i in range(x.ndim)) ind = tuple(
slice(None) if i != self.axis else ch for i in range(x.ndim))
return x[ind] return x[ind]

@ -1,3 +1,16 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect import inspect
from deepspeech.transform.transform_interface import TransformInterface from deepspeech.transform.transform_interface import TransformInterface
@ -57,7 +70,8 @@ class FuncTrans(TransformInterface):
except ValueError: except ValueError:
d = dict() d = dict()
return { return {
k: v.default for k, v in d.items() if v.default != inspect.Parameter.empty k: v.default
for k, v in d.items() if v.default != inspect.Parameter.empty
} }
def __repr__(self): def __repr__(self):

@ -1,3 +1,16 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import librosa import librosa
import numpy import numpy
import scipy import scipy
@ -5,6 +18,7 @@ import soundfile
from deepspeech.io.reader import SoundHDF5File from deepspeech.io.reader import SoundHDF5File
class SpeedPerturbation(): class SpeedPerturbation():
"""SpeedPerturbation """SpeedPerturbation
@ -28,8 +42,7 @@ class SpeedPerturbation():
utt2ratio=None, utt2ratio=None,
keep_length=True, keep_length=True,
res_type="kaiser_best", res_type="kaiser_best",
seed=None, seed=None, ):
):
self.res_type = res_type self.res_type = res_type
self.keep_length = keep_length self.keep_length = keep_length
self.state = numpy.random.RandomState(seed) self.state = numpy.random.RandomState(seed)
@ -60,12 +73,10 @@ class SpeedPerturbation():
self.lower, self.lower,
self.upper, self.upper,
self.keep_length, self.keep_length,
self.res_type, self.res_type, )
)
else: else:
return "{}({}, res_type={})".format( return "{}({}, res_type={})".format(
self.__class__.__name__, self.utt2ratio_file, self.res_type self.__class__.__name__, self.utt2ratio_file, self.res_type)
)
def __call__(self, x, uttid=None, train=True): def __call__(self, x, uttid=None, train=True):
if not train: if not train:
@ -92,8 +103,7 @@ class SpeedPerturbation():
(0, 0) for _ in range(y.ndim - 1) (0, 0) for _ in range(y.ndim - 1)
] ]
y = numpy.pad( y = numpy.pad(
y, pad_width=pad_width, constant_values=0, mode="constant" y, pad_width=pad_width, constant_values=0, mode="constant")
)
return y return y
@ -119,18 +129,16 @@ class BandpassPerturbation():
self.axes = axes self.axes = axes
def __repr__(self): def __repr__(self):
return "{}(lower={}, upper={})".format( return "{}(lower={}, upper={})".format(self.__class__.__name__,
self.__class__.__name__, self.lower, self.upper self.lower, self.upper)
)
def __call__(self, x_stft, uttid=None, train=True): def __call__(self, x_stft, uttid=None, train=True):
if not train: if not train:
return x_stft return x_stft
if x_stft.ndim == 1: if x_stft.ndim == 1:
raise RuntimeError( raise RuntimeError("Input in time-freq domain: "
"Input in time-freq domain: " "(Time, Channel, Freq) or (Time, Freq)" "(Time, Channel, Freq) or (Time, Freq)")
)
ratio = self.state.uniform(self.lower, self.upper) ratio = self.state.uniform(self.lower, self.upper)
axes = [i if i >= 0 else x_stft.ndim - i for i in self.axes] axes = [i if i >= 0 else x_stft.ndim - i for i in self.axes]
@ -142,7 +150,12 @@ class BandpassPerturbation():
class VolumePerturbation(): class VolumePerturbation():
def __init__(self, lower=-1.6, upper=1.6, utt2ratio=None, dbunit=True, seed=None): def __init__(self,
lower=-1.6,
upper=1.6,
utt2ratio=None,
dbunit=True,
seed=None):
self.dbunit = dbunit self.dbunit = dbunit
self.utt2ratio_file = utt2ratio self.utt2ratio_file = utt2ratio
self.lower = lower self.lower = lower
@ -168,12 +181,10 @@ class VolumePerturbation():
def __repr__(self): def __repr__(self):
if self.utt2ratio is None: if self.utt2ratio is None:
return "{}(lower={}, upper={}, dbunit={})".format( return "{}(lower={}, upper={}, dbunit={})".format(
self.__class__.__name__, self.lower, self.upper, self.dbunit self.__class__.__name__, self.lower, self.upper, self.dbunit)
)
else: else:
return '{}("{}", dbunit={})'.format( return '{}("{}", dbunit={})'.format(
self.__class__.__name__, self.utt2ratio_file, self.dbunit self.__class__.__name__, self.utt2ratio_file, self.dbunit)
)
def __call__(self, x, uttid=None, train=True): def __call__(self, x, uttid=None, train=True):
if not train: if not train:
@ -201,8 +212,7 @@ class NoiseInjection():
utt2ratio=None, utt2ratio=None,
filetype="list", filetype="list",
dbunit=True, dbunit=True,
seed=None, seed=None, ):
):
self.utt2noise_file = utt2noise self.utt2noise_file = utt2noise
self.utt2ratio_file = utt2ratio self.utt2ratio_file = utt2ratio
self.filetype = filetype self.filetype = filetype
@ -242,19 +252,16 @@ class NoiseInjection():
if utt2noise is not None and utt2ratio is not None: if utt2noise is not None and utt2ratio is not None:
if set(self.utt2ratio) != set(self.utt2noise): if set(self.utt2ratio) != set(self.utt2noise):
raise RuntimeError( raise RuntimeError("The uttids mismatch between {} and {}".
"The uttids mismatch between {} and {}".format(utt2ratio, utt2noise) format(utt2ratio, utt2noise))
)
def __repr__(self): def __repr__(self):
if self.utt2ratio is None: if self.utt2ratio is None:
return "{}(lower={}, upper={}, dbunit={})".format( return "{}(lower={}, upper={}, dbunit={})".format(
self.__class__.__name__, self.lower, self.upper, self.dbunit self.__class__.__name__, self.lower, self.upper, self.dbunit)
)
else: else:
return '{}("{}", dbunit={})'.format( return '{}("{}", dbunit={})'.format(
self.__class__.__name__, self.utt2ratio_file, self.dbunit self.__class__.__name__, self.utt2ratio_file, self.dbunit)
)
def __call__(self, x, uttid=None, train=True): def __call__(self, x, uttid=None, train=True):
if not train: if not train:
@ -289,7 +296,8 @@ class NoiseInjection():
# Truncate noise # Truncate noise
noise = noise[offset:-(diff - offset)] noise = noise[offset:-(diff - offset)]
else: else:
noise = numpy.pad(noise, pad_width=[offset, diff - offset], mode="wrap") noise = numpy.pad(
noise, pad_width=[offset, diff - offset], mode="wrap")
else: else:
# Generate white noise # Generate white noise
@ -329,15 +337,14 @@ class RIRConvolve():
if x.ndim != 1: if x.ndim != 1:
# Must be single channel # Must be single channel
raise RuntimeError( raise RuntimeError(
"Input x must be one dimensional array, but got {}".format(x.shape) "Input x must be one dimensional array, but got {}".format(
) x.shape))
rir, rate = self.utt2rir[uttid] rir, rate = self.utt2rir[uttid]
if rir.ndim == 2: if rir.ndim == 2:
# FIXME(kamo): Use chainer.convolution_1d? # FIXME(kamo): Use chainer.convolution_1d?
# return [Time, Channel] # return [Time, Channel]
return numpy.stack( return numpy.stack(
[scipy.convolve(x, r, mode="same") for r in rir], axis=-1 [scipy.convolve(x, r, mode="same") for r in rir], axis=-1)
)
else: else:
return scipy.convolve(x, rir, mode="same") return scipy.convolve(x, rir, mode="same")

@ -1,5 +1,17 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Spec Augment module for preprocessing i.e., data augmentation""" """Spec Augment module for preprocessing i.e., data augmentation"""
import random import random
import numpy import numpy
@ -27,10 +39,12 @@ def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"):
return x return x
# NOTE: randrange(a, b) emits a, a + 1, ..., b - 1 # NOTE: randrange(a, b) emits a, a + 1, ..., b - 1
center = random.randrange(window, t - window) center = random.randrange(window, t - window)
warped = random.randrange(center - window, center + window) + 1 # 1 ... t - 1 warped = random.randrange(center - window, center +
window) + 1 # 1 ... t - 1
left = Image.fromarray(x[:center]).resize((x.shape[1], warped), BICUBIC) left = Image.fromarray(x[:center]).resize((x.shape[1], warped), BICUBIC)
right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped), BICUBIC) right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped),
BICUBIC)
if inplace: if inplace:
x[:warped] = left x[:warped] = left
x[warped:] = right x[warped:] = right
@ -44,11 +58,8 @@ def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"):
# TODO(karita): make this differentiable again # TODO(karita): make this differentiable again
return spec_augment.time_warp(paddle.to_tensor(x), window).numpy() return spec_augment.time_warp(paddle.to_tensor(x), window).numpy()
else: else:
raise NotImplementedError( raise NotImplementedError("unknown resize mode: " + mode +
"unknown resize mode: " ", choose one from (PIL, sparse_image_warp).")
+ mode
+ ", choose one from (PIL, sparse_image_warp)."
)
class TimeWarp(FuncTrans): class TimeWarp(FuncTrans):
@ -153,8 +164,7 @@ def spec_augment(
max_time_width=100, max_time_width=100,
n_time_mask=2, n_time_mask=2,
inplace=True, inplace=True,
replace_with_zero=True, replace_with_zero=True, ):
):
"""spec agument """spec agument
apply random time warping and time/freq masking apply random time warping and time/freq masking
@ -180,15 +190,13 @@ def spec_augment(
max_freq_width, max_freq_width,
n_freq_mask, n_freq_mask,
inplace=inplace, inplace=inplace,
replace_with_zero=replace_with_zero, replace_with_zero=replace_with_zero, )
)
x = time_mask( x = time_mask(
x, x,
max_time_width, max_time_width,
n_time_mask, n_time_mask,
inplace=inplace, inplace=inplace,
replace_with_zero=replace_with_zero, replace_with_zero=replace_with_zero, )
)
return x return x

@ -1,10 +1,27 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import librosa import librosa
import numpy as np import numpy as np
def stft( def stft(x,
x, n_fft, n_shift, win_length=None, window="hann", center=True, pad_mode="reflect" n_fft,
): n_shift,
win_length=None,
window="hann",
center=True,
pad_mode="reflect"):
# x: [Time, Channel] # x: [Time, Channel]
if x.ndim == 1: if x.ndim == 1:
single_channel = True single_channel = True
@ -25,12 +42,9 @@ def stft(
win_length=win_length, win_length=win_length,
window=window, window=window,
center=center, center=center,
pad_mode=pad_mode, pad_mode=pad_mode, ).T for ch in range(x.shape[1])
).T
for ch in range(x.shape[1])
], ],
axis=1, axis=1, )
)
if single_channel: if single_channel:
# x: [Time, Channel, Freq] -> [Time, Freq] # x: [Time, Channel, Freq] -> [Time, Freq]
@ -55,12 +69,9 @@ def istft(x, n_shift, win_length=None, window="hann", center=True):
hop_length=n_shift, hop_length=n_shift,
win_length=win_length, win_length=win_length,
window=window, window=window,
center=center, center=center, ) for ch in range(x.shape[1])
)
for ch in range(x.shape[1])
], ],
axis=1, axis=1, )
)
if single_channel: if single_channel:
# x: [Time, Channel] -> [Time] # x: [Time, Channel] -> [Time]
@ -68,7 +79,13 @@ def istft(x, n_shift, win_length=None, window="hann", center=True):
return x return x
def stft2logmelspectrogram(x_stft, fs, n_mels, n_fft, fmin=None, fmax=None, eps=1e-10): def stft2logmelspectrogram(x_stft,
fs,
n_mels,
n_fft,
fmin=None,
fmax=None,
eps=1e-10):
# x_stft: (Time, Channel, Freq) or (Time, Freq) # x_stft: (Time, Channel, Freq) or (Time, Freq)
fmin = 0 if fmin is None else fmin fmin = 0 if fmin is None else fmin
fmax = fs / 2 if fmax is None else fmax fmax = fs / 2 if fmax is None else fmax
@ -100,8 +117,7 @@ def logmelspectrogram(
fmin=None, fmin=None,
fmax=None, fmax=None,
eps=1e-10, eps=1e-10,
pad_mode="reflect", pad_mode="reflect", ):
):
# stft: (Time, Channel, Freq) or (Time, Freq) # stft: (Time, Channel, Freq) or (Time, Freq)
x_stft = stft( x_stft = stft(
x, x,
@ -109,12 +125,16 @@ def logmelspectrogram(
n_shift=n_shift, n_shift=n_shift,
win_length=win_length, win_length=win_length,
window=window, window=window,
pad_mode=pad_mode, pad_mode=pad_mode, )
)
return stft2logmelspectrogram( return stft2logmelspectrogram(
x_stft, fs=fs, n_mels=n_mels, n_fft=n_fft, fmin=fmin, fmax=fmax, eps=eps x_stft,
) fs=fs,
n_mels=n_mels,
n_fft=n_fft,
fmin=fmin,
fmax=fmax,
eps=eps)
class Spectrogram(): class Spectrogram():
@ -125,16 +145,13 @@ class Spectrogram():
self.window = window self.window = window
def __repr__(self): def __repr__(self):
return ( return ("{name}(n_fft={n_fft}, n_shift={n_shift}, "
"{name}(n_fft={n_fft}, n_shift={n_shift}, "
"win_length={win_length}, window={window})".format( "win_length={win_length}, window={window})".format(
name=self.__class__.__name__, name=self.__class__.__name__,
n_fft=self.n_fft, n_fft=self.n_fft,
n_shift=self.n_shift, n_shift=self.n_shift,
win_length=self.win_length, win_length=self.win_length,
window=self.window, window=self.window, ))
)
)
def __call__(self, x): def __call__(self, x):
return spectrogram( return spectrogram(
@ -142,8 +159,7 @@ class Spectrogram():
n_fft=self.n_fft, n_fft=self.n_fft,
n_shift=self.n_shift, n_shift=self.n_shift,
win_length=self.win_length, win_length=self.win_length,
window=self.window, window=self.window, )
)
class LogMelSpectrogram(): class LogMelSpectrogram():
@ -157,8 +173,7 @@ class LogMelSpectrogram():
window="hann", window="hann",
fmin=None, fmin=None,
fmax=None, fmax=None,
eps=1e-10, eps=1e-10, ):
):
self.fs = fs self.fs = fs
self.n_mels = n_mels self.n_mels = n_mels
self.n_fft = n_fft self.n_fft = n_fft
@ -170,8 +185,7 @@ class LogMelSpectrogram():
self.eps = eps self.eps = eps
def __repr__(self): def __repr__(self):
return ( return ("{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
"{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
"n_shift={n_shift}, win_length={win_length}, window={window}, " "n_shift={n_shift}, win_length={win_length}, window={window}, "
"fmin={fmin}, fmax={fmax}, eps={eps}))".format( "fmin={fmin}, fmax={fmax}, eps={eps}))".format(
name=self.__class__.__name__, name=self.__class__.__name__,
@ -183,9 +197,7 @@ class LogMelSpectrogram():
window=self.window, window=self.window,
fmin=self.fmin, fmin=self.fmin,
fmax=self.fmax, fmax=self.fmax,
eps=self.eps, eps=self.eps, ))
)
)
def __call__(self, x): def __call__(self, x):
return logmelspectrogram( return logmelspectrogram(
@ -195,8 +207,7 @@ class LogMelSpectrogram():
n_fft=self.n_fft, n_fft=self.n_fft,
n_shift=self.n_shift, n_shift=self.n_shift,
win_length=self.win_length, win_length=self.win_length,
window=self.window, window=self.window, )
)
class Stft2LogMelSpectrogram(): class Stft2LogMelSpectrogram():
@ -209,8 +220,7 @@ class Stft2LogMelSpectrogram():
self.eps = eps self.eps = eps
def __repr__(self): def __repr__(self):
return ( return ("{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
"{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
"fmin={fmin}, fmax={fmax}, eps={eps}))".format( "fmin={fmin}, fmax={fmax}, eps={eps}))".format(
name=self.__class__.__name__, name=self.__class__.__name__,
fs=self.fs, fs=self.fs,
@ -218,9 +228,7 @@ class Stft2LogMelSpectrogram():
n_fft=self.n_fft, n_fft=self.n_fft,
fmin=self.fmin, fmin=self.fmin,
fmax=self.fmax, fmax=self.fmax,
eps=self.eps, eps=self.eps, ))
)
)
def __call__(self, x): def __call__(self, x):
return stft2logmelspectrogram( return stft2logmelspectrogram(
@ -229,8 +237,7 @@ class Stft2LogMelSpectrogram():
n_mels=self.n_mels, n_mels=self.n_mels,
n_fft=self.n_fft, n_fft=self.n_fft,
fmin=self.fmin, fmin=self.fmin,
fmax=self.fmax, fmax=self.fmax, )
)
class Stft(): class Stft():
@ -241,8 +248,7 @@ class Stft():
win_length=None, win_length=None,
window="hann", window="hann",
center=True, center=True,
pad_mode="reflect", pad_mode="reflect", ):
):
self.n_fft = n_fft self.n_fft = n_fft
self.n_shift = n_shift self.n_shift = n_shift
self.win_length = win_length self.win_length = win_length
@ -251,8 +257,7 @@ class Stft():
self.pad_mode = pad_mode self.pad_mode = pad_mode
def __repr__(self): def __repr__(self):
return ( return ("{name}(n_fft={n_fft}, n_shift={n_shift}, "
"{name}(n_fft={n_fft}, n_shift={n_shift}, "
"win_length={win_length}, window={window}," "win_length={win_length}, window={window},"
"center={center}, pad_mode={pad_mode})".format( "center={center}, pad_mode={pad_mode})".format(
name=self.__class__.__name__, name=self.__class__.__name__,
@ -261,9 +266,7 @@ class Stft():
win_length=self.win_length, win_length=self.win_length,
window=self.window, window=self.window,
center=self.center, center=self.center,
pad_mode=self.pad_mode, pad_mode=self.pad_mode, ))
)
)
def __call__(self, x): def __call__(self, x):
return stft( return stft(
@ -273,8 +276,7 @@ class Stft():
win_length=self.win_length, win_length=self.win_length,
window=self.window, window=self.window,
center=self.center, center=self.center,
pad_mode=self.pad_mode, pad_mode=self.pad_mode, )
)
class IStft(): class IStft():
@ -285,17 +287,14 @@ class IStft():
self.center = center self.center = center
def __repr__(self): def __repr__(self):
return ( return ("{name}(n_shift={n_shift}, "
"{name}(n_shift={n_shift}, "
"win_length={win_length}, window={window}," "win_length={win_length}, window={window},"
"center={center})".format( "center={center})".format(
name=self.__class__.__name__, name=self.__class__.__name__,
n_shift=self.n_shift, n_shift=self.n_shift,
win_length=self.win_length, win_length=self.win_length,
window=self.window, window=self.window,
center=self.center, center=self.center, ))
)
)
def __call__(self, x): def __call__(self, x):
return istft( return istft(
@ -303,5 +302,4 @@ class IStft():
self.n_shift, self.n_shift,
win_length=self.win_length, win_length=self.win_length,
window=self.window, window=self.window,
center=self.center, center=self.center, )
)

@ -1,3 +1,16 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(karita): add this to all the transform impl. # TODO(karita): add this to all the transform impl.
class TransformInterface: class TransformInterface:
"""Transform Interface""" """Transform Interface"""

@ -1,16 +1,28 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformation module.""" """Transformation module."""
from collections.abc import Sequence
from collections import OrderedDict
import copy import copy
from inspect import signature
import io import io
import logging import logging
from collections import OrderedDict
from collections.abc import Sequence
from inspect import signature
import yaml import yaml
from deepspeech.utils.dynamic_import import dynamic_import from deepspeech.utils.dynamic_import import dynamic_import
# TODO(karita): inherit TransformInterface # TODO(karita): inherit TransformInterface
# TODO(karita): register cmd arguments in asr_train.py # TODO(karita): register cmd arguments in asr_train.py
import_alias = dict( import_alias = dict(
@ -33,8 +45,7 @@ import_alias = dict(
istft="deepspeech.transform.spectrogram:IStft", istft="deepspeech.transform.spectrogram:IStft",
stft2fbank="deepspeech.transform.spectrogram:Stft2LogMelSpectrogram", stft2fbank="deepspeech.transform.spectrogram:Stft2LogMelSpectrogram",
wpe="deepspeech.transform.wpe:WPE", wpe="deepspeech.transform.wpe:WPE",
channel_selector="deepspeech.transform.channel_selector:ChannelSelector", channel_selector="deepspeech.transform.channel_selector:ChannelSelector", )
)
class Transformation(): class Transformation():
@ -83,21 +94,16 @@ class Transformation():
# Some function, e.g. built-in function, are failed # Some function, e.g. built-in function, are failed
pass pass
else: else:
logging.error( logging.error("Expected signature: {}({})".format(
"Expected signature: {}({})".format( class_obj.__name__, signa))
class_obj.__name__, signa
)
)
raise raise
else: else:
raise NotImplementedError( raise NotImplementedError(
"Not supporting mode={}".format(self.conf["mode"]) "Not supporting mode={}".format(self.conf["mode"]))
)
def __repr__(self): def __repr__(self):
rep = "\n" + "\n".join( rep = "\n" + "\n".join(" {}: {}".format(k, v)
" {}: {}".format(k, v) for k, v in self.functions.items() for k, v in self.functions.items())
)
return "{}({})".format(self.__class__.__name__, rep) return "{}({})".format(self.__class__.__name__, rep)
def __call__(self, xs, uttid_list=None, **kwargs): def __call__(self, xs, uttid_list=None, **kwargs):
@ -130,18 +136,19 @@ class Transformation():
_kwargs = {k: v for k, v in kwargs.items() if k in param} _kwargs = {k: v for k, v in kwargs.items() if k in param}
try: try:
if uttid_list is not None and "uttid" in param: if uttid_list is not None and "uttid" in param:
xs = [func(x, u, **_kwargs) for x, u in zip(xs, uttid_list)] xs = [
func(x, u, **_kwargs)
for x, u in zip(xs, uttid_list)
]
else: else:
xs = [func(x, **_kwargs) for x in xs] xs = [func(x, **_kwargs) for x in xs]
except Exception: except Exception:
logging.fatal( logging.fatal("Catch a exception from {}th func: {}".format(
"Catch a exception from {}th func: {}".format(idx, func) idx, func))
)
raise raise
else: else:
raise NotImplementedError( raise NotImplementedError(
"Not supporting mode={}".format(self.conf["mode"]) "Not supporting mode={}".format(self.conf["mode"]))
)
if is_batch: if is_batch:
return xs return xs

@ -1,10 +1,26 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from nara_wpe.wpe import wpe from nara_wpe.wpe import wpe
class WPE(object): class WPE(object):
def __init__( def __init__(self,
self, taps=10, delay=3, iterations=3, psd_context=0, statistics_mode="full" taps=10,
): delay=3,
iterations=3,
psd_context=0,
statistics_mode="full"):
self.taps = taps self.taps = taps
self.delay = delay self.delay = delay
self.iterations = iterations self.iterations = iterations
@ -12,8 +28,7 @@ class WPE(object):
self.statistics_mode = statistics_mode self.statistics_mode = statistics_mode
def __repr__(self): def __repr__(self):
return ( return ("{name}(taps={taps}, delay={delay}"
"{name}(taps={taps}, delay={delay}"
"iterations={iterations}, psd_context={psd_context}, " "iterations={iterations}, psd_context={psd_context}, "
"statistics_mode={statistics_mode})".format( "statistics_mode={statistics_mode})".format(
name=self.__class__.__name__, name=self.__class__.__name__,
@ -21,9 +36,7 @@ class WPE(object):
delay=self.delay, delay=self.delay,
iterations=self.iterations, iterations=self.iterations,
psd_context=self.psd_context, psd_context=self.psd_context,
statistics_mode=self.statistics_mode, statistics_mode=self.statistics_mode, ))
)
)
def __call__(self, xs): def __call__(self, xs):
"""Return enhanced """Return enhanced
@ -40,6 +53,5 @@ class WPE(object):
delay=self.delay, delay=self.delay,
iterations=self.iterations, iterations=self.iterations,
psd_context=self.psd_context, psd_context=self.psd_context,
statistics_mode=self.statistics_mode, statistics_mode=self.statistics_mode, )
)
return xs.transpose(2, 1, 0) return xs.transpose(2, 1, 0)

@ -1,3 +1,16 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect import inspect
@ -17,4 +30,5 @@ def check_kwargs(func, kwargs, name=None):
name = func.__name__ name = func.__name__
for k in kwargs.keys(): for k in kwargs.keys():
if k not in params: if k not in params:
raise TypeError(f"{name}() got an unexpected keyword argument '{k}'") raise TypeError(
f"{name}() got an unexpected keyword argument '{k}'")

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -53,8 +53,8 @@ def batch_text_id(minibatch, pad_id=0, dtype=np.int64):
peek_example = minibatch[0] peek_example = minibatch[0]
assert len(peek_example.shape) == 1, "text example is an 1D tensor" assert len(peek_example.shape) == 1, "text example is an 1D tensor"
lengths = [example.shape[0] for example in minibatch lengths = [example.shape[0] for example in
] # assume (channel, n_samples) or (n_samples, ) minibatch] # assume (channel, n_samples) or (n_samples, )
max_len = np.max(lengths) max_len = np.max(lengths)
batch = [] batch = []

@ -67,19 +67,16 @@ class LJSpeechCollector(object):
# Sort by text_len in descending order # Sort by text_len in descending order
texts = [ texts = [
i i for i, _ in sorted(
for i, _ in sorted(
zip(texts, text_lens), key=lambda x: x[1], reverse=True) zip(texts, text_lens), key=lambda x: x[1], reverse=True)
] ]
mels = [ mels = [
i i for i, _ in sorted(
for i, _ in sorted(
zip(mels, text_lens), key=lambda x: x[1], reverse=True) zip(mels, text_lens), key=lambda x: x[1], reverse=True)
] ]
mel_lens = [ mel_lens = [
i i for i, _ in sorted(
for i, _ in sorted(
zip(mel_lens, text_lens), key=lambda x: x[1], reverse=True) zip(mel_lens, text_lens), key=lambda x: x[1], reverse=True)
] ]

@ -13,7 +13,7 @@ librosa
llvmlite llvmlite
loguru loguru
matplotlib matplotlib
nltk nara_wpenltk
numba numba
numpy==1.20.0 numpy==1.20.0
pandas pandas
@ -42,4 +42,3 @@ visualdl==2.2.0
webrtcvad webrtcvad
yacs yacs
yq yq
nara_wpe

@ -12,33 +12,32 @@ from deepspeech.utils.cli_utils import is_scipy_wav_style
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="convert feature to its shape", description="convert feature to its shape",
formatter_class=argparse.ArgumentDefaultsHelpFormatter, formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
) parser.add_argument(
parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option") "--verbose", "-V", default=0, type=int, help="Verbose option")
parser.add_argument( parser.add_argument(
"--filetype", "--filetype",
type=str, type=str,
default="mat", default="mat",
choices=["mat", "hdf5", "sound.hdf5", "sound"], choices=["mat", "hdf5", "sound.hdf5", "sound"],
help="Specify the file format for the rspecifier. " help="Specify the file format for the rspecifier. "
'"mat" is the matrix format in kaldi', '"mat" is the matrix format in kaldi', )
)
parser.add_argument( parser.add_argument(
"--preprocess-conf", "--preprocess-conf",
type=str, type=str,
default=None, default=None,
help="The configuration file for the pre-processing", help="The configuration file for the pre-processing", )
)
parser.add_argument( parser.add_argument(
"rspecifier", type=str, help="Read specifier for feats. e.g. ark:some.ark" "rspecifier",
) type=str,
help="Read specifier for feats. e.g. ark:some.ark")
parser.add_argument( parser.add_argument(
"out", "out",
nargs="?", nargs="?",
type=argparse.FileType("w"), type=argparse.FileType("w"),
default=sys.stdout, default=sys.stdout,
help="The output filename. " "If omitted, then output to sys.stdout", help="The output filename. "
) "If omitted, then output to sys.stdout", )
return parser return parser
@ -64,8 +63,7 @@ def main():
# so change to file_reader_helper to return shape. # so change to file_reader_helper to return shape.
# This make sense only with filetype="hdf5". # This make sense only with filetype="hdf5".
for utt, mat in file_reader_helper( for utt, mat in file_reader_helper(
args.rspecifier, args.filetype, return_shape=preprocessing is None args.rspecifier, args.filetype, return_shape=preprocessing is None):
):
if preprocessing is not None: if preprocessing is not None:
if is_scipy_wav_style(mat): if is_scipy_wav_style(mat):
# If data is sound file, then got as Tuple[int, ndarray] # If data is sound file, then got as Tuple[int, ndarray]

Loading…
Cancel
Save