|
|
|
@ -34,7 +34,7 @@ from speechtask.punctuation_restoration.model.lstm import RnnLm
|
|
|
|
|
from speechtask.punctuation_restoration.utils import layer_tools
|
|
|
|
|
from speechtask.punctuation_restoration.utils import mp_tools
|
|
|
|
|
from speechtask.punctuation_restoration.utils.checkpoint import Checkpoint
|
|
|
|
|
from tensorboardX import SummaryWriter
|
|
|
|
|
from visualdl import LogWriter
|
|
|
|
|
|
|
|
|
|
__all__ = ["Trainer", "Tester"]
|
|
|
|
|
|
|
|
|
@ -252,10 +252,10 @@ class Trainer():
|
|
|
|
|
self.logger.info("Epoch {} Val info val_loss {}, F1_score {}".
|
|
|
|
|
format(self.epoch, total_loss, F1_score))
|
|
|
|
|
if self.visualizer:
|
|
|
|
|
self.visualizer.add_scalars("epoch", {
|
|
|
|
|
"total_loss": total_loss,
|
|
|
|
|
"lr": self.lr_scheduler()
|
|
|
|
|
}, self.epoch)
|
|
|
|
|
self.visualizer.add_scalar(
|
|
|
|
|
tag='eval/cv_loss', value=cv_loss, step=self.epoch)
|
|
|
|
|
self.visualizer.add_scalar(
|
|
|
|
|
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
|
|
|
|
|
|
|
|
|
|
self.save(
|
|
|
|
|
tag=self.epoch, infos={"val_loss": total_loss,
|
|
|
|
@ -341,7 +341,7 @@ class Trainer():
|
|
|
|
|
unexpected behaviors.
|
|
|
|
|
"""
|
|
|
|
|
# visualizer
|
|
|
|
|
visualizer = SummaryWriter(logdir=str(self.output_dir))
|
|
|
|
|
visualizer = LogWriter(logdir=str(self.output_dir))
|
|
|
|
|
self.visualizer = visualizer
|
|
|
|
|
|
|
|
|
|
@mp_tools.rank_zero_only
|
|
|
|
|