Merge branch 'develop' of https://github.com/PaddlePaddle/DeepSpeech into fix_bug

pull/791/head
huangyuxin 3 years ago
commit 03a50d7bf9

@ -42,6 +42,10 @@ ignore =
# these ignores are from flake8-comprehensions; please fix!
C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415
per-file-ignores =
*/__init__.py: F401
# Specify the list of error codes you wish Flake8 to report.
select =
E,

@ -352,45 +352,6 @@ if not hasattr(paddle.Tensor, 'tolist'):
"register user tolist to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'tolist', tolist)
########### hcak paddle.nn.functional #############
def glu(x: paddle.Tensor, axis=-1) -> paddle.Tensor:
"""The gated linear unit (GLU) activation."""
a, b = x.split(2, axis=axis)
act_b = F.sigmoid(b)
return a * act_b
if not hasattr(paddle.nn.functional, 'glu'):
logger.warn(
"register user glu to paddle.nn.functional, remove this when fixed!")
setattr(paddle.nn.functional, 'glu', glu)
# def softplus(x):
# """Softplus function."""
# if hasattr(paddle.nn.functional, 'softplus'):
# #return paddle.nn.functional.softplus(x.float()).type_as(x)
# return paddle.nn.functional.softplus(x)
# else:
# raise NotImplementedError
# def gelu_accurate(x):
# """Gaussian Error Linear Units (GELU) activation."""
# # [reference] https://github.com/pytorch/fairseq/blob/e75cff5f2c1d62f12dc911e0bf420025eb1a4e33/fairseq/modules/gelu.py
# if not hasattr(gelu_accurate, "_a"):
# gelu_accurate._a = math.sqrt(2 / math.pi)
# return 0.5 * x * (1 + paddle.tanh(gelu_accurate._a *
# (x + 0.044715 * paddle.pow(x, 3))))
# def gelu(x):
# """Gaussian Error Linear Units (GELU) activation."""
# if hasattr(nn.functional, 'gelu'):
# #return nn.functional.gelu(x.float()).type_as(x)
# return nn.functional.gelu(x)
# else:
# return x * 0.5 * (1.0 + paddle.erf(x / math.sqrt(2.0)))
########### hcak paddle.nn #############
class GLU(nn.Layer):
@ -401,7 +362,7 @@ class GLU(nn.Layer):
self.dim = dim
def forward(self, xs):
return glu(xs, dim=self.dim)
return F.glu(xs, dim=self.dim)
if not hasattr(paddle.nn, 'GLU'):

@ -83,10 +83,13 @@ FILES = glob.glob('kenlm/util/*.cc') \
FILES += glob.glob('openfst-1.6.3/src/lib/*.cc')
# yapf: disable
FILES = [
fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc')
or fn.endswith('unittest.cc'))
fn for fn in FILES
if not (fn.endswith('main.cc') or fn.endswith('test.cc') or fn.endswith(
'unittest.cc'))
]
# yapf: enable
LIBS = ['stdc++']
if platform.system() != 'Darwin':

@ -30,6 +30,9 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
# save jit model to
parser.add_argument(
"--export_path", type=str, help="path of the jit model to save")
parser.add_argument("--model_type")
args = parser.parse_args()
if args.model_type is None:

@ -31,6 +31,9 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
parser.add_argument("--model_type")
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())
if args.model_type is None:

@ -30,6 +30,9 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())

@ -30,6 +30,9 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
# save jit model to
parser.add_argument(
"--export_path", type=str, help="path of the jit model to save")
args = parser.parse_args()
print_arguments(args, globals())

@ -34,6 +34,9 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())

@ -14,6 +14,8 @@
"""Evaluation for U2 model."""
import cProfile
from yacs.config import CfgNode
from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.dynamic_import import dynamic_import
from deepspeech.utils.utility import print_arguments
@ -53,6 +55,14 @@ if __name__ == "__main__":
type=str,
default='test',
help='run mode, e.g. test, align, export')
parser.add_argument(
'--dict-path', type=str, default=None, help='dict path.')
# save asr result to
parser.add_argument(
"--result-file", type=str, help="path of save the asr result")
# save jit model to
parser.add_argument(
"--export-path", type=str, help="path of the jit model to save")
args = parser.parse_args()
print_arguments(args, globals())

@ -25,6 +25,8 @@ import paddle
from paddle import distributed as dist
from yacs.config import CfgNode
from deepspeech.frontend.featurizer import TextFeaturizer
from deepspeech.frontend.utility import load_dict
from deepspeech.io.dataloader import BatchDataLoader
from deepspeech.models.u2 import U2Model
from deepspeech.training.optimizer import OptimizerFactory
@ -80,8 +82,8 @@ class U2Trainer(Trainer):
def train_batch(self, batch_index, batch_data, msg):
train_conf = self.config.training
start = time.time()
utt, audio, audio_len, text, text_len = batch_data
utt, audio, audio_len, text, text_len = batch_data
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
text_len)
# loss div by `batch_size * accum_grad`
@ -124,6 +126,7 @@ class U2Trainer(Trainer):
valid_losses = defaultdict(list)
num_seen_utts = 1
total_loss = 0.0
for i, batch in enumerate(self.valid_loader):
utt, audio, audio_len, text, text_len = batch
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
@ -168,10 +171,7 @@ class U2Trainer(Trainer):
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init')
self.lr_scheduler.step(self.iteration)
if self.parallel:
self.train_loader.batch_sampler.set_epoch(self.epoch)
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch:
@ -225,7 +225,7 @@ class U2Trainer(Trainer):
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
mini_batch_size=1,
mini_batch_size=self.args.nprocs,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
@ -244,7 +244,7 @@ class U2Trainer(Trainer):
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
mini_batch_size=1,
mini_batch_size=self.args.nprocs,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
@ -260,7 +260,7 @@ class U2Trainer(Trainer):
json_file=config.data.test_manifest,
train_mode=False,
sortagrad=False,
batch_size=config.collator.batch_size,
batch_size=config.decoding.batch_size,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
@ -279,7 +279,7 @@ class U2Trainer(Trainer):
json_file=config.data.test_manifest,
train_mode=False,
sortagrad=False,
batch_size=config.collator.batch_size,
batch_size=config.decoding.batch_size,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
@ -305,10 +305,8 @@ class U2Trainer(Trainer):
model_conf.output_dim = self.train_loader.vocab_size
model_conf.freeze()
model = U2Model.from_config(model_conf)
if self.parallel:
model = paddle.DataParallel(model)
logger.info(f"{model}")
layer_tools.print_params(model, logger.info)
@ -379,13 +377,13 @@ class U2Tester(U2Trainer):
def __init__(self, config, args):
super().__init__(config, args)
def ordid2token(self, texts, texts_len):
def id2token(self, texts, texts_len, text_feature):
""" ord() id to chr() chr """
trans = []
for text, n in zip(texts, texts_len):
n = n.numpy().item()
ids = text[:n]
trans.append(''.join([chr(i) for i in ids]))
trans.append(text_feature.defeaturize(ids.numpy().tolist()))
return trans
def compute_metrics(self,
@ -401,8 +399,11 @@ class U2Tester(U2Trainer):
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
start_time = time.time()
text_feature = self.test_loader.collate_fn.text_feature
target_transcripts = self.ordid2token(texts, texts_len)
text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)
target_transcripts = self.id2token(texts, texts_len, text_feature)
result_transcripts = self.model.decode(
audio,
audio_len,
@ -450,7 +451,7 @@ class U2Tester(U2Trainer):
self.model.eval()
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
stride_ms = self.test_loader.collate_fn.stride_ms
stride_ms = self.config.collator.stride_ms
error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0
num_frames = 0.0
@ -525,8 +526,9 @@ class U2Tester(U2Trainer):
self.model.eval()
logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}")
stride_ms = self.config.collate.stride_ms
token_dict = self.align_loader.collate_fn.vocab_list
stride_ms = self.config.collater.stride_ms
token_dict = self.args.char_list
with open(self.args.result_file, 'w') as fout:
# one example in batch
for i, batch in enumerate(self.align_loader):
@ -613,6 +615,11 @@ class U2Tester(U2Trainer):
except KeyboardInterrupt:
sys.exit(-1)
def setup_dict(self):
# load dictionary for debug log
self.args.char_list = load_dict(self.args.dict_path,
"maskctc" in self.args.model_name)
def setup(self):
"""Setup the experiment.
"""
@ -624,6 +631,8 @@ class U2Tester(U2Trainer):
self.setup_dataloader()
self.setup_model()
self.setup_dict()
self.iteration = 0
self.epoch = 0

@ -30,6 +30,9 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
# save jit model to
parser.add_argument(
"--export_path", type=str, help="path of the jit model to save")
args = parser.parse_args()
print_arguments(args, globals())

@ -34,6 +34,9 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())

@ -32,7 +32,7 @@ class ImpulseResponseAugmentor(AugmentorBase):
def __call__(self, x, uttid=None, train=True):
if not train:
return
return x
self.transform_audio(x)
return x

@ -38,7 +38,7 @@ class NoisePerturbAugmentor(AugmentorBase):
def __call__(self, x, uttid=None, train=True):
if not train:
return
return x
self.transform_audio(x)
return x

@ -46,7 +46,7 @@ class OnlineBayesianNormalizationAugmentor(AugmentorBase):
def __call__(self, x, uttid=None, train=True):
if not train:
return
return x
self.transform_audio(x)
return x

@ -33,7 +33,7 @@ class ResampleAugmentor(AugmentorBase):
def __call__(self, x, uttid=None, train=True):
if not train:
return
return x
self.transform_audio(x)
return x

@ -33,7 +33,7 @@ class ShiftPerturbAugmentor(AugmentorBase):
def __call__(self, x, uttid=None, train=True):
if not train:
return
return x
self.transform_audio(x)
return x

