[Fix] use reshape instead of view (#3939)

pull/3946/head
megemini 9 months ago committed by GitHub
parent a34bf501a5
commit e3c4d4bd7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -177,8 +177,9 @@ def th_accuracy(pad_outputs: paddle.Tensor,
Returns: Returns:
float: Accuracy value (0.0 - 1.0). float: Accuracy value (0.0 - 1.0).
""" """
pad_pred = pad_outputs.view(pad_targets.shape[0], pad_targets.shape[1], pad_pred = pad_outputs.reshape(
pad_outputs.shape[1]).argmax(2) [pad_targets.shape[0], pad_targets.shape[1],
pad_outputs.shape[1]]).argmax(2)
mask = pad_targets != ignore_label mask = pad_targets != ignore_label
#TODO(Hui Zhang): sum not support bool type #TODO(Hui Zhang): sum not support bool type
# numerator = paddle.sum( # numerator = paddle.sum(

@ -86,7 +86,7 @@ class CTCPrefixScorePD():
dtype=self.dtype, ) # (T, 2, B, W) dtype=self.dtype, ) # (T, 2, B, W)
r_prev[:, 1] = paddle.cumsum(self.x[0, :, :, self.blank], r_prev[:, 1] = paddle.cumsum(self.x[0, :, :, self.blank],
0).unsqueeze(2) 0).unsqueeze(2)
r_prev = r_prev.view(-1, 2, n_bh) # (T, 2, BW) r_prev = r_prev.reshape([-1, 2, n_bh]) # (T, 2, BW)
s_prev = 0.0 # score s_prev = 0.0 # score
f_min_prev = 0 # eq. 22-23 f_min_prev = 0 # eq. 22-23
f_max_prev = 1 # eq. 22-23 f_max_prev = 1 # eq. 22-23
@ -100,23 +100,23 @@ class CTCPrefixScorePD():
(n_bh, self.odim), -1, dtype=paddle.long) (n_bh, self.odim), -1, dtype=paddle.long)
snum = self.scoring_num snum = self.scoring_num
if self.idx_bh is None or n_bh > len(self.idx_bh): if self.idx_bh is None or n_bh > len(self.idx_bh):
self.idx_bh = paddle.arange(n_bh).view(-1, 1) # (BW, 1) self.idx_bh = paddle.arange(n_bh).reshape([-1, 1]) # (BW, 1)
scoring_idmap[self.idx_bh[:n_bh], scoring_ids] = paddle.arange(snum) scoring_idmap[self.idx_bh[:n_bh], scoring_ids] = paddle.arange(snum)
scoring_idx = ( scoring_idx = (
scoring_ids + self.idx_bo.repeat(1, n_hyps).view(-1, scoring_ids + self.idx_bo.repeat(1, n_hyps).reshape(
1) # (BW,1) [-1, 1]) # (BW,1)
).view(-1) # (BWO) ).reshape([-1]) # (BWO)
# x_ shape (2, T, B*W, O) # x_ shape (2, T, B*W, O)
x_ = paddle.index_select( x_ = paddle.index_select(
self.x.view(2, -1, self.batch * self.odim), scoring_idx, self.x.reshape([2, -1, self.batch * self.odim]), scoring_idx,
2).view(2, -1, n_bh, snum) 2).reshape([2, -1, n_bh, snum])
else: else:
scoring_ids = None scoring_ids = None
scoring_idmap = None scoring_idmap = None
snum = self.odim snum = self.odim
# x_ shape (2, T, B*W, O) # x_ shape (2, T, B*W, O)
x_ = self.x.unsqueeze(3).repeat(1, 1, 1, n_hyps, 1).view(2, -1, x_ = self.x.unsqueeze(3).repeat(1, 1, 1, n_hyps, 1).reshape(
n_bh, snum) [2, -1, n_bh, snum])
# new CTC forward probs are prepared as a (T x 2 x BW x S) tensor # new CTC forward probs are prepared as a (T x 2 x BW x S) tensor
# that corresponds to r_t^n(h) and r_t^b(h) in a batch. # that corresponds to r_t^n(h) and r_t^b(h) in a batch.
@ -154,8 +154,8 @@ class CTCPrefixScorePD():
# compute forward probabilities log(r_t^n(h)) and log(r_t^b(h)) # compute forward probabilities log(r_t^n(h)) and log(r_t^b(h))
for t in range(start, end): for t in range(start, end):
rp = r[t - 1] # (2 x BW x O') rp = r[t - 1] # (2 x BW x O')
rr = paddle.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view( rr = paddle.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).reshape(
2, 2, n_bh, snum) # (2,2,BW,O') [2, 2, n_bh, snum]) # (2,2,BW,O')
r[t] = paddle.logsumexp(rr, 1) + x_[:, t] r[t] = paddle.logsumexp(rr, 1) + x_[:, t]
# compute log prefix probabilities log(psi) # compute log prefix probabilities log(psi)
@ -197,25 +197,27 @@ class CTCPrefixScorePD():
# convert ids to BHO space # convert ids to BHO space
n_bh = len(s) n_bh = len(s)
n_hyps = n_bh // self.batch n_hyps = n_bh // self.batch
vidx = (best_ids + (self.idx_b * vidx = (best_ids +
(n_hyps * self.odim)).view(-1, 1)).view(-1) (self.idx_b *
(n_hyps * self.odim)).reshape([-1, 1])).reshape([-1])
# select hypothesis scores # select hypothesis scores
s_new = paddle.index_select(s.view(-1), vidx, 0) s_new = paddle.index_select(s.reshape([-1]), vidx, 0)
s_new = s_new.view(-1, 1).repeat(1, self.odim).view(n_bh, self.odim) s_new = s_new.reshape([-1, 1]).repeat(1, self.odim).reshape(
[n_bh, self.odim])
# convert ids to BHS space (S: scoring_num) # convert ids to BHS space (S: scoring_num)
if scoring_idmap is not None: if scoring_idmap is not None:
snum = self.scoring_num snum = self.scoring_num
hyp_idx = (best_ids // self.odim + hyp_idx = (best_ids // self.odim +
(self.idx_b * n_hyps).view(-1, 1)).view(-1) (self.idx_b * n_hyps).reshape([-1, 1])).reshape([-1])
label_ids = paddle.fmod(best_ids, self.odim).view(-1) label_ids = paddle.fmod(best_ids, self.odim).reshape([-1])
score_idx = scoring_idmap[hyp_idx, label_ids] score_idx = scoring_idmap[hyp_idx, label_ids]
score_idx[score_idx == -1] = 0 score_idx[score_idx == -1] = 0
vidx = score_idx + hyp_idx * snum vidx = score_idx + hyp_idx * snum
else: else:
snum = self.odim snum = self.odim
# select forward probabilities # select forward probabilities
r_new = paddle.index_select(r.view(-1, 2, n_bh * snum), vidx, 2).view( r_new = paddle.index_select(r.reshape([-1, 2, n_bh * snum]), vidx,
-1, 2, n_bh) 2).reshape([-1, 2, n_bh])
return r_new, s_new, f_min, f_max return r_new, s_new, f_min, f_max
def extend_prob(self, x): def extend_prob(self, x):

@ -135,7 +135,7 @@ class BatchScorerInterface(ScorerInterface):
score, outstate = self.score(y, state, x) score, outstate = self.score(y, state, x)
outstates.append(outstate) outstates.append(outstate)
scores.append(score) scores.append(score)
scores = paddle.cat(scores, 0).view(ys.shape[0], -1) scores = paddle.cat(scores, 0).reshape([ys.shape[0], -1])
return scores, outstates return scores, outstates

@ -213,7 +213,7 @@ class HubertASR(nn.Layer):
x_lens = x.shape[1] x_lens = x.shape[1]
ctc_probs = self.ctc.log_softmax(x) # (B, maxlen, vocab_size) ctc_probs = self.ctc.log_softmax(x) # (B, maxlen, vocab_size)
topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1) topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1)
topk_index = topk_index.view(batch_size, x_lens) # (B, maxlen) topk_index = topk_index.reshape([batch_size, x_lens]) # (B, maxlen)
hyps = [hyp.tolist() for hyp in topk_index] hyps = [hyp.tolist() for hyp in topk_index]
hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps] hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]

