[Fix] fastspeech2 0d

pull/3951/head
megemini 9 months ago
parent 73beb187da
commit f1da82751b

@ -903,14 +903,14 @@ class FastSpeech2(nn.Layer):
# initialize alpha in scaled positional encoding # initialize alpha in scaled positional encoding
if self.encoder_type == "transformer" and self.use_scaled_pos_enc: if self.encoder_type == "transformer" and self.use_scaled_pos_enc:
init_enc_alpha = paddle.to_tensor(init_enc_alpha) init_enc_alpha = paddle.to_tensor(init_enc_alpha).reshape([1])
self.encoder.embed[-1].alpha = paddle.create_parameter( self.encoder.embed[-1].alpha = paddle.create_parameter(
shape=init_enc_alpha.shape, shape=init_enc_alpha.shape,
dtype=str(init_enc_alpha.numpy().dtype), dtype=str(init_enc_alpha.numpy().dtype),
default_initializer=paddle.nn.initializer.Assign( default_initializer=paddle.nn.initializer.Assign(
init_enc_alpha)) init_enc_alpha))
if self.decoder_type == "transformer" and self.use_scaled_pos_enc: if self.decoder_type == "transformer" and self.use_scaled_pos_enc:
init_dec_alpha = paddle.to_tensor(init_dec_alpha) init_dec_alpha = paddle.to_tensor(init_dec_alpha).reshape([1])
self.decoder.embed[-1].alpha = paddle.create_parameter( self.decoder.embed[-1].alpha = paddle.create_parameter(
shape=init_dec_alpha.shape, shape=init_dec_alpha.shape,
dtype=str(init_dec_alpha.numpy().dtype), dtype=str(init_dec_alpha.numpy().dtype),

Loading…
Cancel
Save