From 32da83b098e1bce66ad0fd8205a15c8e97531eb0 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 19 May 2021 10:52:42 +0800 Subject: [PATCH] train ds2 model (#622) * default cmvn compute config; more log of grad clip; diff ds2 cmvn compute and conf; ds2 lr step by epoch; * fix ds2 config * fix install and egs link * sox speed pertrub shape (T, C), float64, process using int32 * fix libri ds2 scripts; add ngram and spm doc * aishell ds2 cer7.86 * fix ds2 result --- README.md | 4 ++-- README_cn.md | 4 ++-- deepspeech/exps/deepspeech2/model.py | 2 -- deepspeech/frontend/audio.py | 4 +++- deepspeech/training/gradclip.py | 15 +++++++++------ deepspeech/training/trainer.py | 3 +-- deepspeech/utils/layer_tools.py | 4 ++-- examples/aishell/s0/README.md | 11 ++++++----- examples/aishell/s0/conf/augmentation.json | 9 +++++++++ examples/aishell/s0/conf/deepspeech2.yaml | 8 ++++---- examples/aishell/s0/local/data.sh | 6 +++--- examples/librispeech/s0/README.md | 2 +- examples/librispeech/s0/conf/augmentation.json | 9 +++++++++ examples/librispeech/s0/conf/deepspeech2.yaml | 10 +++++----- examples/librispeech/s0/local/data.sh | 12 ++++++------ examples/librispeech/s0/run.sh | 2 +- examples/ngram_lm/README.md | 7 +++++++ examples/spm/README.md | 4 +++- examples/tiny/s1/conf/augmentation.json | 16 ++++++++-------- tests/mask_test.py | 18 ++++++++++-------- utils/compute_mean_std.py | 4 ++-- 21 files changed, 93 insertions(+), 61 deletions(-) create mode 100644 examples/ngram_lm/README.md diff --git a/README.md b/README.md index a2d2f9a56..eb1814906 100644 --- a/README.md +++ b/README.md @@ -21,11 +21,11 @@ * python>=3.7 * paddlepaddle>=2.1.0 -Please see [install](doc/install.md). +Please see [install](doc/src/install.md). ## Getting Started -Please see [Getting Started](doc/src/getting_started.md) and [tiny egs](examples/tiny/README.md). +Please see [Getting Started](doc/src/getting_started.md) and [tiny egs](examples/tiny/s0/README.md). ## More Information diff --git a/README_cn.md b/README_cn.md index 3c1111b5e..e1a38906e 100644 --- a/README_cn.md +++ b/README_cn.md @@ -22,11 +22,11 @@ * python>=3.7 * paddlepaddle>=2.1.0 -参看 [安装](doc/install.md)。 +参看 [安装](doc/src/install.md)。 ## 开始 -请查看 [Getting Started](doc/src/getting_started.md) 和 [tiny egs](examples/tiny/README.md)。 +请查看 [Getting Started](doc/src/getting_started.md) 和 [tiny egs](examples/tiny/s0/README.md)。 ## 更多信息 diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index c1fe82250..643936f17 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -43,13 +43,11 @@ class DeepSpeech2Trainer(Trainer): def train_batch(self, batch_index, batch_data, msg): start = time.time() - loss = self.model(*batch_data) loss.backward() layer_tools.print_grads(self.model, print_func=None) self.optimizer.step() self.optimizer.clear_grad() - iteration_time = time.time() - start losses_np = { diff --git a/deepspeech/frontend/audio.py b/deepspeech/frontend/audio.py index 4488f5f2e..ffdcd4b3a 100644 --- a/deepspeech/frontend/audio.py +++ b/deepspeech/frontend/audio.py @@ -351,7 +351,9 @@ class AudioSegment(object): tfm.set_globals(multithread=False) tfm.speed(speed_rate) self._samples = tfm.build_array( - input_array=self._samples, sample_rate_in=self._sample_rate).copy() + input_array=self._samples, + sample_rate_in=self._sample_rate).squeeze(-1).astype( + np.float32).copy() def normalize(self, target_db=-20, max_gain_db=300.0): """Normalize audio to be of the desired RMS value in decibels. diff --git a/deepspeech/training/gradclip.py b/deepspeech/training/gradclip.py index 6c106f340..d0f9803d2 100644 --- a/deepspeech/training/gradclip.py +++ b/deepspeech/training/gradclip.py @@ -31,7 +31,7 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm): def _dygraph_clip(self, params_grads): params_and_grads = [] sum_square_list = [] - for p, g in params_grads: + for i, (p, g) in enumerate(params_grads): if g is None: continue if getattr(p, 'need_clip', True) is False: @@ -45,7 +45,9 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm): sum_square_list.append(sum_square) # debug log - # logger.debug(f"Grad Before Clip: {p.name}: {float(sum_square.sqrt()) }") + if i < 10: + logger.debug( + f"Grad Before Clip: {p.name}: {float(sum_square.sqrt()) }") # all parameters have been filterd out if len(sum_square_list) == 0: @@ -62,7 +64,7 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm): clip_var = layers.elementwise_div( x=max_global_norm, y=layers.elementwise_max(x=global_norm_var, y=max_global_norm)) - for p, g in params_grads: + for i, (p, g) in enumerate(params_grads): if g is None: continue if getattr(p, 'need_clip', True) is False: @@ -72,8 +74,9 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm): params_and_grads.append((p, new_grad)) # debug log - # logger.debug( - # f"Grad After Clip: {p.name}: {float(merge_grad.square().sum().sqrt())}" - # ) + if i < 10: + logger.debug( + f"Grad After Clip: {p.name}: {float(new_grad.square().sum().sqrt())}" + ) return params_and_grads diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index e630febbc..56de32617 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -226,6 +226,7 @@ class Trainer(): 'lr': self.lr_scheduler()}, self.epoch) self.save(tag=self.epoch, infos={'val_loss': cv_loss}) + # step lr every epoch self.lr_scheduler.step() self.new_epoch() @@ -283,7 +284,6 @@ class Trainer(): """ # visualizer visualizer = SummaryWriter(logdir=str(self.output_dir)) - self.visualizer = visualizer @mp_tools.rank_zero_only @@ -301,7 +301,6 @@ class Trainer(): """ raise NotImplementedError("train_batch should be implemented.") - @mp_tools.rank_zero_only @paddle.no_grad() def valid(self): """The validation. A subclass should implement this method. diff --git a/deepspeech/utils/layer_tools.py b/deepspeech/utils/layer_tools.py index 67f3c9396..fb076c0c7 100644 --- a/deepspeech/utils/layer_tools.py +++ b/deepspeech/utils/layer_tools.py @@ -33,7 +33,7 @@ def summary(layer: nn.Layer, print_func=print): if print_func: num_elements = num_elements / 1024**2 print_func( - f"Total parameters: {num_params}, {num_elements:.2f} M elements.") + f"Total parameters: {num_params}, {num_elements:.2f}M elements.") def print_grads(model, print_func=print): @@ -57,7 +57,7 @@ def print_params(model, print_func=print): print_func(msg) if print_func: total = total / 1024**2 - print_func(f"Total parameters: {num_params}, {total:.2f} M elements.") + print_func(f"Total parameters: {num_params}, {total:.2f}M elements.") def gradient_norm(layer: nn.Layer): diff --git a/examples/aishell/s0/README.md b/examples/aishell/s0/README.md index 6d67d19a9..004498799 100644 --- a/examples/aishell/s0/README.md +++ b/examples/aishell/s0/README.md @@ -1,7 +1,8 @@ # Aishell-1 -## CTC -| Model | Config | Test set | CER | -| --- | --- | --- | --- | -| DeepSpeech2 | conf/deepspeech2.yaml | test | 0.078977 | -| DeepSpeech2 | release 1.8.5 | test | 0.080447 | +## Deepspeech2 +| Model | release | Config | Test set | CER | +| --- | --- | --- | --- | --- | +| DeepSpeech2 | 2.1 | conf/deepspeech2.yaml | test | 0.078671 | +| DeepSpeech2 | 2.0 | conf/deepspeech2.yaml | test | 0.078977 | +| DeepSpeech2 | 1.8.5 | - | test | 0.080447 | diff --git a/examples/aishell/s0/conf/augmentation.json b/examples/aishell/s0/conf/augmentation.json index a1a759e67..5635d9c84 100644 --- a/examples/aishell/s0/conf/augmentation.json +++ b/examples/aishell/s0/conf/augmentation.json @@ -1,4 +1,13 @@ [ + { + "type": "speed", + "params": { + "min_speed_rate": 0.9, + "max_speed_rate": 1.1, + "num_rates": 3 + }, + "prob": 0.0 + }, { "type": "shift", "params": { diff --git a/examples/aishell/s0/conf/deepspeech2.yaml b/examples/aishell/s0/conf/deepspeech2.yaml index 02c68df9c..8b08ee308 100644 --- a/examples/aishell/s0/conf/deepspeech2.yaml +++ b/examples/aishell/s0/conf/deepspeech2.yaml @@ -10,9 +10,9 @@ data: min_input_len: 0.0 max_input_len: 27.0 # second min_output_len: 0.0 - max_output_len: 400.0 - min_output_input_ratio: 0.05 - max_output_input_ratio: 10.0 + max_output_len: .inf + min_output_input_ratio: 0.00 + max_output_input_ratio: .inf specgram_type: linear target_sample_rate: 16000 max_freq: None @@ -41,7 +41,7 @@ training: lr: 2e-3 lr_decay: 0.83 weight_decay: 1e-06 - global_grad_clip: 5.0 + global_grad_clip: 3.0 log_interval: 100 decoding: diff --git a/examples/aishell/s0/local/data.sh b/examples/aishell/s0/local/data.sh index f2a5dfc36..c92152c7c 100755 --- a/examples/aishell/s0/local/data.sh +++ b/examples/aishell/s0/local/data.sh @@ -32,7 +32,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then --unit_type="char" \ --count_threshold=0 \ --vocab_path="data/vocab.txt" \ - --manifest_paths "data/manifest.train.raw" + --manifest_paths "data/manifest.train.raw" "data/manifest.dev.raw" if [ $? -ne 0 ]; then echo "Build vocabulary failed. Terminated." @@ -51,8 +51,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then --stride_ms=10.0 \ --window_ms=20.0 \ --sample_rate=16000 \ - --use_dB_normalization=False \ - --num_samples=-1 \ + --use_dB_normalization=True \ + --num_samples=2000 \ --num_workers=${num_workers} \ --output_path="data/mean_std.json" diff --git a/examples/librispeech/s0/README.md b/examples/librispeech/s0/README.md index 1e694df1c..e71cc8340 100644 --- a/examples/librispeech/s0/README.md +++ b/examples/librispeech/s0/README.md @@ -1,6 +1,6 @@ # LibriSpeech -## CTC +## Deepspeech2 | Model | Config | Test set | WER | | --- | --- | --- | --- | | DeepSpeech2 | conf/deepspeech2.yaml | test-clean | 0.073973 | diff --git a/examples/librispeech/s0/conf/augmentation.json b/examples/librispeech/s0/conf/augmentation.json index a1a759e67..5635d9c84 100644 --- a/examples/librispeech/s0/conf/augmentation.json +++ b/examples/librispeech/s0/conf/augmentation.json @@ -1,4 +1,13 @@ [ + { + "type": "speed", + "params": { + "min_speed_rate": 0.9, + "max_speed_rate": 1.1, + "num_rates": 3 + }, + "prob": 0.0 + }, { "type": "shift", "params": { diff --git a/examples/librispeech/s0/conf/deepspeech2.yaml b/examples/librispeech/s0/conf/deepspeech2.yaml index 688f0cba9..80280f5cc 100644 --- a/examples/librispeech/s0/conf/deepspeech2.yaml +++ b/examples/librispeech/s0/conf/deepspeech2.yaml @@ -10,9 +10,9 @@ data: min_input_len: 0.0 max_input_len: 27.0 # second min_output_len: 0.0 - max_output_len: 400.0 - min_output_input_ratio: 0.05 - max_output_input_ratio: 10.0 + max_output_len: .inf + min_output_input_ratio: 0.00 + max_output_input_ratio: .inf specgram_type: linear target_sample_rate: 16000 max_freq: None @@ -21,7 +21,7 @@ data: window_ms: 20.0 delta_delta: False dither: 1.0 - use_dB_normalization: True + use_dB_normalization: True target_dB: -20 random_seed: 0 keep_transcription_text: False @@ -41,7 +41,7 @@ training: lr: 1e-3 lr_decay: 0.83 weight_decay: 1e-06 - global_grad_clip: 5.0 + global_grad_clip: 3.0 log_interval: 100 decoding: diff --git a/examples/librispeech/s0/local/data.sh b/examples/librispeech/s0/local/data.sh index 9c3ddcfac..921f1f49a 100755 --- a/examples/librispeech/s0/local/data.sh +++ b/examples/librispeech/s0/local/data.sh @@ -17,12 +17,12 @@ if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then --manifest_prefix="data/manifest" \ --target_dir="${TARGET_DIR}/librispeech" \ --full_download="True" - + if [ $? -ne 0 ]; then echo "Prepare LibriSpeech failed. Terminated." exit 1 fi - + for set in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do mv data/manifest.${set} data/manifest.${set}.raw done @@ -48,7 +48,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then --count_threshold=0 \ --vocab_path="data/vocab.txt" \ --manifest_paths="data/manifest.train.raw" - + if [ $? -ne 0 ]; then echo "Build vocabulary failed. Terminated." exit 1 @@ -61,16 +61,16 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then num_workers=$(nproc) python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ --manifest_path="data/manifest.train.raw" \ - --num_samples=-1 \ + --num_samples=2000 \ --specgram_type="linear" \ --delta_delta=false \ --sample_rate=16000 \ --stride_ms=10.0 \ --window_ms=20.0 \ - --use_dB_normalization=False \ + --use_dB_normalization=True \ --num_workers=${num_workers} \ --output_path="data/mean_std.json" - + if [ $? -ne 0 ]; then echo "Compute mean and stddev failed. Terminated." exit 1 diff --git a/examples/librispeech/s0/run.sh b/examples/librispeech/s0/run.sh index 472e6ebfb..1c55c6ea5 100755 --- a/examples/librispeech/s0/run.sh +++ b/examples/librispeech/s0/run.sh @@ -19,7 +19,7 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # train model, all `ckpt` under `exp` dir - CUDA_VISIBLE_DEVICES=4,5,6,7 ./local/train.sh ${conf_path} ${ckpt} + CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./local/train.sh ${conf_path} ${ckpt} fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then diff --git a/examples/ngram_lm/README.md b/examples/ngram_lm/README.md new file mode 100644 index 000000000..698d7c290 --- /dev/null +++ b/examples/ngram_lm/README.md @@ -0,0 +1,7 @@ +# Ngram LM + +Train chinese chararctor ngram lm by [kenlm](https://github.com/kpu/kenlm). + +``` +bash run.sh +``` diff --git a/examples/spm/README.md b/examples/spm/README.md index 8b24b28e5..3109d3ffb 100644 --- a/examples/spm/README.md +++ b/examples/spm/README.md @@ -1,4 +1,6 @@ -# SPM demo +# [SentencePiece Model](https://github.com/google/sentencepiece) + +Train a `spm` model for English tokenizer. ``` bash run.sh diff --git a/examples/tiny/s1/conf/augmentation.json b/examples/tiny/s1/conf/augmentation.json index c1078393d..f26c282e7 100644 --- a/examples/tiny/s1/conf/augmentation.json +++ b/examples/tiny/s1/conf/augmentation.json @@ -1,20 +1,20 @@ [ { - "type": "shift", + "type": "speed", "params": { - "min_shift_ms": -5, - "max_shift_ms": 5 + "min_speed_rate": 0.9, + "max_speed_rate": 1.1, + "num_rates": 3 }, "prob": 1.0 }, { - "type": "speed", + "type": "shift", "params": { - "min_speed_rate": 0.9, - "max_speed_rate": 1.1, - "num_rates": 3 + "min_shift_ms": -5, + "max_shift_ms": 5 }, - "prob": 0.0 + "prob": 1.0 }, { "type": "specaug", diff --git a/tests/mask_test.py b/tests/mask_test.py index c4a843e32..ce1a673a5 100644 --- a/tests/mask_test.py +++ b/tests/mask_test.py @@ -26,25 +26,27 @@ class TestU2Model(unittest.TestCase): paddle.set_device('cpu') self.lengths = paddle.to_tensor([5, 3, 2]) self.masks = np.array([ - [1, 1, 1, 1, 1], - [1, 1, 1, 0, 0], - [1, 1, 0, 0, 0], + [True, True, True, True, True], + [True, True, True, False, False], + [True, True, False, False, False], ]) self.pad_masks = np.array([ - [0, 0, 0, 0, 0], - [0, 0, 0, 1, 1], - [0, 0, 1, 1, 1], + [False, False, False, False, False], + [False, False, False, True, True], + [False, False, True, True, True], ]) def test_sequence_mask(self): - res = sequence_mask(self.lengths) + res = sequence_mask(self.lengths, dtype='bool') self.assertSequenceEqual(res.numpy().tolist(), self.masks.tolist()) def test_make_non_pad_mask(self): res = make_non_pad_mask(self.lengths) - res1 = sequence_mask(self.lengths) + res1 = sequence_mask(self.lengths, dtype='bool') + res2 = make_pad_mask(self.lengths).logical_not() self.assertSequenceEqual(res.numpy().tolist(), self.masks.tolist()) self.assertSequenceEqual(res.numpy().tolist(), res1.numpy().tolist()) + self.assertSequenceEqual(res.numpy().tolist(), res2.numpy().tolist()) def test_make_pad_mask(self): res = make_pad_mask(self.lengths) diff --git a/utils/compute_mean_std.py b/utils/compute_mean_std.py index 8dfd3e590..aff6f47c6 100644 --- a/utils/compute_mean_std.py +++ b/utils/compute_mean_std.py @@ -24,7 +24,7 @@ from deepspeech.utils.utility import print_arguments parser = argparse.ArgumentParser(description=__doc__) add_arg = functools.partial(add_arguments, argparser=parser) # yapf: disable -add_arg('num_samples', int, -1, "# of samples to for statistics.") +add_arg('num_samples', int, 2000, "# of samples to for statistics.") add_arg('specgram_type', str, 'linear', @@ -35,7 +35,7 @@ add_arg('delta_delta', bool, False, "Audio feature with delta delta.") add_arg('stride_ms', float, 10.0, "stride length in ms.") add_arg('window_ms', float, 20.0, "stride length in ms.") add_arg('sample_rate', int, 16000, "target sample rate.") -add_arg('use_dB_normalization', bool, False, "do dB normalization.") +add_arg('use_dB_normalization', bool, True, "do dB normalization.") add_arg('target_dB', int, -20, "target dB.") add_arg('manifest_path', str,