[TTS]fix tacotron2 dygraph to static (#1414)

* fix tacotron2 dygraph to static , test=tts

* fix tacotron2 dygraph to static , test=tts

* simplify synthesize_e2e.py , test=tts
pull/1422/head
TianYuan 4 years ago committed by GitHub
parent 8891621e2c
commit 89e69ee10e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,51 @@
#!/bin/bash
train_output_path=$1
stage=0
stop_stage=0
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
python3 ${BIN_DIR}/../inference.py \
--inference_dir=${train_output_path}/inference \
--am=tacotron2_csmsc \
--voc=pwgan_csmsc \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/pd_infer_out \
--phones_dict=dump/phone_id_map.txt
fi
# for more GAN Vocoders
# multi band melgan
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
python3 ${BIN_DIR}/../inference.py \
--inference_dir=${train_output_path}/inference \
--am=tacotron2_csmsc \
--voc=mb_melgan_csmsc \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/pd_infer_out \
--phones_dict=dump/phone_id_map.txt
fi
# style melgan
# style melgan's Dygraph to Static Graph is not ready now
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
python3 ${BIN_DIR}/../inference.py \
--inference_dir=${train_output_path}/inference \
--am=tacotron2_csmsc \
--voc=style_melgan_csmsc \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/pd_infer_out \
--phones_dict=dump/phone_id_map.txt
fi
# hifigan
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
python3 ${BIN_DIR}/../inference.py \
--inference_dir=${train_output_path}/inference \
--am=tacotron2_csmsc \
--voc=hifigan_csmsc \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/pd_infer_out \
--phones_dict=dump/phone_id_map.txt
fi

@ -22,8 +22,9 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--lang=zh \ --lang=zh \
--text=${BIN_DIR}/../sentences.txt \ --text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/test_e2e \ --output_dir=${train_output_path}/test_e2e \
--inference_dir=${train_output_path}/inference \ --phones_dict=dump/phone_id_map.txt \
--phones_dict=dump/phone_id_map.txt --inference_dir=${train_output_path}/inference
fi fi
# for more GAN Vocoders # for more GAN Vocoders

@ -33,7 +33,7 @@ def main():
default='fastspeech2_csmsc', default='fastspeech2_csmsc',
choices=[ choices=[
'speedyspeech_csmsc', 'fastspeech2_csmsc', 'fastspeech2_aishell3', 'speedyspeech_csmsc', 'fastspeech2_csmsc', 'fastspeech2_aishell3',
'fastspeech2_vctk' 'fastspeech2_vctk', 'tacotron2_csmsc'
], ],
help='Choose acoustic model type of tts task.') help='Choose acoustic model type of tts task.')
parser.add_argument( parser.add_argument(

@ -178,10 +178,7 @@ def evaluate(args):
am_inference = jit.to_static( am_inference = jit.to_static(
am_inference, am_inference,
input_spec=[InputSpec([-1], dtype=paddle.int64)]) input_spec=[InputSpec([-1], dtype=paddle.int64)])
paddle.jit.save(am_inference,
os.path.join(args.inference_dir, args.am))
am_inference = paddle.jit.load(
os.path.join(args.inference_dir, args.am))
elif am_name == 'speedyspeech': elif am_name == 'speedyspeech':
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
am_inference = jit.to_static( am_inference = jit.to_static(
@ -200,8 +197,11 @@ def evaluate(args):
InputSpec([-1], dtype=paddle.int64) InputSpec([-1], dtype=paddle.int64)
]) ])
paddle.jit.save(am_inference, elif am_name == 'tacotron2':
os.path.join(args.inference_dir, args.am)) am_inference = jit.to_static(
am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
paddle.jit.save(am_inference, os.path.join(args.inference_dir, args.am))
am_inference = paddle.jit.load( am_inference = paddle.jit.load(
os.path.join(args.inference_dir, args.am)) os.path.join(args.inference_dir, args.am))

@ -432,6 +432,7 @@ class Tacotron2(nn.Layer):
# inference # inference
h = self.enc.inference(x) h = self.enc.inference(x)
if self.spk_num is not None: if self.spk_num is not None:
sid_emb = self.sid_emb(spk_id.reshape([-1])) sid_emb = self.sid_emb(spk_id.reshape([-1]))
h = h + sid_emb h = h + sid_emb

@ -157,7 +157,7 @@ class AttLoc(nn.Layer):
paddle.Tensor paddle.Tensor
previous attention weights (B, T_max) previous attention weights (B, T_max)
""" """
batch = len(enc_hs_pad) batch = paddle.shape(enc_hs_pad)[0]
# pre-compute all h outside the decoder loop # pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None or self.han_mode: if self.pre_compute_enc_h is None or self.han_mode:
# (utt, frame, hdim) # (utt, frame, hdim)
@ -172,33 +172,30 @@ class AttLoc(nn.Layer):
dec_z = dec_z.reshape([batch, self.dunits]) dec_z = dec_z.reshape([batch, self.dunits])
# initialize attention weight with uniform dist. # initialize attention weight with uniform dist.
if att_prev is None: if paddle.sum(att_prev) == 0:
# if no bias, 0 0-pad goes 0 # if no bias, 0 0-pad goes 0
att_prev = 1.0 - make_pad_mask(enc_hs_len) att_prev = 1.0 - make_pad_mask(enc_hs_len)
att_prev = att_prev / enc_hs_len.unsqueeze(-1) att_prev = att_prev / enc_hs_len.unsqueeze(-1)
# att_prev: (utt, frame) -> (utt, 1, 1, frame) # att_prev: (utt, frame) -> (utt, 1, 1, frame)
# -> (utt, att_conv_chans, 1, frame) # -> (utt, att_conv_chans, 1, frame)
att_conv = self.loc_conv(att_prev.reshape([batch, 1, 1, self.h_length])) att_conv = self.loc_conv(att_prev.reshape([batch, 1, 1, self.h_length]))
# att_conv: (utt, att_conv_chans, 1, frame) -> (utt, frame, att_conv_chans) # att_conv: (utt, att_conv_chans, 1, frame) -> (utt, frame, att_conv_chans)
att_conv = att_conv.squeeze(2).transpose([0, 2, 1]) att_conv = att_conv.squeeze(2).transpose([0, 2, 1])
# att_conv: (utt, frame, att_conv_chans) -> (utt, frame, att_dim) # att_conv: (utt, frame, att_conv_chans) -> (utt, frame, att_dim)
att_conv = self.mlp_att(att_conv) att_conv = self.mlp_att(att_conv)
# dec_z_tiled: (utt, frame, att_dim) # dec_z_tiled: (utt, frame, att_dim)
dec_z_tiled = self.mlp_dec(dec_z).reshape([batch, 1, self.att_dim]) dec_z_tiled = self.mlp_dec(dec_z).reshape([batch, 1, self.att_dim])
# dot with gvec # dot with gvec
# (utt, frame, att_dim) -> (utt, frame) # (utt, frame, att_dim) -> (utt, frame)
e = self.gvec( e = paddle.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)
paddle.tanh(att_conv + self.pre_compute_enc_h + e = self.gvec(e).squeeze(2)
dec_z_tiled)).squeeze(2)
# NOTE: consider zero padding when compute w. # NOTE: consider zero padding when compute w.
if self.mask is None: if self.mask is None:
self.mask = make_pad_mask(enc_hs_len) self.mask = make_pad_mask(enc_hs_len)
e = masked_fill(e, self.mask, -float("inf")) e = masked_fill(e, self.mask, -float("inf"))
# apply monotonic attention constraint (mainly for TTS) # apply monotonic attention constraint (mainly for TTS)
if last_attended_idx is not None: if last_attended_idx is not None:
@ -211,7 +208,6 @@ class AttLoc(nn.Layer):
# utt x hdim # utt x hdim
c = paddle.sum( c = paddle.sum(
self.enc_h * w.reshape([batch, self.h_length, 1]), axis=1) self.enc_h * w.reshape([batch, self.h_length, 1]), axis=1)
return c, w return c, w