@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains the volume perturb augmentation model."""
import random
import numpy as np
from PIL import Image
from PIL.Image import BICUBIC
from deepspeech.frontend.augmentor.base import AugmentorBase
from deepspeech.utils.log import Log
@ -41,7 +45,9 @@ class SpecAugmentor(AugmentorBase):
W=40,
adaptive_number_ratio=0,
adaptive_size_ratio=0,
max_n_time_masks=20):
max_n_time_masks=20,
replace_with_zero=True,
warp_mode='PIL'):
"""SpecAugment class.
Args:
rng (random.Random): random generator object.
@ -54,10 +60,16 @@ class SpecAugmentor(AugmentorBase):
adaptive_number_ratio (float): adaptive multiplicity ratio for time masking
adaptive_size_ratio (float): adaptive size ratio for time masking
max_n_time_masks (int): maximum number of time masking
replace_with_zero (bool): pad zero on mask if true else use mean
warp_mode (str): "PIL" (default, fast, not differentiable)
or "sparse_image_warp" (slow, differentiable)
"""
super().__init__()
self._rng = rng
self.inplace = True
self.replace_with_zero = replace_with_zero
self.mode = warp_mode
self.W = W
self.F = F
self.T = T
@ -123,21 +135,83 @@ class SpecAugmentor(AugmentorBase):
def __repr__(self):
return f"specaug: F-{F}, T-{T}, F-n-{n_freq_masks}, T-n-{n_time_masks}"
def time_warp(xs, W=40):
raise NotImplementedError
def time_warp(self, x, mode='PIL'):
"""time warp for spec augment
move random center frame by the random width ~ uniform(-window, window)
Args:
x (np.ndarray): spectrogram (time, freq)
mode (str): PIL or sparse_image_warp
Raises:
NotImplementedError: [description]
NotImplementedError: [description]
Returns:
np.ndarray: time warped spectrogram (time, freq)
"""
window = max_time_warp = self.W
if window == 0:
return x
if mode == "PIL":
t = x.shape[0]
if t - window <= window:
return x
# NOTE: randrange(a, b) emits a, a + 1, ..., b - 1
center = random.randrange(window, t - window)
warped = random.randrange(center - window, center +
window) + 1 # 1 ... t - 1
left = Image.fromarray(x[:center]).resize((x.shape[1], warped),
BICUBIC)
right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped),
BICUBIC)
if self.inplace:
x[:warped] = left
x[warped:] = right
return x
return np.concatenate((left, right), 0)
elif mode == "sparse_image_warp":
raise NotImplementedError('sparse_image_warp')
else:
raise NotImplementedError(
"unknown resize mode: " + mode +
", choose one from (PIL, sparse_image_warp).")
def mask_freq(self, x, replace_with_zero=False):
"""freq mask
Args:
x (np.ndarray): spectrogram (time, freq)
replace_with_zero (bool, optional): Defaults to False.
def mask_freq(self, xs, replace_with_zero=False):
n_bins = xs.shape[0]
Returns:
np.ndarray: freq mask spectrogram (time, freq)
"""
n_bins = x.shape[1]
for i in range(0, self.n_freq_masks):
f = int(self._rng.uniform(low=0, high=self.F))
f_0 = int(self._rng.uniform(low=0, high=n_bins - f))
xs[f_0:f_0 + f, :] = 0
assert f_0 <= f_0 + f
if replace_with_zero:
x[:, f_0:f_0 + f] = 0
else:
x[:, f_0:f_0 + f] = x.mean()
self._freq_mask = (f_0, f_0 + f)
return xs
return x
def mask_time(self, xs, replace_with_zero=False):
n_frames = xs.shape[1]
def mask_time(self, x, replace_with_zero=False):
"""time mask
Args:
x (np.ndarray): spectrogram (time, freq)
replace_with_zero (bool, optional): Defaults to False.
Returns:
np.ndarray: time mask spectrogram (time, freq)
"""
n_frames = x.shape[0]
if self.adaptive_number_ratio > 0:
n_masks = int(n_frames * self.adaptive_number_ratio)
@ -154,24 +228,29 @@ class SpecAugmentor(AugmentorBase):
t = int(self._rng.uniform(low=0, high=T))
t = min(t, int(n_frames * self.p))
t_0 = int(self._rng.uniform(low=0, high=n_frames - t))
xs[:, t_0:t_0 + t] = 0
assert t_0 <= t_0 + t
if replace_with_zero:
x[t_0:t_0 + t, :] = 0
else:
x[t_0:t_0 + t, :] = x.mean()
self._time_mask = (t_0, t_0 + t)
return xs
return x
def __call__(self, x, train=True):
if not train:
return
return x
return self.transform_feature(x)
def transform_feature(self, xs: np.ndarray):
def transform_feature(self, x: np.ndarray):
"""
Args:
xs (FloatTensor): `[F, T]`
x (np.ndarray): `[T, F]`
Returns:
xs (FloatTensor): `[F, T]`
x (np.ndarray): `[T, F]`
"""
# xs = self.time_warp(xs)
xs = self.mask_freq(xs)
xs = self.mask_time(xs)
return xs
assert isinstance(x, np.ndarray)
assert x.ndim == 2
x = self.time_warp(x, self.mode)
x = self.mask_freq(x, self.replace_with_zero)
x = self.mask_time(x, self.replace_with_zero)
return x

@ -81,7 +81,7 @@ class SpeedPerturbAugmentor(AugmentorBase):
def __call__(self, x, uttid=None, train=True):
if not train:
return
return x
self.transform_audio(x)
return x

@ -39,7 +39,7 @@ class VolumePerturbAugmentor(AugmentorBase):
def __call__(self, x, uttid=None, train=True):
if not train:
return
return x
self.transform_audio(x)
return x

@ -11,3 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .audio_featurizer import AudioFeaturizer #noqa: F401
from .speech_featurizer import SpeechFeaturizer
from .text_featurizer import TextFeaturizer

@ -18,7 +18,7 @@ from python_speech_features import logfbank
from python_speech_features import mfcc
class AudioFeaturizer(object):
class AudioFeaturizer():
"""Audio featurizer, for extracting features from audio contents of
AudioSegment or SpeechSegment.
@ -167,32 +167,6 @@ class AudioFeaturizer(object):
raise ValueError("Unknown specgram_type %s. "
"Supported values: linear." % self._specgram_type)
def _compute_linear_specgram(self,
samples,
sample_rate,
stride_ms=10.0,
window_ms=20.0,
max_freq=None,
eps=1e-14):
"""Compute the linear spectrogram from FFT energy."""
if max_freq is None:
max_freq = sample_rate / 2
if max_freq > sample_rate / 2:
raise ValueError("max_freq must not be greater than half of "
"sample rate.")
if stride_ms > window_ms:
raise ValueError("Stride size must not be greater than "
"window size.")
stride_size = int(0.001 * sample_rate * stride_ms)
window_size = int(0.001 * sample_rate * window_ms)
specgram, freqs = self._specgram_real(
samples,
window_size=window_size,
stride_size=stride_size,
sample_rate=sample_rate)
ind = np.where(freqs <= max_freq)[0][-1] + 1
return np.log(specgram[:ind, :] + eps)
def _specgram_real(self, samples, window_size, stride_size, sample_rate):
"""Compute the spectrogram for samples from a real signal."""
# extract strided windows
@ -217,26 +191,65 @@ class AudioFeaturizer(object):
freqs = float(sample_rate) / window_size * np.arange(fft.shape[0])
return fft, freqs
def _compute_linear_specgram(self,
samples,
sample_rate,
stride_ms=10.0,
window_ms=20.0,
max_freq=None,
eps=1e-14):
"""Compute the linear spectrogram from FFT energy.
Args:
samples ([type]): [description]
sample_rate ([type]): [description]
stride_ms (float, optional): [description]. Defaults to 10.0.
window_ms (float, optional): [description]. Defaults to 20.0.
max_freq ([type], optional): [description]. Defaults to None.
eps ([type], optional): [description]. Defaults to 1e-14.
Raises:
ValueError: [description]
ValueError: [description]
Returns:
np.ndarray: log spectrogram, (time, freq)
"""
if max_freq is None:
max_freq = sample_rate / 2
if max_freq > sample_rate / 2:
raise ValueError("max_freq must not be greater than half of "
"sample rate.")
if stride_ms > window_ms:
raise ValueError("Stride size must not be greater than "
"window size.")
stride_size = int(0.001 * sample_rate * stride_ms)
window_size = int(0.001 * sample_rate * window_ms)
specgram, freqs = self._specgram_real(
samples,
window_size=window_size,
stride_size=stride_size,
sample_rate=sample_rate)
ind = np.where(freqs <= max_freq)[0][-1] + 1
# (freq, time)
spec = np.log(specgram[:ind, :] + eps)
return np.transpose(spec)
def _concat_delta_delta(self, feat):
"""append delat, delta-delta feature.
Args:
feat (np.ndarray): (D, T)
feat (np.ndarray): (T, D)
Returns:
np.ndarray: feat with delta-delta, (3*D, T)
np.ndarray: feat with delta-delta, (T, 3*D)
"""
feat = np.transpose(feat)
# Deltas
d_feat = delta(feat, 2)
# Deltas-Deltas
dd_feat = delta(feat, 2)
# transpose
feat = np.transpose(feat)
d_feat = np.transpose(d_feat)
dd_feat = np.transpose(dd_feat)
# concat above three features
concat_feat = np.concatenate((feat, d_feat, dd_feat))
concat_feat = np.concatenate((feat, d_feat, dd_feat), axis=1)
return concat_feat
def _compute_mfcc(self,
@ -292,7 +305,6 @@ class AudioFeaturizer(object):
ceplifter=22,
useEnergy=True,
winfunc='povey')
mfcc_feat = np.transpose(mfcc_feat)
if delta_delta:
mfcc_feat = self._concat_delta_delta(mfcc_feat)
return mfcc_feat
@ -346,8 +358,6 @@ class AudioFeaturizer(object):
remove_dc_offset=True,
preemph=0.97,
wintype='povey')
fbank_feat = np.transpose(fbank_feat)
if delta_delta:
fbank_feat = self._concat_delta_delta(fbank_feat)
return fbank_feat

