update racotron2 and transformer tts, test=tts

pull/1314/head
TianYuan 3 years ago
parent 89e988a69e
commit 9c7f0762b0

@ -324,7 +324,10 @@ class Tacotron2(nn.Layer):
ys = ys[:, :max_out] ys = ys[:, :max_out]
labels = labels[:, :max_out] labels = labels[:, :max_out]
labels = paddle.scatter(labels, 1, (olens - 1).unsqueeze(1), 1.0) labels = paddle.scatter(labels, 1, (olens - 1).unsqueeze(1), 1.0)
return after_outs, before_outs, logits, ys, labels, olens, att_ws, ilens olens_in = olens // self.reduction_factor
else:
olens_in = olens
return after_outs, before_outs, logits, ys, labels, olens, att_ws, olens_in
def _forward( def _forward(
self, self,

@ -72,11 +72,10 @@ class Tacotron2Updater(StandardUpdater):
# spk_id!=None in multiple spk fastspeech2 # spk_id!=None in multiple spk fastspeech2
spk_id = batch["spk_id"] if "spk_id" in batch else None spk_id = batch["spk_id"] if "spk_id" in batch else None
spk_emb = batch["spk_emb"] if "spk_emb" in batch else None spk_emb = batch["spk_emb"] if "spk_emb" in batch else None
# No explicit speaker identifier labels are used during voice cloning training.
if spk_emb is not None: if spk_emb is not None:
spk_id = None spk_id = None
after_outs, before_outs, logits, ys, labels, olens, att_ws, ilens = self.model( after_outs, before_outs, logits, ys, labels, olens, att_ws, olens_in = self.model(
text=batch["text"], text=batch["text"],
text_lengths=batch["text_lengths"], text_lengths=batch["text_lengths"],
speech=batch["speech"], speech=batch["speech"],
@ -101,11 +100,8 @@ class Tacotron2Updater(StandardUpdater):
if self.use_guided_attn_loss: if self.use_guided_attn_loss:
# NOTE: length of output for auto-regressive # NOTE: length of output for auto-regressive
# input will be changed when r > 1 # input will be changed when r > 1
if self.model.reduction_factor > 1: attn_loss = self.attn_loss(
olens_in = olens // self.model.reduction_factor att_ws=att_ws, ilens=batch["text_lengths"] + 1, olens=olens_in)
else:
olens_in = olens
attn_loss = self.attn_loss(att_ws, ilens, olens_in)
loss = loss + attn_loss loss = loss + attn_loss
optimizer = self.optimizer optimizer = self.optimizer
@ -169,7 +165,7 @@ class Tacotron2Evaluator(StandardEvaluator):
if spk_emb is not None: if spk_emb is not None:
spk_id = None spk_id = None
after_outs, before_outs, logits, ys, labels, olens, att_ws, ilens = self.model( after_outs, before_outs, logits, ys, labels, olens, att_ws, olens_in = self.model(
text=batch["text"], text=batch["text"],
text_lengths=batch["text_lengths"], text_lengths=batch["text_lengths"],
speech=batch["speech"], speech=batch["speech"],
@ -194,11 +190,8 @@ class Tacotron2Evaluator(StandardEvaluator):
if self.use_guided_attn_loss: if self.use_guided_attn_loss:
# NOTE: length of output for auto-regressive # NOTE: length of output for auto-regressive
# input will be changed when r > 1 # input will be changed when r > 1
if self.model.reduction_factor > 1: attn_loss = self.attn_loss(
olens_in = olens // self.model.reduction_factor att_ws=att_ws, ilens=batch["text_lengths"] + 1, olens=olens_in)
else:
olens_in = olens
attn_loss = self.attn_loss(att_ws, ilens, olens_in)
loss = loss + attn_loss loss = loss + attn_loss
report("eval/l1_loss", float(l1_loss)) report("eval/l1_loss", float(l1_loss))

@ -447,12 +447,15 @@ class TransformerTTS(nn.Layer):
# modifiy mod part of groundtruth # modifiy mod part of groundtruth
if self.reduction_factor > 1: if self.reduction_factor > 1:
olens = paddle.to_tensor( olens = olens - olens % self.reduction_factor
[olen - olen % self.reduction_factor for olen in olens.numpy()])
max_olen = max(olens) max_olen = max(olens)
ys = ys[:, :max_olen] ys = ys[:, :max_olen]
labels = labels[:, :max_olen] labels = labels[:, :max_olen]
labels[:, -1] = 1.0 # make sure at least one frame has 1 labels[:, -1] = 1.0 # make sure at least one frame has 1
olens_in = olens // self.reduction_factor
else:
olens_in = olens
need_dict = {} need_dict = {}
need_dict['encoder'] = self.encoder need_dict['encoder'] = self.encoder
need_dict['decoder'] = self.decoder need_dict['decoder'] = self.decoder
@ -462,7 +465,7 @@ class TransformerTTS(nn.Layer):
'num_layers_applied_guided_attn'] = self.num_layers_applied_guided_attn 'num_layers_applied_guided_attn'] = self.num_layers_applied_guided_attn
need_dict['use_scaled_pos_enc'] = self.use_scaled_pos_enc need_dict['use_scaled_pos_enc'] = self.use_scaled_pos_enc
return after_outs, before_outs, logits, ys, labels, olens, ilens, need_dict return after_outs, before_outs, logits, ys, labels, olens, olens_in, need_dict
def _forward( def _forward(
self, self,
@ -488,8 +491,7 @@ class TransformerTTS(nn.Layer):
# thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim) # thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim)
if self.reduction_factor > 1: if self.reduction_factor > 1:
ys_in = ys[:, self.reduction_factor - 1::self.reduction_factor] ys_in = ys[:, self.reduction_factor - 1::self.reduction_factor]
olens_in = olens.new( olens_in = olens // self.reduction_factor
[olen // self.reduction_factor for olen in olens])
else: else:
ys_in, olens_in = ys, olens ys_in, olens_in = ys, olens
@ -769,318 +771,3 @@ class TransformerTTSInference(nn.Layer):
normalized_mel = self.acoustic_model.inference(text)[0] normalized_mel = self.acoustic_model.inference(text)[0]
logmel = self.normalizer.inverse(normalized_mel) logmel = self.normalizer.inverse(normalized_mel)
return logmel return logmel
class TransformerTTSLoss(nn.Layer):
"""Loss function module for Tacotron2."""
def __init__(self,
use_masking=True,
use_weighted_masking=False,
bce_pos_weight=5.0):
"""Initialize Tactoron2 loss module.
Parameters
----------
use_masking : bool
Whether to apply masking for padded part in loss calculation.
use_weighted_masking : bool
Whether to apply weighted masking in loss calculation.
bce_pos_weight : float
Weight of positive sample of stop token.
"""
super().__init__()
assert (use_masking != use_weighted_masking) or not use_masking
self.use_masking = use_masking
self.use_weighted_masking = use_weighted_masking
# define criterions
reduction = "none" if self.use_weighted_masking else "mean"
self.l1_criterion = nn.L1Loss(reduction=reduction)
self.mse_criterion = nn.MSELoss(reduction=reduction)
self.bce_criterion = nn.BCEWithLogitsLoss(
reduction=reduction, pos_weight=paddle.to_tensor(bce_pos_weight))
def forward(self, after_outs, before_outs, logits, ys, labels, olens):
"""Calculate forward propagation.
Parameters
----------
after_outs : Tensor
Batch of outputs after postnets (B, Lmax, odim).
before_outs : Tensor
Batch of outputs before postnets (B, Lmax, odim).
logits : Tensor
Batch of stop logits (B, Lmax).
ys : Tensor
Batch of padded target features (B, Lmax, odim).
labels : LongTensor
Batch of the sequences of stop token labels (B, Lmax).
olens : LongTensor
Batch of the lengths of each target (B,).
Returns
----------
Tensor
L1 loss value.
Tensor
Mean square error loss value.
Tensor
Binary cross entropy loss value.
"""
# make mask and apply it
if self.use_masking:
masks = make_non_pad_mask(olens).unsqueeze(-1)
ys = ys.masked_select(masks.broadcast_to(ys.shape))
after_outs = after_outs.masked_select(
masks.broadcast_to(after_outs.shape))
before_outs = before_outs.masked_select(
masks.broadcast_to(before_outs.shape))
# Operator slice does not have kernel for data_type[bool]
tmp_masks = paddle.cast(masks, dtype='int64')
tmp_masks = tmp_masks[:, :, 0]
tmp_masks = paddle.cast(tmp_masks, dtype='bool')
labels = labels.masked_select(tmp_masks.broadcast_to(labels.shape))
logits = logits.masked_select(tmp_masks.broadcast_to(logits.shape))
# calculate loss
l1_loss = self.l1_criterion(after_outs, ys) + self.l1_criterion(
before_outs, ys)
mse_loss = self.mse_criterion(after_outs, ys) + self.mse_criterion(
before_outs, ys)
bce_loss = self.bce_criterion(logits, labels)
# make weighted mask and apply it
if self.use_weighted_masking:
masks = make_non_pad_mask(olens).unsqueeze(-1)
weights = masks.float() / masks.sum(dim=1, keepdim=True).float()
out_weights = weights.div(ys.shape[0] * ys.shape[2])
logit_weights = weights.div(ys.shape[0])
# apply weight
l1_loss = l1_loss.multiply(out_weights)
l1_loss = l1_loss.masked_select(
masks.broadcast_to(l1_loss.shape)).sum()
mse_loss = mse_loss.multiply(out_weights)
mse_loss = mse_loss.masked_select(
masks.broadcast_to(mse_loss.shape)).sum()
bce_loss = bce_loss.multiply(logit_weights.squeeze(-1))
bce_loss = bce_loss.masked_select(
masks.squeeze(-1).broadcast_to(bce_loss.shape)).sum()
return l1_loss, mse_loss, bce_loss
class GuidedAttentionLoss(nn.Layer):
"""Guided attention loss function module.
This module calculates the guided attention loss described
in `Efficiently Trainable Text-to-Speech System Based
on Deep Convolutional Networks with Guided Attention`_,
which forces the attention to be diagonal.
.. _`Efficiently Trainable Text-to-Speech System
Based on Deep Convolutional Networks with Guided Attention`:
https://arxiv.org/abs/1710.08969
"""
def __init__(self, sigma=0.4, alpha=1.0, reset_always=True):
"""Initialize guided attention loss module.
Parameters
----------
sigma : float, optional
Standard deviation to control how close attention to a diagonal.
alpha : float, optional
Scaling coefficient (lambda).
reset_always : bool, optional
Whether to always reset masks.
"""
super(GuidedAttentionLoss, self).__init__()
self.sigma = sigma
self.alpha = alpha
self.reset_always = reset_always
self.guided_attn_masks = None
self.masks = None
def _reset_masks(self):
self.guided_attn_masks = None
self.masks = None
def forward(self, att_ws, ilens, olens):
"""Calculate forward propagation.
Parameters
----------
att_ws : Tensor
Batch of attention weights (B, T_max_out, T_max_in).
ilens : LongTensor
Batch of input lenghts (B,).
olens : LongTensor
Batch of output lenghts (B,).
Returns
----------
Tensor
Guided attention loss value.
"""
if self.guided_attn_masks is None:
self.guided_attn_masks = self._make_guided_attention_masks(ilens,
olens)
if self.masks is None:
self.masks = self._make_masks(ilens, olens)
losses = self.guided_attn_masks * att_ws
loss = paddle.mean(
losses.masked_select(self.masks.broadcast_to(losses.shape)))
if self.reset_always:
self._reset_masks()
return self.alpha * loss
def _make_guided_attention_masks(self, ilens, olens):
n_batches = len(ilens)
max_ilen = max(ilens)
max_olen = max(olens)
guided_attn_masks = paddle.zeros((n_batches, max_olen, max_ilen))
for idx, (ilen, olen) in enumerate(zip(ilens, olens)):
ilen = int(ilen)
olen = int(olen)
guided_attn_masks[idx, :olen, :
ilen] = self._make_guided_attention_mask(
ilen, olen, self.sigma)
return guided_attn_masks
@staticmethod
def _make_guided_attention_mask(ilen, olen, sigma):
"""Make guided attention mask.
Examples
----------
>>> guided_attn_mask =_make_guided_attention(5, 5, 0.4)
>>> guided_attn_mask.shape
[5, 5]
>>> guided_attn_mask
tensor([[0.0000, 0.1175, 0.3935, 0.6753, 0.8647],
[0.1175, 0.0000, 0.1175, 0.3935, 0.6753],
[0.3935, 0.1175, 0.0000, 0.1175, 0.3935],
[0.6753, 0.3935, 0.1175, 0.0000, 0.1175],
[0.8647, 0.6753, 0.3935, 0.1175, 0.0000]])
>>> guided_attn_mask =_make_guided_attention(3, 6, 0.4)
>>> guided_attn_mask.shape
[6, 3]
>>> guided_attn_mask
tensor([[0.0000, 0.2934, 0.7506],
[0.0831, 0.0831, 0.5422],
[0.2934, 0.0000, 0.2934],
[0.5422, 0.0831, 0.0831],
[0.7506, 0.2934, 0.0000],
[0.8858, 0.5422, 0.0831]])
"""
grid_x, grid_y = paddle.meshgrid(
paddle.arange(olen), paddle.arange(ilen))
grid_x = grid_x.cast(dtype=paddle.float32)
grid_y = grid_y.cast(dtype=paddle.float32)
return 1.0 - paddle.exp(-(
(grid_y / ilen - grid_x / olen)**2) / (2 * (sigma**2)))
@staticmethod
def _make_masks(ilens, olens):
"""Make masks indicating non-padded part.
Parameters
----------
ilens (LongTensor or List): Batch of lengths (B,).
olens (LongTensor or List): Batch of lengths (B,).
Returns
----------
Tensor
Mask tensor indicating non-padded part.
Examples
----------
>>> ilens, olens = [5, 2], [8, 5]
>>> _make_mask(ilens, olens)
tensor([[[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1]],
[[1, 1, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]], dtype=paddle.uint8)
"""
# (B, T_in)
in_masks = make_non_pad_mask(ilens)
# (B, T_out)
out_masks = make_non_pad_mask(olens)
# (B, T_out, T_in)
return paddle.logical_and(
out_masks.unsqueeze(-1), in_masks.unsqueeze(-2))
class GuidedMultiHeadAttentionLoss(GuidedAttentionLoss):
"""Guided attention loss function module for multi head attention.
Parameters
----------
sigma : float, optional
Standard deviation to controlGuidedAttentionLoss
how close attention to a diagonal.
alpha : float, optional
Scaling coefficient (lambda).
reset_always : bool, optional
Whether to always reset masks.
"""
def forward(self, att_ws, ilens, olens):
"""Calculate forward propagation.
Parameters
----------
att_ws : Tensor
Batch of multi head attention weights (B, H, T_max_out, T_max_in).
ilens : Tensor
Batch of input lenghts (B,).
olens : Tensor
Batch of output lenghts (B,).
Returns
----------
Tensor
Guided attention loss value.
"""
if self.guided_attn_masks is None:
self.guided_attn_masks = (
self._make_guided_attention_masks(ilens, olens).unsqueeze(1))
if self.masks is None:
self.masks = self._make_masks(ilens, olens).unsqueeze(1)
losses = self.guided_attn_masks * att_ws
loss = paddle.mean(
losses.masked_select(self.masks.broadcast_to(losses.shape)))
if self.reset_always:
self._reset_masks()
return self.alpha * loss

@ -17,8 +17,8 @@ from typing import Sequence
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
from paddlespeech.t2s.models.transformer_tts import GuidedMultiHeadAttentionLoss from paddlespeech.t2s.modules.losses import GuidedMultiHeadAttentionLoss
from paddlespeech.t2s.models.transformer_tts import TransformerTTSLoss from paddlespeech.t2s.modules.losses import Tacotron2Loss as TransformerTTSLoss
from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
from paddlespeech.t2s.training.reporter import report from paddlespeech.t2s.training.reporter import report
from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater
@ -71,7 +71,7 @@ class TransformerTTSUpdater(StandardUpdater):
self.msg = "Rank: {}, ".format(dist.get_rank()) self.msg = "Rank: {}, ".format(dist.get_rank())
losses_dict = {} losses_dict = {}
after_outs, before_outs, logits, ys, labels, olens, ilens, need_dict = self.model( after_outs, before_outs, logits, ys, labels, olens, olens_in, need_dict = self.model(
text=batch["text"], text=batch["text"],
text_lengths=batch["text_lengths"], text_lengths=batch["text_lengths"],
speech=batch["speech"], speech=batch["speech"],
@ -116,7 +116,10 @@ class TransformerTTSUpdater(StandardUpdater):
break break
# (B, H*L, T_in, T_in) # (B, H*L, T_in, T_in)
att_ws = paddle.concat(att_ws, axis=1) att_ws = paddle.concat(att_ws, axis=1)
enc_attn_loss = self.attn_criterion(att_ws, ilens, ilens) enc_attn_loss = self.attn_criterion(
att_ws=att_ws,
ilens=batch["text_lengths"] + 1,
olens=batch["text_lengths"] + 1)
loss = loss + enc_attn_loss loss = loss + enc_attn_loss
report("train/enc_attn_loss", float(enc_attn_loss)) report("train/enc_attn_loss", float(enc_attn_loss))
losses_dict["enc_attn_loss"] = float(enc_attn_loss) losses_dict["enc_attn_loss"] = float(enc_attn_loss)
@ -133,7 +136,8 @@ class TransformerTTSUpdater(StandardUpdater):
break break
# (B, H*L, T_out, T_out) # (B, H*L, T_out, T_out)
att_ws = paddle.concat(att_ws, axis=1) att_ws = paddle.concat(att_ws, axis=1)
dec_attn_loss = self.attn_criterion(att_ws, olens, olens) dec_attn_loss = self.attn_criterion(
att_ws=att_ws, ilens=olens_in, olens=olens_in)
report("train/dec_attn_loss", float(dec_attn_loss)) report("train/dec_attn_loss", float(dec_attn_loss))
losses_dict["dec_attn_loss"] = float(dec_attn_loss) losses_dict["dec_attn_loss"] = float(dec_attn_loss)
loss = loss + dec_attn_loss loss = loss + dec_attn_loss
@ -150,7 +154,10 @@ class TransformerTTSUpdater(StandardUpdater):
break break
# (B, H*L, T_out, T_in) # (B, H*L, T_out, T_in)
att_ws = paddle.concat(att_ws, axis=1) att_ws = paddle.concat(att_ws, axis=1)
enc_dec_attn_loss = self.attn_criterion(att_ws, ilens, olens) enc_dec_attn_loss = self.attn_criterion(
att_ws=att_ws,
ilens=batch["text_lengths"] + 1,
olens=olens_in)
report("train/enc_dec_attn_loss", float(enc_dec_attn_loss)) report("train/enc_dec_attn_loss", float(enc_dec_attn_loss))
losses_dict["enc_dec_attn_loss"] = float(enc_dec_attn_loss) losses_dict["enc_dec_attn_loss"] = float(enc_dec_attn_loss)
loss = loss + enc_dec_attn_loss loss = loss + enc_dec_attn_loss
@ -215,7 +222,7 @@ class TransformerTTSEvaluator(StandardEvaluator):
def evaluate_core(self, batch): def evaluate_core(self, batch):
self.msg = "Evaluate: " self.msg = "Evaluate: "
losses_dict = {} losses_dict = {}
after_outs, before_outs, logits, ys, labels, olens, ilens, need_dict = self.model( after_outs, before_outs, logits, ys, labels, olens, olens_in, need_dict = self.model(
text=batch["text"], text=batch["text"],
text_lengths=batch["text_lengths"], text_lengths=batch["text_lengths"],
speech=batch["speech"], speech=batch["speech"],
@ -260,7 +267,10 @@ class TransformerTTSEvaluator(StandardEvaluator):
break break
# (B, H*L, T_in, T_in) # (B, H*L, T_in, T_in)
att_ws = paddle.concat(att_ws, axis=1) att_ws = paddle.concat(att_ws, axis=1)
enc_attn_loss = self.attn_criterion(att_ws, ilens, ilens) enc_attn_loss = self.attn_criterion(
att_ws=att_ws,
ilens=batch["text_lengths"] + 1,
olens=batch["text_lengths"] + 1)
loss = loss + enc_attn_loss loss = loss + enc_attn_loss
report("train/enc_attn_loss", float(enc_attn_loss)) report("train/enc_attn_loss", float(enc_attn_loss))
losses_dict["enc_attn_loss"] = float(enc_attn_loss) losses_dict["enc_attn_loss"] = float(enc_attn_loss)
@ -277,7 +287,8 @@ class TransformerTTSEvaluator(StandardEvaluator):
break break
# (B, H*L, T_out, T_out) # (B, H*L, T_out, T_out)
att_ws = paddle.concat(att_ws, axis=1) att_ws = paddle.concat(att_ws, axis=1)
dec_attn_loss = self.attn_criterion(att_ws, olens, olens) dec_attn_loss = self.attn_criterion(
att_ws=att_ws, ilens=olens_in, olens=olens_in)
report("eval/dec_attn_loss", float(dec_attn_loss)) report("eval/dec_attn_loss", float(dec_attn_loss))
losses_dict["dec_attn_loss"] = float(dec_attn_loss) losses_dict["dec_attn_loss"] = float(dec_attn_loss)
loss = loss + dec_attn_loss loss = loss + dec_attn_loss
@ -295,7 +306,10 @@ class TransformerTTSEvaluator(StandardEvaluator):
break break
# (B, H*L, T_out, T_in) # (B, H*L, T_out, T_in)
att_ws = paddle.concat(att_ws, axis=1) att_ws = paddle.concat(att_ws, axis=1)
enc_dec_attn_loss = self.attn_criterion(att_ws, ilens, olens) enc_dec_attn_loss = self.attn_criterion(
att_ws=att_ws,
ilens=batch["text_lengths"] + 1,
olens=olens_in)
report("eval/enc_dec_attn_loss", float(enc_dec_attn_loss)) report("eval/enc_dec_attn_loss", float(enc_dec_attn_loss))
losses_dict["enc_dec_attn_loss"] = float(enc_dec_attn_loss) losses_dict["enc_dec_attn_loss"] = float(enc_dec_attn_loss)
loss = loss + enc_dec_attn_loss loss = loss + enc_dec_attn_loss

@ -26,26 +26,30 @@ from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
# Loss for new Tacotron2 # Loss for new Tacotron2
class GuidedAttentionLoss(nn.Layer): class GuidedAttentionLoss(nn.Layer):
"""Guided attention loss function module. """Guided attention loss function module.
This module calculates the guided attention loss described This module calculates the guided attention loss described
in `Efficiently Trainable Text-to-Speech System Based in `Efficiently Trainable Text-to-Speech System Based
on Deep Convolutional Networks with Guided Attention`_, on Deep Convolutional Networks with Guided Attention`_,
which forces the attention to be diagonal. which forces the attention to be diagonal.
.. _`Efficiently Trainable Text-to-Speech System .. _`Efficiently Trainable Text-to-Speech System
Based on Deep Convolutional Networks with Guided Attention`: Based on Deep Convolutional Networks with Guided Attention`:
https://arxiv.org/abs/1710.08969 https://arxiv.org/abs/1710.08969
""" """
def __init__(self, sigma=0.4, alpha=1.0, reset_always=True): def __init__(self, sigma=0.4, alpha=1.0, reset_always=True):
"""Initialize guided attention loss module. """Initialize guided attention loss module.
Parameters Parameters
---------- ----------
sigma : float, optional sigma : float, optional
Standard deviation to control Standard deviation to control how close attention to a diagonal.
how close attention to a diagonal.
alpha : float, optional alpha : float, optional
Scaling coefficient (lambda). Scaling coefficient (lambda).
reset_always : bool, optional reset_always : bool, optional
Whether to always reset masks. Whether to always reset masks.
""" """
super().__init__() super().__init__()
self.sigma = sigma self.sigma = sigma
@ -60,18 +64,21 @@ class GuidedAttentionLoss(nn.Layer):
def forward(self, att_ws, ilens, olens): def forward(self, att_ws, ilens, olens):
"""Calculate forward propagation. """Calculate forward propagation.
Parameters Parameters
---------- ----------
att_ws : Tensor att_ws : Tensor
Batch of attention weights (B, T_max_out, T_max_in). Batch of attention weights (B, T_max_out, T_max_in).
ilens : Tensor(int64) ilens : Tensor(int64)
Batch of input lengths (B,). Batch of input lenghts (B,).
olens : Tensor(int64) olens : Tensor(int64)
Batch of output lengths (B,). Batch of output lenghts (B,).
Returns Returns
---------- ----------
Tensor Tensor
Guided attention loss value. Guided attention loss value.
""" """
if self.guided_attn_masks is None: if self.guided_attn_masks is None:
self.guided_attn_masks = self._make_guided_attention_masks(ilens, self.guided_attn_masks = self._make_guided_attention_masks(ilens,
@ -79,7 +86,8 @@ class GuidedAttentionLoss(nn.Layer):
if self.masks is None: if self.masks is None:
self.masks = self._make_masks(ilens, olens) self.masks = self._make_masks(ilens, olens)
losses = self.guided_attn_masks * att_ws losses = self.guided_attn_masks * att_ws
loss = paddle.mean(losses.masked_select(self.masks)) loss = paddle.mean(
losses.masked_select(self.masks.broadcast_to(losses.shape)))
if self.reset_always: if self.reset_always:
self._reset_masks() self._reset_masks()
return self.alpha * loss return self.alpha * loss
@ -89,6 +97,7 @@ class GuidedAttentionLoss(nn.Layer):
max_ilen = max(ilens) max_ilen = max(ilens)
max_olen = max(olens) max_olen = max(olens)
guided_attn_masks = paddle.zeros((n_batches, max_olen, max_ilen)) guided_attn_masks = paddle.zeros((n_batches, max_olen, max_ilen))
for idx, (ilen, olen) in enumerate(zip(ilens, olens)): for idx, (ilen, olen) in enumerate(zip(ilens, olens)):
guided_attn_masks[idx, :olen, : guided_attn_masks[idx, :olen, :
ilen] = self._make_guided_attention_mask( ilen] = self._make_guided_attention_mask(
@ -98,11 +107,12 @@ class GuidedAttentionLoss(nn.Layer):
@staticmethod @staticmethod
def _make_guided_attention_mask(ilen, olen, sigma): def _make_guided_attention_mask(ilen, olen, sigma):
"""Make guided attention mask. """Make guided attention mask.
Parameters
Examples
---------- ----------
>>> guided_attn_mask =_make_guided_attention(5, 5, 0.4) >>> guided_attn_mask =_make_guided_attention(5, 5, 0.4)
>>> guided_attn_mask.shape >>> guided_attn_mask.shape
Size([5, 5]) [5, 5]
>>> guided_attn_mask >>> guided_attn_mask
tensor([[0.0000, 0.1175, 0.3935, 0.6753, 0.8647], tensor([[0.0000, 0.1175, 0.3935, 0.6753, 0.8647],
[0.1175, 0.0000, 0.1175, 0.3935, 0.6753], [0.1175, 0.0000, 0.1175, 0.3935, 0.6753],
@ -111,7 +121,7 @@ class GuidedAttentionLoss(nn.Layer):
[0.8647, 0.6753, 0.3935, 0.1175, 0.0000]]) [0.8647, 0.6753, 0.3935, 0.1175, 0.0000]])
>>> guided_attn_mask =_make_guided_attention(3, 6, 0.4) >>> guided_attn_mask =_make_guided_attention(3, 6, 0.4)
>>> guided_attn_mask.shape >>> guided_attn_mask.shape
Size([6, 3]) [6, 3]
>>> guided_attn_mask >>> guided_attn_mask
tensor([[0.0000, 0.2934, 0.7506], tensor([[0.0000, 0.2934, 0.7506],
[0.0831, 0.0831, 0.5422], [0.0831, 0.0831, 0.5422],
@ -119,55 +129,109 @@ class GuidedAttentionLoss(nn.Layer):
[0.5422, 0.0831, 0.0831], [0.5422, 0.0831, 0.0831],
[0.7506, 0.2934, 0.0000], [0.7506, 0.2934, 0.0000],
[0.8858, 0.5422, 0.0831]]) [0.8858, 0.5422, 0.0831]])
""" """
grid_x, grid_y = paddle.meshgrid( grid_x, grid_y = paddle.meshgrid(
paddle.arange(olen), paddle.arange(ilen)) paddle.arange(olen), paddle.arange(ilen))
grid_x = paddle.cast(grid_x, dtype='float32') grid_x = grid_x.cast(dtype=paddle.float32)
grid_y = paddle.cast(grid_y, dtype='float32') grid_y = grid_y.cast(dtype=paddle.float32)
return 1.0 - paddle.exp(-( return 1.0 - paddle.exp(-(
(grid_y / ilen - grid_x / olen)**2) / (2 * (sigma**2))) (grid_y / ilen - grid_x / olen)**2) / (2 * (sigma**2)))
@staticmethod @staticmethod
def _make_masks(ilens, olens): def _make_masks(ilens, olens):
"""Make masks indicating non-padded part. """Make masks indicating non-padded part.
Examples
Parameters
---------- ----------
ilens : Tensor(int64) or List ilens : Tensor(int64) or List
Batch of lengths (B,). Batch of lengths (B,).
olens : Tensor(int64) or List olens : Tensor(int64) or List
Batch of lengths (B,). Batch of lengths (B,).
Returns Returns
---------- ----------
Tensor Tensor
Mask tensor indicating non-padded part. Mask tensor indicating non-padded part.
Examples Examples
---------- ----------
>>> ilens, olens = [5, 2], [8, 5] >>> ilens, olens = [5, 2], [8, 5]
>>> _make_mask(ilens, olens) >>> _make_mask(ilens, olens)
tensor([[[1, 1, 1, 1, 1], tensor([[[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1], [1, 1, 1, 1, 1],
[1, 1, 1, 1, 1], [1, 1, 1, 1, 1],
[1, 1, 1, 1, 1], [1, 1, 1, 1, 1],
[1, 1, 1, 1, 1], [1, 1, 1, 1, 1],
[1, 1, 1, 1, 1], [1, 1, 1, 1, 1],
[1, 1, 1, 1, 1], [1, 1, 1, 1, 1],
[1, 1, 1, 1, 1]], [1, 1, 1, 1, 1]],
[[1, 1, 0, 0, 0], [[1, 1, 0, 0, 0],
[1, 1, 0, 0, 0], [1, 1, 0, 0, 0],
[1, 1, 0, 0, 0], [1, 1, 0, 0, 0],
[1, 1, 0, 0, 0], [1, 1, 0, 0, 0],
[1, 1, 0, 0, 0], [1, 1, 0, 0, 0],
[0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
[0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]],) [0, 0, 0, 0, 0]]], dtype=paddle.uint8)
""" """
# (B, T_in) # (B, T_in)
in_masks = make_non_pad_mask(ilens) in_masks = make_non_pad_mask(ilens)
# (B, T_out) # (B, T_out)
out_masks = make_non_pad_mask(olens) out_masks = make_non_pad_mask(olens)
# (B, T_out, T_in) # (B, T_out, T_in)
return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2)
return paddle.logical_and(
out_masks.unsqueeze(-1), in_masks.unsqueeze(-2))
class GuidedMultiHeadAttentionLoss(GuidedAttentionLoss):
"""Guided attention loss function module for multi head attention.
Parameters
----------
sigma : float, optional
Standard deviation to controlGuidedAttentionLoss
how close attention to a diagonal.
alpha : float, optional
Scaling coefficient (lambda).
reset_always : bool, optional
Whether to always reset masks.
"""
def forward(self, att_ws, ilens, olens):
"""Calculate forward propagation.
Parameters
----------
att_ws : Tensor
Batch of multi head attention weights (B, H, T_max_out, T_max_in).
ilens : Tensor
Batch of input lenghts (B,).
olens : Tensor
Batch of output lenghts (B,).
Returns
----------
Tensor
Guided attention loss value.
"""
if self.guided_attn_masks is None:
self.guided_attn_masks = (
self._make_guided_attention_masks(ilens, olens).unsqueeze(1))
if self.masks is None:
self.masks = self._make_masks(ilens, olens).unsqueeze(1)
losses = self.guided_attn_masks * att_ws
loss = paddle.mean(
losses.masked_select(self.masks.broadcast_to(losses.shape)))
if self.reset_always:
self._reset_masks()
return self.alpha * loss
class Tacotron2Loss(nn.Layer): class Tacotron2Loss(nn.Layer):

Loading…
Cancel
Save