Merge pull request #858 from PaddlePaddle/ctc

ctc using nn.Dropout; ds2 libri vald batch_size / 4
pull/868/head
Jackwaterveg 3 years ago committed by GitHub
commit 4b225b7602
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -235,7 +235,7 @@ class DeepSpeech2Trainer(Trainer):
num_workers=config.collator.num_workers)
self.valid_loader = DataLoader(
dev_dataset,
batch_size=config.collator.batch_size,
batch_size=int(config.collator.batch_size / 4),
shuffle=False,
drop_last=False,
collate_fn=collate_fn_dev)

@ -56,7 +56,7 @@ class CTCDecoder(nn.Layer):
self.blank_id = blank_id
self.odim = odim
self.dropout_rate = dropout_rate
self.dropout = nn.Dropout(dropout_rate)
self.ctc_lo = nn.Linear(enc_n_units, self.odim)
reduction_type = "sum" if reduction else "none"
self.criterion = CTCLoss(
@ -79,7 +79,7 @@ class CTCDecoder(nn.Layer):
Returns:
loss (Tenosr): ctc loss value, scalar.
"""
logits = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate))
logits = self.ctc_lo(self.dropout(hs_pad))
loss = self.criterion(logits, ys_pad, hlens, ys_lens)
return loss

@ -45,9 +45,9 @@ class CTCLoss(nn.Layer):
logger.info(f"CTCLoss Grad Norm Type: {grad_norm_type}")
if grad_norm_type == 'instance':
self.norm_by_times = True
if grad_norm_type == 'batch':
elif grad_norm_type == 'batch':
self.norm_by_batchsize = True
if grad_norm_type == 'frame':
elif grad_norm_type == 'frame':
self.norm_by_total_logits_len = True
def forward(self, logits, ys_pad, hlens, ys_lens):

@ -8,7 +8,7 @@ data:
min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
max_output_input_ratio: 100.0
collator:
vocab_filepath: data/vocab.txt
@ -16,7 +16,7 @@ collator:
spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 64
batch_size: 32
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
@ -75,7 +75,7 @@ model:
training:
n_epoch: 120
accum_grad: 2
accum_grad: 4
global_grad_clip: 5.0
optim: adam
optim_conf:

Loading…
Cancel
Save