|
|
|
@ -33,7 +33,8 @@ logger.setLevel(logging.INFO)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FastSpeech2Updater(StandardUpdater):
|
|
|
|
|
def __init__(self,
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
model: Layer,
|
|
|
|
|
optimizer: Optimizer,
|
|
|
|
|
dataloader: DataLoader,
|
|
|
|
@ -41,7 +42,8 @@ class FastSpeech2Updater(StandardUpdater):
|
|
|
|
|
use_masking: bool=False,
|
|
|
|
|
spk_loss_scale: float=0.02,
|
|
|
|
|
use_weighted_masking: bool=False,
|
|
|
|
|
output_dir: Path=None):
|
|
|
|
|
output_dir: Path=None,
|
|
|
|
|
enable_spk_cls: bool=False, ):
|
|
|
|
|
super().__init__(model, optimizer, dataloader, init_state=None)
|
|
|
|
|
|
|
|
|
|
self.criterion = FastSpeech2Loss(
|
|
|
|
@ -54,6 +56,7 @@ class FastSpeech2Updater(StandardUpdater):
|
|
|
|
|
self.logger = logger
|
|
|
|
|
self.msg = ""
|
|
|
|
|
self.spk_loss_scale = spk_loss_scale
|
|
|
|
|
self.enable_spk_cls = enable_spk_cls
|
|
|
|
|
|
|
|
|
|
def update_core(self, batch):
|
|
|
|
|
self.msg = "Rank: {}, ".format(dist.get_rank())
|
|
|
|
@ -118,7 +121,7 @@ class FastSpeech2Updater(StandardUpdater):
|
|
|
|
|
report("train/duration_loss", float(duration_loss))
|
|
|
|
|
report("train/pitch_loss", float(pitch_loss))
|
|
|
|
|
report("train/energy_loss", float(energy_loss))
|
|
|
|
|
if speaker_loss != 0.0:
|
|
|
|
|
if self.enable_spk_cls:
|
|
|
|
|
report("train/speaker_loss", float(speaker_loss))
|
|
|
|
|
report("train/scale_speaker_loss",
|
|
|
|
|
float(self.spk_loss_scale * speaker_loss))
|
|
|
|
@ -128,7 +131,7 @@ class FastSpeech2Updater(StandardUpdater):
|
|
|
|
|
losses_dict["pitch_loss"] = float(pitch_loss)
|
|
|
|
|
losses_dict["energy_loss"] = float(energy_loss)
|
|
|
|
|
losses_dict["energy_loss"] = float(energy_loss)
|
|
|
|
|
if speaker_loss != 0.0:
|
|
|
|
|
if self.enable_spk_cls:
|
|
|
|
|
losses_dict["speaker_loss"] = float(speaker_loss)
|
|
|
|
|
losses_dict["scale_speaker_loss"] = float(self.spk_loss_scale *
|
|
|
|
|
speaker_loss)
|
|
|
|
@ -144,7 +147,8 @@ class FastSpeech2Evaluator(StandardEvaluator):
|
|
|
|
|
use_masking: bool=False,
|
|
|
|
|
use_weighted_masking: bool=False,
|
|
|
|
|
spk_loss_scale: float=0.02,
|
|
|
|
|
output_dir: Path=None):
|
|
|
|
|
output_dir: Path=None,
|
|
|
|
|
enable_spk_cls: bool=False):
|
|
|
|
|
super().__init__(model, dataloader)
|
|
|
|
|
|
|
|
|
|
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
|
|
|
@ -153,6 +157,7 @@ class FastSpeech2Evaluator(StandardEvaluator):
|
|
|
|
|
self.logger = logger
|
|
|
|
|
self.msg = ""
|
|
|
|
|
self.spk_loss_scale = spk_loss_scale
|
|
|
|
|
self.enable_spk_cls = enable_spk_cls
|
|
|
|
|
|
|
|
|
|
self.criterion = FastSpeech2Loss(
|
|
|
|
|
use_masking=use_masking, use_weighted_masking=use_weighted_masking)
|
|
|
|
@ -213,7 +218,7 @@ class FastSpeech2Evaluator(StandardEvaluator):
|
|
|
|
|
report("eval/duration_loss", float(duration_loss))
|
|
|
|
|
report("eval/pitch_loss", float(pitch_loss))
|
|
|
|
|
report("eval/energy_loss", float(energy_loss))
|
|
|
|
|
if speaker_loss != 0.0:
|
|
|
|
|
if self.enable_spk_cls:
|
|
|
|
|
report("train/speaker_loss", float(speaker_loss))
|
|
|
|
|
report("train/scale_speaker_loss",
|
|
|
|
|
float(self.spk_loss_scale * speaker_loss))
|
|
|
|
@ -222,7 +227,7 @@ class FastSpeech2Evaluator(StandardEvaluator):
|
|
|
|
|
losses_dict["duration_loss"] = float(duration_loss)
|
|
|
|
|
losses_dict["pitch_loss"] = float(pitch_loss)
|
|
|
|
|
losses_dict["energy_loss"] = float(energy_loss)
|
|
|
|
|
if speaker_loss != 0.0:
|
|
|
|
|
if self.enable_spk_cls:
|
|
|
|
|
losses_dict["speaker_loss"] = float(speaker_loss)
|
|
|
|
|
losses_dict["scale_speaker_loss"] = float(self.spk_loss_scale *
|
|
|
|
|
speaker_loss)
|
|
|
|
|