|
|
|
@ -139,7 +139,8 @@ class U2STTrainer(Trainer):
|
|
|
|
|
losses_np_v = losses_np.copy()
|
|
|
|
|
losses_np_v.update({"lr": self.lr_scheduler()})
|
|
|
|
|
for key, val in losses_np_v.items():
|
|
|
|
|
self.visualizer.add_scalar(tag="train/"+key, value=val, step=self.iteration - 1)
|
|
|
|
|
self.visualizer.add_scalar(
|
|
|
|
|
tag="train/" + key, value=val, step=self.iteration - 1)
|
|
|
|
|
|
|
|
|
|
@paddle.no_grad()
|
|
|
|
|
def valid(self):
|
|
|
|
@ -235,8 +236,10 @@ class U2STTrainer(Trainer):
|
|
|
|
|
logger.info(
|
|
|
|
|
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
|
|
|
|
|
if self.visualizer:
|
|
|
|
|
self.visualizer.add_scalar(tag='eval/cv_loss', value=cv_loss, step=self.epoch)
|
|
|
|
|
self.visualizer.add_scalar(tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
|
|
|
|
|
self.visualizer.add_scalar(
|
|
|
|
|
tag='eval/cv_loss', value=cv_loss, step=self.epoch)
|
|
|
|
|
self.visualizer.add_scalar(
|
|
|
|
|
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
|
|
|
|
|
|
|
|
|
|
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
|
|
|
|
|
self.new_epoch()
|
|
|
|
|