|
|
@ -81,11 +81,11 @@ class U2Trainer(Trainer):
|
|
|
|
loss.backward()
|
|
|
|
loss.backward()
|
|
|
|
layer_tools.print_grads(self.model, print_func=None)
|
|
|
|
layer_tools.print_grads(self.model, print_func=None)
|
|
|
|
|
|
|
|
|
|
|
|
losses_np = {
|
|
|
|
losses_np = {'loss': float(loss) * train_conf.accum_grad}
|
|
|
|
'train_loss': float(loss) * train_conf.accum_grad,
|
|
|
|
if attention_loss:
|
|
|
|
'train_att_loss': float(attention_loss),
|
|
|
|
losses_np['att_loss'] = float(attention_loss)
|
|
|
|
'train_ctc_loss': float(ctc_loss),
|
|
|
|
if ctc_loss:
|
|
|
|
}
|
|
|
|
losses_np['ctc_loss'] = float(ctc_loss)
|
|
|
|
|
|
|
|
|
|
|
|
if (batch_index + 1) % train_conf.accum_grad == 0:
|
|
|
|
if (batch_index + 1) % train_conf.accum_grad == 0:
|
|
|
|
if dist.get_rank() == 0 and self.visualizer:
|
|
|
|
if dist.get_rank() == 0 and self.visualizer:
|
|
|
@ -135,6 +135,8 @@ class U2Trainer(Trainer):
|
|
|
|
msg = "Train: Rank: {}, ".format(dist.get_rank())
|
|
|
|
msg = "Train: Rank: {}, ".format(dist.get_rank())
|
|
|
|
msg += "epoch: {}, ".format(self.epoch)
|
|
|
|
msg += "epoch: {}, ".format(self.epoch)
|
|
|
|
msg += "step: {}, ".format(self.iteration)
|
|
|
|
msg += "step: {}, ".format(self.iteration)
|
|
|
|
|
|
|
|
msg += "batch : {}/{}, ".format(batch_index + 1,
|
|
|
|
|
|
|
|
len(self.train_loader))
|
|
|
|
msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
|
|
|
|
msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
|
|
|
|
msg += "dataloader time: {:>.3f}s, ".format(dataload_time)
|
|
|
|
msg += "dataloader time: {:>.3f}s, ".format(dataload_time)
|
|
|
|
self.train_batch(batch_index, batch, msg)
|
|
|
|
self.train_batch(batch_index, batch, msg)
|
|
|
@ -143,8 +145,9 @@ class U2Trainer(Trainer):
|
|
|
|
logger.error(e)
|
|
|
|
logger.error(e)
|
|
|
|
raise e
|
|
|
|
raise e
|
|
|
|
|
|
|
|
|
|
|
|
valid_losses = self.valid()
|
|
|
|
total_loss, num_seen_utts = self.valid()
|
|
|
|
self.save(tag=self.epoch, infos=valid_losses)
|
|
|
|
self.save(
|
|
|
|
|
|
|
|
tag=self.epoch, infos={'val_loss': total_loss / num_seen_utts})
|
|
|
|
self.new_epoch()
|
|
|
|
self.new_epoch()
|
|
|
|
|
|
|
|
|
|
|
|
@mp_tools.rank_zero_only
|
|
|
|
@mp_tools.rank_zero_only
|
|
|
@ -153,29 +156,42 @@ class U2Trainer(Trainer):
|
|
|
|
self.model.eval()
|
|
|
|
self.model.eval()
|
|
|
|
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
|
|
|
|
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
|
|
|
|
valid_losses = defaultdict(list)
|
|
|
|
valid_losses = defaultdict(list)
|
|
|
|
|
|
|
|
num_seen_utts = 1
|
|
|
|
|
|
|
|
total_loss = 0.0
|
|
|
|
for i, batch in enumerate(self.valid_loader):
|
|
|
|
for i, batch in enumerate(self.valid_loader):
|
|
|
|
total_loss, attention_loss, ctc_loss = self.model(*batch)
|
|
|
|
loss, attention_loss, ctc_loss = self.model(*batch)
|
|
|
|
|
|
|
|
if paddle.isfinite(loss):
|
|
|
|
valid_losses['val_loss'].append(float(total_loss))
|
|
|
|
num_utts = batch[0].shape[0]
|
|
|
|
valid_losses['val_att_loss'].append(float(attention_loss))
|
|
|
|
num_seen_utts += num_utts
|
|
|
|
valid_losses['val_ctc_loss'].append(float(ctc_loss))
|
|
|
|
total_loss += float(loss) * num_utts
|
|
|
|
|
|
|
|
valid_losses = {'val_loss': float(loss)}
|
|
|
|
# write visual log
|
|
|
|
if attention_loss:
|
|
|
|
valid_losses = {k: np.mean(v) for k, v in valid_losses.items()}
|
|
|
|
valid_losses['val_att_loss'] = float(attention_loss)
|
|
|
|
|
|
|
|
if ctc_loss:
|
|
|
|
# logging
|
|
|
|
valid_losses['val_ctc_loss'] = float(ctc_loss)
|
|
|
|
msg = f"Valid: Rank: {dist.get_rank()}, "
|
|
|
|
|
|
|
|
msg += "epoch: {}, ".format(self.epoch)
|
|
|
|
if (i + 1) % self.config.training.log_interval == 0:
|
|
|
|
msg += "step: {}, ".format(self.iteration)
|
|
|
|
valid_losses['val_history_loss'] = total_loss / num_seen_utts
|
|
|
|
msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
|
|
|
|
|
|
|
for k, v in valid_losses.items())
|
|
|
|
# write visual log
|
|
|
|
logger.info(msg)
|
|
|
|
valid_losses = {k: np.mean(v) for k, v in valid_losses.items()}
|
|
|
|
|
|
|
|
|
|
|
|
if self.visualizer:
|
|
|
|
# logging
|
|
|
|
valid_losses_v = valid_losses.copy()
|
|
|
|
msg = f"Valid: Rank: {dist.get_rank()}, "
|
|
|
|
valid_losses_v.update({"lr": self.lr_scheduler()})
|
|
|
|
msg += "epoch: {}, ".format(self.epoch)
|
|
|
|
self.visualizer.add_scalars('epoch', valid_losses_v, self.epoch)
|
|
|
|
msg += "step: {}, ".format(self.iteration)
|
|
|
|
return valid_losses
|
|
|
|
msg += "batch : {}/{}, ".format(i + 1, len(self.valid_loader))
|
|
|
|
|
|
|
|
msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
|
|
|
|
|
|
|
for k, v in valid_losses.items())
|
|
|
|
|
|
|
|
logger.info(msg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.visualizer:
|
|
|
|
|
|
|
|
valid_losses_v = valid_losses.copy()
|
|
|
|
|
|
|
|
valid_losses_v.update({"lr": self.lr_scheduler()})
|
|
|
|
|
|
|
|
self.visualizer.add_scalars('epoch', valid_losses_v,
|
|
|
|
|
|
|
|
self.epoch)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return total_loss, num_seen_utts
|
|
|
|
|
|
|
|
|
|
|
|
def setup_dataloader(self):
|
|
|
|
def setup_dataloader(self):
|
|
|
|
config = self.config.clone()
|
|
|
|
config = self.config.clone()
|
|
|
|