diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 702a0576..b854a996 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -235,7 +235,7 @@ class DeepSpeech2Trainer(Trainer): num_workers=config.collator.num_workers) self.valid_loader = DataLoader( dev_dataset, - batch_size=config.collator.batch_size, + batch_size=int(config.collator.batch_size / 4), shuffle=False, drop_last=False, collate_fn=collate_fn_dev) diff --git a/deepspeech/modules/ctc.py b/deepspeech/modules/ctc.py index b3ca2827..11ce871f 100644 --- a/deepspeech/modules/ctc.py +++ b/deepspeech/modules/ctc.py @@ -56,7 +56,7 @@ class CTCDecoder(nn.Layer): self.blank_id = blank_id self.odim = odim - self.dropout_rate = dropout_rate + self.dropout = nn.Dropout(dropout_rate) self.ctc_lo = nn.Linear(enc_n_units, self.odim) reduction_type = "sum" if reduction else "none" self.criterion = CTCLoss( @@ -79,7 +79,7 @@ class CTCDecoder(nn.Layer): Returns: loss (Tenosr): ctc loss value, scalar. """ - logits = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate)) + logits = self.ctc_lo(self.dropout(hs_pad)) loss = self.criterion(logits, ys_pad, hlens, ys_lens) return loss diff --git a/deepspeech/modules/loss.py b/deepspeech/modules/loss.py index 2c58be7e..7d24e170 100644 --- a/deepspeech/modules/loss.py +++ b/deepspeech/modules/loss.py @@ -45,9 +45,9 @@ class CTCLoss(nn.Layer): logger.info(f"CTCLoss Grad Norm Type: {grad_norm_type}") if grad_norm_type == 'instance': self.norm_by_times = True - if grad_norm_type == 'batch': + elif grad_norm_type == 'batch': self.norm_by_batchsize = True - if grad_norm_type == 'frame': + elif grad_norm_type == 'frame': self.norm_by_total_logits_len = True def forward(self, logits, ys_pad, hlens, ys_lens): diff --git a/examples/librispeech/s1/conf/transformer.yaml b/examples/librispeech/s1/conf/transformer.yaml index fe9cab06..e4a06767 100644 --- a/examples/librispeech/s1/conf/transformer.yaml +++ b/examples/librispeech/s1/conf/transformer.yaml @@ -8,7 +8,7 @@ data: min_output_len: 0.0 # tokens max_output_len: 400.0 # tokens min_output_input_ratio: 0.05 - max_output_input_ratio: 10.0 + max_output_input_ratio: 100.0 collator: vocab_filepath: data/vocab.txt @@ -16,7 +16,7 @@ collator: spm_model_prefix: 'data/bpe_unigram_5000' mean_std_filepath: "" augmentation_config: conf/augmentation.json - batch_size: 64 + batch_size: 32 raw_wav: True # use raw_wav or kaldi feature specgram_type: fbank #linear, mfcc, fbank feat_dim: 80 @@ -75,7 +75,7 @@ model: training: n_epoch: 120 - accum_grad: 2 + accum_grad: 4 global_grad_clip: 5.0 optim: adam optim_conf: