|
|
|
@ -23,10 +23,10 @@ from typeguard import check_argument_types
|
|
|
|
|
|
|
|
|
|
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2
|
|
|
|
|
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Loss
|
|
|
|
|
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
|
|
|
|
|
from paddlespeech.t2s.modules.nets_utils import make_pad_mask
|
|
|
|
|
from paddlespeech.t2s.modules.losses import ssim
|
|
|
|
|
from paddlespeech.t2s.modules.masked_fill import masked_fill
|
|
|
|
|
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
|
|
|
|
|
from paddlespeech.t2s.modules.nets_utils import make_pad_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FastSpeech2MIDI(FastSpeech2):
|
|
|
|
@ -61,18 +61,18 @@ class FastSpeech2MIDI(FastSpeech2):
|
|
|
|
|
self.note_embed_dim = self.is_slur_embed_dim = fastspeech2_params[
|
|
|
|
|
"adim"]
|
|
|
|
|
|
|
|
|
|
if note_num is not None:
|
|
|
|
|
self.note_embedding_table = nn.Embedding(
|
|
|
|
|
num_embeddings=note_num,
|
|
|
|
|
embedding_dim=self.note_embed_dim,
|
|
|
|
|
padding_idx=self.padding_idx)
|
|
|
|
|
self.note_dur_layer = nn.Linear(1, self.note_embed_dim)
|
|
|
|
|
# note_ embed
|
|
|
|
|
self.note_embedding_table = nn.Embedding(
|
|
|
|
|
num_embeddings=note_num,
|
|
|
|
|
embedding_dim=self.note_embed_dim,
|
|
|
|
|
padding_idx=self.padding_idx)
|
|
|
|
|
self.note_dur_layer = nn.Linear(1, self.note_embed_dim)
|
|
|
|
|
|
|
|
|
|
if is_slur_num is not None:
|
|
|
|
|
self.is_slur_embedding_table = nn.Embedding(
|
|
|
|
|
num_embeddings=is_slur_num,
|
|
|
|
|
embedding_dim=self.is_slur_embed_dim,
|
|
|
|
|
padding_idx=self.padding_idx)
|
|
|
|
|
# slur embed
|
|
|
|
|
self.is_slur_embedding_table = nn.Embedding(
|
|
|
|
|
num_embeddings=is_slur_num,
|
|
|
|
|
embedding_dim=self.is_slur_embed_dim,
|
|
|
|
|
padding_idx=self.padding_idx)
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
@ -203,7 +203,7 @@ class FastSpeech2MIDI(FastSpeech2):
|
|
|
|
|
spk_emb = self.spk_embedding_table(spk_id)
|
|
|
|
|
hs = self._integrate_with_spk_embed(hs, spk_emb)
|
|
|
|
|
|
|
|
|
|
# forward duration predictor and variance predictors
|
|
|
|
|
# forward duration predictor (phone-level) and variance predictors (frame-level)
|
|
|
|
|
d_masks = make_pad_mask(ilens)
|
|
|
|
|
if olens is not None:
|
|
|
|
|
pitch_masks = make_pad_mask(olens).unsqueeze(-1)
|
|
|
|
@ -214,13 +214,12 @@ class FastSpeech2MIDI(FastSpeech2):
|
|
|
|
|
if is_train_diffusion:
|
|
|
|
|
hs = self.length_regulator(hs, ds, is_inference=False)
|
|
|
|
|
p_outs = self.pitch_predictor(hs.detach(), pitch_masks)
|
|
|
|
|
# e_outs = self.energy_predictor(hs.detach(), pitch_masks)
|
|
|
|
|
e_outs = self.energy_predictor(hs.detach(), pitch_masks)
|
|
|
|
|
p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose(
|
|
|
|
|
(0, 2, 1))
|
|
|
|
|
# e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose(
|
|
|
|
|
# (0, 2, 1))
|
|
|
|
|
# hs = hs + p_embs + e_embs
|
|
|
|
|
hs = hs + p_embs
|
|
|
|
|
e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose(
|
|
|
|
|
(0, 2, 1))
|
|
|
|
|
hs = hs + p_embs + e_embs
|
|
|
|
|
|
|
|
|
|
elif is_inference:
|
|
|
|
|
# (B, Tmax)
|
|
|
|
@ -240,20 +239,19 @@ class FastSpeech2MIDI(FastSpeech2):
|
|
|
|
|
else:
|
|
|
|
|
p_outs = self.pitch_predictor(hs, pitch_masks)
|
|
|
|
|
|
|
|
|
|
# if es is not None:
|
|
|
|
|
# e_outs = es
|
|
|
|
|
# else:
|
|
|
|
|
# if self.stop_gradient_from_energy_predictor:
|
|
|
|
|
# e_outs = self.energy_predictor(hs.detach(), pitch_masks)
|
|
|
|
|
# else:
|
|
|
|
|
# e_outs = self.energy_predictor(hs, pitch_masks)
|
|
|
|
|
if es is not None:
|
|
|
|
|
e_outs = es
|
|
|
|
|
else:
|
|
|
|
|
if self.stop_gradient_from_energy_predictor:
|
|
|
|
|
e_outs = self.energy_predictor(hs.detach(), pitch_masks)
|
|
|
|
|
else:
|
|
|
|
|
e_outs = self.energy_predictor(hs, pitch_masks)
|
|
|
|
|
|
|
|
|
|
p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose(
|
|
|
|
|
(0, 2, 1))
|
|
|
|
|
# e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose(
|
|
|
|
|
# (0, 2, 1))
|
|
|
|
|
# hs = hs + p_embs + e_embs
|
|
|
|
|
hs = hs + p_embs
|
|
|
|
|
e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose(
|
|
|
|
|
(0, 2, 1))
|
|
|
|
|
hs = hs + p_embs + e_embs
|
|
|
|
|
|
|
|
|
|
# training
|
|
|
|
|
else:
|
|
|
|
@ -264,16 +262,15 @@ class FastSpeech2MIDI(FastSpeech2):
|
|
|
|
|
p_outs = self.pitch_predictor(hs.detach(), pitch_masks)
|
|
|
|
|
else:
|
|
|
|
|
p_outs = self.pitch_predictor(hs, pitch_masks)
|
|
|
|
|
# if self.stop_gradient_from_energy_predictor:
|
|
|
|
|
# e_outs = self.energy_predictor(hs.detach(), pitch_masks)
|
|
|
|
|
# else:
|
|
|
|
|
# e_outs = self.energy_predictor(hs, pitch_masks)
|
|
|
|
|
if self.stop_gradient_from_energy_predictor:
|
|
|
|
|
e_outs = self.energy_predictor(hs.detach(), pitch_masks)
|
|
|
|
|
else:
|
|
|
|
|
e_outs = self.energy_predictor(hs, pitch_masks)
|
|
|
|
|
p_embs = self.pitch_embed(ps.transpose((0, 2, 1))).transpose(
|
|
|
|
|
(0, 2, 1))
|
|
|
|
|
# e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose(
|
|
|
|
|
# (0, 2, 1))
|
|
|
|
|
# hs = hs + p_embs + e_embs
|
|
|
|
|
hs = hs + p_embs
|
|
|
|
|
e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose(
|
|
|
|
|
(0, 2, 1))
|
|
|
|
|
hs = hs + p_embs + e_embs
|
|
|
|
|
|
|
|
|
|
# forward decoder
|
|
|
|
|
if olens is not None and not is_inference:
|
|
|
|
@ -302,11 +299,11 @@ class FastSpeech2MIDI(FastSpeech2):
|
|
|
|
|
(paddle.shape(zs)[0], -1, self.odim))
|
|
|
|
|
|
|
|
|
|
# postnet -> (B, Lmax//r * r, odim)
|
|
|
|
|
# if self.postnet is None:
|
|
|
|
|
# after_outs = before_outs
|
|
|
|
|
# else:
|
|
|
|
|
# after_outs = before_outs + self.postnet(
|
|
|
|
|
# before_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
|
|
|
|
|
if self.postnet is None:
|
|
|
|
|
after_outs = before_outs
|
|
|
|
|
else:
|
|
|
|
|
after_outs = before_outs + self.postnet(
|
|
|
|
|
before_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
|
|
|
|
|
after_outs = before_outs
|
|
|
|
|
|
|
|
|
|
return before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits
|
|
|
|
@ -478,8 +475,7 @@ class FastSpeech2MIDI(FastSpeech2):
|
|
|
|
|
spk_emb=spk_emb,
|
|
|
|
|
spk_id=spk_id, )
|
|
|
|
|
|
|
|
|
|
# return outs[0], d_outs[0], p_outs[0], e_outs[0]
|
|
|
|
|
return outs[0], d_outs[0], p_outs[0], None
|
|
|
|
|
return outs[0], d_outs[0], p_outs[0], e_outs[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FastSpeech2MIDILoss(FastSpeech2Loss):
|
|
|
|
@ -551,21 +547,21 @@ class FastSpeech2MIDILoss(FastSpeech2Loss):
|
|
|
|
|
"""
|
|
|
|
|
l1_loss = duration_loss = pitch_loss = energy_loss = speaker_loss = ssim_loss = 0.0
|
|
|
|
|
|
|
|
|
|
out_pad_masks = make_pad_mask(olens).unsqueeze(-1)
|
|
|
|
|
before_outs_batch = masked_fill(before_outs, out_pad_masks, 0.0)
|
|
|
|
|
# print(before_outs.shape, ys.shape)
|
|
|
|
|
ssim_loss = 1.0 - ssim(before_outs_batch.unsqueeze(1), ys.unsqueeze(1))
|
|
|
|
|
ssim_loss = ssim_loss * 0.5
|
|
|
|
|
|
|
|
|
|
# apply mask to remove padded part
|
|
|
|
|
if self.use_masking:
|
|
|
|
|
# make feature for ssim loss
|
|
|
|
|
out_pad_masks = make_pad_mask(olens).unsqueeze(-1)
|
|
|
|
|
before_outs_ssim = masked_fill(before_outs, out_pad_masks, 0.0)
|
|
|
|
|
if after_outs is not None:
|
|
|
|
|
after_outs_ssim = masked_fill(after_outs, out_pad_masks, 0.0)
|
|
|
|
|
ys_ssim = masked_fill(ys, out_pad_masks, 0.0)
|
|
|
|
|
|
|
|
|
|
out_masks = make_non_pad_mask(olens).unsqueeze(-1)
|
|
|
|
|
before_outs = before_outs.masked_select(
|
|
|
|
|
out_masks.broadcast_to(before_outs.shape))
|
|
|
|
|
|
|
|
|
|
# if after_outs is not None:
|
|
|
|
|
# after_outs = after_outs.masked_select(
|
|
|
|
|
# out_masks.broadcast_to(after_outs.shape))
|
|
|
|
|
if after_outs is not None:
|
|
|
|
|
after_outs = after_outs.masked_select(
|
|
|
|
|
out_masks.broadcast_to(after_outs.shape))
|
|
|
|
|
ys = ys.masked_select(out_masks.broadcast_to(ys.shape))
|
|
|
|
|
duration_masks = make_non_pad_mask(ilens)
|
|
|
|
|
d_outs = d_outs.masked_select(
|
|
|
|
@ -574,8 +570,8 @@ class FastSpeech2MIDILoss(FastSpeech2Loss):
|
|
|
|
|
pitch_masks = out_masks
|
|
|
|
|
p_outs = p_outs.masked_select(
|
|
|
|
|
pitch_masks.broadcast_to(p_outs.shape))
|
|
|
|
|
# e_outs = e_outs.masked_select(
|
|
|
|
|
# pitch_masks.broadcast_to(e_outs.shape))
|
|
|
|
|
e_outs = e_outs.masked_select(
|
|
|
|
|
pitch_masks.broadcast_to(e_outs.shape))
|
|
|
|
|
ps = ps.masked_select(pitch_masks.broadcast_to(ps.shape))
|
|
|
|
|
es = es.masked_select(pitch_masks.broadcast_to(es.shape))
|
|
|
|
|
|
|
|
|
@ -591,17 +587,18 @@ class FastSpeech2MIDILoss(FastSpeech2Loss):
|
|
|
|
|
|
|
|
|
|
# calculate loss
|
|
|
|
|
l1_loss = self.l1_criterion(before_outs, ys)
|
|
|
|
|
# if after_outs is not None:
|
|
|
|
|
# l1_loss += self.l1_criterion(after_outs, ys)
|
|
|
|
|
# ssim_loss += (1.0 - ssim(after_outs, ys))
|
|
|
|
|
ssim_loss = 1.0 - ssim(
|
|
|
|
|
before_outs_ssim.unsqueeze(1), ys_ssim.unsqueeze(1))
|
|
|
|
|
if after_outs is not None:
|
|
|
|
|
l1_loss += self.l1_criterion(after_outs, ys)
|
|
|
|
|
ssim_loss += (
|
|
|
|
|
1.0 - ssim(after_outs_ssim.unsqueeze(1), ys_ssim.unsqueeze(1)))
|
|
|
|
|
l1_loss = l1_loss * 0.5
|
|
|
|
|
|
|
|
|
|
ssim_loss = ssim_loss * 0.5
|
|
|
|
|
|
|
|
|
|
duration_loss = self.duration_criterion(d_outs, ds)
|
|
|
|
|
# print("ppppppppppoooooooooooo: ", p_outs, p_outs.shape)
|
|
|
|
|
# print("ppppppppppssssssssssss: ", ps, ps.shape)
|
|
|
|
|
# pitch_loss = self.mse_criterion(p_outs, ps)
|
|
|
|
|
# energy_loss = self.mse_criterion(e_outs, es)
|
|
|
|
|
pitch_loss = self.l1_criterion(p_outs, ps)
|
|
|
|
|
energy_loss = self.l1_criterion(e_outs, es)
|
|
|
|
|
|
|
|
|
|
if spk_logits is not None and spk_ids is not None:
|
|
|
|
|
speaker_loss = self.ce_criterion(spk_logits, spk_ids) / batch_size
|
|
|
|
@ -623,6 +620,9 @@ class FastSpeech2MIDILoss(FastSpeech2Loss):
|
|
|
|
|
l1_loss = l1_loss.multiply(out_weights)
|
|
|
|
|
l1_loss = l1_loss.masked_select(
|
|
|
|
|
out_masks.broadcast_to(l1_loss.shape)).sum()
|
|
|
|
|
ssim_loss = ssim_loss.multiply(out_weights)
|
|
|
|
|
ssim_loss = ssim_loss.masked_select(
|
|
|
|
|
out_masks.broadcast_to(ssim_loss.shape)).sum()
|
|
|
|
|
duration_loss = (duration_loss.multiply(duration_weights)
|
|
|
|
|
.masked_select(duration_masks).sum())
|
|
|
|
|
pitch_masks = out_masks
|
|
|
|
@ -630,8 +630,8 @@ class FastSpeech2MIDILoss(FastSpeech2Loss):
|
|
|
|
|
pitch_loss = pitch_loss.multiply(pitch_weights)
|
|
|
|
|
pitch_loss = pitch_loss.masked_select(
|
|
|
|
|
pitch_masks.broadcast_to(pitch_loss.shape)).sum()
|
|
|
|
|
# energy_loss = energy_loss.multiply(pitch_weights)
|
|
|
|
|
# energy_loss = energy_loss.masked_select(
|
|
|
|
|
# pitch_masks.broadcast_to(energy_loss.shape)).sum()
|
|
|
|
|
energy_loss = energy_loss.multiply(pitch_weights)
|
|
|
|
|
energy_loss = energy_loss.masked_select(
|
|
|
|
|
pitch_masks.broadcast_to(energy_loss.shape)).sum()
|
|
|
|
|
|
|
|
|
|
return l1_loss, ssim_loss, duration_loss, pitch_loss, energy_loss, speaker_loss
|
|
|
|
|