|
|
|
@ -131,7 +131,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
|
if self.ctc_weight != 1.0:
|
|
|
|
|
start = time.time()
|
|
|
|
|
loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask,
|
|
|
|
|
text, text_lengths, self.reverse_weight)
|
|
|
|
|
text, text_lengths,
|
|
|
|
|
self.reverse_weight)
|
|
|
|
|
decoder_time = time.time() - start
|
|
|
|
|
#logger.debug(f"decoder time: {decoder_time}")
|
|
|
|
|
|
|
|
|
@ -152,13 +153,12 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
|
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
|
|
|
|
|
return loss, loss_att, loss_ctc
|
|
|
|
|
|
|
|
|
|
def _calc_att_loss(
|
|
|
|
|
self,
|
|
|
|
|
encoder_out: paddle.Tensor,
|
|
|
|
|
encoder_mask: paddle.Tensor,
|
|
|
|
|
ys_pad: paddle.Tensor,
|
|
|
|
|
ys_pad_lens: paddle.Tensor,
|
|
|
|
|
reverse_weight: float) -> Tuple[paddle.Tensor, float]:
|
|
|
|
|
def _calc_att_loss(self,
|
|
|
|
|
encoder_out: paddle.Tensor,
|
|
|
|
|
encoder_mask: paddle.Tensor,
|
|
|
|
|
ys_pad: paddle.Tensor,
|
|
|
|
|
ys_pad_lens: paddle.Tensor,
|
|
|
|
|
reverse_weight: float) -> Tuple[paddle.Tensor, float]:
|
|
|
|
|
"""Calc attention loss.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
@ -188,8 +188,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
|
r_loss_att = paddle.to_tensor(0.0)
|
|
|
|
|
if reverse_weight > 0.0:
|
|
|
|
|
r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad)
|
|
|
|
|
loss_att = loss_att * (1 - reverse_weight
|
|
|
|
|
) + r_loss_att * reverse_weight
|
|
|
|
|
loss_att = loss_att * (1 - reverse_weight) + r_loss_att * reverse_weight
|
|
|
|
|
acc_att = th_accuracy(
|
|
|
|
|
decoder_out.view(-1, self.vocab_size),
|
|
|
|
|
ys_out_pad,
|
|
|
|
@ -599,8 +598,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
|
f"hyp {i} len {len(hyp[0])} r2l score: {r_score} ctc_score: {hyp[1]} reverse_weight: {reverse_weight}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
score = score * (1 - reverse_weight
|
|
|
|
|
) + r_score * reverse_weight
|
|
|
|
|
score = score * (1 - reverse_weight) + r_score * reverse_weight
|
|
|
|
|
# add ctc score (which in ln domain)
|
|
|
|
|
score += hyp[1] * ctc_weight
|
|
|
|
|
if score > best_score:
|
|
|
|
|