batch average ctc loss (#567)

* when loss div batchsize, change lr, more epoch, loss can reduce more and cer lower than before

* since loss reduce more when loss div batchsize,  less lm alpha can be better.

* less lm alpha, more cer reduce

* alpha 2.2, cer 0.077478

* alpha 1.9, cer 0.077249

* large librispeech lr for batch_average ctc loss

* since loss reduce and model more confidence, then less lm alpha
pull/570/head
Hui Zhang 5 years ago committed by GitHub
parent 258307df9b
commit e0a87a5ab1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -39,7 +39,6 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.io.sampler import SortagradBatchSampler
from deepspeech.io.dataset import ManifestDataset
from deepspeech.modules.loss import CTCLoss
from deepspeech.models.deepspeech2 import DeepSpeech2Model
from deepspeech.models.deepspeech2 import DeepSpeech2InferModel

@ -170,7 +170,8 @@ class DeepSpeech2Model(nn.Layer):
odim=dict_size + 1, # <blank> is append after vocab
blank_id=dict_size, # last token is <blank>
dropout_rate=0.0,
reduction=True)
reduction=True, # sum
batch_average=True) # sum / batch_size
def forward(self, audio, text, audio_len, text_len):
"""Compute Model loss

@ -36,14 +36,16 @@ class CTCDecoder(nn.Layer):
odim,
blank_id=0,
dropout_rate: float=0.0,
reduction: bool=True):
reduction: bool=True,
batch_average: bool=False):
"""CTC decoder
Args:
enc_n_units ([int]): encoder output dimention
vocab_size ([int]): text vocabulary size
dropout_rate (float): dropout rate (0.0 ~ 1.0)
reduction (bool): reduce the CTC loss into a scalar
reduction (bool): reduce the CTC loss into a scalar, True for 'sum' or 'none'
batch_average (bool): do batch dim wise average.
"""
assert check_argument_types()
super().__init__()
@ -53,7 +55,10 @@ class CTCDecoder(nn.Layer):
self.dropout_rate = dropout_rate
self.ctc_lo = nn.Linear(enc_n_units, self.odim)
reduction_type = "sum" if reduction else "none"
self.criterion = CTCLoss(blank=self.blank_id, reduction=reduction_type)
self.criterion = CTCLoss(
blank=self.blank_id,
reduction=reduction_type,
batch_average=batch_average)
# CTCDecoder LM Score handle
self._ext_scorer = None

@ -53,10 +53,11 @@ F.ctc_loss = ctc_loss
class CTCLoss(nn.Layer):
def __init__(self, blank=0, reduction='sum'):
def __init__(self, blank=0, reduction='sum', batch_average=False):
super().__init__()
# last token id as blank id
self.loss = nn.CTCLoss(blank=blank, reduction=reduction)
self.batch_average = batch_average
def forward(self, logits, ys_pad, hlens, ys_lens):
"""Compute CTC loss.
@ -76,8 +77,7 @@ class CTCLoss(nn.Layer):
# logits: (B, L, D) -> (L, B, D)
logits = logits.transpose([1, 0, 2])
loss = self.loss(logits, ys_pad, hlens, ys_lens)
# wenet do batch-size average, deepspeech2 not do this
if self.batch_average:
# Batch-size average
# loss = loss / B
loss = loss / B
return loss

@ -1,7 +1,7 @@
# Aishell-1
## CTC
| Model | Config | Test set | CER |
| --- | --- | --- | --- |
| DeepSpeech2 | conf/deepspeech2.yaml | test | 0.078977 |
| DeepSpeech2 | release 1.8.5 | test | 0.080447 |
| Model | Config | Test Set | CER | Valid Loss |
| --- | --- | --- | --- | --- |
| DeepSpeech2 | conf/deepspeech2.yaml | test | 0.077249 | 7.036566 |
| DeepSpeech2 | release 1.8.5 | test | 0.087004 | 8.575452 |

@ -29,8 +29,8 @@ model:
use_gru: True
share_rnn_weights: False
training:
n_epoch: 30
lr: 5e-4
n_epoch: 50
lr: 2e-3
lr_decay: 0.83
weight_decay: 1e-06
global_grad_clip: 5.0
@ -39,7 +39,7 @@ decoding:
error_rate_type: cer
decoding_method: ctc_beam_search
lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm
alpha: 2.6
alpha: 1.9
beta: 5.0
beam_size: 300
cutoff_prob: 0.99

@ -2,7 +2,7 @@
# train model
# if you wish to resume from an exists model, uncomment --init_from_pretrained_model
export FLAGS_sync_nccl_allreduce=0
#export FLAGS_sync_nccl_allreduce=0
ngpu=$(echo ${CUDA_VISIBLE_DEVICES} | python -c 'import sys; a = sys.stdin.read(); print(len(a.split(",")));')
echo "using $ngpu gpus..."

@ -1,7 +1,7 @@
# LibriSpeech
## CTC
| Model | Config | Test set | WER |
| --- | --- | --- | --- |
| DeepSpeech2 | conf/deepspeech2.yaml | test-clean | 0.073973 |
| DeepSpeech2 | release 1.8.5 | test-clean | 0.074939 |
| Model | Config | Test Set | WER | Valid Loss |
| --- | --- | --- | --- | --- |
| DeepSpeech2 | conf/deepspeech2.yaml | test-clean | 0.069357 | 15.078561 |
| DeepSpeech2 | release 1.8.5 | test-clean | 0.074939 | 15.351633 |

@ -29,8 +29,8 @@ model:
use_gru: False
share_rnn_weights: True
training:
n_epoch: 20
lr: 5e-4
n_epoch: 50
lr: 1e-3
lr_decay: 0.83
weight_decay: 1e-06
global_grad_clip: 5.0
@ -39,7 +39,7 @@ decoding:
error_rate_type: wer
decoding_method: ctc_beam_search
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 2.5
alpha: 1.9
beta: 0.3
beam_size: 500
cutoff_prob: 1.0

@ -1,8 +1,9 @@
#! /usr/bin/env bash
export FLAGS_sync_nccl_allreduce=0
#export FLAGS_sync_nccl_allreduce=0
# https://github.com/PaddlePaddle/Paddle/pull/28484
export NCCL_SHM_DISABLE=1
#export NCCL_SHM_DISABLE=1
ngpu=$(echo ${CUDA_VISIBLE_DEVICES} | python -c 'import sys; a = sys.stdin.read(); print(len(a.split(",")));')
echo "using $ngpu gpus..."

Loading…
Cancel
Save