# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Modified from espnet(https://github.com/espnet/espnet) """Tacotron2 decoder related modules.""" import paddle import paddle.nn.functional as F import six from paddle import nn from paddlespeech.t2s.modules.tacotron2.attentions import AttForwardTA class Prenet(nn.Layer): """Prenet module for decoder of Spectrogram prediction network. This is a module of Prenet in the decoder of Spectrogram prediction network, which described in `Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_. The Prenet preforms nonlinear conversion of inputs before input to auto-regressive lstm, which helps to learn diagonal attentions. Notes ---------- This module alway applies dropout even in evaluation. See the detail in `Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_. .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: https://arxiv.org/abs/1712.05884 """ def __init__(self, idim, n_layers=2, n_units=256, dropout_rate=0.5): """Initialize prenet module. Parameters ---------- idim : int Dimension of the inputs. odim : int Dimension of the outputs. n_layers : int, optional The number of prenet layers. n_units : int, optional The number of prenet units. """ super().__init__() self.dropout_rate = dropout_rate self.prenet = nn.LayerList() for layer in six.moves.range(n_layers): n_inputs = idim if layer == 0 else n_units self.prenet.append( nn.Sequential(nn.Linear(n_inputs, n_units), nn.ReLU())) def forward(self, x): """Calculate forward propagation. Parameters ---------- x : Tensor Batch of input tensors (B, ..., idim). Returns ---------- Tensor Batch of output tensors (B, ..., odim). """ for i in six.moves.range(len(self.prenet)): # F.dropout 引入了随机, tacotron2 的 dropout 是不能去掉的 x = F.dropout(self.prenet[i](x)) return x class Postnet(nn.Layer): """Postnet module for Spectrogram prediction network. This is a module of Postnet in Spectrogram prediction network, which described in `Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_. The Postnet predicts refines the predicted Mel-filterbank of the decoder, which helps to compensate the detail sturcture of spectrogram. .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: https://arxiv.org/abs/1712.05884 """ def __init__( self, idim, odim, n_layers=5, n_chans=512, n_filts=5, dropout_rate=0.5, use_batch_norm=True, ): """Initialize postnet module. Parameters ---------- idim : int Dimension of the inputs. odim : int Dimension of the outputs. n_layers : int, optional The number of layers. n_filts : int, optional The number of filter size. n_units : int, optional The number of filter channels. use_batch_norm : bool, optional Whether to use batch normalization.. dropout_rate : float, optional Dropout rate.. """ super().__init__() self.postnet = nn.LayerList() for layer in six.moves.range(n_layers - 1): ichans = odim if layer == 0 else n_chans ochans = odim if layer == n_layers - 1 else n_chans if use_batch_norm: self.postnet.append( nn.Sequential( nn.Conv1D( ichans, ochans, n_filts, stride=1, padding=(n_filts - 1) // 2, bias_attr=False, ), nn.BatchNorm1D(ochans), nn.Tanh(), nn.Dropout(dropout_rate), )) else: self.postnet.append( nn.Sequential( nn.Conv1D( ichans, ochans, n_filts, stride=1, padding=(n_filts - 1) // 2, bias_attr=False, ), nn.Tanh(), nn.Dropout(dropout_rate), )) ichans = n_chans if n_layers != 1 else odim if use_batch_norm: self.postnet.append( nn.Sequential( nn.Conv1D( ichans, odim, n_filts, stride=1, padding=(n_filts - 1) // 2, bias_attr=False, ), nn.BatchNorm1D(odim), nn.Dropout(dropout_rate), )) else: self.postnet.append( nn.Sequential( nn.Conv1D( ichans, odim, n_filts, stride=1, padding=(n_filts - 1) // 2, bias_attr=False, ), nn.Dropout(dropout_rate), )) def forward(self, xs): """Calculate forward propagation. Parameters ---------- xs : Tensor Batch of the sequences of padded input tensors (B, idim, Tmax). Returns ---------- Tensor Batch of padded output tensor. (B, odim, Tmax). """ for i in six.moves.range(len(self.postnet)): xs = self.postnet[i](xs) return xs class ZoneOutCell(nn.Layer): """ZoneOut Cell module. This is a module of zoneout described in `Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations`_. This code is modified from `eladhoffer/seq2seq.pytorch`_. Examples ---------- >>> lstm = paddle.nn.LSTMCell(16, 32) >>> lstm = ZoneOutCell(lstm, 0.5) .. _`Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations`: https://arxiv.org/abs/1606.01305 .. _`eladhoffer/seq2seq.pytorch`: https://github.com/eladhoffer/seq2seq.pytorch """ def __init__(self, cell, zoneout_rate=0.1): """Initialize zone out cell module. Parameters ---------- cell : nn.Layer: Paddle recurrent cell module e.g. `paddle.nn.LSTMCell`. zoneout_rate : float, optional Probability of zoneout from 0.0 to 1.0. """ super().__init__() self.cell = cell self.hidden_size = cell.hidden_size self.zoneout_rate = zoneout_rate if zoneout_rate > 1.0 or zoneout_rate < 0.0: raise ValueError( "zoneout probability must be in the range from 0.0 to 1.0.") def forward(self, inputs, hidden): """Calculate forward propagation. Parameters ---------- inputs : Tensor Batch of input tensor (B, input_size). hidden : tuple - Tensor: Batch of initial hidden states (B, hidden_size). - Tensor: Batch of initial cell states (B, hidden_size). Returns ---------- Tensor Batch of next hidden states (B, hidden_size). tuple: - Tensor: Batch of next hidden states (B, hidden_size). - Tensor: Batch of next cell states (B, hidden_size). """ # we only use the second output of LSTMCell in paddle _, next_hidden = self.cell(inputs, hidden) next_hidden = self._zoneout(hidden, next_hidden, self.zoneout_rate) # to have the same output format with LSTMCell in paddle return next_hidden[0], next_hidden def _zoneout(self, h, next_h, prob): # apply recursively if isinstance(h, tuple): num_h = len(h) if not isinstance(prob, tuple): prob = tuple([prob] * num_h) return tuple( [self._zoneout(h[i], next_h[i], prob[i]) for i in range(num_h)]) if self.training: mask = paddle.bernoulli(paddle.ones([*paddle.shape(h)]) * prob) return mask * h + (1 - mask) * next_h else: return prob * h + (1 - prob) * next_h class Decoder(nn.Layer): """Decoder module of Spectrogram prediction network. This is a module of decoder of Spectrogram prediction network in Tacotron2, which described in `Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_. The decoder generates the sequence of features from the sequence of the hidden states. .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: https://arxiv.org/abs/1712.05884 """ def __init__( self, idim, odim, att, dlayers=2, dunits=1024, prenet_layers=2, prenet_units=256, postnet_layers=5, postnet_chans=512, postnet_filts=5, output_activation_fn=None, cumulate_att_w=True, use_batch_norm=True, use_concate=True, dropout_rate=0.5, zoneout_rate=0.1, reduction_factor=1, ): """Initialize Tacotron2 decoder module. Parameters ---------- idim : int Dimension of the inputs. odim : int Dimension of the outputs. att nn.Layer Instance of attention class. dlayers int, optional The number of decoder lstm layers. dunits : int, optional The number of decoder lstm units. prenet_layers : int, optional The number of prenet layers. prenet_units : int, optional The number of prenet units. postnet_layers : int, optional The number of postnet layers. postnet_filts : int, optional The number of postnet filter size. postnet_chans : int, optional The number of postnet filter channels. output_activation_fn : nn.Layer, optional Activation function for outputs. cumulate_att_w : bool, optional Whether to cumulate previous attention weight. use_batch_norm : bool, optional Whether to use batch normalization. use_concate : bool, optional Whether to concatenate encoder embedding with decoder lstm outputs. dropout_rate : float, optional Dropout rate. zoneout_rate : float, optional Zoneout rate. reduction_factor : int, optional Reduction factor. """ super().__init__() # store the hyperparameters self.idim = idim self.odim = odim self.att = att self.output_activation_fn = output_activation_fn self.cumulate_att_w = cumulate_att_w self.use_concate = use_concate self.reduction_factor = reduction_factor # check attention type if isinstance(self.att, AttForwardTA): self.use_att_extra_inputs = True else: self.use_att_extra_inputs = False # define lstm network prenet_units = prenet_units if prenet_layers != 0 else odim self.lstm = nn.LayerList() for layer in six.moves.range(dlayers): iunits = idim + prenet_units if layer == 0 else dunits lstm = nn.LSTMCell(iunits, dunits) if zoneout_rate > 0.0: lstm = ZoneOutCell(lstm, zoneout_rate) self.lstm.append(lstm) # define prenet if prenet_layers > 0: self.prenet = Prenet( idim=odim, n_layers=prenet_layers, n_units=prenet_units, dropout_rate=dropout_rate, ) else: self.prenet = None # define postnet if postnet_layers > 0: self.postnet = Postnet( idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=dropout_rate, ) else: self.postnet = None # define projection layers iunits = idim + dunits if use_concate else dunits self.feat_out = nn.Linear( iunits, odim * reduction_factor, bias_attr=False) self.prob_out = nn.Linear(iunits, reduction_factor) # initialize # self.apply(decoder_init) def _zero_state(self, hs): init_hs = paddle.zeros([paddle.shape(hs)[0], self.lstm[0].hidden_size]) return init_hs def forward(self, hs, hlens, ys): """Calculate forward propagation. Parameters ---------- hs : Tensor Batch of the sequences of padded hidden states (B, Tmax, idim). hlens : Tensor(int64) padded Batch of lengths of each input batch (B,). ys : Tensor Batch of the sequences of padded target features (B, Lmax, odim). Returns ---------- Tensor Batch of output tensors after postnet (B, Lmax, odim). Tensor Batch of output tensors before postnet (B, Lmax, odim). Tensor Batch of logits of stop prediction (B, Lmax). Tensor Batch of attention weights (B, Lmax, Tmax). Note ---------- This computation is performed in teacher-forcing manner. """ # thin out frames (B, Lmax, odim) -> (B, Lmax/r, odim) if self.reduction_factor > 1: ys = ys[:, self.reduction_factor - 1::self.reduction_factor] # length list should be list of int # hlens = list(map(int, hlens)) # initialize hidden states of decoder c_list = [self._zero_state(hs)] z_list = [self._zero_state(hs)] for _ in six.moves.range(1, len(self.lstm)): c_list += [self._zero_state(hs)] z_list += [self._zero_state(hs)] prev_out = paddle.zeros([paddle.shape(hs)[0], self.odim]) # initialize attention prev_att_w = None self.att.reset() # loop for an output sequence outs, logits, att_ws = [], [], [] for y in ys.transpose([1, 0, 2]): if self.use_att_extra_inputs: att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w, prev_out) else: att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w) prenet_out = self.prenet( prev_out) if self.prenet is not None else prev_out xs = paddle.concat([att_c, prenet_out], axis=1) # we only use the second output of LSTMCell in paddle _, next_hidden = self.lstm[0](xs, (z_list[0], c_list[0])) z_list[0], c_list[0] = next_hidden for i in six.moves.range(1, len(self.lstm)): # we only use the second output of LSTMCell in paddle _, next_hidden = self.lstm[i](z_list[i - 1], (z_list[i], c_list[i])) z_list[i], c_list[i] = next_hidden zcs = (paddle.concat([z_list[-1], att_c], axis=1) if self.use_concate else z_list[-1]) outs += [ self.feat_out(zcs).reshape([paddle.shape(hs)[0], self.odim, -1]) ] logits += [self.prob_out(zcs)] att_ws += [att_w] # teacher forcing prev_out = y if self.cumulate_att_w and prev_att_w is not None: prev_att_w = prev_att_w + att_w # Note: error when use += else: prev_att_w = att_w # (B, Lmax) logits = paddle.concat(logits, axis=1) # (B, odim, Lmax) before_outs = paddle.concat(outs, axis=2) # (B, Lmax, Tmax) att_ws = paddle.stack(att_ws, axis=1) if self.reduction_factor > 1: # (B, odim, Lmax) before_outs = before_outs.reshape( [paddle.shape(before_outs)[0], self.odim, -1]) if self.postnet is not None: # (B, odim, Lmax) after_outs = before_outs + self.postnet(before_outs) else: after_outs = before_outs # (B, Lmax, odim) before_outs = before_outs.transpose([0, 2, 1]) # (B, Lmax, odim) after_outs = after_outs.transpose([0, 2, 1]) logits = logits # apply activation function for scaling if self.output_activation_fn is not None: before_outs = self.output_activation_fn(before_outs) after_outs = self.output_activation_fn(after_outs) return after_outs, before_outs, logits, att_ws def inference( self, h, threshold=0.5, minlenratio=0.0, maxlenratio=10.0, use_att_constraint=False, backward_window=None, forward_window=None, ): """Generate the sequence of features given the sequences of characters. Parameters ---------- h : Tensor Input sequence of encoder hidden states (T, C). threshold : float, optional Threshold to stop generation. minlenratio : float, optional Minimum length ratio. If set to 1.0 and the length of input is 10, the minimum length of outputs will be 10 * 1 = 10. minlenratio : float, optional Minimum length ratio. If set to 10 and the length of input is 10, the maximum length of outputs will be 10 * 10 = 100. use_att_constraint : bool Whether to apply attention constraint introduced in `Deep Voice 3`_. backward_window : int Backward window size in attention constraint. forward_window : int Forward window size in attention constraint. Returns ---------- Tensor Output sequence of features (L, odim). Tensor Output sequence of stop probabilities (L,). Tensor Attention weights (L, T). Note ---------- This computation is performed in auto-regressive manner. .. _`Deep Voice 3`: https://arxiv.org/abs/1710.07654 """ # setup assert len(paddle.shape(h)) == 2 hs = h.unsqueeze(0) ilens = paddle.shape(h)[0] maxlen = int(paddle.shape(h)[0] * maxlenratio) minlen = int(paddle.shape(h)[0] * minlenratio) # initialize hidden states of decoder c_list = [self._zero_state(hs)] z_list = [self._zero_state(hs)] for _ in six.moves.range(1, len(self.lstm)): c_list += [self._zero_state(hs)] z_list += [self._zero_state(hs)] prev_out = paddle.zeros([1, self.odim]) # initialize attention prev_att_w = None self.att.reset() # setup for attention constraint if use_att_constraint: last_attended_idx = 0 else: last_attended_idx = None # loop for an output sequence idx = 0 outs, att_ws, probs = [], [], [] while True: # updated index idx += self.reduction_factor # decoder calculation if self.use_att_extra_inputs: att_c, att_w = self.att( hs, ilens, z_list[0], prev_att_w, prev_out, last_attended_idx=last_attended_idx, backward_window=backward_window, forward_window=forward_window, ) else: att_c, att_w = self.att( hs, ilens, z_list[0], prev_att_w, last_attended_idx=last_attended_idx, backward_window=backward_window, forward_window=forward_window, ) att_ws += [att_w] prenet_out = self.prenet( prev_out) if self.prenet is not None else prev_out xs = paddle.concat([att_c, prenet_out], axis=1) # we only use the second output of LSTMCell in paddle _, next_hidden = self.lstm[0](xs, (z_list[0], c_list[0])) z_list[0], c_list[0] = next_hidden for i in six.moves.range(1, len(self.lstm)): # we only use the second output of LSTMCell in paddle _, next_hidden = self.lstm[i](z_list[i - 1], (z_list[i], c_list[i])) z_list[i], c_list[i] = next_hidden zcs = (paddle.concat([z_list[-1], att_c], axis=1) if self.use_concate else z_list[-1]) # [(1, odim, r), ...] outs += [self.feat_out(zcs).reshape([1, self.odim, -1])] # [(r), ...] probs += [F.sigmoid(self.prob_out(zcs))[0]] if self.output_activation_fn is not None: prev_out = self.output_activation_fn( outs[-1][:, :, -1]) # (1, odim) else: prev_out = outs[-1][:, :, -1] # (1, odim) if self.cumulate_att_w and prev_att_w is not None: prev_att_w = prev_att_w + att_w # Note: error when use += else: prev_att_w = att_w if use_att_constraint: last_attended_idx = int(att_w.argmax()) # check whether to finish generation if sum(paddle.cast(probs[-1] >= threshold, 'int64')) > 0 or idx >= maxlen: # check mininum length if idx < minlen: continue # (1, odim, L) outs = paddle.concat(outs, axis=2) if self.postnet is not None: # (1, odim, L) outs = outs + self.postnet(outs) # (L, odim) outs = outs.transpose([0, 2, 1]).squeeze(0) probs = paddle.concat(probs, axis=0) att_ws = paddle.concat(att_ws, axis=0) break if self.output_activation_fn is not None: outs = self.output_activation_fn(outs) return outs, probs, att_ws def calculate_all_attentions(self, hs, hlens, ys): """Calculate all of the attention weights. Parameters ---------- hs : Tensor Batch of the sequences of padded hidden states (B, Tmax, idim). hlens : Tensor(int64) Batch of lengths of each input batch (B,). ys : Tensor Batch of the sequences of padded target features (B, Lmax, odim). Returns ---------- numpy.ndarray Batch of attention weights (B, Lmax, Tmax). Note ---------- This computation is performed in teacher-forcing manner. """ # thin out frames (B, Lmax, odim) -> (B, Lmax/r, odim) if self.reduction_factor > 1: ys = ys[:, self.reduction_factor - 1::self.reduction_factor] # length list should be list of int hlens = list(map(int, hlens)) # initialize hidden states of decoder c_list = [self._zero_state(hs)] z_list = [self._zero_state(hs)] for _ in six.moves.range(1, len(self.lstm)): c_list += [self._zero_state(hs)] z_list += [self._zero_state(hs)] prev_out = paddle.zeros([paddle.shape(hs)[0], self.odim]) # initialize attention prev_att_w = None self.att.reset() # loop for an output sequence att_ws = [] for y in ys.transpose([1, 0, 2]): if self.use_att_extra_inputs: att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w, prev_out) else: att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w) att_ws += [att_w] prenet_out = self.prenet( prev_out) if self.prenet is not None else prev_out xs = paddle.concat([att_c, prenet_out], axis=1) # we only use the second output of LSTMCell in paddle _, next_hidden = self.lstm[0](xs, (z_list[0], c_list[0])) z_list[0], c_list[0] = next_hidden for i in six.moves.range(1, len(self.lstm)): z_list[i], c_list[i] = self.lstm[i](z_list[i - 1], (z_list[i], c_list[i])) # teacher forcing prev_out = y if self.cumulate_att_w and prev_att_w is not None: # Note: error when use += prev_att_w = prev_att_w + att_w else: prev_att_w = att_w # (B, Lmax, Tmax) att_ws = paddle.stack(att_ws, axis=1) return att_ws