From bd47de824ceb996e742faab47fadc66d803c215e Mon Sep 17 00:00:00 2001 From: lym0302 Date: Fri, 10 Mar 2023 04:05:59 +0000 Subject: [PATCH] update --- examples/opencpop/svs1/conf/default.yaml | 11 ++- paddlespeech/t2s/exps/diffsinger/train.py | 6 +- paddlespeech/t2s/exps/syn_utils.py | 6 +- .../t2s/models/diffsinger/diffsinger.py | 31 ++++--- .../models/diffsinger/diffsinger_updater.py | 10 ++- .../t2s/models/diffsinger/fastspeech2midi.py | 87 +++++++++++-------- 6 files changed, 89 insertions(+), 62 deletions(-) diff --git a/examples/opencpop/svs1/conf/default.yaml b/examples/opencpop/svs1/conf/default.yaml index 32ad115a7..5d8060630 100644 --- a/examples/opencpop/svs1/conf/default.yaml +++ b/examples/opencpop/svs1/conf/default.yaml @@ -34,7 +34,9 @@ model: # music score related note_num: 300 # number of note is_slur_num: 2 # number of slur - stretch: True # whether to stretch before diffusion + # fastspeech2 module options + use_energy_pred: False # whether use energy predictor + use_postnet: False # whether use postnet # fastspeech2 module fastspeech2_params: @@ -105,8 +107,8 @@ model: beta_start: 0.0001 # beta start parameter for the scheduler beta_end: 0.06 # beta end parameter for the scheduler beta_schedule: "linear" # beta schedule parameter for the scheduler - num_max_timesteps: 100 # The max timestep transition from real to noise - + num_max_timesteps: 100 # The max timestep transition from real to noise + stretch: True # whether to stretch before diffusion ########################################################### @@ -136,7 +138,7 @@ ds_optimizer_params: ds_scheduler_params: learning_rate: 0.001 gamma: 0.5 - step_size: 50000 + step_size: 50000 ds_grad_norm: 1 @@ -150,6 +152,7 @@ save_interval_steps: 2000 # Interval steps to save checkpoint. eval_interval_steps: 2000 # Interval steps to evaluate the network. num_snapshots: 5 + ########################################################### # OTHER SETTING # ########################################################### diff --git a/paddlespeech/t2s/exps/diffsinger/train.py b/paddlespeech/t2s/exps/diffsinger/train.py index 3f062eefc..e79104c4a 100644 --- a/paddlespeech/t2s/exps/diffsinger/train.py +++ b/paddlespeech/t2s/exps/diffsinger/train.py @@ -137,11 +137,11 @@ def train_sp(args, config): odim = config.n_mels config["model"]["fastspeech2_params"]["spk_num"] = spk_num model = DiffSinger( + spec_min=spec_min, + spec_max=spec_max, idim=vocab_size, odim=odim, - **config["model"], - spec_min=spec_min, - spec_max=spec_max) + **config["model"], ) model_fs2 = model.fs2 model_ds = model.diffusion if world_size > 1: diff --git a/paddlespeech/t2s/exps/syn_utils.py b/paddlespeech/t2s/exps/syn_utils.py index df8e0659b..6ab8a1674 100644 --- a/paddlespeech/t2s/exps/syn_utils.py +++ b/paddlespeech/t2s/exps/syn_utils.py @@ -373,11 +373,11 @@ def get_am_inference( spec_max = paddle.to_tensor(spec_max) am_config["model"]["fastspeech2_params"]["spk_num"] = spk_num am = am_class( + spec_min=spec_min, + spec_max=spec_max, idim=vocab_size, odim=odim, - **am_config["model"], - spec_min=spec_min, - spec_max=spec_max, ) + **am_config["model"], ) elif am_name == 'speedyspeech': am = am_class( vocab_size=vocab_size, diff --git a/paddlespeech/t2s/models/diffsinger/diffsinger.py b/paddlespeech/t2s/models/diffsinger/diffsinger.py index 50f719918..990cfc56a 100644 --- a/paddlespeech/t2s/models/diffsinger/diffsinger.py +++ b/paddlespeech/t2s/models/diffsinger/diffsinger.py @@ -42,9 +42,14 @@ class DiffSinger(nn.Layer): def __init__( self, + # min and max spec for stretching before diffusion + spec_min: paddle.Tensor, + spec_max: paddle.Tensor, # fastspeech2midi config idim: int, odim: int, + use_energy_pred: bool=False, + use_postnet: bool=False, # music score related note_num: int=300, is_slur_num: int=2, @@ -134,24 +139,23 @@ class DiffSinger(nn.Layer): "beta_start": 0.0001, "beta_end": 0.06, "beta_schedule": "squaredcos_cap_v2", - "num_max_timesteps": 60 - }, - stretch: bool=True, - spec_min: paddle.Tensor=None, - spec_max: paddle.Tensor=None, ): + "num_max_timesteps": 60, + "stretch": True, + }, ): """Initialize DiffSinger module. Args: - idim (int): - Dimension of the inputs (Input vocabrary size.). - odim (int): - Dimension of the outputs (Acoustic feature dimension.). + spec_min (paddle.Tensor): The minimum value of the feature(mel) to stretch before diffusion. + spec_max (paddle.Tensor): The maximum value of the feature(mel) to stretch before diffusion. + idim (int): Dimension of the inputs (Input vocabrary size.). + odim (int): Dimension of the outputs (Acoustic feature dimension.). + use_energy_pred (bool, optional): whether use energy predictor. Defaults False. + use_postnet (bool, optional): whether use postnet. Defaults False. note_num (int, optional): The number of note. Defaults to 300. is_slur_num (int, optional): The number of slur. Defaults to 2. fastspeech2_params (Dict[str, Any]): Parameter dict for fastspeech2 module. denoiser_params (Dict[str, Any]): Parameter dict for dinoiser module. diffusion_params (Dict[str, Any]): Parameter dict for diffusion module. - stretch (bool): Whether to stretch before diffusion. Defaults True. """ assert check_argument_types() super().__init__() @@ -160,12 +164,13 @@ class DiffSinger(nn.Layer): odim=odim, fastspeech2_params=fastspeech2_params, note_num=note_num, - is_slur_num=is_slur_num) + is_slur_num=is_slur_num, + use_energy_pred=use_energy_pred, + use_postnet=use_postnet, ) denoiser = DiffNet(**denoiser_params) self.diffusion = GaussianDiffusion( denoiser, **diffusion_params, - stretch=stretch, min_values=spec_min, max_values=spec_max, ) @@ -319,7 +324,7 @@ class DiffSingerInference(nn.Layer): logmel(Tensor(float32)): denorm logmel, [T, mel_bin] """ normalized_mel = self.acoustic_model.inference( - text, + text=text, note=note, note_dur=note_dur, is_slur=is_slur, diff --git a/paddlespeech/t2s/models/diffsinger/diffsinger_updater.py b/paddlespeech/t2s/models/diffsinger/diffsinger_updater.py index 018b781d1..d89b09b2a 100644 --- a/paddlespeech/t2s/models/diffsinger/diffsinger_updater.py +++ b/paddlespeech/t2s/models/diffsinger/diffsinger_updater.py @@ -121,17 +121,18 @@ class DiffSingerUpdater(StandardUpdater): report("train/ssim_loss_fs2", float(ssim_loss_fs2)) report("train/duration_loss", float(duration_loss)) report("train/pitch_loss", float(pitch_loss)) - report("train/energy_loss", float(energy_loss)) losses_dict["l1_loss_fs2"] = float(l1_loss_fs2) losses_dict["ssim_loss_fs2"] = float(ssim_loss_fs2) losses_dict["duration_loss"] = float(duration_loss) losses_dict["pitch_loss"] = float(pitch_loss) - losses_dict["energy_loss"] = float(energy_loss) if speaker_loss != 0.: report("train/speaker_loss", float(speaker_loss)) losses_dict["speaker_loss"] = float(speaker_loss) + if energy_loss != 0.: + report("train/energy_loss", float(energy_loss)) + losses_dict["energy_loss"] = float(energy_loss) losses_dict["loss_fs2"] = float(loss_fs2) self.msg += ', '.join('{}: {:>.6f}'.format(k, v) @@ -250,17 +251,18 @@ class DiffSingerEvaluator(StandardEvaluator): report("eval/ssim_loss_fs2", float(ssim_loss_fs2)) report("eval/duration_loss", float(duration_loss)) report("eval/pitch_loss", float(pitch_loss)) - report("eval/energy_loss", float(energy_loss)) losses_dict["l1_loss_fs2"] = float(l1_loss_fs2) losses_dict["ssim_loss_fs2"] = float(ssim_loss_fs2) losses_dict["duration_loss"] = float(duration_loss) losses_dict["pitch_loss"] = float(pitch_loss) - losses_dict["energy_loss"] = float(energy_loss) if speaker_loss != 0.: report("eval/speaker_loss", float(speaker_loss)) losses_dict["speaker_loss"] = float(speaker_loss) + if energy_loss != 0.: + report("eval/energy_loss", float(energy_loss)) + losses_dict["energy_loss"] = float(energy_loss) losses_dict["loss_fs2"] = float(loss_fs2) diff --git a/paddlespeech/t2s/models/diffsinger/fastspeech2midi.py b/paddlespeech/t2s/models/diffsinger/fastspeech2midi.py index 7846779db..cce88d8a0 100644 --- a/paddlespeech/t2s/models/diffsinger/fastspeech2midi.py +++ b/paddlespeech/t2s/models/diffsinger/fastspeech2midi.py @@ -42,7 +42,9 @@ class FastSpeech2MIDI(FastSpeech2): # note emb note_num: int=300, # is_slur emb - is_slur_num: int=2, ): + is_slur_num: int=2, + use_energy_pred: bool=False, + use_postnet: bool=False, ): """Initialize FastSpeech2 module for svs. Args: fastspeech2_params (Dict): @@ -57,6 +59,10 @@ class FastSpeech2MIDI(FastSpeech2): """ assert check_argument_types() super().__init__(idim=idim, odim=odim, **fastspeech2_params) + self.use_energy_pred = use_energy_pred + self.use_postnet = use_postnet + if not self.use_postnet: + self.postnet = None self.note_embed_dim = self.is_slur_embed_dim = fastspeech2_params[ "adim"] @@ -214,12 +220,14 @@ class FastSpeech2MIDI(FastSpeech2): if is_train_diffusion: hs = self.length_regulator(hs, ds, is_inference=False) p_outs = self.pitch_predictor(hs.detach(), pitch_masks) - e_outs = self.energy_predictor(hs.detach(), pitch_masks) p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose( (0, 2, 1)) - e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose( - (0, 2, 1)) - hs = hs + p_embs + e_embs + hs += p_embs + if self.use_energy_pred: + e_outs = self.energy_predictor(hs.detach(), pitch_masks) + e_embs = self.energy_embed( + e_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) + hs += e_embs elif is_inference: # (B, Tmax) @@ -238,20 +246,21 @@ class FastSpeech2MIDI(FastSpeech2): p_outs = self.pitch_predictor(hs.detach(), pitch_masks) else: p_outs = self.pitch_predictor(hs, pitch_masks) - - if es is not None: - e_outs = es - else: - if self.stop_gradient_from_energy_predictor: - e_outs = self.energy_predictor(hs.detach(), pitch_masks) - else: - e_outs = self.energy_predictor(hs, pitch_masks) - p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose( (0, 2, 1)) - e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose( - (0, 2, 1)) - hs = hs + p_embs + e_embs + hs += p_embs + + if self.use_energy_pred: + if es is not None: + e_outs = es + else: + if self.stop_gradient_from_energy_predictor: + e_outs = self.energy_predictor(hs.detach(), pitch_masks) + else: + e_outs = self.energy_predictor(hs, pitch_masks) + e_embs = self.energy_embed( + e_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) + hs += e_embs # training else: @@ -262,15 +271,18 @@ class FastSpeech2MIDI(FastSpeech2): p_outs = self.pitch_predictor(hs.detach(), pitch_masks) else: p_outs = self.pitch_predictor(hs, pitch_masks) - if self.stop_gradient_from_energy_predictor: - e_outs = self.energy_predictor(hs.detach(), pitch_masks) - else: - e_outs = self.energy_predictor(hs, pitch_masks) p_embs = self.pitch_embed(ps.transpose((0, 2, 1))).transpose( (0, 2, 1)) - e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose( - (0, 2, 1)) - hs = hs + p_embs + e_embs + hs += p_embs + + if self.use_energy_pred: + if self.stop_gradient_from_energy_predictor: + e_outs = self.energy_predictor(hs.detach(), pitch_masks) + else: + e_outs = self.energy_predictor(hs, pitch_masks) + e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose( + (0, 2, 1)) + hs += e_embs # forward decoder if olens is not None and not is_inference: @@ -304,7 +316,6 @@ class FastSpeech2MIDI(FastSpeech2): else: after_outs = before_outs + self.postnet( before_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) - after_outs = before_outs return before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits @@ -475,6 +486,9 @@ class FastSpeech2MIDI(FastSpeech2): spk_emb=spk_emb, spk_id=spk_id, ) + if e_outs is None: + e_outs = [None] + return outs[0], d_outs[0], p_outs[0], e_outs[0] @@ -552,14 +566,14 @@ class FastSpeech2MIDILoss(FastSpeech2Loss): # make feature for ssim loss out_pad_masks = make_pad_mask(olens).unsqueeze(-1) before_outs_ssim = masked_fill(before_outs, out_pad_masks, 0.0) - if after_outs is not None: + if not paddle.equal_all(after_outs, before_outs): after_outs_ssim = masked_fill(after_outs, out_pad_masks, 0.0) ys_ssim = masked_fill(ys, out_pad_masks, 0.0) out_masks = make_non_pad_mask(olens).unsqueeze(-1) before_outs = before_outs.masked_select( out_masks.broadcast_to(before_outs.shape)) - if after_outs is not None: + if not paddle.equal_all(after_outs, before_outs): after_outs = after_outs.masked_select( out_masks.broadcast_to(after_outs.shape)) ys = ys.masked_select(out_masks.broadcast_to(ys.shape)) @@ -570,10 +584,11 @@ class FastSpeech2MIDILoss(FastSpeech2Loss): pitch_masks = out_masks p_outs = p_outs.masked_select( pitch_masks.broadcast_to(p_outs.shape)) - e_outs = e_outs.masked_select( - pitch_masks.broadcast_to(e_outs.shape)) ps = ps.masked_select(pitch_masks.broadcast_to(ps.shape)) - es = es.masked_select(pitch_masks.broadcast_to(es.shape)) + if e_outs is not None: + e_outs = e_outs.masked_select( + pitch_masks.broadcast_to(e_outs.shape)) + es = es.masked_select(pitch_masks.broadcast_to(es.shape)) if spk_logits is not None and spk_ids is not None: batch_size = spk_ids.shape[0] @@ -589,7 +604,7 @@ class FastSpeech2MIDILoss(FastSpeech2Loss): l1_loss = self.l1_criterion(before_outs, ys) ssim_loss = 1.0 - ssim( before_outs_ssim.unsqueeze(1), ys_ssim.unsqueeze(1)) - if after_outs is not None: + if not paddle.equal_all(after_outs, before_outs): l1_loss += self.l1_criterion(after_outs, ys) ssim_loss += ( 1.0 - ssim(after_outs_ssim.unsqueeze(1), ys_ssim.unsqueeze(1))) @@ -598,7 +613,8 @@ class FastSpeech2MIDILoss(FastSpeech2Loss): duration_loss = self.duration_criterion(d_outs, ds) pitch_loss = self.l1_criterion(p_outs, ps) - energy_loss = self.l1_criterion(e_outs, es) + if e_outs is not None: + energy_loss = self.l1_criterion(e_outs, es) if spk_logits is not None and spk_ids is not None: speaker_loss = self.ce_criterion(spk_logits, spk_ids) / batch_size @@ -630,8 +646,9 @@ class FastSpeech2MIDILoss(FastSpeech2Loss): pitch_loss = pitch_loss.multiply(pitch_weights) pitch_loss = pitch_loss.masked_select( pitch_masks.broadcast_to(pitch_loss.shape)).sum() - energy_loss = energy_loss.multiply(pitch_weights) - energy_loss = energy_loss.masked_select( - pitch_masks.broadcast_to(energy_loss.shape)).sum() + if e_outs is not None: + energy_loss = energy_loss.multiply(pitch_weights) + energy_loss = energy_loss.masked_select( + pitch_masks.broadcast_to(energy_loss.shape)).sum() return l1_loss, ssim_loss, duration_loss, pitch_loss, energy_loss, speaker_loss