fix dygraph to static for tacotron2, test=doc

pull/2426/head
TianYuan 3 years ago
parent e6cbcca3e2
commit 7cdd057988

@ -47,11 +47,13 @@ def _apply_attention_constraint(e,
https://arxiv.org/abs/1710.07654 https://arxiv.org/abs/1710.07654
""" """
if paddle.shape(e)[0] != 1: # for dygraph to static graph
raise NotImplementedError( # if e.shape[0] != 1:
"Batch attention constraining is not yet supported.") # raise NotImplementedError(
backward_idx = last_attended_idx - backward_window # "Batch attention constraining is not yet supported.")
forward_idx = last_attended_idx + forward_window backward_idx = paddle.cast(
last_attended_idx - backward_window, dtype='int64')
forward_idx = paddle.cast(last_attended_idx + forward_window, dtype='int64')
if backward_idx > 0: if backward_idx > 0:
e[:, :backward_idx] = -float("inf") e[:, :backward_idx] = -float("inf")
if forward_idx < paddle.shape(e)[1]: if forward_idx < paddle.shape(e)[1]:

@ -562,7 +562,7 @@ class Decoder(nn.Layer):
idx = 0 idx = 0
outs, att_ws, probs = [], [], [] outs, att_ws, probs = [], [], []
prob = paddle.zeros([1]) prob = paddle.zeros([1])
while True: while paddle.to_tensor(True):
# updated index # updated index
idx += self.reduction_factor idx += self.reduction_factor

Loading…
Cancel
Save