@ -122,10 +122,12 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
h, _ = self.encoder(emb, xlen) h, _ = self.encoder(emb, xlen)
y = self.decoder(h) y = self.decoder(h)
loss = F.cross_entropy( loss = F.cross_entropy(
y.view(-1, paddle.shape(y)[-1]), t.view(-1), reduction="none") y.reshape([-1, paddle.shape(y)[-1]]),
t.reshape([-1]),
reduction="none")
mask = xm.to(loss.dtype) mask = xm.to(loss.dtype)
logp = loss * mask.view(-1) logp = loss * mask.reshape([-1])
nll = logp.view(batch_size, -1).sum(-1) nll = logp.reshape([batch_size, -1]).sum(-1)
nll_count = mask.sum(-1) nll_count = mask.sum(-1)
logp = logp.sum() logp = logp.sum()
count = mask.sum() count = mask.sum()

@ -176,7 +176,7 @@ class U2STBaseModel(nn.Layer):
# 2. Compute attention loss # 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_out_pad) loss_att = self.criterion_att(decoder_out, ys_out_pad)
acc_att = th_accuracy( acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size), decoder_out.reshape([-1, self.vocab_size]),
ys_out_pad, ys_out_pad,
ignore_label=self.ignore_id, ) ignore_label=self.ignore_id, )
return loss_att, acc_att return loss_att, acc_att
@ -209,7 +209,7 @@ class U2STBaseModel(nn.Layer):
# 2. Compute attention loss # 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_out_pad) loss_att = self.criterion_att(decoder_out, ys_out_pad)
acc_att = th_accuracy( acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size), decoder_out.reshape([-1, self.vocab_size]),
ys_out_pad, ys_out_pad,
ignore_label=self.ignore_id, ) ignore_label=self.ignore_id, )
return loss_att, acc_att return loss_att, acc_att

