train ds2 model (#622)

* default cmvn compute config; more log of grad clip; diff ds2 cmvn compute and conf; ds2 lr step by epoch;

* fix ds2 config

* fix install and egs link

* sox speed pertrub shape (T, C), float64, process using int32

* fix libri ds2 scripts; add ngram and spm doc

* aishell ds2 cer7.86

* fix ds2 result
pull/625/head
Hui Zhang 4 years ago committed by GitHub
parent 4853761542
commit 295f8bdad5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -21,11 +21,11 @@
* python>=3.7 * python>=3.7
* paddlepaddle>=2.1.0 * paddlepaddle>=2.1.0
Please see [install](doc/install.md). Please see [install](doc/src/install.md).
## Getting Started ## Getting Started
Please see [Getting Started](doc/src/getting_started.md) and [tiny egs](examples/tiny/README.md). Please see [Getting Started](doc/src/getting_started.md) and [tiny egs](examples/tiny/s0/README.md).
## More Information ## More Information

@ -22,11 +22,11 @@
* python>=3.7 * python>=3.7
* paddlepaddle>=2.1.0 * paddlepaddle>=2.1.0
参看 [安装](doc/install.md)。 参看 [安装](doc/src/install.md)。
## 开始 ## 开始
请查看 [Getting Started](doc/src/getting_started.md) 和 [tiny egs](examples/tiny/README.md)。 请查看 [Getting Started](doc/src/getting_started.md) 和 [tiny egs](examples/tiny/s0/README.md)。
## 更多信息 ## 更多信息

@ -43,13 +43,11 @@ class DeepSpeech2Trainer(Trainer):
def train_batch(self, batch_index, batch_data, msg): def train_batch(self, batch_index, batch_data, msg):
start = time.time() start = time.time()
loss = self.model(*batch_data) loss = self.model(*batch_data)
loss.backward() loss.backward()
layer_tools.print_grads(self.model, print_func=None) layer_tools.print_grads(self.model, print_func=None)
self.optimizer.step() self.optimizer.step()
self.optimizer.clear_grad() self.optimizer.clear_grad()
iteration_time = time.time() - start iteration_time = time.time() - start
losses_np = { losses_np = {

@ -351,7 +351,9 @@ class AudioSegment(object):
tfm.set_globals(multithread=False) tfm.set_globals(multithread=False)
tfm.speed(speed_rate) tfm.speed(speed_rate)
self._samples = tfm.build_array( self._samples = tfm.build_array(
input_array=self._samples, sample_rate_in=self._sample_rate).copy() input_array=self._samples,
sample_rate_in=self._sample_rate).squeeze(-1).astype(
np.float32).copy()
def normalize(self, target_db=-20, max_gain_db=300.0): def normalize(self, target_db=-20, max_gain_db=300.0):
"""Normalize audio to be of the desired RMS value in decibels. """Normalize audio to be of the desired RMS value in decibels.

@ -31,7 +31,7 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm):
def _dygraph_clip(self, params_grads): def _dygraph_clip(self, params_grads):
params_and_grads = [] params_and_grads = []
sum_square_list = [] sum_square_list = []
for p, g in params_grads: for i, (p, g) in enumerate(params_grads):
if g is None: if g is None:
continue continue
if getattr(p, 'need_clip', True) is False: if getattr(p, 'need_clip', True) is False:
@ -45,7 +45,9 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm):
sum_square_list.append(sum_square) sum_square_list.append(sum_square)
# debug log # debug log
# logger.debug(f"Grad Before Clip: {p.name}: {float(sum_square.sqrt()) }") if i < 10:
logger.debug(
f"Grad Before Clip: {p.name}: {float(sum_square.sqrt()) }")
# all parameters have been filterd out # all parameters have been filterd out
if len(sum_square_list) == 0: if len(sum_square_list) == 0:
@ -62,7 +64,7 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm):
clip_var = layers.elementwise_div( clip_var = layers.elementwise_div(
x=max_global_norm, x=max_global_norm,
y=layers.elementwise_max(x=global_norm_var, y=max_global_norm)) y=layers.elementwise_max(x=global_norm_var, y=max_global_norm))
for p, g in params_grads: for i, (p, g) in enumerate(params_grads):
if g is None: if g is None:
continue continue
if getattr(p, 'need_clip', True) is False: if getattr(p, 'need_clip', True) is False:
@ -72,8 +74,9 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm):
params_and_grads.append((p, new_grad)) params_and_grads.append((p, new_grad))
# debug log # debug log
# logger.debug( if i < 10:
# f"Grad After Clip: {p.name}: {float(merge_grad.square().sum().sqrt())}" logger.debug(
# ) f"Grad After Clip: {p.name}: {float(new_grad.square().sum().sqrt())}"
)
return params_and_grads return params_and_grads

