ctc using nn.Dropout; ds2 libri vald batch_size / 4

pull/858/head
Hui Zhang 3 years ago
parent 156a840216
commit 0c08915207

@ -235,7 +235,7 @@ class DeepSpeech2Trainer(Trainer):
num_workers=config.collator.num_workers)
self.valid_loader = DataLoader(
dev_dataset,
batch_size=config.collator.batch_size,
batch_size=int(config.collator.batch_size / 4),
shuffle=False,
drop_last=False,
collate_fn=collate_fn_dev)

@ -56,7 +56,7 @@ class CTCDecoder(nn.Layer):
self.blank_id = blank_id
self.odim = odim
self.dropout_rate = dropout_rate
self.dropout = nn.Dropout(dropout_rate)
self.ctc_lo = nn.Linear(enc_n_units, self.odim)
reduction_type = "sum" if reduction else "none"
self.criterion = CTCLoss(
@ -79,7 +79,7 @@ class CTCDecoder(nn.Layer):
Returns:
loss (Tenosr): ctc loss value, scalar.
"""
logits = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate))
logits = self.ctc_lo(self.dropout(hs_pad))
loss = self.criterion(logits, ys_pad, hlens, ys_lens)
return loss

@ -45,9 +45,9 @@ class CTCLoss(nn.Layer):
logger.info(f"CTCLoss Grad Norm Type: {grad_norm_type}")
if grad_norm_type == 'instance':
self.norm_by_times = True
if grad_norm_type == 'batch':
elif grad_norm_type == 'batch':
self.norm_by_batchsize = True
if grad_norm_type == 'frame':
elif grad_norm_type == 'frame':
self.norm_by_total_logits_len = True
def forward(self, logits, ys_pad, hlens, ys_lens):

Loading…
Cancel
Save