You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/paddlespeech/t2s/models/diffsinger/fastspeech2midi.py

655 lines
25 KiB

# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
from typing import Any
from typing import Dict
from typing import Sequence
from typing import Tuple
import paddle
from paddle import nn
from typeguard import check_argument_types
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Loss
from paddlespeech.t2s.modules.losses import ssim
from paddlespeech.t2s.modules.masked_fill import masked_fill
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
from paddlespeech.t2s.modules.nets_utils import make_pad_mask
class FastSpeech2MIDI(FastSpeech2):
"""The Fastspeech2 module of DiffSinger.
"""
def __init__(
self,
# fastspeech2 network structure related
idim: int,
odim: int,
fastspeech2_params: Dict[str, Any],
# note emb
note_num: int=300,
# is_slur emb
is_slur_num: int=2,
use_energy_pred: bool=False,
use_postnet: bool=False, ):
"""Initialize FastSpeech2 module for svs.
Args:
fastspeech2_params (Dict):
The config of FastSpeech2 module on DiffSinger model
note_num (Optional[int]):
Number of note. If not None, assume that the
note_ids will be provided as the input and use note_embedding_table.
is_slur_num (Optional[int]):
Number of note. If not None, assume that the
is_slur_ids will be provided as the input
"""
assert check_argument_types()
super().__init__(idim=idim, odim=odim, **fastspeech2_params)
self.use_energy_pred = use_energy_pred
self.use_postnet = use_postnet
if not self.use_postnet:
self.postnet = None
self.note_embed_dim = self.is_slur_embed_dim = fastspeech2_params[
"adim"]
# note_ embed
self.note_embedding_table = nn.Embedding(
num_embeddings=note_num,
embedding_dim=self.note_embed_dim,
padding_idx=self.padding_idx)
self.note_dur_layer = nn.Linear(1, self.note_embed_dim)
# slur embed
self.is_slur_embedding_table = nn.Embedding(
num_embeddings=is_slur_num,
embedding_dim=self.is_slur_embed_dim,
padding_idx=self.padding_idx)
def forward(
self,
text: paddle.Tensor,
note: paddle.Tensor,
note_dur: paddle.Tensor,
is_slur: paddle.Tensor,
text_lengths: paddle.Tensor,
speech: paddle.Tensor,
speech_lengths: paddle.Tensor,
durations: paddle.Tensor,
pitch: paddle.Tensor,
energy: paddle.Tensor,
spk_emb: paddle.Tensor=None,
spk_id: paddle.Tensor=None,
) -> Tuple[paddle.Tensor, Dict[str, paddle.Tensor], paddle.Tensor]:
"""Calculate forward propagation.
Args:
text(Tensor(int64)):
Batch of padded token (phone) ids (B, Tmax).
note(Tensor(int64)):
Batch of padded note (element in music score) ids (B, Tmax).
note_dur(Tensor(float32)):
Batch of padded note durations in seconds (element in music score) (B, Tmax).
is_slur(Tensor(int64)):
Batch of padded slur (element in music score) ids (B, Tmax).
text_lengths(Tensor(int64)):
Batch of phone lengths of each input (B,).
speech(Tensor[float32]):
Batch of padded target features (e.g. mel) (B, Lmax, odim).
speech_lengths(Tensor(int64)):
Batch of the lengths of each target features (B,).
durations(Tensor(int64)):
Batch of padded token durations in frame (B, Tmax).
pitch(Tensor[float32]):
Batch of padded frame-averaged pitch (B, Lmax, 1).
energy(Tensor[float32]):
Batch of padded frame-averaged energy (B, Lmax, 1).
spk_emb(Tensor[float32], optional):
Batch of speaker embeddings (B, spk_embed_dim).
spk_id(Tnesor[int64], optional(int64)):
Batch of speaker ids (B,)
Returns:
"""
xs = paddle.cast(text, 'int64')
note = paddle.cast(note, 'int64')
note_dur = paddle.cast(note_dur, 'float32')
is_slur = paddle.cast(is_slur, 'int64')
ilens = paddle.cast(text_lengths, 'int64')
olens = paddle.cast(speech_lengths, 'int64')
ds = paddle.cast(durations, 'int64')
ps = pitch
es = energy
ys = speech
olens = speech_lengths
if spk_id is not None:
spk_id = paddle.cast(spk_id, 'int64')
# forward propagation
before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits = self._forward(
xs=xs,
note=note,
note_dur=note_dur,
is_slur=is_slur,
ilens=ilens,
olens=olens,
ds=ds,
ps=ps,
es=es,
is_inference=False,
spk_emb=spk_emb,
spk_id=spk_id, )
# modify mod part of groundtruth
if self.reduction_factor > 1:
olens = olens - olens % self.reduction_factor
max_olen = max(olens)
ys = ys[:, :max_olen]
return before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits
def _forward(
self,
xs: paddle.Tensor,
note: paddle.Tensor,
note_dur: paddle.Tensor,
is_slur: paddle.Tensor,
ilens: paddle.Tensor,
olens: paddle.Tensor=None,
ds: paddle.Tensor=None,
ps: paddle.Tensor=None,
es: paddle.Tensor=None,
is_inference: bool=False,
is_train_diffusion: bool=False,
return_after_enc=False,
alpha: float=1.0,
spk_emb=None,
spk_id=None, ) -> Sequence[paddle.Tensor]:
before_outs = after_outs = d_outs = p_outs = e_outs = spk_logits = None
# forward encoder
masks = self._source_mask(ilens)
note_emb = self.note_embedding_table(note)
note_dur_emb = self.note_dur_layer(paddle.unsqueeze(note_dur, axis=-1))
is_slur_emb = self.is_slur_embedding_table(is_slur)
# (B, Tmax, adim)
hs, _ = self.encoder(
xs=xs,
masks=masks,
note_emb=note_emb,
note_dur_emb=note_dur_emb,
is_slur_emb=is_slur_emb, )
if self.spk_num and self.enable_speaker_classifier and not is_inference:
hs_for_spk_cls = self.grad_reverse(hs)
spk_logits = self.speaker_classifier(hs_for_spk_cls, ilens)
else:
spk_logits = None
# integrate speaker embedding
if self.spk_embed_dim is not None:
# spk_emb has a higher priority than spk_id
if spk_emb is not None:
hs = self._integrate_with_spk_embed(hs, spk_emb)
elif spk_id is not None:
spk_emb = self.spk_embedding_table(spk_id)
hs = self._integrate_with_spk_embed(hs, spk_emb)
# forward duration predictor (phone-level) and variance predictors (frame-level)
d_masks = make_pad_mask(ilens)
if olens is not None:
pitch_masks = make_pad_mask(olens).unsqueeze(-1)
else:
pitch_masks = None
# inference for decoder input for diffusion
if is_train_diffusion:
hs = self.length_regulator(hs, ds, is_inference=False)
p_outs = self.pitch_predictor(hs.detach(), pitch_masks)
p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose(
(0, 2, 1))
hs += p_embs
if self.use_energy_pred:
e_outs = self.energy_predictor(hs.detach(), pitch_masks)
e_embs = self.energy_embed(
e_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
hs += e_embs
elif is_inference:
# (B, Tmax)
if ds is not None:
d_outs = ds
else:
d_outs = self.duration_predictor.inference(hs, d_masks)
# (B, Lmax, adim)
hs = self.length_regulator(hs, d_outs, alpha, is_inference=True)
if ps is not None:
p_outs = ps
else:
if self.stop_gradient_from_pitch_predictor:
p_outs = self.pitch_predictor(hs.detach(), pitch_masks)
else:
p_outs = self.pitch_predictor(hs, pitch_masks)
p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose(
(0, 2, 1))
hs += p_embs
if self.use_energy_pred:
if es is not None:
e_outs = es
else:
if self.stop_gradient_from_energy_predictor:
e_outs = self.energy_predictor(hs.detach(), pitch_masks)
else:
e_outs = self.energy_predictor(hs, pitch_masks)
e_embs = self.energy_embed(
e_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
hs += e_embs
# training
else:
d_outs = self.duration_predictor(hs, d_masks)
# (B, Lmax, adim)
hs = self.length_regulator(hs, ds, is_inference=False)
if self.stop_gradient_from_pitch_predictor:
p_outs = self.pitch_predictor(hs.detach(), pitch_masks)
else:
p_outs = self.pitch_predictor(hs, pitch_masks)
p_embs = self.pitch_embed(ps.transpose((0, 2, 1))).transpose(
(0, 2, 1))
hs += p_embs
if self.use_energy_pred:
if self.stop_gradient_from_energy_predictor:
e_outs = self.energy_predictor(hs.detach(), pitch_masks)
else:
e_outs = self.energy_predictor(hs, pitch_masks)
e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose(
(0, 2, 1))
hs += e_embs
# forward decoder
if olens is not None and not is_inference:
if self.reduction_factor > 1:
olens_in = paddle.to_tensor(
[olen // self.reduction_factor for olen in olens.numpy()])
else:
olens_in = olens
# (B, 1, T)
h_masks = self._source_mask(olens_in)
else:
h_masks = None
if return_after_enc:
return hs, h_masks
if self.decoder_type == 'cnndecoder':
# remove output masks for dygraph to static graph
zs = self.decoder(hs, h_masks)
before_outs = zs
else:
# (B, Lmax, adim)
zs, _ = self.decoder(hs, h_masks)
# (B, Lmax, odim)
before_outs = self.feat_out(zs).reshape(
(paddle.shape(zs)[0], -1, self.odim))
# postnet -> (B, Lmax//r * r, odim)
if self.postnet is None:
after_outs = before_outs
else:
after_outs = before_outs + self.postnet(
before_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
return before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits
def encoder_infer(
self,
text: paddle.Tensor,
note: paddle.Tensor,
note_dur: paddle.Tensor,
is_slur: paddle.Tensor,
alpha: float=1.0,
spk_emb=None,
spk_id=None,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
xs = paddle.cast(text, 'int64').unsqueeze(0)
note = paddle.cast(note, 'int64').unsqueeze(0)
note_dur = paddle.cast(note_dur, 'float32').unsqueeze(0)
is_slur = paddle.cast(is_slur, 'int64').unsqueeze(0)
# setup batch axis
ilens = paddle.shape(xs)[1]
if spk_emb is not None:
spk_emb = spk_emb.unsqueeze(0)
# (1, L, odim)
# use *_ to avoid bug in dygraph to static graph
hs, _ = self._forward(
xs=xs,
note=note,
note_dur=note_dur,
is_slur=is_slur,
ilens=ilens,
is_inference=True,
return_after_enc=True,
alpha=alpha,
spk_emb=spk_emb,
spk_id=spk_id, )
return hs
# get encoder output for diffusion training
def encoder_infer_batch(
self,
text: paddle.Tensor,
note: paddle.Tensor,
note_dur: paddle.Tensor,
is_slur: paddle.Tensor,
text_lengths: paddle.Tensor,
speech_lengths: paddle.Tensor,
ds: paddle.Tensor=None,
ps: paddle.Tensor=None,
es: paddle.Tensor=None,
alpha: float=1.0,
spk_emb=None,
spk_id=None, ) -> Tuple[paddle.Tensor, paddle.Tensor]:
xs = paddle.cast(text, 'int64')
note = paddle.cast(note, 'int64')
note_dur = paddle.cast(note_dur, 'float32')
is_slur = paddle.cast(is_slur, 'int64')
ilens = paddle.cast(text_lengths, 'int64')
olens = paddle.cast(speech_lengths, 'int64')
if spk_emb is not None:
spk_emb = spk_emb.unsqueeze(0)
# (1, L, odim)
# use *_ to avoid bug in dygraph to static graph
hs, h_masks = self._forward(
xs=xs,
note=note,
note_dur=note_dur,
is_slur=is_slur,
ilens=ilens,
olens=olens,
ds=ds,
ps=ps,
es=es,
return_after_enc=True,
is_train_diffusion=True,
alpha=alpha,
spk_emb=spk_emb,
spk_id=spk_id, )
return hs, h_masks
def inference(
self,
text: paddle.Tensor,
note: paddle.Tensor,
note_dur: paddle.Tensor,
is_slur: paddle.Tensor,
durations: paddle.Tensor=None,
pitch: paddle.Tensor=None,
energy: paddle.Tensor=None,
alpha: float=1.0,
use_teacher_forcing: bool=False,
spk_emb=None,
spk_id=None,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Generate the sequence of features given the sequences of characters.
Args:
text(Tensor(int64)):
Input sequence of characters (T,).
note(Tensor(int64)):
Input note (element in music score) ids (T,).
note_dur(Tensor(float32)):
Input note durations in seconds (element in music score) (T,).
is_slur(Tensor(int64)):
Input slur (element in music score) ids (T,).
durations(Tensor, optional (int64)):
Groundtruth of duration (T,).
pitch(Tensor, optional):
Groundtruth of token-averaged pitch (T, 1).
energy(Tensor, optional):
Groundtruth of token-averaged energy (T, 1).
alpha(float, optional):
Alpha to control the speed.
use_teacher_forcing(bool, optional):
Whether to use teacher forcing.
If true, groundtruth of duration, pitch and energy will be used.
spk_emb(Tensor, optional, optional):
peaker embedding vector (spk_embed_dim,). (Default value = None)
spk_id(Tensor, optional(int64), optional):
spk ids (1,). (Default value = None)
Returns:
"""
xs = paddle.cast(text, 'int64').unsqueeze(0)
note = paddle.cast(note, 'int64').unsqueeze(0)
note_dur = paddle.cast(note_dur, 'float32').unsqueeze(0)
is_slur = paddle.cast(is_slur, 'int64').unsqueeze(0)
d, p, e = durations, pitch, energy
# setup batch axis
ilens = paddle.shape(xs)[1]
if spk_emb is not None:
spk_emb = spk_emb.unsqueeze(0)
if use_teacher_forcing:
# use groundtruth of duration, pitch, and energy
ds = d.unsqueeze(0) if d is not None else None
ps = p.unsqueeze(0) if p is not None else None
es = e.unsqueeze(0) if e is not None else None
# (1, L, odim)
_, outs, d_outs, p_outs, e_outs, _ = self._forward(
xs=xs,
note=note,
note_dur=note_dur,
is_slur=is_slur,
ilens=ilens,
ds=ds,
ps=ps,
es=es,
spk_emb=spk_emb,
spk_id=spk_id,
is_inference=True)
else:
# (1, L, odim)
_, outs, d_outs, p_outs, e_outs, _ = self._forward(
xs=xs,
note=note,
note_dur=note_dur,
is_slur=is_slur,
ilens=ilens,
is_inference=True,
alpha=alpha,
spk_emb=spk_emb,
spk_id=spk_id, )
if e_outs is None:
e_outs = [None]
return outs[0], d_outs[0], p_outs[0], e_outs[0]
class FastSpeech2MIDILoss(FastSpeech2Loss):
"""Loss function module for DiffSinger."""
def __init__(self, use_masking: bool=True,
use_weighted_masking: bool=False):
"""Initialize feed-forward Transformer loss module.
Args:
use_masking (bool):
Whether to apply masking for padded part in loss calculation.
use_weighted_masking (bool):
Whether to weighted masking in loss calculation.
"""
assert check_argument_types()
super().__init__(use_masking, use_weighted_masking)
def forward(
self,
after_outs: paddle.Tensor,
before_outs: paddle.Tensor,
d_outs: paddle.Tensor,
p_outs: paddle.Tensor,
e_outs: paddle.Tensor,
ys: paddle.Tensor,
ds: paddle.Tensor,
ps: paddle.Tensor,
es: paddle.Tensor,
ilens: paddle.Tensor,
olens: paddle.Tensor,
spk_logits: paddle.Tensor=None,
spk_ids: paddle.Tensor=None,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor,
paddle.Tensor, ]:
"""Calculate forward propagation.
Args:
after_outs(Tensor):
Batch of outputs after postnets (B, Lmax, odim).
before_outs(Tensor):
Batch of outputs before postnets (B, Lmax, odim).
d_outs(Tensor):
Batch of outputs of duration predictor (B, Tmax).
p_outs(Tensor):
Batch of outputs of pitch predictor (B, Lmax, 1).
e_outs(Tensor):
Batch of outputs of energy predictor (B, Lmax, 1).
ys(Tensor):
Batch of target features (B, Lmax, odim).
ds(Tensor):
Batch of durations (B, Tmax).
ps(Tensor):
Batch of target frame-averaged pitch (B, Lmax, 1).
es(Tensor):
Batch of target frame-averaged energy (B, Lmax, 1).
ilens(Tensor):
Batch of the lengths of each input (B,).
olens(Tensor):
Batch of the lengths of each target (B,).
spk_logits(Option[Tensor]):
Batch of outputs after speaker classifier (B, Lmax, num_spk)
spk_ids(Option[Tensor]):
Batch of target spk_id (B,)
Returns:
"""
l1_loss = duration_loss = pitch_loss = energy_loss = speaker_loss = ssim_loss = 0.0
# apply mask to remove padded part
if self.use_masking:
# make feature for ssim loss
out_pad_masks = make_pad_mask(olens).unsqueeze(-1)
before_outs_ssim = masked_fill(before_outs, out_pad_masks, 0.0)
if not paddle.equal_all(after_outs, before_outs):
after_outs_ssim = masked_fill(after_outs, out_pad_masks, 0.0)
ys_ssim = masked_fill(ys, out_pad_masks, 0.0)
out_masks = make_non_pad_mask(olens).unsqueeze(-1)
before_outs = before_outs.masked_select(
out_masks.broadcast_to(before_outs.shape))
if not paddle.equal_all(after_outs, before_outs):
after_outs = after_outs.masked_select(
out_masks.broadcast_to(after_outs.shape))
ys = ys.masked_select(out_masks.broadcast_to(ys.shape))
duration_masks = make_non_pad_mask(ilens)
d_outs = d_outs.masked_select(
duration_masks.broadcast_to(d_outs.shape))
ds = ds.masked_select(duration_masks.broadcast_to(ds.shape))
pitch_masks = out_masks
p_outs = p_outs.masked_select(
pitch_masks.broadcast_to(p_outs.shape))
ps = ps.masked_select(pitch_masks.broadcast_to(ps.shape))
if e_outs is not None:
e_outs = e_outs.masked_select(
pitch_masks.broadcast_to(e_outs.shape))
es = es.masked_select(pitch_masks.broadcast_to(es.shape))
if spk_logits is not None and spk_ids is not None:
batch_size = spk_ids.shape[0]
spk_ids = paddle.repeat_interleave(spk_ids, spk_logits.shape[1],
None)
spk_logits = paddle.reshape(spk_logits,
[-1, spk_logits.shape[-1]])
mask_index = spk_logits.abs().sum(axis=1) != 0
spk_ids = spk_ids[mask_index]
spk_logits = spk_logits[mask_index]
# calculate loss
l1_loss = self.l1_criterion(before_outs, ys)
ssim_loss = 1.0 - ssim(
before_outs_ssim.unsqueeze(1), ys_ssim.unsqueeze(1))
if not paddle.equal_all(after_outs, before_outs):
l1_loss += self.l1_criterion(after_outs, ys)
ssim_loss += (
1.0 - ssim(after_outs_ssim.unsqueeze(1), ys_ssim.unsqueeze(1)))
l1_loss = l1_loss * 0.5
ssim_loss = ssim_loss * 0.5
duration_loss = self.duration_criterion(d_outs, ds)
pitch_loss = self.l1_criterion(p_outs, ps)
if e_outs is not None:
energy_loss = self.l1_criterion(e_outs, es)
if spk_logits is not None and spk_ids is not None:
speaker_loss = self.ce_criterion(spk_logits, spk_ids) / batch_size
# make weighted mask and apply it
if self.use_weighted_masking:
out_masks = make_non_pad_mask(olens).unsqueeze(-1)
out_weights = out_masks.cast(dtype=paddle.float32) / out_masks.cast(
dtype=paddle.float32).sum(
axis=1, keepdim=True)
out_weights /= ys.shape[0] * ys.shape[2]
duration_masks = make_non_pad_mask(ilens)
duration_weights = (duration_masks.cast(dtype=paddle.float32) /
duration_masks.cast(dtype=paddle.float32).sum(
axis=1, keepdim=True))
duration_weights /= ds.shape[0]
# apply weight
l1_loss = l1_loss.multiply(out_weights)
l1_loss = l1_loss.masked_select(
out_masks.broadcast_to(l1_loss.shape)).sum()
ssim_loss = ssim_loss.multiply(out_weights)
ssim_loss = ssim_loss.masked_select(
out_masks.broadcast_to(ssim_loss.shape)).sum()
duration_loss = (duration_loss.multiply(duration_weights)
.masked_select(duration_masks).sum())
pitch_masks = out_masks
pitch_weights = out_weights
pitch_loss = pitch_loss.multiply(pitch_weights)
pitch_loss = pitch_loss.masked_select(
pitch_masks.broadcast_to(pitch_loss.shape)).sum()
if e_outs is not None:
energy_loss = energy_loss.multiply(pitch_weights)
energy_loss = energy_loss.masked_select(
pitch_masks.broadcast_to(energy_loss.shape)).sum()
return l1_loss, ssim_loss, duration_loss, pitch_loss, energy_loss, speaker_loss