diff --git a/paddlespeech/t2s/exps/fastspeech2/train.py b/paddlespeech/t2s/exps/fastspeech2/train.py index 10e023d0c..0f5edb37e 100644 --- a/paddlespeech/t2s/exps/fastspeech2/train.py +++ b/paddlespeech/t2s/exps/fastspeech2/train.py @@ -145,17 +145,27 @@ def train_sp(args, config): # copy conf to output_dir shutil.copyfile(args.config, output_dir / config_name) + if "enable_speaker_classifier" in config.model: + enable_spk_cls = config.model.enable_speaker_classifier + else: + enable_spk_cls = False + updater = FastSpeech2Updater( model=model, optimizer=optimizer, dataloader=train_dataloader, output_dir=output_dir, - **config["updater"]) + **config["updater"], + enable_spk_cls=enable_spk_cls) trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir) evaluator = FastSpeech2Evaluator( - model, dev_dataloader, output_dir=output_dir, **config["updater"]) + model, + dev_dataloader, + output_dir=output_dir, + **config["updater"], + enable_spk_cls=enable_spk_cls) if dist.get_rank() == 0: trainer.extend(evaluator, trigger=(1, "epoch")) diff --git a/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py b/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py index 2b25b6a62..bbff927b7 100644 --- a/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py +++ b/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py @@ -33,15 +33,17 @@ logger.setLevel(logging.INFO) class FastSpeech2Updater(StandardUpdater): - def __init__(self, - model: Layer, - optimizer: Optimizer, - dataloader: DataLoader, - init_state=None, - use_masking: bool=False, - spk_loss_scale: float=0.02, - use_weighted_masking: bool=False, - output_dir: Path=None): + def __init__( + self, + model: Layer, + optimizer: Optimizer, + dataloader: DataLoader, + init_state=None, + use_masking: bool=False, + spk_loss_scale: float=0.02, + use_weighted_masking: bool=False, + output_dir: Path=None, + enable_spk_cls: bool=False, ): super().__init__(model, optimizer, dataloader, init_state=None) self.criterion = FastSpeech2Loss( @@ -54,6 +56,7 @@ class FastSpeech2Updater(StandardUpdater): self.logger = logger self.msg = "" self.spk_loss_scale = spk_loss_scale + self.enable_spk_cls = enable_spk_cls def update_core(self, batch): self.msg = "Rank: {}, ".format(dist.get_rank()) @@ -118,7 +121,7 @@ class FastSpeech2Updater(StandardUpdater): report("train/duration_loss", float(duration_loss)) report("train/pitch_loss", float(pitch_loss)) report("train/energy_loss", float(energy_loss)) - if speaker_loss != 0.0: + if self.enable_spk_cls: report("train/speaker_loss", float(speaker_loss)) report("train/scale_speaker_loss", float(self.spk_loss_scale * speaker_loss)) @@ -128,7 +131,7 @@ class FastSpeech2Updater(StandardUpdater): losses_dict["pitch_loss"] = float(pitch_loss) losses_dict["energy_loss"] = float(energy_loss) losses_dict["energy_loss"] = float(energy_loss) - if speaker_loss != 0.0: + if self.enable_spk_cls: losses_dict["speaker_loss"] = float(speaker_loss) losses_dict["scale_speaker_loss"] = float(self.spk_loss_scale * speaker_loss) @@ -144,7 +147,8 @@ class FastSpeech2Evaluator(StandardEvaluator): use_masking: bool=False, use_weighted_masking: bool=False, spk_loss_scale: float=0.02, - output_dir: Path=None): + output_dir: Path=None, + enable_spk_cls: bool=False): super().__init__(model, dataloader) log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) @@ -153,6 +157,7 @@ class FastSpeech2Evaluator(StandardEvaluator): self.logger = logger self.msg = "" self.spk_loss_scale = spk_loss_scale + self.enable_spk_cls = enable_spk_cls self.criterion = FastSpeech2Loss( use_masking=use_masking, use_weighted_masking=use_weighted_masking) @@ -213,7 +218,7 @@ class FastSpeech2Evaluator(StandardEvaluator): report("eval/duration_loss", float(duration_loss)) report("eval/pitch_loss", float(pitch_loss)) report("eval/energy_loss", float(energy_loss)) - if speaker_loss != 0.0: + if self.enable_spk_cls: report("train/speaker_loss", float(speaker_loss)) report("train/scale_speaker_loss", float(self.spk_loss_scale * speaker_loss)) @@ -222,7 +227,7 @@ class FastSpeech2Evaluator(StandardEvaluator): losses_dict["duration_loss"] = float(duration_loss) losses_dict["pitch_loss"] = float(pitch_loss) losses_dict["energy_loss"] = float(energy_loss) - if speaker_loss != 0.0: + if self.enable_spk_cls: losses_dict["speaker_loss"] = float(speaker_loss) losses_dict["scale_speaker_loss"] = float(self.spk_loss_scale * speaker_loss)