@ -226,6 +226,7 @@ class Trainer():
'lr': self.lr_scheduler()}, self.epoch) 'lr': self.lr_scheduler()}, self.epoch)
self.save(tag=self.epoch, infos={'val_loss': cv_loss}) self.save(tag=self.epoch, infos={'val_loss': cv_loss})
# step lr every epoch
self.lr_scheduler.step() self.lr_scheduler.step()
self.new_epoch() self.new_epoch()
@ -283,7 +284,6 @@ class Trainer():
""" """
# visualizer # visualizer
visualizer = SummaryWriter(logdir=str(self.output_dir)) visualizer = SummaryWriter(logdir=str(self.output_dir))
self.visualizer = visualizer self.visualizer = visualizer
@mp_tools.rank_zero_only @mp_tools.rank_zero_only
@ -301,7 +301,6 @@ class Trainer():
""" """
raise NotImplementedError("train_batch should be implemented.") raise NotImplementedError("train_batch should be implemented.")
@mp_tools.rank_zero_only
@paddle.no_grad() @paddle.no_grad()
def valid(self): def valid(self):
"""The validation. A subclass should implement this method. """The validation. A subclass should implement this method.

@ -33,7 +33,7 @@ def summary(layer: nn.Layer, print_func=print):
if print_func: if print_func:
num_elements = num_elements / 1024**2 num_elements = num_elements / 1024**2
print_func( print_func(
f"Total parameters: {num_params}, {num_elements:.2f} M elements.") f"Total parameters: {num_params}, {num_elements:.2f}M elements.")
def print_grads(model, print_func=print): def print_grads(model, print_func=print):
@ -57,7 +57,7 @@ def print_params(model, print_func=print):
print_func(msg) print_func(msg)
if print_func: if print_func:
total = total / 1024**2 total = total / 1024**2
print_func(f"Total parameters: {num_params}, {total:.2f} M elements.") print_func(f"Total parameters: {num_params}, {total:.2f}M elements.")
def gradient_norm(layer: nn.Layer): def gradient_norm(layer: nn.Layer):

@ -1,7 +1,8 @@
# Aishell-1 # Aishell-1
## CTC ## Deepspeech2
| Model | Config | Test set | CER | | Model | release | Config | Test set | CER |
| --- | --- | --- | --- | | --- | --- | --- | --- | --- |
| DeepSpeech2 | conf/deepspeech2.yaml | test | 0.078977 | | DeepSpeech2 | 2.1 | conf/deepspeech2.yaml | test | 0.078671 |
| DeepSpeech2 | release 1.8.5 | test | 0.080447 | | DeepSpeech2 | 2.0 | conf/deepspeech2.yaml | test | 0.078977 |
| DeepSpeech2 | 1.8.5 | - | test | 0.080447 |

