updata, test=tts

pull/2588/head
liangym 3 years ago
parent d92852aef7
commit 1c5471e4b0

@ -145,17 +145,27 @@ def train_sp(args, config):
# copy conf to output_dir # copy conf to output_dir
shutil.copyfile(args.config, output_dir / config_name) shutil.copyfile(args.config, output_dir / config_name)
if "enable_speaker_classifier" in config.model:
enable_spk_cls = config.model.enable_speaker_classifier
else:
enable_spk_cls = False
updater = FastSpeech2Updater( updater = FastSpeech2Updater(
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
dataloader=train_dataloader, dataloader=train_dataloader,
output_dir=output_dir, output_dir=output_dir,
**config["updater"]) **config["updater"],
enable_spk_cls=enable_spk_cls)
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir) trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
evaluator = FastSpeech2Evaluator( evaluator = FastSpeech2Evaluator(
model, dev_dataloader, output_dir=output_dir, **config["updater"]) model,
dev_dataloader,
output_dir=output_dir,
**config["updater"],
enable_spk_cls=enable_spk_cls)
if dist.get_rank() == 0: if dist.get_rank() == 0:
trainer.extend(evaluator, trigger=(1, "epoch")) trainer.extend(evaluator, trigger=(1, "epoch"))

@ -33,15 +33,17 @@ logger.setLevel(logging.INFO)
class FastSpeech2Updater(StandardUpdater): class FastSpeech2Updater(StandardUpdater):
def __init__(self, def __init__(
model: Layer, self,
optimizer: Optimizer, model: Layer,
dataloader: DataLoader, optimizer: Optimizer,
init_state=None, dataloader: DataLoader,
use_masking: bool=False, init_state=None,
spk_loss_scale: float=0.02, use_masking: bool=False,
use_weighted_masking: bool=False, spk_loss_scale: float=0.02,
output_dir: Path=None): use_weighted_masking: bool=False,
output_dir: Path=None,
enable_spk_cls: bool=False, ):
super().__init__(model, optimizer, dataloader, init_state=None) super().__init__(model, optimizer, dataloader, init_state=None)
self.criterion = FastSpeech2Loss( self.criterion = FastSpeech2Loss(
@ -54,6 +56,7 @@ class FastSpeech2Updater(StandardUpdater):
self.logger = logger self.logger = logger
self.msg = "" self.msg = ""
self.spk_loss_scale = spk_loss_scale self.spk_loss_scale = spk_loss_scale
self.enable_spk_cls = enable_spk_cls
def update_core(self, batch): def update_core(self, batch):
self.msg = "Rank: {}, ".format(dist.get_rank()) self.msg = "Rank: {}, ".format(dist.get_rank())
@ -118,7 +121,7 @@ class FastSpeech2Updater(StandardUpdater):
report("train/duration_loss", float(duration_loss)) report("train/duration_loss", float(duration_loss))
report("train/pitch_loss", float(pitch_loss)) report("train/pitch_loss", float(pitch_loss))
report("train/energy_loss", float(energy_loss)) report("train/energy_loss", float(energy_loss))
if speaker_loss != 0.0: if self.enable_spk_cls:
report("train/speaker_loss", float(speaker_loss)) report("train/speaker_loss", float(speaker_loss))
report("train/scale_speaker_loss", report("train/scale_speaker_loss",
float(self.spk_loss_scale * speaker_loss)) float(self.spk_loss_scale * speaker_loss))
@ -128,7 +131,7 @@ class FastSpeech2Updater(StandardUpdater):
losses_dict["pitch_loss"] = float(pitch_loss) losses_dict["pitch_loss"] = float(pitch_loss)
losses_dict["energy_loss"] = float(energy_loss) losses_dict["energy_loss"] = float(energy_loss)
losses_dict["energy_loss"] = float(energy_loss) losses_dict["energy_loss"] = float(energy_loss)
if speaker_loss != 0.0: if self.enable_spk_cls:
losses_dict["speaker_loss"] = float(speaker_loss) losses_dict["speaker_loss"] = float(speaker_loss)
losses_dict["scale_speaker_loss"] = float(self.spk_loss_scale * losses_dict["scale_speaker_loss"] = float(self.spk_loss_scale *
speaker_loss) speaker_loss)
@ -144,7 +147,8 @@ class FastSpeech2Evaluator(StandardEvaluator):
use_masking: bool=False, use_masking: bool=False,
use_weighted_masking: bool=False, use_weighted_masking: bool=False,
spk_loss_scale: float=0.02, spk_loss_scale: float=0.02,
output_dir: Path=None): output_dir: Path=None,
enable_spk_cls: bool=False):
super().__init__(model, dataloader) super().__init__(model, dataloader)
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
@ -153,6 +157,7 @@ class FastSpeech2Evaluator(StandardEvaluator):
self.logger = logger self.logger = logger
self.msg = "" self.msg = ""
self.spk_loss_scale = spk_loss_scale self.spk_loss_scale = spk_loss_scale
self.enable_spk_cls = enable_spk_cls
self.criterion = FastSpeech2Loss( self.criterion = FastSpeech2Loss(
use_masking=use_masking, use_weighted_masking=use_weighted_masking) use_masking=use_masking, use_weighted_masking=use_weighted_masking)
@ -213,7 +218,7 @@ class FastSpeech2Evaluator(StandardEvaluator):
report("eval/duration_loss", float(duration_loss)) report("eval/duration_loss", float(duration_loss))
report("eval/pitch_loss", float(pitch_loss)) report("eval/pitch_loss", float(pitch_loss))
report("eval/energy_loss", float(energy_loss)) report("eval/energy_loss", float(energy_loss))
if speaker_loss != 0.0: if self.enable_spk_cls:
report("train/speaker_loss", float(speaker_loss)) report("train/speaker_loss", float(speaker_loss))
report("train/scale_speaker_loss", report("train/scale_speaker_loss",
float(self.spk_loss_scale * speaker_loss)) float(self.spk_loss_scale * speaker_loss))
@ -222,7 +227,7 @@ class FastSpeech2Evaluator(StandardEvaluator):
losses_dict["duration_loss"] = float(duration_loss) losses_dict["duration_loss"] = float(duration_loss)
losses_dict["pitch_loss"] = float(pitch_loss) losses_dict["pitch_loss"] = float(pitch_loss)
losses_dict["energy_loss"] = float(energy_loss) losses_dict["energy_loss"] = float(energy_loss)
if speaker_loss != 0.0: if self.enable_spk_cls:
losses_dict["speaker_loss"] = float(speaker_loss) losses_dict["speaker_loss"] = float(speaker_loss)
losses_dict["scale_speaker_loss"] = float(self.spk_loss_scale * losses_dict["scale_speaker_loss"] = float(self.spk_loss_scale *
speaker_loss) speaker_loss)

Loading…
Cancel
Save