Merge pull request #933 from PaddlePaddle/format

[format] format code
pull/936/head
TianYuan 3 years ago committed by GitHub
commit eb65793769
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -24,11 +24,7 @@ from .utils import add_results_to_json
from deepspeech.exps import dynamic_import_tester
from deepspeech.io.reader import LoadInputsAndTargets
from deepspeech.models.asr_interface import ASRInterface
from deepspeech.models.lm.transformer import TransformerLM
from deepspeech.utils.log import Log
# from espnet.asr.asr_utils import get_model_conf
# from espnet.asr.asr_utils import torch_load
# from espnet.nets.lm_interface import dynamic_import_lm
logger = Log(__name__).getlog()
@ -49,12 +45,14 @@ def load_trained_model(args):
model = exp.model
return model, char_list, exp, confs
def get_config(config_path):
stream = open(config_path, mode='r', encoding="utf-8")
config = yaml.load(stream, Loader=yaml.FullLoader)
stream.close()
return config
def load_trained_lm(args):
lm_args = get_config(args.rnnlm_conf)
# NOTE: for a compatibility with less than 0.5.0 version models
@ -65,6 +63,7 @@ def load_trained_lm(args):
lm.set_state_dict(model_dict)
return lm
def recog_v2(args):
"""Decode with custom models that implements ScorerInterface.

@ -21,6 +21,7 @@ from distutils.util import strtobool
import configargparse
import numpy as np
def get_parser():
"""Get default arguments."""
parser = configargparse.ArgumentParser(

@ -30,8 +30,7 @@ logger = Log(__name__).getlog()
class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
def __init__(
self,
def __init__(self,
n_vocab: int,
pos_enc: str=None,
embed_unit: int=128,

@ -1,3 +1,16 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
@ -9,7 +22,7 @@ def delta(feat, window):
delta_feat[i:] += -i * feat[:-i]
delta_feat[-i:] += i * feat[-1]
delta_feat[:i] += -i * feat[0]
delta_feat /= 2 * sum(i ** 2 for i in range(1, window + 1))
delta_feat /= 2 * sum(i**2 for i in range(1, window + 1))
return delta_feat
@ -34,8 +47,7 @@ class AddDeltas():
def __repr__(self):
return "{name}(window={window}, order={order}".format(
name=self.__class__.__name__, window=self.window, order=self.order
)
name=self.__class__.__name__, window=self.window, order=self.order)
def __call__(self, x):
return add_deltas(x, window=self.window, order=self.order)

@ -1,3 +1,16 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy
@ -10,15 +23,12 @@ class ChannelSelector():
self.axis = axis
def __repr__(self):
return (
"{name}(train_channel={train_channel}, "
return ("{name}(train_channel={train_channel}, "
"eval_channel={eval_channel}, axis={axis})".format(
name=self.__class__.__name__,
train_channel=self.train_channel,
eval_channel=self.eval_channel,
axis=self.axis,
)
)
axis=self.axis, ))
def __call__(self, x, train=True):
# Assuming x: [Time, Channel] by default
@ -27,8 +37,8 @@ class ChannelSelector():
# If the dimension is insufficient, then unsqueeze
# (e.g [Time] -> [Time, 1])
ind = tuple(
slice(None) if i < x.ndim else None for i in range(self.axis + 1)
)
slice(None) if i < x.ndim else None
for i in range(self.axis + 1))
x = x[ind]
if train:
@ -41,5 +51,6 @@ class ChannelSelector():
else:
ch = channel
ind = tuple(slice(None) if i != self.axis else ch for i in range(x.ndim))
ind = tuple(
slice(None) if i != self.axis else ch for i in range(x.ndim))
return x[ind]

@ -1,3 +1,16 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from deepspeech.transform.transform_interface import TransformInterface
@ -57,7 +70,8 @@ class FuncTrans(TransformInterface):
except ValueError:
d = dict()
return {
k: v.default for k, v in d.items() if v.default != inspect.Parameter.empty
k: v.default
for k, v in d.items() if v.default != inspect.Parameter.empty
}
def __repr__(self):

@ -1,3 +1,16 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import librosa
import numpy
import scipy
@ -5,6 +18,7 @@ import soundfile
from deepspeech.io.reader import SoundHDF5File
class SpeedPerturbation():
"""SpeedPerturbation
@ -28,8 +42,7 @@ class SpeedPerturbation():
utt2ratio=None,
keep_length=True,
res_type="kaiser_best",
seed=None,
):
seed=None, ):
self.res_type = res_type
self.keep_length = keep_length
self.state = numpy.random.RandomState(seed)
@ -60,12 +73,10 @@ class SpeedPerturbation():
self.lower,
self.upper,
self.keep_length,
self.res_type,
)
self.res_type, )
else:
return "{}({}, res_type={})".format(
self.__class__.__name__, self.utt2ratio_file, self.res_type
)
self.__class__.__name__, self.utt2ratio_file, self.res_type)
def __call__(self, x, uttid=None, train=True):
if not train:
@ -85,15 +96,14 @@ class SpeedPerturbation():
diff = abs(len(x) - len(y))
if len(y) > len(x):
# Truncate noise
y = y[diff // 2 : -((diff + 1) // 2)]
y = y[diff // 2:-((diff + 1) // 2)]
elif len(y) < len(x):
# Assume the time-axis is the first: (Time, Channel)
pad_width = [(diff // 2, (diff + 1) // 2)] + [
(0, 0) for _ in range(y.ndim - 1)
]
y = numpy.pad(
y, pad_width=pad_width, constant_values=0, mode="constant"
)
y, pad_width=pad_width, constant_values=0, mode="constant")
return y
@ -111,7 +121,7 @@ class BandpassPerturbation():
"""
def __init__(self, lower=0.0, upper=0.75, seed=None, axes=(-1,)):
def __init__(self, lower=0.0, upper=0.75, seed=None, axes=(-1, )):
self.lower = lower
self.upper = upper
self.state = numpy.random.RandomState(seed)
@ -119,18 +129,16 @@ class BandpassPerturbation():
self.axes = axes
def __repr__(self):
return "{}(lower={}, upper={})".format(
self.__class__.__name__, self.lower, self.upper
)
return "{}(lower={}, upper={})".format(self.__class__.__name__,
self.lower, self.upper)
def __call__(self, x_stft, uttid=None, train=True):
if not train:
return x_stft
if x_stft.ndim == 1:
raise RuntimeError(
"Input in time-freq domain: " "(Time, Channel, Freq) or (Time, Freq)"
)
raise RuntimeError("Input in time-freq domain: "
"(Time, Channel, Freq) or (Time, Freq)")
ratio = self.state.uniform(self.lower, self.upper)
axes = [i if i >= 0 else x_stft.ndim - i for i in self.axes]
@ -142,7 +150,12 @@ class BandpassPerturbation():
class VolumePerturbation():
def __init__(self, lower=-1.6, upper=1.6, utt2ratio=None, dbunit=True, seed=None):
def __init__(self,
lower=-1.6,
upper=1.6,
utt2ratio=None,
dbunit=True,
seed=None):
self.dbunit = dbunit
self.utt2ratio_file = utt2ratio
self.lower = lower
@ -168,12 +181,10 @@ class VolumePerturbation():
def __repr__(self):
if self.utt2ratio is None:
return "{}(lower={}, upper={}, dbunit={})".format(
self.__class__.__name__, self.lower, self.upper, self.dbunit
)
self.__class__.__name__, self.lower, self.upper, self.dbunit)
else:
return '{}("{}", dbunit={})'.format(
self.__class__.__name__, self.utt2ratio_file, self.dbunit
)
self.__class__.__name__, self.utt2ratio_file, self.dbunit)
def __call__(self, x, uttid=None, train=True):
if not train:
@ -186,7 +197,7 @@ class VolumePerturbation():
else:
ratio = self.state.uniform(self.lower, self.upper)
if self.dbunit:
ratio = 10 ** (ratio / 20)
ratio = 10**(ratio / 20)
return x * ratio
@ -201,8 +212,7 @@ class NoiseInjection():
utt2ratio=None,
filetype="list",
dbunit=True,
seed=None,
):
seed=None, ):
self.utt2noise_file = utt2noise
self.utt2ratio_file = utt2ratio
self.filetype = filetype
@ -242,19 +252,16 @@ class NoiseInjection():
if utt2noise is not None and utt2ratio is not None:
if set(self.utt2ratio) != set(self.utt2noise):
raise RuntimeError(
"The uttids mismatch between {} and {}".format(utt2ratio, utt2noise)
)
raise RuntimeError("The uttids mismatch between {} and {}".
format(utt2ratio, utt2noise))
def __repr__(self):
if self.utt2ratio is None:
return "{}(lower={}, upper={}, dbunit={})".format(
self.__class__.__name__, self.lower, self.upper, self.dbunit
)
self.__class__.__name__, self.lower, self.upper, self.dbunit)
else:
return '{}("{}", dbunit={})'.format(
self.__class__.__name__, self.utt2ratio_file, self.dbunit
)
self.__class__.__name__, self.utt2ratio_file, self.dbunit)
def __call__(self, x, uttid=None, train=True):
if not train:
@ -268,8 +275,8 @@ class NoiseInjection():
ratio = self.state.uniform(self.lower, self.upper)
if self.dbunit:
ratio = 10 ** (ratio / 20)
scale = ratio * numpy.sqrt((x ** 2).mean())
ratio = 10**(ratio / 20)
scale = ratio * numpy.sqrt((x**2).mean())
# 2. Get noise
if self.utt2noise is not None:
@ -280,16 +287,17 @@ class NoiseInjection():
# Randomly select the noise source
noise = self.state.choice(list(self.utt2noise.values()))
# Normalize the level
noise /= numpy.sqrt((noise ** 2).mean())
noise /= numpy.sqrt((noise**2).mean())
# Adjust the noise length
diff = abs(len(x) - len(noise))
offset = self.state.randint(0, diff)
if len(noise) > len(x):
# Truncate noise
noise = noise[offset : -(diff - offset)]
noise = noise[offset:-(diff - offset)]
else:
noise = numpy.pad(noise, pad_width=[offset, diff - offset], mode="wrap")
noise = numpy.pad(
noise, pad_width=[offset, diff - offset], mode="wrap")
else:
# Generate white noise
@ -329,15 +337,14 @@ class RIRConvolve():
if x.ndim != 1:
# Must be single channel
raise RuntimeError(
"Input x must be one dimensional array, but got {}".format(x.shape)
)
"Input x must be one dimensional array, but got {}".format(
x.shape))
rir, rate = self.utt2rir[uttid]
if rir.ndim == 2:
# FIXME(kamo): Use chainer.convolution_1d?
# return [Time, Channel]
return numpy.stack(
[scipy.convolve(x, r, mode="same") for r in rir], axis=-1
)
[scipy.convolve(x, r, mode="same") for r in rir], axis=-1)
else:
return scipy.convolve(x, rir, mode="same")

@ -1,5 +1,17 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Spec Augment module for preprocessing i.e., data augmentation"""
import random
import numpy
@ -27,10 +39,12 @@ def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"):
return x
# NOTE: randrange(a, b) emits a, a + 1, ..., b - 1
center = random.randrange(window, t - window)
warped = random.randrange(center - window, center + window) + 1 # 1 ... t - 1
warped = random.randrange(center - window, center +
window) + 1 # 1 ... t - 1
left = Image.fromarray(x[:center]).resize((x.shape[1], warped), BICUBIC)
right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped), BICUBIC)
right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped),
BICUBIC)
if inplace:
x[:warped] = left
x[warped:] = right
@ -44,11 +58,8 @@ def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"):
# TODO(karita): make this differentiable again
return spec_augment.time_warp(paddle.to_tensor(x), window).numpy()
else:
raise NotImplementedError(
"unknown resize mode: "
+ mode
+ ", choose one from (PIL, sparse_image_warp)."
)
raise NotImplementedError("unknown resize mode: " + mode +
", choose one from (PIL, sparse_image_warp).")
class TimeWarp(FuncTrans):
@ -153,8 +164,7 @@ def spec_augment(
max_time_width=100,
n_time_mask=2,
inplace=True,
replace_with_zero=True,
):
replace_with_zero=True, ):
"""spec agument
apply random time warping and time/freq masking
@ -180,15 +190,13 @@ def spec_augment(
max_freq_width,
n_freq_mask,
inplace=inplace,
replace_with_zero=replace_with_zero,
)
replace_with_zero=replace_with_zero, )
x = time_mask(
x,
max_time_width,
n_time_mask,
inplace=inplace,
replace_with_zero=replace_with_zero,
)
replace_with_zero=replace_with_zero, )
return x

@ -1,10 +1,27 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import librosa
import numpy as np
def stft(
x, n_fft, n_shift, win_length=None, window="hann", center=True, pad_mode="reflect"
):
def stft(x,
n_fft,
n_shift,
win_length=None,
window="hann",
center=True,
pad_mode="reflect"):
# x: [Time, Channel]
if x.ndim == 1:
single_channel = True
@ -25,12 +42,9 @@ def stft(
win_length=win_length,
window=window,
center=center,
pad_mode=pad_mode,
).T
for ch in range(x.shape[1])
pad_mode=pad_mode, ).T for ch in range(x.shape[1])
],
axis=1,
)
axis=1, )
if single_channel:
# x: [Time, Channel, Freq] -> [Time, Freq]
@ -55,12 +69,9 @@ def istft(x, n_shift, win_length=None, window="hann", center=True):
hop_length=n_shift,
win_length=win_length,
window=window,
center=center,
)
for ch in range(x.shape[1])
center=center, ) for ch in range(x.shape[1])
],
axis=1,
)
axis=1, )
if single_channel:
# x: [Time, Channel] -> [Time]
@ -68,7 +79,13 @@ def istft(x, n_shift, win_length=None, window="hann", center=True):
return x
def stft2logmelspectrogram(x_stft, fs, n_mels, n_fft, fmin=None, fmax=None, eps=1e-10):
def stft2logmelspectrogram(x_stft,
fs,
n_mels,
n_fft,
fmin=None,
fmax=None,
eps=1e-10):
# x_stft: (Time, Channel, Freq) or (Time, Freq)
fmin = 0 if fmin is None else fmin
fmax = fs / 2 if fmax is None else fmax
@ -100,8 +117,7 @@ def logmelspectrogram(
fmin=None,
fmax=None,
eps=1e-10,
pad_mode="reflect",
):
pad_mode="reflect", ):
# stft: (Time, Channel, Freq) or (Time, Freq)
x_stft = stft(
x,
@ -109,12 +125,16 @@ def logmelspectrogram(
n_shift=n_shift,
win_length=win_length,
window=window,
pad_mode=pad_mode,
)
pad_mode=pad_mode, )
return stft2logmelspectrogram(
x_stft, fs=fs, n_mels=n_mels, n_fft=n_fft, fmin=fmin, fmax=fmax, eps=eps
)
x_stft,
fs=fs,
n_mels=n_mels,
n_fft=n_fft,
fmin=fmin,
fmax=fmax,
eps=eps)
class Spectrogram():
@ -125,16 +145,13 @@ class Spectrogram():
self.window = window
def __repr__(self):
return (
"{name}(n_fft={n_fft}, n_shift={n_shift}, "
return ("{name}(n_fft={n_fft}, n_shift={n_shift}, "
"win_length={win_length}, window={window})".format(
name=self.__class__.__name__,
n_fft=self.n_fft,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window,
)
)
window=self.window, ))
def __call__(self, x):
return spectrogram(
@ -142,8 +159,7 @@ class Spectrogram():
n_fft=self.n_fft,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window,
)
window=self.window, )
class LogMelSpectrogram():
@ -157,8 +173,7 @@ class LogMelSpectrogram():
window="hann",
fmin=None,
fmax=None,
eps=1e-10,
):
eps=1e-10, ):
self.fs = fs
self.n_mels = n_mels
self.n_fft = n_fft
@ -170,8 +185,7 @@ class LogMelSpectrogram():
self.eps = eps
def __repr__(self):
return (
"{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
return ("{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
"n_shift={n_shift}, win_length={win_length}, window={window}, "
"fmin={fmin}, fmax={fmax}, eps={eps}))".format(
name=self.__class__.__name__,
@ -183,9 +197,7 @@ class LogMelSpectrogram():
window=self.window,
fmin=self.fmin,
fmax=self.fmax,
eps=self.eps,
)
)
eps=self.eps, ))
def __call__(self, x):
return logmelspectrogram(
@ -195,8 +207,7 @@ class LogMelSpectrogram():
n_fft=self.n_fft,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window,
)
window=self.window, )
class Stft2LogMelSpectrogram():
@ -209,8 +220,7 @@ class Stft2LogMelSpectrogram():
self.eps = eps
def __repr__(self):
return (
"{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
return ("{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
"fmin={fmin}, fmax={fmax}, eps={eps}))".format(
name=self.__class__.__name__,
fs=self.fs,
@ -218,9 +228,7 @@ class Stft2LogMelSpectrogram():
n_fft=self.n_fft,
fmin=self.fmin,
fmax=self.fmax,
eps=self.eps,
)
)
eps=self.eps, ))
def __call__(self, x):
return stft2logmelspectrogram(
@ -229,8 +237,7 @@ class Stft2LogMelSpectrogram():
n_mels=self.n_mels,
n_fft=self.n_fft,
fmin=self.fmin,
fmax=self.fmax,
)
fmax=self.fmax, )
class Stft():
@ -241,8 +248,7 @@ class Stft():
win_length=None,
window="hann",
center=True,
pad_mode="reflect",
):
pad_mode="reflect", ):
self.n_fft = n_fft
self.n_shift = n_shift
self.win_length = win_length
@ -251,8 +257,7 @@ class Stft():
self.pad_mode = pad_mode
def __repr__(self):
return (
"{name}(n_fft={n_fft}, n_shift={n_shift}, "
return ("{name}(n_fft={n_fft}, n_shift={n_shift}, "
"win_length={win_length}, window={window},"
"center={center}, pad_mode={pad_mode})".format(
name=self.__class__.__name__,
@ -261,9 +266,7 @@ class Stft():
win_length=self.win_length,
window=self.window,
center=self.center,
pad_mode=self.pad_mode,
)
)
pad_mode=self.pad_mode, ))
def __call__(self, x):
return stft(
@ -273,8 +276,7 @@ class Stft():
win_length=self.win_length,
window=self.window,
center=self.center,
pad_mode=self.pad_mode,
)
pad_mode=self.pad_mode, )
class IStft():
@ -285,17 +287,14 @@ class IStft():
self.center = center
def __repr__(self):
return (
"{name}(n_shift={n_shift}, "
return ("{name}(n_shift={n_shift}, "
"win_length={win_length}, window={window},"
"center={center})".format(
name=self.__class__.__name__,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window,
center=self.center,
)
)
center=self.center, ))
def __call__(self, x):
return istft(
@ -303,5 +302,4 @@ class IStft():
self.n_shift,
win_length=self.win_length,
window=self.window,
center=self.center,
)
center=self.center, )

@ -1,3 +1,16 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(karita): add this to all the transform impl.
class TransformInterface:
"""Transform Interface"""

@ -1,16 +1,28 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformation module."""
from collections.abc import Sequence
from collections import OrderedDict
import copy
from inspect import signature
import io
import logging
from collections import OrderedDict
from collections.abc import Sequence
from inspect import signature
import yaml
from deepspeech.utils.dynamic_import import dynamic_import
# TODO(karita): inherit TransformInterface
# TODO(karita): register cmd arguments in asr_train.py
import_alias = dict(
@ -33,8 +45,7 @@ import_alias = dict(
istft="deepspeech.transform.spectrogram:IStft",
stft2fbank="deepspeech.transform.spectrogram:Stft2LogMelSpectrogram",
wpe="deepspeech.transform.wpe:WPE",
channel_selector="deepspeech.transform.channel_selector:ChannelSelector",
)
channel_selector="deepspeech.transform.channel_selector:ChannelSelector", )
class Transformation():
@ -83,21 +94,16 @@ class Transformation():
# Some function, e.g. built-in function, are failed
pass
else:
logging.error(
"Expected signature: {}({})".format(
class_obj.__name__, signa
)
)
logging.error("Expected signature: {}({})".format(
class_obj.__name__, signa))
raise
else:
raise NotImplementedError(
"Not supporting mode={}".format(self.conf["mode"])
)
"Not supporting mode={}".format(self.conf["mode"]))
def __repr__(self):
rep = "\n" + "\n".join(
" {}: {}".format(k, v) for k, v in self.functions.items()
)
rep = "\n" + "\n".join(" {}: {}".format(k, v)
for k, v in self.functions.items())
return "{}({})".format(self.__class__.__name__, rep)
def __call__(self, xs, uttid_list=None, **kwargs):
@ -130,18 +136,19 @@ class Transformation():
_kwargs = {k: v for k, v in kwargs.items() if k in param}
try:
if uttid_list is not None and "uttid" in param:
xs = [func(x, u, **_kwargs) for x, u in zip(xs, uttid_list)]
xs = [
func(x, u, **_kwargs)
for x, u in zip(xs, uttid_list)
]
else:
xs = [func(x, **_kwargs) for x in xs]
except Exception:
logging.fatal(
"Catch a exception from {}th func: {}".format(idx, func)
)
logging.fatal("Catch a exception from {}th func: {}".format(
idx, func))
raise
else:
raise NotImplementedError(
"Not supporting mode={}".format(self.conf["mode"])
)
"Not supporting mode={}".format(self.conf["mode"]))
if is_batch:
return xs

