@ -25,6 +25,8 @@ import paddle.nn.functional as F
from paddle import nn
from typeguard import check_argument_types
from paddlespeech . t2s . modules . adversarial_loss . gradient_reversal import GradientReversalLayer
from paddlespeech . t2s . modules . adversarial_loss . speaker_classifier import SpeakerClassifier
from paddlespeech . t2s . modules . nets_utils import initialize
from paddlespeech . t2s . modules . nets_utils import make_non_pad_mask
from paddlespeech . t2s . modules . nets_utils import make_pad_mask
@ -138,7 +140,10 @@ class FastSpeech2(nn.Layer):
# training related
init_type : str = " xavier_uniform " ,
init_enc_alpha : float = 1.0 ,
init_dec_alpha : float = 1.0 , ) :
init_dec_alpha : float = 1.0 ,
# speaker classifier
enable_speaker_classifier : bool = False ,
hidden_sc_dim : int = 256 , ) :
""" Initialize FastSpeech2 module.
Args :
idim ( int ) :
@ -268,6 +273,10 @@ class FastSpeech2(nn.Layer):
Initial value of alpha in scaled pos encoding of the encoder .
init_dec_alpha ( float ) :
Initial value of alpha in scaled pos encoding of the decoder .
enable_speaker_classifier ( bool ) :
Whether to use speaker classifier module
hidden_sc_dim ( int ) :
The hidden layer dim of speaker classifier
"""
assert check_argument_types ( )
@ -281,6 +290,9 @@ class FastSpeech2(nn.Layer):
self . stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor
self . stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor
self . use_scaled_pos_enc = use_scaled_pos_enc
self . hidden_sc_dim = hidden_sc_dim
self . spk_num = spk_num
self . enable_speaker_classifier = enable_speaker_classifier
self . spk_embed_dim = spk_embed_dim
if self . spk_embed_dim is not None :
@ -373,6 +385,12 @@ class FastSpeech2(nn.Layer):
self . tone_projection = nn . Linear ( adim + self . tone_embed_dim ,
adim )
if self . spk_num and self . enable_speaker_classifier :
# set lambda = 1
self . grad_reverse = GradientReversalLayer ( 1 )
self . speaker_classifier = SpeakerClassifier (
idim = adim , hidden_sc_dim = self . hidden_sc_dim , spk_num = spk_num )
# define duration predictor
self . duration_predictor = DurationPredictor (
idim = adim ,
@ -547,7 +565,7 @@ class FastSpeech2(nn.Layer):
if tone_id is not None :
tone_id = paddle . cast ( tone_id , ' int64 ' )
# forward propagation
before_outs , after_outs , d_outs , p_outs , e_outs = self . _forward (
before_outs , after_outs , d_outs , p_outs , e_outs , spk_logits = self . _forward (
xs ,
ilens ,
olens ,
@ -564,7 +582,7 @@ class FastSpeech2(nn.Layer):
max_olen = max ( olens )
ys = ys [ : , : max_olen ]
return before_outs , after_outs , d_outs , p_outs , e_outs , ys , olens
return before_outs , after_outs , d_outs , p_outs , e_outs , ys , olens , spk_logits
def _forward ( self ,
xs : paddle . Tensor ,
@ -584,6 +602,12 @@ class FastSpeech2(nn.Layer):
# (B, Tmax, adim)
hs , _ = self . encoder ( xs , x_masks )
if self . spk_num and self . enable_speaker_classifier and not is_inference :
hs_for_spk_cls = self . grad_reverse ( hs )
spk_logits = self . speaker_classifier ( hs_for_spk_cls , ilens )
else :
spk_logits = None
# integrate speaker embedding
if self . spk_embed_dim is not None :
# spk_emb has a higher priority than spk_id
@ -676,7 +700,7 @@ class FastSpeech2(nn.Layer):
after_outs = before_outs + self . postnet (
before_outs . transpose ( ( 0 , 2 , 1 ) ) ) . transpose ( ( 0 , 2 , 1 ) )
return before_outs , after_outs , d_outs , p_outs , e_outs
return before_outs , after_outs , d_outs , p_outs , e_outs , spk_logits
def encoder_infer (
self ,
@ -771,7 +795,7 @@ class FastSpeech2(nn.Layer):
es = e . unsqueeze ( 0 ) if e is not None else None
# (1, L, odim)
_ , outs , d_outs , p_outs , e_outs = self . _forward (
_ , outs , d_outs , p_outs , e_outs , _ = self . _forward (
xs ,
ilens ,
ds = ds ,
@ -783,7 +807,7 @@ class FastSpeech2(nn.Layer):
is_inference = True )
else :
# (1, L, odim)
_ , outs , d_outs , p_outs , e_outs = self . _forward (
_ , outs , d_outs , p_outs , e_outs , _ = self . _forward (
xs ,
ilens ,
is_inference = True ,
@ -791,6 +815,7 @@ class FastSpeech2(nn.Layer):
spk_emb = spk_emb ,
spk_id = spk_id ,
tone_id = tone_id )
return outs [ 0 ] , d_outs [ 0 ] , p_outs [ 0 ] , e_outs [ 0 ]
def _integrate_with_spk_embed ( self , hs , spk_emb ) :
@ -1058,6 +1083,7 @@ class FastSpeech2Loss(nn.Layer):
self . l1_criterion = nn . L1Loss ( reduction = reduction )
self . mse_criterion = nn . MSELoss ( reduction = reduction )
self . duration_criterion = DurationPredictorLoss ( reduction = reduction )
self . ce_criterion = nn . CrossEntropyLoss ( )
def forward (
self ,
@ -1072,7 +1098,10 @@ class FastSpeech2Loss(nn.Layer):
es : paddle . Tensor ,
ilens : paddle . Tensor ,
olens : paddle . Tensor ,
) - > Tuple [ paddle . Tensor , paddle . Tensor , paddle . Tensor , paddle . Tensor ] :
spk_logits : paddle . Tensor = None ,
spk_ids : paddle . Tensor = None ,
) - > Tuple [ paddle . Tensor , paddle . Tensor , paddle . Tensor , paddle . Tensor ,
paddle . Tensor , ] :
""" Calculate forward propagation.
Args :
@ -1098,11 +1127,18 @@ class FastSpeech2Loss(nn.Layer):
Batch of the lengths of each input ( B , ) .
olens ( Tensor ) :
Batch of the lengths of each target ( B , ) .
spk_logits ( Option [ Tensor ] ) :
Batch of outputs after speaker classifier ( B , Lmax , num_spk )
spk_ids ( Option [ Tensor ] ) :
Batch of target spk_id ( B , )
Returns :
"""
speaker_loss = 0.0
# apply mask to remove padded part
if self . use_masking :
out_masks = make_non_pad_mask ( olens ) . unsqueeze ( - 1 )
@ -1124,6 +1160,16 @@ class FastSpeech2Loss(nn.Layer):
ps = ps . masked_select ( pitch_masks . broadcast_to ( ps . shape ) )
es = es . masked_select ( pitch_masks . broadcast_to ( es . shape ) )
if spk_logits is not None and spk_ids is not None :
batch_size = spk_ids . shape [ 0 ]
spk_ids = paddle . repeat_interleave ( spk_ids , spk_logits . shape [ 1 ] ,
None )
spk_logits = paddle . reshape ( spk_logits ,
[ - 1 , spk_logits . shape [ - 1 ] ] )
mask_index = spk_logits . abs ( ) . sum ( axis = 1 ) != 0
spk_ids = spk_ids [ mask_index ]
spk_logits = spk_logits [ mask_index ]
# calculate loss
l1_loss = self . l1_criterion ( before_outs , ys )
if after_outs is not None :
@ -1132,6 +1178,9 @@ class FastSpeech2Loss(nn.Layer):
pitch_loss = self . mse_criterion ( p_outs , ps )
energy_loss = self . mse_criterion ( e_outs , es )
if spk_logits is not None and spk_ids is not None :
speaker_loss = self . ce_criterion ( spk_logits , spk_ids ) / batch_size
# make weighted mask and apply it
if self . use_weighted_masking :
out_masks = make_non_pad_mask ( olens ) . unsqueeze ( - 1 )
@ -1161,4 +1210,4 @@ class FastSpeech2Loss(nn.Layer):
energy_loss = energy_loss . masked_select (
pitch_masks . broadcast_to ( energy_loss . shape ) ) . sum ( )
return l1_loss , duration_loss , pitch_loss , energy_loss
return l1_loss , duration_loss , pitch_loss , energy_loss , speaker_loss