|
|
|
@ -150,8 +150,8 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
|
self.logger.info("Setup model/optimizer/criterion!")
|
|
|
|
|
|
|
|
|
|
def compute_losses(self, inputs, outputs):
|
|
|
|
|
_, texts, _, texts_len = inputs
|
|
|
|
|
logits, logits_len = outputs
|
|
|
|
|
del inputs
|
|
|
|
|
logits, texts, logits_len, texts_len = outputs
|
|
|
|
|
loss = self.criterion(logits, texts, logits_len, texts_len)
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
@ -169,8 +169,8 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
|
self.optimizer.step()
|
|
|
|
|
iteration_time = time.time() - start
|
|
|
|
|
|
|
|
|
|
losses_np = {'loss': float(loss)}
|
|
|
|
|
msg = "Rank: {}, ".format(dist.get_rank())
|
|
|
|
|
losses_np = {'train_loss': float(loss)}
|
|
|
|
|
msg = "Train: Rank: {}, ".format(dist.get_rank())
|
|
|
|
|
msg += "epoch: {}, ".format(self.epoch)
|
|
|
|
|
msg += "step: {}, ".format(self.iteration)
|
|
|
|
|
|
|
|
|
@ -185,6 +185,9 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
|
self.visualizer.add_scalar("train/{}".format(k), v,
|
|
|
|
|
self.iteration)
|
|
|
|
|
|
|
|
|
|
def compute_metrics(self, inputs, outputs):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
@mp_tools.rank_zero_only
|
|
|
|
|
@paddle.no_grad()
|
|
|
|
|
def valid(self):
|
|
|
|
@ -192,15 +195,17 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
|
for i, batch in enumerate(self.valid_loader):
|
|
|
|
|
audio, text, audio_len, text_len = batch
|
|
|
|
|
outputs = self.model(audio, text, audio_len, text_len)
|
|
|
|
|
losses = self.compute_losses(batch, outputs)
|
|
|
|
|
loss = self.compute_losses(batch, outputs)
|
|
|
|
|
metrics = self.compute_metrics(batch, outputs)
|
|
|
|
|
|
|
|
|
|
valid_losses['val_loss'].append(float(v))
|
|
|
|
|
valid_losses['val_loss'].append(float(loss))
|
|
|
|
|
|
|
|
|
|
# write visual log
|
|
|
|
|
valid_losses = {k: np.mean(v) for k, v in valid_losses.items()}
|
|
|
|
|
|
|
|
|
|
# logging
|
|
|
|
|
msg = "Valid: "
|
|
|
|
|
msg += "epoch: {}, ".format(self.epoch)
|
|
|
|
|
msg += "step: {}, ".format(self.iteration)
|
|
|
|
|
msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
|
|
|
|
for k, v in valid_losses.items())
|
|
|
|
|