@ -1,10 +1,26 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from nara_wpe.wpe import wpe
class WPE(object):
def __init__(
self, taps=10, delay=3, iterations=3, psd_context=0, statistics_mode="full"
):
def __init__(self,
taps=10,
delay=3,
iterations=3,
psd_context=0,
statistics_mode="full"):
self.taps = taps
self.delay = delay
self.iterations = iterations
@ -12,8 +28,7 @@ class WPE(object):
self.statistics_mode = statistics_mode
def __repr__(self):
return (
"{name}(taps={taps}, delay={delay}"
return ("{name}(taps={taps}, delay={delay}"
"iterations={iterations}, psd_context={psd_context}, "
"statistics_mode={statistics_mode})".format(
name=self.__class__.__name__,
@ -21,9 +36,7 @@ class WPE(object):
delay=self.delay,
iterations=self.iterations,
psd_context=self.psd_context,
statistics_mode=self.statistics_mode,
)
)
statistics_mode=self.statistics_mode, ))
def __call__(self, xs):
"""Return enhanced
@ -40,6 +53,5 @@ class WPE(object):
delay=self.delay,
iterations=self.iterations,
psd_context=self.psd_context,
statistics_mode=self.statistics_mode,
)
statistics_mode=self.statistics_mode, )
return xs.transpose(2, 1, 0)

@ -1,3 +1,16 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
@ -17,4 +30,5 @@ def check_kwargs(func, kwargs, name=None):
name = func.__name__
for k in kwargs.keys():
if k not in params:
raise TypeError(f"{name}() got an unexpected keyword argument '{k}'")
raise TypeError(
f"{name}() got an unexpected keyword argument '{k}'")

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -53,8 +53,8 @@ def batch_text_id(minibatch, pad_id=0, dtype=np.int64):
peek_example = minibatch[0]
assert len(peek_example.shape) == 1, "text example is an 1D tensor"
lengths = [example.shape[0] for example in minibatch
] # assume (channel, n_samples) or (n_samples, )
lengths = [example.shape[0] for example in
minibatch] # assume (channel, n_samples) or (n_samples, )
max_len = np.max(lengths)
batch = []

@ -67,19 +67,16 @@ class LJSpeechCollector(object):
# Sort by text_len in descending order
texts = [
i
for i, _ in sorted(
i for i, _ in sorted(
zip(texts, text_lens), key=lambda x: x[1], reverse=True)
]
mels = [
i
for i, _ in sorted(
i for i, _ in sorted(
zip(mels, text_lens), key=lambda x: x[1], reverse=True)
]
mel_lens = [
i
for i, _ in sorted(
i for i, _ in sorted(
zip(mel_lens, text_lens), key=lambda x: x[1], reverse=True)
]

@ -13,7 +13,7 @@ librosa
llvmlite
loguru
matplotlib
nltk
nara_wpenltk
numba
numpy==1.20.0
pandas
@ -42,4 +42,3 @@ visualdl==2.2.0
webrtcvad
yacs
yq
nara_wpe

@ -12,33 +12,32 @@ from deepspeech.utils.cli_utils import is_scipy_wav_style
def get_parser():
parser = argparse.ArgumentParser(
description="convert feature to its shape",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option")
formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
parser.add_argument(
"--verbose", "-V", default=0, type=int, help="Verbose option")
parser.add_argument(
"--filetype",
type=str,
default="mat",
choices=["mat", "hdf5", "sound.hdf5", "sound"],
help="Specify the file format for the rspecifier. "
'"mat" is the matrix format in kaldi',
)
'"mat" is the matrix format in kaldi', )
parser.add_argument(
"--preprocess-conf",
type=str,
default=None,
help="The configuration file for the pre-processing",
)
help="The configuration file for the pre-processing", )
parser.add_argument(
"rspecifier", type=str, help="Read specifier for feats. e.g. ark:some.ark"
)
"rspecifier",
type=str,
help="Read specifier for feats. e.g. ark:some.ark")
parser.add_argument(
"out",
nargs="?",
type=argparse.FileType("w"),
default=sys.stdout,
help="The output filename. " "If omitted, then output to sys.stdout",
)
help="The output filename. "
"If omitted, then output to sys.stdout", )
return parser
@ -64,8 +63,7 @@ def main():
# so change to file_reader_helper to return shape.
# This make sense only with filetype="hdf5".
for utt, mat in file_reader_helper(
args.rspecifier, args.filetype, return_shape=preprocessing is None
):
args.rspecifier, args.filetype, return_shape=preprocessing is None):
if preprocessing is not None:
if is_scipy_wav_style(mat):
# If data is sound file, then got as Tuple[int, ndarray]

Loading…
Cancel
Save