@ -1,4 +1,13 @@
[ [
{
"type": "speed",
"params": {
"min_speed_rate": 0.9,
"max_speed_rate": 1.1,
"num_rates": 3
},
"prob": 0.0
},
{ {
"type": "shift", "type": "shift",
"params": { "params": {

@ -10,9 +10,9 @@ data:
min_input_len: 0.0 min_input_len: 0.0
max_input_len: 27.0 # second max_input_len: 27.0 # second
min_output_len: 0.0 min_output_len: 0.0
max_output_len: 400.0 max_output_len: .inf
min_output_input_ratio: 0.05 min_output_input_ratio: 0.00
max_output_input_ratio: 10.0 max_output_input_ratio: .inf
specgram_type: linear specgram_type: linear
target_sample_rate: 16000 target_sample_rate: 16000
max_freq: None max_freq: None
@ -41,7 +41,7 @@ training:
lr: 2e-3 lr: 2e-3
lr_decay: 0.83 lr_decay: 0.83
weight_decay: 1e-06 weight_decay: 1e-06
global_grad_clip: 5.0 global_grad_clip: 3.0
log_interval: 100 log_interval: 100
decoding: decoding:

@ -32,7 +32,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--unit_type="char" \ --unit_type="char" \
--count_threshold=0 \ --count_threshold=0 \
--vocab_path="data/vocab.txt" \ --vocab_path="data/vocab.txt" \
--manifest_paths "data/manifest.train.raw" --manifest_paths "data/manifest.train.raw" "data/manifest.dev.raw"
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Build vocabulary failed. Terminated." echo "Build vocabulary failed. Terminated."
@ -51,8 +51,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
--stride_ms=10.0 \ --stride_ms=10.0 \
--window_ms=20.0 \ --window_ms=20.0 \
--sample_rate=16000 \ --sample_rate=16000 \
--use_dB_normalization=False \ --use_dB_normalization=True \
--num_samples=-1 \ --num_samples=2000 \
--num_workers=${num_workers} \ --num_workers=${num_workers} \
--output_path="data/mean_std.json" --output_path="data/mean_std.json"

@ -1,6 +1,6 @@
# LibriSpeech # LibriSpeech
## CTC ## Deepspeech2
| Model | Config | Test set | WER | | Model | Config | Test set | WER |
| --- | --- | --- | --- | | --- | --- | --- | --- |
| DeepSpeech2 | conf/deepspeech2.yaml | test-clean | 0.073973 | | DeepSpeech2 | conf/deepspeech2.yaml | test-clean | 0.073973 |

@ -1,4 +1,13 @@
[ [
{
"type": "speed",
"params": {
"min_speed_rate": 0.9,
"max_speed_rate": 1.1,
"num_rates": 3
},
"prob": 0.0
},
{ {
"type": "shift", "type": "shift",
"params": { "params": {

@ -10,9 +10,9 @@ data:
min_input_len: 0.0 min_input_len: 0.0
max_input_len: 27.0 # second max_input_len: 27.0 # second
min_output_len: 0.0 min_output_len: 0.0
max_output_len: 400.0 max_output_len: .inf
min_output_input_ratio: 0.05 min_output_input_ratio: 0.00
max_output_input_ratio: 10.0 max_output_input_ratio: .inf
specgram_type: linear specgram_type: linear
target_sample_rate: 16000 target_sample_rate: 16000
max_freq: None max_freq: None
@ -41,7 +41,7 @@ training:
lr: 1e-3 lr: 1e-3
lr_decay: 0.83 lr_decay: 0.83
weight_decay: 1e-06 weight_decay: 1e-06
global_grad_clip: 5.0 global_grad_clip: 3.0
log_interval: 100 log_interval: 100
decoding: decoding:

@ -61,13 +61,13 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
num_workers=$(nproc) num_workers=$(nproc)
python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ python3 ${MAIN_ROOT}/utils/compute_mean_std.py \
--manifest_path="data/manifest.train.raw" \ --manifest_path="data/manifest.train.raw" \
--num_samples=-1 \ --num_samples=2000 \
--specgram_type="linear" \ --specgram_type="linear" \
--delta_delta=false \ --delta_delta=false \
--sample_rate=16000 \ --sample_rate=16000 \
--stride_ms=10.0 \ --stride_ms=10.0 \
--window_ms=20.0 \ --window_ms=20.0 \
--use_dB_normalization=False \ --use_dB_normalization=True \
--num_workers=${num_workers} \ --num_workers=${num_workers} \
--output_path="data/mean_std.json" --output_path="data/mean_std.json"

@ -19,7 +19,7 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir # train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=4,5,6,7 ./local/train.sh ${conf_path} ${ckpt} CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./local/train.sh ${conf_path} ${ckpt}
fi fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then

@ -0,0 +1,7 @@
# Ngram LM
Train chinese chararctor ngram lm by [kenlm](https://github.com/kpu/kenlm).
```
bash run.sh
```

@ -1,4 +1,6 @@
# SPM demo # [SentencePiece Model](https://github.com/google/sentencepiece)
Train a `spm` model for English tokenizer.
``` ```
bash run.sh bash run.sh

@ -1,20 +1,20 @@
[ [
{ {
"type": "shift", "type": "speed",
"params": { "params": {
"min_shift_ms": -5, "min_speed_rate": 0.9,
"max_shift_ms": 5 "max_speed_rate": 1.1,
"num_rates": 3
}, },
"prob": 1.0 "prob": 1.0
}, },
{ {
"type": "speed", "type": "shift",
"params": { "params": {
"min_speed_rate": 0.9, "min_shift_ms": -5,
"max_speed_rate": 1.1, "max_shift_ms": 5
"num_rates": 3
}, },
"prob": 0.0 "prob": 1.0
}, },
{ {
"type": "specaug", "type": "specaug",

@ -26,25 +26,27 @@ class TestU2Model(unittest.TestCase):
paddle.set_device('cpu') paddle.set_device('cpu')
self.lengths = paddle.to_tensor([5, 3, 2]) self.lengths = paddle.to_tensor([5, 3, 2])
self.masks = np.array([ self.masks = np.array([
[1, 1, 1, 1, 1], [True, True, True, True, True],
[1, 1, 1, 0, 0], [True, True, True, False, False],
[1, 1, 0, 0, 0], [True, True, False, False, False],
]) ])
self.pad_masks = np.array([ self.pad_masks = np.array([
[0, 0, 0, 0, 0], [False, False, False, False, False],
[0, 0, 0, 1, 1], [False, False, False, True, True],
[0, 0, 1, 1, 1], [False, False, True, True, True],
]) ])
def test_sequence_mask(self): def test_sequence_mask(self):
res = sequence_mask(self.lengths) res = sequence_mask(self.lengths, dtype='bool')
self.assertSequenceEqual(res.numpy().tolist(), self.masks.tolist()) self.assertSequenceEqual(res.numpy().tolist(), self.masks.tolist())
def test_make_non_pad_mask(self): def test_make_non_pad_mask(self):
res = make_non_pad_mask(self.lengths) res = make_non_pad_mask(self.lengths)
res1 = sequence_mask(self.lengths) res1 = sequence_mask(self.lengths, dtype='bool')
res2 = make_pad_mask(self.lengths).logical_not()
self.assertSequenceEqual(res.numpy().tolist(), self.masks.tolist()) self.assertSequenceEqual(res.numpy().tolist(), self.masks.tolist())
self.assertSequenceEqual(res.numpy().tolist(), res1.numpy().tolist()) self.assertSequenceEqual(res.numpy().tolist(), res1.numpy().tolist())
self.assertSequenceEqual(res.numpy().tolist(), res2.numpy().tolist())
def test_make_pad_mask(self): def test_make_pad_mask(self):
res = make_pad_mask(self.lengths) res = make_pad_mask(self.lengths)

@ -24,7 +24,7 @@ from deepspeech.utils.utility import print_arguments
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser) add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable # yapf: disable
add_arg('num_samples', int, -1, "# of samples to for statistics.") add_arg('num_samples', int, 2000, "# of samples to for statistics.")
add_arg('specgram_type', str, add_arg('specgram_type', str,
'linear', 'linear',
@ -35,7 +35,7 @@ add_arg('delta_delta', bool, False, "Audio feature with delta delta.")
add_arg('stride_ms', float, 10.0, "stride length in ms.") add_arg('stride_ms', float, 10.0, "stride length in ms.")
add_arg('window_ms', float, 20.0, "stride length in ms.") add_arg('window_ms', float, 20.0, "stride length in ms.")
add_arg('sample_rate', int, 16000, "target sample rate.") add_arg('sample_rate', int, 16000, "target sample rate.")
add_arg('use_dB_normalization', bool, False, "do dB normalization.") add_arg('use_dB_normalization', bool, True, "do dB normalization.")
add_arg('target_dB', int, -20, "target dB.") add_arg('target_dB', int, -20, "target dB.")
add_arg('manifest_path', str, add_arg('manifest_path', str,

Loading…
Cancel
Save