@ -16,7 +16,7 @@ from deepspeech.frontend.featurizer.audio_featurizer import AudioFeaturizer
from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer
class SpeechFeaturizer(object):
class SpeechFeaturizer():
"""Speech featurizer, for extracting features from both audio and transcript
contents of SpeechSegment.

@ -14,12 +14,19 @@
"""Contains the text featurizer class."""
import sentencepiece as spm
from deepspeech.frontend.utility import EOS
from deepspeech.frontend.utility import UNK
from ..utility import EOS
from ..utility import load_dict
from ..utility import UNK
__all__ = ["TextFeaturizer"]
class TextFeaturizer(object):
def __init__(self, unit_type, vocab_filepath, spm_model_prefix=None):
class TextFeaturizer():
def __init__(self,
unit_type,
vocab_filepath,
spm_model_prefix=None,
maskctc=False):
"""Text featurizer, for processing or extracting features from text.
Currently, it supports char/word/sentence-piece level tokenizing and conversion into
@ -34,11 +41,12 @@ class TextFeaturizer(object):
assert unit_type in ('char', 'spm', 'word')
self.unit_type = unit_type
self.unk = UNK
self.maskctc = maskctc
if vocab_filepath:
self._vocab_dict, self._id2token, self._vocab_list = self._load_vocabulary_from_file(
vocab_filepath)
self.unk_id = self._vocab_list.index(self.unk)
self.eos_id = self._vocab_list.index(EOS)
self.vocab_dict, self._id2token, self.vocab_list, self.unk_id, self.eos_id = self._load_vocabulary_from_file(
vocab_filepath, maskctc)
self.vocab_size = len(self.vocab_list)
if unit_type == 'spm':
spm_model = spm_model_prefix + '.model'
@ -67,7 +75,7 @@ class TextFeaturizer(object):
"""Convert text string to a list of token indices.
Args:
text (str): Text to process.
text (str): Text.
Returns:
List[int]: List of token indices.
@ -75,8 +83,8 @@ class TextFeaturizer(object):
tokens = self.tokenize(text)
ids = []
for token in tokens:
token = token if token in self._vocab_dict else self.unk
ids.append(self._vocab_dict[token])
token = token if token in self.vocab_dict else self.unk
ids.append(self.vocab_dict[token])
return ids
def defeaturize(self, idxs):
@ -87,7 +95,7 @@ class TextFeaturizer(object):
idxs (List[int]): List of token indices.
Returns:
str: Text to process.
str: Text.
"""
tokens = []
for idx in idxs:
@ -97,33 +105,6 @@ class TextFeaturizer(object):
text = self.detokenize(tokens)
return text
@property
def vocab_size(self):
"""Return the vocabulary size.
:return: Vocabulary size.
:rtype: int
"""
return len(self._vocab_list)
@property
def vocab_list(self):
"""Return the vocabulary in list.
Returns:
List[str]: tokens.
"""
return self._vocab_list
@property
def vocab_dict(self):
"""Return the vocabulary in dict.
Returns:
Dict[str, int]: token str -> int
"""
return self._vocab_dict
def char_tokenize(self, text):
"""Character tokenizer.
@ -206,14 +187,16 @@ class TextFeaturizer(object):
return decode(tokens)
def _load_vocabulary_from_file(self, vocab_filepath):
def _load_vocabulary_from_file(self, vocab_filepath: str, maskctc: bool):
"""Load vocabulary from file."""
vocab_lines = []
with open(vocab_filepath, 'r', encoding='utf-8') as file:
vocab_lines.extend(file.readlines())
vocab_list = [line[:-1] for line in vocab_lines]
vocab_list = load_dict(vocab_filepath, maskctc)
assert vocab_list is not None
id2token = dict(
[(idx, token) for (idx, token) in enumerate(vocab_list)])
token2id = dict(
[(token, idx) for (idx, token) in enumerate(vocab_list)])
return token2id, id2token, vocab_list
unk_id = vocab_list.index(UNK)
eos_id = vocab_list.index(EOS)
return token2id, id2token, vocab_list, unk_id, eos_id

@ -40,21 +40,21 @@ class CollateFunc(object):
number = 0
for item in batch:
audioseg = AudioSegment.from_file(item['feat'])
feat = self.feature_func(audioseg) #(D, T)
feat = self.feature_func(audioseg) #(T, D)
sums = np.sum(feat, axis=1)
sums = np.sum(feat, axis=0)
if mean_stat is None:
mean_stat = sums
else:
mean_stat += sums
square_sums = np.sum(np.square(feat), axis=1)
square_sums = np.sum(np.square(feat), axis=0)
if var_stat is None:
var_stat = square_sums
else:
var_stat += square_sums
number += feat.shape[1]
number += feat.shape[0]
return number, mean_stat, var_stat
@ -120,7 +120,7 @@ class FeatureNormalizer(object):
"""Normalize features to be of zero mean and unit stddev.
:param features: Input features to be normalized.
:type features: ndarray, shape (D, T)
:type features: ndarray, shape (T, D)
:param eps: added to stddev to provide numerical stablibity.
:type eps: float
:return: Normalized features.
@ -131,8 +131,8 @@ class FeatureNormalizer(object):
def _read_mean_std_from_file(self, filepath, eps=1e-20):
"""Load mean and std from file."""
mean, istd = load_cmvn(filepath, filetype='json')
self._mean = np.expand_dims(mean, axis=-1)
self._istd = np.expand_dims(istd, axis=-1)
self._mean = np.expand_dims(mean, axis=0)
self._istd = np.expand_dims(istd, axis=0)
def write_to_file(self, filepath):
"""Write the mean and stddev to the file.

@ -15,6 +15,9 @@
import codecs
import json
import math
from typing import List
from typing import Optional
from typing import Text
import numpy as np
@ -23,16 +26,35 @@ from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = [
"load_cmvn", "read_manifest", "rms_to_db", "rms_to_dbfs", "max_dbfs",
"mean_dbfs", "gain_db_to_ratio", "normalize_audio", "SOS", "EOS", "UNK",
"BLANK"
"load_dict", "load_cmvn", "read_manifest", "rms_to_db", "rms_to_dbfs",
"max_dbfs", "mean_dbfs", "gain_db_to_ratio", "normalize_audio", "SOS",
"EOS", "UNK", "BLANK", "MASKCTC"
]
IGNORE_ID = -1
SOS = "<sos/eos>"
# `sos` and `eos` using same token
SOS = "<eos>"
EOS = SOS
UNK = "<unk>"
BLANK = "<blank>"
MASKCTC = "<mask>"
def load_dict(dict_path: Optional[Text], maskctc=False) -> Optional[List[Text]]:
if dict_path is None:
return None
with open(dict_path, "r") as f:
dictionary = f.readlines()
char_list = [entry.strip().split(" ")[0] for entry in dictionary]
if BLANK not in char_list:
char_list.insert(0, BLANK)
if EOS not in char_list:
char_list.append(EOS)
# for non-autoregressive maskctc model
if maskctc and MASKCTC not in char_list:
char_list.append(MASKCTC)
return char_list
def read_manifest(
@ -47,12 +69,20 @@ def read_manifest(
Args:
manifest_path ([type]): Manifest file to load and parse.
max_input_len ([type], optional): maximum output seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf').
min_input_len (float, optional): minimum input seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0.
max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0.
min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0.
max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0.
min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05.
max_input_len ([type], optional): maximum output seq length,
in seconds for raw wav, in frame numbers for feature data.
Defaults to float('inf').
min_input_len (float, optional): minimum input seq length,
in seconds for raw wav, in frame numbers for feature data.
Defaults to 0.0.
max_output_len (float, optional): maximum input seq length,
in modeling units. Defaults to 500.0.
min_output_len (float, optional): minimum input seq length,
in modeling units. Defaults to 0.0.
max_output_input_ratio (float, optional):
maximum output seq length/output seq length ratio. Defaults to 10.0.
min_output_input_ratio (float, optional):
minimum output seq length/output seq length ratio. Defaults to 0.05.
Raises:
IOError: If failed to parse the manifest.

@ -242,7 +242,6 @@ class SpeechCollator():
# specgram augment
specgram = self._augmentation_pipeline.transform_feature(specgram)
specgram = specgram.transpose([1, 0])
return specgram, transcript_part
def __call__(self, batch):
@ -250,7 +249,7 @@ class SpeechCollator():
Args:
batch ([List]): batch is (audio, text)
audio (np.ndarray) shape (D, T)
audio (np.ndarray) shape (T, D)
text (List[int] or str): shape (U,)
Returns:

@ -217,6 +217,34 @@ class SpeechCollator():
return self._local_data.tar2object[tarpath].extractfile(
self._local_data.tar2info[tarpath][filename])
@property
def manifest(self):
return self._manifest
@property
def vocab_size(self):
return self._speech_featurizer.vocab_size
@property
def vocab_list(self):
return self._speech_featurizer.vocab_list
@property
def vocab_dict(self):
return self._speech_featurizer.vocab_dict
@property
def text_feature(self):
return self._speech_featurizer.text_feature
@property
def feature_size(self):
return self._speech_featurizer.feature_size
@property
def stride_ms(self):
return self._speech_featurizer.stride_ms
def process_utterance(self, audio_file, translation):
"""Load, augment, featurize and normalize for speech data.
@ -244,7 +272,6 @@ class SpeechCollator():
# specgram augment
specgram = self._augmentation_pipeline.transform_feature(specgram)
specgram = specgram.transpose([1, 0])
return specgram, translation_part
def __call__(self, batch):
@ -252,7 +279,7 @@ class SpeechCollator():
Args:
batch ([List]): batch is (audio, text)
audio (np.ndarray) shape (D, T)
audio (np.ndarray) shape (T, D)
text (List[int] or str): shape (U,)
Returns:
@ -296,34 +323,6 @@ class SpeechCollator():
text_lens = np.array(text_lens).astype(np.int64)
return utts, padded_audios, audio_lens, padded_texts, text_lens
@property
def manifest(self):
return self._manifest
@property
def vocab_size(self):
return self._speech_featurizer.vocab_size
@property
def vocab_list(self):
return self._speech_featurizer.vocab_list
@property
def vocab_dict(self):
return self._speech_featurizer.vocab_dict
@property
def text_feature(self):
return self._speech_featurizer.text_feature
@property
def feature_size(self):
return self._speech_featurizer.feature_size
@property
def stride_ms(self):
return self._speech_featurizer.stride_ms
class TripletSpeechCollator(SpeechCollator):
def process_utterance(self, audio_file, translation, transcript):
@ -355,7 +354,6 @@ class TripletSpeechCollator(SpeechCollator):
# specgram augment
specgram = self._augmentation_pipeline.transform_feature(specgram)
specgram = specgram.transpose([1, 0])
return specgram, translation_part, transcript_part
def __call__(self, batch):
@ -363,7 +361,7 @@ class TripletSpeechCollator(SpeechCollator):
Args:
batch ([List]): batch is (audio, text)
audio (np.ndarray) shape (D, T)
audio (np.ndarray) shape (T, D)
text (List[int] or str): shape (U,)
Returns:
@ -524,49 +522,19 @@ class KaldiPrePorocessedCollator(SpeechCollator):
:rtype: tuple of (2darray, list)
"""
specgram = kaldiio.load_mat(audio_file)
specgram = specgram.transpose([1, 0])
assert specgram.shape[
0] == self._feat_dim, 'expect feat dim {}, but got {}'.format(
self._feat_dim, specgram.shape[0])
1] == self._feat_dim, 'expect feat dim {}, but got {}'.format(
self._feat_dim, specgram.shape[1])
# specgram augment
specgram = self._augmentation_pipeline.transform_feature(specgram)
specgram = specgram.transpose([1, 0])
if self._keep_transcription_text:
return specgram, translation
else:
text_ids = self._text_featurizer.featurize(translation)
return specgram, text_ids
@property
def manifest(self):
return self._manifest
@property
def vocab_size(self):
return self._text_featurizer.vocab_size
@property
def vocab_list(self):
return self._text_featurizer.vocab_list
@property
def vocab_dict(self):
return self._text_featurizer.vocab_dict
@property
def text_feature(self):
return self._text_featurizer
@property
def feature_size(self):
return self._feat_dim
@property
def stride_ms(self):
return self._stride_ms
class TripletKaldiPrePorocessedCollator(KaldiPrePorocessedCollator):
def process_utterance(self, audio_file, translation, transcript):
@ -583,15 +551,13 @@ class TripletKaldiPrePorocessedCollator(KaldiPrePorocessedCollator):
:rtype: tuple of (2darray, (list, list))
"""
specgram = kaldiio.load_mat(audio_file)
specgram = specgram.transpose([1, 0])
assert specgram.shape[
0] == self._feat_dim, 'expect feat dim {}, but got {}'.format(
self._feat_dim, specgram.shape[0])
1] == self._feat_dim, 'expect feat dim {}, but got {}'.format(
self._feat_dim, specgram.shape[1])
# specgram augment
specgram = self._augmentation_pipeline.transform_feature(specgram)
specgram = specgram.transpose([1, 0])
if self._keep_transcription_text:
return specgram, translation, transcript
else:
@ -604,7 +570,7 @@ class TripletKaldiPrePorocessedCollator(KaldiPrePorocessedCollator):
Args:
batch ([List]): batch is (audio, text)
audio (np.ndarray) shape (D, T)
audio (np.ndarray) shape (T, D)
translation (List[int] or str): shape (U,)
transcription (List[int] or str): shape (V,)

