|
|
|
@ -47,11 +47,13 @@ def _apply_attention_constraint(e,
|
|
|
|
|
https://arxiv.org/abs/1710.07654
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
if paddle.shape(e)[0] != 1:
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
"Batch attention constraining is not yet supported.")
|
|
|
|
|
backward_idx = last_attended_idx - backward_window
|
|
|
|
|
forward_idx = last_attended_idx + forward_window
|
|
|
|
|
# for dygraph to static graph
|
|
|
|
|
# if e.shape[0] != 1:
|
|
|
|
|
# raise NotImplementedError(
|
|
|
|
|
# "Batch attention constraining is not yet supported.")
|
|
|
|
|
backward_idx = paddle.cast(
|
|
|
|
|
last_attended_idx - backward_window, dtype='int64')
|
|
|
|
|
forward_idx = paddle.cast(last_attended_idx + forward_window, dtype='int64')
|
|
|
|
|
if backward_idx > 0:
|
|
|
|
|
e[:, :backward_idx] = -float("inf")
|
|
|
|
|
if forward_idx < paddle.shape(e)[1]:
|
|
|
|
|