|
|
|
@ -47,11 +47,13 @@ def _apply_attention_constraint(e,
|
|
|
|
|
https://arxiv.org/abs/1710.07654
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
if paddle.shape(e)[0] != 1:
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
"Batch attention constraining is not yet supported.")
|
|
|
|
|
backward_idx = last_attended_idx - backward_window
|
|
|
|
|
forward_idx = last_attended_idx + forward_window
|
|
|
|
|
# for dygraph to static graph
|
|
|
|
|
# if e.shape[0] != 1:
|
|
|
|
|
# raise NotImplementedError(
|
|
|
|
|
# "Batch attention constraining is not yet supported.")
|
|
|
|
|
backward_idx = paddle.cast(
|
|
|
|
|
last_attended_idx - backward_window, dtype='int64')
|
|
|
|
|
forward_idx = paddle.cast(last_attended_idx + forward_window, dtype='int64')
|
|
|
|
|
if backward_idx > 0:
|
|
|
|
|
e[:, :backward_idx] = -float("inf")
|
|
|
|
|
if forward_idx < paddle.shape(e)[1]:
|
|
|
|
@ -122,7 +124,7 @@ class AttLoc(nn.Layer):
|
|
|
|
|
dec_z,
|
|
|
|
|
att_prev,
|
|
|
|
|
scaling=2.0,
|
|
|
|
|
last_attended_idx=None,
|
|
|
|
|
last_attended_idx=-1,
|
|
|
|
|
backward_window=1,
|
|
|
|
|
forward_window=3, ):
|
|
|
|
|
"""Calculate AttLoc forward propagation.
|
|
|
|
@ -192,7 +194,7 @@ class AttLoc(nn.Layer):
|
|
|
|
|
|
|
|
|
|
e = masked_fill(e, self.mask, -float("inf"))
|
|
|
|
|
# apply monotonic attention constraint (mainly for TTS)
|
|
|
|
|
if last_attended_idx is not None:
|
|
|
|
|
if last_attended_idx != -1:
|
|
|
|
|
e = _apply_attention_constraint(e, last_attended_idx,
|
|
|
|
|
backward_window, forward_window)
|
|
|
|
|
|
|
|
|
|