|
|
@ -109,6 +109,43 @@ class U2Trainer(Trainer):
|
|
|
|
self.visualizer.add_scalars("step", losses_np_v,
|
|
|
|
self.visualizer.add_scalars("step", losses_np_v,
|
|
|
|
self.iteration - 1)
|
|
|
|
self.iteration - 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@paddle.no_grad()
|
|
|
|
|
|
|
|
def valid(self):
|
|
|
|
|
|
|
|
self.model.eval()
|
|
|
|
|
|
|
|
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
|
|
|
|
|
|
|
|
valid_losses = defaultdict(list)
|
|
|
|
|
|
|
|
num_seen_utts = 1
|
|
|
|
|
|
|
|
total_loss = 0.0
|
|
|
|
|
|
|
|
for i, batch in enumerate(self.valid_loader):
|
|
|
|
|
|
|
|
loss, attention_loss, ctc_loss = self.model(*batch)
|
|
|
|
|
|
|
|
if paddle.isfinite(loss):
|
|
|
|
|
|
|
|
num_utts = batch[0].shape[0]
|
|
|
|
|
|
|
|
num_seen_utts += num_utts
|
|
|
|
|
|
|
|
total_loss += float(loss) * num_utts
|
|
|
|
|
|
|
|
valid_losses['val_loss'].append(float(loss))
|
|
|
|
|
|
|
|
if attention_loss:
|
|
|
|
|
|
|
|
valid_losses['val_att_loss'].append(float(attention_loss))
|
|
|
|
|
|
|
|
if ctc_loss:
|
|
|
|
|
|
|
|
valid_losses['val_ctc_loss'].append(float(ctc_loss))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (i + 1) % self.config.training.log_interval == 0:
|
|
|
|
|
|
|
|
valid_losses = {k: np.mean(v) for k, v in valid_losses.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
valid_losses['val_history_loss'] = total_loss / num_seen_utts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# logging
|
|
|
|
|
|
|
|
msg = f"Valid: Rank: {dist.get_rank()}, "
|
|
|
|
|
|
|
|
msg += "epoch: {}, ".format(self.epoch)
|
|
|
|
|
|
|
|
msg += "step: {}, ".format(self.iteration)
|
|
|
|
|
|
|
|
msg += "batch : {}/{}, ".format(i + 1, len(self.valid_loader))
|
|
|
|
|
|
|
|
msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
|
|
|
|
|
|
|
for k, v in valid_losses.items())
|
|
|
|
|
|
|
|
logger.info(msg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info('Rank {} Val info val_loss {}'.format(
|
|
|
|
|
|
|
|
dist.get_rank(), total_loss / num_seen_utts))
|
|
|
|
|
|
|
|
return total_loss, num_seen_utts
|
|
|
|
|
|
|
|
|
|
|
|
def train(self):
|
|
|
|
def train(self):
|
|
|
|
"""The training process control by step."""
|
|
|
|
"""The training process control by step."""
|
|
|
|
# !!!IMPORTANT!!!
|
|
|
|
# !!!IMPORTANT!!!
|
|
|
@ -148,52 +185,25 @@ class U2Trainer(Trainer):
|
|
|
|
raise e
|
|
|
|
raise e
|
|
|
|
|
|
|
|
|
|
|
|
total_loss, num_seen_utts = self.valid()
|
|
|
|
total_loss, num_seen_utts = self.valid()
|
|
|
|
self.save(
|
|
|
|
if dist.get_world_size() > 1:
|
|
|
|
tag=self.epoch, infos={'val_loss': total_loss / num_seen_utts})
|
|
|
|
num_seen_utts = paddle.to_tensor(num_seen_utts)
|
|
|
|
self.new_epoch()
|
|
|
|
# the default operator in all_reduce function is sum.
|
|
|
|
|
|
|
|
dist.all_reduce(num_seen_utts)
|
|
|
|
@mp_tools.rank_zero_only
|
|
|
|
total_loss = paddle.to_tensor(total_loss)
|
|
|
|
@paddle.no_grad()
|
|
|
|
dist.all_reduce(total_loss)
|
|
|
|
def valid(self):
|
|
|
|
cv_loss = total_loss / num_seen_utts
|
|
|
|
self.model.eval()
|
|
|
|
cv_loss = float(cv_loss)
|
|
|
|
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
|
|
|
|
else:
|
|
|
|
valid_losses = defaultdict(list)
|
|
|
|
cv_loss = total_loss / num_seen_utts
|
|
|
|
num_seen_utts = 1
|
|
|
|
|
|
|
|
total_loss = 0.0
|
|
|
|
|
|
|
|
for i, batch in enumerate(self.valid_loader):
|
|
|
|
|
|
|
|
loss, attention_loss, ctc_loss = self.model(*batch)
|
|
|
|
|
|
|
|
if paddle.isfinite(loss):
|
|
|
|
|
|
|
|
num_utts = batch[0].shape[0]
|
|
|
|
|
|
|
|
num_seen_utts += num_utts
|
|
|
|
|
|
|
|
total_loss += float(loss) * num_utts
|
|
|
|
|
|
|
|
valid_losses = {'val_loss': float(loss)}
|
|
|
|
|
|
|
|
if attention_loss:
|
|
|
|
|
|
|
|
valid_losses['val_att_loss'] = float(attention_loss)
|
|
|
|
|
|
|
|
if ctc_loss:
|
|
|
|
|
|
|
|
valid_losses['val_ctc_loss'] = float(ctc_loss)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (i + 1) % self.config.training.log_interval == 0:
|
|
|
|
|
|
|
|
valid_losses['val_history_loss'] = total_loss / num_seen_utts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# write visual log
|
|
|
|
|
|
|
|
valid_losses = {k: np.mean(v) for k, v in valid_losses.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# logging
|
|
|
|
|
|
|
|
msg = f"Valid: Rank: {dist.get_rank()}, "
|
|
|
|
|
|
|
|
msg += "epoch: {}, ".format(self.epoch)
|
|
|
|
|
|
|
|
msg += "step: {}, ".format(self.iteration)
|
|
|
|
|
|
|
|
msg += "batch : {}/{}, ".format(i + 1, len(self.valid_loader))
|
|
|
|
|
|
|
|
msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
|
|
|
|
|
|
|
for k, v in valid_losses.items())
|
|
|
|
|
|
|
|
logger.info(msg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
|
|
|
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
|
|
|
|
if self.visualizer:
|
|
|
|
if self.visualizer:
|
|
|
|
valid_losses_v = valid_losses.copy()
|
|
|
|
self.visualizer.add_scalars(
|
|
|
|
valid_losses_v.update({"lr": self.lr_scheduler()})
|
|
|
|
'epoch', {'cv_loss': cv_loss,
|
|
|
|
self.visualizer.add_scalars('epoch', valid_losses_v,
|
|
|
|
'lr': self.lr_scheduler()}, self.epoch)
|
|
|
|
self.epoch)
|
|
|
|
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
|
|
|
|
|
|
|
|
self.new_epoch()
|
|
|
|
return total_loss, num_seen_utts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def setup_dataloader(self):
|
|
|
|
def setup_dataloader(self):
|
|
|
|
config = self.config.clone()
|
|
|
|
config = self.config.clone()
|
|
|
|