|
|
@ -103,7 +103,7 @@ class MultiHeadedAttention(nn.Layer):
|
|
|
|
mask = paddle.logical_not(mask)
|
|
|
|
mask = paddle.logical_not(mask)
|
|
|
|
# assume scores.dtype==paddle.float32, we only use "float32" here
|
|
|
|
# assume scores.dtype==paddle.float32, we only use "float32" here
|
|
|
|
dtype = str(scores.dtype).split(".")[-1]
|
|
|
|
dtype = str(scores.dtype).split(".")[-1]
|
|
|
|
min_value = numpy.finfo(dtype).min
|
|
|
|
min_value = float(numpy.finfo(dtype).min)
|
|
|
|
scores = masked_fill(scores, mask, min_value)
|
|
|
|
scores = masked_fill(scores, mask, min_value)
|
|
|
|
# (batch, head, time1, time2)
|
|
|
|
# (batch, head, time1, time2)
|
|
|
|
self.attn = softmax(scores)
|
|
|
|
self.attn = softmax(scores)
|
|
|
@ -192,12 +192,11 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
|
|
|
x_padded = paddle.concat([zero_pad, x], axis=-1)
|
|
|
|
x_padded = paddle.concat([zero_pad, x], axis=-1)
|
|
|
|
x_padded = x_padded.reshape([b, h, t2 + 1, t1])
|
|
|
|
x_padded = x_padded.reshape([b, h, t2 + 1, t1])
|
|
|
|
# only keep the positions from 0 to time2
|
|
|
|
# only keep the positions from 0 to time2
|
|
|
|
x = x_padded[:, :, 1:].reshape([b, h, t1, t2])[:, :, :, :t2 // 2 + 1]
|
|
|
|
new_t = paddle.cast(paddle.floor(t2 / 2) + 1, dtype='int32')
|
|
|
|
|
|
|
|
x = x_padded[:, :, 1:].reshape([b, h, t1, t2])[:, :, :, :new_t]
|
|
|
|
if self.zero_triu:
|
|
|
|
if self.zero_triu:
|
|
|
|
ones = paddle.ones((t1, t2))
|
|
|
|
ones = paddle.ones((t1, t2))
|
|
|
|
x = x * paddle.tril(ones, t2 - t1)[None, None, :, :]
|
|
|
|
x = x * paddle.tril(ones, t2 - t1)[None, None, :, :]
|
|
|
|
|
|
|
|
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, query, key, value, pos_emb, mask):
|
|
|
|
def forward(self, query, key, value, pos_emb, mask):
|
|
|
@ -221,7 +220,6 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
|
|
|
q, k, v = self.forward_qkv(query, key, value)
|
|
|
|
q, k, v = self.forward_qkv(query, key, value)
|
|
|
|
# (batch, time1, head, d_k)
|
|
|
|
# (batch, time1, head, d_k)
|
|
|
|
q = q.transpose([0, 2, 1, 3])
|
|
|
|
q = q.transpose([0, 2, 1, 3])
|
|
|
|
|
|
|
|
|
|
|
|
n_batch_pos = paddle.shape(pos_emb)[0]
|
|
|
|
n_batch_pos = paddle.shape(pos_emb)[0]
|
|
|
|
p = self.linear_pos(pos_emb).reshape(
|
|
|
|
p = self.linear_pos(pos_emb).reshape(
|
|
|
|
[n_batch_pos, -1, self.h, self.d_k])
|
|
|
|
[n_batch_pos, -1, self.h, self.d_k])
|
|
|
|