diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index c171089dc..717eea4bf 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -39,7 +39,6 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler from deepspeech.io.sampler import SortagradBatchSampler from deepspeech.io.dataset import ManifestDataset -from deepspeech.modules.loss import CTCLoss from deepspeech.models.deepspeech2 import DeepSpeech2Model from deepspeech.models.deepspeech2 import DeepSpeech2InferModel diff --git a/deepspeech/models/deepspeech2.py b/deepspeech/models/deepspeech2.py index ffe678a69..4e66a75f8 100644 --- a/deepspeech/models/deepspeech2.py +++ b/deepspeech/models/deepspeech2.py @@ -170,7 +170,8 @@ class DeepSpeech2Model(nn.Layer): odim=dict_size + 1, # is append after vocab blank_id=dict_size, # last token is dropout_rate=0.0, - reduction=True) + reduction=True, # sum + batch_average=True) # sum / batch_size def forward(self, audio, text, audio_len, text_len): """Compute Model loss diff --git a/deepspeech/modules/ctc.py b/deepspeech/modules/ctc.py index 66737f599..74b21d395 100644 --- a/deepspeech/modules/ctc.py +++ b/deepspeech/modules/ctc.py @@ -36,14 +36,16 @@ class CTCDecoder(nn.Layer): odim, blank_id=0, dropout_rate: float=0.0, - reduction: bool=True): + reduction: bool=True, + batch_average: bool=False): """CTC decoder Args: enc_n_units ([int]): encoder output dimention vocab_size ([int]): text vocabulary size dropout_rate (float): dropout rate (0.0 ~ 1.0) - reduction (bool): reduce the CTC loss into a scalar + reduction (bool): reduce the CTC loss into a scalar, True for 'sum' or 'none' + batch_average (bool): do batch dim wise average. """ assert check_argument_types() super().__init__() @@ -53,7 +55,10 @@ class CTCDecoder(nn.Layer): self.dropout_rate = dropout_rate self.ctc_lo = nn.Linear(enc_n_units, self.odim) reduction_type = "sum" if reduction else "none" - self.criterion = CTCLoss(blank=self.blank_id, reduction=reduction_type) + self.criterion = CTCLoss( + blank=self.blank_id, + reduction=reduction_type, + batch_average=batch_average) # CTCDecoder LM Score handle self._ext_scorer = None diff --git a/deepspeech/modules/loss.py b/deepspeech/modules/loss.py index 0ef7e2f73..a229e7ebe 100644 --- a/deepspeech/modules/loss.py +++ b/deepspeech/modules/loss.py @@ -53,10 +53,11 @@ F.ctc_loss = ctc_loss class CTCLoss(nn.Layer): - def __init__(self, blank=0, reduction='sum'): + def __init__(self, blank=0, reduction='sum', batch_average=False): super().__init__() # last token id as blank id self.loss = nn.CTCLoss(blank=blank, reduction=reduction) + self.batch_average = batch_average def forward(self, logits, ys_pad, hlens, ys_lens): """Compute CTC loss. @@ -76,8 +77,7 @@ class CTCLoss(nn.Layer): # logits: (B, L, D) -> (L, B, D) logits = logits.transpose([1, 0, 2]) loss = self.loss(logits, ys_pad, hlens, ys_lens) - - # wenet do batch-size average, deepspeech2 not do this - # Batch-size average - # loss = loss / B + if self.batch_average: + # Batch-size average + loss = loss / B return loss diff --git a/examples/aishell/README.md b/examples/aishell/README.md index 6d67d19a9..ded740d10 100644 --- a/examples/aishell/README.md +++ b/examples/aishell/README.md @@ -1,7 +1,7 @@ # Aishell-1 ## CTC -| Model | Config | Test set | CER | -| --- | --- | --- | --- | -| DeepSpeech2 | conf/deepspeech2.yaml | test | 0.078977 | -| DeepSpeech2 | release 1.8.5 | test | 0.080447 | +| Model | Config | Test Set | CER | Valid Loss | +| --- | --- | --- | --- | --- | +| DeepSpeech2 | conf/deepspeech2.yaml | test | 0.077249 | 7.036566 | +| DeepSpeech2 | release 1.8.5 | test | 0.087004 | 8.575452 | diff --git a/examples/aishell/conf/deepspeech2.yaml b/examples/aishell/conf/deepspeech2.yaml index 821c183e5..a50a7ecf5 100644 --- a/examples/aishell/conf/deepspeech2.yaml +++ b/examples/aishell/conf/deepspeech2.yaml @@ -29,8 +29,8 @@ model: use_gru: True share_rnn_weights: False training: - n_epoch: 30 - lr: 5e-4 + n_epoch: 50 + lr: 2e-3 lr_decay: 0.83 weight_decay: 1e-06 global_grad_clip: 5.0 @@ -39,7 +39,7 @@ decoding: error_rate_type: cer decoding_method: ctc_beam_search lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm - alpha: 2.6 + alpha: 1.9 beta: 5.0 beam_size: 300 cutoff_prob: 0.99 diff --git a/examples/aishell/local/train.sh b/examples/aishell/local/train.sh index c286566a8..245ed2172 100644 --- a/examples/aishell/local/train.sh +++ b/examples/aishell/local/train.sh @@ -2,7 +2,7 @@ # train model # if you wish to resume from an exists model, uncomment --init_from_pretrained_model -export FLAGS_sync_nccl_allreduce=0 +#export FLAGS_sync_nccl_allreduce=0 ngpu=$(echo ${CUDA_VISIBLE_DEVICES} | python -c 'import sys; a = sys.stdin.read(); print(len(a.split(",")));') echo "using $ngpu gpus..." diff --git a/examples/librispeech/README.md b/examples/librispeech/README.md index 1e694df1c..d553faecf 100644 --- a/examples/librispeech/README.md +++ b/examples/librispeech/README.md @@ -1,7 +1,7 @@ # LibriSpeech ## CTC -| Model | Config | Test set | WER | -| --- | --- | --- | --- | -| DeepSpeech2 | conf/deepspeech2.yaml | test-clean | 0.073973 | -| DeepSpeech2 | release 1.8.5 | test-clean | 0.074939 | +| Model | Config | Test Set | WER | Valid Loss | +| --- | --- | --- | --- | --- | +| DeepSpeech2 | conf/deepspeech2.yaml | test-clean | 0.069357 | 15.078561 | +| DeepSpeech2 | release 1.8.5 | test-clean | 0.074939 | 15.351633 | diff --git a/examples/librispeech/conf/deepspeech2.yaml b/examples/librispeech/conf/deepspeech2.yaml index 15fd4cbe3..3368374b0 100644 --- a/examples/librispeech/conf/deepspeech2.yaml +++ b/examples/librispeech/conf/deepspeech2.yaml @@ -29,8 +29,8 @@ model: use_gru: False share_rnn_weights: True training: - n_epoch: 20 - lr: 5e-4 + n_epoch: 50 + lr: 1e-3 lr_decay: 0.83 weight_decay: 1e-06 global_grad_clip: 5.0 @@ -39,7 +39,7 @@ decoding: error_rate_type: wer decoding_method: ctc_beam_search lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm - alpha: 2.5 + alpha: 1.9 beta: 0.3 beam_size: 500 cutoff_prob: 1.0 diff --git a/examples/librispeech/local/train.sh b/examples/librispeech/local/train.sh index 758098679..cbccb1896 100644 --- a/examples/librispeech/local/train.sh +++ b/examples/librispeech/local/train.sh @@ -1,8 +1,9 @@ #! /usr/bin/env bash -export FLAGS_sync_nccl_allreduce=0 +#export FLAGS_sync_nccl_allreduce=0 + # https://github.com/PaddlePaddle/Paddle/pull/28484 -export NCCL_SHM_DISABLE=1 +#export NCCL_SHM_DISABLE=1 ngpu=$(echo ${CUDA_VISIBLE_DEVICES} | python -c 'import sys; a = sys.stdin.read(); print(len(a.split(",")));') echo "using $ngpu gpus..."