large librispeech lr for batch_average ctc loss

pull/567/head
Hui Zhang 5 years ago
parent 1ab5bc2b24
commit f8719971b5

@ -39,7 +39,6 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.io.sampler import SortagradBatchSampler from deepspeech.io.sampler import SortagradBatchSampler
from deepspeech.io.dataset import ManifestDataset from deepspeech.io.dataset import ManifestDataset
from deepspeech.modules.loss import CTCLoss
from deepspeech.models.deepspeech2 import DeepSpeech2Model from deepspeech.models.deepspeech2 import DeepSpeech2Model
from deepspeech.models.deepspeech2 import DeepSpeech2InferModel from deepspeech.models.deepspeech2 import DeepSpeech2InferModel

@ -170,7 +170,8 @@ class DeepSpeech2Model(nn.Layer):
odim=dict_size + 1, # <blank> is append after vocab odim=dict_size + 1, # <blank> is append after vocab
blank_id=dict_size, # last token is <blank> blank_id=dict_size, # last token is <blank>
dropout_rate=0.0, dropout_rate=0.0,
reduction=True) reduction=True, # sum
batch_average=True) # sum / batch_size
def forward(self, audio, text, audio_len, text_len): def forward(self, audio, text, audio_len, text_len):
"""Compute Model loss """Compute Model loss

@ -36,14 +36,16 @@ class CTCDecoder(nn.Layer):
odim, odim,
blank_id=0, blank_id=0,
dropout_rate: float=0.0, dropout_rate: float=0.0,
reduction: bool=True): reduction: bool=True,
batch_average: bool=False):
"""CTC decoder """CTC decoder
Args: Args:
enc_n_units ([int]): encoder output dimention enc_n_units ([int]): encoder output dimention
vocab_size ([int]): text vocabulary size vocab_size ([int]): text vocabulary size
dropout_rate (float): dropout rate (0.0 ~ 1.0) dropout_rate (float): dropout rate (0.0 ~ 1.0)
reduction (bool): reduce the CTC loss into a scalar reduction (bool): reduce the CTC loss into a scalar, True for 'sum' or 'none'
batch_average (bool): do batch dim wise average.
""" """
assert check_argument_types() assert check_argument_types()
super().__init__() super().__init__()
@ -53,7 +55,10 @@ class CTCDecoder(nn.Layer):
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
self.ctc_lo = nn.Linear(enc_n_units, self.odim) self.ctc_lo = nn.Linear(enc_n_units, self.odim)
reduction_type = "sum" if reduction else "none" reduction_type = "sum" if reduction else "none"
self.criterion = CTCLoss(blank=self.blank_id, reduction=reduction_type) self.criterion = CTCLoss(
blank=self.blank_id,
reduction=reduction_type,
batch_average=batch_average)
# CTCDecoder LM Score handle # CTCDecoder LM Score handle
self._ext_scorer = None self._ext_scorer = None

@ -53,10 +53,11 @@ F.ctc_loss = ctc_loss
class CTCLoss(nn.Layer): class CTCLoss(nn.Layer):
def __init__(self, blank=0, reduction='sum'): def __init__(self, blank=0, reduction='sum', batch_average=False):
super().__init__() super().__init__()
# last token id as blank id # last token id as blank id
self.loss = nn.CTCLoss(blank=blank, reduction=reduction) self.loss = nn.CTCLoss(blank=blank, reduction=reduction)
self.batch_average = batch_average
def forward(self, logits, ys_pad, hlens, ys_lens): def forward(self, logits, ys_pad, hlens, ys_lens):
"""Compute CTC loss. """Compute CTC loss.
@ -76,8 +77,7 @@ class CTCLoss(nn.Layer):
# logits: (B, L, D) -> (L, B, D) # logits: (B, L, D) -> (L, B, D)
logits = logits.transpose([1, 0, 2]) logits = logits.transpose([1, 0, 2])
loss = self.loss(logits, ys_pad, hlens, ys_lens) loss = self.loss(logits, ys_pad, hlens, ys_lens)
if self.batch_average:
# wenet do batch-size average, deepspeech2 not do this
# Batch-size average # Batch-size average
# loss = loss / B loss = loss / B
return loss return loss

@ -4,4 +4,4 @@
| Model | Config | Test Set | CER | Valid Loss | | Model | Config | Test Set | CER | Valid Loss |
| --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- |
| DeepSpeech2 | conf/deepspeech2.yaml | test | 0.077249 | 7.036566 | | DeepSpeech2 | conf/deepspeech2.yaml | test | 0.077249 | 7.036566 |
| DeepSpeech2 | release 1.8.5 | test | 0.080447 | - | | DeepSpeech2 | release 1.8.5 | test | 0.087004 | 8.575452 |

@ -1,7 +1,7 @@
# LibriSpeech # LibriSpeech
## CTC ## CTC
| Model | Config | Test set | WER | | Model | Config | Test Set | WER | Valid Loss |
| --- | --- | --- | --- | | --- | --- | --- | --- | --- |
| DeepSpeech2 | conf/deepspeech2.yaml | test-clean | 0.073973 | | DeepSpeech2 | conf/deepspeech2.yaml | test-clean | 0.071391 | 15.078561 |
| DeepSpeech2 | release 1.8.5 | test-clean | 0.074939 | | DeepSpeech2 | release 1.8.5 | test-clean | 0.074939 | 15.351633 |

@ -29,8 +29,8 @@ model:
use_gru: False use_gru: False
share_rnn_weights: True share_rnn_weights: True
training: training:
n_epoch: 20 n_epoch: 50
lr: 5e-4 lr: 1e-3
lr_decay: 0.83 lr_decay: 0.83
weight_decay: 1e-06 weight_decay: 1e-06
global_grad_clip: 5.0 global_grad_clip: 5.0

@ -1,8 +1,9 @@
#! /usr/bin/env bash #! /usr/bin/env bash
export FLAGS_sync_nccl_allreduce=0 #export FLAGS_sync_nccl_allreduce=0
# https://github.com/PaddlePaddle/Paddle/pull/28484 # https://github.com/PaddlePaddle/Paddle/pull/28484
export NCCL_SHM_DISABLE=1 #export NCCL_SHM_DISABLE=1
ngpu=$(echo ${CUDA_VISIBLE_DEVICES} | python -c 'import sys; a = sys.stdin.read(); print(len(a.split(",")));') ngpu=$(echo ${CUDA_VISIBLE_DEVICES} | python -c 'import sys; a = sys.stdin.read(); print(len(a.split(",")));')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."

Loading…
Cancel
Save