@ -6,17 +6,18 @@
# Based on fairseq code bases # Based on fairseq code bases
# https://github.com/pytorch/fairseq # https://github.com/pytorch/fairseq
# -------------------------------------------------------- # --------------------------------------------------------
import math import math
import warnings import warnings
from typing import Dict, Optional, Tuple from typing import Dict
from .functional import multi_head_attention_forward_paddle from typing import Optional
from typing import Tuple
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import Tensor from paddle import Tensor
from .functional import multi_head_attention_forward_paddle
class TransposeLast(nn.Layer): class TransposeLast(nn.Layer):
@ -40,8 +41,7 @@ class Fp32LayerNorm(nn.LayerNorm):
self.normalized_shape, self.normalized_shape,
self.weight.float() if self.weight is not None else None, self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None, self.bias.float() if self.bias is not None else None,
self.eps, self.eps, )
)
return output.type_as(input) return output.type_as(input)
@ -55,12 +55,10 @@ class Fp32GroupNorm(nn.GroupNorm):
self.num_groups, self.num_groups,
self.weight.float() if self.weight is not None else None, self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None, self.bias.float() if self.bias is not None else None,
self.eps, self.eps, )
)
return output.type_as(input) return output.type_as(input)
class SamePad(nn.Layer): class SamePad(nn.Layer):
def __init__(self, kernel_size, causal=False): def __init__(self, kernel_size, causal=False):
super().__init__() super().__init__()
@ -71,7 +69,7 @@ class SamePad(nn.Layer):
def forward(self, x): def forward(self, x):
if self.remove > 0: if self.remove > 0:
x = x[:, :, : -self.remove] x = x[:, :, :-self.remove]
return x return x
@ -89,7 +87,11 @@ class Swish(nn.Layer):
class GLU_Linear(nn.Layer): class GLU_Linear(nn.Layer):
def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True): def __init__(self,
input_dim,
output_dim,
glu_type="sigmoid",
bias_in_glu=True):
super(GLU_Linear, self).__init__() super(GLU_Linear, self).__init__()
self.glu_type = glu_type self.glu_type = glu_type
@ -114,9 +116,11 @@ class GLU_Linear(nn.Layer):
x = self.linear(x) x = self.linear(x)
if self.glu_type == "bilinear": if self.glu_type == "bilinear":
x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2]) x = (x[:, :, 0:self.output_dim] *
x[:, :, self.output_dim:self.output_dim * 2])
else: else:
x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2])) x = (x[:, :, 0:self.output_dim] *
self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
return x return x
@ -124,9 +128,8 @@ class GLU_Linear(nn.Layer):
def gelu_accurate(x): def gelu_accurate(x):
if not hasattr(gelu_accurate, "_a"): if not hasattr(gelu_accurate, "_a"):
gelu_accurate._a = math.sqrt(2 / math.pi) gelu_accurate._a = math.sqrt(2 / math.pi)
return ( return (0.5 * x * (1 + paddle.tanh(gelu_accurate._a *
0.5 * x * (1 + paddle.tanh(gelu_accurate._a * (x + 0.044715 * paddle.pow(x, 3)))) (x + 0.044715 * paddle.pow(x, 3)))))
)
def gelu(x: Tensor) -> Tensor: def gelu(x: Tensor) -> Tensor:
@ -142,8 +145,7 @@ def get_activation_fn(activation: str):
return gelu return gelu
elif activation == "gelu_fast": elif activation == "gelu_fast":
warnings.warn( warnings.warn(
"--activation-fn=gelu_fast has been renamed to gelu_accurate" "--activation-fn=gelu_fast has been renamed to gelu_accurate")
)
return gelu_accurate return gelu_accurate
elif activation == "gelu_accurate": elif activation == "gelu_accurate":
return gelu_accurate return gelu_accurate
@ -154,7 +156,8 @@ def get_activation_fn(activation: str):
elif activation == "glu": elif activation == "glu":
return lambda x: x return lambda x: x
else: else:
raise RuntimeError("--activation-fn {} not supported".format(activation)) raise RuntimeError(
"--activation-fn {} not supported".format(activation))
def quant_noise(module, p, block_size): def quant_noise(module, p, block_size):
@ -190,15 +193,14 @@ def quant_noise(module, p, block_size):
# 2D matrix # 2D matrix
if not is_conv: if not is_conv:
assert ( assert (
module.weight.size(1) % block_size == 0 module.weight.size(1) %
), "Input features must be a multiple of block sizes" block_size == 0), "Input features must be a multiple of block sizes"
# 4D matrix # 4D matrix
else: else:
# 1x1 convolutions # 1x1 convolutions
if module.kernel_size == (1, 1): if module.kernel_size == (1, 1):
assert ( assert (module.in_channels % block_size == 0
module.in_channels % block_size == 0
), "Input channels must be a multiple of block sizes" ), "Input channels must be a multiple of block sizes"
# regular convolutions # regular convolutions
else: else:
@ -216,10 +218,11 @@ def quant_noise(module, p, block_size):
# split weight matrix into blocks and randomly drop selected blocks # split weight matrix into blocks and randomly drop selected blocks
mask = paddle.zeros( mask = paddle.zeros(
in_features // block_size * out_features, device=weight.device in_features // block_size * out_features,
) device=weight.device)
mask.bernoulli_(p) mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) mask = mask.repeat_interleave(block_size, -1).reshape(
[-1, in_features])
else: else:
# gather weight and sizes # gather weight and sizes
@ -231,26 +234,21 @@ def quant_noise(module, p, block_size):
if mod.kernel_size == (1, 1): if mod.kernel_size == (1, 1):
mask = paddle.zeros( mask = paddle.zeros(
int(in_channels // block_size * out_channels), int(in_channels // block_size * out_channels),
device=weight.device, device=weight.device, )
)
mask.bernoulli_(p) mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) mask = mask.repeat_interleave(block_size, -1).reshape(
[-1, in_channels])
else: else:
mask = paddle.zeros( mask = paddle.zeros(
weight.size(0), weight.size(1), device=weight.device weight.size(0), weight.size(1), device=weight.device)
)
mask.bernoulli_(p) mask.bernoulli_(p)
mask = ( mask = (
mask.unsqueeze(2) mask.unsqueeze(2).unsqueeze(3)
.unsqueeze(3) .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]))
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
)
# scale weights and apply mask # scale weights and apply mask
mask = mask.to( mask = mask.to(paddle.bool)
paddle.bool
)
s = 1 / (1 - p) s = 1 / (1 - p)
mod.weight.data = s * weight.masked_fill(mask, 0) mod.weight.data = s * weight.masked_fill(mask, 0)
@ -282,8 +280,7 @@ class MultiheadAttention(nn.Layer):
num_buckets=32, num_buckets=32,
max_distance=128, max_distance=128,
gru_rel_pos=True, gru_rel_pos=True,
rescale_init=False, rescale_init=False, ):
):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim self.kdim = kdim if kdim is not None else embed_dim
@ -302,17 +299,16 @@ class MultiheadAttention(nn.Layer):
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
self.q_head_dim = self.head_dim self.q_head_dim = self.head_dim
self.k_head_dim = self.head_dim self.k_head_dim = self.head_dim
assert ( assert (self.head_dim * num_heads == self.embed_dim
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads" ), "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5 self.scaling = self.head_dim**-0.5
self.self_attention = self_attention self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention self.encoder_decoder_attention = encoder_decoder_attention
assert not self.self_attention or self.qkv_same_dim, ( assert not self.self_attention or self.qkv_same_dim, (
"Self-attention requires query, key and " "value to be of the same size" "Self-attention requires query, key and "
) "value to be of the same size")
k_bias = True k_bias = True
if rescale_init: if rescale_init:
@ -322,26 +318,24 @@ class MultiheadAttention(nn.Layer):
q_embed_dim = embed_dim q_embed_dim = embed_dim
self.k_proj = quant_noise( self.k_proj = quant_noise(
nn.Linear(self.kdim, k_embed_dim, bias_attr=k_bias), q_noise, qn_block_size nn.Linear(self.kdim, k_embed_dim, bias_attr=k_bias), q_noise,
) qn_block_size)
self.v_proj = quant_noise( self.v_proj = quant_noise(
nn.Linear(self.vdim, embed_dim, bias_attr=bias), q_noise, qn_block_size nn.Linear(self.vdim, embed_dim, bias_attr=bias), q_noise,
) qn_block_size)
self.q_proj = quant_noise( self.q_proj = quant_noise(
nn.Linear(embed_dim, q_embed_dim, bias_attr=bias), q_noise, qn_block_size nn.Linear(embed_dim, q_embed_dim, bias_attr=bias), q_noise,
) qn_block_size)
self.out_proj = quant_noise( self.out_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, bias_attr=bias), q_noise, qn_block_size nn.Linear(embed_dim, embed_dim, bias_attr=bias), q_noise,
) qn_block_size)
if add_bias_kv: if add_bias_kv:
self.bias_k = self.create_parameter( self.bias_k = self.create_parameter(
shape=[1, 1, embed_dim], dtype="float32" shape=[1, 1, embed_dim], dtype="float32")
)
self.bias_v = self.create_parameter( self.bias_v = self.create_parameter(
shape=[1, 1, embed_dim], dtype="float32" shape=[1, 1, embed_dim], dtype="float32")
)
else: else:
self.bias_k = self.bias_v = None self.bias_k = self.bias_v = None
@ -352,40 +346,41 @@ class MultiheadAttention(nn.Layer):
if self.gru_rel_pos: if self.gru_rel_pos:
self.grep_linear = nn.Linear(self.q_head_dim, 8) self.grep_linear = nn.Linear(self.q_head_dim, 8)
self.grep_a = self.create_parameter( self.grep_a = self.create_parameter(
shape=[1, num_heads, 1, 1], dtype="float32" shape=[1, num_heads, 1, 1], dtype="float32")
)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
pass pass
def _relative_positions_bucket(self, relative_positions, bidirectional=True): def _relative_positions_bucket(self, relative_positions,
bidirectional=True):
num_buckets = self.num_buckets num_buckets = self.num_buckets
max_distance = self.max_distance max_distance = self.max_distance
relative_buckets = 0 relative_buckets = 0
if bidirectional: if bidirectional:
num_buckets = num_buckets // 2 num_buckets = num_buckets // 2
relative_buckets += (relative_positions > 0).astype("int64") * num_buckets relative_buckets += (
relative_positions > 0).astype("int64") * num_buckets
relative_positions = paddle.abs(relative_positions) relative_positions = paddle.abs(relative_positions)
else: else:
relative_positions = -paddle.minimum(relative_positions, paddle.zeros_like(relative_positions)) relative_positions = -paddle.minimum(
relative_positions, paddle.zeros_like(relative_positions))
max_exact = num_buckets // 2 max_exact = num_buckets // 2
is_small = relative_positions < max_exact is_small = relative_positions < max_exact
relative_postion_if_large = max_exact + ( relative_postion_if_large = max_exact + (
paddle.log(relative_positions.astype("float32") / max_exact) paddle.log(relative_positions.astype("float32") /
/ math.log(max_distance / max_exact) max_exact) / math.log(max_distance / max_exact) *
* (num_buckets - max_exact) (num_buckets - max_exact)).astype("int64")
).astype("int64")
relative_postion_if_large = paddle.minimum( relative_postion_if_large = paddle.minimum(
relative_postion_if_large, paddle.full_like(relative_postion_if_large, num_buckets - 1) relative_postion_if_large,
) paddle.full_like(relative_postion_if_large, num_buckets - 1))
relative_buckets += paddle.where(is_small, relative_positions, relative_postion_if_large) relative_buckets += paddle.where(is_small, relative_positions,
relative_postion_if_large)
return relative_buckets return relative_buckets
def compute_bias(self, query_length, key_length): def compute_bias(self, query_length, key_length):
@ -393,27 +388,25 @@ class MultiheadAttention(nn.Layer):
memory_position = paddle.arange(key_length, dtype="int64")[None, :] memory_position = paddle.arange(key_length, dtype="int64")[None, :]
relative_position = memory_position - context_position relative_position = memory_position - context_position
relative_position_bucket = self._relative_positions_bucket( relative_position_bucket = self._relative_positions_bucket(
relative_position, relative_position, bidirectional=True)
bidirectional=True
)
# relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) # relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
values = self.relative_attention_bias(relative_position_bucket) values = self.relative_attention_bias(relative_position_bucket)
values = values.transpose([2, 0, 1]) values = values.transpose([2, 0, 1])
return values return values
def forward( def forward(self,
self,
query, query,
key: Optional[Tensor], key: Optional[Tensor],
value: Optional[Tensor], value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor]=None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[
need_weights: bool = True, Tensor]]]]=None,
static_kv: bool = False, need_weights: bool=True,
attn_mask: Optional[Tensor] = None, static_kv: bool=False,
before_softmax: bool = False, attn_mask: Optional[Tensor]=None,
need_head_weights: bool = False, before_softmax: bool=False,
position_bias: Optional[Tensor] = None need_head_weights: bool=False,
position_bias: Optional[Tensor]=None
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
"""Input shape: Time x Batch x Channel """Input shape: Time x Batch x Channel
@ -445,13 +438,12 @@ class MultiheadAttention(nn.Layer):
if self.has_relative_attention_bias and position_bias is None: if self.has_relative_attention_bias and position_bias is None:
position_bias = self.compute_bias(tgt_len, src_len) position_bias = self.compute_bias(tgt_len, src_len)
position_bias_ = position_bias.unsqueeze(0) position_bias_ = position_bias.unsqueeze(0)
position_bias = paddle.concat([position_bias_ for _ in range(bsz)], axis=0) position_bias = paddle.concat(
position_bias = position_bias.reshape([bsz * self.num_heads, tgt_len, src_len]) [position_bias_ for _ in range(bsz)], axis=0)
if ( position_bias = position_bias.reshape(
incremental_state is None [bsz * self.num_heads, tgt_len, src_len])
and not static_kv if (incremental_state is None and not static_kv and
and self.q_head_dim == self.head_dim self.q_head_dim == self.head_dim):
):
assert key is not None and value is not None assert key is not None and value is not None
assert attn_mask is None assert attn_mask is None
@ -465,17 +457,21 @@ class MultiheadAttention(nn.Layer):
query_layer = query_layer.transpose([0, 2, 1, 3]) query_layer = query_layer.transpose([0, 2, 1, 3])
_B, _H, _L, __ = query_layer.shape _B, _H, _L, __ = query_layer.shape
gate_a, gate_b = paddle.nn.functional.sigmoid(self.grep_linear(query_layer).reshape([_B, _H, _L, 2, 4]).sum(-1, keepdim=False)).chunk(2, axis=-1) gate_a, gate_b = paddle.nn.functional.sigmoid(
self.grep_linear(query_layer).reshape(
[_B, _H, _L, 2, 4]).sum(-1, keepdim=False)).chunk(
2, axis=-1)
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
attn_mask_rel_pos = gate_a_1.reshape([bsz * self.num_heads, -1, 1]) * position_bias attn_mask_rel_pos = gate_a_1.reshape(
[bsz * self.num_heads, -1, 1]) * position_bias
attn_mask_rel_pos = attn_mask_rel_pos.reshape((-1, tgt_len, tgt_len)) attn_mask_rel_pos = attn_mask_rel_pos.reshape(
(-1, tgt_len, tgt_len))
k_proj_bias = self.k_proj.bias k_proj_bias = self.k_proj.bias
if k_proj_bias is None: if k_proj_bias is None:
k_proj_bias = paddle.zeros_like(self.q_proj.bias) k_proj_bias = paddle.zeros_like(self.q_proj.bias)
x, attn = multi_head_attention_forward_paddle( x, attn = multi_head_attention_forward_paddle(
query, query,
key, key,
@ -483,7 +479,9 @@ class MultiheadAttention(nn.Layer):
self.embed_dim, self.embed_dim,
self.num_heads, self.num_heads,
paddle.empty([0]), paddle.empty([0]),
paddle.concat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias), axis=0), paddle.concat(
(self.q_proj.bias, self.k_proj.bias, self.v_proj.bias),
axis=0),
self.bias_k, self.bias_k,
self.bias_v, self.bias_v,
self.add_zero_attn, self.add_zero_attn,
@ -497,8 +495,7 @@ class MultiheadAttention(nn.Layer):
use_separate_proj_weight=True, use_separate_proj_weight=True,
q_proj_weight=self.q_proj.weight, q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight, k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight, v_proj_weight=self.v_proj.weight, )
)
return x, attn, position_bias return x, attn, position_bias
@ -540,8 +537,8 @@ class MultiheadAttention(nn.Layer):
v = paddle.concat([v, self.bias_v.repeat(1, bsz, 1)], axis=0) v = paddle.concat([v, self.bias_v.repeat(1, bsz, 1)], axis=0)
if attn_mask is not None: if attn_mask is not None:
attn_mask = paddle.concat( attn_mask = paddle.concat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], axis=1 [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)],
) axis=1)
if key_padding_mask is not None: if key_padding_mask is not None:
key_padding_mask = paddle.concat( key_padding_mask = paddle.concat(
@ -549,33 +546,27 @@ class MultiheadAttention(nn.Layer):
key_padding_mask, key_padding_mask,
key_padding_mask.new_zeros(key_padding_mask.size(0), 1), key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
], ],
axis=1, axis=1, )
)
q = (q.contiguous()
q = ( .reshape([tgt_len, bsz * self.num_heads, self.q_head_dim])
q.contiguous() .transpose([1, 0, 2]))
.view(tgt_len, bsz * self.num_heads, self.q_head_dim)
.transpose([1, 0, 2])
)
if k is not None: if k is not None:
k = ( k = (k.contiguous()
k.contiguous() .reshape([-1, bsz * self.num_heads, self.k_head_dim])
.view(-1, bsz * self.num_heads, self.k_head_dim) .transpose([1, 0, 2]))
.transpose([1, 0, 2])
)
if v is not None: if v is not None:
v = ( v = (v.contiguous()
v.contiguous() .reshape([-1, bsz * self.num_heads, self.head_dim])
.view(-1, bsz * self.num_heads, self.head_dim) .transpose([1, 0, 2]))
.transpose([1, 0, 2])
)
if saved_state is not None: if saved_state is not None:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim) # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if "prev_key" in saved_state: if "prev_key" in saved_state:
_prev_key = saved_state["prev_key"] _prev_key = saved_state["prev_key"]
assert _prev_key is not None assert _prev_key is not None
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) prev_key = _prev_key.reshape(
[bsz * self.num_heads, -1, self.head_dim])
if static_kv: if static_kv:
k = prev_key k = prev_key
else: else:
@ -585,7 +576,8 @@ class MultiheadAttention(nn.Layer):
if "prev_value" in saved_state: if "prev_value" in saved_state:
_prev_value = saved_state["prev_value"] _prev_value = saved_state["prev_value"]
assert _prev_value is not None assert _prev_value is not None
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) prev_value = _prev_value.reshape(
[bsz * self.num_heads, -1, self.head_dim])
if static_kv: if static_kv:
v = prev_value v = prev_value
else: else:
@ -600,15 +592,17 @@ class MultiheadAttention(nn.Layer):
prev_key_padding_mask=prev_key_padding_mask, prev_key_padding_mask=prev_key_padding_mask,
batch_size=bsz, batch_size=bsz,
src_len=k.size(1), src_len=k.size(1),
static_kv=static_kv, static_kv=static_kv, )
)
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) saved_state["prev_key"] = k.reshape(
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) [bsz, self.num_heads, -1, self.head_dim])
saved_state["prev_value"] = v.reshape(
[bsz, self.num_heads, -1, self.head_dim])
saved_state["prev_key_padding_mask"] = key_padding_mask saved_state["prev_key_padding_mask"] = key_padding_mask
# In this branch incremental_state is never None # In this branch incremental_state is never None
assert incremental_state is not None assert incremental_state is not None
incremental_state = self._set_input_buffer(incremental_state, saved_state) incremental_state = self._set_input_buffer(incremental_state,
saved_state)
assert k is not None assert k is not None
assert k.size(1) == src_len assert k.size(1) == src_len
@ -624,30 +618,31 @@ class MultiheadAttention(nn.Layer):
if self.add_zero_attn: if self.add_zero_attn:
assert v is not None assert v is not None
src_len += 1 src_len += 1
k = paddle.concat([k, k.new_zeros((k.size(0), 1) + k.shape[2:])], axis=1) k = paddle.concat(
v = paddle.concat([v, v.new_zeros((v.size(0), 1) + v.shape[2:])], axis=1) [k, k.new_zeros((k.size(0), 1) + k.shape[2:])], axis=1)
v = paddle.concat(
[v, v.new_zeros((v.size(0), 1) + v.shape[2:])], axis=1)
if attn_mask is not None: if attn_mask is not None:
attn_mask = paddle.concat( attn_mask = paddle.concat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], axis=1 [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)],
) axis=1)
if key_padding_mask is not None: if key_padding_mask is not None:
key_padding_mask = paddle.concat( key_padding_mask = paddle.concat(
[ [
key_padding_mask, key_padding_mask,
paddle.zeros(key_padding_mask.size(0), 1).type_as( paddle.zeros(key_padding_mask.size(0),
key_padding_mask 1).type_as(key_padding_mask),
),
], ],
axis=1, axis=1, )
)
attn_weights = paddle.matmul(q, k.transpose([0, 2, 1])) attn_weights = paddle.matmul(q, k.transpose([0, 2, 1]))
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len,
bsz)
assert list(attn_weights.shape) == [bsz * self.num_heads, tgt_len, src_len] assert list(
attn_weights.shape) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None: if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0) attn_mask = attn_mask.unsqueeze(0)
@ -655,46 +650,49 @@ class MultiheadAttention(nn.Layer):
if key_padding_mask is not None: if key_padding_mask is not None:
# don't attend to padding symbols # don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.reshape(
[bsz, self.num_heads, tgt_len, src_len])
attn_weights = attn_weights.masked_fill( attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(paddle.bool), key_padding_mask.unsqueeze(1).unsqueeze(2).to(paddle.bool),
float("-inf"), float("-inf"), )
) attn_weights = attn_weights.reshape(
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) [bsz * self.num_heads, tgt_len, src_len])
if before_softmax: if before_softmax:
return attn_weights, v, position_bias return attn_weights, v, position_bias
if position_bias is not None: if position_bias is not None:
if self.gru_rel_pos == 1: if self.gru_rel_pos == 1:
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) query_layer = q.reshape(
[bsz, self.num_heads, tgt_len, self.q_head_dim])
_B, _H, _L, __ = query_layer.shape _B, _H, _L, __ = query_layer.shape
gate_a, gate_b = paddle.sigmoid(self.grep_linear(query_layer).view( gate_a, gate_b = paddle.sigmoid(
_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, axis=-1) self.grep_linear(query_layer).reshape([_B, _H, _L, 2, 4])
.sum(-1, keepdim=False)).chunk(
2, axis=-1)
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias position_bias = gate_a_1.reshape(
[bsz * self.num_heads, -1, 1]) * position_bias
position_bias = position_bias.view(attn_weights.shape) position_bias = position_bias.reshape(attn_weights.shape)
attn_weights = attn_weights + position_bias attn_weights = attn_weights + position_bias
attn_weights_float = F.softmax( attn_weights_float = F.softmax(attn_weights, dim=-1)
attn_weights, dim=-1
)
attn_weights = attn_weights_float.type_as(attn_weights) attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = self.dropout_module(attn_weights) attn_probs = self.dropout_module(attn_weights)
assert v is not None assert v is not None
attn = paddle.bmm(attn_probs, v) attn = paddle.bmm(attn_probs, v)
assert list(attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim] assert list(
attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.transpose([1, 0, 2]).reshape([tgt_len, bsz, embed_dim]) attn = attn.transpose([1, 0, 2]).reshape([tgt_len, bsz, embed_dim])
attn = self.out_proj(attn) attn = self.out_proj(attn)
attn_weights: Optional[Tensor] = None attn_weights: Optional[Tensor] = None
if need_weights: if need_weights:
attn_weights = attn_weights_float.view( attn_weights = attn_weights_float.reshape(
bsz, self.num_heads, tgt_len, src_len [bsz, self.num_heads, tgt_len, src_len]).transpose([1, 0, 2, 3])
).transpose([1, 0, 2, 3])
if not need_head_weights: if not need_head_weights:
# average attention weights over heads # average attention weights over heads
attn_weights = attn_weights.mean(dim=0) attn_weights = attn_weights.mean(dim=0)
@ -707,15 +705,14 @@ class MultiheadAttention(nn.Layer):
prev_key_padding_mask: Optional[Tensor], prev_key_padding_mask: Optional[Tensor],
batch_size: int, batch_size: int,
src_len: int, src_len: int,
static_kv: bool, static_kv: bool, ) -> Optional[Tensor]:
) -> Optional[Tensor]:
# saved key padding masks have shape (bsz, seq_len) # saved key padding masks have shape (bsz, seq_len)
if prev_key_padding_mask is not None and static_kv: if prev_key_padding_mask is not None and static_kv:
new_key_padding_mask = prev_key_padding_mask new_key_padding_mask = prev_key_padding_mask
elif prev_key_padding_mask is not None and key_padding_mask is not None: elif prev_key_padding_mask is not None and key_padding_mask is not None:
new_key_padding_mask = paddle.concat( new_key_padding_mask = paddle.concat(
[prev_key_padding_mask.float(), key_padding_mask.float()], axis=1 [prev_key_padding_mask.float(), key_padding_mask.float()],
) axis=1)
# During incremental decoding, as the padding token enters and # During incremental decoding, as the padding token enters and
# leaves the frame, there will be a time when prev or current # leaves the frame, there will be a time when prev or current
# is None # is None
@ -723,11 +720,9 @@ class MultiheadAttention(nn.Layer):
if src_len > prev_key_padding_mask.size(1): if src_len > prev_key_padding_mask.size(1):
filler = paddle.zeros( filler = paddle.zeros(
(batch_size, src_len - prev_key_padding_mask.size(1)), (batch_size, src_len - prev_key_padding_mask.size(1)),
device=prev_key_padding_mask.device, device=prev_key_padding_mask.device, )
)
new_key_padding_mask = paddle.concat( new_key_padding_mask = paddle.concat(
[prev_key_padding_mask.float(), filler.float()], axis=1 [prev_key_padding_mask.float(), filler.float()], axis=1)
)
else: else:
new_key_padding_mask = prev_key_padding_mask.float() new_key_padding_mask = prev_key_padding_mask.float()
@ -735,11 +730,9 @@ class MultiheadAttention(nn.Layer):
if src_len > key_padding_mask.size(1): if src_len > key_padding_mask.size(1):
filler = paddle.zeros( filler = paddle.zeros(
(batch_size, src_len - key_padding_mask.size(1)), (batch_size, src_len - key_padding_mask.size(1)),
device=key_padding_mask.device, device=key_padding_mask.device, )
)
new_key_padding_mask = paddle.concat( new_key_padding_mask = paddle.concat(
[filler.float(), key_padding_mask.float()], axis=1 [filler.float(), key_padding_mask.float()], axis=1)
)
else: else:
new_key_padding_mask = key_padding_mask.float() new_key_padding_mask = key_padding_mask.float()
@ -748,7 +741,8 @@ class MultiheadAttention(nn.Layer):
return new_key_padding_mask return new_key_padding_mask
def _get_input_buffer( def _get_input_buffer(
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] self,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
) -> Dict[str, Optional[Tensor]]: ) -> Dict[str, Optional[Tensor]]:
result = self.get_incremental_state(incremental_state, "attn_state") result = self.get_incremental_state(incremental_state, "attn_state")
if result is not None: if result is not None:
@ -760,9 +754,13 @@ class MultiheadAttention(nn.Layer):
def _set_input_buffer( def _set_input_buffer(
self, self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]], incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
buffer: Dict[str, Optional[Tensor]], buffer: Dict[str, Optional[Tensor]], ):
): return self.set_incremental_state(incremental_state, "attn_state",
return self.set_incremental_state(incremental_state, "attn_state", buffer) buffer)
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): def apply_sparse_mask(self,
attn_weights,
tgt_len: int,
src_len: int,
bsz: int):
return attn_weights return attn_weights

