fix bug on win

pull/2425/head
Hui Zhang 3 years ago
parent d25871a7b0
commit 7382050e21

@ -237,7 +237,7 @@ def st_reverse_pad_list(ys_pad: paddle.Tensor,
# >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id))
# >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id)
B = ys_pad.shape[0]
_sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos
_sos = paddle.full([B, 1], sos, dtype=ys_pad.dtype)
max_len = paddle.max(ys_lens)
index_range = paddle.arange(0, max_len, 1)
seq_len_expand = ys_lens.unsqueeze(1)
@ -279,7 +279,8 @@ def st_reverse_pad_list(ys_pad: paddle.Tensor,
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, 2, 2]])
r_hyps = paddle.where(seq_mask, r_hyps, eos)
_eos = paddle.full([1], eos, dtype=r_hyps.dtype)
r_hyps = paddle.where(seq_mask, r_hyps, _eos)
# >>> r_hyps
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],

@ -600,7 +600,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w]
r_score += r_decoder_out[i][len(hyp[0])][self.eos]
logger.info(
logger.debug(
f"hyp {i} len {len(hyp[0])} r2l score: {r_score} ctc_score: {hyp[1]} reverse_weight: {reverse_weight}"
)

Loading…
Cancel
Save