|
|
|
@ -22,9 +22,11 @@ from paddle import nn
|
|
|
|
|
from typeguard import check_argument_types
|
|
|
|
|
|
|
|
|
|
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2
|
|
|
|
|
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Loss
|
|
|
|
|
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
|
|
|
|
|
from paddlespeech.t2s.modules.nets_utils import make_pad_mask
|
|
|
|
|
from paddlespeech.t2s.modules.predictor.duration_predictor import DurationPredictorLoss
|
|
|
|
|
from paddlespeech.t2s.modules.losses import ssim
|
|
|
|
|
from paddlespeech.t2s.modules.masked_fill import masked_fill
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FastSpeech2MIDI(FastSpeech2):
|
|
|
|
@ -36,14 +38,14 @@ class FastSpeech2MIDI(FastSpeech2):
|
|
|
|
|
# fastspeech2 network structure related
|
|
|
|
|
idim: int,
|
|
|
|
|
odim: int,
|
|
|
|
|
fastspeech2_config: Dict[str, Any],
|
|
|
|
|
fastspeech2_params: Dict[str, Any],
|
|
|
|
|
# note emb
|
|
|
|
|
note_num: int=300,
|
|
|
|
|
# is_slur emb
|
|
|
|
|
is_slur_num: int=2, ):
|
|
|
|
|
"""Initialize FastSpeech2 module for svs.
|
|
|
|
|
Args:
|
|
|
|
|
fastspeech2_config (Dict):
|
|
|
|
|
fastspeech2_params (Dict):
|
|
|
|
|
The config of FastSpeech2 module on DiffSinger model
|
|
|
|
|
note_num (Optional[int]):
|
|
|
|
|
Number of note. If not None, assume that the
|
|
|
|
@ -54,9 +56,9 @@ class FastSpeech2MIDI(FastSpeech2):
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
assert check_argument_types()
|
|
|
|
|
super().__init__(idim=idim, odim=odim, **fastspeech2_config)
|
|
|
|
|
super().__init__(idim=idim, odim=odim, **fastspeech2_params)
|
|
|
|
|
|
|
|
|
|
self.note_embed_dim = self.is_slur_embed_dim = fastspeech2_config[
|
|
|
|
|
self.note_embed_dim = self.is_slur_embed_dim = fastspeech2_params[
|
|
|
|
|
"adim"]
|
|
|
|
|
|
|
|
|
|
if note_num is not None:
|
|
|
|
@ -133,15 +135,15 @@ class FastSpeech2MIDI(FastSpeech2):
|
|
|
|
|
spk_id = paddle.cast(spk_id, 'int64')
|
|
|
|
|
# forward propagation
|
|
|
|
|
before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits = self._forward(
|
|
|
|
|
xs,
|
|
|
|
|
note,
|
|
|
|
|
note_dur,
|
|
|
|
|
is_slur,
|
|
|
|
|
ilens,
|
|
|
|
|
olens,
|
|
|
|
|
ds,
|
|
|
|
|
ps,
|
|
|
|
|
es,
|
|
|
|
|
xs=xs,
|
|
|
|
|
note=note,
|
|
|
|
|
note_dur=note_dur,
|
|
|
|
|
is_slur=is_slur,
|
|
|
|
|
ilens=ilens,
|
|
|
|
|
olens=olens,
|
|
|
|
|
ds=ds,
|
|
|
|
|
ps=ps,
|
|
|
|
|
es=es,
|
|
|
|
|
is_inference=False,
|
|
|
|
|
spk_emb=spk_emb,
|
|
|
|
|
spk_id=spk_id, )
|
|
|
|
@ -170,6 +172,8 @@ class FastSpeech2MIDI(FastSpeech2):
|
|
|
|
|
alpha: float=1.0,
|
|
|
|
|
spk_emb=None,
|
|
|
|
|
spk_id=None, ) -> Sequence[paddle.Tensor]:
|
|
|
|
|
|
|
|
|
|
before_outs = after_outs = d_outs = p_outs = e_outs = spk_logits = None
|
|
|
|
|
# forward encoder
|
|
|
|
|
x_masks = self._source_mask(ilens)
|
|
|
|
|
note_emb = self.note_embedding_table(note)
|
|
|
|
@ -206,16 +210,17 @@ class FastSpeech2MIDI(FastSpeech2):
|
|
|
|
|
else:
|
|
|
|
|
pitch_masks = None
|
|
|
|
|
|
|
|
|
|
# inference for decoder input for duffusion
|
|
|
|
|
# inference for decoder input for diffusion
|
|
|
|
|
if is_train_diffusion:
|
|
|
|
|
hs = self.length_regulator(hs, ds, is_inference=False)
|
|
|
|
|
p_outs = self.pitch_predictor(hs.detach(), pitch_masks)
|
|
|
|
|
e_outs = self.energy_predictor(hs.detach(), pitch_masks)
|
|
|
|
|
# e_outs = self.energy_predictor(hs.detach(), pitch_masks)
|
|
|
|
|
p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose(
|
|
|
|
|
(0, 2, 1))
|
|
|
|
|
e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose(
|
|
|
|
|
(0, 2, 1))
|
|
|
|
|
hs = hs + e_embs + p_embs
|
|
|
|
|
# e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose(
|
|
|
|
|
# (0, 2, 1))
|
|
|
|
|
# hs = hs + p_embs + e_embs
|
|
|
|
|
hs = hs + p_embs
|
|
|
|
|
|
|
|
|
|
elif is_inference:
|
|
|
|
|
# (B, Tmax)
|
|
|
|
@ -235,19 +240,20 @@ class FastSpeech2MIDI(FastSpeech2):
|
|
|
|
|
else:
|
|
|
|
|
p_outs = self.pitch_predictor(hs, pitch_masks)
|
|
|
|
|
|
|
|
|
|
if es is not None:
|
|
|
|
|
e_outs = es
|
|
|
|
|
else:
|
|
|
|
|
if self.stop_gradient_from_energy_predictor:
|
|
|
|
|
e_outs = self.energy_predictor(hs.detach(), pitch_masks)
|
|
|
|
|
else:
|
|
|
|
|
e_outs = self.energy_predictor(hs, pitch_masks)
|
|
|
|
|
# if es is not None:
|
|
|
|
|
# e_outs = es
|
|
|
|
|
# else:
|
|
|
|
|
# if self.stop_gradient_from_energy_predictor:
|
|
|
|
|
# e_outs = self.energy_predictor(hs.detach(), pitch_masks)
|
|
|
|
|
# else:
|
|
|
|
|
# e_outs = self.energy_predictor(hs, pitch_masks)
|
|
|
|
|
|
|
|
|
|
p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose(
|
|
|
|
|
(0, 2, 1))
|
|
|
|
|
e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose(
|
|
|
|
|
(0, 2, 1))
|
|
|
|
|
hs = hs + e_embs + p_embs
|
|
|
|
|
# e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose(
|
|
|
|
|
# (0, 2, 1))
|
|
|
|
|
# hs = hs + p_embs + e_embs
|
|
|
|
|
hs = hs + p_embs
|
|
|
|
|
|
|
|
|
|
# training
|
|
|
|
|
else:
|
|
|
|
@ -258,15 +264,16 @@ class FastSpeech2MIDI(FastSpeech2):
|
|
|
|
|
p_outs = self.pitch_predictor(hs.detach(), pitch_masks)
|
|
|
|
|
else:
|
|
|
|
|
p_outs = self.pitch_predictor(hs, pitch_masks)
|
|
|
|
|
if self.stop_gradient_from_energy_predictor:
|
|
|
|
|
e_outs = self.energy_predictor(hs.detach(), pitch_masks)
|
|
|
|
|
else:
|
|
|
|
|
e_outs = self.energy_predictor(hs, pitch_masks)
|
|
|
|
|
# if self.stop_gradient_from_energy_predictor:
|
|
|
|
|
# e_outs = self.energy_predictor(hs.detach(), pitch_masks)
|
|
|
|
|
# else:
|
|
|
|
|
# e_outs = self.energy_predictor(hs, pitch_masks)
|
|
|
|
|
p_embs = self.pitch_embed(ps.transpose((0, 2, 1))).transpose(
|
|
|
|
|
(0, 2, 1))
|
|
|
|
|
e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose(
|
|
|
|
|
(0, 2, 1))
|
|
|
|
|
hs = hs + e_embs + p_embs
|
|
|
|
|
# e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose(
|
|
|
|
|
# (0, 2, 1))
|
|
|
|
|
# hs = hs + p_embs + e_embs
|
|
|
|
|
hs = hs + p_embs
|
|
|
|
|
|
|
|
|
|
# forward decoder
|
|
|
|
|
if olens is not None and not is_inference:
|
|
|
|
@ -295,11 +302,12 @@ class FastSpeech2MIDI(FastSpeech2):
|
|
|
|
|
(paddle.shape(zs)[0], -1, self.odim))
|
|
|
|
|
|
|
|
|
|
# postnet -> (B, Lmax//r * r, odim)
|
|
|
|
|
if self.postnet is None:
|
|
|
|
|
after_outs = before_outs
|
|
|
|
|
else:
|
|
|
|
|
after_outs = before_outs + self.postnet(
|
|
|
|
|
before_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
|
|
|
|
|
# if self.postnet is None:
|
|
|
|
|
# after_outs = before_outs
|
|
|
|
|
# else:
|
|
|
|
|
# after_outs = before_outs + self.postnet(
|
|
|
|
|
# before_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
|
|
|
|
|
after_outs = before_outs
|
|
|
|
|
|
|
|
|
|
return before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits
|
|
|
|
|
|
|
|
|
@ -326,11 +334,11 @@ class FastSpeech2MIDI(FastSpeech2):
|
|
|
|
|
# (1, L, odim)
|
|
|
|
|
# use *_ to avoid bug in dygraph to static graph
|
|
|
|
|
hs, _ = self._forward(
|
|
|
|
|
xs,
|
|
|
|
|
note,
|
|
|
|
|
note_dur,
|
|
|
|
|
is_slur,
|
|
|
|
|
ilens,
|
|
|
|
|
xs=xs,
|
|
|
|
|
note=note,
|
|
|
|
|
note_dur=note_dur,
|
|
|
|
|
is_slur=is_slur,
|
|
|
|
|
ilens=ilens,
|
|
|
|
|
is_inference=True,
|
|
|
|
|
return_after_enc=True,
|
|
|
|
|
alpha=alpha,
|
|
|
|
@ -367,15 +375,15 @@ class FastSpeech2MIDI(FastSpeech2):
|
|
|
|
|
# (1, L, odim)
|
|
|
|
|
# use *_ to avoid bug in dygraph to static graph
|
|
|
|
|
hs, h_masks = self._forward(
|
|
|
|
|
xs,
|
|
|
|
|
note,
|
|
|
|
|
note_dur,
|
|
|
|
|
is_slur,
|
|
|
|
|
ilens,
|
|
|
|
|
olens,
|
|
|
|
|
ds,
|
|
|
|
|
ps,
|
|
|
|
|
es,
|
|
|
|
|
xs=xs,
|
|
|
|
|
note=note,
|
|
|
|
|
note_dur=note_dur,
|
|
|
|
|
is_slur=is_slur,
|
|
|
|
|
ilens=ilens,
|
|
|
|
|
olens=olens,
|
|
|
|
|
ds=ds,
|
|
|
|
|
ps=ps,
|
|
|
|
|
es=es,
|
|
|
|
|
return_after_enc=True,
|
|
|
|
|
is_train_diffusion=True,
|
|
|
|
|
alpha=alpha,
|
|
|
|
@ -446,11 +454,11 @@ class FastSpeech2MIDI(FastSpeech2):
|
|
|
|
|
|
|
|
|
|
# (1, L, odim)
|
|
|
|
|
_, outs, d_outs, p_outs, e_outs, _ = self._forward(
|
|
|
|
|
xs,
|
|
|
|
|
note,
|
|
|
|
|
note_dur,
|
|
|
|
|
is_slur,
|
|
|
|
|
ilens,
|
|
|
|
|
xs=xs,
|
|
|
|
|
note=note,
|
|
|
|
|
note_dur=note_dur,
|
|
|
|
|
is_slur=is_slur,
|
|
|
|
|
ilens=ilens,
|
|
|
|
|
ds=ds,
|
|
|
|
|
ps=ps,
|
|
|
|
|
es=es,
|
|
|
|
@ -460,20 +468,21 @@ class FastSpeech2MIDI(FastSpeech2):
|
|
|
|
|
else:
|
|
|
|
|
# (1, L, odim)
|
|
|
|
|
_, outs, d_outs, p_outs, e_outs, _ = self._forward(
|
|
|
|
|
xs,
|
|
|
|
|
note,
|
|
|
|
|
note_dur,
|
|
|
|
|
is_slur,
|
|
|
|
|
ilens,
|
|
|
|
|
xs=xs,
|
|
|
|
|
note=note,
|
|
|
|
|
note_dur=note_dur,
|
|
|
|
|
is_slur=is_slur,
|
|
|
|
|
ilens=ilens,
|
|
|
|
|
is_inference=True,
|
|
|
|
|
alpha=alpha,
|
|
|
|
|
spk_emb=spk_emb,
|
|
|
|
|
spk_id=spk_id, )
|
|
|
|
|
|
|
|
|
|
return outs[0], d_outs[0], p_outs[0], e_outs[0]
|
|
|
|
|
# return outs[0], d_outs[0], p_outs[0], e_outs[0]
|
|
|
|
|
return outs[0], d_outs[0], p_outs[0], None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FastSpeech2MIDILoss(nn.Layer):
|
|
|
|
|
class FastSpeech2MIDILoss(FastSpeech2Loss):
|
|
|
|
|
"""Loss function module for DiffSinger."""
|
|
|
|
|
|
|
|
|
|
def __init__(self, use_masking: bool=True,
|
|
|
|
@ -486,18 +495,7 @@ class FastSpeech2MIDILoss(nn.Layer):
|
|
|
|
|
Whether to weighted masking in loss calculation.
|
|
|
|
|
"""
|
|
|
|
|
assert check_argument_types()
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
assert (use_masking != use_weighted_masking) or not use_masking
|
|
|
|
|
self.use_masking = use_masking
|
|
|
|
|
self.use_weighted_masking = use_weighted_masking
|
|
|
|
|
|
|
|
|
|
# define criterions
|
|
|
|
|
reduction = "none" if self.use_weighted_masking else "mean"
|
|
|
|
|
self.l1_criterion = nn.L1Loss(reduction=reduction)
|
|
|
|
|
self.mse_criterion = nn.MSELoss(reduction=reduction)
|
|
|
|
|
self.duration_criterion = DurationPredictorLoss(reduction=reduction)
|
|
|
|
|
self.ce_criterion = nn.CrossEntropyLoss()
|
|
|
|
|
super().__init__(use_masking, use_weighted_masking)
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
@ -551,15 +549,23 @@ class FastSpeech2MIDILoss(nn.Layer):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
speaker_loss = 0.0
|
|
|
|
|
l1_loss = duration_loss = pitch_loss = energy_loss = speaker_loss = ssim_loss = 0.0
|
|
|
|
|
|
|
|
|
|
out_pad_masks = make_pad_mask(olens).unsqueeze(-1)
|
|
|
|
|
before_outs_batch = masked_fill(before_outs, out_pad_masks, 0.0)
|
|
|
|
|
# print(before_outs.shape, ys.shape)
|
|
|
|
|
ssim_loss = 1.0 - ssim(before_outs_batch.unsqueeze(1), ys.unsqueeze(1))
|
|
|
|
|
ssim_loss = ssim_loss * 0.5
|
|
|
|
|
|
|
|
|
|
# apply mask to remove padded part
|
|
|
|
|
if self.use_masking:
|
|
|
|
|
out_masks = make_non_pad_mask(olens).unsqueeze(-1)
|
|
|
|
|
before_outs = before_outs.masked_select(
|
|
|
|
|
out_masks.broadcast_to(before_outs.shape))
|
|
|
|
|
if after_outs is not None:
|
|
|
|
|
after_outs = after_outs.masked_select(
|
|
|
|
|
out_masks.broadcast_to(after_outs.shape))
|
|
|
|
|
|
|
|
|
|
# if after_outs is not None:
|
|
|
|
|
# after_outs = after_outs.masked_select(
|
|
|
|
|
# out_masks.broadcast_to(after_outs.shape))
|
|
|
|
|
ys = ys.masked_select(out_masks.broadcast_to(ys.shape))
|
|
|
|
|
duration_masks = make_non_pad_mask(ilens)
|
|
|
|
|
d_outs = d_outs.masked_select(
|
|
|
|
@ -568,8 +574,8 @@ class FastSpeech2MIDILoss(nn.Layer):
|
|
|
|
|
pitch_masks = out_masks
|
|
|
|
|
p_outs = p_outs.masked_select(
|
|
|
|
|
pitch_masks.broadcast_to(p_outs.shape))
|
|
|
|
|
e_outs = e_outs.masked_select(
|
|
|
|
|
pitch_masks.broadcast_to(e_outs.shape))
|
|
|
|
|
# e_outs = e_outs.masked_select(
|
|
|
|
|
# pitch_masks.broadcast_to(e_outs.shape))
|
|
|
|
|
ps = ps.masked_select(pitch_masks.broadcast_to(ps.shape))
|
|
|
|
|
es = es.masked_select(pitch_masks.broadcast_to(es.shape))
|
|
|
|
|
|
|
|
|
@ -585,11 +591,17 @@ class FastSpeech2MIDILoss(nn.Layer):
|
|
|
|
|
|
|
|
|
|
# calculate loss
|
|
|
|
|
l1_loss = self.l1_criterion(before_outs, ys)
|
|
|
|
|
if after_outs is not None:
|
|
|
|
|
l1_loss += self.l1_criterion(after_outs, ys)
|
|
|
|
|
# if after_outs is not None:
|
|
|
|
|
# l1_loss += self.l1_criterion(after_outs, ys)
|
|
|
|
|
# ssim_loss += (1.0 - ssim(after_outs, ys))
|
|
|
|
|
l1_loss = l1_loss * 0.5
|
|
|
|
|
|
|
|
|
|
duration_loss = self.duration_criterion(d_outs, ds)
|
|
|
|
|
pitch_loss = self.mse_criterion(p_outs, ps)
|
|
|
|
|
energy_loss = self.mse_criterion(e_outs, es)
|
|
|
|
|
# print("ppppppppppoooooooooooo: ", p_outs, p_outs.shape)
|
|
|
|
|
# print("ppppppppppssssssssssss: ", ps, ps.shape)
|
|
|
|
|
# pitch_loss = self.mse_criterion(p_outs, ps)
|
|
|
|
|
# energy_loss = self.mse_criterion(e_outs, es)
|
|
|
|
|
pitch_loss = self.l1_criterion(p_outs, ps)
|
|
|
|
|
|
|
|
|
|
if spk_logits is not None and spk_ids is not None:
|
|
|
|
|
speaker_loss = self.ce_criterion(spk_logits, spk_ids) / batch_size
|
|
|
|
@ -618,8 +630,8 @@ class FastSpeech2MIDILoss(nn.Layer):
|
|
|
|
|
pitch_loss = pitch_loss.multiply(pitch_weights)
|
|
|
|
|
pitch_loss = pitch_loss.masked_select(
|
|
|
|
|
pitch_masks.broadcast_to(pitch_loss.shape)).sum()
|
|
|
|
|
energy_loss = energy_loss.multiply(pitch_weights)
|
|
|
|
|
energy_loss = energy_loss.masked_select(
|
|
|
|
|
pitch_masks.broadcast_to(energy_loss.shape)).sum()
|
|
|
|
|
# energy_loss = energy_loss.multiply(pitch_weights)
|
|
|
|
|
# energy_loss = energy_loss.masked_select(
|
|
|
|
|
# pitch_masks.broadcast_to(energy_loss.shape)).sum()
|
|
|
|
|
|
|
|
|
|
return l1_loss, duration_loss, pitch_loss, energy_loss, speaker_loss
|
|
|
|
|
return l1_loss, ssim_loss, duration_loss, pitch_loss, energy_loss, speaker_loss
|
|
|
|
|