@ -43,7 +43,7 @@ class CustomConverter():
batch (list): The batch to transform.
Returns:
tuple(paddle.Tensor, paddle.Tensor, paddle.Tensor)
tuple(np.ndarray, nn.ndarray, nn.ndarray)
"""
# batch should be located in list

@ -43,6 +43,18 @@ def feat_dim_and_vocab_size(data_json: List[Dict[Text, Any]],
return feat_dim, vocab_size
def batch_collate(x):
"""de-tuple.
Args:
x (List[Tuple]): [(utts, xs, ilens, ys, olens)]
Returns:
Tuple: (utts, xs, ilens, ys, olens)
"""
return x[0]
class BatchDataLoader():
def __init__(self,
json_file: str,
@ -120,15 +132,15 @@ class BatchDataLoader():
# actual bathsize is included in a list
# default collate function converts numpy array to pytorch tensor
# we used an empty collate function instead which returns list
self.dataset = TransformDataset(
self.minibaches,
lambda data: self.converter([self.reader(data, return_uttid=True)]))
self.dataset = TransformDataset(self.minibaches, self.converter,
self.reader)
self.dataloader = DataLoader(
dataset=self.dataset,
batch_size=1,
shuffle=not self.use_sortagrad if train_mode else False,
collate_fn=lambda x: x[0],
num_workers=n_iter_processes, )
shuffle=not self.use_sortagrad if self.train_mode else False,
collate_fn=batch_collate,
num_workers=self.n_iter_processes, )
def __repr__(self):
echo = f"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> "

@ -129,15 +129,16 @@ class TransformDataset(Dataset):
Args:
data: list object from make_batchset
transfrom: transform function
converter: batch function
reader: read data
"""
def __init__(self, data, transform):
def __init__(self, data, converter, reader):
"""Init function."""
super().__init__()
self.data = data
self.transform = transform
self.converter = converter
self.reader = reader
def __len__(self):
"""Len function."""
@ -145,4 +146,4 @@ class TransformDataset(Dataset):
def __getitem__(self, idx):
"""[] operator."""
return self.transform(self.data[idx])
return self.converter([self.reader(self.data[idx], return_uttid=True)])

