Update tacotron2_updater.py

pull/2860/head
TianYuan 3 years ago committed by GitHub
parent b3934536ab
commit 7b7dd7158d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -114,7 +114,7 @@ class Tacotron2Updater(StandardUpdater):
optimizer.step()
if self.use_guided_attn_loss:
report("eval/attn_loss", float(attn_loss))
report("train/attn_loss", float(attn_loss))
losses_dict["attn_loss"] = float(attn_loss)
report("train/l1_loss", float(l1_loss))
@ -204,17 +204,19 @@ class Tacotron2Evaluator(StandardEvaluator):
attn_loss = self.attn_loss(
att_ws=att_ws, ilens=batch["text_lengths"] + 1, olens=olens_in)
loss = loss + attn_loss
if self.use_guided_attn_loss:
report("eval/attn_loss", float(attn_loss))
losses_dict["attn_loss"] = float(attn_loss)
report("eval/l1_loss", float(l1_loss))
report("eval/mse_loss", float(mse_loss))
report("eval/bce_loss", float(bce_loss))
report("eval/attn_loss", float(attn_loss))
report("eval/loss", float(loss))
losses_dict["l1_loss"] = float(l1_loss)
losses_dict["mse_loss"] = float(mse_loss)
losses_dict["bce_loss"] = float(bce_loss)
losses_dict["attn_loss"] = float(attn_loss)
losses_dict["loss"] = float(loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())

Loading…
Cancel
Save