fix train valid log

pull/522/head
Hui Zhang 5 years ago
parent bb947eecb6
commit d79ae3824a

@ -114,7 +114,6 @@ class AudioFeaturizer(object):
else:
raise ValueError("Unknown specgram_type %s. "
"Supported values: linear." % self._specgram_type)
print('feat_dim:', feat_dim)
return feat_dim
def _compute_specgram(self, samples, sample_rate):

@ -150,8 +150,8 @@ class DeepSpeech2Trainer(Trainer):
self.logger.info("Setup model/optimizer/criterion!")
def compute_losses(self, inputs, outputs):
_, texts, _, texts_len = inputs
logits, logits_len = outputs
del inputs
logits, texts, logits_len, texts_len = outputs
loss = self.criterion(logits, texts, logits_len, texts_len)
return loss
@ -169,8 +169,8 @@ class DeepSpeech2Trainer(Trainer):
self.optimizer.step()
iteration_time = time.time() - start
losses_np = {'loss': float(loss)}
msg = "Rank: {}, ".format(dist.get_rank())
losses_np = {'train_loss': float(loss)}
msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
@ -185,6 +185,9 @@ class DeepSpeech2Trainer(Trainer):
self.visualizer.add_scalar("train/{}".format(k), v,
self.iteration)
def compute_metrics(self, inputs, outputs):
pass
@mp_tools.rank_zero_only
@paddle.no_grad()
def valid(self):
@ -192,15 +195,17 @@ class DeepSpeech2Trainer(Trainer):
for i, batch in enumerate(self.valid_loader):
audio, text, audio_len, text_len = batch
outputs = self.model(audio, text, audio_len, text_len)
losses = self.compute_losses(batch, outputs)
loss = self.compute_losses(batch, outputs)
metrics = self.compute_metrics(batch, outputs)
valid_losses['val_loss'].append(float(v))
valid_losses['val_loss'].append(float(loss))
# write visual log
valid_losses = {k: np.mean(v) for k, v in valid_losses.items()}
# logging
msg = "Valid: "
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_losses.items())

@ -532,7 +532,7 @@ class DeepSpeech2(nn.Layer):
text_len: shape [B]
"""
logits, _, audio_len = self.predict(audio, audio_len)
return logits, audio_len
return logits, text, audio_len, text_len
class DeepSpeech2Loss(nn.Layer):

Loading…
Cancel
Save