|
|
|
@ -96,7 +96,7 @@ class U2Trainer(Trainer):
|
|
|
|
|
iteration_time = time.time() - start
|
|
|
|
|
|
|
|
|
|
if (batch_index + 1) % train_conf.log_interval == 0:
|
|
|
|
|
msg += "time: {:>.3f}s, ".format(iteration_time)
|
|
|
|
|
msg += "train time: {:>.3f}s, ".format(iteration_time)
|
|
|
|
|
msg += "batch size: {}, ".format(self.config.data.batch_size)
|
|
|
|
|
msg += "accum: {}, ".format(train_conf.accum_grad)
|
|
|
|
|
msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
|
|
|
@ -177,7 +177,7 @@ class U2Trainer(Trainer):
|
|
|
|
|
msg += "batch : {}/{}, ".format(batch_index + 1,
|
|
|
|
|
len(self.train_loader))
|
|
|
|
|
msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
|
|
|
|
|
msg += "dataloader time: {:>.3f}s, ".format(dataload_time)
|
|
|
|
|
msg += "data time: {:>.3f}s, ".format(dataload_time)
|
|
|
|
|
self.train_batch(batch_index, batch, msg)
|
|
|
|
|
data_start_time = time.time()
|
|
|
|
|
except Exception as e:
|
|
|
|
@ -275,6 +275,7 @@ class U2Trainer(Trainer):
|
|
|
|
|
if self.parallel:
|
|
|
|
|
model = paddle.DataParallel(model)
|
|
|
|
|
|
|
|
|
|
logger.info(f"{model}")
|
|
|
|
|
layer_tools.print_params(model, logger.info)
|
|
|
|
|
|
|
|
|
|
train_config = config.training
|
|
|
|
|