fix fastspeech2 to static

pull/948/head
TianYuan 3 years ago
parent 6dbcd7720d
commit f652ba3a34

@ -341,6 +341,7 @@ class FastSpeech2(nn.Layer):
Tensor
speech_lengths, modified if reduction_factor > 1
"""
# input of embedding must be int64
xs = paddle.cast(text, 'int64')
ilens = paddle.cast(text_lengths, 'int64')
@ -387,8 +388,8 @@ class FastSpeech2(nn.Layer):
spk_id=None,
tone_id=None) -> Sequence[paddle.Tensor]:
# forward encoder
bs = xs.shape[0]
x_masks = self._source_mask(ilens)
# (B, Tmax, adim)
hs, _ = self.encoder(xs, x_masks)
@ -405,7 +406,6 @@ class FastSpeech2(nn.Layer):
if tone_id is not None:
tone_embs = self.tone_embedding_table(tone_id)
hs = self._integrate_with_tone_embed(hs, tone_embs)
# forward duration predictor and variance predictors
d_masks = make_pad_mask(ilens)
@ -452,9 +452,10 @@ class FastSpeech2(nn.Layer):
else:
h_masks = None
# (B, Lmax, adim)
zs, _ = self.decoder(hs, h_masks)
# (B, Lmax, odim)
before_outs = self.feat_out(zs).reshape((zs.shape[0], -1, self.odim))
before_outs = self.feat_out(zs).reshape((bs, -1, self.odim))
# postnet -> (B, Lmax//r * r, odim)
if self.postnet is None:
@ -462,7 +463,6 @@ class FastSpeech2(nn.Layer):
else:
after_outs = before_outs + self.postnet(
before_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
return before_outs, after_outs, d_outs, p_outs, e_outs
def inference(
@ -517,8 +517,8 @@ class FastSpeech2(nn.Layer):
d = paddle.cast(durations, 'int64')
p, e = pitch, energy
# setup batch axis
ilens = paddle.to_tensor(
[x.shape[0]], dtype=paddle.int64, place=x.place)
ilens = paddle.shape(x)[0]
xs, ys = x.unsqueeze(0), None
if y is not None:

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Length regulator related modules."""
import numpy as np
import paddle
from paddle import nn
@ -50,10 +49,10 @@ class LengthRegulator(nn.Layer):
durations: (B, T)
"""
batch_size, t_enc = durations.shape
durations = durations.numpy()
slens = np.sum(durations, -1)
t_dec = np.max(slens)
M = np.zeros([batch_size, t_dec, t_enc])
# durations = durations.numpy()
slens = paddle.sum(durations, -1)
t_dec = paddle.max(slens)
M = paddle.zeros([batch_size, t_dec, t_enc])
for i in range(batch_size):
k = 0
for j in range(t_enc):
@ -82,6 +81,7 @@ class LengthRegulator(nn.Layer):
Tensor
replicated input tensor based on durations (B, T*, D).
"""
if alpha != 1.0:
assert alpha > 0
ds = paddle.round(ds.cast(dtype=paddle.float32) * alpha)

@ -37,7 +37,7 @@ class MultiHeadedAttention(nn.Layer):
def __init__(self, n_head, n_feat, dropout_rate):
"""Construct an MultiHeadedAttention object."""
super(MultiHeadedAttention, self).__init__()
assert n_feat % n_head == 0
# assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
@ -106,13 +106,9 @@ class MultiHeadedAttention(nn.Layer):
n_batch = value.shape[0]
softmax = paddle.nn.Softmax(axis=-1)
if mask is not None:
mask = mask.unsqueeze(1)
mask = paddle.logical_not(mask)
min_value = float(
numpy.finfo(
paddle.to_tensor(0, dtype=scores.dtype).numpy().dtype).min)
min_value = float(numpy.finfo("float32").min)
scores = masked_fill(scores, mask, min_value)
# (batch, head, time1, time2)
self.attn = softmax(scores)

