diff --git a/paddlespeech/s2t/models/wavlm/modules/functional.py b/paddlespeech/s2t/models/wavlm/modules/functional.py index 2959c810..d2ebdc71 100644 --- a/paddlespeech/s2t/models/wavlm/modules/functional.py +++ b/paddlespeech/s2t/models/wavlm/modules/functional.py @@ -49,9 +49,6 @@ def _mha_shape_check(query: paddle.Tensor, key: paddle.Tensor, value: paddle.Ten raise AssertionError( f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor") -def masked_fill(x, mask, value): - y = paddle.full(x.shape, value) - def scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal): """ @@ -61,18 +58,22 @@ def scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal): d_key = k.shape[-1] scaled_q = paddle.scale(x=q, scale=d_key ** -0.5) product = paddle.matmul(x=scaled_q, y=k, transpose_y=True) - weights = paddle.nn.functional.softmax(x=product + attn_mask) + weights = F.softmax(x=product + attn_mask) if dropout_p: - weights = paddle.fluid.layers.nn.dropout( + weights = F.dropout( weights, - dropout_prob=dropout_p, - dropout_implementation="upscale_in_train", - is_test=False) + p=dropout_p, + training=True, + mode="upscale_in_train" + ) out = paddle.matmul(x=weights, y=v) return out def addr(input, vec1, vec2, beta=1, alpha=1, out=None): + """ + A helper function to calculate alpha*(vec1*vec2^T) + beta*input + """ row = vec1.shape[0] column = vec2.shape[0] vec1 = paddle.unsqueeze(vec1, 0) @@ -164,12 +165,11 @@ def _in_projection_packed( - in output list :math:`[q', k', v']`, each output tensor will have the same shape as the corresponding input tensor. """ - # E = q.size(-1) E = q.shape[-1] if k is v: if q is k: # self-attention - proj = linear(q, w, b) + proj = F.linear(q, w, b) # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk() proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose([2, 1, 0]).squeeze(-2).contiguous() return proj[0], proj[1], proj[2] @@ -180,8 +180,8 @@ def _in_projection_packed( b_q = b_kv = None else: b_q, b_kv = b.split([E, E * 2]) - q_proj = linear(q, w_q, b_q) - kv_proj = linear(k, w_kv, b_kv) + q_proj = F.linear(q, w_q, b_q) + kv_proj = F.linear(k, w_kv, b_kv) # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk() kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose([2, 1, 0]).squeeze(-2).contiguous() return (q_proj, kv_proj[0], kv_proj[1]) @@ -191,7 +191,7 @@ def _in_projection_packed( b_q = b_k = b_v = None else: b_q, b_k, b_v = b.chunk(3) - return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) + return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v) def _in_projection( q: paddle.Tensor, @@ -204,10 +204,8 @@ def _in_projection( b_k: Optional[paddle.Tensor] = None, b_v: Optional[paddle.Tensor] = None, ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: - A, B, C = linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) - + A, B, C = F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v) return A, B, C - # return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) def multi_head_attention_forward_paddle( query: paddle.Tensor, @@ -299,22 +297,7 @@ def multi_head_attention_forward_paddle( """ is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads) - - # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input - # is batched, run the computation and before returning squeeze the - # batch dimension so that the output doesn't carry this temporary batch dimension. - # if not is_batched: - # # unsqueeze if the input is unbatched - # query = query.unsqueeze(1) - # key = key.unsqueeze(1) - # value = value.unsqueeze(1) - # if key_padding_mask is not None: - # key_padding_mask = key_padding_mask.unsqueeze(0) - - # set up shape vars - # import pdb; pdb.set_trace() tgt_len, bsz, embed_dim = query.shape - # tgt_len, bsz, embed_dim = query.shape src_len, _, _ = key.shape if is_causal: @@ -373,9 +356,7 @@ def multi_head_attention_forward_paddle( if bias_k is not None and bias_v is not None: assert static_k is None, "bias cannot be added to static key." assert static_v is None, "bias cannot be added to static value." - # k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) k = paddle.concat([k, bias_k.repeat(1, bsz, 1)], axis=1) - # v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) v = paddle.concat([v, bias_v.repeat(1, bsz, 1)], axis=1) if attn_mask is not None: # attn_mask = pad(attn_mask, (0, 1)) @@ -392,22 +373,18 @@ def multi_head_attention_forward_paddle( # # reshape q, k, v for multihead attention and make em batch first # - # q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) q = q.reshape([tgt_len, bsz * num_heads, head_dim]).transpose([1, 0, 2]) if static_k is None: - # k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) k = k.reshape([k.shape[0], bsz * num_heads, head_dim]).transpose([1, 0, 2]) else: - # TODO finish disentangling control flow so we don't do in-projections when statics are passed assert static_k.size(0) == bsz * num_heads, \ f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}" assert static_k.size(2) == head_dim, \ f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" k = static_k if static_v is None: - # v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) v = v.reshape([v.shape[0], bsz * num_heads, head_dim]).transpose([1, 0, 2]) else: # TODO finish disentangling control flow so we don't do in-projections when statics are passed @@ -420,9 +397,7 @@ def multi_head_attention_forward_paddle( # add zero attention along batch dimension (now first) if add_zero_attn: zero_attn_shape = (bsz * num_heads, 1, head_dim) - # k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1) k = paddle.concat([k, paddle.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], axis=1) - # v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1) v = paddle.concat([v, paddle.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], axis=1) if attn_mask is not None: # attn_mask = pad(attn_mask, (0, 1)) @@ -438,7 +413,6 @@ def multi_head_attention_forward_paddle( if key_padding_mask is not None: assert key_padding_mask.shape == (bsz, src_len), \ f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" - # key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len) key_padding_mask = key_padding_mask.reshape([bsz, 1, 1, src_len]).expand([-1, num_heads, -1, -1]).reshape([bsz * num_heads, 1, src_len]) if attn_mask is None: attn_mask = key_padding_mask @@ -456,25 +430,20 @@ def multi_head_attention_forward_paddle( B, Nt, E = q.shape q_scaled = q / math.sqrt(E) if attn_mask is not None: - # attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1)) attn_output_weights = addr(q_scaled, k.transpose(-2, -1)) else: - # attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1)) attn_output_weights = paddle.bmm(q_scaled, k.transpose(0, 2, 1)) - # attn_output_weights = softmax(attn_output_weights, dim=-1) - attn_output_weights = paddle.nn.functional.softmax(attn_output_weights, axis=-1) + attn_output_weights = F.softmax(attn_output_weights, axis=-1) if dropout_p > 0.0: - # attn_output_weights = dropout(attn_output_weights, p=dropout_p) - attn_output_weights = paddle.nn.functional.dropout(attn_output_weights, p=dropout_p) + attn_output_weights = F.dropout(attn_output_weights, p=dropout_p) - # attn_output = torch.bmm(attn_output_weights, v) attn_output = paddle.bmm(attn_output_weights, v) attn_output = attn_output.transpose([1, 0, 2]).reshape([tgt_len * bsz, embed_dim]) - attn_output = linear(attn_output, out_proj_weight, out_proj_bias) + # attn_output = linear(attn_output, out_proj_weight, out_proj_bias) + attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias) attn_output = attn_output.reshape([tgt_len, bsz, attn_output.shape[1]]) # optionally average attention weights over heads - # attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) attn_output_weights = attn_output_weights.reshape([bsz, num_heads, tgt_len, src_len]) if average_attn_weights: attn_output_weights = attn_output_weights.mean(dim=1) @@ -492,7 +461,6 @@ def multi_head_attention_forward_paddle( if attn_mask.shape[0] == 1 and attn_mask.dim() == 3: attn_mask = attn_mask.unsqueeze(0) else: - # attn_mask = attn_mask.view(bsz, num_heads, -1, src_len) attn_mask = attn_mask.reshape([bsz, num_heads, -1, src_len]) q = q.reshape([bsz, num_heads, tgt_len, head_dim]) @@ -500,9 +468,6 @@ def multi_head_attention_forward_paddle( v = v.reshape([bsz, num_heads, src_len, head_dim]) attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal) attn_output = attn_output.transpose(perm=[2, 0, 1, 3]).reshape([bsz * tgt_len, embed_dim]) - attn_output = linear(attn_output, out_proj_weight, out_proj_bias) + attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias) attn_output = attn_output.reshape([tgt_len, bsz, attn_output.shape[1]]) - # if not is_batched: - # # squeeze the output if input was unbatched - # attn_output = attn_output.squeeze(1) return attn_output, None \ No newline at end of file diff --git a/paddlespeech/s2t/models/wavlm/modules/modules.py b/paddlespeech/s2t/models/wavlm/modules/modules.py index 09c72088..5ef42e60 100644 --- a/paddlespeech/s2t/models/wavlm/modules/modules.py +++ b/paddlespeech/s2t/models/wavlm/modules/modules.py @@ -60,19 +60,6 @@ class Fp32GroupNorm(nn.GroupNorm): return output.type_as(input) -# class GradMultiply(torch.autograd.Function): -# convert into paddle equivalent -# class GradMultiply(torch.autograd.Function): - # @staticmethod - # def forward(ctx, x, scale): - # ctx.scale = scale - # res = x.new(x) - # return res - - # @staticmethod - # def backward(ctx, grad): - # return grad * ctx.scale, None - class SamePad(nn.Layer): def __init__(self, kernel_size, causal=False): @@ -95,7 +82,6 @@ class Swish(nn.Layer): def __init__(self): """Construct an MultiHeadedAttention object.""" super(Swish, self).__init__() - # self.act = torch.nn.Sigmoid() self.act = nn.Sigmoid() def forward(self, x): @@ -162,7 +148,6 @@ def get_activation_fn(activation: str): elif activation == "gelu_accurate": return gelu_accurate elif activation == "tanh": - # return torch.tanh return paddle.tanh elif activation == "linear": return lambda x: x @@ -172,44 +157,6 @@ def get_activation_fn(activation: str): raise RuntimeError("--activation-fn {} not supported".format(activation)) -def init_bert_params(module): - """ - Initialize the weights specific to the BERT Model. - This overrides the default initializations depending on the specified arguments. - 1. If normal_init_linear_weights is set then weights of linear - layer will be initialized using the normal distribution and - bais will be set to the specified value. - 2. If normal_init_embed_weights is set then weights of embedding - layer will be initialized using the normal distribution. - 3. If normal_init_proj_weights is set then weights of - in_project_weight for MultiHeadAttention initialized using - the normal distribution (to be validated). - """ - - def normal_(data): - # with FSDP, module params will be on CUDA, so we cast them back to CPU - # so that the RNG is consistent with and without FSDP - data.copy_( - data.cpu().normal_(mean=0.0, std=0.02).to(data.device) - ) - - if isinstance(module, nn.Linear): - # normal_(module.weight.data) - if module.bias is not None: - # module.bias.data.zero_() - pass - if isinstance(module, nn.Embedding): - # normal_(module.weight.data) - if module.padding_idx is not None: - # module.weight.data[module.padding_idx].zero_() - pass - if isinstance(module, MultiheadAttention): - pass - # normal_(module.q_proj.weight.data) - # normal_(module.k_proj.weight.data) - # normal_(module.v_proj.weight.data) - - def quant_noise(module, p, block_size): """ Wraps modules and applies quantization noise to the weights for @@ -302,9 +249,8 @@ def quant_noise(module, p, block_size): # scale weights and apply mask mask = mask.to( - # torch.bool paddle.bool - ) # x.bool() is not currently supported in TorchScript + ) s = 1 / (1 - p) mod.weight.data = s * weight.masked_fill(mask, 0) @@ -405,7 +351,6 @@ class MultiheadAttention(nn.Layer): self.gru_rel_pos = gru_rel_pos if self.gru_rel_pos: self.grep_linear = nn.Linear(self.q_head_dim, 8) - # self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1)) self.grep_a = self.create_parameter( shape=[1, num_heads, 1, 1], dtype="float32" ) @@ -415,48 +360,7 @@ class MultiheadAttention(nn.Layer): def reset_parameters(self): pass - # if self.qkv_same_dim: - # # Empirically observed the convergence to be much better with - # # the scaled initialization - # # nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) - # # self.k_proj.weight.set_value( - # # paddle.nn.initializer.XavierUniform(1, 1)(self.k_proj.weight.shape) - # # ) - # # self.v_proj.weight.set_value( - # # paddle.nn.initializer.XavierUniform(1, 1)(self.v_proj.weight.shape) - # # ) - # # self.q_proj.weight.set_value( - # # paddle.nn.initializer.XavierUniform(1, 1)(self.q_proj.weight.shape) - # # ) - # pass - # # nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) - # # nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) - # else: - # # nn.init.xavier_uniform_(self.k_proj.weight) - # # nn.init.xavier_uniform_(self.v_proj.weight) - # # nn.init.xavier_uniform_(self.q_proj.weight) - # # self.k_proj.weight.set_value( - # # paddle.nn.initializer.XavierUniform()(self.k_proj.weight.shape) - # # ) - # # self.v_proj.weight.set_value( - # # paddle.nn.initializer.XavierUniform()(self.v_proj.weight.shape) - # # ) - # # self.q_proj.weight.set_value( - # # paddle.nn.initializer.XavierUniform()(self.q_proj.weight.shape) - # # ) - # pass - - - # nn.init.xavier_uniform_(self.out_proj.weight) - # if self.out_proj.bias is not None: - # nn.init.constant_(self.out_proj.bias, 0.0) - # if self.bias_k is not None: - # nn.init.xavier_normal_(self.bias_k) - # if self.bias_v is not None: - # nn.init.xavier_normal_(self.bias_v) - # if self.has_relative_attention_bias: - # nn.init.xavier_normal_(self.relative_attention_bias.weight) - + def _relative_positions_bucket(self, relative_positions, bidirectional=True): num_buckets = self.num_buckets max_distance = self.max_distance @@ -544,7 +448,6 @@ class MultiheadAttention(nn.Layer): position_bias = paddle.concat([position_bias_ for _ in range(bsz)], axis=0) position_bias = position_bias.reshape([bsz * self.num_heads, tgt_len, src_len]) if ( - # not is_tpu # don't use PyTorch version on TPUs incremental_state is None and not static_kv and self.q_head_dim == self.head_dim @@ -740,7 +643,6 @@ class MultiheadAttention(nn.Layer): ) - # attn_weights = torch.bmm(q, k.transpose(1, 2)) attn_weights = paddle.bmm(q, k.transpose(1, 2)) attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) @@ -753,16 +655,10 @@ class MultiheadAttention(nn.Layer): if key_padding_mask is not None: # don't attend to padding symbols attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - if not is_tpu: - attn_weights = attn_weights.masked_fill( - # key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), - key_padding_mask.unsqueeze(1).unsqueeze(2).to(paddle.bool), - float("-inf"), - ) - else: - attn_weights = attn_weights.transpose(0, 2) - attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) - attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(paddle.bool), + float("-inf"), + ) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if before_softmax: @@ -772,8 +668,6 @@ class MultiheadAttention(nn.Layer): if self.gru_rel_pos == 1: query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) _B, _H, _L, __ = query_layer.shape - # gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( - # _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) gate_a, gate_b = paddle.sigmoid(self.grep_linear(query_layer).view( _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, axis=-1) @@ -791,7 +685,6 @@ class MultiheadAttention(nn.Layer): attn_probs = self.dropout_module(attn_weights) assert v is not None - # attn = torch.bmm(attn_probs, v) attn = paddle.bmm(attn_probs, v) assert list(attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim] attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) @@ -819,9 +712,6 @@ class MultiheadAttention(nn.Layer): if prev_key_padding_mask is not None and static_kv: new_key_padding_mask = prev_key_padding_mask elif prev_key_padding_mask is not None and key_padding_mask is not None: - # new_key_padding_mask = torch.cat( - # [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 - # ) new_key_padding_mask = paddle.concat( [prev_key_padding_mask.float(), key_padding_mask.float()], axis=1 ) @@ -830,18 +720,10 @@ class MultiheadAttention(nn.Layer): # is None elif prev_key_padding_mask is not None: if src_len > prev_key_padding_mask.size(1): - # filler = torch.zeros( - # (batch_size, src_len - prev_key_padding_mask.size(1)), - # device=prev_key_padding_mask.device, - # ) filler = paddle.zeros( (batch_size, src_len - prev_key_padding_mask.size(1)), device=prev_key_padding_mask.device, ) - - # new_key_padding_mask = torch.cat( - # [prev_key_padding_mask.float(), filler.float()], dim=1 - # ) new_key_padding_mask = paddle.concat( [prev_key_padding_mask.float(), filler.float()], axis=1 ) @@ -850,17 +732,10 @@ class MultiheadAttention(nn.Layer): new_key_padding_mask = prev_key_padding_mask.float() elif key_padding_mask is not None: if src_len > key_padding_mask.size(1): - # filler = torch.zeros( - # (batch_size, src_len - key_padding_mask.size(1)), - # device=key_padding_mask.device, - # ) filler = paddle.zeros( (batch_size, src_len - key_padding_mask.size(1)), device=key_padding_mask.device, ) - # new_key_padding_mask = torch.cat( - # [filler.float(), key_padding_mask.float()], dim=1 - # ) new_key_padding_mask = paddle.concat( [filler.float(), key_padding_mask.float()], axis=1 ) diff --git a/paddlespeech/s2t/models/wavlm/wavlm_paddle.py b/paddlespeech/s2t/models/wavlm/wavlm_paddle.py index 3d1b5c4d..6ed9ecd0 100644 --- a/paddlespeech/s2t/models/wavlm/wavlm_paddle.py +++ b/paddlespeech/s2t/models/wavlm/wavlm_paddle.py @@ -21,7 +21,6 @@ from paddle import Tensor from .modules.modules import ( MultiheadAttention, SamePad, - init_bert_params, get_activation_fn, TransposeLast, GLU_Linear,