|
|
|
@ -190,7 +190,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
|
r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad)
|
|
|
|
|
loss_att = loss_att * (1 - reverse_weight) + r_loss_att * reverse_weight
|
|
|
|
|
acc_att = th_accuracy(
|
|
|
|
|
decoder_out.view(-1, self.vocab_size),
|
|
|
|
|
decoder_out.reshape([-1, self.vocab_size]),
|
|
|
|
|
ys_out_pad,
|
|
|
|
|
ignore_label=self.ignore_id, )
|
|
|
|
|
return loss_att, acc_att
|
|
|
|
@ -271,11 +271,13 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
|
maxlen = encoder_out.shape[1]
|
|
|
|
|
encoder_dim = encoder_out.shape[2]
|
|
|
|
|
running_size = batch_size * beam_size
|
|
|
|
|
encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view(
|
|
|
|
|
running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim)
|
|
|
|
|
encoder_out = encoder_out.unsqueeze(1).repeat(
|
|
|
|
|
1, beam_size, 1, 1).reshape(
|
|
|
|
|
[running_size, maxlen,
|
|
|
|
|
encoder_dim]) # (B*N, maxlen, encoder_dim)
|
|
|
|
|
encoder_mask = encoder_mask.unsqueeze(1).repeat(
|
|
|
|
|
1, beam_size, 1, 1).view(running_size, 1,
|
|
|
|
|
maxlen) # (B*N, 1, max_len)
|
|
|
|
|
1, beam_size, 1, 1).reshape([running_size, 1,
|
|
|
|
|
maxlen]) # (B*N, 1, max_len)
|
|
|
|
|
|
|
|
|
|
hyps = paddle.ones(
|
|
|
|
|
[running_size, 1], dtype=paddle.long).fill_(self.sos) # (B*N, 1)
|
|
|
|
@ -305,34 +307,35 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
|
|
|
|
|
|
# 2.3 Seconde beam prune: select topk score with history
|
|
|
|
|
scores = scores + top_k_logp # (B*N, N), broadcast add
|
|
|
|
|
scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N)
|
|
|
|
|
scores = scores.reshape(
|
|
|
|
|
[batch_size, beam_size * beam_size]) # (B, N*N)
|
|
|
|
|
scores, offset_k_index = scores.topk(k=beam_size) # (B, N)
|
|
|
|
|
scores = scores.view(-1, 1) # (B*N, 1)
|
|
|
|
|
scores = scores.reshape([-1, 1]) # (B*N, 1)
|
|
|
|
|
|
|
|
|
|
# 2.4. Compute base index in top_k_index,
|
|
|
|
|
# regard top_k_index as (B*N*N),regard offset_k_index as (B*N),
|
|
|
|
|
# then find offset_k_index in top_k_index
|
|
|
|
|
base_k_index = paddle.arange(batch_size).view(-1, 1).repeat(
|
|
|
|
|
base_k_index = paddle.arange(batch_size).reshape([-1, 1]).repeat(
|
|
|
|
|
1, beam_size) # (B, N)
|
|
|
|
|
base_k_index = base_k_index * beam_size * beam_size
|
|
|
|
|
best_k_index = base_k_index.view(-1) + offset_k_index.view(
|
|
|
|
|
-1) # (B*N)
|
|
|
|
|
best_k_index = base_k_index.reshape([-1]) + offset_k_index.reshape(
|
|
|
|
|
[-1]) # (B*N)
|
|
|
|
|
|
|
|
|
|
# 2.5 Update best hyps
|
|
|
|
|
best_k_pred = paddle.index_select(
|
|
|
|
|
top_k_index.view(-1), index=best_k_index, axis=0) # (B*N)
|
|
|
|
|
top_k_index.reshape([-1]), index=best_k_index, axis=0) # (B*N)
|
|
|
|
|
best_hyps_index = best_k_index // beam_size
|
|
|
|
|
last_best_k_hyps = paddle.index_select(
|
|
|
|
|
hyps, index=best_hyps_index, axis=0) # (B*N, i)
|
|
|
|
|
hyps = paddle.cat(
|
|
|
|
|
(last_best_k_hyps, best_k_pred.view(-1, 1)),
|
|
|
|
|
(last_best_k_hyps, best_k_pred.reshape([-1, 1])),
|
|
|
|
|
dim=1) # (B*N, i+1)
|
|
|
|
|
|
|
|
|
|
# 2.6 Update end flag
|
|
|
|
|
end_flag = paddle.equal(hyps[:, -1], self.eos).view(-1, 1)
|
|
|
|
|
end_flag = paddle.equal(hyps[:, -1], self.eos).reshape([-1, 1])
|
|
|
|
|
|
|
|
|
|
# 3. Select best of best
|
|
|
|
|
scores = scores.view(batch_size, beam_size)
|
|
|
|
|
scores = scores.reshape([batch_size, beam_size])
|
|
|
|
|
# TODO: length normalization
|
|
|
|
|
best_index = paddle.argmax(scores, axis=-1).long() # (B)
|
|
|
|
|
best_hyps_index = best_index + paddle.arange(
|
|
|
|
@ -379,7 +382,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
|
ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size)
|
|
|
|
|
|
|
|
|
|
topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1)
|
|
|
|
|
topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen)
|
|
|
|
|
topk_index = topk_index.reshape([batch_size, maxlen]) # (B, maxlen)
|
|
|
|
|
pad_mask = make_pad_mask(encoder_out_lens) # (B, maxlen)
|
|
|
|
|
topk_index = topk_index.masked_fill_(pad_mask, self.eos) # (B, maxlen)
|
|
|
|
|
|
|
|
|
|