[TTS]fix dygraph to static for tacotron2, test=doc (#2426)

* fix dygraph to static for tacotron2, test=doc

* Fix dy2st error for taco2

* Update attentions.py

---------

Co-authored-by: 0x45f <wangzhen45@baidu.com>
pull/2964/head
TianYuan 3 years ago committed by GitHub
parent d9b041e999
commit c8d5a01bdb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -47,11 +47,13 @@ def _apply_attention_constraint(e,
https://arxiv.org/abs/1710.07654
"""
if paddle.shape(e)[0] != 1:
raise NotImplementedError(
"Batch attention constraining is not yet supported.")
backward_idx = last_attended_idx - backward_window
forward_idx = last_attended_idx + forward_window
# for dygraph to static graph
# if e.shape[0] != 1:
# raise NotImplementedError(
# "Batch attention constraining is not yet supported.")
backward_idx = paddle.cast(
last_attended_idx - backward_window, dtype='int64')
forward_idx = paddle.cast(last_attended_idx + forward_window, dtype='int64')
if backward_idx > 0:
e[:, :backward_idx] = -float("inf")
if forward_idx < paddle.shape(e)[1]:
@ -122,7 +124,7 @@ class AttLoc(nn.Layer):
dec_z,
att_prev,
scaling=2.0,
last_attended_idx=None,
last_attended_idx=-1,
backward_window=1,
forward_window=3, ):
"""Calculate AttLoc forward propagation.
@ -192,7 +194,7 @@ class AttLoc(nn.Layer):
e = masked_fill(e, self.mask, -float("inf"))
# apply monotonic attention constraint (mainly for TTS)
if last_attended_idx is not None:
if last_attended_idx != -1:
e = _apply_attention_constraint(e, last_attended_idx,
backward_window, forward_window)

@ -556,13 +556,15 @@ class Decoder(nn.Layer):
if use_att_constraint:
last_attended_idx = 0
else:
last_attended_idx = None
last_attended_idx = -1
# loop for an output sequence
idx = 0
outs, att_ws, probs = [], [], []
prob = paddle.zeros([1])
while True:
while paddle.to_tensor(True):
z_list = z_list
c_list = c_list
# updated index
idx += self.reduction_factor

Loading…
Cancel
Save