@ -102,13 +102,13 @@ class CRNNEncoder(nn.Layer):
Args:
x (Tensor): [B, feature_size, D]
x_lens (Tensor): [B]
init_state_h_box(Tensor): init_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
init_state_c_box(Tensor): init_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
Returns:
init_state_h_box(Tensor): init_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
init_state_c_box(Tensor): init_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
Return:
x (Tensor): encoder outputs, [B, size, D]
x_lens (Tensor): encoder length, [B]
final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
final_state_h_box(Tensor): final_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
final_state_c_box(Tensor): final_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
"""
if init_state_h_box is not None:
init_state_list = None
@ -142,7 +142,7 @@ class CRNNEncoder(nn.Layer):
if self.use_gru is True:
final_chunk_state_h_box = paddle.concat(
final_chunk_state_list, axis=0)
final_chunk_state_c_box = init_state_c_box #paddle.zeros_like(final_chunk_state_h_box)
final_chunk_state_c_box = init_state_c_box
else:
final_chunk_state_h_list = [
final_chunk_state_list[i][0] for i in range(self.num_rnn_layers)
@ -165,10 +165,10 @@ class CRNNEncoder(nn.Layer):
x_lens (Tensor): [B]
decoder_chunk_size: The chunk size of decoder
Returns:
eouts_list (List of Tensor): The list of encoder outputs in chunk_size, [B, chunk_size, D] * num_chunks
eouts_lens_list (List of Tensor): The list of encoder length in chunk_size, [B] * num_chunks
final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
eouts_list (List of Tensor): The list of encoder outputs in chunk_size: [B, chunk_size, D] * num_chunks
eouts_lens_list (List of Tensor): The list of encoder length in chunk_size: [B] * num_chunks
final_state_h_box(Tensor): final_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
final_state_c_box(Tensor): final_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
"""
subsampling_rate = self.conv.subsampling_rate
receptive_field_length = self.conv.receptive_field_length
@ -215,12 +215,14 @@ class CRNNEncoder(nn.Layer):
class DeepSpeech2ModelOnline(nn.Layer):
"""The DeepSpeech2 network structure for online.
:param audio_data: Audio spectrogram data layer.
:type audio_data: Variable
:param text_data: Transcription text data layer.
:type text_data: Variable
:param audio: Audio spectrogram data layer.
:type audio: Variable
:param text: Transcription text data layer.
:type text: Variable
:param audio_len: Valid sequence length data layer.
:type audio_len: Variable
:param feat_size: feature size for audio.
:type feat_size: int
:param dict_size: Dictionary size for tokenized transcription.
:type dict_size: int
:param num_conv_layers: Number of stacking convolution layers.

@ -417,32 +417,32 @@ class U2STBaseModel(nn.Layer):
best_hyps = best_hyps[:, 1:]
return best_hyps
@jit.export
@jit.to_static
def subsampling_rate(self) -> int:
""" Export interface for c++ call, return subsampling_rate of the
model
"""
return self.encoder.embed.subsampling_rate
@jit.export
@jit.to_static
def right_context(self) -> int:
""" Export interface for c++ call, return right_context of the model
"""
return self.encoder.embed.right_context
@jit.export
@jit.to_static
def sos_symbol(self) -> int:
""" Export interface for c++ call, return sos symbol id of the model
"""
return self.sos
@jit.export
@jit.to_static
def eos_symbol(self) -> int:
""" Export interface for c++ call, return eos symbol id of the model
"""
return self.eos
@jit.export
@jit.to_static
def forward_encoder_chunk(
self,
xs: paddle.Tensor,
@ -472,7 +472,7 @@ class U2STBaseModel(nn.Layer):
xs, offset, required_cache_size, subsampling_cache,
elayers_output_cache, conformer_cnn_cache)
@jit.export
@jit.to_static
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
""" Export interface for c++ call, apply linear transform and log
softmax before ctc
@ -483,7 +483,7 @@ class U2STBaseModel(nn.Layer):
"""
return self.ctc.log_softmax(xs)
@jit.export
@jit.to_static
def forward_attention_decoder(
self,
hyps: paddle.Tensor,

@ -16,23 +16,23 @@ import argparse
def default_argument_parser():
r"""A simple yet genral argument parser for experiments with parakeet.
This is used in examples with parakeet. And it is intended to be used by
other experiments with parakeet. It requires a minimal set of command line
This is used in examples with parakeet. And it is intended to be used by
other experiments with parakeet. It requires a minimal set of command line
arguments to start a training script.
The ``--config`` and ``--opts`` are used for overwrite the deault
The ``--config`` and ``--opts`` are used for overwrite the deault
configuration.
The ``--data`` and ``--output`` specifies the data path and output path.
Resuming training from existing progress at the output directory is the
The ``--data`` and ``--output`` specifies the data path and output path.
Resuming training from existing progress at the output directory is the
intended default behavior.
The ``--checkpoint_path`` specifies the checkpoint to load from.
The ``--device`` and ``--nprocs`` specifies how to run the training.
See Also
--------
parakeet.training.experiment
@ -47,28 +47,24 @@ def default_argument_parser():
# data and output
parser.add_argument("--config", metavar="FILE", help="path of the config file to overwrite to default config with.")
parser.add_argument("--dump-config", metavar="FILE", help="dump config to yaml file.")
# parser.add_argument("--data", metavar="DATA_DIR", help="path to the datatset.")
parser.add_argument("--output", metavar="OUTPUT_DIR", help="path to save checkpoint and logs.")
# load from saved checkpoint
parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load")
# save jit model to
parser.add_argument("--export_path", type=str, help="path of the jit model to save")
# save asr result to
parser.add_argument("--result_file", type=str, help="path of save the asr result")
# running
parser.add_argument("--device", type=str, default='gpu', choices=["cpu", "gpu"],
help="device type to use, cpu and gpu are supported.")
parser.add_argument("--nprocs", type=int, default=1, help="number of parallel processes to use.")
# overwrite extra config and default config
# parser.add_argument("--opts", nargs=argparse.REMAINDER,
# parser.add_argument("--opts", nargs=argparse.REMAINDER,
# help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
parser.add_argument("--opts", type=str, default=[], nargs='+',
help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
parser.add_argument("--seed", type=int, default=None,
help="seed to use for paddle, np and random. None or 0 for random, else set seed.")
# yapd: enable
return parser

@ -0,0 +1,41 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable
from .extension import Extension
def make_extension(trigger: Callable=None,
default_name: str=None,
priority: int=None,
finalizer: Callable=None,
initializer: Callable=None,
on_error: Callable=None):
"""Make an Extension-like object by injecting required attributes to it.
"""
if trigger is None:
trigger = Extension.trigger
if priority is None:
priority = Extension.priority
def decorator(ext):
ext.trigger = trigger
ext.default_name = default_name or ext.__name__
ext.priority = priority
ext.finalize = finalizer
ext.on_error = on_error
ext.initialize = initializer
return ext
return decorator

@ -0,0 +1,71 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
import extension
import paddle
from paddle.io import DataLoader
from paddle.nn import Layer
from ..reporter import DictSummary
from ..reporter import report
from ..reporter import scope
class StandardEvaluator(extension.Extension):
trigger = (1, 'epoch')
default_name = 'validation'
priority = extension.PRIORITY_WRITER
name = None
def __init__(self, model: Layer, dataloader: DataLoader):
# it is designed to hold multiple models
models = {"main": model}
self.models: Dict[str, Layer] = models
self.model = model
# dataloaders
self.dataloader = dataloader
def evaluate_core(self, batch):
# compute
self.model(batch) # you may report here
def evaluate(self):
# switch to eval mode
for model in self.models.values():
model.eval()
# to average evaluation metrics
summary = DictSummary()
for batch in self.dataloader:
observation = {}
with scope(observation):
# main evaluation computation here.
with paddle.no_grad():
self.evaluate_core(batch)
summary.add(observation)
summary = summary.compute_mean()
return summary
def __call__(self, trainer=None):
# evaluate and report the averaged metric to current observation
# if it is used to extend a trainer, the metrics is reported to
# to observation of the trainer
# or otherwise, you can use your own observation
summary = self.evaluate()
for k, v in summary.items():
report(k, v)

@ -0,0 +1,52 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
PRIORITY_WRITER = 300
PRIORITY_EDITOR = 200
PRIORITY_READER = 100
class Extension():
"""Extension to customize the behavior of Trainer."""
trigger = (1, 'iteration')
priority = PRIORITY_READER
name = None
@property
def default_name(self):
"""Default name of the extension, class name by default."""
return type(self).__name__
def __call__(self, trainer):
"""Main action of the extention. After each update, it is executed
when the trigger fires."""
raise NotImplementedError(
'Extension implementation must override __call__.')
def initialize(self, trainer):
"""Action that is executed once to get the corect trainer state.
It is called before training normally, but if the trainer restores
states with an Snapshot extension, this method should also be called.
"""
pass
def on_error(self, trainer, exc, tb):
"""Handles the error raised during training before finalization.
"""
pass
def finalize(self, trainer):
"""Action that is executed when training is done.
For example, visualizers would need to be closed.
"""
pass

@ -0,0 +1,114 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from datetime import datetime
from pathlib import Path
from typing import Any
from typing import Dict
from typing import List
import jsonlines
from deepspeech.training.extensions import extension
from deepspeech.training.updaters.trainer import Trainer
from deepspeech.utils.log import Log
from deepspeech.utils.mp_tools import rank_zero_only
logger = Log(__name__).getlog()
def load_records(records_fp):
"""Load record files (json lines.)"""
with jsonlines.open(records_fp, 'r') as reader:
records = list(reader)
return records
class Snapshot(extension.Extension):
"""An extension to make snapshot of the updater object inside
the trainer. It is done by calling the updater's `save` method.
An Updater save its state_dict by default, which contains the
updater state, (i.e. epoch and iteration) and all the model
parameters and optimizer states. If the updater inside the trainer
subclasses StandardUpdater, everything is good to go.
Parameters
----------
checkpoint_dir : Union[str, Path]
The directory to save checkpoints into.
"""
trigger = (1, 'epoch')
priority = -100
default_name = "snapshot"
def __init__(self, max_size: int=5, snapshot_on_error: bool=False):
self.records: List[Dict[str, Any]] = []
self.max_size = max_size
self._snapshot_on_error = snapshot_on_error
self._save_all = (max_size == -1)
self.checkpoint_dir = None
def initialize(self, trainer: Trainer):
"""Setting up this extention."""
self.checkpoint_dir = trainer.out / "checkpoints"
# load existing records
record_path: Path = self.checkpoint_dir / "records.jsonl"
if record_path.exists():
logger.debug("Loading from an existing checkpoint dir")
self.records = load_records(record_path)
trainer.updater.load(self.records[-1]['path'])
def on_error(self, trainer, exc, tb):
if self._snapshot_on_error:
self.save_checkpoint_and_update(trainer)
def __call__(self, trainer: Trainer):
self.save_checkpoint_and_update(trainer)
def full(self):
"""Whether the number of snapshots it keeps track of is greater
than the max_size."""
return (not self._save_all) and len(self.records) > self.max_size
@rank_zero_only
def save_checkpoint_and_update(self, trainer: Trainer):
"""Saving new snapshot and remove the oldest snapshot if needed."""
iteration = trainer.updater.state.iteration
epoch = trainer.updater.state.epoch
num = epoch if self.trigger[1] == 'epoch' else iteration
path = self.checkpoint_dir / f"{num}.pdz"
# add the new one
trainer.updater.save(path)
record = {
"time": str(datetime.now()),
'path': str(path.resolve()), # use absolute path
'iteration': iteration,
'epoch': epoch,
}
self.records.append(record)
# remove the earist
if self.full():
eariest_record = self.records[0]
os.remove(eariest_record["path"])
self.records.pop(0)
# update the record file
record_path = self.checkpoint_dir / "records.jsonl"
with jsonlines.open(record_path, 'w') as writer:
for record in self.records:
# jsonlines.open may return a Writer or a Reader
writer.write(record) # pylint: disable=no-member

@ -0,0 +1,37 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from deepspeech.training.extensions import extension
from deepspeech.training.updaters.trainer import Trainer
class VisualDL(extension.Extension):
"""A wrapper of visualdl log writer. It assumes that the metrics to be visualized
are all scalars which are recorded into the `.observation` dictionary of the
trainer object. The dictionary is created for each step, thus the visualdl log
writer uses the iteration from the updater's `iteration` as the global step to
add records.
"""
trigger = (1, 'iteration')
default_name = 'visualdl'
priority = extension.PRIORITY_READER
def __init__(self, writer):
self.writer = writer
def __call__(self, trainer: Trainer):
for k, v in trainer.observation.items():
self.writer.add_scalar(k, v, step=trainer.updater.state.iteration)
def finalize(self, trainer):
self.writer.close()

@ -0,0 +1,144 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import math
from collections import defaultdict
OBSERVATIONS = None
@contextlib.contextmanager
def scope(observations):
# make `observation` the target to report to.
# it is basically a dictionary that stores temporary observations
global OBSERVATIONS
old = OBSERVATIONS
OBSERVATIONS = observations
try:
yield
finally:
OBSERVATIONS = old
def get_observations():
global OBSERVATIONS
return OBSERVATIONS
def report(name, value):
# a simple function to report named value
# you can use it everywhere, it will get the default target and writ to it
# you can think of it as std.out
observations = get_observations()
if observations is None:
return
else:
observations[name] = value
class Summary():
"""Online summarization of a sequence of scalars.
Summary computes the statistics of given scalars online.
"""
def __init__(self):
self._x = 0.0
self._x2 = 0.0
self._n = 0
def add(self, value, weight=1):
"""Adds a scalar value.
Args:
value: Scalar value to accumulate. It is either a NumPy scalar or
a zero-dimensional array (on CPU or GPU).
weight: An optional weight for the value. It is a NumPy scalar or
a zero-dimensional array (on CPU or GPU).
Default is 1 (integer).
"""
self._x += weight * value
self._x2 += weight * value * value
self._n += weight
def compute_mean(self):
"""Computes the mean."""
x, n = self._x, self._n
return x / n
def make_statistics(self):
"""Computes and returns the mean and standard deviation values.
Returns:
tuple: Mean and standard deviation values.
"""
x, n = self._x, self._n
mean = x / n
var = self._x2 / n - mean * mean
std = math.sqrt(var)
return mean, std
class DictSummary():
"""Online summarization of a sequence of dictionaries.
``DictSummary`` computes the statistics of a given set of scalars online.
It only computes the statistics for scalar values and variables of scalar
values in the dictionaries.
"""
def __init__(self):
self._summaries = defaultdict(Summary)
def add(self, d):
"""Adds a dictionary of scalars.
Args:
d (dict): Dictionary of scalars to accumulate. Only elements of
scalars, zero-dimensional arrays, and variables of
zero-dimensional arrays are accumulated. When the value
is a tuple, the second element is interpreted as a weight.
"""
summaries = self._summaries
for k, v in d.items():
w = 1
if isinstance(v, tuple):
v = v[0]
w = v[1]
summaries[k].add(v, weight=w)
def compute_mean(self):
"""Creates a dictionary of mean values.
It returns a single dictionary that holds a mean value for each entry
added to the summary.
Returns:
dict: Dictionary of mean values.
"""
return {
name: summary.compute_mean()
for name, summary in self._summaries.items()
}
def make_statistics(self):
"""Creates a dictionary of statistics.
It returns a single dictionary that holds mean and standard deviation
values for every entry added to the summary. For an entry of name
``'key'``, these values are added to the dictionary by names ``'key'``
and ``'key.std'``, respectively.
Returns:
dict: Dictionary of statistics of all entries.
"""
stats = {}
for name, summary in self._summaries.items():
mean, std = summary.make_statistics()
stats[name] = mean
stats[name + '.std'] = std
return stats

@ -21,6 +21,7 @@ from tensorboardX import SummaryWriter
from deepspeech.utils import mp_tools
from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log
from deepspeech.utils.utility import seed_all
__all__ = ["Trainer"]
@ -94,6 +95,10 @@ class Trainer():
self.iteration = 0
self.epoch = 0
if args.seed:
seed_all(args.seed)
logger.info(f"Set seed {args.seed}")
def setup(self):
"""Setup the experiment.
"""
@ -172,8 +177,10 @@ class Trainer():
"""Reset the train loader seed and increment `epoch`.
"""
self.epoch += 1
if self.parallel:
self.train_loader.batch_sampler.set_epoch(self.epoch)
if self.parallel and hasattr(self.train_loader, "batch_sampler"):
batch_sampler = self.train_loader.batch_sampler
if isinstance(batch_sampler, paddle.io.DistributedBatchSampler):
batch_sampler.set_epoch(self.epoch)
def train(self):
"""The training process control by epoch."""
@ -182,7 +189,7 @@ class Trainer():
# save init model, i.e. 0 epoch
self.save(tag='init', infos=None)
self.lr_scheduler.step(self.epoch)
if self.parallel:
if self.parallel and hasattr(self.train_loader, "batch_sampler"):
self.train_loader.batch_sampler.set_epoch(self.epoch)
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")

@ -0,0 +1,28 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .interval_trigger import IntervalTrigger
def never_fail_trigger(trainer):
return False
def get_trigger(trigger):
if trigger is None:
return never_fail_trigger
if callable(trigger):
return trigger
else:
trigger = IntervalTrigger(*trigger)
return trigger

@ -0,0 +1,38 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class IntervalTrigger():
"""A Predicate to do something every N cycle."""
def __init__(self, period: int, unit: str):
if unit not in ("iteration", "epoch"):
raise ValueError("unit should be 'iteration' or 'epoch'")
if period <= 0:
raise ValueError("period should be a positive integer.")
self.period = period
self.unit = unit
self.last_index = None
def __call__(self, trainer):
if self.last_index is None:
last_index = getattr(trainer.updater.state, self.unit)
self.last_index = last_index
last_index = self.last_index
index = getattr(trainer.updater.state, self.unit)
fire = index // self.period != last_index // self.period
self.last_index = index
return fire

@ -0,0 +1,31 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class LimitTrigger():
"""A Predicate to decide whether to stop."""
def __init__(self, limit: int, unit: str):
if unit not in ("iteration", "epoch"):
raise ValueError("unit should be 'iteration' or 'epoch'")
if limit <= 0:
raise ValueError("limit should be a positive integer.")
self.limit = limit
self.unit = unit
def __call__(self, trainer):
state = trainer.updater.state
index = getattr(state, self.unit)
fire = index >= self.limit
return fire

@ -0,0 +1,32 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class TimeTrigger():
"""Trigger based on a fixed time interval.
This trigger accepts iterations with a given interval time.
Args:
period (float): Interval time. It is given in seconds.
"""
def __init__(self, period):
self._period = period
self._next_time = self._period
def __call__(self, trainer):
if self._next_time < trainer.elapsed_time:
self._next_time += self._period
return True
else:
return False

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,192 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from typing import Optional
from paddle import Tensor
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from paddle.nn import Layer
from paddle.optimizer import Optimizer
from timer import timer
from deepspeech.training.reporter import report
from deepspeech.training.updaters.updater import UpdaterBase
from deepspeech.training.updaters.updater import UpdaterState
from deepspeech.utils.log import Log
__all__ = ["StandardUpdater"]
logger = Log(__name__).getlog()
class StandardUpdater(UpdaterBase):
"""An example of over-simplification. Things may not be that simple, but
you can subclass it to fit your need.
"""
def __init__(self,
model: Layer,
optimizer: Optimizer,
dataloader: DataLoader,
init_state: Optional[UpdaterState]=None):
# it is designed to hold multiple models
models = {"main": model}
self.models: Dict[str, Layer] = models
self.model = model
# it is designed to hold multiple optimizers
optimizers = {"main": optimizer}
self.optimizer = optimizer
self.optimizers: Dict[str, Optimizer] = optimizers
# dataloaders
self.dataloader = dataloader
# init state
if init_state is None:
self.state = UpdaterState()
else:
self.state = init_state
self.train_iterator = iter(dataloader)
def update(self):
# We increase the iteration index after updating and before extension.
# Here are the reasons.
# 0. Snapshotting(as well as other extensions, like visualizer) is
# executed after a step of updating;
# 1. We decide to increase the iteration index after updating and
# before any all extension is executed.
# 3. We do not increase the iteration after extension because we
# prefer a consistent resume behavior, when load from a
# `snapshot_iter_100.pdz` then the next step to train is `101`,
# naturally. But if iteration is increased increased after
# extension(including snapshot), then, a `snapshot_iter_99` is
# loaded. You would need a extra increasing of the iteration idex
# before training to avoid another iteration `99`, which has been
# done before snapshotting.
# 4. Thus iteration index represrnts "currently how mant epochs has
# been done."
# NOTE: use report to capture the correctly value. If you want to
# report the learning rate used for a step, you must report it before
# the learning rate scheduler's step() has been called. In paddle's
# convention, we do not use an extension to change the learning rate.
# so if you want to report it, do it in the updater.
# Then here comes the next question. When is the proper time to
# increase the epoch index? Since all extensions are executed after
# updating, it is the time that after updating is the proper time to
# increase epoch index.
# 1. If we increase the epoch index before updating, then an extension
# based ot epoch would miss the correct timing. It could only be
# triggerd after an extra updating.
# 2. Theoretically, when an epoch is done, the epoch index should be
# increased. So it would be increase after updating.
# 3. Thus, eppoch index represents "currently how many epochs has been
# done." So it starts from 0.
# switch to training mode
for model in self.models.values():
model.train()
# training for a step is implemented here
batch = self.read_batch()
self.update_core(batch)
self.state.iteration += 1
if self.updates_per_epoch is not None:
if self.state.iteration % self.updates_per_epoch == 0:
self.state.epoch += 1
def update_core(self, batch):
"""A simple case for a training step. Basic assumptions are:
Single model;
Single optimizer;
A batch from the dataloader is just the input of the model;
The model return a single loss, or a dict containing serval losses.
Parameters updates at every batch, no gradient accumulation.
"""
loss = self.model(*batch)
if isinstance(loss, Tensor):
loss_dict = {"main": loss}
else:
# Dict[str, Tensor]
loss_dict = loss
if "main" not in loss_dict:
main_loss = 0
for loss_item in loss.values():
main_loss += loss_item
loss_dict["main"] = main_loss
for name, loss_item in loss_dict.items():
report(name, float(loss_item))
self.optimizer.clear_gradient()
loss_dict["main"].backward()
self.optimizer.update()
@property
def updates_per_epoch(self):
"""Number of updater per epoch, determined by the length of the
dataloader."""
length_of_dataloader = None
try:
length_of_dataloader = len(self.dataloader)
except TypeError:
logger.debug("This dataloader has no __len__.")
finally:
return length_of_dataloader
def new_epoch(self):
"""Start a new epoch."""
# NOTE: all batch sampler for distributed training should
# subclass DistributedBatchSampler and implement `set_epoch` method
if hasattr(self.dataloader, "batch_sampler"):
batch_sampler = self.dataloader.batch_sampler
if isinstance(batch_sampler, DistributedBatchSampler):
batch_sampler.set_epoch(self.state.epoch)
self.train_iterator = iter(self.dataloader)
def read_batch(self):
"""Read a batch from the data loader, auto renew when data is exhausted."""
with timer() as t:
try:
batch = next(self.train_iterator)
except StopIteration:
self.new_epoch()
batch = next(self.train_iterator)
logger.debug(
f"Read a batch takes {t.elapse}s.") # replace it with logger
return batch
def state_dict(self):
"""State dict of a Updater, model, optimizer and updater state are included."""
state_dict = super().state_dict()
for name, model in self.models.items():
state_dict[f"{name}_params"] = model.state_dict()
for name, optim in self.optimizers.items():
state_dict[f"{name}_optimizer"] = optim.state_dict()
return state_dict
def set_state_dict(self, state_dict):
"""Set state dict for a Updater. Parameters of models, states for
optimizers and UpdaterState are restored."""
for name, model in self.models.items():
model.set_state_dict(state_dict[f"{name}_params"])
for name, optim in self.optimizers.items():
optim.set_state_dict(state_dict[f"{name}_optimizer"])
super().set_state_dict(state_dict)

@ -0,0 +1,184 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import traceback
from collections import OrderedDict
from pathlib import Path
from typing import Callable
from typing import List
from typing import Union
import six
import tqdm
from deepspeech.training.extensions.extension import Extension
from deepspeech.training.extensions.extension import PRIORITY_READER
from deepspeech.training.reporter import scope
from deepspeech.training.triggers import get_trigger
from deepspeech.training.triggers.limit_trigger import LimitTrigger
from deepspeech.training.updaters.updater import UpdaterBase
class _ExtensionEntry():
def __init__(self, extension, trigger, priority):
self.extension = extension
self.trigger = trigger
self.priority = priority
class Trainer():
def __init__(self,
updater: UpdaterBase,
stop_trigger: Callable=None,
out: Union[str, Path]='result',
extensions: List[Extension]=None):
self.updater = updater
self.extensions = OrderedDict()
self.stop_trigger = LimitTrigger(*stop_trigger)
self.out = Path(out)
self.observation = None
self._done = False
if extensions:
for ext in extensions:
self.extend(ext)
@property
def is_before_training(self):
return self.updater.state.iteration == 0
def extend(self, extension, name=None, trigger=None, priority=None):
# get name for the extension
# argument \
# -> extention's name \
# -> default_name (class name, when it is an object) \
# -> function name when it is a function \
# -> error
if name is None:
name = getattr(extension, 'name', None)
if name is None:
name = getattr(extension, 'default_name', None)
if name is None:
name = getattr(extension, '__name__', None)
if name is None:
raise ValueError("Name is not given for the extension.")
if name == 'training':
raise ValueError("training is a reserved name.")
if trigger is None:
trigger = getattr(extension, 'trigger', (1, 'iteration'))
trigger = get_trigger(trigger)
if priority is None:
priority = getattr(extension, 'priority', PRIORITY_READER)
# add suffix to avoid nameing conflict
ordinal = 0
modified_name = name
while modified_name in self.extensions:
ordinal += 1
modified_name = f"{name}_{ordinal}"
extension.name = modified_name
self.extensions[modified_name] = _ExtensionEntry(extension, trigger,
priority)
def get_extension(self, name):
"""get extension by name."""
extensions = self.extensions
if name in extensions:
return extensions[name].extension
else:
raise ValueError(f'extension {name} not found')
def run(self):
if self._done:
raise RuntimeError("Training is already done!.")
self.out.mkdir(parents=True, exist_ok=True)
# sort extensions by priorities once
extension_order = sorted(
self.extensions.keys(),
key=lambda name: self.extensions[name].priority,
reverse=True)
extensions = [(name, self.extensions[name]) for name in extension_order]
# initializing all extensions
for name, entry in extensions:
if hasattr(entry.extension, "initialize"):
entry.extension.initialize(self)
update = self.updater.update # training step
stop_trigger = self.stop_trigger
# display only one progress bar
max_iteration = None
if isinstance(stop_trigger, LimitTrigger):
if stop_trigger.unit == 'epoch':
max_epoch = self.stop_trigger.limit
updates_per_epoch = getattr(self.updater, "updates_per_epoch",
None)
max_iteration = max_epoch * updates_per_epoch if updates_per_epoch else None
else:
max_iteration = self.stop_trigger.limit
p = tqdm.tqdm(initial=self.updater.state.iteration, total=max_iteration)
try:
while not stop_trigger(self):
self.observation = {}
# set observation as the report target
# you can use report freely in Updater.update()
# updating parameters and state
with scope(self.observation):
update()
p.update()
# execute extension when necessary
for name, entry in extensions:
if entry.trigger(self):
entry.extension(self)
# print("###", self.observation)
except Exception as e:
f = sys.stderr
f.write(f"Exception in main training loop: {e}\n")
f.write("Traceback (most recent call last):\n")
traceback.print_tb(sys.exc_info()[2])
f.write(
"Trainer extensions will try to handle the extension. Then all extensions will finalize."
)
# capture the exception in the mian training loop
exc_info = sys.exc_info()
# try to handle it
for name, entry in extensions:
if hasattr(entry.extension, "on_error"):
try:
entry.extension.on_error(self, e, sys.exc_info()[2])
except Exception as ee:
f.write(f"Exception in error handler: {ee}\n")
f.write('Traceback (most recent call last):\n')
traceback.print_tb(sys.exc_info()[2])
# raise exception in main training loop
six.reraise(*exc_info)
finally:
for name, entry in extensions:
if hasattr(entry.extension, "finalize"):
entry.extension.finalize(self)

@ -0,0 +1,83 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
import paddle
from deepspeech.utils.log import Log
__all__ = ["UpdaterBase", "UpdaterState"]
logger = Log(__name__).getlog()
@dataclass
class UpdaterState:
iteration: int = 0
epoch: int = 0
class UpdaterBase():
"""An updater is the abstraction of how a model is trained given the
dataloader and the optimizer.
The `update_core` method is a step in the training loop with only necessary
operations (get a batch, forward and backward, update the parameters).
Other stuffs are made extensions. Visualization, saving, loading and
periodical validation and evaluation are not considered here.
But even in such simplist case, things are not that simple. There is an
attempt to standardize this process and requires only the model and
dataset and do all the stuffs automatically. But this may hurt flexibility.
If we assume a batch yield from the dataloader is just the input to the
model, we will find that some model requires more arguments, or just some
keyword arguments. But this prevents us from over-simplifying it.
From another perspective, the batch may includes not just the input, but
also the target. But the model's forward method may just need the input.
We can pass a dict or a super-long tuple to the model and let it pick what
it really needs. But this is an abuse of lazy interface.
After all, we care about how a model is trained. But just how the model is
used for inference. We want to control how a model is trained. We just
don't want to be messed up with other auxiliary code.
So the best practice is to define a model and define a updater for it.
"""
def __init__(self, init_state=None):
if init_state is None:
self.state = UpdaterState()
else:
self.state = init_state
def update(self, batch):
raise NotImplementedError(
"Implement your own `update` method for training a step.")
def state_dict(self):
state_dict = {
"epoch": self.state.epoch,
"iteration": self.state.iteration,
}
return state_dict
def set_state_dict(self, state_dict):
self.state.epoch = state_dict["epoch"]
self.state.iteration = state_dict["iteration"]
def save(self, path):
logger.debug(f"Saving to {path}.")
archive = self.state_dict()
paddle.save(archive, str(path))
def load(self, path):
logger.debug(f"Loading from {path}.")
archive = paddle.load(str(path))
self.set_state_dict(archive)

@ -15,9 +15,19 @@
import distutils.util
import math
import os
import random
from typing import List
__all__ = ['print_arguments', 'add_arguments', "log_add"]
import numpy as np
import paddle
__all__ = ["seed_all", 'print_arguments', 'add_arguments', "log_add"]
def seed_all(seed: int=210329):
np.random.seed(seed)
random.seed(seed)
paddle.seed(seed)
def print_arguments(args, info=None):

@ -1,10 +1,16 @@
# Aishell-1
## Data
| Data Subset | Duration in Seconds |
| data/manifest.train | 1.23 ~ 14.53125 |
| data/manifest.dev | 1.645 ~ 12.533 |
| data/manifest.test | 1.859125 ~ 14.6999375 |
## Deepspeech2
| Model | Params | Release | Config | Test set | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- |
| DeepSpeech2 | 58.4M | 2.2.0 | conf/deepspeech2.yaml + spec aug + new datapipe | test | 6.396368026733398 | 0.068382,0.073507 |
| DeepSpeech2 | 58.4M | 2.2.0 | conf/deepspeech2.yaml + spec aug + new datapipe | test | 6.396368026733398 | 0.068382 |
| DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml + spec aug | test | 7.483316898345947 | 0.077860 |
| DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml | test | 7.299022197723389 | 0.078671 |
| DeepSpeech2 | 58.4M | 2.0.0 | conf/deepspeech2.yaml | test | - | 0.078977 |

@ -19,15 +19,17 @@
{
"type": "specaug",
"params": {
"W": 0,
"warp_mode": "PIL",
"F": 10,
"T": 50,
"n_freq_masks": 2,
"T": 50,
"n_time_masks": 2,
"p": 1.0,
"W": 80,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20
"max_n_time_masks": 20,
"replace_with_zero": true
},
"prob": 1.0
}

@ -19,12 +19,22 @@ fi
mkdir -p exp
seed=1024
if [ ${seed} ]; then
export FLAGS_cudnn_deterministic=True
fi
python3 -u ${BIN_DIR}/train.py \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name} \
--model_type ${model_type}
--model_type ${model_type} \
--seed ${seed}
if [ ${seed} ]; then
unset FLAGS_cudnn_deterministic
fi
if [ $? -ne 0 ]; then
echo "Failed in training!"

