|
|
|
@ -80,54 +80,60 @@ class U2Trainer(Trainer):
|
|
|
|
|
self.model.train()
|
|
|
|
|
|
|
|
|
|
start = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss, attention_loss, ctc_loss = self.model(*batch_data)
|
|
|
|
|
# loss div by `batch_size * accum_grad`
|
|
|
|
|
loss /= train_conf.accum_grad
|
|
|
|
|
loss.backward()
|
|
|
|
|
layer_tools.print_grads(self.model, print_func=None)
|
|
|
|
|
|
|
|
|
|
if self.iteration % train_conf.accum_grad == 0:
|
|
|
|
|
|
|
|
|
|
losses_np = {
|
|
|
|
|
'train_loss': float(loss) * train_conf.accum_grad,
|
|
|
|
|
'train_att_loss': float(attention_loss),
|
|
|
|
|
'train_ctc_loss': float(ctc_loss),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (self.iteration + 1) % train_conf.accum_grad == 0:
|
|
|
|
|
if dist.get_rank() == 0 and self.visualizer:
|
|
|
|
|
for k, v in losses_np.items():
|
|
|
|
|
self.visualizer.add_scalar("train/{}".format(k), v,
|
|
|
|
|
self.iteration)
|
|
|
|
|
self.optimizer.step()
|
|
|
|
|
self.optimizer.clear_grad()
|
|
|
|
|
self.lr_scheduler.step()
|
|
|
|
|
self.iteration += 1
|
|
|
|
|
|
|
|
|
|
iteration_time = time.time() - start
|
|
|
|
|
|
|
|
|
|
losses_np = {
|
|
|
|
|
'train_loss': float(loss),
|
|
|
|
|
'train_att_loss': float(attention_loss),
|
|
|
|
|
'train_ctc_loss': float(ctc_loss),
|
|
|
|
|
}
|
|
|
|
|
msg += "time: {:>.3f}s, ".format(iteration_time)
|
|
|
|
|
msg += "batch size: {}, ".format(self.config.data.batch_size)
|
|
|
|
|
msg += "accum: {}, ".format(train_conf.accum_grad)
|
|
|
|
|
msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
|
|
|
|
for k, v in losses_np.items())
|
|
|
|
|
if self.iteration % train_conf.log_interval == 0:
|
|
|
|
|
if (self.iteration + 1) % train_conf.log_interval == 0:
|
|
|
|
|
msg += "time: {:>.3f}s, ".format(iteration_time)
|
|
|
|
|
msg += "batch size: {}, ".format(self.config.data.batch_size)
|
|
|
|
|
msg += "accum: {}, ".format(train_conf.accum_grad)
|
|
|
|
|
msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
|
|
|
|
for k, v in losses_np.items())
|
|
|
|
|
self.logger.info(msg)
|
|
|
|
|
|
|
|
|
|
# display
|
|
|
|
|
if dist.get_rank() == 0 and self.visualizer:
|
|
|
|
|
for k, v in losses_np.items():
|
|
|
|
|
self.visualizer.add_scalar("train/{}".format(k), v,
|
|
|
|
|
self.iteration)
|
|
|
|
|
|
|
|
|
|
def train(self):
|
|
|
|
|
"""The training process.
|
|
|
|
|
It includes forward/backward/update and periodical validation and
|
|
|
|
|
saving.
|
|
|
|
|
"""
|
|
|
|
|
"""The training process control by step."""
|
|
|
|
|
# !!!IMPORTANT!!!
|
|
|
|
|
# Try to export the model by script, if fails, we should refine
|
|
|
|
|
# the code to satisfy the script export requirements
|
|
|
|
|
# script_model = paddle.jit.to_static(self.model)
|
|
|
|
|
# script_model_path = str(self.checkpoint_dir / 'init')
|
|
|
|
|
# paddle.jit.save(script_model, script_model_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from_scratch = self.resume_or_scratch()
|
|
|
|
|
if from_scratch:
|
|
|
|
|
# save init model, i.e. 0 epoch
|
|
|
|
|
self.save(tag='init')
|
|
|
|
|
|
|
|
|
|
self.lr_scheduler.step(self.iteration)
|
|
|
|
|
if self.parallel:
|
|
|
|
|
self.train_loader.batch_sampler.set_epoch(self.epoch)
|
|
|
|
|
|
|
|
|
|
self.logger.info(
|
|
|
|
|
f"Train Total Examples: {len(self.train_loader.dataset)}")
|
|
|
|
|
|
|
|
|
|
while self.epoch <= self.config.training.n_epoch:
|
|
|
|
|
while self.epoch < self.config.training.n_epoch:
|
|
|
|
|
try:
|
|
|
|
|
data_start_time = time.time()
|
|
|
|
|
for batch in self.train_loader:
|
|
|
|
@ -135,19 +141,18 @@ class U2Trainer(Trainer):
|
|
|
|
|
msg = "Train: Rank: {}, ".format(dist.get_rank())
|
|
|
|
|
msg += "epoch: {}, ".format(self.epoch)
|
|
|
|
|
msg += "step: {}, ".format(self.iteration)
|
|
|
|
|
msg += "lr: {}, ".foramt(self.lr_scheduler())
|
|
|
|
|
msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
|
|
|
|
|
msg += "dataloader time: {:>.3f}s, ".format(dataload_time)
|
|
|
|
|
self.iteration += 1
|
|
|
|
|
self.train_batch(batch, msg)
|
|
|
|
|
data_start_time = time.time()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
self.logger.error(e)
|
|
|
|
|
raise e
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.valid()
|
|
|
|
|
self.save()
|
|
|
|
|
self.new_epoch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@mp_tools.rank_zero_only
|
|
|
|
|
@paddle.no_grad()
|
|
|
|
|
def valid(self):
|
|
|
|
@ -263,12 +268,12 @@ class U2Trainer(Trainer):
|
|
|
|
|
lr_scheduler = paddle.optimizer.lr.ExponentialDecay(
|
|
|
|
|
learning_rate=optim_conf.lr,
|
|
|
|
|
gamma=scheduler_conf.lr_decay,
|
|
|
|
|
verbose=True)
|
|
|
|
|
verbose=False)
|
|
|
|
|
elif scheduler_type == 'warmuplr':
|
|
|
|
|
lr_scheduler = WarmupLR(
|
|
|
|
|
learning_rate=optim_conf.lr,
|
|
|
|
|
warmup_steps=scheduler_conf.warmup_steps,
|
|
|
|
|
verbose=True)
|
|
|
|
|
verbose=False)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Not support scheduler: {scheduler_type}")
|
|
|
|
|
|
|
|
|
|