|
|
|
@ -40,8 +40,9 @@ class StyleMelGANUpdater(StandardUpdater):
|
|
|
|
|
criterions: Dict[str, Layer],
|
|
|
|
|
schedulers: Dict[str, LRScheduler],
|
|
|
|
|
dataloader: DataLoader,
|
|
|
|
|
discriminator_train_start_steps: int,
|
|
|
|
|
lambda_adv: float,
|
|
|
|
|
generator_train_start_steps: int=0,
|
|
|
|
|
discriminator_train_start_steps: int=100000,
|
|
|
|
|
lambda_adv: float=1.0,
|
|
|
|
|
lambda_aux: float=1.0,
|
|
|
|
|
output_dir: Path=None):
|
|
|
|
|
self.models = models
|
|
|
|
@ -63,11 +64,12 @@ class StyleMelGANUpdater(StandardUpdater):
|
|
|
|
|
|
|
|
|
|
self.dataloader = dataloader
|
|
|
|
|
|
|
|
|
|
self.generator_train_start_steps = generator_train_start_steps
|
|
|
|
|
self.discriminator_train_start_steps = discriminator_train_start_steps
|
|
|
|
|
self.lambda_adv = lambda_adv
|
|
|
|
|
self.lambda_aux = lambda_aux
|
|
|
|
|
self.state = UpdaterState(iteration=0, epoch=0)
|
|
|
|
|
|
|
|
|
|
self.state = UpdaterState(iteration=0, epoch=0)
|
|
|
|
|
self.train_iterator = iter(self.dataloader)
|
|
|
|
|
|
|
|
|
|
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
|
|
|
@ -79,42 +81,45 @@ class StyleMelGANUpdater(StandardUpdater):
|
|
|
|
|
def update_core(self, batch):
|
|
|
|
|
self.msg = "Rank: {}, ".format(dist.get_rank())
|
|
|
|
|
losses_dict = {}
|
|
|
|
|
|
|
|
|
|
# parse batch
|
|
|
|
|
wav, mel = batch
|
|
|
|
|
|
|
|
|
|
# Generator
|
|
|
|
|
# (B, out_channels, T ** prod(upsample_scales)
|
|
|
|
|
wav_ = self.generator(mel)
|
|
|
|
|
if self.state.iteration > self.generator_train_start_steps:
|
|
|
|
|
# (B, out_channels, T ** prod(upsample_scales)
|
|
|
|
|
wav_ = self.generator(mel)
|
|
|
|
|
|
|
|
|
|
# initialize
|
|
|
|
|
gen_loss = 0.0
|
|
|
|
|
# initialize
|
|
|
|
|
gen_loss = 0.0
|
|
|
|
|
aux_loss = 0.0
|
|
|
|
|
|
|
|
|
|
# full band Multi-resolution stft loss
|
|
|
|
|
sc_loss, mag_loss = self.criterion_stft(wav_, wav)
|
|
|
|
|
gen_loss += sc_loss + mag_loss
|
|
|
|
|
report("train/spectral_convergence_loss", float(sc_loss))
|
|
|
|
|
report("train/log_stft_magnitude_loss", float(mag_loss))
|
|
|
|
|
losses_dict["spectral_convergence_loss"] = float(sc_loss)
|
|
|
|
|
losses_dict["log_stft_magnitude_loss"] = float(mag_loss)
|
|
|
|
|
# full band multi-resolution stft loss
|
|
|
|
|
sc_loss, mag_loss = self.criterion_stft(wav_, wav)
|
|
|
|
|
aux_loss += sc_loss + mag_loss
|
|
|
|
|
report("train/spectral_convergence_loss", float(sc_loss))
|
|
|
|
|
report("train/log_stft_magnitude_loss", float(mag_loss))
|
|
|
|
|
losses_dict["spectral_convergence_loss"] = float(sc_loss)
|
|
|
|
|
losses_dict["log_stft_magnitude_loss"] = float(mag_loss)
|
|
|
|
|
|
|
|
|
|
gen_loss *= self.lambda_aux
|
|
|
|
|
gen_loss += aux_loss * self.lambda_aux
|
|
|
|
|
|
|
|
|
|
## Adversarial loss
|
|
|
|
|
if self.state.iteration > self.discriminator_train_start_steps:
|
|
|
|
|
p_ = self.discriminator(wav_)
|
|
|
|
|
adv_loss = self.criterion_gen_adv(p_)
|
|
|
|
|
report("train/adversarial_loss", float(adv_loss))
|
|
|
|
|
losses_dict["adversarial_loss"] = float(adv_loss)
|
|
|
|
|
gen_loss += self.lambda_adv * adv_loss
|
|
|
|
|
# adversarial loss
|
|
|
|
|
if self.state.iteration > self.discriminator_train_start_steps:
|
|
|
|
|
p_ = self.discriminator(wav_)
|
|
|
|
|
adv_loss = self.criterion_gen_adv(p_)
|
|
|
|
|
report("train/adversarial_loss", float(adv_loss))
|
|
|
|
|
losses_dict["adversarial_loss"] = float(adv_loss)
|
|
|
|
|
|
|
|
|
|
report("train/generator_loss", float(gen_loss))
|
|
|
|
|
losses_dict["generator_loss"] = float(gen_loss)
|
|
|
|
|
gen_loss += self.lambda_adv * adv_loss
|
|
|
|
|
|
|
|
|
|
self.optimizer_g.clear_grad()
|
|
|
|
|
gen_loss.backward()
|
|
|
|
|
report("train/generator_loss", float(gen_loss))
|
|
|
|
|
losses_dict["generator_loss"] = float(gen_loss)
|
|
|
|
|
|
|
|
|
|
self.optimizer_g.step()
|
|
|
|
|
self.scheduler_g.step()
|
|
|
|
|
self.optimizer_g.clear_grad()
|
|
|
|
|
gen_loss.backward()
|
|
|
|
|
|
|
|
|
|
self.optimizer_g.step()
|
|
|
|
|
self.scheduler_g.step()
|
|
|
|
|
|
|
|
|
|
# Disctiminator
|
|
|
|
|
if self.state.iteration > self.discriminator_train_start_steps:
|
|
|
|
@ -148,7 +153,7 @@ class StyleMelGANEvaluator(StandardEvaluator):
|
|
|
|
|
models: Dict[str, Layer],
|
|
|
|
|
criterions: Dict[str, Layer],
|
|
|
|
|
dataloader: DataLoader,
|
|
|
|
|
lambda_adv: float,
|
|
|
|
|
lambda_adv: float=1.0,
|
|
|
|
|
lambda_aux: float=1.0,
|
|
|
|
|
output_dir: Path=None):
|
|
|
|
|
self.models = models
|
|
|
|
@ -161,6 +166,7 @@ class StyleMelGANEvaluator(StandardEvaluator):
|
|
|
|
|
self.criterion_dis_adv = criterions["dis_adv"]
|
|
|
|
|
|
|
|
|
|
self.dataloader = dataloader
|
|
|
|
|
|
|
|
|
|
self.lambda_adv = lambda_adv
|
|
|
|
|
self.lambda_aux = lambda_aux
|
|
|
|
|
|
|
|
|
@ -171,26 +177,27 @@ class StyleMelGANEvaluator(StandardEvaluator):
|
|
|
|
|
self.msg = ""
|
|
|
|
|
|
|
|
|
|
def evaluate_core(self, batch):
|
|
|
|
|
# logging.debug("Evaluate: ")
|
|
|
|
|
self.msg = "Evaluate: "
|
|
|
|
|
losses_dict = {}
|
|
|
|
|
|
|
|
|
|
wav, mel = batch
|
|
|
|
|
|
|
|
|
|
# Generator
|
|
|
|
|
# (B, out_channels, T ** prod(upsample_scales)
|
|
|
|
|
wav_ = self.generator(mel)
|
|
|
|
|
|
|
|
|
|
## Adversarial loss
|
|
|
|
|
# initialize
|
|
|
|
|
gen_loss = 0.0
|
|
|
|
|
aux_loss = 0.0
|
|
|
|
|
|
|
|
|
|
# adversarial loss
|
|
|
|
|
p_ = self.discriminator(wav_)
|
|
|
|
|
adv_loss = self.criterion_gen_adv(p_)
|
|
|
|
|
|
|
|
|
|
report("eval/adversarial_loss", float(adv_loss))
|
|
|
|
|
losses_dict["adversarial_loss"] = float(adv_loss)
|
|
|
|
|
gen_loss = self.lambda_adv * adv_loss
|
|
|
|
|
|
|
|
|
|
# initialize
|
|
|
|
|
aux_loss = 0.0
|
|
|
|
|
# Multi-resolution stft loss
|
|
|
|
|
gen_loss += self.lambda_adv * adv_loss
|
|
|
|
|
|
|
|
|
|
# multi-resolution stft loss
|
|
|
|
|
sc_loss, mag_loss = self.criterion_stft(wav_, wav)
|
|
|
|
|
aux_loss += sc_loss + mag_loss
|
|
|
|
|
report("eval/spectral_convergence_loss", float(sc_loss))
|
|
|
|
@ -198,8 +205,7 @@ class StyleMelGANEvaluator(StandardEvaluator):
|
|
|
|
|
losses_dict["spectral_convergence_loss"] = float(sc_loss)
|
|
|
|
|
losses_dict["log_stft_magnitude_loss"] = float(mag_loss)
|
|
|
|
|
|
|
|
|
|
aux_loss *= self.lambda_aux
|
|
|
|
|
gen_loss += aux_loss
|
|
|
|
|
gen_loss += aux_loss * self.lambda_aux
|
|
|
|
|
|
|
|
|
|
report("eval/generator_loss", float(gen_loss))
|
|
|
|
|
losses_dict["generator_loss"] = float(gen_loss)
|
|
|
|
|