# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tacotron 2 related modules for paddle""" import logging from typing import Dict from typing import Optional from typing import Tuple import paddle import paddle.nn.functional as F from paddle import nn from typeguard import check_argument_types from paddlespeech.t2s.modules.nets_utils import initialize from paddlespeech.t2s.modules.nets_utils import make_pad_mask from paddlespeech.t2s.modules.tacotron2.attentions import AttForward from paddlespeech.t2s.modules.tacotron2.attentions import AttForwardTA from paddlespeech.t2s.modules.tacotron2.attentions import AttLoc from paddlespeech.t2s.modules.tacotron2.decoder import Decoder from paddlespeech.t2s.modules.tacotron2.encoder import Encoder class Tacotron2(nn.Layer): """Tacotron2 module for end-to-end text-to-speech. This is a module of Spectrogram prediction network in Tacotron2 described in `Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_, which converts the sequence of characters into the sequence of Mel-filterbanks. .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: https://arxiv.org/abs/1712.05884 """ def __init__( self, # network structure related idim: int, odim: int, embed_dim: int=512, elayers: int=1, eunits: int=512, econv_layers: int=3, econv_chans: int=512, econv_filts: int=5, atype: str="location", adim: int=512, aconv_chans: int=32, aconv_filts: int=15, cumulate_att_w: bool=True, dlayers: int=2, dunits: int=1024, prenet_layers: int=2, prenet_units: int=256, postnet_layers: int=5, postnet_chans: int=512, postnet_filts: int=5, output_activation: str=None, use_batch_norm: bool=True, use_concate: bool=True, use_residual: bool=False, reduction_factor: int=1, # extra embedding related spk_num: Optional[int]=None, lang_num: Optional[int]=None, spk_embed_dim: Optional[int]=None, spk_embed_integration_type: str="concat", dropout_rate: float=0.5, zoneout_rate: float=0.1, # training related init_type: str="xavier_uniform", ): """Initialize Tacotron2 module. Parameters ---------- idim : int Dimension of the inputs. odim : int Dimension of the outputs. embed_dim : int Dimension of the token embedding. elayers : int Number of encoder blstm layers. eunits : int Number of encoder blstm units. econv_layers : int Number of encoder conv layers. econv_filts : int Number of encoder conv filter size. econv_chans : int Number of encoder conv filter channels. dlayers : int Number of decoder lstm layers. dunits : int Number of decoder lstm units. prenet_layers : int Number of prenet layers. prenet_units : int Number of prenet units. postnet_layers : int Number of postnet layers. postnet_filts : int Number of postnet filter size. postnet_chans : int Number of postnet filter channels. output_activation : str Name of activation function for outputs. adim : int Number of dimension of mlp in attention. aconv_chans : int Number of attention conv filter channels. aconv_filts : int Number of attention conv filter size. cumulate_att_w : bool Whether to cumulate previous attention weight. use_batch_norm : bool Whether to use batch normalization. use_concate : bool Whether to concat enc outputs w/ dec lstm outputs. reduction_factor : int Reduction factor. spk_num : Optional[int] Number of speakers. If set to > 1, assume that the sids will be provided as the input and use sid embedding layer. lang_num : Optional[int] Number of languages. If set to > 1, assume that the lids will be provided as the input and use sid embedding layer. spk_embed_dim : Optional[int] Speaker embedding dimension. If set to > 0, assume that spk_emb will be provided as the input. spk_embed_integration_type : str How to integrate speaker embedding. dropout_rate : float Dropout rate. zoneout_rate : float Zoneout rate. """ assert check_argument_types() super().__init__() # store hyperparameters self.idim = idim self.odim = odim self.eos = idim - 1 self.cumulate_att_w = cumulate_att_w self.reduction_factor = reduction_factor # define activation function for the final output if output_activation is None: self.output_activation_fn = None elif hasattr(F, output_activation): self.output_activation_fn = getattr(F, output_activation) else: raise ValueError(f"there is no such an activation function. " f"({output_activation})") # set padding idx padding_idx = 0 self.padding_idx = padding_idx # initialize parameters initialize(self, init_type) # define network modules self.enc = Encoder( idim=idim, embed_dim=embed_dim, elayers=elayers, eunits=eunits, econv_layers=econv_layers, econv_chans=econv_chans, econv_filts=econv_filts, use_batch_norm=use_batch_norm, use_residual=use_residual, dropout_rate=dropout_rate, padding_idx=padding_idx, ) self.spk_num = None if spk_num is not None and spk_num > 1: self.spk_num = spk_num self.sid_emb = nn.Embedding(spk_num, eunits) self.lang_num = None if lang_num is not None and lang_num > 1: self.lang_num = lang_num self.lid_emb = nn.Embedding(lang_num, eunits) self.spk_embed_dim = None if spk_embed_dim is not None and spk_embed_dim > 0: self.spk_embed_dim = spk_embed_dim self.spk_embed_integration_type = spk_embed_integration_type if self.spk_embed_dim is None: dec_idim = eunits elif self.spk_embed_integration_type == "concat": dec_idim = eunits + spk_embed_dim elif self.spk_embed_integration_type == "add": dec_idim = eunits self.projection = nn.Linear(self.spk_embed_dim, eunits) else: raise ValueError(f"{spk_embed_integration_type} is not supported.") if atype == "location": att = AttLoc(dec_idim, dunits, adim, aconv_chans, aconv_filts) elif atype == "forward": att = AttForward(dec_idim, dunits, adim, aconv_chans, aconv_filts) if self.cumulate_att_w: logging.warning("cumulation of attention weights is disabled " "in forward attention.") self.cumulate_att_w = False elif atype == "forward_ta": att = AttForwardTA(dec_idim, dunits, adim, aconv_chans, aconv_filts, odim) if self.cumulate_att_w: logging.warning("cumulation of attention weights is disabled " "in forward attention.") self.cumulate_att_w = False else: raise NotImplementedError("Support only location or forward") self.dec = Decoder( idim=dec_idim, odim=odim, att=att, dlayers=dlayers, dunits=dunits, prenet_layers=prenet_layers, prenet_units=prenet_units, postnet_layers=postnet_layers, postnet_chans=postnet_chans, postnet_filts=postnet_filts, output_activation_fn=self.output_activation_fn, cumulate_att_w=self.cumulate_att_w, use_batch_norm=use_batch_norm, use_concate=use_concate, dropout_rate=dropout_rate, zoneout_rate=zoneout_rate, reduction_factor=reduction_factor, ) nn.initializer.set_global_initializer(None) def forward( self, text: paddle.Tensor, text_lengths: paddle.Tensor, speech: paddle.Tensor, speech_lengths: paddle.Tensor, spk_emb: Optional[paddle.Tensor]=None, spk_id: Optional[paddle.Tensor]=None, lang_id: Optional[paddle.Tensor]=None ) -> Tuple[paddle.Tensor, Dict[str, paddle.Tensor], paddle.Tensor]: """Calculate forward propagation. Parameters ---------- text : Tensor(int64) Batch of padded character ids (B, T_text). text_lengths : Tensor(int64) Batch of lengths of each input batch (B,). speech : Tensor Batch of padded target features (B, T_feats, odim). speech_lengths : Tensor(int64) Batch of the lengths of each target (B,). spk_emb : Optional[Tensor] Batch of speaker embeddings (B, spk_embed_dim). spk_id : Optional[Tensor] Batch of speaker IDs (B, 1). lang_id : Optional[Tensor] Batch of language IDs (B, 1). Returns ---------- Tensor Loss scalar value. Dict Statistics to be monitored. Tensor Weight value if not joint training else model outputs. """ text = text[:, :text_lengths.max()] speech = speech[:, :speech_lengths.max()] batch_size = paddle.shape(text)[0] # Add eos at the last of sequence xs = F.pad(text, [0, 0, 0, 1], "constant", self.padding_idx) for i, l in enumerate(text_lengths): xs[i, l] = self.eos ilens = text_lengths + 1 ys = speech olens = speech_lengths # make labels for stop prediction labels = make_pad_mask(olens - 1) # bool 类型无法切片 labels = paddle.cast(labels, dtype='float32') labels = F.pad(labels, [0, 0, 0, 1], "constant", 1.0) # calculate tacotron2 outputs after_outs, before_outs, logits, att_ws = self._forward( xs=xs, ilens=ilens, ys=ys, olens=olens, spk_emb=spk_emb, spk_id=spk_id, lang_id=lang_id, ) # modify mod part of groundtruth if self.reduction_factor > 1: assert olens.ge(self.reduction_factor).all( ), "Output length must be greater than or equal to reduction factor." olens = olens - olens % self.reduction_factor max_out = max(olens) ys = ys[:, :max_out] labels = labels[:, :max_out] labels = paddle.scatter(labels, 1, (olens - 1).unsqueeze(1), 1.0) return after_outs, before_outs, logits, ys, labels, olens, att_ws, ilens def _forward( self, xs: paddle.Tensor, ilens: paddle.Tensor, ys: paddle.Tensor, olens: paddle.Tensor, spk_emb: paddle.Tensor, spk_id: paddle.Tensor, lang_id: paddle.Tensor, ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: hs, hlens = self.enc(xs, ilens) if self.spk_num is not None: sid_embs = self.sid_emb(spk_id.reshape([-1])) hs = hs + sid_embs.unsqueeze(1) if self.lang_num is not None: lid_embs = self.lid_emb(lang_id.reshape([-1])) hs = hs + lid_embs.unsqueeze(1) if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, spk_emb) return self.dec(hs, hlens, ys) def inference( self, text: paddle.Tensor, speech: Optional[paddle.Tensor]=None, spk_emb: Optional[paddle.Tensor]=None, spk_id: Optional[paddle.Tensor]=None, lang_id: Optional[paddle.Tensor]=None, threshold: float=0.5, minlenratio: float=0.0, maxlenratio: float=10.0, use_att_constraint: bool=False, backward_window: int=1, forward_window: int=3, use_teacher_forcing: bool=False, ) -> Dict[str, paddle.Tensor]: """Generate the sequence of features given the sequences of characters. Parameters ---------- text Tensor(int64) Input sequence of characters (T_text,). speech : Optional[Tensor] Feature sequence to extract style (N, idim). spk_emb : ptional[Tensor] Speaker embedding (spk_embed_dim,). spk_id : Optional[Tensor] Speaker ID (1,). lang_id : Optional[Tensor] Language ID (1,). threshold : float Threshold in inference. minlenratio : float Minimum length ratio in inference. maxlenratio : float Maximum length ratio in inference. use_att_constraint : bool Whether to apply attention constraint. backward_window : int Backward window in attention constraint. forward_window : int Forward window in attention constraint. use_teacher_forcing : bool Whether to use teacher forcing. Return ---------- Dict[str, Tensor] Output dict including the following items: * feat_gen (Tensor): Output sequence of features (T_feats, odim). * prob (Tensor): Output sequence of stop probabilities (T_feats,). * att_w (Tensor): Attention weights (T_feats, T). """ x = text y = speech # add eos at the last of sequence x = F.pad(x, [0, 1], "constant", self.eos) # inference with teacher forcing if use_teacher_forcing: assert speech is not None, "speech must be provided with teacher forcing." xs, ys = x.unsqueeze(0), y.unsqueeze(0) spk_emb = None if spk_emb is None else spk_emb.unsqueeze(0) ilens = paddle.shape(xs)[1] olens = paddle.shape(ys)[1] outs, _, _, att_ws = self._forward( xs=xs, ilens=ilens, ys=ys, olens=olens, spk_emb=spk_emb, spk_id=spk_id, lang_id=lang_id, ) return dict(feat_gen=outs[0], att_w=att_ws[0]) # inference h = self.enc.inference(x) if self.spk_num is not None: sid_emb = self.sid_emb(spk_id.reshape([-1])) h = h + sid_emb if self.lang_num is not None: lid_emb = self.lid_emb(lang_id.reshape([-1])) h = h + lid_emb if self.spk_embed_dim is not None: hs, spk_emb = h.unsqueeze(0), spk_emb.unsqueeze(0) h = self._integrate_with_spk_embed(hs, spk_emb)[0] out, prob, att_w = self.dec.inference( h, threshold=threshold, minlenratio=minlenratio, maxlenratio=maxlenratio, use_att_constraint=use_att_constraint, backward_window=backward_window, forward_window=forward_window, ) return dict(feat_gen=out, prob=prob, att_w=att_w) def _integrate_with_spk_embed(self, hs: paddle.Tensor, spk_emb: paddle.Tensor) -> paddle.Tensor: """Integrate speaker embedding with hidden states. Parameters ---------- hs : Tensor Batch of hidden state sequences (B, Tmax, eunits). spk_emb : Tensor Batch of speaker embeddings (B, spk_embed_dim). Returns ---------- Tensor Batch of integrated hidden state sequences (B, Tmax, eunits) if integration_type is "add" else (B, Tmax, eunits + spk_embed_dim). """ if self.spk_embed_integration_type == "add": # apply projection and then add to hidden states spk_emb = self.projection(F.normalize(spk_emb)) hs = hs + spk_emb.unsqueeze(1) elif self.spk_embed_integration_type == "concat": # concat hidden states with spk embeds spk_emb = F.normalize(spk_emb).unsqueeze(1).expand( -1, paddle.shape(hs)[1], -1) hs = paddle.concat([hs, spk_emb], axis=-1) else: raise NotImplementedError("support only add or concat.") return hs class Tacotron2Inference(nn.Layer): def __init__(self, normalizer, model): super().__init__() self.normalizer = normalizer self.acoustic_model = model def forward(self, text, spk_id=None, spk_emb=None): out = self.acoustic_model.inference( text, spk_id=spk_id, spk_emb=spk_emb) normalized_mel = out["feat_gen"] logmel = self.normalizer.inverse(normalized_mel) return logmel