From d225a503312165d11dcd68773a37e50cc37dd640 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Wed, 27 Oct 2021 07:23:29 +0000 Subject: [PATCH] fix fastspeech2 to static --- parakeet/models/fastspeech2/fastspeech2.py | 12 ++++++------ .../fastspeech2_predictor/length_regulator.py | 10 +++++----- .../modules/fastspeech2_transformer/attention.py | 8 ++------ .../modules/fastspeech2_transformer/embedding.py | 15 ++++++++------- .../modules/fastspeech2_transformer/encoder.py | 1 + .../fastspeech2_transformer/encoder_layer.py | 2 +- parakeet/modules/layer_norm.py | 8 +++++--- parakeet/modules/masked_fill.py | 14 ++++++++++++-- parakeet/modules/nets_utils.py | 14 ++++---------- 9 files changed, 44 insertions(+), 40 deletions(-) diff --git a/parakeet/models/fastspeech2/fastspeech2.py b/parakeet/models/fastspeech2/fastspeech2.py index 7c0e20bc2..21c2d2c3f 100644 --- a/parakeet/models/fastspeech2/fastspeech2.py +++ b/parakeet/models/fastspeech2/fastspeech2.py @@ -341,6 +341,7 @@ class FastSpeech2(nn.Layer): Tensor speech_lengths, modified if reduction_factor > 1 """ + # input of embedding must be int64 xs = paddle.cast(text, 'int64') ilens = paddle.cast(text_lengths, 'int64') @@ -387,8 +388,8 @@ class FastSpeech2(nn.Layer): spk_id=None, tone_id=None) -> Sequence[paddle.Tensor]: # forward encoder + bs = xs.shape[0] x_masks = self._source_mask(ilens) - # (B, Tmax, adim) hs, _ = self.encoder(xs, x_masks) @@ -405,7 +406,6 @@ class FastSpeech2(nn.Layer): if tone_id is not None: tone_embs = self.tone_embedding_table(tone_id) hs = self._integrate_with_tone_embed(hs, tone_embs) - # forward duration predictor and variance predictors d_masks = make_pad_mask(ilens) @@ -452,9 +452,10 @@ class FastSpeech2(nn.Layer): else: h_masks = None # (B, Lmax, adim) + zs, _ = self.decoder(hs, h_masks) # (B, Lmax, odim) - before_outs = self.feat_out(zs).reshape((zs.shape[0], -1, self.odim)) + before_outs = self.feat_out(zs).reshape((bs, -1, self.odim)) # postnet -> (B, Lmax//r * r, odim) if self.postnet is None: @@ -462,7 +463,6 @@ class FastSpeech2(nn.Layer): else: after_outs = before_outs + self.postnet( before_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) - return before_outs, after_outs, d_outs, p_outs, e_outs def inference( @@ -517,8 +517,8 @@ class FastSpeech2(nn.Layer): d = paddle.cast(durations, 'int64') p, e = pitch, energy # setup batch axis - ilens = paddle.to_tensor( - [x.shape[0]], dtype=paddle.int64, place=x.place) + ilens = paddle.shape(x)[0] + xs, ys = x.unsqueeze(0), None if y is not None: diff --git a/parakeet/modules/fastspeech2_predictor/length_regulator.py b/parakeet/modules/fastspeech2_predictor/length_regulator.py index e5195e536..e413812d2 100644 --- a/parakeet/modules/fastspeech2_predictor/length_regulator.py +++ b/parakeet/modules/fastspeech2_predictor/length_regulator.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Length regulator related modules.""" -import numpy as np import paddle from paddle import nn @@ -50,10 +49,10 @@ class LengthRegulator(nn.Layer): durations: (B, T) """ batch_size, t_enc = durations.shape - durations = durations.numpy() - slens = np.sum(durations, -1) - t_dec = np.max(slens) - M = np.zeros([batch_size, t_dec, t_enc]) + # durations = durations.numpy() + slens = paddle.sum(durations, -1) + t_dec = paddle.max(slens) + M = paddle.zeros([batch_size, t_dec, t_enc]) for i in range(batch_size): k = 0 for j in range(t_enc): @@ -82,6 +81,7 @@ class LengthRegulator(nn.Layer): Tensor replicated input tensor based on durations (B, T*, D). """ + if alpha != 1.0: assert alpha > 0 ds = paddle.round(ds.cast(dtype=paddle.float32) * alpha) diff --git a/parakeet/modules/fastspeech2_transformer/attention.py b/parakeet/modules/fastspeech2_transformer/attention.py index ae941a79a..8cef0023c 100644 --- a/parakeet/modules/fastspeech2_transformer/attention.py +++ b/parakeet/modules/fastspeech2_transformer/attention.py @@ -37,7 +37,7 @@ class MultiHeadedAttention(nn.Layer): def __init__(self, n_head, n_feat, dropout_rate): """Construct an MultiHeadedAttention object.""" super(MultiHeadedAttention, self).__init__() - assert n_feat % n_head == 0 + # assert n_feat % n_head == 0 # We assume d_v always equals d_k self.d_k = n_feat // n_head self.h = n_head @@ -106,13 +106,9 @@ class MultiHeadedAttention(nn.Layer): n_batch = value.shape[0] softmax = paddle.nn.Softmax(axis=-1) if mask is not None: - mask = mask.unsqueeze(1) mask = paddle.logical_not(mask) - min_value = float( - numpy.finfo( - paddle.to_tensor(0, dtype=scores.dtype).numpy().dtype).min) - + min_value = float(numpy.finfo("float32").min) scores = masked_fill(scores, mask, min_value) # (batch, head, time1, time2) self.attn = softmax(scores) diff --git a/parakeet/modules/fastspeech2_transformer/embedding.py b/parakeet/modules/fastspeech2_transformer/embedding.py index 6c1c7245f..888a209a5 100644 --- a/parakeet/modules/fastspeech2_transformer/embedding.py +++ b/parakeet/modules/fastspeech2_transformer/embedding.py @@ -46,13 +46,14 @@ class PositionalEncoding(nn.Layer): def extend_pe(self, x): """Reset the positional encodings.""" - pe = paddle.zeros([x.shape[1], self.d_model]) + pe = paddle.zeros([paddle.shape(x)[1], self.d_model]) if self.reverse: position = paddle.arange( - x.shape[1] - 1, -1, -1.0, dtype=paddle.float32).unsqueeze(1) + paddle.shape(x)[1] - 1, -1, -1.0, + dtype=paddle.float32).unsqueeze(1) else: position = paddle.arange( - 0, x.shape[1], dtype=paddle.float32).unsqueeze(1) + 0, paddle.shape(x)[1], dtype=paddle.float32).unsqueeze(1) div_term = paddle.exp( paddle.arange(0, self.d_model, 2, dtype=paddle.float32) * -(math.log(10000.0) / self.d_model)) @@ -75,7 +76,8 @@ class PositionalEncoding(nn.Layer): Encoded tensor (batch, time, `*`). """ self.extend_pe(x) - x = x * self.xscale + self.pe[:, :x.shape[1]] + + x = x * self.xscale + self.pe[:, :paddle.shape(x)[1]] return self.dropout(x) @@ -101,7 +103,7 @@ class ScaledPositionalEncoding(PositionalEncoding): x = paddle.ones([1], dtype="float32") self.alpha = paddle.create_parameter( shape=x.shape, - dtype=str(x.numpy().dtype), + dtype="float32", default_initializer=paddle.nn.initializer.Assign(x)) def reset_parameters(self): @@ -115,12 +117,11 @@ class ScaledPositionalEncoding(PositionalEncoding): ---------- x : paddle.Tensor Input tensor (batch, time, `*`). - Returns ---------- paddle.Tensor Encoded tensor (batch, time, `*`). """ self.extend_pe(x) - x = x + self.alpha * self.pe[:, :x.shape[1]] + x = x + self.alpha * self.pe[:, :paddle.shape(x)[1]] return self.dropout(x) diff --git a/parakeet/modules/fastspeech2_transformer/encoder.py b/parakeet/modules/fastspeech2_transformer/encoder.py index 630b50ff5..996e9dee0 100644 --- a/parakeet/modules/fastspeech2_transformer/encoder.py +++ b/parakeet/modules/fastspeech2_transformer/encoder.py @@ -185,6 +185,7 @@ class Encoder(nn.Layer): paddle.Tensor Mask tensor (#batch, time). """ + xs = self.embed(xs) xs, masks = self.encoders(xs, masks) if self.normalize_before: diff --git a/parakeet/modules/fastspeech2_transformer/encoder_layer.py b/parakeet/modules/fastspeech2_transformer/encoder_layer.py index d8f89d677..298e13f88 100644 --- a/parakeet/modules/fastspeech2_transformer/encoder_layer.py +++ b/parakeet/modules/fastspeech2_transformer/encoder_layer.py @@ -87,7 +87,7 @@ class EncoderLayer(nn.Layer): if cache is None: x_q = x else: - assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) + # assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) x_q = x[:, -1:, :] residual = residual[:, -1:, :] mask = None if mask is None else mask[:, -1:, :] diff --git a/parakeet/modules/layer_norm.py b/parakeet/modules/layer_norm.py index 3bab823f2..f91b49ae6 100644 --- a/parakeet/modules/layer_norm.py +++ b/parakeet/modules/layer_norm.py @@ -44,6 +44,7 @@ class LayerNorm(paddle.nn.LayerNorm): paddle.Tensor Normalized tensor. """ + if self.dim == -1: return super(LayerNorm, self).forward(x) else: @@ -54,9 +55,10 @@ class LayerNorm(paddle.nn.LayerNorm): orig_perm = list(range(len_dim)) new_perm = orig_perm[:] - new_perm[self.dim], new_perm[len_dim - - 1] = new_perm[len_dim - - 1], new_perm[self.dim] + temp = new_perm[self.dim] + new_perm[self.dim] = new_perm[len_dim - 1] + new_perm[len_dim - 1] = temp + # new_perm[self.dim], new_perm[len_dim -1] = new_perm[len_dim -1], new_perm[self.dim] return paddle.transpose( super(LayerNorm, self).forward(paddle.transpose(x, new_perm)), diff --git a/parakeet/modules/masked_fill.py b/parakeet/modules/masked_fill.py index 34230f1c4..e42a3cc0d 100644 --- a/parakeet/modules/masked_fill.py +++ b/parakeet/modules/masked_fill.py @@ -25,12 +25,22 @@ def is_broadcastable(shp1, shp2): return True +def broadcast_shape(shp1, shp2): + result = [] + for a, b in zip(shp1[::-1], shp2[::-1]): + result.append(max(a, b)) + return result[::-1] + + def masked_fill(xs: paddle.Tensor, mask: paddle.Tensor, value: Union[float, int]): - assert is_broadcastable(xs.shape, mask.shape) is True - bshape = paddle.broadcast_shape(xs.shape, mask.shape) + # assert is_broadcastable(xs.shape, mask.shape) is True + # bshape = paddle.broadcast_shape(xs.shape, mask.shape) + bshape = broadcast_shape(xs.shape, mask.shape) + mask.stop_gradient = True mask = mask.broadcast_to(bshape) + trues = paddle.ones_like(xs) * value mask = mask.cast(dtype=paddle.bool) xs = paddle.where(mask, trues, xs) diff --git a/parakeet/modules/nets_utils.py b/parakeet/modules/nets_utils.py index 47eae65d6..0696335a5 100644 --- a/parakeet/modules/nets_utils.py +++ b/parakeet/modules/nets_utils.py @@ -56,7 +56,7 @@ def make_pad_mask(lengths, length_dim=-1): Parameters ---------- - lengths : LongTensor or List + lengths : LongTensor Batch of lengths (B,). Returns @@ -77,17 +77,11 @@ def make_pad_mask(lengths, length_dim=-1): if length_dim == 0: raise ValueError("length_dim cannot be 0: {}".format(length_dim)) - if not isinstance(lengths, list): - lengths = lengths.tolist() - bs = int(len(lengths)) - - maxlen = int(max(lengths)) - + bs = paddle.shape(lengths)[0] + maxlen = lengths.max() seq_range = paddle.arange(0, maxlen, dtype=paddle.int64) seq_range_expand = seq_range.unsqueeze(0).expand([bs, maxlen]) - - seq_length_expand = paddle.to_tensor( - lengths, dtype=seq_range_expand.dtype).unsqueeze(-1) + seq_length_expand = lengths.unsqueeze(-1) mask = seq_range_expand >= seq_length_expand return mask