Code clean up for CIs

pull/3242/head
jiamingkong 2 years ago
parent 3ef28dee45
commit 0e2068e2cf

@ -49,9 +49,6 @@ def _mha_shape_check(query: paddle.Tensor, key: paddle.Tensor, value: paddle.Ten
raise AssertionError( raise AssertionError(
f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor") f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor")
def masked_fill(x, mask, value):
y = paddle.full(x.shape, value)
def scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal): def scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal):
""" """
@ -61,18 +58,22 @@ def scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal):
d_key = k.shape[-1] d_key = k.shape[-1]
scaled_q = paddle.scale(x=q, scale=d_key ** -0.5) scaled_q = paddle.scale(x=q, scale=d_key ** -0.5)
product = paddle.matmul(x=scaled_q, y=k, transpose_y=True) product = paddle.matmul(x=scaled_q, y=k, transpose_y=True)
weights = paddle.nn.functional.softmax(x=product + attn_mask) weights = F.softmax(x=product + attn_mask)
if dropout_p: if dropout_p:
weights = paddle.fluid.layers.nn.dropout( weights = F.dropout(
weights, weights,
dropout_prob=dropout_p, p=dropout_p,
dropout_implementation="upscale_in_train", training=True,
is_test=False) mode="upscale_in_train"
)
out = paddle.matmul(x=weights, y=v) out = paddle.matmul(x=weights, y=v)
return out return out
def addr(input, vec1, vec2, beta=1, alpha=1, out=None): def addr(input, vec1, vec2, beta=1, alpha=1, out=None):
"""
A helper function to calculate alpha*(vec1*vec2^T) + beta*input
"""
row = vec1.shape[0] row = vec1.shape[0]
column = vec2.shape[0] column = vec2.shape[0]
vec1 = paddle.unsqueeze(vec1, 0) vec1 = paddle.unsqueeze(vec1, 0)
@ -164,12 +165,11 @@ def _in_projection_packed(
- in output list :math:`[q', k', v']`, each output tensor will have the - in output list :math:`[q', k', v']`, each output tensor will have the
same shape as the corresponding input tensor. same shape as the corresponding input tensor.
""" """
# E = q.size(-1)
E = q.shape[-1] E = q.shape[-1]
if k is v: if k is v:
if q is k: if q is k:
# self-attention # self-attention
proj = linear(q, w, b) proj = F.linear(q, w, b)
# reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk() # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose([2, 1, 0]).squeeze(-2).contiguous() proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose([2, 1, 0]).squeeze(-2).contiguous()
return proj[0], proj[1], proj[2] return proj[0], proj[1], proj[2]
@ -180,8 +180,8 @@ def _in_projection_packed(
b_q = b_kv = None b_q = b_kv = None
else: else:
b_q, b_kv = b.split([E, E * 2]) b_q, b_kv = b.split([E, E * 2])
q_proj = linear(q, w_q, b_q) q_proj = F.linear(q, w_q, b_q)
kv_proj = linear(k, w_kv, b_kv) kv_proj = F.linear(k, w_kv, b_kv)
# reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk() # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk()
kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose([2, 1, 0]).squeeze(-2).contiguous() kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose([2, 1, 0]).squeeze(-2).contiguous()
return (q_proj, kv_proj[0], kv_proj[1]) return (q_proj, kv_proj[0], kv_proj[1])
@ -191,7 +191,7 @@ def _in_projection_packed(
b_q = b_k = b_v = None b_q = b_k = b_v = None
else: else:
b_q, b_k, b_v = b.chunk(3) b_q, b_k, b_v = b.chunk(3)
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
def _in_projection( def _in_projection(
q: paddle.Tensor, q: paddle.Tensor,
@ -204,10 +204,8 @@ def _in_projection(
b_k: Optional[paddle.Tensor] = None, b_k: Optional[paddle.Tensor] = None,
b_v: Optional[paddle.Tensor] = None, b_v: Optional[paddle.Tensor] = None,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
A, B, C = linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) A, B, C = F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
return A, B, C return A, B, C
# return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
def multi_head_attention_forward_paddle( def multi_head_attention_forward_paddle(
query: paddle.Tensor, query: paddle.Tensor,
@ -299,22 +297,7 @@ def multi_head_attention_forward_paddle(
""" """
is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads) is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
# is batched, run the computation and before returning squeeze the
# batch dimension so that the output doesn't carry this temporary batch dimension.
# if not is_batched:
# # unsqueeze if the input is unbatched
# query = query.unsqueeze(1)
# key = key.unsqueeze(1)
# value = value.unsqueeze(1)
# if key_padding_mask is not None:
# key_padding_mask = key_padding_mask.unsqueeze(0)
# set up shape vars
# import pdb; pdb.set_trace()
tgt_len, bsz, embed_dim = query.shape tgt_len, bsz, embed_dim = query.shape
# tgt_len, bsz, embed_dim = query.shape
src_len, _, _ = key.shape src_len, _, _ = key.shape
if is_causal: if is_causal:
@ -373,9 +356,7 @@ def multi_head_attention_forward_paddle(
if bias_k is not None and bias_v is not None: if bias_k is not None and bias_v is not None:
assert static_k is None, "bias cannot be added to static key." assert static_k is None, "bias cannot be added to static key."
assert static_v is None, "bias cannot be added to static value." assert static_v is None, "bias cannot be added to static value."
# k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
k = paddle.concat([k, bias_k.repeat(1, bsz, 1)], axis=1) k = paddle.concat([k, bias_k.repeat(1, bsz, 1)], axis=1)
# v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
v = paddle.concat([v, bias_v.repeat(1, bsz, 1)], axis=1) v = paddle.concat([v, bias_v.repeat(1, bsz, 1)], axis=1)
if attn_mask is not None: if attn_mask is not None:
# attn_mask = pad(attn_mask, (0, 1)) # attn_mask = pad(attn_mask, (0, 1))
@ -392,22 +373,18 @@ def multi_head_attention_forward_paddle(
# #
# reshape q, k, v for multihead attention and make em batch first # reshape q, k, v for multihead attention and make em batch first
# #
# q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
q = q.reshape([tgt_len, bsz * num_heads, head_dim]).transpose([1, 0, 2]) q = q.reshape([tgt_len, bsz * num_heads, head_dim]).transpose([1, 0, 2])
if static_k is None: if static_k is None:
# k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
k = k.reshape([k.shape[0], bsz * num_heads, head_dim]).transpose([1, 0, 2]) k = k.reshape([k.shape[0], bsz * num_heads, head_dim]).transpose([1, 0, 2])
else: else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert static_k.size(0) == bsz * num_heads, \ assert static_k.size(0) == bsz * num_heads, \
f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}" f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
assert static_k.size(2) == head_dim, \ assert static_k.size(2) == head_dim, \
f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
k = static_k k = static_k
if static_v is None: if static_v is None:
# v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
v = v.reshape([v.shape[0], bsz * num_heads, head_dim]).transpose([1, 0, 2]) v = v.reshape([v.shape[0], bsz * num_heads, head_dim]).transpose([1, 0, 2])
else: else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed # TODO finish disentangling control flow so we don't do in-projections when statics are passed
@ -420,9 +397,7 @@ def multi_head_attention_forward_paddle(
# add zero attention along batch dimension (now first) # add zero attention along batch dimension (now first)
if add_zero_attn: if add_zero_attn:
zero_attn_shape = (bsz * num_heads, 1, head_dim) zero_attn_shape = (bsz * num_heads, 1, head_dim)
# k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
k = paddle.concat([k, paddle.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], axis=1) k = paddle.concat([k, paddle.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], axis=1)
# v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
v = paddle.concat([v, paddle.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], axis=1) v = paddle.concat([v, paddle.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], axis=1)
if attn_mask is not None: if attn_mask is not None:
# attn_mask = pad(attn_mask, (0, 1)) # attn_mask = pad(attn_mask, (0, 1))
@ -438,7 +413,6 @@ def multi_head_attention_forward_paddle(
if key_padding_mask is not None: if key_padding_mask is not None:
assert key_padding_mask.shape == (bsz, src_len), \ assert key_padding_mask.shape == (bsz, src_len), \
f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
# key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
key_padding_mask = key_padding_mask.reshape([bsz, 1, 1, src_len]).expand([-1, num_heads, -1, -1]).reshape([bsz * num_heads, 1, src_len]) key_padding_mask = key_padding_mask.reshape([bsz, 1, 1, src_len]).expand([-1, num_heads, -1, -1]).reshape([bsz * num_heads, 1, src_len])
if attn_mask is None: if attn_mask is None:
attn_mask = key_padding_mask attn_mask = key_padding_mask
@ -456,25 +430,20 @@ def multi_head_attention_forward_paddle(
B, Nt, E = q.shape B, Nt, E = q.shape
q_scaled = q / math.sqrt(E) q_scaled = q / math.sqrt(E)
if attn_mask is not None: if attn_mask is not None:
# attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
attn_output_weights = addr(q_scaled, k.transpose(-2, -1)) attn_output_weights = addr(q_scaled, k.transpose(-2, -1))
else: else:
# attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
attn_output_weights = paddle.bmm(q_scaled, k.transpose(0, 2, 1)) attn_output_weights = paddle.bmm(q_scaled, k.transpose(0, 2, 1))
# attn_output_weights = softmax(attn_output_weights, dim=-1) attn_output_weights = F.softmax(attn_output_weights, axis=-1)
attn_output_weights = paddle.nn.functional.softmax(attn_output_weights, axis=-1)
if dropout_p > 0.0: if dropout_p > 0.0:
# attn_output_weights = dropout(attn_output_weights, p=dropout_p) attn_output_weights = F.dropout(attn_output_weights, p=dropout_p)
attn_output_weights = paddle.nn.functional.dropout(attn_output_weights, p=dropout_p)
# attn_output = torch.bmm(attn_output_weights, v)
attn_output = paddle.bmm(attn_output_weights, v) attn_output = paddle.bmm(attn_output_weights, v)
attn_output = attn_output.transpose([1, 0, 2]).reshape([tgt_len * bsz, embed_dim]) attn_output = attn_output.transpose([1, 0, 2]).reshape([tgt_len * bsz, embed_dim])
attn_output = linear(attn_output, out_proj_weight, out_proj_bias) # attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.reshape([tgt_len, bsz, attn_output.shape[1]]) attn_output = attn_output.reshape([tgt_len, bsz, attn_output.shape[1]])
# optionally average attention weights over heads # optionally average attention weights over heads
# attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
attn_output_weights = attn_output_weights.reshape([bsz, num_heads, tgt_len, src_len]) attn_output_weights = attn_output_weights.reshape([bsz, num_heads, tgt_len, src_len])
if average_attn_weights: if average_attn_weights:
attn_output_weights = attn_output_weights.mean(dim=1) attn_output_weights = attn_output_weights.mean(dim=1)
@ -492,7 +461,6 @@ def multi_head_attention_forward_paddle(
if attn_mask.shape[0] == 1 and attn_mask.dim() == 3: if attn_mask.shape[0] == 1 and attn_mask.dim() == 3:
attn_mask = attn_mask.unsqueeze(0) attn_mask = attn_mask.unsqueeze(0)
else: else:
# attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
attn_mask = attn_mask.reshape([bsz, num_heads, -1, src_len]) attn_mask = attn_mask.reshape([bsz, num_heads, -1, src_len])
q = q.reshape([bsz, num_heads, tgt_len, head_dim]) q = q.reshape([bsz, num_heads, tgt_len, head_dim])
@ -500,9 +468,6 @@ def multi_head_attention_forward_paddle(
v = v.reshape([bsz, num_heads, src_len, head_dim]) v = v.reshape([bsz, num_heads, src_len, head_dim])
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal) attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
attn_output = attn_output.transpose(perm=[2, 0, 1, 3]).reshape([bsz * tgt_len, embed_dim]) attn_output = attn_output.transpose(perm=[2, 0, 1, 3]).reshape([bsz * tgt_len, embed_dim])
attn_output = linear(attn_output, out_proj_weight, out_proj_bias) attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.reshape([tgt_len, bsz, attn_output.shape[1]]) attn_output = attn_output.reshape([tgt_len, bsz, attn_output.shape[1]])
# if not is_batched:
# # squeeze the output if input was unbatched
# attn_output = attn_output.squeeze(1)
return attn_output, None return attn_output, None

@ -60,19 +60,6 @@ class Fp32GroupNorm(nn.GroupNorm):
return output.type_as(input) return output.type_as(input)
# class GradMultiply(torch.autograd.Function):
# convert into paddle equivalent
# class GradMultiply(torch.autograd.Function):
# @staticmethod
# def forward(ctx, x, scale):
# ctx.scale = scale
# res = x.new(x)
# return res
# @staticmethod
# def backward(ctx, grad):
# return grad * ctx.scale, None
class SamePad(nn.Layer): class SamePad(nn.Layer):
def __init__(self, kernel_size, causal=False): def __init__(self, kernel_size, causal=False):
@ -95,7 +82,6 @@ class Swish(nn.Layer):
def __init__(self): def __init__(self):
"""Construct an MultiHeadedAttention object.""" """Construct an MultiHeadedAttention object."""
super(Swish, self).__init__() super(Swish, self).__init__()
# self.act = torch.nn.Sigmoid()
self.act = nn.Sigmoid() self.act = nn.Sigmoid()
def forward(self, x): def forward(self, x):
@ -162,7 +148,6 @@ def get_activation_fn(activation: str):
elif activation == "gelu_accurate": elif activation == "gelu_accurate":
return gelu_accurate return gelu_accurate
elif activation == "tanh": elif activation == "tanh":
# return torch.tanh
return paddle.tanh return paddle.tanh
elif activation == "linear": elif activation == "linear":
return lambda x: x return lambda x: x
@ -172,44 +157,6 @@ def get_activation_fn(activation: str):
raise RuntimeError("--activation-fn {} not supported".format(activation)) raise RuntimeError("--activation-fn {} not supported".format(activation))
def init_bert_params(module):
"""
Initialize the weights specific to the BERT Model.
This overrides the default initializations depending on the specified arguments.
1. If normal_init_linear_weights is set then weights of linear
layer will be initialized using the normal distribution and
bais will be set to the specified value.
2. If normal_init_embed_weights is set then weights of embedding
layer will be initialized using the normal distribution.
3. If normal_init_proj_weights is set then weights of
in_project_weight for MultiHeadAttention initialized using
the normal distribution (to be validated).
"""
def normal_(data):
# with FSDP, module params will be on CUDA, so we cast them back to CPU
# so that the RNG is consistent with and without FSDP
data.copy_(
data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
)
if isinstance(module, nn.Linear):
# normal_(module.weight.data)
if module.bias is not None:
# module.bias.data.zero_()
pass
if isinstance(module, nn.Embedding):
# normal_(module.weight.data)
if module.padding_idx is not None:
# module.weight.data[module.padding_idx].zero_()
pass
if isinstance(module, MultiheadAttention):
pass
# normal_(module.q_proj.weight.data)
# normal_(module.k_proj.weight.data)
# normal_(module.v_proj.weight.data)
def quant_noise(module, p, block_size): def quant_noise(module, p, block_size):
""" """
Wraps modules and applies quantization noise to the weights for Wraps modules and applies quantization noise to the weights for
@ -302,9 +249,8 @@ def quant_noise(module, p, block_size):
# scale weights and apply mask # scale weights and apply mask
mask = mask.to( mask = mask.to(
# torch.bool
paddle.bool paddle.bool
) # x.bool() is not currently supported in TorchScript )
s = 1 / (1 - p) s = 1 / (1 - p)
mod.weight.data = s * weight.masked_fill(mask, 0) mod.weight.data = s * weight.masked_fill(mask, 0)
@ -405,7 +351,6 @@ class MultiheadAttention(nn.Layer):
self.gru_rel_pos = gru_rel_pos self.gru_rel_pos = gru_rel_pos
if self.gru_rel_pos: if self.gru_rel_pos:
self.grep_linear = nn.Linear(self.q_head_dim, 8) self.grep_linear = nn.Linear(self.q_head_dim, 8)
# self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
self.grep_a = self.create_parameter( self.grep_a = self.create_parameter(
shape=[1, num_heads, 1, 1], dtype="float32" shape=[1, num_heads, 1, 1], dtype="float32"
) )
@ -415,47 +360,6 @@ class MultiheadAttention(nn.Layer):
def reset_parameters(self): def reset_parameters(self):
pass pass
# if self.qkv_same_dim:
# # Empirically observed the convergence to be much better with
# # the scaled initialization
# # nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
# # self.k_proj.weight.set_value(
# # paddle.nn.initializer.XavierUniform(1, 1)(self.k_proj.weight.shape)
# # )
# # self.v_proj.weight.set_value(
# # paddle.nn.initializer.XavierUniform(1, 1)(self.v_proj.weight.shape)
# # )
# # self.q_proj.weight.set_value(
# # paddle.nn.initializer.XavierUniform(1, 1)(self.q_proj.weight.shape)
# # )
# pass
# # nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
# # nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
# else:
# # nn.init.xavier_uniform_(self.k_proj.weight)
# # nn.init.xavier_uniform_(self.v_proj.weight)
# # nn.init.xavier_uniform_(self.q_proj.weight)
# # self.k_proj.weight.set_value(
# # paddle.nn.initializer.XavierUniform()(self.k_proj.weight.shape)
# # )
# # self.v_proj.weight.set_value(
# # paddle.nn.initializer.XavierUniform()(self.v_proj.weight.shape)
# # )
# # self.q_proj.weight.set_value(
# # paddle.nn.initializer.XavierUniform()(self.q_proj.weight.shape)
# # )
# pass
# nn.init.xavier_uniform_(self.out_proj.weight)
# if self.out_proj.bias is not None:
# nn.init.constant_(self.out_proj.bias, 0.0)
# if self.bias_k is not None:
# nn.init.xavier_normal_(self.bias_k)
# if self.bias_v is not None:
# nn.init.xavier_normal_(self.bias_v)
# if self.has_relative_attention_bias:
# nn.init.xavier_normal_(self.relative_attention_bias.weight)
def _relative_positions_bucket(self, relative_positions, bidirectional=True): def _relative_positions_bucket(self, relative_positions, bidirectional=True):
num_buckets = self.num_buckets num_buckets = self.num_buckets
@ -544,7 +448,6 @@ class MultiheadAttention(nn.Layer):
position_bias = paddle.concat([position_bias_ for _ in range(bsz)], axis=0) position_bias = paddle.concat([position_bias_ for _ in range(bsz)], axis=0)
position_bias = position_bias.reshape([bsz * self.num_heads, tgt_len, src_len]) position_bias = position_bias.reshape([bsz * self.num_heads, tgt_len, src_len])
if ( if (
# not is_tpu # don't use PyTorch version on TPUs
incremental_state is None incremental_state is None
and not static_kv and not static_kv
and self.q_head_dim == self.head_dim and self.q_head_dim == self.head_dim
@ -740,7 +643,6 @@ class MultiheadAttention(nn.Layer):
) )
# attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = paddle.bmm(q, k.transpose(1, 2)) attn_weights = paddle.bmm(q, k.transpose(1, 2))
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
@ -753,16 +655,10 @@ class MultiheadAttention(nn.Layer):
if key_padding_mask is not None: if key_padding_mask is not None:
# don't attend to padding symbols # don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
if not is_tpu:
attn_weights = attn_weights.masked_fill( attn_weights = attn_weights.masked_fill(
# key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
key_padding_mask.unsqueeze(1).unsqueeze(2).to(paddle.bool), key_padding_mask.unsqueeze(1).unsqueeze(2).to(paddle.bool),
float("-inf"), float("-inf"),
) )
else:
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if before_softmax: if before_softmax:
@ -772,8 +668,6 @@ class MultiheadAttention(nn.Layer):
if self.gru_rel_pos == 1: if self.gru_rel_pos == 1:
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
_B, _H, _L, __ = query_layer.shape _B, _H, _L, __ = query_layer.shape
# gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
# _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
gate_a, gate_b = paddle.sigmoid(self.grep_linear(query_layer).view( gate_a, gate_b = paddle.sigmoid(self.grep_linear(query_layer).view(
_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, axis=-1) _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, axis=-1)
@ -791,7 +685,6 @@ class MultiheadAttention(nn.Layer):
attn_probs = self.dropout_module(attn_weights) attn_probs = self.dropout_module(attn_weights)
assert v is not None assert v is not None
# attn = torch.bmm(attn_probs, v)
attn = paddle.bmm(attn_probs, v) attn = paddle.bmm(attn_probs, v)
assert list(attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim] assert list(attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
@ -819,9 +712,6 @@ class MultiheadAttention(nn.Layer):
if prev_key_padding_mask is not None and static_kv: if prev_key_padding_mask is not None and static_kv:
new_key_padding_mask = prev_key_padding_mask new_key_padding_mask = prev_key_padding_mask
elif prev_key_padding_mask is not None and key_padding_mask is not None: elif prev_key_padding_mask is not None and key_padding_mask is not None:
# new_key_padding_mask = torch.cat(
# [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
# )
new_key_padding_mask = paddle.concat( new_key_padding_mask = paddle.concat(
[prev_key_padding_mask.float(), key_padding_mask.float()], axis=1 [prev_key_padding_mask.float(), key_padding_mask.float()], axis=1
) )
@ -830,18 +720,10 @@ class MultiheadAttention(nn.Layer):
# is None # is None
elif prev_key_padding_mask is not None: elif prev_key_padding_mask is not None:
if src_len > prev_key_padding_mask.size(1): if src_len > prev_key_padding_mask.size(1):
# filler = torch.zeros(
# (batch_size, src_len - prev_key_padding_mask.size(1)),
# device=prev_key_padding_mask.device,
# )
filler = paddle.zeros( filler = paddle.zeros(
(batch_size, src_len - prev_key_padding_mask.size(1)), (batch_size, src_len - prev_key_padding_mask.size(1)),
device=prev_key_padding_mask.device, device=prev_key_padding_mask.device,
) )
# new_key_padding_mask = torch.cat(
# [prev_key_padding_mask.float(), filler.float()], dim=1
# )
new_key_padding_mask = paddle.concat( new_key_padding_mask = paddle.concat(
[prev_key_padding_mask.float(), filler.float()], axis=1 [prev_key_padding_mask.float(), filler.float()], axis=1
) )
@ -850,17 +732,10 @@ class MultiheadAttention(nn.Layer):
new_key_padding_mask = prev_key_padding_mask.float() new_key_padding_mask = prev_key_padding_mask.float()
elif key_padding_mask is not None: elif key_padding_mask is not None:
if src_len > key_padding_mask.size(1): if src_len > key_padding_mask.size(1):
# filler = torch.zeros(
# (batch_size, src_len - key_padding_mask.size(1)),
# device=key_padding_mask.device,
# )
filler = paddle.zeros( filler = paddle.zeros(
(batch_size, src_len - key_padding_mask.size(1)), (batch_size, src_len - key_padding_mask.size(1)),
device=key_padding_mask.device, device=key_padding_mask.device,
) )
# new_key_padding_mask = torch.cat(
# [filler.float(), key_padding_mask.float()], dim=1
# )
new_key_padding_mask = paddle.concat( new_key_padding_mask = paddle.concat(
[filler.float(), key_padding_mask.float()], axis=1 [filler.float(), key_padding_mask.float()], axis=1
) )

@ -21,7 +21,6 @@ from paddle import Tensor
from .modules.modules import ( from .modules.modules import (
MultiheadAttention, MultiheadAttention,
SamePad, SamePad,
init_bert_params,
get_activation_fn, get_activation_fn,
TransposeLast, TransposeLast,
GLU_Linear, GLU_Linear,

Loading…
Cancel
Save