|
|
@ -243,6 +243,7 @@ class U2Trainer(Trainer):
|
|
|
|
self.visualizer.add_scalars(
|
|
|
|
self.visualizer.add_scalars(
|
|
|
|
'epoch', {'cv_loss': cv_loss,
|
|
|
|
'epoch', {'cv_loss': cv_loss,
|
|
|
|
'lr': self.lr_scheduler()}, self.epoch)
|
|
|
|
'lr': self.lr_scheduler()}, self.epoch)
|
|
|
|
|
|
|
|
|
|
|
|
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
|
|
|
|
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
|
|
|
|
self.new_epoch()
|
|
|
|
self.new_epoch()
|
|
|
|
|
|
|
|
|
|
|
@ -291,7 +292,8 @@ class U2Trainer(Trainer):
|
|
|
|
batch_size=config.collator.batch_size,
|
|
|
|
batch_size=config.collator.batch_size,
|
|
|
|
shuffle=False,
|
|
|
|
shuffle=False,
|
|
|
|
drop_last=False,
|
|
|
|
drop_last=False,
|
|
|
|
collate_fn=collate_fn_dev)
|
|
|
|
collate_fn=collate_fn_dev,
|
|
|
|
|
|
|
|
num_workers=config.collator.num_workers, )
|
|
|
|
|
|
|
|
|
|
|
|
# test dataset, return raw text
|
|
|
|
# test dataset, return raw text
|
|
|
|
config.data.manifest = config.data.test_manifest
|
|
|
|
config.data.manifest = config.data.test_manifest
|
|
|
|