From 1a8c5278a111e5da2df8f5d46c9eacb77a46a1bc Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 9 Sep 2021 02:52:38 +0000 Subject: [PATCH] export ctc grad norm config --- deepspeech/models/ds2/deepspeech2.py | 10 +++--- deepspeech/models/ds2_online/deepspeech2.py | 3 +- deepspeech/models/u2.py | 3 +- deepspeech/models/u2_st.py | 3 +- deepspeech/modules/ctc.py | 7 ++-- deepspeech/modules/loss.py | 32 +++++++++++++++++-- examples/aishell/s0/conf/deepspeech2.yaml | 1 + .../aishell/s0/conf/deepspeech2_online.yaml | 1 + examples/librispeech/s0/conf/deepspeech2.yaml | 1 + .../s0/conf/deepspeech2_online.yaml | 1 + examples/tiny/s0/conf/deepspeech2.yaml | 1 + examples/tiny/s0/conf/deepspeech2_online.yaml | 1 + 12 files changed, 52 insertions(+), 12 deletions(-) diff --git a/deepspeech/models/ds2/deepspeech2.py b/deepspeech/models/ds2/deepspeech2.py index 620d9008..dda26358 100644 --- a/deepspeech/models/ds2/deepspeech2.py +++ b/deepspeech/models/ds2/deepspeech2.py @@ -128,8 +128,8 @@ class DeepSpeech2Model(nn.Layer): num_rnn_layers=3, #Number of stacking RNN layers. rnn_layer_size=1024, #RNN layer size (number of RNN cells). use_gru=True, #Use gru if set True. Use simple rnn if set False. - share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported. - )) + share_rnn_weights=True, #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported. + ctc_grad_norm_type='instance', )) if config is not None: config.merge_from_other_cfg(default) return default @@ -142,7 +142,8 @@ class DeepSpeech2Model(nn.Layer): rnn_size=1024, use_gru=False, share_rnn_weights=True, - blank_id=0): + blank_id=0, + ctc_grad_norm_type='instance'): super().__init__() self.encoder = CRNNEncoder( feat_size=feat_size, @@ -160,7 +161,8 @@ class DeepSpeech2Model(nn.Layer): blank_id=blank_id, dropout_rate=0.0, reduction=True, # sum - batch_average=True) # sum / batch_size + batch_average=True, # sum / batch_size + grad_norm_type=ctc_grad_norm_type) def forward(self, audio, audio_len, text, text_len): """Compute Model loss diff --git a/deepspeech/models/ds2_online/deepspeech2.py b/deepspeech/models/ds2_online/deepspeech2.py index f049f415..29d207c4 100644 --- a/deepspeech/models/ds2_online/deepspeech2.py +++ b/deepspeech/models/ds2_online/deepspeech2.py @@ -289,7 +289,8 @@ class DeepSpeech2ModelOnline(nn.Layer): blank_id=blank_id, dropout_rate=0.0, reduction=True, # sum - batch_average=True) # sum / batch_size + batch_average=True, # sum / batch_size + grad_norm_type='instance') def forward(self, audio, audio_len, text, text_len): """Compute Model loss diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py index 1ca6a4fe..a01766da 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2.py @@ -864,7 +864,8 @@ class U2Model(U2BaseModel): blank_id=0, dropout_rate=0.0, reduction=True, # sum - batch_average=True) # sum / batch_size + batch_average=True, # sum / batch_size + grad_norm_type='instance') return vocab_size, encoder, decoder, ctc diff --git a/deepspeech/models/u2_st.py b/deepspeech/models/u2_st.py index 531fafd0..7dae3745 100644 --- a/deepspeech/models/u2_st.py +++ b/deepspeech/models/u2_st.py @@ -649,7 +649,8 @@ class U2STModel(U2STBaseModel): blank_id=0, dropout_rate=0.0, reduction=True, # sum - batch_average=True) # sum / batch_size + batch_average=True, # sum / batch_size + grad_norm_type='instance') return vocab_size, encoder, (st_decoder, decoder, ctc) else: diff --git a/deepspeech/modules/ctc.py b/deepspeech/modules/ctc.py index c330caf1..b3ca2827 100644 --- a/deepspeech/modules/ctc.py +++ b/deepspeech/modules/ctc.py @@ -39,7 +39,8 @@ class CTCDecoder(nn.Layer): blank_id=0, dropout_rate: float=0.0, reduction: bool=True, - batch_average: bool=True): + batch_average: bool=True, + grad_norm_type: str="instance"): """CTC decoder Args: @@ -48,6 +49,7 @@ class CTCDecoder(nn.Layer): dropout_rate (float): dropout rate (0.0 ~ 1.0) reduction (bool): reduce the CTC loss into a scalar, True for 'sum' or 'none' batch_average (bool): do batch dim wise average. + grad_norm_type (str): one of 'instance', 'batchsize', 'frame', None. """ assert check_argument_types() super().__init__() @@ -60,7 +62,8 @@ class CTCDecoder(nn.Layer): self.criterion = CTCLoss( blank=self.blank_id, reduction=reduction_type, - batch_average=batch_average) + batch_average=batch_average, + grad_norm_type=grad_norm_type) # CTCDecoder LM Score handle self._ext_scorer = None diff --git a/deepspeech/modules/loss.py b/deepspeech/modules/loss.py index f692a818..399e84e2 100644 --- a/deepspeech/modules/loss.py +++ b/deepspeech/modules/loss.py @@ -23,11 +23,32 @@ __all__ = ['CTCLoss', "LabelSmoothingLoss"] class CTCLoss(nn.Layer): - def __init__(self, blank=0, reduction='sum', batch_average=False): + def __init__(self, + blank=0, + reduction='sum', + batch_average=False, + grad_norm_type=None): super().__init__() # last token id as blank id self.loss = nn.CTCLoss(blank=blank, reduction=reduction) self.batch_average = batch_average + logger.info( + f"CTCLoss Loss reduction: {reduction}, div-bs: {batch_average}") + + # instance for norm_by_times + # batchsize for norm_by_batchsize + # frame for norm_by_total_logits_len + assert grad_norm_type in ('instance', 'batchsize', 'frame', None) + self.norm_by_times = False + self.norm_by_batchsize = False + self.norm_by_total_logits_len = False + logger.info(f"CTCLoss Grad Norm Type: {grad_norm_type}") + if grad_norm_type == 'instance': + self.norm_by_times = True + if grad_norm_type == 'batchsize': + self.norm_by_times = True + if grad_norm_type == 'frame': + self.norm_by_total_logits_len = True def forward(self, logits, ys_pad, hlens, ys_lens): """Compute CTC loss. @@ -46,10 +67,15 @@ class CTCLoss(nn.Layer): # warp-ctc need activation with shape [T, B, V + 1] # logits: (B, L, D) -> (L, B, D) logits = logits.transpose([1, 0, 2]) - # (TODO:Hui Zhang) ctc loss does not support int64 labels ys_pad = ys_pad.astype(paddle.int32) loss = self.loss( - logits, ys_pad, hlens, ys_lens, norm_by_times=self.batch_average) + logits, + ys_pad, + hlens, + ys_lens, + norm_by_times=self.norm_by_times, + norm_by_batchsize=self.norm_by_batchsize, + norm_by_total_logits_len=self.norm_by_total_logits_len) if self.batch_average: # Batch-size average loss = loss / B diff --git a/examples/aishell/s0/conf/deepspeech2.yaml b/examples/aishell/s0/conf/deepspeech2.yaml index 4bf03ec6..9560930a 100644 --- a/examples/aishell/s0/conf/deepspeech2.yaml +++ b/examples/aishell/s0/conf/deepspeech2.yaml @@ -41,6 +41,7 @@ model: use_gru: True share_rnn_weights: False blank_id: 0 + ctc_grad_norm_type: instance training: n_epoch: 80 diff --git a/examples/aishell/s0/conf/deepspeech2_online.yaml b/examples/aishell/s0/conf/deepspeech2_online.yaml index 9946852d..7e87594c 100644 --- a/examples/aishell/s0/conf/deepspeech2_online.yaml +++ b/examples/aishell/s0/conf/deepspeech2_online.yaml @@ -43,6 +43,7 @@ model: fc_layers_size_list: -1, use_gru: False blank_id: 0 + ctc_grad_norm_type: instance training: n_epoch: 50 diff --git a/examples/librispeech/s0/conf/deepspeech2.yaml b/examples/librispeech/s0/conf/deepspeech2.yaml index 0e6ed5ba..d5b1ed91 100644 --- a/examples/librispeech/s0/conf/deepspeech2.yaml +++ b/examples/librispeech/s0/conf/deepspeech2.yaml @@ -41,6 +41,7 @@ model: use_gru: False share_rnn_weights: True blank_id: 0 + ctc_grad_norm_type: instance training: n_epoch: 50 diff --git a/examples/librispeech/s0/conf/deepspeech2_online.yaml b/examples/librispeech/s0/conf/deepspeech2_online.yaml index 6e74f704..180a6205 100644 --- a/examples/librispeech/s0/conf/deepspeech2_online.yaml +++ b/examples/librispeech/s0/conf/deepspeech2_online.yaml @@ -43,6 +43,7 @@ model: fc_layers_size_list: 512, 256 use_gru: False blank_id: 0 + ctc_grad_norm_type: instance training: n_epoch: 50 diff --git a/examples/tiny/s0/conf/deepspeech2.yaml b/examples/tiny/s0/conf/deepspeech2.yaml index 5c9436e3..64598b4b 100644 --- a/examples/tiny/s0/conf/deepspeech2.yaml +++ b/examples/tiny/s0/conf/deepspeech2.yaml @@ -42,6 +42,7 @@ model: use_gru: False share_rnn_weights: True blank_id: 0 + ctc_grad_norm_type: instance training: n_epoch: 10 diff --git a/examples/tiny/s0/conf/deepspeech2_online.yaml b/examples/tiny/s0/conf/deepspeech2_online.yaml index e435ff96..0098a226 100644 --- a/examples/tiny/s0/conf/deepspeech2_online.yaml +++ b/examples/tiny/s0/conf/deepspeech2_online.yaml @@ -44,6 +44,7 @@ model: fc_layers_size_list: 512, 256 use_gru: True blank_id: 0 + ctc_grad_norm_type: instance training: n_epoch: 10