@ -46,13 +46,14 @@ class PositionalEncoding(nn.Layer):
def extend_pe(self, x):
"""Reset the positional encodings."""
pe = paddle.zeros([x.shape[1], self.d_model])
pe = paddle.zeros([paddle.shape(x)[1], self.d_model])
if self.reverse:
position = paddle.arange(
x.shape[1] - 1, -1, -1.0, dtype=paddle.float32).unsqueeze(1)
paddle.shape(x)[1] - 1, -1, -1.0,
dtype=paddle.float32).unsqueeze(1)
else:
position = paddle.arange(
0, x.shape[1], dtype=paddle.float32).unsqueeze(1)
0, paddle.shape(x)[1], dtype=paddle.float32).unsqueeze(1)
div_term = paddle.exp(
paddle.arange(0, self.d_model, 2, dtype=paddle.float32) *
-(math.log(10000.0) / self.d_model))
@ -75,7 +76,8 @@ class PositionalEncoding(nn.Layer):
Encoded tensor (batch, time, `*`).
"""
self.extend_pe(x)
x = x * self.xscale + self.pe[:, :x.shape[1]]
x = x * self.xscale + self.pe[:, :paddle.shape(x)[1]]
return self.dropout(x)
@ -101,7 +103,7 @@ class ScaledPositionalEncoding(PositionalEncoding):
x = paddle.ones([1], dtype="float32")
self.alpha = paddle.create_parameter(
shape=x.shape,
dtype=str(x.numpy().dtype),
dtype="float32",
default_initializer=paddle.nn.initializer.Assign(x))
def reset_parameters(self):
@ -115,12 +117,11 @@ class ScaledPositionalEncoding(PositionalEncoding):
----------
x : paddle.Tensor
Input tensor (batch, time, `*`).
Returns
----------
paddle.Tensor
Encoded tensor (batch, time, `*`).
"""
self.extend_pe(x)
x = x + self.alpha * self.pe[:, :x.shape[1]]
x = x + self.alpha * self.pe[:, :paddle.shape(x)[1]]
return self.dropout(x)

@ -185,6 +185,7 @@ class Encoder(nn.Layer):
paddle.Tensor
Mask tensor (#batch, time).
"""
xs = self.embed(xs)
xs, masks = self.encoders(xs, masks)
if self.normalize_before:

@ -87,7 +87,7 @@ class EncoderLayer(nn.Layer):
if cache is None:
x_q = x
else:
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
# assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
x_q = x[:, -1:, :]
residual = residual[:, -1:, :]
mask = None if mask is None else mask[:, -1:, :]

@ -44,6 +44,7 @@ class LayerNorm(paddle.nn.LayerNorm):
paddle.Tensor
Normalized tensor.
"""
if self.dim == -1:
return super(LayerNorm, self).forward(x)
else:
@ -54,9 +55,10 @@ class LayerNorm(paddle.nn.LayerNorm):
orig_perm = list(range(len_dim))
new_perm = orig_perm[:]
new_perm[self.dim], new_perm[len_dim -
1] = new_perm[len_dim -
1], new_perm[self.dim]
temp = new_perm[self.dim]
new_perm[self.dim] = new_perm[len_dim - 1]
new_perm[len_dim - 1] = temp
# new_perm[self.dim], new_perm[len_dim -1] = new_perm[len_dim -1], new_perm[self.dim]
return paddle.transpose(
super(LayerNorm, self).forward(paddle.transpose(x, new_perm)),

@ -25,12 +25,22 @@ def is_broadcastable(shp1, shp2):
return True
def broadcast_shape(shp1, shp2):
result = []
for a, b in zip(shp1[::-1], shp2[::-1]):
result.append(max(a, b))
return result[::-1]
def masked_fill(xs: paddle.Tensor,
mask: paddle.Tensor,
value: Union[float, int]):
assert is_broadcastable(xs.shape, mask.shape) is True
bshape = paddle.broadcast_shape(xs.shape, mask.shape)
# assert is_broadcastable(xs.shape, mask.shape) is True
# bshape = paddle.broadcast_shape(xs.shape, mask.shape)
bshape = broadcast_shape(xs.shape, mask.shape)
mask.stop_gradient = True
mask = mask.broadcast_to(bshape)
trues = paddle.ones_like(xs) * value
mask = mask.cast(dtype=paddle.bool)
xs = paddle.where(mask, trues, xs)

@ -56,7 +56,7 @@ def make_pad_mask(lengths, length_dim=-1):
Parameters
----------
lengths : LongTensor or List
lengths : LongTensor
Batch of lengths (B,).
Returns
@ -77,17 +77,11 @@ def make_pad_mask(lengths, length_dim=-1):
if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
if not isinstance(lengths, list):
lengths = lengths.tolist()
bs = int(len(lengths))
maxlen = int(max(lengths))
bs = paddle.shape(lengths)[0]
maxlen = lengths.max()
seq_range = paddle.arange(0, maxlen, dtype=paddle.int64)
seq_range_expand = seq_range.unsqueeze(0).expand([bs, maxlen])
seq_length_expand = paddle.to_tensor(
lengths, dtype=seq_range_expand.dtype).unsqueeze(-1)
seq_length_expand = lengths.unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
return mask

Loading…
Cancel
Save