|
|
@ -49,9 +49,6 @@ def _mha_shape_check(query: paddle.Tensor, key: paddle.Tensor, value: paddle.Ten
|
|
|
|
raise AssertionError(
|
|
|
|
raise AssertionError(
|
|
|
|
f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor")
|
|
|
|
f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor")
|
|
|
|
|
|
|
|
|
|
|
|
def masked_fill(x, mask, value):
|
|
|
|
|
|
|
|
y = paddle.full(x.shape, value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal):
|
|
|
|
def scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -61,18 +58,22 @@ def scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal):
|
|
|
|
d_key = k.shape[-1]
|
|
|
|
d_key = k.shape[-1]
|
|
|
|
scaled_q = paddle.scale(x=q, scale=d_key ** -0.5)
|
|
|
|
scaled_q = paddle.scale(x=q, scale=d_key ** -0.5)
|
|
|
|
product = paddle.matmul(x=scaled_q, y=k, transpose_y=True)
|
|
|
|
product = paddle.matmul(x=scaled_q, y=k, transpose_y=True)
|
|
|
|
weights = paddle.nn.functional.softmax(x=product + attn_mask)
|
|
|
|
weights = F.softmax(x=product + attn_mask)
|
|
|
|
if dropout_p:
|
|
|
|
if dropout_p:
|
|
|
|
weights = paddle.fluid.layers.nn.dropout(
|
|
|
|
weights = F.dropout(
|
|
|
|
weights,
|
|
|
|
weights,
|
|
|
|
dropout_prob=dropout_p,
|
|
|
|
p=dropout_p,
|
|
|
|
dropout_implementation="upscale_in_train",
|
|
|
|
training=True,
|
|
|
|
is_test=False)
|
|
|
|
mode="upscale_in_train"
|
|
|
|
|
|
|
|
)
|
|
|
|
out = paddle.matmul(x=weights, y=v)
|
|
|
|
out = paddle.matmul(x=weights, y=v)
|
|
|
|
return out
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def addr(input, vec1, vec2, beta=1, alpha=1, out=None):
|
|
|
|
def addr(input, vec1, vec2, beta=1, alpha=1, out=None):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
A helper function to calculate alpha*(vec1*vec2^T) + beta*input
|
|
|
|
|
|
|
|
"""
|
|
|
|
row = vec1.shape[0]
|
|
|
|
row = vec1.shape[0]
|
|
|
|
column = vec2.shape[0]
|
|
|
|
column = vec2.shape[0]
|
|
|
|
vec1 = paddle.unsqueeze(vec1, 0)
|
|
|
|
vec1 = paddle.unsqueeze(vec1, 0)
|
|
|
@ -164,12 +165,11 @@ def _in_projection_packed(
|
|
|
|
- in output list :math:`[q', k', v']`, each output tensor will have the
|
|
|
|
- in output list :math:`[q', k', v']`, each output tensor will have the
|
|
|
|
same shape as the corresponding input tensor.
|
|
|
|
same shape as the corresponding input tensor.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
# E = q.size(-1)
|
|
|
|
|
|
|
|
E = q.shape[-1]
|
|
|
|
E = q.shape[-1]
|
|
|
|
if k is v:
|
|
|
|
if k is v:
|
|
|
|
if q is k:
|
|
|
|
if q is k:
|
|
|
|
# self-attention
|
|
|
|
# self-attention
|
|
|
|
proj = linear(q, w, b)
|
|
|
|
proj = F.linear(q, w, b)
|
|
|
|
# reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
|
|
|
|
# reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
|
|
|
|
proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose([2, 1, 0]).squeeze(-2).contiguous()
|
|
|
|
proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose([2, 1, 0]).squeeze(-2).contiguous()
|
|
|
|
return proj[0], proj[1], proj[2]
|
|
|
|
return proj[0], proj[1], proj[2]
|
|
|
@ -180,8 +180,8 @@ def _in_projection_packed(
|
|
|
|
b_q = b_kv = None
|
|
|
|
b_q = b_kv = None
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
b_q, b_kv = b.split([E, E * 2])
|
|
|
|
b_q, b_kv = b.split([E, E * 2])
|
|
|
|
q_proj = linear(q, w_q, b_q)
|
|
|
|
q_proj = F.linear(q, w_q, b_q)
|
|
|
|
kv_proj = linear(k, w_kv, b_kv)
|
|
|
|
kv_proj = F.linear(k, w_kv, b_kv)
|
|
|
|
# reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk()
|
|
|
|
# reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk()
|
|
|
|
kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose([2, 1, 0]).squeeze(-2).contiguous()
|
|
|
|
kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose([2, 1, 0]).squeeze(-2).contiguous()
|
|
|
|
return (q_proj, kv_proj[0], kv_proj[1])
|
|
|
|
return (q_proj, kv_proj[0], kv_proj[1])
|
|
|
@ -191,7 +191,7 @@ def _in_projection_packed(
|
|
|
|
b_q = b_k = b_v = None
|
|
|
|
b_q = b_k = b_v = None
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
b_q, b_k, b_v = b.chunk(3)
|
|
|
|
b_q, b_k, b_v = b.chunk(3)
|
|
|
|
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
|
|
|
|
return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
|
|
|
|
|
|
|
|
|
|
|
|
def _in_projection(
|
|
|
|
def _in_projection(
|
|
|
|
q: paddle.Tensor,
|
|
|
|
q: paddle.Tensor,
|
|
|
@ -204,10 +204,8 @@ def _in_projection(
|
|
|
|
b_k: Optional[paddle.Tensor] = None,
|
|
|
|
b_k: Optional[paddle.Tensor] = None,
|
|
|
|
b_v: Optional[paddle.Tensor] = None,
|
|
|
|
b_v: Optional[paddle.Tensor] = None,
|
|
|
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
|
|
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
|
|
|
A, B, C = linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
|
|
|
|
A, B, C = F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
|
|
|
|
|
|
|
|
|
|
|
|
return A, B, C
|
|
|
|
return A, B, C
|
|
|
|
# return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def multi_head_attention_forward_paddle(
|
|
|
|
def multi_head_attention_forward_paddle(
|
|
|
|
query: paddle.Tensor,
|
|
|
|
query: paddle.Tensor,
|
|
|
@ -299,22 +297,7 @@ def multi_head_attention_forward_paddle(
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
|
|
|
|
is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
|
|
|
|
|
|
|
|
|
|
|
|
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
|
|
|
|
|
|
|
|
# is batched, run the computation and before returning squeeze the
|
|
|
|
|
|
|
|
# batch dimension so that the output doesn't carry this temporary batch dimension.
|
|
|
|
|
|
|
|
# if not is_batched:
|
|
|
|
|
|
|
|
# # unsqueeze if the input is unbatched
|
|
|
|
|
|
|
|
# query = query.unsqueeze(1)
|
|
|
|
|
|
|
|
# key = key.unsqueeze(1)
|
|
|
|
|
|
|
|
# value = value.unsqueeze(1)
|
|
|
|
|
|
|
|
# if key_padding_mask is not None:
|
|
|
|
|
|
|
|
# key_padding_mask = key_padding_mask.unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# set up shape vars
|
|
|
|
|
|
|
|
# import pdb; pdb.set_trace()
|
|
|
|
|
|
|
|
tgt_len, bsz, embed_dim = query.shape
|
|
|
|
tgt_len, bsz, embed_dim = query.shape
|
|
|
|
# tgt_len, bsz, embed_dim = query.shape
|
|
|
|
|
|
|
|
src_len, _, _ = key.shape
|
|
|
|
src_len, _, _ = key.shape
|
|
|
|
|
|
|
|
|
|
|
|
if is_causal:
|
|
|
|
if is_causal:
|
|
|
@ -373,9 +356,7 @@ def multi_head_attention_forward_paddle(
|
|
|
|
if bias_k is not None and bias_v is not None:
|
|
|
|
if bias_k is not None and bias_v is not None:
|
|
|
|
assert static_k is None, "bias cannot be added to static key."
|
|
|
|
assert static_k is None, "bias cannot be added to static key."
|
|
|
|
assert static_v is None, "bias cannot be added to static value."
|
|
|
|
assert static_v is None, "bias cannot be added to static value."
|
|
|
|
# k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
|
|
|
|
|
|
|
k = paddle.concat([k, bias_k.repeat(1, bsz, 1)], axis=1)
|
|
|
|
k = paddle.concat([k, bias_k.repeat(1, bsz, 1)], axis=1)
|
|
|
|
# v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
|
|
|
|
|
|
|
v = paddle.concat([v, bias_v.repeat(1, bsz, 1)], axis=1)
|
|
|
|
v = paddle.concat([v, bias_v.repeat(1, bsz, 1)], axis=1)
|
|
|
|
if attn_mask is not None:
|
|
|
|
if attn_mask is not None:
|
|
|
|
# attn_mask = pad(attn_mask, (0, 1))
|
|
|
|
# attn_mask = pad(attn_mask, (0, 1))
|
|
|
@ -392,22 +373,18 @@ def multi_head_attention_forward_paddle(
|
|
|
|
#
|
|
|
|
#
|
|
|
|
# reshape q, k, v for multihead attention and make em batch first
|
|
|
|
# reshape q, k, v for multihead attention and make em batch first
|
|
|
|
#
|
|
|
|
#
|
|
|
|
# q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
|
|
|
|
|
|
|
q = q.reshape([tgt_len, bsz * num_heads, head_dim]).transpose([1, 0, 2])
|
|
|
|
q = q.reshape([tgt_len, bsz * num_heads, head_dim]).transpose([1, 0, 2])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if static_k is None:
|
|
|
|
if static_k is None:
|
|
|
|
# k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
|
|
|
|
|
|
|
k = k.reshape([k.shape[0], bsz * num_heads, head_dim]).transpose([1, 0, 2])
|
|
|
|
k = k.reshape([k.shape[0], bsz * num_heads, head_dim]).transpose([1, 0, 2])
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
|
|
|
|
|
|
|
assert static_k.size(0) == bsz * num_heads, \
|
|
|
|
assert static_k.size(0) == bsz * num_heads, \
|
|
|
|
f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
|
|
|
|
f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
|
|
|
|
assert static_k.size(2) == head_dim, \
|
|
|
|
assert static_k.size(2) == head_dim, \
|
|
|
|
f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
|
|
|
|
f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
|
|
|
|
k = static_k
|
|
|
|
k = static_k
|
|
|
|
if static_v is None:
|
|
|
|
if static_v is None:
|
|
|
|
# v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
|
|
|
|
|
|
|
v = v.reshape([v.shape[0], bsz * num_heads, head_dim]).transpose([1, 0, 2])
|
|
|
|
v = v.reshape([v.shape[0], bsz * num_heads, head_dim]).transpose([1, 0, 2])
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
|
|
|
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
|
|
@ -420,9 +397,7 @@ def multi_head_attention_forward_paddle(
|
|
|
|
# add zero attention along batch dimension (now first)
|
|
|
|
# add zero attention along batch dimension (now first)
|
|
|
|
if add_zero_attn:
|
|
|
|
if add_zero_attn:
|
|
|
|
zero_attn_shape = (bsz * num_heads, 1, head_dim)
|
|
|
|
zero_attn_shape = (bsz * num_heads, 1, head_dim)
|
|
|
|
# k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
|
|
|
|
|
|
|
|
k = paddle.concat([k, paddle.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], axis=1)
|
|
|
|
k = paddle.concat([k, paddle.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], axis=1)
|
|
|
|
# v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
|
|
|
|
|
|
|
|
v = paddle.concat([v, paddle.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], axis=1)
|
|
|
|
v = paddle.concat([v, paddle.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], axis=1)
|
|
|
|
if attn_mask is not None:
|
|
|
|
if attn_mask is not None:
|
|
|
|
# attn_mask = pad(attn_mask, (0, 1))
|
|
|
|
# attn_mask = pad(attn_mask, (0, 1))
|
|
|
@ -438,7 +413,6 @@ def multi_head_attention_forward_paddle(
|
|
|
|
if key_padding_mask is not None:
|
|
|
|
if key_padding_mask is not None:
|
|
|
|
assert key_padding_mask.shape == (bsz, src_len), \
|
|
|
|
assert key_padding_mask.shape == (bsz, src_len), \
|
|
|
|
f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
|
|
|
|
f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
|
|
|
|
# key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
|
|
|
|
|
|
|
|
key_padding_mask = key_padding_mask.reshape([bsz, 1, 1, src_len]).expand([-1, num_heads, -1, -1]).reshape([bsz * num_heads, 1, src_len])
|
|
|
|
key_padding_mask = key_padding_mask.reshape([bsz, 1, 1, src_len]).expand([-1, num_heads, -1, -1]).reshape([bsz * num_heads, 1, src_len])
|
|
|
|
if attn_mask is None:
|
|
|
|
if attn_mask is None:
|
|
|
|
attn_mask = key_padding_mask
|
|
|
|
attn_mask = key_padding_mask
|
|
|
@ -456,25 +430,20 @@ def multi_head_attention_forward_paddle(
|
|
|
|
B, Nt, E = q.shape
|
|
|
|
B, Nt, E = q.shape
|
|
|
|
q_scaled = q / math.sqrt(E)
|
|
|
|
q_scaled = q / math.sqrt(E)
|
|
|
|
if attn_mask is not None:
|
|
|
|
if attn_mask is not None:
|
|
|
|
# attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
|
|
|
|
|
|
|
|
attn_output_weights = addr(q_scaled, k.transpose(-2, -1))
|
|
|
|
attn_output_weights = addr(q_scaled, k.transpose(-2, -1))
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
# attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
|
|
|
|
|
|
|
|
attn_output_weights = paddle.bmm(q_scaled, k.transpose(0, 2, 1))
|
|
|
|
attn_output_weights = paddle.bmm(q_scaled, k.transpose(0, 2, 1))
|
|
|
|
# attn_output_weights = softmax(attn_output_weights, dim=-1)
|
|
|
|
attn_output_weights = F.softmax(attn_output_weights, axis=-1)
|
|
|
|
attn_output_weights = paddle.nn.functional.softmax(attn_output_weights, axis=-1)
|
|
|
|
|
|
|
|
if dropout_p > 0.0:
|
|
|
|
if dropout_p > 0.0:
|
|
|
|
# attn_output_weights = dropout(attn_output_weights, p=dropout_p)
|
|
|
|
attn_output_weights = F.dropout(attn_output_weights, p=dropout_p)
|
|
|
|
attn_output_weights = paddle.nn.functional.dropout(attn_output_weights, p=dropout_p)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# attn_output = torch.bmm(attn_output_weights, v)
|
|
|
|
|
|
|
|
attn_output = paddle.bmm(attn_output_weights, v)
|
|
|
|
attn_output = paddle.bmm(attn_output_weights, v)
|
|
|
|
attn_output = attn_output.transpose([1, 0, 2]).reshape([tgt_len * bsz, embed_dim])
|
|
|
|
attn_output = attn_output.transpose([1, 0, 2]).reshape([tgt_len * bsz, embed_dim])
|
|
|
|
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
|
|
|
# attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
|
|
|
|
|
|
|
attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
|
|
|
|
attn_output = attn_output.reshape([tgt_len, bsz, attn_output.shape[1]])
|
|
|
|
attn_output = attn_output.reshape([tgt_len, bsz, attn_output.shape[1]])
|
|
|
|
|
|
|
|
|
|
|
|
# optionally average attention weights over heads
|
|
|
|
# optionally average attention weights over heads
|
|
|
|
# attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
|
|
|
|
|
|
|
attn_output_weights = attn_output_weights.reshape([bsz, num_heads, tgt_len, src_len])
|
|
|
|
attn_output_weights = attn_output_weights.reshape([bsz, num_heads, tgt_len, src_len])
|
|
|
|
if average_attn_weights:
|
|
|
|
if average_attn_weights:
|
|
|
|
attn_output_weights = attn_output_weights.mean(dim=1)
|
|
|
|
attn_output_weights = attn_output_weights.mean(dim=1)
|
|
|
@ -492,7 +461,6 @@ def multi_head_attention_forward_paddle(
|
|
|
|
if attn_mask.shape[0] == 1 and attn_mask.dim() == 3:
|
|
|
|
if attn_mask.shape[0] == 1 and attn_mask.dim() == 3:
|
|
|
|
attn_mask = attn_mask.unsqueeze(0)
|
|
|
|
attn_mask = attn_mask.unsqueeze(0)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
# attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
|
|
|
|
|
|
|
|
attn_mask = attn_mask.reshape([bsz, num_heads, -1, src_len])
|
|
|
|
attn_mask = attn_mask.reshape([bsz, num_heads, -1, src_len])
|
|
|
|
|
|
|
|
|
|
|
|
q = q.reshape([bsz, num_heads, tgt_len, head_dim])
|
|
|
|
q = q.reshape([bsz, num_heads, tgt_len, head_dim])
|
|
|
@ -500,9 +468,6 @@ def multi_head_attention_forward_paddle(
|
|
|
|
v = v.reshape([bsz, num_heads, src_len, head_dim])
|
|
|
|
v = v.reshape([bsz, num_heads, src_len, head_dim])
|
|
|
|
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
|
|
|
|
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
|
|
|
|
attn_output = attn_output.transpose(perm=[2, 0, 1, 3]).reshape([bsz * tgt_len, embed_dim])
|
|
|
|
attn_output = attn_output.transpose(perm=[2, 0, 1, 3]).reshape([bsz * tgt_len, embed_dim])
|
|
|
|
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
|
|
|
attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
|
|
|
|
attn_output = attn_output.reshape([tgt_len, bsz, attn_output.shape[1]])
|
|
|
|
attn_output = attn_output.reshape([tgt_len, bsz, attn_output.shape[1]])
|
|
|
|
# if not is_batched:
|
|
|
|
|
|
|
|
# # squeeze the output if input was unbatched
|
|
|
|
|
|
|
|
# attn_output = attn_output.squeeze(1)
|
|
|
|
|
|
|
|
return attn_output, None
|
|
|
|
return attn_output, None
|