This is a module of text-to-speech Transformer described in `Neural Speech Synthesis with Transformer Network`_, which convert the sequence of tokens into the sequence of Mel-filterbanks. .. _`Neural Speech Synthesis with Transformer Network`: https://arxiv.org/pdf/1809.08895.pdf Parameters ---------- idim : int Dimension of the inputs. odim : int Dimension of the outputs. embed_dim : int, optional Dimension of character embedding. eprenet_conv_layers : int, optional Number of encoder prenet convolution layers. eprenet_conv_chans : int, optional Number of encoder prenet convolution channels. eprenet_conv_filts : int, optional Filter size of encoder prenet convolution. dprenet_layers : int, optional Number of decoder prenet layers. dprenet_units : int, optional Number of decoder prenet hidden units. elayers : int, optional Number of encoder layers. eunits : int, optional Number of encoder hidden units. adim : int, optional Number of attention transformation dimensions. aheads : int, optional Number of heads for multi head attention. dlayers : int, optional Number of decoder layers. dunits : int, optional Number of decoder hidden units. postnet_layers : int, optional Number of postnet layers. postnet_chans : int, optional Number of postnet channels. postnet_filts : int, optional Filter size of postnet. use_scaled_pos_enc : pool, optional Whether to use trainable scaled positional encoding. use_batch_norm : bool, optional Whether to use batch normalization in encoder prenet. encoder_normalize_before : bool, optional Whether to perform layer normalization before encoder block. decoder_normalize_before : bool, optional Whether to perform layer normalization before decoder block. encoder_concat_after : bool, optional Whether to concatenate attention layer's input and output in encoder. decoder_concat_after : bool, optional Whether to concatenate attention layer's input and output in decoder. positionwise_layer_type : str, optional Position-wise operation type. positionwise_conv_kernel_size : int, optional Kernel size in position wise conv 1d. reduction_factor : int, optional Reduction factor. spk_embed_dim : int, optional Number of speaker embedding dimenstions. spk_embed_integration_type : str, optional How to integrate speaker embedding. use_gst : str, optional Whether to use global style token. gst_tokens : int, optional The number of GST embeddings. gst_heads : int, optional The number of heads in GST multihead attention. gst_conv_layers : int, optional The number of conv layers in GST. gst_conv_chans_list : Sequence[int], optional List of the number of channels of conv layers in GST. gst_conv_kernel_size : int, optional Kernal size of conv layers in GST. gst_conv_stride : int, optional Stride size of conv layers in GST. gst_gru_layers : int, optional The number of GRU layers in GST. gst_gru_units : int, optional The number of GRU units in GST. transformer_lr : float, optional Initial value of learning rate. transformer_warmup_steps : int, optional Optimizer warmup steps. transformer_enc_dropout_rate : float, optional Dropout rate in encoder except attention and positional encoding. transformer_enc_positional_dropout_rate : float, optional Dropout rate after encoder positional encoding. transformer_enc_attn_dropout_rate : float, optional Dropout rate in encoder self-attention module. transformer_dec_dropout_rate : float, optional Dropout rate in decoder except attention & positional encoding. transformer_dec_positional_dropout_rate : float, optional Dropout rate after decoder positional encoding. transformer_dec_attn_dropout_rate : float, optional Dropout rate in deocoder self-attention module. transformer_enc_dec_attn_dropout_rate : float, optional Dropout rate in encoder-deocoder attention module. init_type : str, optional How to initialize transformer parameters. init_enc_alpha : float, optional Initial value of alpha in scaled pos encoding of the encoder. init_dec_alpha : float, optional Initial value of alpha in scaled pos encoding of the decoder. eprenet_dropout_rate : float, optional Dropout rate in encoder prenet. dprenet_dropout_rate : float, optional Dropout rate in decoder prenet. postnet_dropout_rate : float, optional Dropout rate in postnet. use_masking : bool, optional Whether to apply masking for padded part in loss calculation. use_weighted_masking : bool, optional Whether to apply weighted masking in loss calculation. bce_pos_weight : float, optional Positive sample weight in bce calculation (only for use_masking=true). loss_type : str, optional How to calculate loss. use_guided_attn_loss : bool, optional Whether to use guided attention loss. num_heads_applied_guided_attn : int, optional Number of heads in each layer to apply guided attention loss. num_layers_applied_guided_attn : int, optional Number of layers to apply guided attention loss. List of module names to apply guided attention loss. """ def __init__( self, # network structure related idim: int, odim: int, embed_dim: int=512, eprenet_conv_layers: int=3, eprenet_conv_chans: int=256, eprenet_conv_filts: int=5, dprenet_layers: int=2, dprenet_units: int=256, elayers: int=6, eunits: int=1024, adim: int=512, aheads: int=4, dlayers: int=6, dunits: int=1024, postnet_layers: int=5, postnet_chans: int=256, postnet_filts: int=5, positionwise_layer_type: str="conv1d", positionwise_conv_kernel_size: int=1, use_scaled_pos_enc: bool=True, use_batch_norm: bool=True, encoder_normalize_before: bool=True, decoder_normalize_before: bool=True, encoder_concat_after: bool=False, decoder_concat_after: bool=False, reduction_factor: int=1, spk_embed_dim: int=None, spk_embed_integration_type: str="add", use_gst: bool=False, gst_tokens: int=10, gst_heads: int=4, gst_conv_layers: int=6, gst_conv_chans_list: Sequence[int]=(32, 32, 64, 64, 128, 128), gst_conv_kernel_size: int=3, gst_conv_stride: int=2, gst_gru_layers: int=1, gst_gru_units: int=128, # training related transformer_enc_dropout_rate: float=0.1, transformer_enc_positional_dropout_rate: float=0.1, transformer_enc_attn_dropout_rate: float=0.1, transformer_dec_dropout_rate: float=0.1, transformer_dec_positional_dropout_rate: float=0.1, transformer_dec_attn_dropout_rate: float=0.1, transformer_enc_dec_attn_dropout_rate: float=0.1, eprenet_dropout_rate: float=0.5, dprenet_dropout_rate: float=0.5, postnet_dropout_rate: float=0.5, init_type: str="xavier_uniform", init_enc_alpha: float=1.0, init_dec_alpha: float=1.0, use_guided_attn_loss: bool=True, num_heads_applied_guided_attn: int=2, num_layers_applied_guided_attn: int=2, ): """Initialize Transformer module.""" assert check_argument_types() super().__init__() # store hyperparameters self.idim = idim self.odim = odim self.eos = idim - 1 self.spk_embed_dim = spk_embed_dim self.reduction_factor = reduction_factor self.use_gst = use_gst self.use_scaled_pos_enc = use_scaled_pos_enc self.use_guided_attn_loss = use_guided_attn_loss if self.use_guided_attn_loss: if num_layers_applied_guided_attn == -1: self.num_layers_applied_guided_attn = elayers else: self.num_layers_applied_guided_attn = num_layers_applied_guided_attn if num_heads_applied_guided_attn == -1: self.num_heads_applied_guided_attn = aheads else: self.num_heads_applied_guided_attn = num_heads_applied_guided_attn if self.spk_embed_dim is not None: self.spk_embed_integration_type = spk_embed_integration_type # use idx 0 as padding idx self.padding_idx = 0 # set_global_initializer 会影响后面的全局,包括 create_parameter initialize(self, init_type) # get positional encoding layer type transformer_pos_enc_layer_type = "scaled_abs_pos" if self.use_scaled_pos_enc else "abs_pos" # define transformer encoder if eprenet_conv_layers != 0: # encoder prenet encoder_input_layer = nn.Sequential( EncoderPrenet( idim=idim, embed_dim=embed_dim, elayers=0, econv_layers=eprenet_conv_layers, econv_chans=eprenet_conv_chans, econv_filts=eprenet_conv_filts, use_batch_norm=use_batch_norm, dropout_rate=eprenet_dropout_rate, padding_idx=self.padding_idx, ), nn.Linear(eprenet_conv_chans, adim), ) else: encoder_input_layer = nn.Embedding( num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx) self.encoder = TransformerEncoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, pos_enc_layer_type=transformer_pos_enc_layer_type, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) # define GST if self.use_gst: self.gst = StyleEncoder( idim=odim, # the input is mel-spectrogram gst_tokens=gst_tokens, gst_token_dim=adim, gst_heads=gst_heads, conv_layers=gst_conv_layers, conv_chans_list=gst_conv_chans_list, conv_kernel_size=gst_conv_kernel_size, conv_stride=gst_conv_stride, gru_layers=gst_gru_layers, gru_units=gst_gru_units, ) # define projection layer if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = nn.Linear(self.spk_embed_dim, adim) else: self.projection = nn.Linear(adim + self.spk_embed_dim, adim) # define transformer decoder if dprenet_layers != 0: # decoder prenet decoder_input_layer = nn.Sequential( DecoderPrenet( idim=odim, n_layers=dprenet_layers, n_units=dprenet_units, dropout_rate=dprenet_dropout_rate, ), nn.Linear(dprenet_units, adim), ) else: decoder_input_layer = "linear" # get positional encoding class pos_enc_class = (ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding) self.decoder = Decoder( odim=odim, # odim is needed when no prenet is used attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, self_attention_dropout_rate=transformer_dec_attn_dropout_rate, src_attention_dropout_rate=transformer_enc_dec_attn_dropout_rate, input_layer=decoder_input_layer, use_output_layer=False, pos_enc_class=pos_enc_class, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, ) # define final projection self.feat_out = nn.Linear(adim, odim * reduction_factor) self.prob_out = nn.Linear(adim, reduction_factor) # define postnet self.postnet = (None if postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=postnet_dropout_rate, )) # 闭合的 initialize() 中的 set_global_initializer 的作用域,防止其影响到 self._reset_parameters() nn.initializer.set_global_initializer(None) self._reset_parameters( init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha, ) def _reset_parameters(self, init_enc_alpha: float, init_dec_alpha: float): # initialize alpha in scaled positional encoding if self.use_scaled_pos_enc: init_enc_alpha = paddle.to_tensor(init_enc_alpha) self.encoder.embed[-1].alpha = paddle.create_parameter( shape=init_enc_alpha.shape, dtype=str(init_enc_alpha.numpy().dtype), default_initializer=paddle.nn.initializer.Assign( init_enc_alpha)) init_dec_alpha = paddle.to_tensor(init_dec_alpha) self.decoder.embed[-1].alpha = paddle.create_parameter( shape=init_dec_alpha.shape, dtype=str(init_dec_alpha.numpy().dtype), default_initializer=paddle.nn.initializer.Assign( init_dec_alpha)) def forward( self, text: paddle.Tensor, text_lengths: paddle.Tensor, speech: paddle.Tensor, speech_lengths: paddle.Tensor, spk_emb: paddle.Tensor=None, ) -> Tuple[paddle.Tensor, Dict[str, paddle.Tensor], paddle.Tensor]: """Calculate forward propagation. Parameters ---------- text : Tensor(int64) Batch of padded character ids (B, Tmax). text_lengths : Tensor(int64) Batch of lengths of each input batch (B,). speech : Tensor Batch of padded target features (B, Lmax, odim). speech_lengths : Tensor(int64) Batch of the lengths of each target (B,). spk_emb : Tensor, optional Batch of speaker embeddings (B, spk_embed_dim). Returns ---------- Tensor Loss scalar value. Dict Statistics to be monitored. """ # input of embedding must be int64 text_lengths = paddle.cast(text_lengths, 'int64') # Add eos at the last of sequence text = numpy.pad(text.numpy(), ((0, 0), (0, 1)), 'constant') xs = paddle.to_tensor(text, dtype='int64') for i, l in enumerate(text_lengths): xs[i, l] = self.eos ilens = text_lengths + 1 ys = speech olens = paddle.cast(speech_lengths, 'int64') # make labels for stop prediction labels = make_pad_mask(olens - 1) labels = numpy.pad( labels.numpy(), ((0, 0), (0, 1)), 'constant', constant_values=1.0) labels = paddle.to_tensor(labels) labels = paddle.cast(labels, dtype="float32") # labels = F.pad(labels, [0, 1], "constant", 1.0) # calculate transformer outputs after_outs, before_outs, logits = self._forward(xs, ilens, ys, olens, spk_emb) # modifiy mod part of groundtruth if self.reduction_factor > 1: olens = paddle.to_tensor( [olen - olen % self.reduction_factor for olen in olens.numpy()]) max_olen = max(olens) ys = ys[:, :max_olen] labels = labels[:, :max_olen] labels[:, -1] = 1.0 # make sure at least one frame has 1 need_dict = {} need_dict['encoder'] = self.encoder need_dict['decoder'] = self.decoder need_dict[ 'num_heads_applied_guided_attn'] = self.num_heads_applied_guided_attn need_dict[ 'num_layers_applied_guided_attn'] = self.num_layers_applied_guided_attn need_dict['use_scaled_pos_enc'] = self.use_scaled_pos_enc return after_outs, before_outs, logits, ys, labels, olens, ilens, need_dict def _forward( self, xs: paddle.Tensor, ilens: paddle.Tensor, ys: paddle.Tensor, olens: paddle.Tensor, spk_emb: paddle.Tensor, ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: # forward encoder x_masks = self._source_mask(ilens) hs, h_masks = self.encoder(xs, x_masks) # integrate with GST if self.use_gst: style_embs = self.gst(ys) hs = hs + style_embs.unsqueeze(1) # integrate speaker embedding if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, spk_emb) # thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim) if self.reduction_factor > 1: ys_in = ys[:, self.reduction_factor - 1::self.reduction_factor] olens_in = olens.new( [olen // self.reduction_factor for olen in olens]) else: ys_in, olens_in = ys, olens # add first zero frame and remove last frame for auto-regressive ys_in = self._add_first_frame_and_remove_last_frame(ys_in) # forward decoder y_masks = self._target_mask(olens_in) zs, _ = self.decoder(ys_in, y_masks, hs, h_masks) # (B, Lmax//r, odim * r) -> (B, Lmax//r * r, odim) before_outs = self.feat_out(zs).reshape([zs.shape[0], -1, self.odim]) # (B, Lmax//r, r) -> (B, Lmax//r * r) logits = self.prob_out(zs).reshape([zs.shape[0], -1]) # postnet -> (B, Lmax//r * r, odim) if self.postnet is None: after_outs = before_outs else: after_outs = before_outs + self.postnet( before_outs.transpose([0, 2, 1])).transpose([0, 2, 1]) return after_outs, before_outs, logits def inference( self, text: paddle.Tensor, speech: paddle.Tensor=None, spk_emb: paddle.Tensor=None, threshold: float=0.5, minlenratio: float=0.0, maxlenratio: float=10.0, use_teacher_forcing: bool=False, ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: """Generate the sequence of features given the sequences of characters. Parameters ---------- text : Tensor(int64) Input sequence of characters (T,). speech : Tensor, optional Feature sequence to extract style (N, idim). spk_emb : Tensor, optional Speaker embedding vector (spk_embed_dim,). threshold : float, optional Threshold in inference. minlenratio : float, optional Minimum length ratio in inference. maxlenratio : float, optional Maximum length ratio in inference. use_teacher_forcing : bool, optional Whether to use teacher forcing. Returns ---------- Tensor Output sequence of features (L, odim). Tensor Output sequence of stop probabilities (L,). Tensor Encoder-decoder (source) attention weights (#layers, #heads, L, T). """ # input of embedding must be int64 y = speech # add eos at the last of sequence text = numpy.pad( text.numpy(), (0, 1), 'constant', constant_values=self.eos) x = paddle.to_tensor(text, dtype='int64') # inference with teacher forcing if use_teacher_forcing: assert speech is not None, "speech must be provided with teacher forcing." # get teacher forcing outputs xs, ys = x.unsqueeze(0), y.unsqueeze(0) spk_emb = None if spk_emb is None else spk_emb.unsqueeze(0) ilens = paddle.to_tensor( [xs.shape[1]], dtype=paddle.int64, place=xs.place) olens = paddle.to_tensor( [ys.shape[1]], dtype=paddle.int64, place=ys.place) outs, *_ = self._forward(xs, ilens, ys, olens, spk_emb) # get attention weights att_ws = [] for i in range(len(self.decoder.decoders)): att_ws += [self.decoder.decoders[i].src_attn.attn] # (B, L, H, T_out, T_in) att_ws = paddle.stack(att_ws, axis=1) return outs[0], None, att_ws[0] # forward encoder xs = x.unsqueeze(0) hs, _ = self.encoder(xs, None) # integrate GST if self.use_gst: style_embs = self.gst(y.unsqueeze(0)) hs = hs + style_embs.unsqueeze(1) # integrate speaker embedding if spk_emb is not None: spk_emb = spk_emb.unsqueeze(0) hs = self._integrate_with_spk_embed(hs, spk_emb) # set limits of length maxlen = int(hs.shape[1] * maxlenratio / self.reduction_factor) minlen = int(hs.shape[1] * minlenratio / self.reduction_factor) # initialize idx = 0 ys = paddle.zeros([1, 1, self.odim]) outs, probs = [], [] # forward decoder step-by-step z_cache = None while True: # update index idx += 1 # calculate output and stop prob at idx-th step y_masks = subsequent_mask(idx).unsqueeze(0) z, z_cache = self.decoder.forward_one_step( ys, y_masks, hs, cache=z_cache) # (B, adim) outs += [ self.feat_out(z).reshape([self.reduction_factor, self.odim]) ] # [(r, odim), ...] probs += [F.sigmoid(self.prob_out(z))[0]] # [(r), ...] # update next inputs ys = paddle.concat( (ys, outs[-1][-1].reshape([1, 1, self.odim])), axis=1) # (1, idx + 1, odim) # get attention weights att_ws_ = [] for name, m in self.named_sublayers(): if isinstance(m, MultiHeadedAttention) and "src" in name: # [(#heads, 1, T),...] att_ws_ += [m.attn[0, :, -1].unsqueeze(1)] if idx == 1: att_ws = att_ws_ else: # [(#heads, l, T), ...] att_ws = [ paddle.concat([att_w, att_w_], axis=1) for att_w, att_w_ in zip(att_ws, att_ws_) ] # check whether to finish generation if sum(paddle.cast(probs[-1] >= threshold, 'int64')) > 0 or idx >= maxlen: # check mininum length if idx < minlen: continue # (L, odim) -> (1, L, odim) -> (1, odim, L) outs = (paddle.concat(outs, axis=0).unsqueeze(0).transpose( [0, 2, 1])) if self.postnet is not None: # (1, odim, L) outs = outs + self.postnet(outs) # (L, odim) outs = outs.transpose([0, 2, 1]).squeeze(0) probs = paddle.concat(probs, axis=0) break # concatenate attention weights -> (#layers, #heads, L, T) att_ws = paddle.stack(att_ws, axis=0) return outs, probs, att_ws def _add_first_frame_and_remove_last_frame( self, ys: paddle.Tensor) -> paddle.Tensor: ys_in = paddle.concat( [paddle.zeros((ys.shape[0], 1, ys.shape[2])), ys[:, :-1]], axis=1) return ys_in def _source_mask(self, ilens: paddle.Tensor) -> paddle.Tensor: """Make masks for self-attention. Parameters ---------- ilens : Tensor Batch of lengths (B,). Returns ------- Tensor Mask tensor for self-attention. dtype=paddle.bool Examples ------- >>> ilens = [5, 3] >>> self._source_mask(ilens) tensor([[[1, 1, 1, 1, 1], [1, 1, 1, 0, 0]]]) bool """ x_masks = make_non_pad_mask(ilens) return x_masks.unsqueeze(-2) def _target_mask(self, olens: paddle.Tensor) -> paddle.Tensor: """Make masks for masked self-attention. Parameters ---------- olens : LongTensor Batch of lengths (B,). Returns ---------- Tensor Mask tensor for masked self-attention. Examples ---------- >>> olens = [5, 3] >>> self._target_mask(olens) tensor([[[1, 0, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 1, 0, 0], [1, 1, 1, 1, 0], [1, 1, 1, 1, 1]], [[1, 0, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 1, 0, 0], [1, 1, 1, 0, 0], [1, 1, 1, 0, 0]]], dtype=paddle.uint8) """ y_masks = make_non_pad_mask(olens) s_masks = subsequent_mask(y_masks.shape[-1]).unsqueeze(0) return paddle.logical_and(y_masks.unsqueeze(-2), s_masks) def _integrate_with_spk_embed(self, hs: paddle.Tensor, spk_emb: paddle.Tensor) -> paddle.Tensor: """Integrate speaker embedding with hidden states. Parameters ---------- hs : Tensor Batch of hidden state sequences (B, Tmax, adim). spk_emb : Tensor Batch of speaker embeddings (B, spk_embed_dim). Returns ---------- Tensor Batch of integrated hidden state sequences (B, Tmax, adim). """ if self.spk_embed_integration_type == "add": # apply projection and then add to hidden states spk_emb = self.projection(F.normalize(spk_emb)) hs = hs + spk_emb.unsqueeze(1) elif self.spk_embed_integration_type == "concat": # concat hidden states with spk embeds and then apply projection spk_emb = F.normalize(spk_emb).unsqueeze(1).expand(-1, hs.shape[1], -1) hs = self.projection(paddle.concat([hs, spk_emb], axis=-1)) else: raise NotImplementedError("support only add or concat.") return hs class TransformerTTSInference(nn.Layer): def __init__(self, normalizer, model): super().__init__() self.normalizer = normalizer self.acoustic_model = model def forward(self, text, spk_id=None): normalized_mel = self.acoustic_model.inference(text)[0] logmel = self.normalizer.inverse(normalized_mel) return logmel class TransformerTTSLoss(nn.Layer): """Loss function module for Tacotron2.""" def __init__(self, use_masking=True, use_weighted_masking=False, bce_pos_weight=5.0): """Initialize Tactoron2 loss module. Parameters ---------- use_masking : bool Whether to apply masking for padded part in loss calculation. use_weighted_masking : bool Whether to apply weighted masking in loss calculation. bce_pos_weight : float Weight of positive sample of stop token. """ super().__init__() assert (use_masking != use_weighted_masking) or not use_masking self.use_masking = use_masking self.use_weighted_masking = use_weighted_masking # define criterions reduction = "none" if self.use_weighted_masking else "mean" self.l1_criterion = nn.L1Loss(reduction=reduction) self.mse_criterion = nn.MSELoss(reduction=reduction) self.bce_criterion = nn.BCEWithLogitsLoss( reduction=reduction, pos_weight=paddle.to_tensor(bce_pos_weight)) def forward(self, after_outs, before_outs, logits, ys, labels, olens): """Calculate forward propagation. Parameters ---------- after_outs : Tensor Batch of outputs after postnets (B, Lmax, odim). before_outs : Tensor Batch of outputs before postnets (B, Lmax, odim). logits : Tensor Batch of stop logits (B, Lmax). ys : Tensor Batch of padded target features (B, Lmax, odim). labels : LongTensor Batch of the sequences of stop token labels (B, Lmax). olens : LongTensor Batch of the lengths of each target (B,). Returns ---------- Tensor L1 loss value. Tensor Mean square error loss value. Tensor Binary cross entropy loss value. """ # make mask and apply it if self.use_masking: masks = make_non_pad_mask(olens).unsqueeze(-1) ys = ys.masked_select(masks.broadcast_to(ys.shape)) after_outs = after_outs.masked_select( masks.broadcast_to(after_outs.shape)) before_outs = before_outs.masked_select( masks.broadcast_to(before_outs.shape)) # Operator slice does not have kernel for data_type[bool] tmp_masks = paddle.cast(masks, dtype='int64') tmp_masks = tmp_masks[:, :, 0] tmp_masks = paddle.cast(tmp_masks, dtype='bool') labels = labels.masked_select(tmp_masks.broadcast_to(labels.shape)) logits = logits.masked_select(tmp_masks.broadcast_to(logits.shape)) # calculate loss l1_loss = self.l1_criterion(after_outs, ys) + self.l1_criterion( before_outs, ys) mse_loss = self.mse_criterion(after_outs, ys) + self.mse_criterion( before_outs, ys) bce_loss = self.bce_criterion(logits, labels) # make weighted mask and apply it if self.use_weighted_masking: masks = make_non_pad_mask(olens).unsqueeze(-1) weights = masks.float() / masks.sum(dim=1, keepdim=True).float() out_weights = weights.div(ys.shape[0] * ys.shape[2]) logit_weights = weights.div(ys.shape[0]) # apply weight l1_loss = l1_loss.multiply(out_weights) l1_loss = l1_loss.masked_select( masks.broadcast_to(l1_loss.shape)).sum() mse_loss = mse_loss.multiply(out_weights) mse_loss = mse_loss.masked_select( masks.broadcast_to(mse_loss.shape)).sum() bce_loss = bce_loss.multiply(logit_weights.squeeze(-1)) bce_loss = bce_loss.masked_select( masks.squeeze(-1).broadcast_to(bce_loss.shape)).sum() return l1_loss, mse_loss, bce_loss class GuidedAttentionLoss(nn.Layer): """Guided attention loss function module. This module calculates the guided attention loss described in `Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention`_, which forces the attention to be diagonal. .. _`Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention`: https://arxiv.org/abs/1710.08969 """ def __init__(self, sigma=0.4, alpha=1.0, reset_always=True): """Initialize guided attention loss module. Parameters ---------- sigma : float, optional Standard deviation to control how close attention to a diagonal. alpha : float, optional Scaling coefficient (lambda). reset_always : bool, optional Whether to always reset masks. """ super(GuidedAttentionLoss, self).__init__() self.sigma = sigma self.alpha = alpha self.reset_always = reset_always self.guided_attn_masks = None self.masks = None def _reset_masks(self): self.guided_attn_masks = None self.masks = None def forward(self, att_ws, ilens, olens): """Calculate forward propagation. Parameters ---------- att_ws : Tensor Batch of attention weights (B, T_max_out, T_max_in). ilens : LongTensor Batch of input lenghts (B,). olens : LongTensor Batch of output lenghts (B,). Returns ---------- Tensor Guided attention loss value. """ if self.guided_attn_masks is None: self.guided_attn_masks = self._make_guided_attention_masks(ilens, olens) if self.masks is None: self.masks = self._make_masks(ilens, olens) losses = self.guided_attn_masks * att_ws loss = paddle.mean( losses.masked_select(self.masks.broadcast_to(losses.shape))) if self.reset_always: self._reset_masks() return self.alpha * loss def _make_guided_attention_masks(self, ilens, olens): n_batches = len(ilens) max_ilen = max(ilens) max_olen = max(olens) guided_attn_masks = paddle.zeros((n_batches, max_olen, max_ilen)) for idx, (ilen, olen) in enumerate(zip(ilens, olens)): ilen = int(ilen) olen = int(olen) guided_attn_masks[idx, :olen, : ilen] = self._make_guided_attention_mask( ilen, olen, self.sigma) return guided_attn_masks @staticmethod def _make_guided_attention_mask(ilen, olen, sigma): """Make guided attention mask. Examples ---------- >>> guided_attn_mask =_make_guided_attention(5, 5, 0.4) >>> guided_attn_mask.shape [5, 5] >>> guided_attn_mask tensor([[0.0000, 0.1175, 0.3935, 0.6753, 0.8647], [0.1175, 0.0000, 0.1175, 0.3935, 0.6753], [0.3935, 0.1175, 0.0000, 0.1175, 0.3935], [0.6753, 0.3935, 0.1175, 0.0000, 0.1175], [0.8647, 0.6753, 0.3935, 0.1175, 0.0000]]) >>> guided_attn_mask =_make_guided_attention(3, 6, 0.4) >>> guided_attn_mask.shape [6, 3] >>> guided_attn_mask tensor([[0.0000, 0.2934, 0.7506], [0.0831, 0.0831, 0.5422], [0.2934, 0.0000, 0.2934], [0.5422, 0.0831, 0.0831], [0.7506, 0.2934, 0.0000], [0.8858, 0.5422, 0.0831]]) """ grid_x, grid_y = paddle.meshgrid( paddle.arange(olen), paddle.arange(ilen)) grid_x = grid_x.cast(dtype=paddle.float32) grid_y = grid_y.cast(dtype=paddle.float32) return 1.0 - paddle.exp(-( (grid_y / ilen - grid_x / olen)**2) / (2 * (sigma**2))) @staticmethod def _make_masks(ilens, olens): """Make masks indicating non-padded part. Parameters ---------- ilens (LongTensor or List): Batch of lengths (B,). olens (LongTensor or List): Batch of lengths (B,). Returns ---------- Tensor Mask tensor indicating non-padded part. Examples ---------- >>> ilens, olens = [5, 2], [8, 5] >>> _make_mask(ilens, olens) tensor([[[1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]], [[1, 1, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]], dtype=paddle.uint8) """ # (B, T_in) in_masks = make_non_pad_mask(ilens) # (B, T_out) out_masks = make_non_pad_mask(olens) # (B, T_out, T_in) return paddle.logical_and( out_masks.unsqueeze(-1), in_masks.unsqueeze(-2)) class GuidedMultiHeadAttentionLoss(GuidedAttentionLoss): """Guided attention loss function module for multi head attention. Parameters ---------- sigma : float, optional Standard deviation to controlGuidedAttentionLoss how close attention to a diagonal. alpha : float, optional Scaling coefficient (lambda). reset_always : bool, optional Whether to always reset masks. """ def forward(self, att_ws, ilens, olens): """Calculate forward propagation. Parameters ---------- att_ws : Tensor Batch of multi head attention weights (B, H, T_max_out, T_max_in). ilens : Tensor Batch of input lenghts (B,). olens : Tensor Batch of output lenghts (B,). Returns ---------- Tensor Guided attention loss value. """ if self.guided_attn_masks is None: self.guided_attn_masks = ( self._make_guided_attention_masks(ilens, olens).unsqueeze(1)) if self.masks is None: self.masks = self._make_masks(ilens, olens).unsqueeze(1) losses = self.guided_attn_masks * att_ws loss = paddle.mean( losses.masked_select(self.masks.broadcast_to(losses.shape))) if self.reset_always: self._reset_masks() return self.alpha * loss