You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/paddlespeech/t2s/models/diffsinger/fastspeech2midi.py

655 lines
25 KiB

# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
from typing import Any
from typing import Dict
from typing import Sequence
from typing import Tuple
import paddle
from paddle import nn
from typeguard import typechecked
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Loss
from paddlespeech.t2s.modules.losses import ssim
from paddlespeech.t2s.modules.masked_fill import masked_fill
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
from paddlespeech.t2s.modules.nets_utils import make_pad_mask
class FastSpeech2MIDI(FastSpeech2):
"""The Fastspeech2 module of DiffSinger.
"""
@typechecked
def __init__(
self,
# fastspeech2 network structure related
idim: int,
odim: int,
fastspeech2_params: Dict[str, Any],
# note emb
note_num: int=300,
# is_slur emb
is_slur_num: int=2,
use_energy_pred: bool=False,
use_postnet: bool=False, ):
"""Initialize FastSpeech2 module for svs.
Args:
fastspeech2_params (Dict):
The config of FastSpeech2 module on DiffSinger model
note_num (Optional[int]):
Number of note. If not None, assume that the
note_ids will be provided as the input and use note_embedding_table.
is_slur_num (Optional[int]):
Number of note. If not None, assume that the
is_slur_ids will be provided as the input
"""
super().__init__(idim=idim, odim=odim, **fastspeech2_params)
self.use_energy_pred = use_energy_pred
self.use_postnet = use_postnet
if not self.use_postnet:
self.postnet = None
self.note_embed_dim = self.is_slur_embed_dim = fastspeech2_params[
"adim"]
# note_ embed
self.note_embedding_table = nn.Embedding(
num_embeddings=note_num,
embedding_dim=self.note_embed_dim,
padding_idx=self.padding_idx)
self.note_dur_layer = nn.Linear(1, self.note_embed_dim)
# slur embed
self.is_slur_embedding_table = nn.Embedding(
num_embeddings=is_slur_num,
embedding_dim=self.is_slur_embed_dim,
padding_idx=self.padding_idx)
def forward(
self,
text: paddle.Tensor,
note: paddle.Tensor,
note_dur: paddle.Tensor,
is_slur: paddle.Tensor,
text_lengths: paddle.Tensor,
speech: paddle.Tensor,
speech_lengths: paddle.Tensor,
durations: paddle.Tensor,
pitch: paddle.Tensor,
energy: paddle.Tensor,
spk_emb: paddle.Tensor=None,
spk_id: paddle.Tensor=None,
) -> Tuple[paddle.Tensor, Dict[str, paddle.Tensor], paddle.Tensor]:
"""Calculate forward propagation.
Args:
text(Tensor(int64)):
Batch of padded token (phone) ids (B, Tmax).
note(Tensor(int64)):
Batch of padded note (element in music score) ids (B, Tmax).
note_dur(Tensor(float32)):
Batch of padded note durations in seconds (element in music score) (B, Tmax).
is_slur(Tensor(int64)):
Batch of padded slur (element in music score) ids (B, Tmax).
text_lengths(Tensor(int64)):
Batch of phone lengths of each input (B,).
speech(Tensor[float32]):
Batch of padded target features (e.g. mel) (B, Lmax, odim).
speech_lengths(Tensor(int64)):
Batch of the lengths of each target features (B,).
durations(Tensor(int64)):
Batch of padded token durations in frame (B, Tmax).
pitch(Tensor[float32]):
Batch of padded frame-averaged pitch (B, Lmax, 1).
energy(Tensor[float32]):
Batch of padded frame-averaged energy (B, Lmax, 1).
spk_emb(Tensor[float32], optional):
Batch of speaker embeddings (B, spk_embed_dim).
spk_id(Tnesor[int64], optional(int64)):
Batch of speaker ids (B,)
Returns:
"""
xs = paddle.cast(text, 'int64')
note = paddle.cast(note, 'int64')
note_dur = paddle.cast(note_dur, 'float32')
is_slur = paddle.cast(is_slur, 'int64')
ilens = paddle.cast(text_lengths, 'int64')
olens = paddle.cast(speech_lengths, 'int64')
ds = paddle.cast(durations, 'int64')
ps = pitch
es = energy
ys = speech
olens = speech_lengths
if spk_id is not None:
spk_id = paddle.cast(spk_id, 'int64')
# forward propagation
before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits = self._forward(
xs=xs,
note=note,
note_dur=note_dur,
is_slur=is_slur,
ilens=ilens,
olens=olens,
ds=ds,
ps=ps,
es=es,
is_inference=False,
spk_emb=spk_emb,
spk_id=spk_id, )
# modify mod part of groundtruth
if self.reduction_factor > 1:
olens = olens - olens % self.reduction_factor
max_olen = max(olens)
ys = ys[:, :max_olen]
return before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits
def _forward(
self,
xs: paddle.Tensor,
note: paddle.Tensor,
note_dur: paddle.Tensor,
is_slur: paddle.Tensor,
ilens: paddle.Tensor,
olens: paddle.Tensor=None,
ds: paddle.Tensor=None,
ps: paddle.Tensor=None,
es: paddle.Tensor=None,
is_inference: bool=False,
is_train_diffusion: bool=False,
return_after_enc=False,
alpha: float=1.0,
spk_emb=None,
spk_id=None, ) -> Sequence[paddle.Tensor]:
before_outs = after_outs = d_outs = p_outs = e_outs = spk_logits = None
# forward encoder
masks = self._source_mask(ilens)
note_emb = self.note_embedding_table(note)
note_dur_emb = self.note_dur_layer(paddle.unsqueeze(note_dur, axis=-1))
is_slur_emb = self.is_slur_embedding_table(is_slur)
# (B, Tmax, adim)
hs, _ = self.encoder(
xs=xs,
masks=masks,
note_emb=note_emb,
note_dur_emb=note_dur_emb,
is_slur_emb=is_slur_emb, )
if self.spk_num and self.enable_speaker_classifier and not is_inference:
hs_for_spk_cls = self.grad_reverse(hs)
spk_logits = self.speaker_classifier(hs_for_spk_cls, ilens)
else:
spk_logits = None
# integrate speaker embedding
if self.spk_embed_dim is not None:
# spk_emb has a higher priority than spk_id
if spk_emb is not None:
hs = self._integrate_with_spk_embed(hs, spk_emb)
elif spk_id is not None:
spk_emb = self.spk_embedding_table(spk_id)
hs = self._integrate_with_spk_embed(hs, spk_emb)
# forward duration predictor (phone-level) and variance predictors (frame-level)
d_masks = make_pad_mask(ilens)
if olens is not None:
pitch_masks = make_pad_mask(olens).unsqueeze(-1)
else:
pitch_masks = None
# inference for decoder input for diffusion
if is_train_diffusion:
hs = self.length_regulator(hs, ds, is_inference=False)
p_outs = self.pitch_predictor(hs.detach(), pitch_masks)
p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose(
(0, 2, 1))
hs += p_embs
if self.use_energy_pred:
e_outs = self.energy_predictor(hs.detach(), pitch_masks)
e_embs = self.energy_embed(
e_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
hs += e_embs
elif is_inference:
# (B, Tmax)
if ds is not None:
d_outs = ds
else:
d_outs = self.duration_predictor.inference(hs, d_masks)
# (B, Lmax, adim)
hs = self.length_regulator(hs, d_outs, alpha, is_inference=True)
if ps is not None:
p_outs = ps
else:
if self.stop_gradient_from_pitch_predictor:
p_outs = self.pitch_predictor(hs.detach(), pitch_masks)
else:
p_outs = self.pitch_predictor(hs, pitch_masks)
p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose(
(0, 2, 1))
hs += p_embs
if self.use_energy_pred:
if es is not None:
e_outs = es
else:
if self.stop_gradient_from_energy_predictor:
e_outs = self.energy_predictor(hs.detach(), pitch_masks)
else:
e_outs = self.energy_predictor(hs, pitch_masks)
e_embs = self.energy_embed(
e_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
hs += e_embs
# training
else:
d_outs = self.duration_predictor(hs, d_masks)
# (B, Lmax, adim)
hs = self.length_regulator(hs, ds, is_inference=False)
if self.stop_gradient_from_pitch_predictor:
p_outs = self.pitch_predictor(hs.detach(), pitch_masks)
else:
p_outs = self.pitch_predictor(hs, pitch_masks)
p_embs = self.pitch_embed(ps.transpose((0, 2, 1))).transpose(
(0, 2, 1))
hs += p_embs
if self.use_energy_pred:
if self.stop_gradient_from_energy_predictor:
e_outs = self.energy_predictor(hs.detach(), pitch_masks)
else:
e_outs = self.energy_predictor(hs, pitch_masks)
e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose(
(0, 2, 1))
hs += e_embs
# forward decoder
if olens is not None and not is_inference:
if self.reduction_factor > 1:
olens_in = paddle.to_tensor(
[olen // self.reduction_factor for olen in olens.numpy()])
else:
olens_in = olens
# (B, 1, T)
h_masks = self._source_mask(olens_in)
else:
h_masks = None
if return_after_enc:
return hs, h_masks
if self.decoder_type == 'cnndecoder':
# remove output masks for dygraph to static graph
zs = self.decoder(hs, h_masks)
before_outs = zs
else:
# (B, Lmax, adim)
zs, _ = self.decoder(hs, h_masks)
# (B, Lmax, odim)
before_outs = self.feat_out(zs).reshape(
(paddle.shape(zs)[0], -1, self.odim))
# postnet -> (B, Lmax//r * r, odim)
if self.postnet is None:
after_outs = before_outs
else:
after_outs = before_outs + self.postnet(
before_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
return before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits
def encoder_infer(
self,
text: paddle.Tensor,
note: paddle.Tensor,
note_dur: paddle.Tensor,
is_slur: paddle.Tensor,
alpha: float=1.0,
spk_emb=None,
spk_id=None,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
xs = paddle.cast(text, 'int64').unsqueeze(0)
note = paddle.cast(note, 'int64').unsqueeze(0)
note_dur = paddle.cast(note_dur, 'float32').unsqueeze(0)
is_slur = paddle.cast(is_slur, 'int64').unsqueeze(0)
# setup batch axis
ilens = paddle.shape(xs)[1]
if spk_emb is not None:
spk_emb = spk_emb.unsqueeze(0)
# (1, L, odim)
# use *_ to avoid bug in dygraph to static graph
hs, _ = self._forward(
xs=xs,
note=note,
note_dur=note_dur,
is_slur=is_slur,
ilens=ilens,
is_inference=True,
return_after_enc=True,
alpha=alpha,
spk_emb=spk_emb,
spk_id=spk_id, )
return hs
# get encoder output for diffusion training
def encoder_infer_batch(
self,
text: paddle.Tensor,
note: paddle.Tensor,
note_dur: paddle.Tensor,
is_slur: paddle.Tensor,
text_lengths: paddle.Tensor,
speech_lengths: paddle.Tensor,
ds: paddle.Tensor=None,
ps: paddle.Tensor=None,
es: paddle.Tensor=None,
alpha: float=1.0,
spk_emb=None,
spk_id=None, ) -> Tuple[paddle.Tensor, paddle.Tensor]:
xs = paddle.cast(text, 'int64')
note = paddle.cast(note, 'int64')
note_dur = paddle.cast(note_dur, 'float32')
is_slur = paddle.cast(is_slur, 'int64')
ilens = paddle.cast(text_lengths, 'int64')
olens = paddle.cast(speech_lengths, 'int64')
if spk_emb is not None:
spk_emb = spk_emb.unsqueeze(0)
# (1, L, odim)
# use *_ to avoid bug in dygraph to static graph
hs, h_masks = self._forward(
xs=xs,
note=note,
note_dur=note_dur,
is_slur=is_slur,
ilens=ilens,
olens=olens,
ds=ds,
ps=ps,
es=es,
return_after_enc=True,
is_train_diffusion=True,
alpha=alpha,
spk_emb=spk_emb,
spk_id=spk_id, )
return hs, h_masks
def inference(
self,
text: paddle.Tensor,
note: paddle.Tensor,
note_dur: paddle.Tensor,
is_slur: paddle.Tensor,
durations: paddle.Tensor=None,
pitch: paddle.Tensor=None,
energy: paddle.Tensor=None,
alpha: float=1.0,
use_teacher_forcing: bool=False,
spk_emb=None,
spk_id=None,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Generate the sequence of features given the sequences of characters.
Args:
text(Tensor(int64)):
Input sequence of characters (T,).
note(Tensor(int64)):
Input note (element in music score) ids (T,).
note_dur(Tensor(float32)):
Input note durations in seconds (element in music score) (T,).
is_slur(Tensor(int64)):
Input slur (element in music score) ids (T,).
durations(Tensor, optional (int64)):
Groundtruth of duration (T,).
pitch(Tensor, optional):
Groundtruth of token-averaged pitch (T, 1).
energy(Tensor, optional):
Groundtruth of token-averaged energy (T, 1).
alpha(float, optional):
Alpha to control the speed.
use_teacher_forcing(bool, optional):
Whether to use teacher forcing.
If true, groundtruth of duration, pitch and energy will be used.
spk_emb(Tensor, optional, optional):
peaker embedding vector (spk_embed_dim,). (Default value = None)
spk_id(Tensor, optional(int64), optional):
spk ids (1,). (Default value = None)
Returns:
"""
xs = paddle.cast(text, 'int64').unsqueeze(0)
note = paddle.cast(note, 'int64').unsqueeze(0)
note_dur = paddle.cast(note_dur, 'float32').unsqueeze(0)
is_slur = paddle.cast(is_slur, 'int64').unsqueeze(0)
d, p, e = durations, pitch, energy
# setup batch axis
ilens = paddle.shape(xs)[1]
if spk_emb is not None:
spk_emb = spk_emb.unsqueeze(0)
if use_teacher_forcing:
# use groundtruth of duration, pitch, and energy
ds = d.unsqueeze(0) if d is not None else None
ps = p.unsqueeze(0) if p is not None else None
es = e.unsqueeze(0) if e is not None else None
# (1, L, odim)
_, outs, d_outs, p_outs, e_outs, _ = self._forward(
xs=xs,
note=note,
note_dur=note_dur,
is_slur=is_slur,
ilens=ilens,
ds=ds,
ps=ps,
es=es,
spk_emb=spk_emb,
spk_id=spk_id,
is_inference=True)
else:
# (1, L, odim)
_, outs, d_outs, p_outs, e_outs, _ = self._forward(
xs=xs,
note=note,
note_dur=note_dur,
is_slur=is_slur,
ilens=ilens,
is_inference=True,
alpha=alpha,
spk_emb=spk_emb,
spk_id=spk_id, )
if e_outs is None:
e_outs = [None]
return outs[0], d_outs[0], p_outs[0], e_outs[0]
class FastSpeech2MIDILoss(FastSpeech2Loss):
"""Loss function module for DiffSinger."""
@typechecked
def __init__(self, use_masking: bool=True,
use_weighted_masking: bool=False):
"""Initialize feed-forward Transformer loss module.
Args:
use_masking (bool):
Whether to apply masking for padded part in loss calculation.
use_weighted_masking (bool):
Whether to weighted masking in loss calculation.
"""
super().__init__(use_masking, use_weighted_masking)
def forward(
self,
after_outs: paddle.Tensor,
before_outs: paddle.Tensor,
d_outs: paddle.Tensor,
p_outs: paddle.Tensor,
e_outs: paddle.Tensor,
ys: paddle.Tensor,
ds: paddle.Tensor,
ps: paddle.Tensor,
es: paddle.Tensor,
ilens: paddle.Tensor,
olens: paddle.Tensor,
spk_logits: paddle.Tensor=None,
spk_ids: paddle.Tensor=None,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor,
paddle.Tensor, ]:
"""Calculate forward propagation.
Args:
after_outs(Tensor):
Batch of outputs after postnets (B, Lmax, odim).
before_outs(Tensor):
Batch of outputs before postnets (B, Lmax, odim).
d_outs(Tensor):
Batch of outputs of duration predictor (B, Tmax).
p_outs(Tensor):
Batch of outputs of pitch predictor (B, Lmax, 1).
e_outs(Tensor):
Batch of outputs of energy predictor (B, Lmax, 1).
ys(Tensor):
Batch of target features (B, Lmax, odim).
ds(Tensor):
Batch of durations (B, Tmax).
ps(Tensor):
Batch of target frame-averaged pitch (B, Lmax, 1).
es(Tensor):
Batch of target frame-averaged energy (B, Lmax, 1).
ilens(Tensor):
Batch of the lengths of each input (B,).
olens(Tensor):
Batch of the lengths of each target (B,).
spk_logits(Option[Tensor]):
Batch of outputs after speaker classifier (B, Lmax, num_spk)
spk_ids(Option[Tensor]):
Batch of target spk_id (B,)
Returns:
"""
l1_loss = duration_loss = pitch_loss = energy_loss = speaker_loss = ssim_loss = 0.0
# apply mask to remove padded part
if self.use_masking:
# make feature for ssim loss
out_pad_masks = make_pad_mask(olens).unsqueeze(-1)
before_outs_ssim = masked_fill(before_outs, out_pad_masks, 0.0)
if not paddle.equal_all(after_outs, before_outs):
after_outs_ssim = masked_fill(after_outs, out_pad_masks, 0.0)
ys_ssim = masked_fill(ys, out_pad_masks, 0.0)
out_masks = make_non_pad_mask(olens).unsqueeze(-1)
before_outs = before_outs.masked_select(
out_masks.broadcast_to(before_outs.shape))
if not paddle.equal_all(after_outs, before_outs):
after_outs = after_outs.masked_select(
out_masks.broadcast_to(after_outs.shape))
ys = ys.masked_select(out_masks.broadcast_to(ys.shape))
duration_masks = make_non_pad_mask(ilens)
d_outs = d_outs.masked_select(
duration_masks.broadcast_to(d_outs.shape))
ds = ds.masked_select(duration_masks.broadcast_to(ds.shape))
pitch_masks = out_masks
p_outs = p_outs.masked_select(
pitch_masks.broadcast_to(p_outs.shape))
ps = ps.masked_select(pitch_masks.broadcast_to(ps.shape))
if e_outs is not None:
e_outs = e_outs.masked_select(
pitch_masks.broadcast_to(e_outs.shape))
es = es.masked_select(pitch_masks.broadcast_to(es.shape))
if spk_logits is not None and spk_ids is not None:
batch_size = spk_ids.shape[0]
spk_ids = paddle.repeat_interleave(spk_ids, spk_logits.shape[1],
None)
spk_logits = paddle.reshape(spk_logits,
[-1, spk_logits.shape[-1]])
mask_index = spk_logits.abs().sum(axis=1) != 0
spk_ids = spk_ids[mask_index]
spk_logits = spk_logits[mask_index]
# calculate loss
l1_loss = self.l1_criterion(before_outs, ys)
ssim_loss = 1.0 - ssim(
before_outs_ssim.unsqueeze(1), ys_ssim.unsqueeze(1))
if not paddle.equal_all(after_outs, before_outs):
l1_loss += self.l1_criterion(after_outs, ys)
ssim_loss += (
1.0 - ssim(after_outs_ssim.unsqueeze(1), ys_ssim.unsqueeze(1)))
l1_loss = l1_loss * 0.5
ssim_loss = ssim_loss * 0.5
duration_loss = self.duration_criterion(d_outs, ds)
pitch_loss = self.l1_criterion(p_outs, ps)
if e_outs is not None:
energy_loss = self.l1_criterion(e_outs, es)
if spk_logits is not None and spk_ids is not None:
speaker_loss = self.ce_criterion(spk_logits, spk_ids) / batch_size
# make weighted mask and apply it
if self.use_weighted_masking:
out_masks = make_non_pad_mask(olens).unsqueeze(-1)
out_weights = out_masks.cast(dtype=paddle.float32) / out_masks.cast(
dtype=paddle.float32).sum(
axis=1, keepdim=True)
out_weights /= ys.shape[0] * ys.shape[2]
duration_masks = make_non_pad_mask(ilens)
duration_weights = (duration_masks.cast(dtype=paddle.float32) /
duration_masks.cast(dtype=paddle.float32).sum(
axis=1, keepdim=True))
duration_weights /= ds.shape[0]
# apply weight
l1_loss = l1_loss.multiply(out_weights)
l1_loss = l1_loss.masked_select(
out_masks.broadcast_to(l1_loss.shape)).sum()
ssim_loss = ssim_loss.multiply(out_weights)
ssim_loss = ssim_loss.masked_select(
out_masks.broadcast_to(ssim_loss.shape)).sum()
duration_loss = (duration_loss.multiply(duration_weights)
.masked_select(duration_masks).sum())
pitch_masks = out_masks
pitch_weights = out_weights
pitch_loss = pitch_loss.multiply(pitch_weights)
pitch_loss = pitch_loss.masked_select(
pitch_masks.broadcast_to(pitch_loss.shape)).sum()
if e_outs is not None:
energy_loss = energy_loss.multiply(pitch_weights)
energy_loss = energy_loss.masked_select(
pitch_masks.broadcast_to(energy_loss.shape)).sum()
return l1_loss, ssim_loss, duration_loss, pitch_loss, energy_loss, speaker_loss