fix dygraph to static for tacotron2, test=doc

pull/2426/head
TianYuan 3 years ago
parent e6cbcca3e2
commit 7cdd057988

@ -47,11 +47,13 @@ def _apply_attention_constraint(e,
https://arxiv.org/abs/1710.07654
"""
if paddle.shape(e)[0] != 1:
raise NotImplementedError(
"Batch attention constraining is not yet supported.")
backward_idx = last_attended_idx - backward_window
forward_idx = last_attended_idx + forward_window
# for dygraph to static graph
# if e.shape[0] != 1:
# raise NotImplementedError(
# "Batch attention constraining is not yet supported.")
backward_idx = paddle.cast(
last_attended_idx - backward_window, dtype='int64')
forward_idx = paddle.cast(last_attended_idx + forward_window, dtype='int64')
if backward_idx > 0:
e[:, :backward_idx] = -float("inf")
if forward_idx < paddle.shape(e)[1]:

@ -562,7 +562,7 @@ class Decoder(nn.Layer):
idx = 0
outs, att_ws, probs = [], [], []
prob = paddle.zeros([1])
while True:
while paddle.to_tensor(True):
# updated index
idx += self.reduction_factor

Loading…
Cancel
Save