|
|
|
@ -289,6 +289,7 @@ class Trainer():
|
|
|
|
|
float) else f"{v}"
|
|
|
|
|
msg += ","
|
|
|
|
|
msg = msg[:-1] # remove the last ","
|
|
|
|
|
if (batch_index + 1) % self.config.log_interval == 0:
|
|
|
|
|
logger.info(msg)
|
|
|
|
|
data_start_time = time.time()
|
|
|
|
|
except Exception as e:
|
|
|
|
@ -316,10 +317,10 @@ class Trainer():
|
|
|
|
|
self.visualizer.add_scalar(
|
|
|
|
|
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
|
|
|
|
|
|
|
|
|
|
# after epoch
|
|
|
|
|
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
|
|
|
|
|
# step lr every epoch
|
|
|
|
|
self.lr_scheduler.step()
|
|
|
|
|
# after epoch
|
|
|
|
|
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
|
|
|
|
|
self.new_epoch()
|
|
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
|