@ -27,7 +27,9 @@
"W": 80,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20
"max_n_time_masks": 20,
"replace_with_zero": true,
"warp_mode": "PIL"
},
"prob": 1.0
}

@ -19,11 +19,21 @@ echo "using ${device}..."
mkdir -p exp
seed=1024
if [ ${seed} ]; then
export FLAGS_cudnn_deterministic=True
fi
python3 -u ${BIN_DIR}/train.py \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name}
--output exp/${ckpt_name} \
--seed ${seed}
if [ ${seed} ]; then
unset FLAGS_cudnn_deterministic
fi
if [ $? -ne 0 ]; then
echo "Failed in training!"

@ -1,10 +0,0 @@
[
{
"type": "shift",
"params": {
"min_shift_ms": -5,
"max_shift_ms": 5
},
"prob": 1.0
}
]

@ -52,16 +52,18 @@
{
"type": "specaug",
"params": {
"W": 80,
"warp_mode": "PIL",
"F": 10,
"T": 50,
"n_freq_masks": 2,
"T": 50,
"n_time_masks": 2,
"p": 1.0,
"W": 80,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20
"max_n_time_masks": 20,
"replace_with_zero": false
},
"prob": 0.0
"prob": 1.0
}
]

@ -27,7 +27,8 @@
"W": 80,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20
"max_n_time_masks": 20,
"replace_with_zero": true
},
"prob": 1.0
}

