fix log for u2

pull/843/head
Hui Zhang 3 years ago
parent 98c0d43ae4
commit 7ae204eb0f

@ -123,13 +123,13 @@ class U2Trainer(Trainer):
iteration_time = time.time() - start iteration_time = time.time() - start
if (batch_index + 1) % train_conf.log_interval == 0:
for k, v in losses_np.items(): for k, v in losses_np.items():
report(k, v) report(k, v)
report("batch_size", self.config.collator.batch_size) report("batch_size", self.config.collator.batch_size)
report("accum", train_conf.accum_grad) report("accum", train_conf.accum_grad)
report("step_cost", iteration_time) report("step_cost", iteration_time)
if (batch_index + 1) % train_conf.accum_grad == 0:
if dist.get_rank() == 0 and self.visualizer: if dist.get_rank() == 0 and self.visualizer:
losses_np_v = losses_np.copy() losses_np_v = losses_np.copy()
losses_np_v.update({"lr": self.lr_scheduler()}) losses_np_v.update({"lr": self.lr_scheduler()})
@ -223,6 +223,8 @@ class U2Trainer(Trainer):
msg += f"{v:>.8f}" if isinstance(v, msg += f"{v:>.8f}" if isinstance(v,
float) else f"{v}" float) else f"{v}"
msg += "," msg += ","
if (batch_index + 1
) % self.config.training.log_interval == 0:
logger.info(msg) logger.info(msg)
data_start_time = time.time() data_start_time = time.time()
except Exception as e: except Exception as e:

Loading…
Cancel
Save