Diff attention opearation, test=tts

pull/2770/head
WongLaw 3 years ago
parent f28d0a103b
commit c5f8e44e53

@ -196,7 +196,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
if self.zero_triu:
ones = paddle.ones((t1, t2))
x = x * paddle.tril(ones, t2 - 1)[None, None, :, :]
x = x * paddle.tril(ones, t2 - t1)[None, None, :, :]
return x
@ -299,7 +299,7 @@ class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
if self.zero_triu:
ones = paddle.ones((t1, t2))
x = x * paddle.tril(ones, t2 - 1)[None, None, :, :]
x = x * paddle.tril(ones, t2 - t1)[None, None, :, :]
return x

Loading…
Cancel
Save