|
|
@ -196,7 +196,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
|
|
|
|
|
|
|
|
|
|
|
if self.zero_triu:
|
|
|
|
if self.zero_triu:
|
|
|
|
ones = paddle.ones((t1, t2))
|
|
|
|
ones = paddle.ones((t1, t2))
|
|
|
|
x = x * paddle.tril(ones, t2 - 1)[None, None, :, :]
|
|
|
|
x = x * paddle.tril(ones, t2 - t1)[None, None, :, :]
|
|
|
|
|
|
|
|
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
@ -299,7 +299,7 @@ class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
|
|
|
|
|
|
|
|
|
|
|
|
if self.zero_triu:
|
|
|
|
if self.zero_triu:
|
|
|
|
ones = paddle.ones((t1, t2))
|
|
|
|
ones = paddle.ones((t1, t2))
|
|
|
|
x = x * paddle.tril(ones, t2 - 1)[None, None, :, :]
|
|
|
|
x = x * paddle.tril(ones, t2 - t1)[None, None, :, :]
|
|
|
|
|
|
|
|
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|