@ -15,7 +15,6 @@
"""Tacotron2 decoder related modules.""" """Tacotron2 decoder related modules."""
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
import six
from paddle import nn from paddle import nn
from paddlespeech.t2s.modules.tacotron2.attentions import AttForwardTA from paddlespeech.t2s.modules.tacotron2.attentions import AttForwardTA
@ -59,7 +58,7 @@ class Prenet(nn.Layer):
super().__init__() super().__init__()
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
self.prenet = nn.LayerList() self.prenet = nn.LayerList()
for layer in six.moves.range(n_layers): for layer in range(n_layers):
n_inputs = idim if layer == 0 else n_units n_inputs = idim if layer == 0 else n_units
self.prenet.append( self.prenet.append(
nn.Sequential(nn.Linear(n_inputs, n_units), nn.ReLU())) nn.Sequential(nn.Linear(n_inputs, n_units), nn.ReLU()))
@ -78,7 +77,7 @@ class Prenet(nn.Layer):
Batch of output tensors (B, ..., odim). Batch of output tensors (B, ..., odim).
""" """
for i in six.moves.range(len(self.prenet)): for i in range(len(self.prenet)):
# F.dropout 引入了随机, tacotron2 的 dropout 是不能去掉的 # F.dropout 引入了随机, tacotron2 的 dropout 是不能去掉的
x = F.dropout(self.prenet[i](x)) x = F.dropout(self.prenet[i](x))
return x return x
@ -129,7 +128,7 @@ class Postnet(nn.Layer):
""" """
super().__init__() super().__init__()
self.postnet = nn.LayerList() self.postnet = nn.LayerList()
for layer in six.moves.range(n_layers - 1): for layer in range(n_layers - 1):
ichans = odim if layer == 0 else n_chans ichans = odim if layer == 0 else n_chans
ochans = odim if layer == n_layers - 1 else n_chans ochans = odim if layer == n_layers - 1 else n_chans
if use_batch_norm: if use_batch_norm:
@ -196,7 +195,7 @@ class Postnet(nn.Layer):
Batch of padded output tensor. (B, odim, Tmax). Batch of padded output tensor. (B, odim, Tmax).
""" """
for i in six.moves.range(len(self.postnet)): for i in range(len(self.postnet)):
xs = self.postnet[i](xs) xs = self.postnet[i](xs)
return xs return xs
@ -360,7 +359,7 @@ class Decoder(nn.Layer):
# define lstm network # define lstm network
prenet_units = prenet_units if prenet_layers != 0 else odim prenet_units = prenet_units if prenet_layers != 0 else odim
self.lstm = nn.LayerList() self.lstm = nn.LayerList()
for layer in six.moves.range(dlayers): for layer in range(dlayers):
iunits = idim + prenet_units if layer == 0 else dunits iunits = idim + prenet_units if layer == 0 else dunits
lstm = nn.LSTMCell(iunits, dunits) lstm = nn.LSTMCell(iunits, dunits)
if zoneout_rate > 0.0: if zoneout_rate > 0.0:
@ -437,47 +436,50 @@ class Decoder(nn.Layer):
# initialize hidden states of decoder # initialize hidden states of decoder
c_list = [self._zero_state(hs)] c_list = [self._zero_state(hs)]
z_list = [self._zero_state(hs)] z_list = [self._zero_state(hs)]
for _ in six.moves.range(1, len(self.lstm)): for _ in range(1, len(self.lstm)):
c_list += [self._zero_state(hs)] c_list.append(self._zero_state(hs))
z_list += [self._zero_state(hs)] z_list.append(self._zero_state(hs))
prev_out = paddle.zeros([paddle.shape(hs)[0], self.odim]) prev_out = paddle.zeros([paddle.shape(hs)[0], self.odim])
# initialize attention # initialize attention
prev_att_w = None prev_att_ws = []
prev_att_w = paddle.zeros(paddle.shape(hlens))
prev_att_ws.append(prev_att_w)
self.att.reset() self.att.reset()
# loop for an output sequence # loop for an output sequence
outs, logits, att_ws = [], [], [] outs, logits, att_ws = [], [], []
for y in ys.transpose([1, 0, 2]): for y in ys.transpose([1, 0, 2]):
if self.use_att_extra_inputs: if self.use_att_extra_inputs:
att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w, att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_ws[-1],
prev_out) prev_out)
else: else:
att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w) att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_ws[-1])
prenet_out = self.prenet( prenet_out = self.prenet(
prev_out) if self.prenet is not None else prev_out prev_out) if self.prenet is not None else prev_out
xs = paddle.concat([att_c, prenet_out], axis=1) xs = paddle.concat([att_c, prenet_out], axis=1)
# we only use the second output of LSTMCell in paddle # we only use the second output of LSTMCell in paddle
_, next_hidden = self.lstm[0](xs, (z_list[0], c_list[0])) _, next_hidden = self.lstm[0](xs, (z_list[0], c_list[0]))
z_list[0], c_list[0] = next_hidden z_list[0], c_list[0] = next_hidden
for i in six.moves.range(1, len(self.lstm)): for i in range(1, len(self.lstm)):
# we only use the second output of LSTMCell in paddle # we only use the second output of LSTMCell in paddle
_, next_hidden = self.lstm[i](z_list[i - 1], _, next_hidden = self.lstm[i](z_list[i - 1],
(z_list[i], c_list[i])) (z_list[i], c_list[i]))
z_list[i], c_list[i] = next_hidden z_list[i], c_list[i] = next_hidden
zcs = (paddle.concat([z_list[-1], att_c], axis=1) zcs = (paddle.concat([z_list[-1], att_c], axis=1)
if self.use_concate else z_list[-1]) if self.use_concate else z_list[-1])
outs += [ outs.append(
self.feat_out(zcs).reshape([paddle.shape(hs)[0], self.odim, -1]) self.feat_out(zcs).reshape([paddle.shape(hs)[0], self.odim, -1
] ]))
logits += [self.prob_out(zcs)] logits.append(self.prob_out(zcs))
att_ws += [att_w] att_ws.append(att_w)
# teacher forcing # teacher forcing
prev_out = y prev_out = y
if self.cumulate_att_w and prev_att_w is not None: if self.cumulate_att_w and paddle.sum(prev_att_w) != 0:
prev_att_w = prev_att_w + att_w # Note: error when use += prev_att_w = prev_att_w + att_w # Note: error when use +=
else: else:
prev_att_w = att_w prev_att_w = att_w
prev_att_ws.append(prev_att_w)
# (B, Lmax) # (B, Lmax)
logits = paddle.concat(logits, axis=1) logits = paddle.concat(logits, axis=1)
# (B, odim, Lmax) # (B, odim, Lmax)
@ -552,6 +554,7 @@ class Decoder(nn.Layer):
.. _`Deep Voice 3`: https://arxiv.org/abs/1710.07654 .. _`Deep Voice 3`: https://arxiv.org/abs/1710.07654
""" """
# setup # setup
assert len(paddle.shape(h)) == 2 assert len(paddle.shape(h)) == 2
hs = h.unsqueeze(0) hs = h.unsqueeze(0)
ilens = paddle.shape(h)[0] ilens = paddle.shape(h)[0]
@ -561,13 +564,16 @@ class Decoder(nn.Layer):
# initialize hidden states of decoder # initialize hidden states of decoder
c_list = [self._zero_state(hs)] c_list = [self._zero_state(hs)]
z_list = [self._zero_state(hs)] z_list = [self._zero_state(hs)]
for _ in six.moves.range(1, len(self.lstm)): for _ in range(1, len(self.lstm)):
c_list += [self._zero_state(hs)] c_list.append(self._zero_state(hs))
z_list += [self._zero_state(hs)] z_list.append(self._zero_state(hs))
prev_out = paddle.zeros([1, self.odim]) prev_out = paddle.zeros([1, self.odim])
# initialize attention # initialize attention
prev_att_w = None prev_att_ws = []
prev_att_w = paddle.zeros([ilens])
prev_att_ws.append(prev_att_w)
self.att.reset() self.att.reset()
# setup for attention constraint # setup for attention constraint
@ -579,6 +585,7 @@ class Decoder(nn.Layer):
# loop for an output sequence # loop for an output sequence
idx = 0 idx = 0
outs, att_ws, probs = [], [], [] outs, att_ws, probs = [], [], []
prob = paddle.zeros([1])
while True: while True:
# updated index # updated index
idx += self.reduction_factor idx += self.reduction_factor
@ -589,7 +596,7 @@ class Decoder(nn.Layer):
hs, hs,
ilens, ilens,
z_list[0], z_list[0],
prev_att_w, prev_att_ws[-1],
prev_out, prev_out,
last_attended_idx=last_attended_idx, last_attended_idx=last_attended_idx,
backward_window=backward_window, backward_window=backward_window,
@ -599,19 +606,20 @@ class Decoder(nn.Layer):
hs, hs,
ilens, ilens,
z_list[0], z_list[0],
prev_att_w, prev_att_ws[-1],
last_attended_idx=last_attended_idx, last_attended_idx=last_attended_idx,
backward_window=backward_window, backward_window=backward_window,
forward_window=forward_window, ) forward_window=forward_window, )
att_ws += [att_w] att_ws.append(att_w)
prenet_out = self.prenet( prenet_out = self.prenet(
prev_out) if self.prenet is not None else prev_out prev_out) if self.prenet is not None else prev_out
xs = paddle.concat([att_c, prenet_out], axis=1) xs = paddle.concat([att_c, prenet_out], axis=1)
# we only use the second output of LSTMCell in paddle # we only use the second output of LSTMCell in paddle
_, next_hidden = self.lstm[0](xs, (z_list[0], c_list[0])) _, next_hidden = self.lstm[0](xs, (z_list[0], c_list[0]))
z_list[0], c_list[0] = next_hidden z_list[0], c_list[0] = next_hidden
for i in six.moves.range(1, len(self.lstm)): for i in range(1, len(self.lstm)):
# we only use the second output of LSTMCell in paddle # we only use the second output of LSTMCell in paddle
_, next_hidden = self.lstm[i](z_list[i - 1], _, next_hidden = self.lstm[i](z_list[i - 1],
(z_list[i], c_list[i])) (z_list[i], c_list[i]))
@ -619,28 +627,29 @@ class Decoder(nn.Layer):
zcs = (paddle.concat([z_list[-1], att_c], axis=1) zcs = (paddle.concat([z_list[-1], att_c], axis=1)
if self.use_concate else z_list[-1]) if self.use_concate else z_list[-1])
# [(1, odim, r), ...] # [(1, odim, r), ...]
outs += [self.feat_out(zcs).reshape([1, self.odim, -1])] outs.append(self.feat_out(zcs).reshape([1, self.odim, -1]))
prob = F.sigmoid(self.prob_out(zcs))[0]
probs.append(prob)
# [(r), ...]
probs += [F.sigmoid(self.prob_out(zcs))[0]]
if self.output_activation_fn is not None: if self.output_activation_fn is not None:
prev_out = self.output_activation_fn( prev_out = self.output_activation_fn(
outs[-1][:, :, -1]) # (1, odim) outs[-1][:, :, -1]) # (1, odim)
else: else:
prev_out = outs[-1][:, :, -1] # (1, odim) prev_out = outs[-1][:, :, -1] # (1, odim)
if self.cumulate_att_w and prev_att_w is not None: if self.cumulate_att_w and paddle.sum(prev_att_w) != 0:
prev_att_w = prev_att_w + att_w # Note: error when use += prev_att_w = prev_att_w + att_w # Note: error when use +=
else: else:
prev_att_w = att_w prev_att_w = att_w
prev_att_ws.append(prev_att_w)
if use_att_constraint: if use_att_constraint:
last_attended_idx = int(att_w.argmax()) last_attended_idx = int(att_w.argmax())
# check whether to finish generation if prob >= threshold or idx >= maxlen:
if sum(paddle.cast(probs[-1] >= threshold,
'int64')) > 0 or idx >= maxlen:
# check mininum length # check mininum length
if idx < minlen: if idx < minlen:
continue continue
break
# (1, odim, L) # (1, odim, L)
outs = paddle.concat(outs, axis=2) outs = paddle.concat(outs, axis=2)
if self.postnet is not None: if self.postnet is not None:
@ -650,7 +659,6 @@ class Decoder(nn.Layer):
outs = outs.transpose([0, 2, 1]).squeeze(0) outs = outs.transpose([0, 2, 1]).squeeze(0)
probs = paddle.concat(probs, axis=0) probs = paddle.concat(probs, axis=0)
att_ws = paddle.concat(att_ws, axis=0) att_ws = paddle.concat(att_ws, axis=0)
break
if self.output_activation_fn is not None: if self.output_activation_fn is not None:
outs = self.output_activation_fn(outs) outs = self.output_activation_fn(outs)
@ -685,9 +693,9 @@ class Decoder(nn.Layer):
# initialize hidden states of decoder # initialize hidden states of decoder
c_list = [self._zero_state(hs)] c_list = [self._zero_state(hs)]
z_list = [self._zero_state(hs)] z_list = [self._zero_state(hs)]
for _ in six.moves.range(1, len(self.lstm)): for _ in range(1, len(self.lstm)):
c_list += [self._zero_state(hs)] c_list.append(self._zero_state(hs))
z_list += [self._zero_state(hs)] z_list.append(self._zero_state(hs))
prev_out = paddle.zeros([paddle.shape(hs)[0], self.odim]) prev_out = paddle.zeros([paddle.shape(hs)[0], self.odim])
# initialize attention # initialize attention
@ -702,14 +710,14 @@ class Decoder(nn.Layer):
prev_out) prev_out)
else: else:
att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w) att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w)
att_ws += [att_w] att_ws.append(att_w)
prenet_out = self.prenet( prenet_out = self.prenet(
prev_out) if self.prenet is not None else prev_out prev_out) if self.prenet is not None else prev_out
xs = paddle.concat([att_c, prenet_out], axis=1) xs = paddle.concat([att_c, prenet_out], axis=1)
# we only use the second output of LSTMCell in paddle # we only use the second output of LSTMCell in paddle
_, next_hidden = self.lstm[0](xs, (z_list[0], c_list[0])) _, next_hidden = self.lstm[0](xs, (z_list[0], c_list[0]))
z_list[0], c_list[0] = next_hidden z_list[0], c_list[0] = next_hidden
for i in six.moves.range(1, len(self.lstm)): for i in range(1, len(self.lstm)):
z_list[i], c_list[i] = self.lstm[i](z_list[i - 1], z_list[i], c_list[i] = self.lstm[i](z_list[i - 1],
(z_list[i], c_list[i])) (z_list[i], c_list[i]))
# teacher forcing # teacher forcing

