[TTS]Avoid using variable "attn_loss" before assignment (#2860)

* Avoid using variable "attn_loss" before assignment

* Update tacotron2_updater.py

---------

Co-authored-by: TianYuan <white-sky@qq.com>
pull/2879/head
章宏彬 3 years ago committed by GitHub
parent a283f8a57e
commit c764710aa1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -113,16 +113,18 @@ class Tacotron2Updater(StandardUpdater):
loss.backward()
optimizer.step()
if self.use_guided_attn_loss:
report("train/attn_loss", float(attn_loss))
losses_dict["attn_loss"] = float(attn_loss)
report("train/l1_loss", float(l1_loss))
report("train/mse_loss", float(mse_loss))
report("train/bce_loss", float(bce_loss))
report("train/attn_loss", float(attn_loss))
report("train/loss", float(loss))
losses_dict["l1_loss"] = float(l1_loss)
losses_dict["mse_loss"] = float(mse_loss)
losses_dict["bce_loss"] = float(bce_loss)
losses_dict["attn_loss"] = float(attn_loss)
losses_dict["loss"] = float(loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
@ -202,17 +204,19 @@ class Tacotron2Evaluator(StandardEvaluator):
attn_loss = self.attn_loss(
att_ws=att_ws, ilens=batch["text_lengths"] + 1, olens=olens_in)
loss = loss + attn_loss
if self.use_guided_attn_loss:
report("eval/attn_loss", float(attn_loss))
losses_dict["attn_loss"] = float(attn_loss)
report("eval/l1_loss", float(l1_loss))
report("eval/mse_loss", float(mse_loss))
report("eval/bce_loss", float(bce_loss))
report("eval/attn_loss", float(attn_loss))
report("eval/loss", float(loss))
losses_dict["l1_loss"] = float(l1_loss)
losses_dict["mse_loss"] = float(mse_loss)
losses_dict["bce_loss"] = float(bce_loss)
losses_dict["attn_loss"] = float(attn_loss)
losses_dict["loss"] = float(loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())

Loading…
Cancel
Save