@ -25,6 +25,8 @@ import paddle.nn.functional as F
from paddle import nn
from paddle import nn
from typeguard import check_argument_types
from typeguard import check_argument_types
from paddlespeech . t2s . modules . adversarial_loss . gradient_reversal import GradientReversalLayer
from paddlespeech . t2s . modules . adversarial_loss . speaker_classifier import SpeakerClassifier
from paddlespeech . t2s . modules . nets_utils import initialize
from paddlespeech . t2s . modules . nets_utils import initialize
from paddlespeech . t2s . modules . nets_utils import make_non_pad_mask
from paddlespeech . t2s . modules . nets_utils import make_non_pad_mask
from paddlespeech . t2s . modules . nets_utils import make_pad_mask
from paddlespeech . t2s . modules . nets_utils import make_pad_mask
@ -37,8 +39,6 @@ from paddlespeech.t2s.modules.transformer.encoder import CNNDecoder
from paddlespeech . t2s . modules . transformer . encoder import CNNPostnet
from paddlespeech . t2s . modules . transformer . encoder import CNNPostnet
from paddlespeech . t2s . modules . transformer . encoder import ConformerEncoder
from paddlespeech . t2s . modules . transformer . encoder import ConformerEncoder
from paddlespeech . t2s . modules . transformer . encoder import TransformerEncoder
from paddlespeech . t2s . modules . transformer . encoder import TransformerEncoder
from paddlespeech . t2s . modules . multi_speakers . speaker_classifier import SpeakerClassifier
from paddlespeech . t2s . modules . multi_speakers . gradient_reversal import GradientReversalLayer
class FastSpeech2 ( nn . Layer ) :
class FastSpeech2 ( nn . Layer ) :
@ -140,10 +140,10 @@ class FastSpeech2(nn.Layer):
# training related
# training related
init_type : str = " xavier_uniform " ,
init_type : str = " xavier_uniform " ,
init_enc_alpha : float = 1.0 ,
init_enc_alpha : float = 1.0 ,
init_dec_alpha : float = 1.0 ,
init_dec_alpha : float = 1.0 ,
# speaker classifier
# speaker classifier
enable_speaker_classifier : bool = False ,
enable_speaker_classifier : bool = False ,
hidden_sc_dim : int = 256 , ) :
hidden_sc_dim : int = 256 , ) :
""" Initialize FastSpeech2 module.
""" Initialize FastSpeech2 module.
Args :
Args :
idim ( int ) :
idim ( int ) :
@ -388,7 +388,8 @@ class FastSpeech2(nn.Layer):
if self . spk_num and self . enable_speaker_classifier :
if self . spk_num and self . enable_speaker_classifier :
# set lambda = 1
# set lambda = 1
self . grad_reverse = GradientReversalLayer ( 1 )
self . grad_reverse = GradientReversalLayer ( 1 )
self . speaker_classifier = SpeakerClassifier ( idim = adim , hidden_sc_dim = self . hidden_sc_dim , spk_num = spk_num )
self . speaker_classifier = SpeakerClassifier (
idim = adim , hidden_sc_dim = self . hidden_sc_dim , spk_num = spk_num )
# define duration predictor
# define duration predictor
self . duration_predictor = DurationPredictor (
self . duration_predictor = DurationPredictor (
@ -601,7 +602,7 @@ class FastSpeech2(nn.Layer):
# (B, Tmax, adim)
# (B, Tmax, adim)
hs , _ = self . encoder ( xs , x_masks )
hs , _ = self . encoder ( xs , x_masks )
if self . spk_num and self . enable_speaker_classifier :
if self . spk_num and self . enable_speaker_classifier and not is_inference :
hs_for_spk_cls = self . grad_reverse ( hs )
hs_for_spk_cls = self . grad_reverse ( hs )
spk_logits = self . speaker_classifier ( hs_for_spk_cls , ilens )
spk_logits = self . speaker_classifier ( hs_for_spk_cls , ilens )
else :
else :
@ -794,7 +795,7 @@ class FastSpeech2(nn.Layer):
es = e . unsqueeze ( 0 ) if e is not None else None
es = e . unsqueeze ( 0 ) if e is not None else None
# (1, L, odim)
# (1, L, odim)
_ , outs , d_outs , p_outs , e_outs = self . _inference (
_ , outs , d_outs , p_outs , e_outs , _ = self . _forward (
xs ,
xs ,
ilens ,
ilens ,
ds = ds ,
ds = ds ,
@ -806,7 +807,7 @@ class FastSpeech2(nn.Layer):
is_inference = True )
is_inference = True )
else :
else :
# (1, L, odim)
# (1, L, odim)
_ , outs , d_outs , p_outs , e_outs = self . _inference (
_ , outs , d_outs , p_outs , e_outs , _ = self . _forward (
xs ,
xs ,
ilens ,
ilens ,
is_inference = True ,
is_inference = True ,
@ -815,121 +816,8 @@ class FastSpeech2(nn.Layer):
spk_id = spk_id ,
spk_id = spk_id ,
tone_id = tone_id )
tone_id = tone_id )
return outs [ 0 ] , d_outs [ 0 ] , p_outs [ 0 ] , e_outs [ 0 ]
return outs [ 0 ] , d_outs [ 0 ] , p_outs [ 0 ] , e_outs [ 0 ]
def _inference ( self ,
xs : paddle . Tensor ,
ilens : paddle . Tensor ,
olens : paddle . Tensor = None ,
ds : paddle . Tensor = None ,
ps : paddle . Tensor = None ,
es : paddle . Tensor = None ,
is_inference : bool = False ,
return_after_enc = False ,
alpha : float = 1.0 ,
spk_emb = None ,
spk_id = None ,
tone_id = None ) - > Sequence [ paddle . Tensor ] :
# forward encoder
x_masks = self . _source_mask ( ilens )
# (B, Tmax, adim)
hs , _ = self . encoder ( xs , x_masks )
# integrate speaker embedding
if self . spk_embed_dim is not None :
# spk_emb has a higher priority than spk_id
if spk_emb is not None :
hs = self . _integrate_with_spk_embed ( hs , spk_emb )
elif spk_id is not None :
spk_emb = self . spk_embedding_table ( spk_id )
hs = self . _integrate_with_spk_embed ( hs , spk_emb )
# integrate tone embedding
if self . tone_embed_dim is not None :
if tone_id is not None :
tone_embs = self . tone_embedding_table ( tone_id )
hs = self . _integrate_with_tone_embed ( hs , tone_embs )
# forward duration predictor and variance predictors
d_masks = make_pad_mask ( ilens )
if self . stop_gradient_from_pitch_predictor :
p_outs = self . pitch_predictor ( hs . detach ( ) , d_masks . unsqueeze ( - 1 ) )
else :
p_outs = self . pitch_predictor ( hs , d_masks . unsqueeze ( - 1 ) )
if self . stop_gradient_from_energy_predictor :
e_outs = self . energy_predictor ( hs . detach ( ) , d_masks . unsqueeze ( - 1 ) )
else :
e_outs = self . energy_predictor ( hs , d_masks . unsqueeze ( - 1 ) )
if is_inference :
# (B, Tmax)
if ds is not None :
d_outs = ds
else :
d_outs = self . duration_predictor . inference ( hs , d_masks )
if ps is not None :
p_outs = ps
if es is not None :
e_outs = es
# use prediction in inference
# (B, Tmax, 1)
p_embs = self . pitch_embed ( p_outs . transpose ( ( 0 , 2 , 1 ) ) ) . transpose (
( 0 , 2 , 1 ) )
e_embs = self . energy_embed ( e_outs . transpose ( ( 0 , 2 , 1 ) ) ) . transpose (
( 0 , 2 , 1 ) )
hs = hs + e_embs + p_embs
# (B, Lmax, adim)
hs = self . length_regulator ( hs , d_outs , alpha , is_inference = True )
else :
d_outs = self . duration_predictor ( hs , d_masks )
# use groundtruth in training
p_embs = self . pitch_embed ( ps . transpose ( ( 0 , 2 , 1 ) ) ) . transpose (
( 0 , 2 , 1 ) )
e_embs = self . energy_embed ( es . transpose ( ( 0 , 2 , 1 ) ) ) . transpose (
( 0 , 2 , 1 ) )
hs = hs + e_embs + p_embs
# (B, Lmax, adim)
hs = self . length_regulator ( hs , ds , is_inference = False )
# forward decoder
if olens is not None and not is_inference :
if self . reduction_factor > 1 :
olens_in = paddle . to_tensor (
[ olen / / self . reduction_factor for olen in olens . numpy ( ) ] )
else :
olens_in = olens
# (B, 1, T)
h_masks = self . _source_mask ( olens_in )
else :
h_masks = None
if return_after_enc :
return hs , h_masks
if self . decoder_type == ' cnndecoder ' :
# remove output masks for dygraph to static graph
zs = self . decoder ( hs , h_masks )
before_outs = zs
else :
# (B, Lmax, adim)
zs , _ = self . decoder ( hs , h_masks )
# (B, Lmax, odim)
before_outs = self . feat_out ( zs ) . reshape (
( paddle . shape ( zs ) [ 0 ] , - 1 , self . odim ) )
# postnet -> (B, Lmax//r * r, odim)
if self . postnet is None :
after_outs = before_outs
else :
after_outs = before_outs + self . postnet (
before_outs . transpose ( ( 0 , 2 , 1 ) ) ) . transpose ( ( 0 , 2 , 1 ) )
return before_outs , after_outs , d_outs , p_outs , e_outs
def _integrate_with_spk_embed ( self , hs , spk_emb ) :
def _integrate_with_spk_embed ( self , hs , spk_emb ) :
""" Integrate speaker embedding with hidden states.
""" Integrate speaker embedding with hidden states.
@ -1212,7 +1100,8 @@ class FastSpeech2Loss(nn.Layer):
olens : paddle . Tensor ,
olens : paddle . Tensor ,
spk_logits : paddle . Tensor = None ,
spk_logits : paddle . Tensor = None ,
spk_ids : paddle . Tensor = None ,
spk_ids : paddle . Tensor = None ,
) - > Tuple [ paddle . Tensor , paddle . Tensor , paddle . Tensor , paddle . Tensor , paddle . Tensor , ] :
) - > Tuple [ paddle . Tensor , paddle . Tensor , paddle . Tensor , paddle . Tensor ,
paddle . Tensor , ] :
""" Calculate forward propagation.
""" Calculate forward propagation.
Args :
Args :
@ -1249,7 +1138,7 @@ class FastSpeech2Loss(nn.Layer):
"""
"""
speaker_loss = 0.0
speaker_loss = 0.0
# apply mask to remove padded part
# apply mask to remove padded part
if self . use_masking :
if self . use_masking :
out_masks = make_non_pad_mask ( olens ) . unsqueeze ( - 1 )
out_masks = make_non_pad_mask ( olens ) . unsqueeze ( - 1 )
@ -1273,12 +1162,13 @@ class FastSpeech2Loss(nn.Layer):
if spk_logits is not None and spk_ids is not None :
if spk_logits is not None and spk_ids is not None :
batch_size = spk_ids . shape [ 0 ]
batch_size = spk_ids . shape [ 0 ]
spk_ids = paddle . repeat_interleave ( spk_ids , spk_logits . shape [ 1 ] , None )
spk_ids = paddle . repeat_interleave ( spk_ids , spk_logits . shape [ 1 ] ,
spk_logits = paddle . reshape ( spk_logits , [ - 1 , spk_logits . shape [ - 1 ] ] )
None )
mask_index = spk_logits . abs ( ) . sum ( axis = 1 ) != 0
spk_logits = paddle . reshape ( spk_logits ,
[ - 1 , spk_logits . shape [ - 1 ] ] )
mask_index = spk_logits . abs ( ) . sum ( axis = 1 ) != 0
spk_ids = spk_ids [ mask_index ]
spk_ids = spk_ids [ mask_index ]
spk_logits = spk_logits [ mask_index ]
spk_logits = spk_logits [ mask_index ]
# calculate loss
# calculate loss
l1_loss = self . l1_criterion ( before_outs , ys )
l1_loss = self . l1_criterion ( before_outs , ys )
@ -1289,7 +1179,7 @@ class FastSpeech2Loss(nn.Layer):
energy_loss = self . mse_criterion ( e_outs , es )
energy_loss = self . mse_criterion ( e_outs , es )
if spk_logits is not None and spk_ids is not None :
if spk_logits is not None and spk_ids is not None :
speaker_loss = self . ce_criterion ( spk_logits , spk_ids ) / batch_size
speaker_loss = self . ce_criterion ( spk_logits , spk_ids ) / batch_size
# make weighted mask and apply it
# make weighted mask and apply it
if self . use_weighted_masking :
if self . use_weighted_masking :