|
|
|
@ -199,8 +199,8 @@ class U2Trainer(Trainer):
|
|
|
|
|
report("Rank", dist.get_rank())
|
|
|
|
|
report("epoch", self.epoch)
|
|
|
|
|
report('step', self.iteration)
|
|
|
|
|
report('step/total',
|
|
|
|
|
(batch_index + 1) / len(self.train_loader))
|
|
|
|
|
report('iter', batch_index + 1)
|
|
|
|
|
report('total',len(self.train_loader))
|
|
|
|
|
report("lr", self.lr_scheduler())
|
|
|
|
|
self.train_batch(batch_index, batch, msg)
|
|
|
|
|
self.after_train_batch()
|
|
|
|
|