add style_melgan, test=tts

pull/1068/head
TianYuan 3 years ago
parent 075aeee7f0
commit 7b2ecb6eed

@ -40,8 +40,9 @@ class StyleMelGANUpdater(StandardUpdater):
criterions: Dict[str, Layer], criterions: Dict[str, Layer],
schedulers: Dict[str, LRScheduler], schedulers: Dict[str, LRScheduler],
dataloader: DataLoader, dataloader: DataLoader,
discriminator_train_start_steps: int, generator_train_start_steps: int=0,
lambda_adv: float, discriminator_train_start_steps: int=100000,
lambda_adv: float=1.0,
lambda_aux: float=1.0, lambda_aux: float=1.0,
output_dir: Path=None): output_dir: Path=None):
self.models = models self.models = models
@ -63,11 +64,12 @@ class StyleMelGANUpdater(StandardUpdater):
self.dataloader = dataloader self.dataloader = dataloader
self.generator_train_start_steps = generator_train_start_steps
self.discriminator_train_start_steps = discriminator_train_start_steps self.discriminator_train_start_steps = discriminator_train_start_steps
self.lambda_adv = lambda_adv self.lambda_adv = lambda_adv
self.lambda_aux = lambda_aux self.lambda_aux = lambda_aux
self.state = UpdaterState(iteration=0, epoch=0)
self.state = UpdaterState(iteration=0, epoch=0)
self.train_iterator = iter(self.dataloader) self.train_iterator = iter(self.dataloader)
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
@ -79,42 +81,45 @@ class StyleMelGANUpdater(StandardUpdater):
def update_core(self, batch): def update_core(self, batch):
self.msg = "Rank: {}, ".format(dist.get_rank()) self.msg = "Rank: {}, ".format(dist.get_rank())
losses_dict = {} losses_dict = {}
# parse batch # parse batch
wav, mel = batch wav, mel = batch
# Generator # Generator
# (B, out_channels, T ** prod(upsample_scales) if self.state.iteration > self.generator_train_start_steps:
wav_ = self.generator(mel) # (B, out_channels, T ** prod(upsample_scales)
wav_ = self.generator(mel)
# initialize # initialize
gen_loss = 0.0 gen_loss = 0.0
aux_loss = 0.0
# full band Multi-resolution stft loss # full band multi-resolution stft loss
sc_loss, mag_loss = self.criterion_stft(wav_, wav) sc_loss, mag_loss = self.criterion_stft(wav_, wav)
gen_loss += sc_loss + mag_loss aux_loss += sc_loss + mag_loss
report("train/spectral_convergence_loss", float(sc_loss)) report("train/spectral_convergence_loss", float(sc_loss))
report("train/log_stft_magnitude_loss", float(mag_loss)) report("train/log_stft_magnitude_loss", float(mag_loss))
losses_dict["spectral_convergence_loss"] = float(sc_loss) losses_dict["spectral_convergence_loss"] = float(sc_loss)
losses_dict["log_stft_magnitude_loss"] = float(mag_loss) losses_dict["log_stft_magnitude_loss"] = float(mag_loss)
gen_loss *= self.lambda_aux gen_loss += aux_loss * self.lambda_aux
## Adversarial loss # adversarial loss
if self.state.iteration > self.discriminator_train_start_steps: if self.state.iteration > self.discriminator_train_start_steps:
p_ = self.discriminator(wav_) p_ = self.discriminator(wav_)
adv_loss = self.criterion_gen_adv(p_) adv_loss = self.criterion_gen_adv(p_)
report("train/adversarial_loss", float(adv_loss)) report("train/adversarial_loss", float(adv_loss))
losses_dict["adversarial_loss"] = float(adv_loss) losses_dict["adversarial_loss"] = float(adv_loss)
gen_loss += self.lambda_adv * adv_loss
report("train/generator_loss", float(gen_loss)) gen_loss += self.lambda_adv * adv_loss
losses_dict["generator_loss"] = float(gen_loss)
self.optimizer_g.clear_grad() report("train/generator_loss", float(gen_loss))
gen_loss.backward() losses_dict["generator_loss"] = float(gen_loss)
self.optimizer_g.step() self.optimizer_g.clear_grad()
self.scheduler_g.step() gen_loss.backward()
self.optimizer_g.step()
self.scheduler_g.step()
# Disctiminator # Disctiminator
if self.state.iteration > self.discriminator_train_start_steps: if self.state.iteration > self.discriminator_train_start_steps:
@ -148,7 +153,7 @@ class StyleMelGANEvaluator(StandardEvaluator):
models: Dict[str, Layer], models: Dict[str, Layer],
criterions: Dict[str, Layer], criterions: Dict[str, Layer],
dataloader: DataLoader, dataloader: DataLoader,
lambda_adv: float, lambda_adv: float=1.0,
lambda_aux: float=1.0, lambda_aux: float=1.0,
output_dir: Path=None): output_dir: Path=None):
self.models = models self.models = models
@ -161,6 +166,7 @@ class StyleMelGANEvaluator(StandardEvaluator):
self.criterion_dis_adv = criterions["dis_adv"] self.criterion_dis_adv = criterions["dis_adv"]
self.dataloader = dataloader self.dataloader = dataloader
self.lambda_adv = lambda_adv self.lambda_adv = lambda_adv
self.lambda_aux = lambda_aux self.lambda_aux = lambda_aux
@ -171,26 +177,27 @@ class StyleMelGANEvaluator(StandardEvaluator):
self.msg = "" self.msg = ""
def evaluate_core(self, batch): def evaluate_core(self, batch):
# logging.debug("Evaluate: ")
self.msg = "Evaluate: " self.msg = "Evaluate: "
losses_dict = {} losses_dict = {}
wav, mel = batch wav, mel = batch
# Generator # Generator
# (B, out_channels, T ** prod(upsample_scales) # (B, out_channels, T ** prod(upsample_scales)
wav_ = self.generator(mel) wav_ = self.generator(mel)
## Adversarial loss # initialize
gen_loss = 0.0
aux_loss = 0.0
# adversarial loss
p_ = self.discriminator(wav_) p_ = self.discriminator(wav_)
adv_loss = self.criterion_gen_adv(p_) adv_loss = self.criterion_gen_adv(p_)
report("eval/adversarial_loss", float(adv_loss)) report("eval/adversarial_loss", float(adv_loss))
losses_dict["adversarial_loss"] = float(adv_loss) losses_dict["adversarial_loss"] = float(adv_loss)
gen_loss = self.lambda_adv * adv_loss
# initialize gen_loss += self.lambda_adv * adv_loss
aux_loss = 0.0
# Multi-resolution stft loss # multi-resolution stft loss
sc_loss, mag_loss = self.criterion_stft(wav_, wav) sc_loss, mag_loss = self.criterion_stft(wav_, wav)
aux_loss += sc_loss + mag_loss aux_loss += sc_loss + mag_loss
report("eval/spectral_convergence_loss", float(sc_loss)) report("eval/spectral_convergence_loss", float(sc_loss))
@ -198,8 +205,7 @@ class StyleMelGANEvaluator(StandardEvaluator):
losses_dict["spectral_convergence_loss"] = float(sc_loss) losses_dict["spectral_convergence_loss"] = float(sc_loss)
losses_dict["log_stft_magnitude_loss"] = float(mag_loss) losses_dict["log_stft_magnitude_loss"] = float(mag_loss)
aux_loss *= self.lambda_aux gen_loss += aux_loss * self.lambda_aux
gen_loss += aux_loss
report("eval/generator_loss", float(gen_loss)) report("eval/generator_loss", float(gen_loss))
losses_dict["generator_loss"] = float(gen_loss) losses_dict["generator_loss"] = float(gen_loss)

Loading…
Cancel
Save