@ -14,7 +14,6 @@
# Modified from espnet(https://github.com/espnet/espnet) # Modified from espnet(https://github.com/espnet/espnet)
"""Tacotron2 encoder related modules.""" """Tacotron2 encoder related modules."""
import paddle import paddle
import six
from paddle import nn from paddle import nn
@ -88,7 +87,7 @@ class Encoder(nn.Layer):
if econv_layers > 0: if econv_layers > 0:
self.convs = nn.LayerList() self.convs = nn.LayerList()
for layer in six.moves.range(econv_layers): for layer in range(econv_layers):
ichans = (embed_dim if layer == 0 and input_layer == "embed" ichans = (embed_dim if layer == 0 and input_layer == "embed"
else econv_chans) else econv_chans)
if use_batch_norm: if use_batch_norm:
@ -130,6 +129,7 @@ class Encoder(nn.Layer):
direction='bidirectional', direction='bidirectional',
bias_ih_attr=True, bias_ih_attr=True,
bias_hh_attr=True) bias_hh_attr=True)
self.blstm.flatten_parameters()
else: else:
self.blstm = None self.blstm = None
@ -157,7 +157,7 @@ class Encoder(nn.Layer):
""" """
xs = self.embed(xs).transpose([0, 2, 1]) xs = self.embed(xs).transpose([0, 2, 1])
if self.convs is not None: if self.convs is not None:
for i in six.moves.range(len(self.convs)): for i in range(len(self.convs)):
if self.use_residual: if self.use_residual:
xs += self.convs[i](xs) xs += self.convs[i](xs)
else: else:
@ -167,7 +167,8 @@ class Encoder(nn.Layer):
if not isinstance(ilens, paddle.Tensor): if not isinstance(ilens, paddle.Tensor):
ilens = paddle.to_tensor(ilens) ilens = paddle.to_tensor(ilens)
xs = xs.transpose([0, 2, 1]) xs = xs.transpose([0, 2, 1])
self.blstm.flatten_parameters() # for dygraph to static graph
# self.blstm.flatten_parameters()
# (B, Tmax, C) # (B, Tmax, C)
# see https://www.paddlepaddle.org.cn/documentation/docs/zh/faq/train_cn.html#paddletorch-nn-utils-rnn-pack-padded-sequencetorch-nn-utils-rnn-pad-packed-sequenceapi # see https://www.paddlepaddle.org.cn/documentation/docs/zh/faq/train_cn.html#paddletorch-nn-utils-rnn-pack-padded-sequencetorch-nn-utils-rnn-pad-packed-sequenceapi
xs, _ = self.blstm(xs, sequence_length=ilens) xs, _ = self.blstm(xs, sequence_length=ilens)
@ -191,6 +192,6 @@ class Encoder(nn.Layer):
""" """
xs = x.unsqueeze(0) xs = x.unsqueeze(0)
ilens = paddle.to_tensor([x.shape[0]]) ilens = paddle.shape(x)[0]
return self.forward(xs, ilens)[0][0] return self.forward(xs, ilens)[0][0]

Loading…
Cancel
Save