|
|
|
@ -113,16 +113,18 @@ class Tacotron2Updater(StandardUpdater):
|
|
|
|
|
loss.backward()
|
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
|
|
if self.use_guided_attn_loss:
|
|
|
|
|
report("train/attn_loss", float(attn_loss))
|
|
|
|
|
losses_dict["attn_loss"] = float(attn_loss)
|
|
|
|
|
|
|
|
|
|
report("train/l1_loss", float(l1_loss))
|
|
|
|
|
report("train/mse_loss", float(mse_loss))
|
|
|
|
|
report("train/bce_loss", float(bce_loss))
|
|
|
|
|
report("train/attn_loss", float(attn_loss))
|
|
|
|
|
report("train/loss", float(loss))
|
|
|
|
|
|
|
|
|
|
losses_dict["l1_loss"] = float(l1_loss)
|
|
|
|
|
losses_dict["mse_loss"] = float(mse_loss)
|
|
|
|
|
losses_dict["bce_loss"] = float(bce_loss)
|
|
|
|
|
losses_dict["attn_loss"] = float(attn_loss)
|
|
|
|
|
losses_dict["loss"] = float(loss)
|
|
|
|
|
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
|
|
|
|
for k, v in losses_dict.items())
|
|
|
|
@ -202,17 +204,19 @@ class Tacotron2Evaluator(StandardEvaluator):
|
|
|
|
|
attn_loss = self.attn_loss(
|
|
|
|
|
att_ws=att_ws, ilens=batch["text_lengths"] + 1, olens=olens_in)
|
|
|
|
|
loss = loss + attn_loss
|
|
|
|
|
|
|
|
|
|
if self.use_guided_attn_loss:
|
|
|
|
|
report("eval/attn_loss", float(attn_loss))
|
|
|
|
|
losses_dict["attn_loss"] = float(attn_loss)
|
|
|
|
|
|
|
|
|
|
report("eval/l1_loss", float(l1_loss))
|
|
|
|
|
report("eval/mse_loss", float(mse_loss))
|
|
|
|
|
report("eval/bce_loss", float(bce_loss))
|
|
|
|
|
report("eval/attn_loss", float(attn_loss))
|
|
|
|
|
report("eval/loss", float(loss))
|
|
|
|
|
|
|
|
|
|
losses_dict["l1_loss"] = float(l1_loss)
|
|
|
|
|
losses_dict["mse_loss"] = float(mse_loss)
|
|
|
|
|
losses_dict["bce_loss"] = float(bce_loss)
|
|
|
|
|
losses_dict["attn_loss"] = float(attn_loss)
|
|
|
|
|
losses_dict["loss"] = float(loss)
|
|
|
|
|
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
|
|
|
|
for k, v in losses_dict.items())
|
|
|
|
|