Merge pull request #858 from PaddlePaddle/ctc

ctc using nn.Dropout; ds2 libri vald batch_size / 4
pull/868/head
Jackwaterveg 3 years ago committed by GitHub
commit 4b225b7602
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -235,7 +235,7 @@ class DeepSpeech2Trainer(Trainer):
num_workers=config.collator.num_workers) num_workers=config.collator.num_workers)
self.valid_loader = DataLoader( self.valid_loader = DataLoader(
dev_dataset, dev_dataset,
batch_size=config.collator.batch_size, batch_size=int(config.collator.batch_size / 4),
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=collate_fn_dev) collate_fn=collate_fn_dev)

@ -56,7 +56,7 @@ class CTCDecoder(nn.Layer):
self.blank_id = blank_id self.blank_id = blank_id
self.odim = odim self.odim = odim
self.dropout_rate = dropout_rate self.dropout = nn.Dropout(dropout_rate)
self.ctc_lo = nn.Linear(enc_n_units, self.odim) self.ctc_lo = nn.Linear(enc_n_units, self.odim)
reduction_type = "sum" if reduction else "none" reduction_type = "sum" if reduction else "none"
self.criterion = CTCLoss( self.criterion = CTCLoss(
@ -79,7 +79,7 @@ class CTCDecoder(nn.Layer):
Returns: Returns:
loss (Tenosr): ctc loss value, scalar. loss (Tenosr): ctc loss value, scalar.
""" """
logits = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate)) logits = self.ctc_lo(self.dropout(hs_pad))
loss = self.criterion(logits, ys_pad, hlens, ys_lens) loss = self.criterion(logits, ys_pad, hlens, ys_lens)
return loss return loss

@ -45,9 +45,9 @@ class CTCLoss(nn.Layer):
logger.info(f"CTCLoss Grad Norm Type: {grad_norm_type}") logger.info(f"CTCLoss Grad Norm Type: {grad_norm_type}")
if grad_norm_type == 'instance': if grad_norm_type == 'instance':
self.norm_by_times = True self.norm_by_times = True
if grad_norm_type == 'batch': elif grad_norm_type == 'batch':
self.norm_by_batchsize = True self.norm_by_batchsize = True
if grad_norm_type == 'frame': elif grad_norm_type == 'frame':
self.norm_by_total_logits_len = True self.norm_by_total_logits_len = True
def forward(self, logits, ys_pad, hlens, ys_lens): def forward(self, logits, ys_pad, hlens, ys_lens):

@ -8,7 +8,7 @@ data:
min_output_len: 0.0 # tokens min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens max_output_len: 400.0 # tokens
min_output_input_ratio: 0.05 min_output_input_ratio: 0.05
max_output_input_ratio: 10.0 max_output_input_ratio: 100.0
collator: collator:
vocab_filepath: data/vocab.txt vocab_filepath: data/vocab.txt
@ -16,7 +16,7 @@ collator:
spm_model_prefix: 'data/bpe_unigram_5000' spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: "" mean_std_filepath: ""
augmentation_config: conf/augmentation.json augmentation_config: conf/augmentation.json
batch_size: 64 batch_size: 32
raw_wav: True # use raw_wav or kaldi feature raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80 feat_dim: 80
@ -75,7 +75,7 @@ model:
training: training:
n_epoch: 120 n_epoch: 120
accum_grad: 2 accum_grad: 4
global_grad_clip: 5.0 global_grad_clip: 5.0
optim: adam optim: adam
optim_conf: optim_conf:

Loading…
Cancel
Save