more ctc check; valid dataloader with num workers

pull/861/head
Hui Zhang 3 years ago
parent 4b225b7602
commit f7d7e70cb2

@ -243,6 +243,7 @@ class U2Trainer(Trainer):
self.visualizer.add_scalars(
'epoch', {'cv_loss': cv_loss,
'lr': self.lr_scheduler()}, self.epoch)
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
self.new_epoch()
@ -291,7 +292,8 @@ class U2Trainer(Trainer):
batch_size=config.collator.batch_size,
shuffle=False,
drop_last=False,
collate_fn=collate_fn_dev)
collate_fn=collate_fn_dev,
num_workers=config.collator.num_workers, )
# test dataset, return raw text
config.data.manifest = config.data.test_manifest

@ -49,7 +49,7 @@ class CTCDecoder(nn.Layer):
dropout_rate (float): dropout rate (0.0 ~ 1.0)
reduction (bool): reduce the CTC loss into a scalar, True for 'sum' or 'none'
batch_average (bool): do batch dim wise average.
grad_norm_type (str): one of 'instance', 'batchsize', 'frame', None.
grad_norm_type (str): one of 'instance', 'batch', 'frame', None.
"""
assert check_argument_types()
super().__init__()

@ -49,6 +49,8 @@ class CTCLoss(nn.Layer):
self.norm_by_batchsize = True
elif grad_norm_type == 'frame':
self.norm_by_total_logits_len = True
else:
raise ValueError(f"CTCLoss Grad Norm no support {grad_norm_type}")
def forward(self, logits, ys_pad, hlens, ys_lens):
"""Compute CTC loss.

Loading…
Cancel
Save