pull/3005/head
lym0302 3 years ago
parent c9c6960f7e
commit bd47de824c

@ -34,7 +34,9 @@ model:
# music score related # music score related
note_num: 300 # number of note note_num: 300 # number of note
is_slur_num: 2 # number of slur is_slur_num: 2 # number of slur
stretch: True # whether to stretch before diffusion # fastspeech2 module options
use_energy_pred: False # whether use energy predictor
use_postnet: False # whether use postnet
# fastspeech2 module # fastspeech2 module
fastspeech2_params: fastspeech2_params:
@ -106,7 +108,7 @@ model:
beta_end: 0.06 # beta end parameter for the scheduler beta_end: 0.06 # beta end parameter for the scheduler
beta_schedule: "linear" # beta schedule parameter for the scheduler beta_schedule: "linear" # beta schedule parameter for the scheduler
num_max_timesteps: 100 # The max timestep transition from real to noise num_max_timesteps: 100 # The max timestep transition from real to noise
stretch: True # whether to stretch before diffusion
########################################################### ###########################################################
@ -150,6 +152,7 @@ save_interval_steps: 2000 # Interval steps to save checkpoint.
eval_interval_steps: 2000 # Interval steps to evaluate the network. eval_interval_steps: 2000 # Interval steps to evaluate the network.
num_snapshots: 5 num_snapshots: 5
########################################################### ###########################################################
# OTHER SETTING # # OTHER SETTING #
########################################################### ###########################################################

@ -137,11 +137,11 @@ def train_sp(args, config):
odim = config.n_mels odim = config.n_mels
config["model"]["fastspeech2_params"]["spk_num"] = spk_num config["model"]["fastspeech2_params"]["spk_num"] = spk_num
model = DiffSinger( model = DiffSinger(
spec_min=spec_min,
spec_max=spec_max,
idim=vocab_size, idim=vocab_size,
odim=odim, odim=odim,
**config["model"], **config["model"], )
spec_min=spec_min,
spec_max=spec_max)
model_fs2 = model.fs2 model_fs2 = model.fs2
model_ds = model.diffusion model_ds = model.diffusion
if world_size > 1: if world_size > 1:

@ -373,11 +373,11 @@ def get_am_inference(
spec_max = paddle.to_tensor(spec_max) spec_max = paddle.to_tensor(spec_max)
am_config["model"]["fastspeech2_params"]["spk_num"] = spk_num am_config["model"]["fastspeech2_params"]["spk_num"] = spk_num
am = am_class( am = am_class(
spec_min=spec_min,
spec_max=spec_max,
idim=vocab_size, idim=vocab_size,
odim=odim, odim=odim,
**am_config["model"], **am_config["model"], )
spec_min=spec_min,
spec_max=spec_max, )
elif am_name == 'speedyspeech': elif am_name == 'speedyspeech':
am = am_class( am = am_class(
vocab_size=vocab_size, vocab_size=vocab_size,

@ -42,9 +42,14 @@ class DiffSinger(nn.Layer):
def __init__( def __init__(
self, self,
# min and max spec for stretching before diffusion
spec_min: paddle.Tensor,
spec_max: paddle.Tensor,
# fastspeech2midi config # fastspeech2midi config
idim: int, idim: int,
odim: int, odim: int,
use_energy_pred: bool=False,
use_postnet: bool=False,
# music score related # music score related
note_num: int=300, note_num: int=300,
is_slur_num: int=2, is_slur_num: int=2,
@ -134,24 +139,23 @@ class DiffSinger(nn.Layer):
"beta_start": 0.0001, "beta_start": 0.0001,
"beta_end": 0.06, "beta_end": 0.06,
"beta_schedule": "squaredcos_cap_v2", "beta_schedule": "squaredcos_cap_v2",
"num_max_timesteps": 60 "num_max_timesteps": 60,
}, "stretch": True,
stretch: bool=True, }, ):
spec_min: paddle.Tensor=None,
spec_max: paddle.Tensor=None, ):
"""Initialize DiffSinger module. """Initialize DiffSinger module.
Args: Args:
idim (int): spec_min (paddle.Tensor): The minimum value of the feature(mel) to stretch before diffusion.
Dimension of the inputs (Input vocabrary size.). spec_max (paddle.Tensor): The maximum value of the feature(mel) to stretch before diffusion.
odim (int): idim (int): Dimension of the inputs (Input vocabrary size.).
Dimension of the outputs (Acoustic feature dimension.). odim (int): Dimension of the outputs (Acoustic feature dimension.).
use_energy_pred (bool, optional): whether use energy predictor. Defaults False.
use_postnet (bool, optional): whether use postnet. Defaults False.
note_num (int, optional): The number of note. Defaults to 300. note_num (int, optional): The number of note. Defaults to 300.
is_slur_num (int, optional): The number of slur. Defaults to 2. is_slur_num (int, optional): The number of slur. Defaults to 2.
fastspeech2_params (Dict[str, Any]): Parameter dict for fastspeech2 module. fastspeech2_params (Dict[str, Any]): Parameter dict for fastspeech2 module.
denoiser_params (Dict[str, Any]): Parameter dict for dinoiser module. denoiser_params (Dict[str, Any]): Parameter dict for dinoiser module.
diffusion_params (Dict[str, Any]): Parameter dict for diffusion module. diffusion_params (Dict[str, Any]): Parameter dict for diffusion module.
stretch (bool): Whether to stretch before diffusion. Defaults True.
""" """
assert check_argument_types() assert check_argument_types()
super().__init__() super().__init__()
@ -160,12 +164,13 @@ class DiffSinger(nn.Layer):
odim=odim, odim=odim,
fastspeech2_params=fastspeech2_params, fastspeech2_params=fastspeech2_params,
note_num=note_num, note_num=note_num,
is_slur_num=is_slur_num) is_slur_num=is_slur_num,
use_energy_pred=use_energy_pred,
use_postnet=use_postnet, )
denoiser = DiffNet(**denoiser_params) denoiser = DiffNet(**denoiser_params)
self.diffusion = GaussianDiffusion( self.diffusion = GaussianDiffusion(
denoiser, denoiser,
**diffusion_params, **diffusion_params,
stretch=stretch,
min_values=spec_min, min_values=spec_min,
max_values=spec_max, ) max_values=spec_max, )
@ -319,7 +324,7 @@ class DiffSingerInference(nn.Layer):
logmel(Tensor(float32)): denorm logmel, [T, mel_bin] logmel(Tensor(float32)): denorm logmel, [T, mel_bin]
""" """
normalized_mel = self.acoustic_model.inference( normalized_mel = self.acoustic_model.inference(
text, text=text,
note=note, note=note,
note_dur=note_dur, note_dur=note_dur,
is_slur=is_slur, is_slur=is_slur,

@ -121,17 +121,18 @@ class DiffSingerUpdater(StandardUpdater):
report("train/ssim_loss_fs2", float(ssim_loss_fs2)) report("train/ssim_loss_fs2", float(ssim_loss_fs2))
report("train/duration_loss", float(duration_loss)) report("train/duration_loss", float(duration_loss))
report("train/pitch_loss", float(pitch_loss)) report("train/pitch_loss", float(pitch_loss))
report("train/energy_loss", float(energy_loss))
losses_dict["l1_loss_fs2"] = float(l1_loss_fs2) losses_dict["l1_loss_fs2"] = float(l1_loss_fs2)
losses_dict["ssim_loss_fs2"] = float(ssim_loss_fs2) losses_dict["ssim_loss_fs2"] = float(ssim_loss_fs2)
losses_dict["duration_loss"] = float(duration_loss) losses_dict["duration_loss"] = float(duration_loss)
losses_dict["pitch_loss"] = float(pitch_loss) losses_dict["pitch_loss"] = float(pitch_loss)
losses_dict["energy_loss"] = float(energy_loss)
if speaker_loss != 0.: if speaker_loss != 0.:
report("train/speaker_loss", float(speaker_loss)) report("train/speaker_loss", float(speaker_loss))
losses_dict["speaker_loss"] = float(speaker_loss) losses_dict["speaker_loss"] = float(speaker_loss)
if energy_loss != 0.:
report("train/energy_loss", float(energy_loss))
losses_dict["energy_loss"] = float(energy_loss)
losses_dict["loss_fs2"] = float(loss_fs2) losses_dict["loss_fs2"] = float(loss_fs2)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v) self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
@ -250,17 +251,18 @@ class DiffSingerEvaluator(StandardEvaluator):
report("eval/ssim_loss_fs2", float(ssim_loss_fs2)) report("eval/ssim_loss_fs2", float(ssim_loss_fs2))
report("eval/duration_loss", float(duration_loss)) report("eval/duration_loss", float(duration_loss))
report("eval/pitch_loss", float(pitch_loss)) report("eval/pitch_loss", float(pitch_loss))
report("eval/energy_loss", float(energy_loss))
losses_dict["l1_loss_fs2"] = float(l1_loss_fs2) losses_dict["l1_loss_fs2"] = float(l1_loss_fs2)
losses_dict["ssim_loss_fs2"] = float(ssim_loss_fs2) losses_dict["ssim_loss_fs2"] = float(ssim_loss_fs2)
losses_dict["duration_loss"] = float(duration_loss) losses_dict["duration_loss"] = float(duration_loss)
losses_dict["pitch_loss"] = float(pitch_loss) losses_dict["pitch_loss"] = float(pitch_loss)
losses_dict["energy_loss"] = float(energy_loss)
if speaker_loss != 0.: if speaker_loss != 0.:
report("eval/speaker_loss", float(speaker_loss)) report("eval/speaker_loss", float(speaker_loss))
losses_dict["speaker_loss"] = float(speaker_loss) losses_dict["speaker_loss"] = float(speaker_loss)
if energy_loss != 0.:
report("eval/energy_loss", float(energy_loss))
losses_dict["energy_loss"] = float(energy_loss)
losses_dict["loss_fs2"] = float(loss_fs2) losses_dict["loss_fs2"] = float(loss_fs2)

@ -42,7 +42,9 @@ class FastSpeech2MIDI(FastSpeech2):
# note emb # note emb
note_num: int=300, note_num: int=300,
# is_slur emb # is_slur emb
is_slur_num: int=2, ): is_slur_num: int=2,
use_energy_pred: bool=False,
use_postnet: bool=False, ):
"""Initialize FastSpeech2 module for svs. """Initialize FastSpeech2 module for svs.
Args: Args:
fastspeech2_params (Dict): fastspeech2_params (Dict):
@ -57,6 +59,10 @@ class FastSpeech2MIDI(FastSpeech2):
""" """
assert check_argument_types() assert check_argument_types()
super().__init__(idim=idim, odim=odim, **fastspeech2_params) super().__init__(idim=idim, odim=odim, **fastspeech2_params)
self.use_energy_pred = use_energy_pred
self.use_postnet = use_postnet
if not self.use_postnet:
self.postnet = None
self.note_embed_dim = self.is_slur_embed_dim = fastspeech2_params[ self.note_embed_dim = self.is_slur_embed_dim = fastspeech2_params[
"adim"] "adim"]
@ -214,12 +220,14 @@ class FastSpeech2MIDI(FastSpeech2):
if is_train_diffusion: if is_train_diffusion:
hs = self.length_regulator(hs, ds, is_inference=False) hs = self.length_regulator(hs, ds, is_inference=False)
p_outs = self.pitch_predictor(hs.detach(), pitch_masks) p_outs = self.pitch_predictor(hs.detach(), pitch_masks)
e_outs = self.energy_predictor(hs.detach(), pitch_masks)
p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose( p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose(
(0, 2, 1)) (0, 2, 1))
e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose( hs += p_embs
(0, 2, 1)) if self.use_energy_pred:
hs = hs + p_embs + e_embs e_outs = self.energy_predictor(hs.detach(), pitch_masks)
e_embs = self.energy_embed(
e_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
hs += e_embs
elif is_inference: elif is_inference:
# (B, Tmax) # (B, Tmax)
@ -238,7 +246,11 @@ class FastSpeech2MIDI(FastSpeech2):
p_outs = self.pitch_predictor(hs.detach(), pitch_masks) p_outs = self.pitch_predictor(hs.detach(), pitch_masks)
else: else:
p_outs = self.pitch_predictor(hs, pitch_masks) p_outs = self.pitch_predictor(hs, pitch_masks)
p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose(
(0, 2, 1))
hs += p_embs
if self.use_energy_pred:
if es is not None: if es is not None:
e_outs = es e_outs = es
else: else:
@ -246,12 +258,9 @@ class FastSpeech2MIDI(FastSpeech2):
e_outs = self.energy_predictor(hs.detach(), pitch_masks) e_outs = self.energy_predictor(hs.detach(), pitch_masks)
else: else:
e_outs = self.energy_predictor(hs, pitch_masks) e_outs = self.energy_predictor(hs, pitch_masks)
e_embs = self.energy_embed(
p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose( e_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
(0, 2, 1)) hs += e_embs
e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose(
(0, 2, 1))
hs = hs + p_embs + e_embs
# training # training
else: else:
@ -262,15 +271,18 @@ class FastSpeech2MIDI(FastSpeech2):
p_outs = self.pitch_predictor(hs.detach(), pitch_masks) p_outs = self.pitch_predictor(hs.detach(), pitch_masks)
else: else:
p_outs = self.pitch_predictor(hs, pitch_masks) p_outs = self.pitch_predictor(hs, pitch_masks)
p_embs = self.pitch_embed(ps.transpose((0, 2, 1))).transpose(
(0, 2, 1))
hs += p_embs
if self.use_energy_pred:
if self.stop_gradient_from_energy_predictor: if self.stop_gradient_from_energy_predictor:
e_outs = self.energy_predictor(hs.detach(), pitch_masks) e_outs = self.energy_predictor(hs.detach(), pitch_masks)
else: else:
e_outs = self.energy_predictor(hs, pitch_masks) e_outs = self.energy_predictor(hs, pitch_masks)
p_embs = self.pitch_embed(ps.transpose((0, 2, 1))).transpose(
(0, 2, 1))
e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose( e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose(
(0, 2, 1)) (0, 2, 1))
hs = hs + p_embs + e_embs hs += e_embs
# forward decoder # forward decoder
if olens is not None and not is_inference: if olens is not None and not is_inference:
@ -304,7 +316,6 @@ class FastSpeech2MIDI(FastSpeech2):
else: else:
after_outs = before_outs + self.postnet( after_outs = before_outs + self.postnet(
before_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) before_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
after_outs = before_outs
return before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits return before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits
@ -475,6 +486,9 @@ class FastSpeech2MIDI(FastSpeech2):
spk_emb=spk_emb, spk_emb=spk_emb,
spk_id=spk_id, ) spk_id=spk_id, )
if e_outs is None:
e_outs = [None]
return outs[0], d_outs[0], p_outs[0], e_outs[0] return outs[0], d_outs[0], p_outs[0], e_outs[0]
@ -552,14 +566,14 @@ class FastSpeech2MIDILoss(FastSpeech2Loss):
# make feature for ssim loss # make feature for ssim loss
out_pad_masks = make_pad_mask(olens).unsqueeze(-1) out_pad_masks = make_pad_mask(olens).unsqueeze(-1)
before_outs_ssim = masked_fill(before_outs, out_pad_masks, 0.0) before_outs_ssim = masked_fill(before_outs, out_pad_masks, 0.0)
if after_outs is not None: if not paddle.equal_all(after_outs, before_outs):
after_outs_ssim = masked_fill(after_outs, out_pad_masks, 0.0) after_outs_ssim = masked_fill(after_outs, out_pad_masks, 0.0)
ys_ssim = masked_fill(ys, out_pad_masks, 0.0) ys_ssim = masked_fill(ys, out_pad_masks, 0.0)
out_masks = make_non_pad_mask(olens).unsqueeze(-1) out_masks = make_non_pad_mask(olens).unsqueeze(-1)
before_outs = before_outs.masked_select( before_outs = before_outs.masked_select(
out_masks.broadcast_to(before_outs.shape)) out_masks.broadcast_to(before_outs.shape))
if after_outs is not None: if not paddle.equal_all(after_outs, before_outs):
after_outs = after_outs.masked_select( after_outs = after_outs.masked_select(
out_masks.broadcast_to(after_outs.shape)) out_masks.broadcast_to(after_outs.shape))
ys = ys.masked_select(out_masks.broadcast_to(ys.shape)) ys = ys.masked_select(out_masks.broadcast_to(ys.shape))
@ -570,9 +584,10 @@ class FastSpeech2MIDILoss(FastSpeech2Loss):
pitch_masks = out_masks pitch_masks = out_masks
p_outs = p_outs.masked_select( p_outs = p_outs.masked_select(
pitch_masks.broadcast_to(p_outs.shape)) pitch_masks.broadcast_to(p_outs.shape))
ps = ps.masked_select(pitch_masks.broadcast_to(ps.shape))
if e_outs is not None:
e_outs = e_outs.masked_select( e_outs = e_outs.masked_select(
pitch_masks.broadcast_to(e_outs.shape)) pitch_masks.broadcast_to(e_outs.shape))
ps = ps.masked_select(pitch_masks.broadcast_to(ps.shape))
es = es.masked_select(pitch_masks.broadcast_to(es.shape)) es = es.masked_select(pitch_masks.broadcast_to(es.shape))
if spk_logits is not None and spk_ids is not None: if spk_logits is not None and spk_ids is not None:
@ -589,7 +604,7 @@ class FastSpeech2MIDILoss(FastSpeech2Loss):
l1_loss = self.l1_criterion(before_outs, ys) l1_loss = self.l1_criterion(before_outs, ys)
ssim_loss = 1.0 - ssim( ssim_loss = 1.0 - ssim(
before_outs_ssim.unsqueeze(1), ys_ssim.unsqueeze(1)) before_outs_ssim.unsqueeze(1), ys_ssim.unsqueeze(1))
if after_outs is not None: if not paddle.equal_all(after_outs, before_outs):
l1_loss += self.l1_criterion(after_outs, ys) l1_loss += self.l1_criterion(after_outs, ys)
ssim_loss += ( ssim_loss += (
1.0 - ssim(after_outs_ssim.unsqueeze(1), ys_ssim.unsqueeze(1))) 1.0 - ssim(after_outs_ssim.unsqueeze(1), ys_ssim.unsqueeze(1)))
@ -598,6 +613,7 @@ class FastSpeech2MIDILoss(FastSpeech2Loss):
duration_loss = self.duration_criterion(d_outs, ds) duration_loss = self.duration_criterion(d_outs, ds)
pitch_loss = self.l1_criterion(p_outs, ps) pitch_loss = self.l1_criterion(p_outs, ps)
if e_outs is not None:
energy_loss = self.l1_criterion(e_outs, es) energy_loss = self.l1_criterion(e_outs, es)
if spk_logits is not None and spk_ids is not None: if spk_logits is not None and spk_ids is not None:
@ -630,6 +646,7 @@ class FastSpeech2MIDILoss(FastSpeech2Loss):
pitch_loss = pitch_loss.multiply(pitch_weights) pitch_loss = pitch_loss.multiply(pitch_weights)
pitch_loss = pitch_loss.masked_select( pitch_loss = pitch_loss.masked_select(
pitch_masks.broadcast_to(pitch_loss.shape)).sum() pitch_masks.broadcast_to(pitch_loss.shape)).sum()
if e_outs is not None:
energy_loss = energy_loss.multiply(pitch_weights) energy_loss = energy_loss.multiply(pitch_weights)
energy_loss = energy_loss.masked_select( energy_loss = energy_loss.masked_select(
pitch_masks.broadcast_to(energy_loss.shape)).sum() pitch_masks.broadcast_to(energy_loss.shape)).sum()

Loading…
Cancel
Save