@ -19,11 +19,21 @@ echo "using ${device}..."
mkdir -p exp
seed=1024
if [ ${seed} ]; then
export FLAGS_cudnn_deterministic=True
fi
python3 -u ${BIN_DIR}/train.py \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name}
--output exp/${ckpt_name} \
--seed ${seed}
if [ ${seed} ]; then
unset FLAGS_cudnn_deterministic
fi
if [ $? -ne 0 ]; then
echo "Failed in training!"

@ -19,15 +19,17 @@
{
"type": "specaug",
"params": {
"W": 0,
"warp_mode": "PIL",
"F": 10,
"T": 50,
"n_freq_masks": 2,
"T": 50,
"n_time_masks": 2,
"p": 1.0,
"W": 80,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20
"max_n_time_masks": 20,
"replace_with_zero": true
},
"prob": 1.0
}

@ -20,12 +20,22 @@ echo "using ${device}..."
mkdir -p exp
seed=1024
if [ ${seed} ]; then
export FLAGS_cudnn_deterministic=True
fi
python3 -u ${BIN_DIR}/train.py \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name} \
--model_type ${model_type}
--model_type ${model_type} \
--seed ${seed}
if [ ${seed} ]; then
unset FLAGS_cudnn_deterministic
fi
if [ $? -ne 0 ]; then
echo "Failed in training!"

@ -27,7 +27,9 @@
"W": 80,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20
"max_n_time_masks": 20,
"replace_with_zero": true,
"warp_mode": "PIL"
},
"prob": 1.0
}

@ -19,11 +19,21 @@ echo "using ${device}..."
mkdir -p exp
seed=1024
if [ ${seed} ]; then
export FLAGS_cudnn_deterministic=True
fi
python3 -u ${BIN_DIR}/train.py \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name}
--output exp/${ckpt_name} \
--seed ${seed}
if [ ${seed} ]; then
unset FLAGS_cudnn_deterministic
fi
if [ $? -ne 0 ]; then
echo "Failed in training!"

@ -2,15 +2,17 @@
{
"type": "specaug",
"params": {
"F": 10,
"T": 50,
"W": 5,
"warp_mode": "PIL",
"F": 30,
"n_freq_masks": 2,
"T": 40,
"n_time_masks": 2,
"p": 1.0,
"W": 80,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20
"max_n_time_masks": 20,
"replace_with_zero": false
},
"prob": 1.0
}

@ -3,37 +3,28 @@ data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test-clean
min_input_len: 0.5 # second
max_input_len: 20.0 # second
min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
vocab_filepath: data/vocab.txt
vocab_filepath: data/train_960_unigram5000_units.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 64
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
spm_model_prefix: 'data/train_960_unigram5000'
feat_dim: 83
delta_delta: False
dither: 1.0
target_sample_rate: 16000
max_freq: None
n_fft: None
stride_ms: 10.0
window_ms: 25.0
use_dB_normalization: True
target_dB: -20
random_seed: 0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 32
maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
minibatches: 0 # for debug
batch_count: auto
batch_bins: 0
batch_frames_in: 0
batch_frames_out: 0
batch_frames_inout: 0
augmentation_config: conf/augmentation.json
num_workers: 2
subsampling_factor: 1
num_encs: 1
# network architecture

@ -1,7 +1,7 @@
#!/bin/bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
if [ $# != 3 ];then
echo "usage: ${0} config_path dict_path ckpt_path_prefix"
exit -1
fi
@ -13,7 +13,8 @@ if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
dict_path=$2
ckpt_prefix=$3
batch_size=1
output_dir=${ckpt_prefix}
@ -22,11 +23,13 @@ mkdir -p ${output_dir}
# align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file`
python3 -u ${BIN_DIR}/test.py \
--run_mode 'align' \
--model-name 'u2_kaldi' \
--run-mode 'align' \
--dict-path ${dict_path} \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${output_dir}/${type}.align \
--result-file ${output_dir}/${type}.align \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.batch_size ${batch_size}

@ -18,7 +18,8 @@ if [ ${ngpu} == 0 ];then
fi
python3 -u ${BIN_DIR}/test.py \
--run_mode 'export' \
--model-name 'u2_kaldi' \
--run-mode 'export' \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \

@ -1,7 +1,7 @@
#!/bin/bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
if [ $# != 3 ];then
echo "usage: ${0} config_path dict_path ckpt_path_prefix"
exit -1
fi
@ -14,7 +14,8 @@ if [ ${ngpu} == 0 ];then
fi
config_path=$1
ckpt_prefix=$2
dict_path=$2
ckpt_prefix=$3
chunk_mode=false
if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
@ -38,11 +39,13 @@ for type in attention ctc_greedy_search; do
batch_size=64
fi
python3 -u ${BIN_DIR}/test.py \
--run_mode test \
--model-name u2_kaldi \
--run-mode test \
--dict-path ${dict_path} \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \
--result-file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} decoding.batch_size ${batch_size}
@ -56,11 +59,13 @@ for type in ctc_prefix_beam_search attention_rescoring; do
echo "decoding ${type}"
batch_size=1
python3 -u ${BIN_DIR}/test.py \
--run_mode test \
--model-name u2_kaldi \
--run-mode test \
--dict-path ${dict_path} \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \
--result-file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} decoding.batch_size ${batch_size}

@ -19,12 +19,22 @@ echo "using ${device}..."
mkdir -p exp
seed=1024
if [ ${seed} ]; then
export FLAGS_cudnn_deterministic=True
fi
python3 -u ${BIN_DIR}/train.py \
--model-name u2_kaldi \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name}
--output exp/${ckpt_name} \
--seed ${seed}
if [ ${seed} ]; then
unset FLAGS_cudnn_deterministic
fi
if [ $? -ne 0 ]; then
echo "Failed in training!"

@ -5,6 +5,7 @@ source path.sh
stage=0
stop_stage=100
conf_path=conf/transformer.yaml
dict_path=data/train_960_unigram5000_units.txt
avg_num=5
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
@ -29,12 +30,12 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${dict_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# ctc alignment of test data
CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} ${dict_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then

@ -19,11 +19,21 @@ echo "using ${device}..."
mkdir -p exp
seed=1024
if [ ${seed} ]; then
export FLAGS_cudnn_deterministic=True
fi
python3 -u ${BIN_DIR}/train.py \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name}
--output exp/${ckpt_name} \
--seed ${seed}
if [ ${seed} ]; then
unset FLAGS_cudnn_deterministic
fi
if [ $? -ne 0 ]; then
echo "Failed in training!"

