You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
655 lines
25 KiB
655 lines
25 KiB
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# Modified from espnet(https://github.com/espnet/espnet)
|
|
from typing import Any
|
|
from typing import Dict
|
|
from typing import Sequence
|
|
from typing import Tuple
|
|
|
|
import paddle
|
|
from paddle import nn
|
|
from typeguard import check_argument_types
|
|
|
|
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2
|
|
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Loss
|
|
from paddlespeech.t2s.modules.losses import ssim
|
|
from paddlespeech.t2s.modules.masked_fill import masked_fill
|
|
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
|
|
from paddlespeech.t2s.modules.nets_utils import make_pad_mask
|
|
|
|
|
|
class FastSpeech2MIDI(FastSpeech2):
|
|
"""The Fastspeech2 module of DiffSinger.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
# fastspeech2 network structure related
|
|
idim: int,
|
|
odim: int,
|
|
fastspeech2_params: Dict[str, Any],
|
|
# note emb
|
|
note_num: int=300,
|
|
# is_slur emb
|
|
is_slur_num: int=2,
|
|
use_energy_pred: bool=False,
|
|
use_postnet: bool=False, ):
|
|
"""Initialize FastSpeech2 module for svs.
|
|
Args:
|
|
fastspeech2_params (Dict):
|
|
The config of FastSpeech2 module on DiffSinger model
|
|
note_num (Optional[int]):
|
|
Number of note. If not None, assume that the
|
|
note_ids will be provided as the input and use note_embedding_table.
|
|
is_slur_num (Optional[int]):
|
|
Number of note. If not None, assume that the
|
|
is_slur_ids will be provided as the input
|
|
|
|
"""
|
|
assert check_argument_types()
|
|
super().__init__(idim=idim, odim=odim, **fastspeech2_params)
|
|
self.use_energy_pred = use_energy_pred
|
|
self.use_postnet = use_postnet
|
|
if not self.use_postnet:
|
|
self.postnet = None
|
|
|
|
self.note_embed_dim = self.is_slur_embed_dim = fastspeech2_params[
|
|
"adim"]
|
|
|
|
# note_ embed
|
|
self.note_embedding_table = nn.Embedding(
|
|
num_embeddings=note_num,
|
|
embedding_dim=self.note_embed_dim,
|
|
padding_idx=self.padding_idx)
|
|
self.note_dur_layer = nn.Linear(1, self.note_embed_dim)
|
|
|
|
# slur embed
|
|
self.is_slur_embedding_table = nn.Embedding(
|
|
num_embeddings=is_slur_num,
|
|
embedding_dim=self.is_slur_embed_dim,
|
|
padding_idx=self.padding_idx)
|
|
|
|
def forward(
|
|
self,
|
|
text: paddle.Tensor,
|
|
note: paddle.Tensor,
|
|
note_dur: paddle.Tensor,
|
|
is_slur: paddle.Tensor,
|
|
text_lengths: paddle.Tensor,
|
|
speech: paddle.Tensor,
|
|
speech_lengths: paddle.Tensor,
|
|
durations: paddle.Tensor,
|
|
pitch: paddle.Tensor,
|
|
energy: paddle.Tensor,
|
|
spk_emb: paddle.Tensor=None,
|
|
spk_id: paddle.Tensor=None,
|
|
) -> Tuple[paddle.Tensor, Dict[str, paddle.Tensor], paddle.Tensor]:
|
|
"""Calculate forward propagation.
|
|
|
|
Args:
|
|
text(Tensor(int64)):
|
|
Batch of padded token (phone) ids (B, Tmax).
|
|
note(Tensor(int64)):
|
|
Batch of padded note (element in music score) ids (B, Tmax).
|
|
note_dur(Tensor(float32)):
|
|
Batch of padded note durations in seconds (element in music score) (B, Tmax).
|
|
is_slur(Tensor(int64)):
|
|
Batch of padded slur (element in music score) ids (B, Tmax).
|
|
text_lengths(Tensor(int64)):
|
|
Batch of phone lengths of each input (B,).
|
|
speech(Tensor[float32]):
|
|
Batch of padded target features (e.g. mel) (B, Lmax, odim).
|
|
speech_lengths(Tensor(int64)):
|
|
Batch of the lengths of each target features (B,).
|
|
durations(Tensor(int64)):
|
|
Batch of padded token durations in frame (B, Tmax).
|
|
pitch(Tensor[float32]):
|
|
Batch of padded frame-averaged pitch (B, Lmax, 1).
|
|
energy(Tensor[float32]):
|
|
Batch of padded frame-averaged energy (B, Lmax, 1).
|
|
spk_emb(Tensor[float32], optional):
|
|
Batch of speaker embeddings (B, spk_embed_dim).
|
|
spk_id(Tnesor[int64], optional(int64)):
|
|
Batch of speaker ids (B,)
|
|
|
|
Returns:
|
|
|
|
"""
|
|
xs = paddle.cast(text, 'int64')
|
|
note = paddle.cast(note, 'int64')
|
|
note_dur = paddle.cast(note_dur, 'float32')
|
|
is_slur = paddle.cast(is_slur, 'int64')
|
|
ilens = paddle.cast(text_lengths, 'int64')
|
|
olens = paddle.cast(speech_lengths, 'int64')
|
|
ds = paddle.cast(durations, 'int64')
|
|
ps = pitch
|
|
es = energy
|
|
ys = speech
|
|
olens = speech_lengths
|
|
if spk_id is not None:
|
|
spk_id = paddle.cast(spk_id, 'int64')
|
|
# forward propagation
|
|
before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits = self._forward(
|
|
xs=xs,
|
|
note=note,
|
|
note_dur=note_dur,
|
|
is_slur=is_slur,
|
|
ilens=ilens,
|
|
olens=olens,
|
|
ds=ds,
|
|
ps=ps,
|
|
es=es,
|
|
is_inference=False,
|
|
spk_emb=spk_emb,
|
|
spk_id=spk_id, )
|
|
# modify mod part of groundtruth
|
|
if self.reduction_factor > 1:
|
|
olens = olens - olens % self.reduction_factor
|
|
max_olen = max(olens)
|
|
ys = ys[:, :max_olen]
|
|
|
|
return before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits
|
|
|
|
def _forward(
|
|
self,
|
|
xs: paddle.Tensor,
|
|
note: paddle.Tensor,
|
|
note_dur: paddle.Tensor,
|
|
is_slur: paddle.Tensor,
|
|
ilens: paddle.Tensor,
|
|
olens: paddle.Tensor=None,
|
|
ds: paddle.Tensor=None,
|
|
ps: paddle.Tensor=None,
|
|
es: paddle.Tensor=None,
|
|
is_inference: bool=False,
|
|
is_train_diffusion: bool=False,
|
|
return_after_enc=False,
|
|
alpha: float=1.0,
|
|
spk_emb=None,
|
|
spk_id=None, ) -> Sequence[paddle.Tensor]:
|
|
|
|
before_outs = after_outs = d_outs = p_outs = e_outs = spk_logits = None
|
|
# forward encoder
|
|
masks = self._source_mask(ilens)
|
|
note_emb = self.note_embedding_table(note)
|
|
note_dur_emb = self.note_dur_layer(paddle.unsqueeze(note_dur, axis=-1))
|
|
is_slur_emb = self.is_slur_embedding_table(is_slur)
|
|
|
|
# (B, Tmax, adim)
|
|
hs, _ = self.encoder(
|
|
xs=xs,
|
|
masks=masks,
|
|
note_emb=note_emb,
|
|
note_dur_emb=note_dur_emb,
|
|
is_slur_emb=is_slur_emb, )
|
|
|
|
if self.spk_num and self.enable_speaker_classifier and not is_inference:
|
|
hs_for_spk_cls = self.grad_reverse(hs)
|
|
spk_logits = self.speaker_classifier(hs_for_spk_cls, ilens)
|
|
else:
|
|
spk_logits = None
|
|
|
|
# integrate speaker embedding
|
|
if self.spk_embed_dim is not None:
|
|
# spk_emb has a higher priority than spk_id
|
|
if spk_emb is not None:
|
|
hs = self._integrate_with_spk_embed(hs, spk_emb)
|
|
elif spk_id is not None:
|
|
spk_emb = self.spk_embedding_table(spk_id)
|
|
hs = self._integrate_with_spk_embed(hs, spk_emb)
|
|
|
|
# forward duration predictor (phone-level) and variance predictors (frame-level)
|
|
d_masks = make_pad_mask(ilens)
|
|
if olens is not None:
|
|
pitch_masks = make_pad_mask(olens).unsqueeze(-1)
|
|
else:
|
|
pitch_masks = None
|
|
|
|
# inference for decoder input for diffusion
|
|
if is_train_diffusion:
|
|
hs = self.length_regulator(hs, ds, is_inference=False)
|
|
p_outs = self.pitch_predictor(hs.detach(), pitch_masks)
|
|
p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose(
|
|
(0, 2, 1))
|
|
hs += p_embs
|
|
if self.use_energy_pred:
|
|
e_outs = self.energy_predictor(hs.detach(), pitch_masks)
|
|
e_embs = self.energy_embed(
|
|
e_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
|
|
hs += e_embs
|
|
|
|
elif is_inference:
|
|
# (B, Tmax)
|
|
if ds is not None:
|
|
d_outs = ds
|
|
else:
|
|
d_outs = self.duration_predictor.inference(hs, d_masks)
|
|
|
|
# (B, Lmax, adim)
|
|
hs = self.length_regulator(hs, d_outs, alpha, is_inference=True)
|
|
|
|
if ps is not None:
|
|
p_outs = ps
|
|
else:
|
|
if self.stop_gradient_from_pitch_predictor:
|
|
p_outs = self.pitch_predictor(hs.detach(), pitch_masks)
|
|
else:
|
|
p_outs = self.pitch_predictor(hs, pitch_masks)
|
|
p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose(
|
|
(0, 2, 1))
|
|
hs += p_embs
|
|
|
|
if self.use_energy_pred:
|
|
if es is not None:
|
|
e_outs = es
|
|
else:
|
|
if self.stop_gradient_from_energy_predictor:
|
|
e_outs = self.energy_predictor(hs.detach(), pitch_masks)
|
|
else:
|
|
e_outs = self.energy_predictor(hs, pitch_masks)
|
|
e_embs = self.energy_embed(
|
|
e_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
|
|
hs += e_embs
|
|
|
|
# training
|
|
else:
|
|
d_outs = self.duration_predictor(hs, d_masks)
|
|
# (B, Lmax, adim)
|
|
hs = self.length_regulator(hs, ds, is_inference=False)
|
|
if self.stop_gradient_from_pitch_predictor:
|
|
p_outs = self.pitch_predictor(hs.detach(), pitch_masks)
|
|
else:
|
|
p_outs = self.pitch_predictor(hs, pitch_masks)
|
|
p_embs = self.pitch_embed(ps.transpose((0, 2, 1))).transpose(
|
|
(0, 2, 1))
|
|
hs += p_embs
|
|
|
|
if self.use_energy_pred:
|
|
if self.stop_gradient_from_energy_predictor:
|
|
e_outs = self.energy_predictor(hs.detach(), pitch_masks)
|
|
else:
|
|
e_outs = self.energy_predictor(hs, pitch_masks)
|
|
e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose(
|
|
(0, 2, 1))
|
|
hs += e_embs
|
|
|
|
# forward decoder
|
|
if olens is not None and not is_inference:
|
|
if self.reduction_factor > 1:
|
|
olens_in = paddle.to_tensor(
|
|
[olen // self.reduction_factor for olen in olens.numpy()])
|
|
else:
|
|
olens_in = olens
|
|
# (B, 1, T)
|
|
h_masks = self._source_mask(olens_in)
|
|
else:
|
|
h_masks = None
|
|
|
|
if return_after_enc:
|
|
return hs, h_masks
|
|
|
|
if self.decoder_type == 'cnndecoder':
|
|
# remove output masks for dygraph to static graph
|
|
zs = self.decoder(hs, h_masks)
|
|
before_outs = zs
|
|
else:
|
|
# (B, Lmax, adim)
|
|
zs, _ = self.decoder(hs, h_masks)
|
|
# (B, Lmax, odim)
|
|
before_outs = self.feat_out(zs).reshape(
|
|
(paddle.shape(zs)[0], -1, self.odim))
|
|
|
|
# postnet -> (B, Lmax//r * r, odim)
|
|
if self.postnet is None:
|
|
after_outs = before_outs
|
|
else:
|
|
after_outs = before_outs + self.postnet(
|
|
before_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
|
|
|
|
return before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits
|
|
|
|
def encoder_infer(
|
|
self,
|
|
text: paddle.Tensor,
|
|
note: paddle.Tensor,
|
|
note_dur: paddle.Tensor,
|
|
is_slur: paddle.Tensor,
|
|
alpha: float=1.0,
|
|
spk_emb=None,
|
|
spk_id=None,
|
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
|
xs = paddle.cast(text, 'int64').unsqueeze(0)
|
|
note = paddle.cast(note, 'int64').unsqueeze(0)
|
|
note_dur = paddle.cast(note_dur, 'float32').unsqueeze(0)
|
|
is_slur = paddle.cast(is_slur, 'int64').unsqueeze(0)
|
|
# setup batch axis
|
|
ilens = paddle.shape(xs)[1]
|
|
|
|
if spk_emb is not None:
|
|
spk_emb = spk_emb.unsqueeze(0)
|
|
|
|
# (1, L, odim)
|
|
# use *_ to avoid bug in dygraph to static graph
|
|
hs, _ = self._forward(
|
|
xs=xs,
|
|
note=note,
|
|
note_dur=note_dur,
|
|
is_slur=is_slur,
|
|
ilens=ilens,
|
|
is_inference=True,
|
|
return_after_enc=True,
|
|
alpha=alpha,
|
|
spk_emb=spk_emb,
|
|
spk_id=spk_id, )
|
|
return hs
|
|
|
|
# get encoder output for diffusion training
|
|
def encoder_infer_batch(
|
|
self,
|
|
text: paddle.Tensor,
|
|
note: paddle.Tensor,
|
|
note_dur: paddle.Tensor,
|
|
is_slur: paddle.Tensor,
|
|
text_lengths: paddle.Tensor,
|
|
speech_lengths: paddle.Tensor,
|
|
ds: paddle.Tensor=None,
|
|
ps: paddle.Tensor=None,
|
|
es: paddle.Tensor=None,
|
|
alpha: float=1.0,
|
|
spk_emb=None,
|
|
spk_id=None, ) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
|
|
|
xs = paddle.cast(text, 'int64')
|
|
note = paddle.cast(note, 'int64')
|
|
note_dur = paddle.cast(note_dur, 'float32')
|
|
is_slur = paddle.cast(is_slur, 'int64')
|
|
ilens = paddle.cast(text_lengths, 'int64')
|
|
olens = paddle.cast(speech_lengths, 'int64')
|
|
|
|
if spk_emb is not None:
|
|
spk_emb = spk_emb.unsqueeze(0)
|
|
|
|
# (1, L, odim)
|
|
# use *_ to avoid bug in dygraph to static graph
|
|
hs, h_masks = self._forward(
|
|
xs=xs,
|
|
note=note,
|
|
note_dur=note_dur,
|
|
is_slur=is_slur,
|
|
ilens=ilens,
|
|
olens=olens,
|
|
ds=ds,
|
|
ps=ps,
|
|
es=es,
|
|
return_after_enc=True,
|
|
is_train_diffusion=True,
|
|
alpha=alpha,
|
|
spk_emb=spk_emb,
|
|
spk_id=spk_id, )
|
|
return hs, h_masks
|
|
|
|
def inference(
|
|
self,
|
|
text: paddle.Tensor,
|
|
note: paddle.Tensor,
|
|
note_dur: paddle.Tensor,
|
|
is_slur: paddle.Tensor,
|
|
durations: paddle.Tensor=None,
|
|
pitch: paddle.Tensor=None,
|
|
energy: paddle.Tensor=None,
|
|
alpha: float=1.0,
|
|
use_teacher_forcing: bool=False,
|
|
spk_emb=None,
|
|
spk_id=None,
|
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
|
"""Generate the sequence of features given the sequences of characters.
|
|
|
|
Args:
|
|
text(Tensor(int64)):
|
|
Input sequence of characters (T,).
|
|
note(Tensor(int64)):
|
|
Input note (element in music score) ids (T,).
|
|
note_dur(Tensor(float32)):
|
|
Input note durations in seconds (element in music score) (T,).
|
|
is_slur(Tensor(int64)):
|
|
Input slur (element in music score) ids (T,).
|
|
durations(Tensor, optional (int64)):
|
|
Groundtruth of duration (T,).
|
|
pitch(Tensor, optional):
|
|
Groundtruth of token-averaged pitch (T, 1).
|
|
energy(Tensor, optional):
|
|
Groundtruth of token-averaged energy (T, 1).
|
|
alpha(float, optional):
|
|
Alpha to control the speed.
|
|
use_teacher_forcing(bool, optional):
|
|
Whether to use teacher forcing.
|
|
If true, groundtruth of duration, pitch and energy will be used.
|
|
spk_emb(Tensor, optional, optional):
|
|
peaker embedding vector (spk_embed_dim,). (Default value = None)
|
|
spk_id(Tensor, optional(int64), optional):
|
|
spk ids (1,). (Default value = None)
|
|
|
|
Returns:
|
|
|
|
"""
|
|
xs = paddle.cast(text, 'int64').unsqueeze(0)
|
|
note = paddle.cast(note, 'int64').unsqueeze(0)
|
|
note_dur = paddle.cast(note_dur, 'float32').unsqueeze(0)
|
|
is_slur = paddle.cast(is_slur, 'int64').unsqueeze(0)
|
|
d, p, e = durations, pitch, energy
|
|
# setup batch axis
|
|
ilens = paddle.shape(xs)[1]
|
|
|
|
if spk_emb is not None:
|
|
spk_emb = spk_emb.unsqueeze(0)
|
|
|
|
if use_teacher_forcing:
|
|
# use groundtruth of duration, pitch, and energy
|
|
ds = d.unsqueeze(0) if d is not None else None
|
|
ps = p.unsqueeze(0) if p is not None else None
|
|
es = e.unsqueeze(0) if e is not None else None
|
|
|
|
# (1, L, odim)
|
|
_, outs, d_outs, p_outs, e_outs, _ = self._forward(
|
|
xs=xs,
|
|
note=note,
|
|
note_dur=note_dur,
|
|
is_slur=is_slur,
|
|
ilens=ilens,
|
|
ds=ds,
|
|
ps=ps,
|
|
es=es,
|
|
spk_emb=spk_emb,
|
|
spk_id=spk_id,
|
|
is_inference=True)
|
|
else:
|
|
# (1, L, odim)
|
|
_, outs, d_outs, p_outs, e_outs, _ = self._forward(
|
|
xs=xs,
|
|
note=note,
|
|
note_dur=note_dur,
|
|
is_slur=is_slur,
|
|
ilens=ilens,
|
|
is_inference=True,
|
|
alpha=alpha,
|
|
spk_emb=spk_emb,
|
|
spk_id=spk_id, )
|
|
|
|
if e_outs is None:
|
|
e_outs = [None]
|
|
|
|
return outs[0], d_outs[0], p_outs[0], e_outs[0]
|
|
|
|
|
|
class FastSpeech2MIDILoss(FastSpeech2Loss):
|
|
"""Loss function module for DiffSinger."""
|
|
|
|
def __init__(self, use_masking: bool=True,
|
|
use_weighted_masking: bool=False):
|
|
"""Initialize feed-forward Transformer loss module.
|
|
Args:
|
|
use_masking (bool):
|
|
Whether to apply masking for padded part in loss calculation.
|
|
use_weighted_masking (bool):
|
|
Whether to weighted masking in loss calculation.
|
|
"""
|
|
assert check_argument_types()
|
|
super().__init__(use_masking, use_weighted_masking)
|
|
|
|
def forward(
|
|
self,
|
|
after_outs: paddle.Tensor,
|
|
before_outs: paddle.Tensor,
|
|
d_outs: paddle.Tensor,
|
|
p_outs: paddle.Tensor,
|
|
e_outs: paddle.Tensor,
|
|
ys: paddle.Tensor,
|
|
ds: paddle.Tensor,
|
|
ps: paddle.Tensor,
|
|
es: paddle.Tensor,
|
|
ilens: paddle.Tensor,
|
|
olens: paddle.Tensor,
|
|
spk_logits: paddle.Tensor=None,
|
|
spk_ids: paddle.Tensor=None,
|
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor,
|
|
paddle.Tensor, ]:
|
|
"""Calculate forward propagation.
|
|
|
|
Args:
|
|
after_outs(Tensor):
|
|
Batch of outputs after postnets (B, Lmax, odim).
|
|
before_outs(Tensor):
|
|
Batch of outputs before postnets (B, Lmax, odim).
|
|
d_outs(Tensor):
|
|
Batch of outputs of duration predictor (B, Tmax).
|
|
p_outs(Tensor):
|
|
Batch of outputs of pitch predictor (B, Lmax, 1).
|
|
e_outs(Tensor):
|
|
Batch of outputs of energy predictor (B, Lmax, 1).
|
|
ys(Tensor):
|
|
Batch of target features (B, Lmax, odim).
|
|
ds(Tensor):
|
|
Batch of durations (B, Tmax).
|
|
ps(Tensor):
|
|
Batch of target frame-averaged pitch (B, Lmax, 1).
|
|
es(Tensor):
|
|
Batch of target frame-averaged energy (B, Lmax, 1).
|
|
ilens(Tensor):
|
|
Batch of the lengths of each input (B,).
|
|
olens(Tensor):
|
|
Batch of the lengths of each target (B,).
|
|
spk_logits(Option[Tensor]):
|
|
Batch of outputs after speaker classifier (B, Lmax, num_spk)
|
|
spk_ids(Option[Tensor]):
|
|
Batch of target spk_id (B,)
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
"""
|
|
l1_loss = duration_loss = pitch_loss = energy_loss = speaker_loss = ssim_loss = 0.0
|
|
|
|
# apply mask to remove padded part
|
|
if self.use_masking:
|
|
# make feature for ssim loss
|
|
out_pad_masks = make_pad_mask(olens).unsqueeze(-1)
|
|
before_outs_ssim = masked_fill(before_outs, out_pad_masks, 0.0)
|
|
if not paddle.equal_all(after_outs, before_outs):
|
|
after_outs_ssim = masked_fill(after_outs, out_pad_masks, 0.0)
|
|
ys_ssim = masked_fill(ys, out_pad_masks, 0.0)
|
|
|
|
out_masks = make_non_pad_mask(olens).unsqueeze(-1)
|
|
before_outs = before_outs.masked_select(
|
|
out_masks.broadcast_to(before_outs.shape))
|
|
if not paddle.equal_all(after_outs, before_outs):
|
|
after_outs = after_outs.masked_select(
|
|
out_masks.broadcast_to(after_outs.shape))
|
|
ys = ys.masked_select(out_masks.broadcast_to(ys.shape))
|
|
duration_masks = make_non_pad_mask(ilens)
|
|
d_outs = d_outs.masked_select(
|
|
duration_masks.broadcast_to(d_outs.shape))
|
|
ds = ds.masked_select(duration_masks.broadcast_to(ds.shape))
|
|
pitch_masks = out_masks
|
|
p_outs = p_outs.masked_select(
|
|
pitch_masks.broadcast_to(p_outs.shape))
|
|
ps = ps.masked_select(pitch_masks.broadcast_to(ps.shape))
|
|
if e_outs is not None:
|
|
e_outs = e_outs.masked_select(
|
|
pitch_masks.broadcast_to(e_outs.shape))
|
|
es = es.masked_select(pitch_masks.broadcast_to(es.shape))
|
|
|
|
if spk_logits is not None and spk_ids is not None:
|
|
batch_size = spk_ids.shape[0]
|
|
spk_ids = paddle.repeat_interleave(spk_ids, spk_logits.shape[1],
|
|
None)
|
|
spk_logits = paddle.reshape(spk_logits,
|
|
[-1, spk_logits.shape[-1]])
|
|
mask_index = spk_logits.abs().sum(axis=1) != 0
|
|
spk_ids = spk_ids[mask_index]
|
|
spk_logits = spk_logits[mask_index]
|
|
|
|
# calculate loss
|
|
l1_loss = self.l1_criterion(before_outs, ys)
|
|
ssim_loss = 1.0 - ssim(
|
|
before_outs_ssim.unsqueeze(1), ys_ssim.unsqueeze(1))
|
|
if not paddle.equal_all(after_outs, before_outs):
|
|
l1_loss += self.l1_criterion(after_outs, ys)
|
|
ssim_loss += (
|
|
1.0 - ssim(after_outs_ssim.unsqueeze(1), ys_ssim.unsqueeze(1)))
|
|
l1_loss = l1_loss * 0.5
|
|
ssim_loss = ssim_loss * 0.5
|
|
|
|
duration_loss = self.duration_criterion(d_outs, ds)
|
|
pitch_loss = self.l1_criterion(p_outs, ps)
|
|
if e_outs is not None:
|
|
energy_loss = self.l1_criterion(e_outs, es)
|
|
|
|
if spk_logits is not None and spk_ids is not None:
|
|
speaker_loss = self.ce_criterion(spk_logits, spk_ids) / batch_size
|
|
|
|
# make weighted mask and apply it
|
|
if self.use_weighted_masking:
|
|
out_masks = make_non_pad_mask(olens).unsqueeze(-1)
|
|
out_weights = out_masks.cast(dtype=paddle.float32) / out_masks.cast(
|
|
dtype=paddle.float32).sum(
|
|
axis=1, keepdim=True)
|
|
out_weights /= ys.shape[0] * ys.shape[2]
|
|
duration_masks = make_non_pad_mask(ilens)
|
|
duration_weights = (duration_masks.cast(dtype=paddle.float32) /
|
|
duration_masks.cast(dtype=paddle.float32).sum(
|
|
axis=1, keepdim=True))
|
|
duration_weights /= ds.shape[0]
|
|
|
|
# apply weight
|
|
l1_loss = l1_loss.multiply(out_weights)
|
|
l1_loss = l1_loss.masked_select(
|
|
out_masks.broadcast_to(l1_loss.shape)).sum()
|
|
ssim_loss = ssim_loss.multiply(out_weights)
|
|
ssim_loss = ssim_loss.masked_select(
|
|
out_masks.broadcast_to(ssim_loss.shape)).sum()
|
|
duration_loss = (duration_loss.multiply(duration_weights)
|
|
.masked_select(duration_masks).sum())
|
|
pitch_masks = out_masks
|
|
pitch_weights = out_weights
|
|
pitch_loss = pitch_loss.multiply(pitch_weights)
|
|
pitch_loss = pitch_loss.masked_select(
|
|
pitch_masks.broadcast_to(pitch_loss.shape)).sum()
|
|
if e_outs is not None:
|
|
energy_loss = energy_loss.multiply(pitch_weights)
|
|
energy_loss = energy_loss.masked_select(
|
|
pitch_masks.broadcast_to(energy_loss.shape)).sum()
|
|
|
|
return l1_loss, ssim_loss, duration_loss, pitch_loss, energy_loss, speaker_loss
|