diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index 2e1c14ac1..b7196c644 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -190,7 +190,7 @@ class U2BaseModel(ASRInterface, nn.Layer): r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad) loss_att = loss_att * (1 - reverse_weight) + r_loss_att * reverse_weight acc_att = th_accuracy( - decoder_out.view(-1, self.vocab_size), + decoder_out.reshape([-1, self.vocab_size]), ys_out_pad, ignore_label=self.ignore_id, ) return loss_att, acc_att @@ -271,11 +271,13 @@ class U2BaseModel(ASRInterface, nn.Layer): maxlen = encoder_out.shape[1] encoder_dim = encoder_out.shape[2] running_size = batch_size * beam_size - encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view( - running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim) + encoder_out = encoder_out.unsqueeze(1).repeat( + 1, beam_size, 1, 1).reshape( + [running_size, maxlen, + encoder_dim]) # (B*N, maxlen, encoder_dim) encoder_mask = encoder_mask.unsqueeze(1).repeat( - 1, beam_size, 1, 1).view(running_size, 1, - maxlen) # (B*N, 1, max_len) + 1, beam_size, 1, 1).reshape([running_size, 1, + maxlen]) # (B*N, 1, max_len) hyps = paddle.ones( [running_size, 1], dtype=paddle.long).fill_(self.sos) # (B*N, 1) @@ -305,34 +307,35 @@ class U2BaseModel(ASRInterface, nn.Layer): # 2.3 Seconde beam prune: select topk score with history scores = scores + top_k_logp # (B*N, N), broadcast add - scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N) + scores = scores.reshape( + [batch_size, beam_size * beam_size]) # (B, N*N) scores, offset_k_index = scores.topk(k=beam_size) # (B, N) - scores = scores.view(-1, 1) # (B*N, 1) + scores = scores.reshape([-1, 1]) # (B*N, 1) # 2.4. Compute base index in top_k_index, # regard top_k_index as (B*N*N),regard offset_k_index as (B*N), # then find offset_k_index in top_k_index - base_k_index = paddle.arange(batch_size).view(-1, 1).repeat( + base_k_index = paddle.arange(batch_size).reshape([-1, 1]).repeat( 1, beam_size) # (B, N) base_k_index = base_k_index * beam_size * beam_size - best_k_index = base_k_index.view(-1) + offset_k_index.view( - -1) # (B*N) + best_k_index = base_k_index.reshape([-1]) + offset_k_index.reshape( + [-1]) # (B*N) # 2.5 Update best hyps best_k_pred = paddle.index_select( - top_k_index.view(-1), index=best_k_index, axis=0) # (B*N) + top_k_index.reshape([-1]), index=best_k_index, axis=0) # (B*N) best_hyps_index = best_k_index // beam_size last_best_k_hyps = paddle.index_select( hyps, index=best_hyps_index, axis=0) # (B*N, i) hyps = paddle.cat( - (last_best_k_hyps, best_k_pred.view(-1, 1)), + (last_best_k_hyps, best_k_pred.reshape([-1, 1])), dim=1) # (B*N, i+1) # 2.6 Update end flag - end_flag = paddle.equal(hyps[:, -1], self.eos).view(-1, 1) + end_flag = paddle.equal(hyps[:, -1], self.eos).reshape([-1, 1]) # 3. Select best of best - scores = scores.view(batch_size, beam_size) + scores = scores.reshape([batch_size, beam_size]) # TODO: length normalization best_index = paddle.argmax(scores, axis=-1).long() # (B) best_hyps_index = best_index + paddle.arange( @@ -379,7 +382,7 @@ class U2BaseModel(ASRInterface, nn.Layer): ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size) topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1) - topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen) + topk_index = topk_index.reshape([batch_size, maxlen]) # (B, maxlen) pad_mask = make_pad_mask(encoder_out_lens) # (B, maxlen) topk_index = topk_index.masked_fill_(pad_mask, self.eos) # (B, maxlen) diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py index 7f040d3e2..5d75b3281 100644 --- a/paddlespeech/s2t/modules/attention.py +++ b/paddlespeech/s2t/modules/attention.py @@ -129,8 +129,8 @@ class MultiHeadedAttention(nn.Layer): p_attn = self.dropout(attn) x = paddle.matmul(p_attn, value) # (batch, head, time1, d_k) - x = x.transpose([0, 2, 1, 3]).reshape([n_batch, -1, self.h * - self.d_k]) # (batch, time1, d_model) + x = x.transpose([0, 2, 1, 3]).reshape( + [n_batch, -1, self.h * self.d_k]) # (batch, time1, d_model) return self.linear_out(x) # (batch, time1, d_model) @@ -280,8 +280,8 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): (x.shape[0], x.shape[1], x.shape[2], 1), dtype=x.dtype) x_padded = paddle.cat([zero_pad, x], dim=-1) - x_padded = x_padded.view(x.shape[0], x.shape[1], x.shape[3] + 1, - x.shape[2]) + x_padded = x_padded.reshape( + [x.shape[0], x.shape[1], x.shape[3] + 1, x.shape[2]]) x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1] if zero_triu: @@ -349,7 +349,8 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): new_cache = paddle.concat((k, v), axis=-1) n_batch_pos = pos_emb.shape[0] - p = self.linear_pos(pos_emb).reshape([n_batch_pos, -1, self.h, self.d_k]) + p = self.linear_pos(pos_emb).reshape( + [n_batch_pos, -1, self.h, self.d_k]) p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) # (batch, head, time1, d_k)