@ -20,27 +20,33 @@ if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
echo "Prepare THCHS-30 failed. Terminated."
exit 1
fi
fi
# dump manifest to data/
python3 ${MAIN_ROOT}/utils/dump_manifest.py --manifest-path=data/manifest.train --output-dir=data
# copy files to data/dict to gen word.lexicon
cp ${TARGET_DIR}/thchs30/data_thchs30/lm_word/lexicon.txt data/dict/lm_word_lexicon_1
cp ${TARGET_DIR}/thchs30/resource/dict/lexicon.txt data/dict/lm_word_lexicon_2
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# dump manifest to data/
python3 ${MAIN_ROOT}/utils/dump_manifest.py --manifest-path=data/manifest.train --output-dir=data
fi
# copy phone.lexicon to data/dict
cp ${TARGET_DIR}/thchs30/data_thchs30/lm_phone/lexicon.txt data/dict/phone.lexicon
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# copy files to data/dict to gen word.lexicon
cp ${TARGET_DIR}/thchs30/data_thchs30/lm_word/lexicon.txt data/dict/lm_word_lexicon_1
cp ${TARGET_DIR}/thchs30/resource/dict/lexicon.txt data/dict/lm_word_lexicon_2
# copy phone.lexicon to data/dict
cp ${TARGET_DIR}/thchs30/data_thchs30/lm_phone/lexicon.txt data/dict/phone.lexicon
fi
# gen word.lexicon
python local/gen_word2phone.py --root-dir=data/dict --output-dir=data/dict
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# gen word.lexicon
python local/gen_word2phone.py --lexicon-files="data/dict/lm_word_lexicon_1 data/dict/lm_word_lexicon_2" --output-path=data/dict/word.lexicon
fi
# reorganize dataset for MFA
if [ ! -d $EXP_DIR/thchs30_corpus ]; then
echo "reorganizing thchs30 corpus..."
python local/reorganize_thchs30.py --root-dir=data --output-dir=data/thchs30_corpus --script-type=$LEXICON_NAME
echo "reorganization done."
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# reorganize dataset for MFA
if [ ! -d $EXP_DIR/thchs30_corpus ]; then
echo "reorganizing thchs30 corpus..."
python local/reorganize_thchs30.py --root-dir=data --output-dir=data/thchs30_corpus --script-type=$LEXICON_NAME
echo "reorganization done."
fi
fi
echo "THCHS-30 data preparation done."

@ -18,6 +18,7 @@ file2: THCHS-30/resource/dict/lexicon.txt
import argparse
from collections import defaultdict
from pathlib import Path
from typing import List
from typing import Union
# key: (cn, ('ee', 'er4'))value: count
@ -34,7 +35,7 @@ def is_Chinese(ch):
return False
def proc_line(line):
def proc_line(line: str):
line = line.strip()
if is_Chinese(line[0]):
line_list = line.split()
@ -49,20 +50,25 @@ def proc_line(line):
cn_phones_counter[(cn, phones)] += 1
def gen_lexicon(root_dir: Union[str, Path], output_dir: Union[str, Path]):
root_dir = Path(root_dir).expanduser()
output_dir = Path(output_dir).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
file1 = root_dir / "lm_word_lexicon_1"
file2 = root_dir / "lm_word_lexicon_2"
write_file = output_dir / "word.lexicon"
"""
example lines of output
the first column is a Chinese character
the second is the probability of this pronunciation
and the rest are the phones of this pronunciation
0.22 ii i1
0.45 ii i4
0.32 ii i2
0.01 ii i5
"""
def gen_lexicon(lexicon_files: List[Union[str, Path]],
output_path: Union[str, Path]):
for file_path in lexicon_files:
with open(file_path, "r") as f1:
for line in f1:
proc_line(line)
with open(file1, "r") as f1:
for line in f1:
proc_line(line)
with open(file2, "r") as f2:
for line in f2:
proc_line(line)
for key in cn_phones_counter:
cn = key[0]
cn_counter[cn].append((key[1], cn_phones_counter[key]))
@ -75,7 +81,8 @@ def gen_lexicon(root_dir: Union[str, Path], output_dir: Union[str, Path]):
p = round(p, 2)
if p > 0:
cn_counter_p[key].append((item[0], p))
with open(write_file, "w") as wf:
with open(output_path, "w") as wf:
for key in cn_counter_p:
phone_p_list = cn_counter_p[key]
for item in phone_p_list:
@ -87,8 +94,21 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Gen Chinese characters to phone lexicon for THCHS-30 dataset"
)
# A line of word_lexicon:
# 一丁点 ii i4 d ing1 d ian3
# the first is word, and the rest are the phones of the word, and the len of phones is twice of the word's len
parser.add_argument(
"--lexicon-files",
type=str,
default="data/dict/lm_word_lexicon_1 data/dict/lm_word_lexicon_2",
help="lm_word_lexicon files")
parser.add_argument(
"--root-dir", type=str, help="dir to thchs30 lm_word_lexicons")
parser.add_argument("--output-dir", type=str, help="path to save outputs")
"--output-path",
type=str,
default="data/dict/word.lexicon",
help="path to save output word2phone lexicon")
args = parser.parse_args()
gen_lexicon(args.root_dir, args.output_dir)
lexicon_files = args.lexicon_files.split(" ")
output_path = Path(args.output_path).expanduser()
gen_lexicon(lexicon_files, output_path)

@ -58,8 +58,6 @@ def write_lab(root_dir: Union[str, Path],
def reorganize_thchs30(root_dir: Union[str, Path],
output_dir: Union[str, Path]=None,
script_type='phone'):
root_dir = Path(root_dir).expanduser()
output_dir = Path(output_dir).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
link_wav(root_dir, output_dir)
write_lab(root_dir, output_dir, script_type)
@ -72,12 +70,15 @@ if __name__ == "__main__":
parser.add_argument(
"--output-dir",
type=str,
help="path to save outputs(audio and transcriptions)")
help="path to save outputs (audio and transcriptions)")
parser.add_argument(
"--script-type",
type=str,
default="phone",
help="type of lab ('word'/'syllable'/'phone')")
args = parser.parse_args()
reorganize_thchs30(args.root_dir, args.output_dir, args.script_type)
root_dir = Path(args.root_dir).expanduser()
output_dir = Path(args.output_dir).expanduser()
reorganize_thchs30(root_dir, output_dir, args.script_type)

@ -14,14 +14,17 @@ source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
# gen lexicon relink gen dump
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
bash ./local/data.sh $LEXICON_NAME|| exit -1
echo "Start prepare thchs30 data for MFA ..."
bash ./local/data.sh $LEXICON_NAME || exit -1
fi
# run MFA
if [ ! -d "$EXP_DIR/thchs30_alignment" ]; then
echo "Start MFA training..."
mfa_train_and_align data/thchs30_corpus data/dict/$LEXICON_NAME.lexicon $EXP_DIR/thchs30_alignment -o $EXP_DIR/thchs30_model --clean --verbose --temp_directory exp/.mfa_train_and_align --num_jobs $NUM_JOBS
echo "training done! \nresults: $EXP_DIR/thchs30_alignment \nmodel: $EXP_DIR/thchs30_model\n"
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# run MFA
if [ ! -d "$EXP_DIR/thchs30_alignment" ]; then
echo "Start MFA training ..."
mfa_train_and_align data/thchs30_corpus data/dict/$LEXICON_NAME.lexicon $EXP_DIR/thchs30_alignment -o $EXP_DIR/thchs30_model --clean --verbose --temp_directory exp/.mfa_train_and_align --num_jobs $NUM_JOBS
echo "MFA training done! \nresults: $EXP_DIR/thchs30_alignment \nmodel: $EXP_DIR/thchs30_model\n"
fi
fi

@ -27,7 +27,9 @@
"W": 80,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20
"max_n_time_masks": 20,
"replace_with_zero": true,
"warp_mode": "PIL"
},
"prob": 1.0
}

@ -19,11 +19,21 @@ echo "using ${device}..."
mkdir -p exp
seed=1024
if [ ${seed} ]; then
export FLAGS_cudnn_deterministic=True
fi
python3 -u ${BIN_DIR}/train.py \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name}
--output exp/${ckpt_name} \
--seed ${seed}
if [ ${seed} ]; then
unset FLAGS_cudnn_deterministic
fi
if [ $? -ne 0 ]; then
echo "Failed in training!"

@ -1,4 +1,13 @@
[
{
"type": "speed",
"params": {
"min_speed_rate": 0.9,
"max_speed_rate": 1.1,
"num_rates": 3
},
"prob": 0.0
},
{
"type": "shift",
"params": {
@ -6,5 +15,22 @@
"max_shift_ms": 5
},
"prob": 1.0
},
{
"type": "specaug",
"params": {
"W": 5,
"warp_mode": "PIL",
"F": 30,
"n_freq_masks": 2,
"T": 40,
"n_time_masks": 2,
"p": 1.0,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20,
"replace_with_zero": true
},
"prob": 1.0
}
]

@ -19,12 +19,22 @@ fi
mkdir -p exp
seed=1024
if [ ${seed} ]; then
export FLAGS_cudnn_deterministic=True
fi
python3 -u ${BIN_DIR}/train.py \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name} \
--model_type ${model_type}
--model_type ${model_type} \
--seed ${seed}
if [ ${seed} ]; then
unset FLAGS_cudnn_deterministic
fi
if [ $? -ne 0 ]; then
echo "Failed in training!"

@ -27,7 +27,9 @@
"W": 80,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20
"max_n_time_masks": 20,
"replace_with_zero": true,
"warp_mode": "PIL"
},
"prob": 1.0
}

@ -18,11 +18,21 @@ fi
mkdir -p exp
seed=1024
if [ ${seed} ]; then
export FLAGS_cudnn_deterministic=True
fi
python3 -u ${BIN_DIR}/train.py \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name}
--output exp/${ckpt_name} \
--seed ${seed}
if [ ${seed} ]; then
unset FLAGS_cudnn_deterministic
fi
if [ $? -ne 0 ]; then
echo "Failed in training!"

@ -1,6 +1,8 @@
coverage
gpustat
jsonlines
kaldiio
Pillow
pre-commit
pybind11
resampy==0.2.2

@ -4,7 +4,7 @@
test -d Montreal-Forced-Aligner || git clone https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner.git
pushd Montreal-Forced-Aligner && git checkout v2.0.0a7 && python setup.py install
pushd Montreal-Forced-Aligner && python setup.py install && popd
test -d kaldi || { echo "need install kaldi first"; exit 1;}

Loading…
Cancel
Save