|
|
@ -123,13 +123,13 @@ class U2Trainer(Trainer):
|
|
|
|
|
|
|
|
|
|
|
|
iteration_time = time.time() - start
|
|
|
|
iteration_time = time.time() - start
|
|
|
|
|
|
|
|
|
|
|
|
if (batch_index + 1) % train_conf.log_interval == 0:
|
|
|
|
|
|
|
|
for k, v in losses_np.items():
|
|
|
|
for k, v in losses_np.items():
|
|
|
|
report(k, v)
|
|
|
|
report(k, v)
|
|
|
|
report("batch_size", self.config.collator.batch_size)
|
|
|
|
report("batch_size", self.config.collator.batch_size)
|
|
|
|
report("accum", train_conf.accum_grad)
|
|
|
|
report("accum", train_conf.accum_grad)
|
|
|
|
report("step_cost", iteration_time)
|
|
|
|
report("step_cost", iteration_time)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (batch_index + 1) % train_conf.accum_grad == 0:
|
|
|
|
if dist.get_rank() == 0 and self.visualizer:
|
|
|
|
if dist.get_rank() == 0 and self.visualizer:
|
|
|
|
losses_np_v = losses_np.copy()
|
|
|
|
losses_np_v = losses_np.copy()
|
|
|
|
losses_np_v.update({"lr": self.lr_scheduler()})
|
|
|
|
losses_np_v.update({"lr": self.lr_scheduler()})
|
|
|
@ -223,6 +223,8 @@ class U2Trainer(Trainer):
|
|
|
|
msg += f"{v:>.8f}" if isinstance(v,
|
|
|
|
msg += f"{v:>.8f}" if isinstance(v,
|
|
|
|
float) else f"{v}"
|
|
|
|
float) else f"{v}"
|
|
|
|
msg += ","
|
|
|
|
msg += ","
|
|
|
|
|
|
|
|
if (batch_index + 1
|
|
|
|
|
|
|
|
) % self.config.training.log_interval == 0:
|
|
|
|
logger.info(msg)
|
|
|
|
logger.info(msg)
|
|
|
|
data_start_time = time.time()
|
|
|
|
data_start_time = time.time()
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|