@ -188,7 +188,7 @@ class WavLMASR(nn.Layer):
x_lens = x.shape[1] x_lens = x.shape[1]
ctc_probs = self.ctc.log_softmax(x) # (B, maxlen, vocab_size) ctc_probs = self.ctc.log_softmax(x) # (B, maxlen, vocab_size)
topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1) topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1)
topk_index = topk_index.view(batch_size, x_lens) # (B, maxlen) topk_index = topk_index.reshape([batch_size, x_lens]) # (B, maxlen)
hyps = [hyp.tolist() for hyp in topk_index] hyps = [hyp.tolist() for hyp in topk_index]
hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps] hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]

@ -297,8 +297,8 @@ class WavLM(nn.Layer):
extra = padding_mask.size(1) % features.size(1) extra = padding_mask.size(1) % features.size(1)
if extra > 0: if extra > 0:
padding_mask = padding_mask[:, :-extra] padding_mask = padding_mask[:, :-extra]
padding_mask = padding_mask.view( padding_mask = padding_mask.reshape(
padding_mask.size(0), features.size(1), -1) [padding_mask.size(0), features.size(1), -1])
padding_mask = padding_mask.all(-1) padding_mask = padding_mask.all(-1)
return padding_mask return padding_mask
@ -475,14 +475,15 @@ class ConvFeatureExtractionModel(nn.Layer):
else: else:
x = conv(x) x = conv(x)
x = x.transpose([0, 1, 3, 2]).contiguous() x = x.transpose([0, 1, 3, 2]).contiguous()
x = x.view(x.size(0), -1, x.size(-1)) x = x.reshape([x.size(0), -1, x.size(-1)])
else: else:
for conv in self.conv_layers: for conv in self.conv_layers:
x = conv(x) x = conv(x)
if self.conv_type == "conv2d": if self.conv_type == "conv2d":
b, c, t, f = x.size() b, c, t, f = x.size()
# x = x.transpose(2, 3).contiguous().view(b, c * f, t) # x = x.transpose(2, 3).contiguous().reshape([b, c * f, t])
x = x.transpose([0, 1, 3, 2]).contiguous().view(b, c * f, t) x = x.transpose([0, 1, 3, 2]).contiguous().reshape(
[b, c * f, t])
return x return x

