updata, test=tts

pull/2588/head
liangym 3 years ago
parent d92852aef7
commit 1c5471e4b0

@ -145,17 +145,27 @@ def train_sp(args, config):
# copy conf to output_dir
shutil.copyfile(args.config, output_dir / config_name)
if "enable_speaker_classifier" in config.model:
enable_spk_cls = config.model.enable_speaker_classifier
else:
enable_spk_cls = False
updater = FastSpeech2Updater(
model=model,
optimizer=optimizer,
dataloader=train_dataloader,
output_dir=output_dir,
**config["updater"])
**config["updater"],
enable_spk_cls=enable_spk_cls)
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
evaluator = FastSpeech2Evaluator(
model, dev_dataloader, output_dir=output_dir, **config["updater"])
model,
dev_dataloader,
output_dir=output_dir,
**config["updater"],
enable_spk_cls=enable_spk_cls)
if dist.get_rank() == 0:
trainer.extend(evaluator, trigger=(1, "epoch"))

@ -33,7 +33,8 @@ logger.setLevel(logging.INFO)
class FastSpeech2Updater(StandardUpdater):
def __init__(self,
def __init__(
self,
model: Layer,
optimizer: Optimizer,
dataloader: DataLoader,
@ -41,7 +42,8 @@ class FastSpeech2Updater(StandardUpdater):
use_masking: bool=False,
spk_loss_scale: float=0.02,
use_weighted_masking: bool=False,
output_dir: Path=None):
output_dir: Path=None,
enable_spk_cls: bool=False, ):
super().__init__(model, optimizer, dataloader, init_state=None)
self.criterion = FastSpeech2Loss(
@ -54,6 +56,7 @@ class FastSpeech2Updater(StandardUpdater):
self.logger = logger
self.msg = ""
self.spk_loss_scale = spk_loss_scale
self.enable_spk_cls = enable_spk_cls
def update_core(self, batch):
self.msg = "Rank: {}, ".format(dist.get_rank())
@ -118,7 +121,7 @@ class FastSpeech2Updater(StandardUpdater):
report("train/duration_loss", float(duration_loss))
report("train/pitch_loss", float(pitch_loss))
report("train/energy_loss", float(energy_loss))
if speaker_loss != 0.0:
if self.enable_spk_cls:
report("train/speaker_loss", float(speaker_loss))
report("train/scale_speaker_loss",
float(self.spk_loss_scale * speaker_loss))
@ -128,7 +131,7 @@ class FastSpeech2Updater(StandardUpdater):
losses_dict["pitch_loss"] = float(pitch_loss)
losses_dict["energy_loss"] = float(energy_loss)
losses_dict["energy_loss"] = float(energy_loss)
if speaker_loss != 0.0:
if self.enable_spk_cls:
losses_dict["speaker_loss"] = float(speaker_loss)
losses_dict["scale_speaker_loss"] = float(self.spk_loss_scale *
speaker_loss)
@ -144,7 +147,8 @@ class FastSpeech2Evaluator(StandardEvaluator):
use_masking: bool=False,
use_weighted_masking: bool=False,
spk_loss_scale: float=0.02,
output_dir: Path=None):
output_dir: Path=None,
enable_spk_cls: bool=False):
super().__init__(model, dataloader)
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
@ -153,6 +157,7 @@ class FastSpeech2Evaluator(StandardEvaluator):
self.logger = logger
self.msg = ""
self.spk_loss_scale = spk_loss_scale
self.enable_spk_cls = enable_spk_cls
self.criterion = FastSpeech2Loss(
use_masking=use_masking, use_weighted_masking=use_weighted_masking)
@ -213,7 +218,7 @@ class FastSpeech2Evaluator(StandardEvaluator):
report("eval/duration_loss", float(duration_loss))
report("eval/pitch_loss", float(pitch_loss))
report("eval/energy_loss", float(energy_loss))
if speaker_loss != 0.0:
if self.enable_spk_cls:
report("train/speaker_loss", float(speaker_loss))
report("train/scale_speaker_loss",
float(self.spk_loss_scale * speaker_loss))
@ -222,7 +227,7 @@ class FastSpeech2Evaluator(StandardEvaluator):
losses_dict["duration_loss"] = float(duration_loss)
losses_dict["pitch_loss"] = float(pitch_loss)
losses_dict["energy_loss"] = float(energy_loss)
if speaker_loss != 0.0:
if self.enable_spk_cls:
losses_dict["speaker_loss"] = float(speaker_loss)
losses_dict["scale_speaker_loss"] = float(self.spk_loss_scale *
speaker_loss)

Loading…
Cancel
Save