diff --git a/paddlespeech/t2s/models/melgan/style_melgan_updater.py b/paddlespeech/t2s/models/melgan/style_melgan_updater.py index 49054aa7..b0cb4ed6 100644 --- a/paddlespeech/t2s/models/melgan/style_melgan_updater.py +++ b/paddlespeech/t2s/models/melgan/style_melgan_updater.py @@ -40,8 +40,9 @@ class StyleMelGANUpdater(StandardUpdater): criterions: Dict[str, Layer], schedulers: Dict[str, LRScheduler], dataloader: DataLoader, - discriminator_train_start_steps: int, - lambda_adv: float, + generator_train_start_steps: int=0, + discriminator_train_start_steps: int=100000, + lambda_adv: float=1.0, lambda_aux: float=1.0, output_dir: Path=None): self.models = models @@ -63,11 +64,12 @@ class StyleMelGANUpdater(StandardUpdater): self.dataloader = dataloader + self.generator_train_start_steps = generator_train_start_steps self.discriminator_train_start_steps = discriminator_train_start_steps self.lambda_adv = lambda_adv self.lambda_aux = lambda_aux - self.state = UpdaterState(iteration=0, epoch=0) + self.state = UpdaterState(iteration=0, epoch=0) self.train_iterator = iter(self.dataloader) log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) @@ -79,42 +81,45 @@ class StyleMelGANUpdater(StandardUpdater): def update_core(self, batch): self.msg = "Rank: {}, ".format(dist.get_rank()) losses_dict = {} - # parse batch wav, mel = batch + # Generator - # (B, out_channels, T ** prod(upsample_scales) - wav_ = self.generator(mel) + if self.state.iteration > self.generator_train_start_steps: + # (B, out_channels, T ** prod(upsample_scales) + wav_ = self.generator(mel) - # initialize - gen_loss = 0.0 + # initialize + gen_loss = 0.0 + aux_loss = 0.0 - # full band Multi-resolution stft loss - sc_loss, mag_loss = self.criterion_stft(wav_, wav) - gen_loss += sc_loss + mag_loss - report("train/spectral_convergence_loss", float(sc_loss)) - report("train/log_stft_magnitude_loss", float(mag_loss)) - losses_dict["spectral_convergence_loss"] = float(sc_loss) - losses_dict["log_stft_magnitude_loss"] = float(mag_loss) + # full band multi-resolution stft loss + sc_loss, mag_loss = self.criterion_stft(wav_, wav) + aux_loss += sc_loss + mag_loss + report("train/spectral_convergence_loss", float(sc_loss)) + report("train/log_stft_magnitude_loss", float(mag_loss)) + losses_dict["spectral_convergence_loss"] = float(sc_loss) + losses_dict["log_stft_magnitude_loss"] = float(mag_loss) - gen_loss *= self.lambda_aux + gen_loss += aux_loss * self.lambda_aux - ## Adversarial loss - if self.state.iteration > self.discriminator_train_start_steps: - p_ = self.discriminator(wav_) - adv_loss = self.criterion_gen_adv(p_) - report("train/adversarial_loss", float(adv_loss)) - losses_dict["adversarial_loss"] = float(adv_loss) - gen_loss += self.lambda_adv * adv_loss + # adversarial loss + if self.state.iteration > self.discriminator_train_start_steps: + p_ = self.discriminator(wav_) + adv_loss = self.criterion_gen_adv(p_) + report("train/adversarial_loss", float(adv_loss)) + losses_dict["adversarial_loss"] = float(adv_loss) - report("train/generator_loss", float(gen_loss)) - losses_dict["generator_loss"] = float(gen_loss) + gen_loss += self.lambda_adv * adv_loss - self.optimizer_g.clear_grad() - gen_loss.backward() + report("train/generator_loss", float(gen_loss)) + losses_dict["generator_loss"] = float(gen_loss) - self.optimizer_g.step() - self.scheduler_g.step() + self.optimizer_g.clear_grad() + gen_loss.backward() + + self.optimizer_g.step() + self.scheduler_g.step() # Disctiminator if self.state.iteration > self.discriminator_train_start_steps: @@ -148,7 +153,7 @@ class StyleMelGANEvaluator(StandardEvaluator): models: Dict[str, Layer], criterions: Dict[str, Layer], dataloader: DataLoader, - lambda_adv: float, + lambda_adv: float=1.0, lambda_aux: float=1.0, output_dir: Path=None): self.models = models @@ -161,6 +166,7 @@ class StyleMelGANEvaluator(StandardEvaluator): self.criterion_dis_adv = criterions["dis_adv"] self.dataloader = dataloader + self.lambda_adv = lambda_adv self.lambda_aux = lambda_aux @@ -171,26 +177,27 @@ class StyleMelGANEvaluator(StandardEvaluator): self.msg = "" def evaluate_core(self, batch): - # logging.debug("Evaluate: ") self.msg = "Evaluate: " losses_dict = {} - wav, mel = batch + # Generator # (B, out_channels, T ** prod(upsample_scales) wav_ = self.generator(mel) - ## Adversarial loss + # initialize + gen_loss = 0.0 + aux_loss = 0.0 + + # adversarial loss p_ = self.discriminator(wav_) adv_loss = self.criterion_gen_adv(p_) - report("eval/adversarial_loss", float(adv_loss)) losses_dict["adversarial_loss"] = float(adv_loss) - gen_loss = self.lambda_adv * adv_loss - # initialize - aux_loss = 0.0 - # Multi-resolution stft loss + gen_loss += self.lambda_adv * adv_loss + + # multi-resolution stft loss sc_loss, mag_loss = self.criterion_stft(wav_, wav) aux_loss += sc_loss + mag_loss report("eval/spectral_convergence_loss", float(sc_loss)) @@ -198,8 +205,7 @@ class StyleMelGANEvaluator(StandardEvaluator): losses_dict["spectral_convergence_loss"] = float(sc_loss) losses_dict["log_stft_magnitude_loss"] = float(mag_loss) - aux_loss *= self.lambda_aux - gen_loss += aux_loss + gen_loss += aux_loss * self.lambda_aux report("eval/generator_loss", float(gen_loss)) losses_dict["generator_loss"] = float(gen_loss)