[TTS]fix dygraph to static for tacotron2, test=doc (#2426)

* fix dygraph to static for tacotron2, test=doc

* Fix dy2st error for taco2

* Update attentions.py

---------

Co-authored-by: 0x45f <wangzhen45@baidu.com>
pull/2964/head
TianYuan 3 years ago committed by GitHub
parent d9b041e999
commit c8d5a01bdb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -47,11 +47,13 @@ def _apply_attention_constraint(e,
https://arxiv.org/abs/1710.07654 https://arxiv.org/abs/1710.07654
""" """
if paddle.shape(e)[0] != 1: # for dygraph to static graph
raise NotImplementedError( # if e.shape[0] != 1:
"Batch attention constraining is not yet supported.") # raise NotImplementedError(
backward_idx = last_attended_idx - backward_window # "Batch attention constraining is not yet supported.")
forward_idx = last_attended_idx + forward_window backward_idx = paddle.cast(
last_attended_idx - backward_window, dtype='int64')
forward_idx = paddle.cast(last_attended_idx + forward_window, dtype='int64')
if backward_idx > 0: if backward_idx > 0:
e[:, :backward_idx] = -float("inf") e[:, :backward_idx] = -float("inf")
if forward_idx < paddle.shape(e)[1]: if forward_idx < paddle.shape(e)[1]:
@ -122,7 +124,7 @@ class AttLoc(nn.Layer):
dec_z, dec_z,
att_prev, att_prev,
scaling=2.0, scaling=2.0,
last_attended_idx=None, last_attended_idx=-1,
backward_window=1, backward_window=1,
forward_window=3, ): forward_window=3, ):
"""Calculate AttLoc forward propagation. """Calculate AttLoc forward propagation.
@ -192,7 +194,7 @@ class AttLoc(nn.Layer):
e = masked_fill(e, self.mask, -float("inf")) e = masked_fill(e, self.mask, -float("inf"))
# apply monotonic attention constraint (mainly for TTS) # apply monotonic attention constraint (mainly for TTS)
if last_attended_idx is not None: if last_attended_idx != -1:
e = _apply_attention_constraint(e, last_attended_idx, e = _apply_attention_constraint(e, last_attended_idx,
backward_window, forward_window) backward_window, forward_window)

@ -556,13 +556,15 @@ class Decoder(nn.Layer):
if use_att_constraint: if use_att_constraint:
last_attended_idx = 0 last_attended_idx = 0
else: else:
last_attended_idx = None last_attended_idx = -1
# loop for an output sequence # loop for an output sequence
idx = 0 idx = 0
outs, att_ws, probs = [], [], [] outs, att_ws, probs = [], [], []
prob = paddle.zeros([1]) prob = paddle.zeros([1])
while True: while paddle.to_tensor(True):
z_list = z_list
c_list = c_list
# updated index # updated index
idx += self.reduction_factor idx += self.reduction_factor

Loading…
Cancel
Save