From d368d57d67ec8239c42a25a95e56d534cf23b005 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Mon, 17 Jan 2022 16:11:27 +0800 Subject: [PATCH] fix low ips bug of speedyspeech and fastspeech2, test=tts (#1349) --- .../t2s/models/fastspeech2/fastspeech2.py | 4 +- .../t2s/models/speedyspeech/speedyspeech.py | 62 ++++++++----------- .../t2s/modules/predictor/length_regulator.py | 35 +++++++++-- 3 files changed, 60 insertions(+), 41 deletions(-) diff --git a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py index 405ad957..6bb651a0 100644 --- a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py +++ b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py @@ -627,7 +627,7 @@ class FastSpeech2(nn.Layer): hs = hs + e_embs + p_embs # (B, Lmax, adim) - hs = self.length_regulator(hs, d_outs, alpha) + hs = self.length_regulator(hs, d_outs, alpha, is_inference=True) else: d_outs = self.duration_predictor(hs, d_masks) # use groundtruth in training @@ -638,7 +638,7 @@ class FastSpeech2(nn.Layer): hs = hs + e_embs + p_embs # (B, Lmax, adim) - hs = self.length_regulator(hs, ds) + hs = self.length_regulator(hs, ds, is_inference=False) # forward decoder if olens is not None and not is_inference: diff --git a/paddlespeech/t2s/models/speedyspeech/speedyspeech.py b/paddlespeech/t2s/models/speedyspeech/speedyspeech.py index cc9e2066..42e8f743 100644 --- a/paddlespeech/t2s/models/speedyspeech/speedyspeech.py +++ b/paddlespeech/t2s/models/speedyspeech/speedyspeech.py @@ -14,28 +14,9 @@ import paddle from paddle import nn +from paddlespeech.t2s.modules.nets_utils import initialize from paddlespeech.t2s.modules.positional_encoding import sinusoid_position_encoding - - -def expand(encodings: paddle.Tensor, durations: paddle.Tensor) -> paddle.Tensor: - """ - encodings: (B, T, C) - durations: (B, T) - """ - batch_size, t_enc = paddle.shape(durations) - slens = paddle.sum(durations, -1) - t_dec = paddle.max(slens) - M = paddle.zeros([batch_size, t_dec, t_enc]) - for i in range(batch_size): - k = 0 - for j in range(t_enc): - d = durations[i, j] - # If the d == 0, slice action is meaningless and not supported - if d >= 1: - M[0, k:k + d, j] = 1 - k += d - encodings = paddle.matmul(M, encodings) - return encodings +from paddlespeech.t2s.modules.predictor.length_regulator import LengthRegulator class ResidualBlock(nn.Layer): @@ -175,19 +156,25 @@ class SpeedySpeechDecoder(nn.Layer): class SpeedySpeech(nn.Layer): - def __init__(self, - vocab_size, - encoder_hidden_size, - encoder_kernel_size, - encoder_dilations, - duration_predictor_hidden_size, - decoder_hidden_size, - decoder_output_size, - decoder_kernel_size, - decoder_dilations, - tone_size=None, - spk_num=None): + def __init__( + self, + vocab_size, + encoder_hidden_size, + encoder_kernel_size, + encoder_dilations, + duration_predictor_hidden_size, + decoder_hidden_size, + decoder_output_size, + decoder_kernel_size, + decoder_dilations, + tone_size=None, + spk_num=None, + init_type: str="xavier_uniform", ): super().__init__() + + # initialize parameters + initialize(self, init_type) + encoder = SpeedySpeechEncoder(vocab_size, tone_size, encoder_hidden_size, encoder_kernel_size, encoder_dilations, spk_num) @@ -198,6 +185,10 @@ class SpeedySpeech(nn.Layer): self.encoder = encoder self.duration_predictor = duration_predictor self.decoder = decoder + # define length regulator + self.length_regulator = LengthRegulator() + + nn.initializer.set_global_initializer(None) def forward(self, text, tones, durations, spk_id: paddle.Tensor=None): # input of embedding must be int64 @@ -212,7 +203,7 @@ class SpeedySpeech(nn.Layer): # expand encodings durations_to_expand = durations - encodings = expand(encodings, durations_to_expand) + encodings = self.length_regulator(encodings, durations_to_expand) # decode # remove positional encoding here @@ -240,7 +231,8 @@ class SpeedySpeech(nn.Layer): durations_to_expand = durations_to_expand.astype(paddle.int64) else: durations_to_expand = durations - encodings = expand(encodings, durations_to_expand) + encodings = self.length_regulator( + encodings, durations_to_expand, is_inference=True) shape = paddle.shape(encodings) t_dec, feature_size = shape[1], shape[2] diff --git a/paddlespeech/t2s/modules/predictor/length_regulator.py b/paddlespeech/t2s/modules/predictor/length_regulator.py index f1ecfb7c..9510dd88 100644 --- a/paddlespeech/t2s/modules/predictor/length_regulator.py +++ b/paddlespeech/t2s/modules/predictor/length_regulator.py @@ -13,6 +13,7 @@ # limitations under the License. # Modified from espnet(https://github.com/espnet/espnet) """Length regulator related modules.""" +import numpy as np import paddle from paddle import nn @@ -43,6 +44,28 @@ class LengthRegulator(nn.Layer): super().__init__() self.pad_value = pad_value + # expand_numpy is faster than expand + def expand_numpy(self, encodings: paddle.Tensor, + durations: paddle.Tensor) -> paddle.Tensor: + """ + encodings: (B, T, C) + durations: (B, T) + """ + batch_size, t_enc = durations.shape + durations = durations.numpy() + slens = np.sum(durations, -1) + t_dec = np.max(slens) + M = np.zeros([batch_size, t_dec, t_enc]) + for i in range(batch_size): + k = 0 + for j in range(t_enc): + d = durations[i, j] + M[i, k:k + d, j] = 1 + k += d + M = paddle.to_tensor(M, dtype=encodings.dtype) + encodings = paddle.matmul(M, encodings) + return encodings + def expand(self, encodings: paddle.Tensor, durations: paddle.Tensor) -> paddle.Tensor: """ @@ -50,20 +73,21 @@ class LengthRegulator(nn.Layer): durations: (B, T) """ batch_size, t_enc = paddle.shape(durations) - slens = durations.sum(-1) - t_dec = slens.max() + slens = paddle.sum(durations, -1) + t_dec = paddle.max(slens) M = paddle.zeros([batch_size, t_dec, t_enc]) for i in range(batch_size): k = 0 for j in range(t_enc): d = durations[i, j] + # If the d == 0, slice action is meaningless and not supported in paddle if d >= 1: M[i, k:k + d, j] = 1 k += d encodings = paddle.matmul(M, encodings) return encodings - def forward(self, xs, ds, alpha=1.0): + def forward(self, xs, ds, alpha=1.0, is_inference=False): """Calculate forward propagation. Parameters @@ -85,4 +109,7 @@ class LengthRegulator(nn.Layer): assert alpha > 0 ds = paddle.round(ds.cast(dtype=paddle.float32) * alpha) ds = ds.cast(dtype=paddle.int64) - return self.expand(xs, ds) + if is_inference: + return self.expand(xs, ds) + else: + return self.expand_numpy(xs, ds)