@ -181,8 +181,9 @@ def th_accuracy(pad_outputs: paddle.Tensor,
Returns: Returns:
float: Accuracy value (0.0 - 1.0). float: Accuracy value (0.0 - 1.0).
""" """
pad_pred = pad_outputs.view(pad_targets.shape[0], pad_targets.shape[1], pad_pred = pad_outputs.reshape(
pad_outputs.shape[1]).argmax(2) [pad_targets.shape[0], pad_targets.shape[1],
pad_outputs.shape[1]]).argmax(2)
mask = pad_targets != ignore_label mask = pad_targets != ignore_label
numerator = paddle.sum( numerator = paddle.sum(

@ -751,10 +751,10 @@ class JETSGenerator(nn.Layer):
# integrate with SID and LID embeddings # integrate with SID and LID embeddings
if self.spks is not None: if self.spks is not None:
sid_embs = self.sid_emb(sids.view(-1)) sid_embs = self.sid_emb(sids.reshape([-1]))
hs = hs + sid_embs.unsqueeze(1) hs = hs + sid_embs.unsqueeze(1)
if self.langs is not None: if self.langs is not None:
lid_embs = self.lid_emb(lids.view(-1)) lid_embs = self.lid_emb(lids.reshape([-1]))
hs = hs + lid_embs.unsqueeze(1) hs = hs + lid_embs.unsqueeze(1)
# integrate speaker embedding # integrate speaker embedding

Loading…
Cancel
Save