replace view with reshape in aishell/asr1 (#3887)

pull/3890/head
Wang Xin 1 month ago committed by GitHub
parent 6f44ac92c8
commit 62c21e951f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -190,7 +190,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad) r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad)
loss_att = loss_att * (1 - reverse_weight) + r_loss_att * reverse_weight loss_att = loss_att * (1 - reverse_weight) + r_loss_att * reverse_weight
acc_att = th_accuracy( acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size), decoder_out.reshape([-1, self.vocab_size]),
ys_out_pad, ys_out_pad,
ignore_label=self.ignore_id, ) ignore_label=self.ignore_id, )
return loss_att, acc_att return loss_att, acc_att
@ -271,11 +271,13 @@ class U2BaseModel(ASRInterface, nn.Layer):
maxlen = encoder_out.shape[1] maxlen = encoder_out.shape[1]
encoder_dim = encoder_out.shape[2] encoder_dim = encoder_out.shape[2]
running_size = batch_size * beam_size running_size = batch_size * beam_size
encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view( encoder_out = encoder_out.unsqueeze(1).repeat(
running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim) 1, beam_size, 1, 1).reshape(
[running_size, maxlen,
encoder_dim]) # (B*N, maxlen, encoder_dim)
encoder_mask = encoder_mask.unsqueeze(1).repeat( encoder_mask = encoder_mask.unsqueeze(1).repeat(
1, beam_size, 1, 1).view(running_size, 1, 1, beam_size, 1, 1).reshape([running_size, 1,
maxlen) # (B*N, 1, max_len) maxlen]) # (B*N, 1, max_len)
hyps = paddle.ones( hyps = paddle.ones(
[running_size, 1], dtype=paddle.long).fill_(self.sos) # (B*N, 1) [running_size, 1], dtype=paddle.long).fill_(self.sos) # (B*N, 1)
@ -305,34 +307,35 @@ class U2BaseModel(ASRInterface, nn.Layer):
# 2.3 Seconde beam prune: select topk score with history # 2.3 Seconde beam prune: select topk score with history
scores = scores + top_k_logp # (B*N, N), broadcast add scores = scores + top_k_logp # (B*N, N), broadcast add
scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N) scores = scores.reshape(
[batch_size, beam_size * beam_size]) # (B, N*N)
scores, offset_k_index = scores.topk(k=beam_size) # (B, N) scores, offset_k_index = scores.topk(k=beam_size) # (B, N)
scores = scores.view(-1, 1) # (B*N, 1) scores = scores.reshape([-1, 1]) # (B*N, 1)
# 2.4. Compute base index in top_k_index, # 2.4. Compute base index in top_k_index,
# regard top_k_index as (B*N*N),regard offset_k_index as (B*N), # regard top_k_index as (B*N*N),regard offset_k_index as (B*N),
# then find offset_k_index in top_k_index # then find offset_k_index in top_k_index
base_k_index = paddle.arange(batch_size).view(-1, 1).repeat( base_k_index = paddle.arange(batch_size).reshape([-1, 1]).repeat(
1, beam_size) # (B, N) 1, beam_size) # (B, N)
base_k_index = base_k_index * beam_size * beam_size base_k_index = base_k_index * beam_size * beam_size
best_k_index = base_k_index.view(-1) + offset_k_index.view( best_k_index = base_k_index.reshape([-1]) + offset_k_index.reshape(
-1) # (B*N) [-1]) # (B*N)
# 2.5 Update best hyps # 2.5 Update best hyps
best_k_pred = paddle.index_select( best_k_pred = paddle.index_select(
top_k_index.view(-1), index=best_k_index, axis=0) # (B*N) top_k_index.reshape([-1]), index=best_k_index, axis=0) # (B*N)
best_hyps_index = best_k_index // beam_size best_hyps_index = best_k_index // beam_size
last_best_k_hyps = paddle.index_select( last_best_k_hyps = paddle.index_select(
hyps, index=best_hyps_index, axis=0) # (B*N, i) hyps, index=best_hyps_index, axis=0) # (B*N, i)
hyps = paddle.cat( hyps = paddle.cat(
(last_best_k_hyps, best_k_pred.view(-1, 1)), (last_best_k_hyps, best_k_pred.reshape([-1, 1])),
dim=1) # (B*N, i+1) dim=1) # (B*N, i+1)
# 2.6 Update end flag # 2.6 Update end flag
end_flag = paddle.equal(hyps[:, -1], self.eos).view(-1, 1) end_flag = paddle.equal(hyps[:, -1], self.eos).reshape([-1, 1])
# 3. Select best of best # 3. Select best of best
scores = scores.view(batch_size, beam_size) scores = scores.reshape([batch_size, beam_size])
# TODO: length normalization # TODO: length normalization
best_index = paddle.argmax(scores, axis=-1).long() # (B) best_index = paddle.argmax(scores, axis=-1).long() # (B)
best_hyps_index = best_index + paddle.arange( best_hyps_index = best_index + paddle.arange(
@ -379,7 +382,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size) ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size)
topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1) topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1)
topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen) topk_index = topk_index.reshape([batch_size, maxlen]) # (B, maxlen)
pad_mask = make_pad_mask(encoder_out_lens) # (B, maxlen) pad_mask = make_pad_mask(encoder_out_lens) # (B, maxlen)
topk_index = topk_index.masked_fill_(pad_mask, self.eos) # (B, maxlen) topk_index = topk_index.masked_fill_(pad_mask, self.eos) # (B, maxlen)

@ -129,8 +129,8 @@ class MultiHeadedAttention(nn.Layer):
p_attn = self.dropout(attn) p_attn = self.dropout(attn)
x = paddle.matmul(p_attn, value) # (batch, head, time1, d_k) x = paddle.matmul(p_attn, value) # (batch, head, time1, d_k)
x = x.transpose([0, 2, 1, 3]).reshape([n_batch, -1, self.h * x = x.transpose([0, 2, 1, 3]).reshape(
self.d_k]) # (batch, time1, d_model) [n_batch, -1, self.h * self.d_k]) # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model) return self.linear_out(x) # (batch, time1, d_model)
@ -280,8 +280,8 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
(x.shape[0], x.shape[1], x.shape[2], 1), dtype=x.dtype) (x.shape[0], x.shape[1], x.shape[2], 1), dtype=x.dtype)
x_padded = paddle.cat([zero_pad, x], dim=-1) x_padded = paddle.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(x.shape[0], x.shape[1], x.shape[3] + 1, x_padded = x_padded.reshape(
x.shape[2]) [x.shape[0], x.shape[1], x.shape[3] + 1, x.shape[2]])
x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1] x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1]
if zero_triu: if zero_triu:
@ -349,7 +349,8 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
new_cache = paddle.concat((k, v), axis=-1) new_cache = paddle.concat((k, v), axis=-1)
n_batch_pos = pos_emb.shape[0] n_batch_pos = pos_emb.shape[0]
p = self.linear_pos(pos_emb).reshape([n_batch_pos, -1, self.h, self.d_k]) p = self.linear_pos(pos_emb).reshape(
[n_batch_pos, -1, self.h, self.d_k])
p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
# (batch, head, time1, d_k) # (batch, head, time1, d_k)

Loading…
Cancel
Save