You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1083 lines
40 KiB
1083 lines
40 KiB
3 years ago
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
"""Fastspeech2 related modules for paddle"""
|
||
|
from typing import Dict
|
||
|
from typing import Sequence
|
||
|
from typing import Tuple
|
||
|
import numpy
|
||
|
import paddle
|
||
|
from paddle import nn
|
||
|
import paddle.nn.functional as F
|
||
|
from typeguard import check_argument_types
|
||
|
|
||
|
from parakeet.modules.fastspeech2_transformer.attention import MultiHeadedAttention
|
||
|
from parakeet.modules.fastspeech2_transformer.decoder import Decoder
|
||
|
from parakeet.modules.fastspeech2_transformer.embedding import PositionalEncoding
|
||
|
from parakeet.modules.fastspeech2_transformer.embedding import ScaledPositionalEncoding
|
||
|
from parakeet.modules.fastspeech2_transformer.encoder import Encoder
|
||
|
from parakeet.modules.fastspeech2_transformer.mask import subsequent_mask
|
||
|
from parakeet.modules.style_encoder import StyleEncoder
|
||
|
from parakeet.modules.tacotron2.decoder import Postnet
|
||
|
from parakeet.modules.tacotron2.decoder import Prenet as DecoderPrenet
|
||
|
from parakeet.modules.tacotron2.encoder import Encoder as EncoderPrenet
|
||
|
from parakeet.modules.nets_utils import initialize
|
||
|
from parakeet.modules.nets_utils import make_non_pad_mask
|
||
|
from parakeet.modules.nets_utils import make_pad_mask
|
||
|
|
||
|
|
||
|
class TransformerTTS(nn.Layer):
|
||
|
"""TTS-Transformer module.
|
||
|
|
||
|
This is a module of text-to-speech Transformer described in `Neural Speech Synthesis
|
||
|
with Transformer Network`_, which convert the sequence of tokens into the sequence
|
||
|
of Mel-filterbanks.
|
||
|
|
||
|
.. _`Neural Speech Synthesis with Transformer Network`:
|
||
|
https://arxiv.org/pdf/1809.08895.pdf
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
idim : int
|
||
|
Dimension of the inputs.
|
||
|
odim : int
|
||
|
Dimension of the outputs.
|
||
|
embed_dim : int, optional
|
||
|
Dimension of character embedding.
|
||
|
eprenet_conv_layers : int, optional
|
||
|
Number of encoder prenet convolution layers.
|
||
|
eprenet_conv_chans : int, optional
|
||
|
Number of encoder prenet convolution channels.
|
||
|
eprenet_conv_filts : int, optional
|
||
|
Filter size of encoder prenet convolution.
|
||
|
dprenet_layers : int, optional
|
||
|
Number of decoder prenet layers.
|
||
|
dprenet_units : int, optional
|
||
|
Number of decoder prenet hidden units.
|
||
|
elayers : int, optional
|
||
|
Number of encoder layers.
|
||
|
eunits : int, optional
|
||
|
Number of encoder hidden units.
|
||
|
adim : int, optional
|
||
|
Number of attention transformation dimensions.
|
||
|
aheads : int, optional
|
||
|
Number of heads for multi head attention.
|
||
|
dlayers : int, optional
|
||
|
Number of decoder layers.
|
||
|
dunits : int, optional
|
||
|
Number of decoder hidden units.
|
||
|
postnet_layers : int, optional
|
||
|
Number of postnet layers.
|
||
|
postnet_chans : int, optional
|
||
|
Number of postnet channels.
|
||
|
postnet_filts : int, optional
|
||
|
Filter size of postnet.
|
||
|
use_scaled_pos_enc : pool, optional
|
||
|
Whether to use trainable scaled positional encoding.
|
||
|
use_batch_norm : bool, optional
|
||
|
Whether to use batch normalization in encoder prenet.
|
||
|
encoder_normalize_before : bool, optional
|
||
|
Whether to perform layer normalization before encoder block.
|
||
|
decoder_normalize_before : bool, optional
|
||
|
Whether to perform layer normalization before decoder block.
|
||
|
encoder_concat_after : bool, optional
|
||
|
Whether to concatenate attention layer's input and output in encoder.
|
||
|
decoder_concat_after : bool, optional
|
||
|
Whether to concatenate attention layer's input and output in decoder.
|
||
|
positionwise_layer_type : str, optional
|
||
|
Position-wise operation type.
|
||
|
positionwise_conv_kernel_size : int, optional
|
||
|
Kernel size in position wise conv 1d.
|
||
|
reduction_factor : int, optional
|
||
|
Reduction factor.
|
||
|
spk_embed_dim : int, optional
|
||
|
Number of speaker embedding dimenstions.
|
||
|
spk_embed_integration_type : str, optional
|
||
|
How to integrate speaker embedding.
|
||
|
use_gst : str, optional
|
||
|
Whether to use global style token.
|
||
|
gst_tokens : int, optional
|
||
|
The number of GST embeddings.
|
||
|
gst_heads : int, optional
|
||
|
The number of heads in GST multihead attention.
|
||
|
gst_conv_layers : int, optional
|
||
|
The number of conv layers in GST.
|
||
|
gst_conv_chans_list : Sequence[int], optional
|
||
|
List of the number of channels of conv layers in GST.
|
||
|
gst_conv_kernel_size : int, optional
|
||
|
Kernal size of conv layers in GST.
|
||
|
gst_conv_stride : int, optional
|
||
|
Stride size of conv layers in GST.
|
||
|
gst_gru_layers : int, optional
|
||
|
The number of GRU layers in GST.
|
||
|
gst_gru_units : int, optional
|
||
|
The number of GRU units in GST.
|
||
|
transformer_lr : float, optional
|
||
|
Initial value of learning rate.
|
||
|
transformer_warmup_steps : int, optional
|
||
|
Optimizer warmup steps.
|
||
|
transformer_enc_dropout_rate : float, optional
|
||
|
Dropout rate in encoder except attention and positional encoding.
|
||
|
transformer_enc_positional_dropout_rate : float, optional
|
||
|
Dropout rate after encoder positional encoding.
|
||
|
transformer_enc_attn_dropout_rate : float, optional
|
||
|
Dropout rate in encoder self-attention module.
|
||
|
transformer_dec_dropout_rate : float, optional
|
||
|
Dropout rate in decoder except attention & positional encoding.
|
||
|
transformer_dec_positional_dropout_rate : float, optional
|
||
|
Dropout rate after decoder positional encoding.
|
||
|
transformer_dec_attn_dropout_rate : float, optional
|
||
|
Dropout rate in deocoder self-attention module.
|
||
|
transformer_enc_dec_attn_dropout_rate : float, optional
|
||
|
Dropout rate in encoder-deocoder attention module.
|
||
|
init_type : str, optional
|
||
|
How to initialize transformer parameters.
|
||
|
init_enc_alpha : float, optional
|
||
|
Initial value of alpha in scaled pos encoding of the encoder.
|
||
|
init_dec_alpha : float, optional
|
||
|
Initial value of alpha in scaled pos encoding of the decoder.
|
||
|
eprenet_dropout_rate : float, optional
|
||
|
Dropout rate in encoder prenet.
|
||
|
dprenet_dropout_rate : float, optional
|
||
|
Dropout rate in decoder prenet.
|
||
|
postnet_dropout_rate : float, optional
|
||
|
Dropout rate in postnet.
|
||
|
use_masking : bool, optional
|
||
|
Whether to apply masking for padded part in loss calculation.
|
||
|
use_weighted_masking : bool, optional
|
||
|
Whether to apply weighted masking in loss calculation.
|
||
|
bce_pos_weight : float, optional
|
||
|
Positive sample weight in bce calculation (only for use_masking=true).
|
||
|
loss_type : str, optional
|
||
|
How to calculate loss.
|
||
|
use_guided_attn_loss : bool, optional
|
||
|
Whether to use guided attention loss.
|
||
|
num_heads_applied_guided_attn : int, optional
|
||
|
Number of heads in each layer to apply guided attention loss.
|
||
|
num_layers_applied_guided_attn : int, optional
|
||
|
Number of layers to apply guided attention loss.
|
||
|
List of module names to apply guided attention loss.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
# network structure related
|
||
|
idim: int,
|
||
|
odim: int,
|
||
|
embed_dim: int=512,
|
||
|
eprenet_conv_layers: int=3,
|
||
|
eprenet_conv_chans: int=256,
|
||
|
eprenet_conv_filts: int=5,
|
||
|
dprenet_layers: int=2,
|
||
|
dprenet_units: int=256,
|
||
|
elayers: int=6,
|
||
|
eunits: int=1024,
|
||
|
adim: int=512,
|
||
|
aheads: int=4,
|
||
|
dlayers: int=6,
|
||
|
dunits: int=1024,
|
||
|
postnet_layers: int=5,
|
||
|
postnet_chans: int=256,
|
||
|
postnet_filts: int=5,
|
||
|
positionwise_layer_type: str="conv1d",
|
||
|
positionwise_conv_kernel_size: int=1,
|
||
|
use_scaled_pos_enc: bool=True,
|
||
|
use_batch_norm: bool=True,
|
||
|
encoder_normalize_before: bool=True,
|
||
|
decoder_normalize_before: bool=True,
|
||
|
encoder_concat_after: bool=False,
|
||
|
decoder_concat_after: bool=False,
|
||
|
reduction_factor: int=1,
|
||
|
spk_embed_dim: int=None,
|
||
|
spk_embed_integration_type: str="add",
|
||
|
use_gst: bool=False,
|
||
|
gst_tokens: int=10,
|
||
|
gst_heads: int=4,
|
||
|
gst_conv_layers: int=6,
|
||
|
gst_conv_chans_list: Sequence[int]=(32, 32, 64, 64, 128, 128),
|
||
|
gst_conv_kernel_size: int=3,
|
||
|
gst_conv_stride: int=2,
|
||
|
gst_gru_layers: int=1,
|
||
|
gst_gru_units: int=128,
|
||
|
# training related
|
||
|
transformer_enc_dropout_rate: float=0.1,
|
||
|
transformer_enc_positional_dropout_rate: float=0.1,
|
||
|
transformer_enc_attn_dropout_rate: float=0.1,
|
||
|
transformer_dec_dropout_rate: float=0.1,
|
||
|
transformer_dec_positional_dropout_rate: float=0.1,
|
||
|
transformer_dec_attn_dropout_rate: float=0.1,
|
||
|
transformer_enc_dec_attn_dropout_rate: float=0.1,
|
||
|
eprenet_dropout_rate: float=0.5,
|
||
|
dprenet_dropout_rate: float=0.5,
|
||
|
postnet_dropout_rate: float=0.5,
|
||
|
init_type: str="xavier_uniform",
|
||
|
init_enc_alpha: float=1.0,
|
||
|
init_dec_alpha: float=1.0,
|
||
|
use_guided_attn_loss: bool=True,
|
||
|
num_heads_applied_guided_attn: int=2,
|
||
|
num_layers_applied_guided_attn: int=2, ):
|
||
|
"""Initialize Transformer module."""
|
||
|
assert check_argument_types()
|
||
|
super().__init__()
|
||
|
|
||
|
# store hyperparameters
|
||
|
self.idim = idim
|
||
|
self.odim = odim
|
||
|
self.eos = idim - 1
|
||
|
self.spk_embed_dim = spk_embed_dim
|
||
|
self.reduction_factor = reduction_factor
|
||
|
self.use_gst = use_gst
|
||
|
self.use_scaled_pos_enc = use_scaled_pos_enc
|
||
|
self.use_guided_attn_loss = use_guided_attn_loss
|
||
|
if self.use_guided_attn_loss:
|
||
|
if num_layers_applied_guided_attn == -1:
|
||
|
self.num_layers_applied_guided_attn = elayers
|
||
|
else:
|
||
|
self.num_layers_applied_guided_attn = num_layers_applied_guided_attn
|
||
|
if num_heads_applied_guided_attn == -1:
|
||
|
self.num_heads_applied_guided_attn = aheads
|
||
|
else:
|
||
|
self.num_heads_applied_guided_attn = num_heads_applied_guided_attn
|
||
|
if self.spk_embed_dim is not None:
|
||
|
self.spk_embed_integration_type = spk_embed_integration_type
|
||
|
|
||
|
# use idx 0 as padding idx
|
||
|
self.padding_idx = 0
|
||
|
# set_global_initializer 会影响后面的全局,包括 create_parameter
|
||
|
initialize(self, init_type)
|
||
|
# get positional encoding class
|
||
|
pos_enc_class = (ScaledPositionalEncoding
|
||
|
if self.use_scaled_pos_enc else PositionalEncoding)
|
||
|
|
||
|
# define transformer encoder
|
||
|
if eprenet_conv_layers != 0:
|
||
|
# encoder prenet
|
||
|
encoder_input_layer = nn.Sequential(
|
||
|
EncoderPrenet(
|
||
|
idim=idim,
|
||
|
embed_dim=embed_dim,
|
||
|
elayers=0,
|
||
|
econv_layers=eprenet_conv_layers,
|
||
|
econv_chans=eprenet_conv_chans,
|
||
|
econv_filts=eprenet_conv_filts,
|
||
|
use_batch_norm=use_batch_norm,
|
||
|
dropout_rate=eprenet_dropout_rate,
|
||
|
padding_idx=self.padding_idx, ),
|
||
|
nn.Linear(eprenet_conv_chans, adim), )
|
||
|
else:
|
||
|
encoder_input_layer = nn.Embedding(
|
||
|
num_embeddings=idim,
|
||
|
embedding_dim=adim,
|
||
|
padding_idx=self.padding_idx)
|
||
|
self.encoder = Encoder(
|
||
|
idim=idim,
|
||
|
attention_dim=adim,
|
||
|
attention_heads=aheads,
|
||
|
linear_units=eunits,
|
||
|
num_blocks=elayers,
|
||
|
input_layer=encoder_input_layer,
|
||
|
dropout_rate=transformer_enc_dropout_rate,
|
||
|
positional_dropout_rate=transformer_enc_positional_dropout_rate,
|
||
|
attention_dropout_rate=transformer_enc_attn_dropout_rate,
|
||
|
pos_enc_class=pos_enc_class,
|
||
|
normalize_before=encoder_normalize_before,
|
||
|
concat_after=encoder_concat_after,
|
||
|
positionwise_layer_type=positionwise_layer_type,
|
||
|
positionwise_conv_kernel_size=positionwise_conv_kernel_size, )
|
||
|
|
||
|
# define GST
|
||
|
if self.use_gst:
|
||
|
self.gst = StyleEncoder(
|
||
|
idim=odim, # the input is mel-spectrogram
|
||
|
gst_tokens=gst_tokens,
|
||
|
gst_token_dim=adim,
|
||
|
gst_heads=gst_heads,
|
||
|
conv_layers=gst_conv_layers,
|
||
|
conv_chans_list=gst_conv_chans_list,
|
||
|
conv_kernel_size=gst_conv_kernel_size,
|
||
|
conv_stride=gst_conv_stride,
|
||
|
gru_layers=gst_gru_layers,
|
||
|
gru_units=gst_gru_units, )
|
||
|
|
||
|
# define projection layer
|
||
|
if self.spk_embed_dim is not None:
|
||
|
if self.spk_embed_integration_type == "add":
|
||
|
self.projection = nn.Linear(self.spk_embed_dim, adim)
|
||
|
else:
|
||
|
self.projection = nn.Linear(adim + self.spk_embed_dim, adim)
|
||
|
|
||
|
# define transformer decoder
|
||
|
if dprenet_layers != 0:
|
||
|
# decoder prenet
|
||
|
decoder_input_layer = nn.Sequential(
|
||
|
DecoderPrenet(
|
||
|
idim=odim,
|
||
|
n_layers=dprenet_layers,
|
||
|
n_units=dprenet_units,
|
||
|
dropout_rate=dprenet_dropout_rate, ),
|
||
|
nn.Linear(dprenet_units, adim), )
|
||
|
else:
|
||
|
decoder_input_layer = "linear"
|
||
|
self.decoder = Decoder(
|
||
|
odim=odim, # odim is needed when no prenet is used
|
||
|
attention_dim=adim,
|
||
|
attention_heads=aheads,
|
||
|
linear_units=dunits,
|
||
|
num_blocks=dlayers,
|
||
|
dropout_rate=transformer_dec_dropout_rate,
|
||
|
positional_dropout_rate=transformer_dec_positional_dropout_rate,
|
||
|
self_attention_dropout_rate=transformer_dec_attn_dropout_rate,
|
||
|
src_attention_dropout_rate=transformer_enc_dec_attn_dropout_rate,
|
||
|
input_layer=decoder_input_layer,
|
||
|
use_output_layer=False,
|
||
|
pos_enc_class=pos_enc_class,
|
||
|
normalize_before=decoder_normalize_before,
|
||
|
concat_after=decoder_concat_after, )
|
||
|
|
||
|
# define final projection
|
||
|
self.feat_out = nn.Linear(adim, odim * reduction_factor)
|
||
|
self.prob_out = nn.Linear(adim, reduction_factor)
|
||
|
|
||
|
# define postnet
|
||
|
self.postnet = (None if postnet_layers == 0 else Postnet(
|
||
|
idim=idim,
|
||
|
odim=odim,
|
||
|
n_layers=postnet_layers,
|
||
|
n_chans=postnet_chans,
|
||
|
n_filts=postnet_filts,
|
||
|
use_batch_norm=use_batch_norm,
|
||
|
dropout_rate=postnet_dropout_rate, ))
|
||
|
|
||
|
# 闭合的 initialize() 中的 set_global_initializer 的作用域,防止其影响到 self._reset_parameters()
|
||
|
nn.initializer.set_global_initializer(None)
|
||
|
|
||
|
self._reset_parameters(
|
||
|
init_enc_alpha=init_enc_alpha,
|
||
|
init_dec_alpha=init_dec_alpha, )
|
||
|
|
||
|
def _reset_parameters(self, init_enc_alpha: float, init_dec_alpha: float):
|
||
|
|
||
|
# initialize alpha in scaled positional encoding
|
||
|
if self.use_scaled_pos_enc:
|
||
|
init_enc_alpha = paddle.to_tensor(init_enc_alpha)
|
||
|
self.encoder.embed[-1].alpha = paddle.create_parameter(
|
||
|
shape=init_enc_alpha.shape,
|
||
|
dtype=str(init_enc_alpha.numpy().dtype),
|
||
|
default_initializer=paddle.nn.initializer.Assign(
|
||
|
init_enc_alpha))
|
||
|
|
||
|
init_dec_alpha = paddle.to_tensor(init_dec_alpha)
|
||
|
self.decoder.embed[-1].alpha = paddle.create_parameter(
|
||
|
shape=init_dec_alpha.shape,
|
||
|
dtype=str(init_dec_alpha.numpy().dtype),
|
||
|
default_initializer=paddle.nn.initializer.Assign(
|
||
|
init_dec_alpha))
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
text: paddle.Tensor,
|
||
|
text_lengths: paddle.Tensor,
|
||
|
speech: paddle.Tensor,
|
||
|
speech_lengths: paddle.Tensor,
|
||
|
spembs: paddle.Tensor=None,
|
||
|
) -> Tuple[paddle.Tensor, Dict[str, paddle.Tensor], paddle.Tensor]:
|
||
|
"""Calculate forward propagation.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
text : Tensor(int64)
|
||
|
Batch of padded character ids (B, Tmax).
|
||
|
text_lengths : Tensor(int64)
|
||
|
Batch of lengths of each input batch (B,).
|
||
|
speech : Tensor
|
||
|
Batch of padded target features (B, Lmax, odim).
|
||
|
speech_lengths : Tensor(int64)
|
||
|
Batch of the lengths of each target (B,).
|
||
|
spembs : Tensor, optional
|
||
|
Batch of speaker embeddings (B, spk_embed_dim).
|
||
|
|
||
|
Returns
|
||
|
----------
|
||
|
Tensor
|
||
|
Loss scalar value.
|
||
|
Dict
|
||
|
Statistics to be monitored.
|
||
|
|
||
|
"""
|
||
|
# input of embedding must be int64
|
||
|
text_lengths = paddle.cast(text_lengths, 'int64')
|
||
|
|
||
|
# Add eos at the last of sequence
|
||
|
text = numpy.pad(text.numpy(), ((0, 0), (0, 1)), 'constant')
|
||
|
xs = paddle.to_tensor(text, dtype='int64')
|
||
|
for i, l in enumerate(text_lengths):
|
||
|
xs[i, l] = self.eos
|
||
|
ilens = text_lengths + 1
|
||
|
|
||
|
ys = speech
|
||
|
olens = paddle.cast(speech_lengths, 'int64')
|
||
|
|
||
|
# make labels for stop prediction
|
||
|
labels = make_pad_mask(olens - 1)
|
||
|
labels = numpy.pad(
|
||
|
labels.numpy(), ((0, 0), (0, 1)), 'constant', constant_values=1.0)
|
||
|
labels = paddle.to_tensor(labels)
|
||
|
labels = paddle.cast(labels, dtype="float32")
|
||
|
# labels = F.pad(labels, [0, 1], "constant", 1.0)
|
||
|
|
||
|
# calculate transformer outputs
|
||
|
after_outs, before_outs, logits = self._forward(xs, ilens, ys, olens,
|
||
|
spembs)
|
||
|
|
||
|
# modifiy mod part of groundtruth
|
||
|
|
||
|
if self.reduction_factor > 1:
|
||
|
olens = paddle.to_tensor(
|
||
|
[olen - olen % self.reduction_factor for olen in olens.numpy()])
|
||
|
max_olen = max(olens)
|
||
|
ys = ys[:, :max_olen]
|
||
|
labels = labels[:, :max_olen]
|
||
|
labels[:, -1] = 1.0 # make sure at least one frame has 1
|
||
|
need_dict = {}
|
||
|
need_dict['encoder'] = self.encoder
|
||
|
need_dict['decoder'] = self.decoder
|
||
|
need_dict[
|
||
|
'num_heads_applied_guided_attn'] = self.num_heads_applied_guided_attn
|
||
|
need_dict[
|
||
|
'num_layers_applied_guided_attn'] = self.num_layers_applied_guided_attn
|
||
|
need_dict['use_scaled_pos_enc'] = self.use_scaled_pos_enc
|
||
|
|
||
|
return after_outs, before_outs, logits, ys, labels, olens, ilens, need_dict
|
||
|
|
||
|
def _forward(
|
||
|
self,
|
||
|
xs: paddle.Tensor,
|
||
|
ilens: paddle.Tensor,
|
||
|
ys: paddle.Tensor,
|
||
|
olens: paddle.Tensor,
|
||
|
spembs: paddle.Tensor,
|
||
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
||
|
# forward encoder
|
||
|
x_masks = self._source_mask(ilens)
|
||
|
hs, h_masks = self.encoder(xs, x_masks)
|
||
|
|
||
|
# integrate with GST
|
||
|
if self.use_gst:
|
||
|
style_embs = self.gst(ys)
|
||
|
hs = hs + style_embs.unsqueeze(1)
|
||
|
|
||
|
# integrate speaker embedding
|
||
|
if self.spk_embed_dim is not None:
|
||
|
hs = self._integrate_with_spk_embed(hs, spembs)
|
||
|
|
||
|
# thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim)
|
||
|
if self.reduction_factor > 1:
|
||
|
ys_in = ys[:, self.reduction_factor - 1::self.reduction_factor]
|
||
|
olens_in = olens.new(
|
||
|
[olen // self.reduction_factor for olen in olens])
|
||
|
else:
|
||
|
ys_in, olens_in = ys, olens
|
||
|
|
||
|
# add first zero frame and remove last frame for auto-regressive
|
||
|
ys_in = self._add_first_frame_and_remove_last_frame(ys_in)
|
||
|
|
||
|
# forward decoder
|
||
|
y_masks = self._target_mask(olens_in)
|
||
|
zs, _ = self.decoder(ys_in, y_masks, hs, h_masks)
|
||
|
# (B, Lmax//r, odim * r) -> (B, Lmax//r * r, odim)
|
||
|
before_outs = self.feat_out(zs).reshape([zs.shape[0], -1, self.odim])
|
||
|
# (B, Lmax//r, r) -> (B, Lmax//r * r)
|
||
|
logits = self.prob_out(zs).reshape([zs.shape[0], -1])
|
||
|
|
||
|
# postnet -> (B, Lmax//r * r, odim)
|
||
|
if self.postnet is None:
|
||
|
after_outs = before_outs
|
||
|
else:
|
||
|
after_outs = before_outs + self.postnet(
|
||
|
before_outs.transpose([0, 2, 1])).transpose([0, 2, 1])
|
||
|
|
||
|
return after_outs, before_outs, logits
|
||
|
|
||
|
def inference(
|
||
|
self,
|
||
|
text: paddle.Tensor,
|
||
|
speech: paddle.Tensor=None,
|
||
|
spembs: paddle.Tensor=None,
|
||
|
threshold: float=0.5,
|
||
|
minlenratio: float=0.0,
|
||
|
maxlenratio: float=10.0,
|
||
|
use_teacher_forcing: bool=False,
|
||
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
||
|
"""Generate the sequence of features given the sequences of characters.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
text : Tensor(int64)
|
||
|
Input sequence of characters (T,).
|
||
|
speech : Tensor, optional
|
||
|
Feature sequence to extract style (N, idim).
|
||
|
spembs : Tensor, optional
|
||
|
Speaker embedding vector (spk_embed_dim,).
|
||
|
threshold : float, optional
|
||
|
Threshold in inference.
|
||
|
minlenratio : float, optional
|
||
|
Minimum length ratio in inference.
|
||
|
maxlenratio : float, optional
|
||
|
Maximum length ratio in inference.
|
||
|
use_teacher_forcing : bool, optional
|
||
|
Whether to use teacher forcing.
|
||
|
|
||
|
Returns
|
||
|
----------
|
||
|
Tensor
|
||
|
Output sequence of features (L, odim).
|
||
|
Tensor
|
||
|
Output sequence of stop probabilities (L,).
|
||
|
Tensor
|
||
|
Encoder-decoder (source) attention weights (#layers, #heads, L, T).
|
||
|
|
||
|
"""
|
||
|
# input of embedding must be int64
|
||
|
y = speech
|
||
|
spemb = spembs
|
||
|
|
||
|
# add eos at the last of sequence
|
||
|
text = numpy.pad(
|
||
|
text.numpy(), (0, 1), 'constant', constant_values=self.eos)
|
||
|
x = paddle.to_tensor(text, dtype='int64')
|
||
|
|
||
|
# inference with teacher forcing
|
||
|
if use_teacher_forcing:
|
||
|
assert speech is not None, "speech must be provided with teacher forcing."
|
||
|
|
||
|
# get teacher forcing outputs
|
||
|
xs, ys = x.unsqueeze(0), y.unsqueeze(0)
|
||
|
spembs = None if spemb is None else spemb.unsqueeze(0)
|
||
|
ilens = paddle.to_tensor(
|
||
|
[xs.shape[1]], dtype=paddle.int64, place=xs.place)
|
||
|
olens = paddle.to_tensor(
|
||
|
[ys.shape[1]], dtype=paddle.int64, place=ys.place)
|
||
|
outs, *_ = self._forward(xs, ilens, ys, olens, spembs)
|
||
|
|
||
|
# get attention weights
|
||
|
att_ws = []
|
||
|
for i in range(len(self.decoder.decoders)):
|
||
|
att_ws += [self.decoder.decoders[i].src_attn.attn]
|
||
|
# (B, L, H, T_out, T_in)
|
||
|
att_ws = paddle.stack(att_ws, axis=1)
|
||
|
|
||
|
return outs[0], None, att_ws[0]
|
||
|
|
||
|
# forward encoder
|
||
|
xs = x.unsqueeze(0)
|
||
|
hs, _ = self.encoder(xs, None)
|
||
|
|
||
|
# integrate GST
|
||
|
if self.use_gst:
|
||
|
style_embs = self.gst(y.unsqueeze(0))
|
||
|
hs = hs + style_embs.unsqueeze(1)
|
||
|
|
||
|
# integrate speaker embedding
|
||
|
if self.spk_embed_dim is not None:
|
||
|
spembs = spemb.unsqueeze(0)
|
||
|
hs = self._integrate_with_spk_embed(hs, spembs)
|
||
|
|
||
|
# set limits of length
|
||
|
maxlen = int(hs.shape[1] * maxlenratio / self.reduction_factor)
|
||
|
minlen = int(hs.shape[1] * minlenratio / self.reduction_factor)
|
||
|
|
||
|
# initialize
|
||
|
idx = 0
|
||
|
ys = paddle.zeros([1, 1, self.odim])
|
||
|
outs, probs = [], []
|
||
|
|
||
|
# forward decoder step-by-step
|
||
|
z_cache = None
|
||
|
while True:
|
||
|
# update index
|
||
|
idx += 1
|
||
|
|
||
|
# calculate output and stop prob at idx-th step
|
||
|
y_masks = subsequent_mask(idx).unsqueeze(0)
|
||
|
z, z_cache = self.decoder.forward_one_step(
|
||
|
ys, y_masks, hs, cache=z_cache) # (B, adim)
|
||
|
outs += [
|
||
|
self.feat_out(z).reshape([self.reduction_factor, self.odim])
|
||
|
] # [(r, odim), ...]
|
||
|
probs += [F.sigmoid(self.prob_out(z))[0]] # [(r), ...]
|
||
|
|
||
|
# update next inputs
|
||
|
ys = paddle.concat(
|
||
|
(ys, outs[-1][-1].reshape([1, 1, self.odim])),
|
||
|
axis=1) # (1, idx + 1, odim)
|
||
|
|
||
|
# get attention weights
|
||
|
att_ws_ = []
|
||
|
for name, m in self.named_sublayers():
|
||
|
if isinstance(m, MultiHeadedAttention) and "src" in name:
|
||
|
# [(#heads, 1, T),...]
|
||
|
att_ws_ += [m.attn[0, :, -1].unsqueeze(1)]
|
||
|
if idx == 1:
|
||
|
att_ws = att_ws_
|
||
|
else:
|
||
|
# [(#heads, l, T), ...]
|
||
|
att_ws = [
|
||
|
paddle.concat([att_w, att_w_], axis=1)
|
||
|
for att_w, att_w_ in zip(att_ws, att_ws_)
|
||
|
]
|
||
|
|
||
|
# check whether to finish generation
|
||
|
if sum(paddle.cast(probs[-1] >= threshold,
|
||
|
'int64')) > 0 or idx >= maxlen:
|
||
|
# check mininum length
|
||
|
if idx < minlen:
|
||
|
continue
|
||
|
# (L, odim) -> (1, L, odim) -> (1, odim, L)
|
||
|
outs = (paddle.concat(outs, axis=0).unsqueeze(0).transpose(
|
||
|
[0, 2, 1]))
|
||
|
if self.postnet is not None:
|
||
|
# (1, odim, L)
|
||
|
outs = outs + self.postnet(outs)
|
||
|
# (L, odim)
|
||
|
outs = outs.transpose([0, 2, 1]).squeeze(0)
|
||
|
probs = paddle.concat(probs, axis=0)
|
||
|
break
|
||
|
|
||
|
# concatenate attention weights -> (#layers, #heads, L, T)
|
||
|
att_ws = paddle.stack(att_ws, axis=0)
|
||
|
|
||
|
return outs, probs, att_ws
|
||
|
|
||
|
def _add_first_frame_and_remove_last_frame(
|
||
|
self, ys: paddle.Tensor) -> paddle.Tensor:
|
||
|
ys_in = paddle.concat(
|
||
|
[paddle.zeros((ys.shape[0], 1, ys.shape[2])), ys[:, :-1]], axis=1)
|
||
|
return ys_in
|
||
|
|
||
|
def _source_mask(self, ilens: paddle.Tensor) -> paddle.Tensor:
|
||
|
"""Make masks for self-attention.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
ilens : Tensor
|
||
|
Batch of lengths (B,).
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
Tensor
|
||
|
Mask tensor for self-attention.
|
||
|
dtype=paddle.bool
|
||
|
|
||
|
Examples
|
||
|
-------
|
||
|
>>> ilens = [5, 3]
|
||
|
>>> self._source_mask(ilens)
|
||
|
tensor([[[1, 1, 1, 1, 1],
|
||
|
[1, 1, 1, 0, 0]]]) bool
|
||
|
|
||
|
"""
|
||
|
x_masks = make_non_pad_mask(ilens)
|
||
|
return x_masks.unsqueeze(-2)
|
||
|
|
||
|
def _target_mask(self, olens: paddle.Tensor) -> paddle.Tensor:
|
||
|
"""Make masks for masked self-attention.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
olens : LongTensor
|
||
|
Batch of lengths (B,).
|
||
|
|
||
|
Returns
|
||
|
----------
|
||
|
Tensor
|
||
|
Mask tensor for masked self-attention.
|
||
|
|
||
|
Examples
|
||
|
----------
|
||
|
>>> olens = [5, 3]
|
||
|
>>> self._target_mask(olens)
|
||
|
tensor([[[1, 0, 0, 0, 0],
|
||
|
[1, 1, 0, 0, 0],
|
||
|
[1, 1, 1, 0, 0],
|
||
|
[1, 1, 1, 1, 0],
|
||
|
[1, 1, 1, 1, 1]],
|
||
|
[[1, 0, 0, 0, 0],
|
||
|
[1, 1, 0, 0, 0],
|
||
|
[1, 1, 1, 0, 0],
|
||
|
[1, 1, 1, 0, 0],
|
||
|
[1, 1, 1, 0, 0]]], dtype=paddle.uint8)
|
||
|
|
||
|
"""
|
||
|
y_masks = make_non_pad_mask(olens)
|
||
|
s_masks = subsequent_mask(y_masks.shape[-1]).unsqueeze(0)
|
||
|
return paddle.logical_and(y_masks.unsqueeze(-2), s_masks)
|
||
|
|
||
|
def _integrate_with_spk_embed(self,
|
||
|
hs: paddle.Tensor,
|
||
|
spembs: paddle.Tensor) -> paddle.Tensor:
|
||
|
"""Integrate speaker embedding with hidden states.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
hs : Tensor
|
||
|
Batch of hidden state sequences (B, Tmax, adim).
|
||
|
spembs : Tensor
|
||
|
Batch of speaker embeddings (B, spk_embed_dim).
|
||
|
|
||
|
Returns
|
||
|
----------
|
||
|
Tensor
|
||
|
Batch of integrated hidden state sequences (B, Tmax, adim).
|
||
|
|
||
|
"""
|
||
|
if self.spk_embed_integration_type == "add":
|
||
|
# apply projection and then add to hidden states
|
||
|
spembs = self.projection(F.normalize(spembs))
|
||
|
hs = hs + spembs.unsqueeze(1)
|
||
|
elif self.spk_embed_integration_type == "concat":
|
||
|
# concat hidden states with spk embeds and then apply projection
|
||
|
spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.shape[1],
|
||
|
-1)
|
||
|
hs = self.projection(paddle.concat([hs, spembs], axis=-1))
|
||
|
else:
|
||
|
raise NotImplementedError("support only add or concat.")
|
||
|
|
||
|
return hs
|
||
|
|
||
|
|
||
|
class TransformerTTSInference(nn.Layer):
|
||
|
def __init__(self, normalizer, model):
|
||
|
super().__init__()
|
||
|
self.normalizer = normalizer
|
||
|
self.acoustic_model = model
|
||
|
|
||
|
def forward(self, text, spk_id=None):
|
||
|
normalized_mel = self.acoustic_model.inference(text)[0]
|
||
|
logmel = self.normalizer.inverse(normalized_mel)
|
||
|
return logmel
|
||
|
|
||
|
|
||
|
class TransformerTTSLoss(nn.Layer):
|
||
|
"""Loss function module for Tacotron2."""
|
||
|
|
||
|
def __init__(self,
|
||
|
use_masking=True,
|
||
|
use_weighted_masking=False,
|
||
|
bce_pos_weight=5.0):
|
||
|
"""Initialize Tactoron2 loss module.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
use_masking : bool
|
||
|
Whether to apply masking for padded part in loss calculation.
|
||
|
use_weighted_masking : bool
|
||
|
Whether to apply weighted masking in loss calculation.
|
||
|
bce_pos_weight : float
|
||
|
Weight of positive sample of stop token.
|
||
|
|
||
|
"""
|
||
|
super().__init__()
|
||
|
assert (use_masking != use_weighted_masking) or not use_masking
|
||
|
self.use_masking = use_masking
|
||
|
self.use_weighted_masking = use_weighted_masking
|
||
|
|
||
|
# define criterions
|
||
|
reduction = "none" if self.use_weighted_masking else "mean"
|
||
|
self.l1_criterion = nn.L1Loss(reduction=reduction)
|
||
|
self.mse_criterion = nn.MSELoss(reduction=reduction)
|
||
|
self.bce_criterion = nn.BCEWithLogitsLoss(
|
||
|
reduction=reduction, pos_weight=paddle.to_tensor(bce_pos_weight))
|
||
|
|
||
|
def forward(self, after_outs, before_outs, logits, ys, labels, olens):
|
||
|
"""Calculate forward propagation.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
after_outs : Tensor
|
||
|
Batch of outputs after postnets (B, Lmax, odim).
|
||
|
before_outs : Tensor
|
||
|
Batch of outputs before postnets (B, Lmax, odim).
|
||
|
logits : Tensor
|
||
|
Batch of stop logits (B, Lmax).
|
||
|
ys : Tensor
|
||
|
Batch of padded target features (B, Lmax, odim).
|
||
|
labels : LongTensor
|
||
|
Batch of the sequences of stop token labels (B, Lmax).
|
||
|
olens : LongTensor
|
||
|
Batch of the lengths of each target (B,).
|
||
|
|
||
|
Returns
|
||
|
----------
|
||
|
Tensor
|
||
|
L1 loss value.
|
||
|
Tensor
|
||
|
Mean square error loss value.
|
||
|
Tensor
|
||
|
Binary cross entropy loss value.
|
||
|
|
||
|
"""
|
||
|
# make mask and apply it
|
||
|
if self.use_masking:
|
||
|
masks = make_non_pad_mask(olens).unsqueeze(-1)
|
||
|
ys = ys.masked_select(masks.broadcast_to(ys.shape))
|
||
|
after_outs = after_outs.masked_select(
|
||
|
masks.broadcast_to(after_outs.shape))
|
||
|
before_outs = before_outs.masked_select(
|
||
|
masks.broadcast_to(before_outs.shape))
|
||
|
# Operator slice does not have kernel for data_type[bool]
|
||
|
tmp_masks = paddle.cast(masks, dtype='int64')
|
||
|
tmp_masks = tmp_masks[:, :, 0]
|
||
|
tmp_masks = paddle.cast(tmp_masks, dtype='bool')
|
||
|
labels = labels.masked_select(tmp_masks.broadcast_to(labels.shape))
|
||
|
logits = logits.masked_select(tmp_masks.broadcast_to(logits.shape))
|
||
|
|
||
|
# calculate loss
|
||
|
l1_loss = self.l1_criterion(after_outs, ys) + self.l1_criterion(
|
||
|
before_outs, ys)
|
||
|
mse_loss = self.mse_criterion(after_outs, ys) + self.mse_criterion(
|
||
|
before_outs, ys)
|
||
|
bce_loss = self.bce_criterion(logits, labels)
|
||
|
|
||
|
# make weighted mask and apply it
|
||
|
if self.use_weighted_masking:
|
||
|
masks = make_non_pad_mask(olens).unsqueeze(-1)
|
||
|
weights = masks.float() / masks.sum(dim=1, keepdim=True).float()
|
||
|
out_weights = weights.div(ys.shape[0] * ys.shape[2])
|
||
|
logit_weights = weights.div(ys.shape[0])
|
||
|
|
||
|
# apply weight
|
||
|
l1_loss = l1_loss.multiply(out_weights)
|
||
|
l1_loss = l1_loss.masked_select(
|
||
|
masks.broadcast_to(l1_loss.shape)).sum()
|
||
|
|
||
|
mse_loss = mse_loss.multiply(out_weights)
|
||
|
mse_loss = mse_loss.masked_select(
|
||
|
masks.broadcast_to(mse_loss.shape)).sum()
|
||
|
|
||
|
bce_loss = bce_loss.multiply(logit_weights.squeeze(-1))
|
||
|
bce_loss = bce_loss.masked_select(
|
||
|
masks.squeeze(-1).broadcast_to(bce_loss.shape)).sum()
|
||
|
|
||
|
return l1_loss, mse_loss, bce_loss
|
||
|
|
||
|
|
||
|
class GuidedAttentionLoss(nn.Layer):
|
||
|
"""Guided attention loss function module.
|
||
|
|
||
|
This module calculates the guided attention loss described
|
||
|
in `Efficiently Trainable Text-to-Speech System Based
|
||
|
on Deep Convolutional Networks with Guided Attention`_,
|
||
|
which forces the attention to be diagonal.
|
||
|
|
||
|
.. _`Efficiently Trainable Text-to-Speech System
|
||
|
Based on Deep Convolutional Networks with Guided Attention`:
|
||
|
https://arxiv.org/abs/1710.08969
|
||
|
|
||
|
"""
|
||
|
|
||
|
def __init__(self, sigma=0.4, alpha=1.0, reset_always=True):
|
||
|
"""Initialize guided attention loss module.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
sigma : float, optional
|
||
|
Standard deviation to control how close attention to a diagonal.
|
||
|
alpha : float, optional
|
||
|
Scaling coefficient (lambda).
|
||
|
reset_always : bool, optional
|
||
|
Whether to always reset masks.
|
||
|
|
||
|
"""
|
||
|
super(GuidedAttentionLoss, self).__init__()
|
||
|
self.sigma = sigma
|
||
|
self.alpha = alpha
|
||
|
self.reset_always = reset_always
|
||
|
self.guided_attn_masks = None
|
||
|
self.masks = None
|
||
|
|
||
|
def _reset_masks(self):
|
||
|
self.guided_attn_masks = None
|
||
|
self.masks = None
|
||
|
|
||
|
def forward(self, att_ws, ilens, olens):
|
||
|
"""Calculate forward propagation.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
att_ws : Tensor
|
||
|
Batch of attention weights (B, T_max_out, T_max_in).
|
||
|
ilens : LongTensor
|
||
|
Batch of input lenghts (B,).
|
||
|
olens : LongTensor
|
||
|
Batch of output lenghts (B,).
|
||
|
|
||
|
Returns
|
||
|
----------
|
||
|
Tensor
|
||
|
Guided attention loss value.
|
||
|
|
||
|
"""
|
||
|
if self.guided_attn_masks is None:
|
||
|
self.guided_attn_masks = self._make_guided_attention_masks(ilens,
|
||
|
olens)
|
||
|
if self.masks is None:
|
||
|
self.masks = self._make_masks(ilens, olens)
|
||
|
losses = self.guided_attn_masks * att_ws
|
||
|
loss = paddle.mean(
|
||
|
losses.masked_select(self.masks.broadcast_to(losses.shape)))
|
||
|
if self.reset_always:
|
||
|
self._reset_masks()
|
||
|
return self.alpha * loss
|
||
|
|
||
|
def _make_guided_attention_masks(self, ilens, olens):
|
||
|
n_batches = len(ilens)
|
||
|
max_ilen = max(ilens)
|
||
|
max_olen = max(olens)
|
||
|
guided_attn_masks = paddle.zeros((n_batches, max_olen, max_ilen))
|
||
|
|
||
|
for idx, (ilen, olen) in enumerate(zip(ilens, olens)):
|
||
|
|
||
|
ilen = int(ilen)
|
||
|
olen = int(olen)
|
||
|
guided_attn_masks[idx, :olen, :
|
||
|
ilen] = self._make_guided_attention_mask(
|
||
|
ilen, olen, self.sigma)
|
||
|
return guided_attn_masks
|
||
|
|
||
|
@staticmethod
|
||
|
def _make_guided_attention_mask(ilen, olen, sigma):
|
||
|
"""Make guided attention mask.
|
||
|
|
||
|
Examples
|
||
|
----------
|
||
|
>>> guided_attn_mask =_make_guided_attention(5, 5, 0.4)
|
||
|
>>> guided_attn_mask.shape
|
||
|
[5, 5]
|
||
|
>>> guided_attn_mask
|
||
|
tensor([[0.0000, 0.1175, 0.3935, 0.6753, 0.8647],
|
||
|
[0.1175, 0.0000, 0.1175, 0.3935, 0.6753],
|
||
|
[0.3935, 0.1175, 0.0000, 0.1175, 0.3935],
|
||
|
[0.6753, 0.3935, 0.1175, 0.0000, 0.1175],
|
||
|
[0.8647, 0.6753, 0.3935, 0.1175, 0.0000]])
|
||
|
>>> guided_attn_mask =_make_guided_attention(3, 6, 0.4)
|
||
|
>>> guided_attn_mask.shape
|
||
|
[6, 3]
|
||
|
>>> guided_attn_mask
|
||
|
tensor([[0.0000, 0.2934, 0.7506],
|
||
|
[0.0831, 0.0831, 0.5422],
|
||
|
[0.2934, 0.0000, 0.2934],
|
||
|
[0.5422, 0.0831, 0.0831],
|
||
|
[0.7506, 0.2934, 0.0000],
|
||
|
[0.8858, 0.5422, 0.0831]])
|
||
|
|
||
|
"""
|
||
|
grid_x, grid_y = paddle.meshgrid(
|
||
|
paddle.arange(olen), paddle.arange(ilen))
|
||
|
grid_x = grid_x.cast(dtype=paddle.float32)
|
||
|
grid_y = grid_y.cast(dtype=paddle.float32)
|
||
|
return 1.0 - paddle.exp(-(
|
||
|
(grid_y / ilen - grid_x / olen)**2) / (2 * (sigma**2)))
|
||
|
|
||
|
@staticmethod
|
||
|
def _make_masks(ilens, olens):
|
||
|
"""Make masks indicating non-padded part.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
ilens (LongTensor or List): Batch of lengths (B,).
|
||
|
olens (LongTensor or List): Batch of lengths (B,).
|
||
|
|
||
|
Returns
|
||
|
----------
|
||
|
Tensor
|
||
|
Mask tensor indicating non-padded part.
|
||
|
|
||
|
Examples
|
||
|
----------
|
||
|
>>> ilens, olens = [5, 2], [8, 5]
|
||
|
>>> _make_mask(ilens, olens)
|
||
|
tensor([[[1, 1, 1, 1, 1],
|
||
|
[1, 1, 1, 1, 1],
|
||
|
[1, 1, 1, 1, 1],
|
||
|
[1, 1, 1, 1, 1],
|
||
|
[1, 1, 1, 1, 1],
|
||
|
[1, 1, 1, 1, 1],
|
||
|
[1, 1, 1, 1, 1],
|
||
|
[1, 1, 1, 1, 1]],
|
||
|
|
||
|
[[1, 1, 0, 0, 0],
|
||
|
[1, 1, 0, 0, 0],
|
||
|
[1, 1, 0, 0, 0],
|
||
|
[1, 1, 0, 0, 0],
|
||
|
[1, 1, 0, 0, 0],
|
||
|
[0, 0, 0, 0, 0],
|
||
|
[0, 0, 0, 0, 0],
|
||
|
[0, 0, 0, 0, 0]]], dtype=paddle.uint8)
|
||
|
|
||
|
"""
|
||
|
# (B, T_in)
|
||
|
in_masks = make_non_pad_mask(ilens)
|
||
|
# (B, T_out)
|
||
|
out_masks = make_non_pad_mask(olens)
|
||
|
# (B, T_out, T_in)
|
||
|
|
||
|
return paddle.logical_and(
|
||
|
out_masks.unsqueeze(-1), in_masks.unsqueeze(-2))
|
||
|
|
||
|
|
||
|
class GuidedMultiHeadAttentionLoss(GuidedAttentionLoss):
|
||
|
"""Guided attention loss function module for multi head attention.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
sigma : float, optional
|
||
|
Standard deviation to controlGuidedAttentionLoss
|
||
|
how close attention to a diagonal.
|
||
|
alpha : float, optional
|
||
|
Scaling coefficient (lambda).
|
||
|
reset_always : bool, optional
|
||
|
Whether to always reset masks.
|
||
|
|
||
|
"""
|
||
|
|
||
|
def forward(self, att_ws, ilens, olens):
|
||
|
"""Calculate forward propagation.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
att_ws : Tensor
|
||
|
Batch of multi head attention weights (B, H, T_max_out, T_max_in).
|
||
|
ilens : Tensor
|
||
|
Batch of input lenghts (B,).
|
||
|
olens : Tensor
|
||
|
Batch of output lenghts (B,).
|
||
|
|
||
|
Returns
|
||
|
----------
|
||
|
Tensor
|
||
|
Guided attention loss value.
|
||
|
|
||
|
"""
|
||
|
if self.guided_attn_masks is None:
|
||
|
self.guided_attn_masks = (
|
||
|
self._make_guided_attention_masks(ilens, olens).unsqueeze(1))
|
||
|
if self.masks is None:
|
||
|
self.masks = self._make_masks(ilens, olens).unsqueeze(1)
|
||
|
losses = self.guided_attn_masks * att_ws
|
||
|
loss = paddle.mean(
|
||
|
losses.masked_select(self.masks.broadcast_to(losses.shape)))
|
||
|
if self.reset_always:
|
||
|
self._reset_masks()
|
||
|
|
||
|
return self.alpha * loss
|