|
|
|
@ -689,24 +689,24 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
|
"""
|
|
|
|
|
return self.ctc.log_softmax(xs)
|
|
|
|
|
|
|
|
|
|
@jit.to_static
|
|
|
|
|
# @jit.to_static
|
|
|
|
|
def is_bidirectional_decoder(self) -> bool:
|
|
|
|
|
"""
|
|
|
|
|
Returns:
|
|
|
|
|
torch.Tensor: decoder output
|
|
|
|
|
paddle.Tensor: decoder output
|
|
|
|
|
"""
|
|
|
|
|
if hasattr(self.decoder, 'right_decoder'):
|
|
|
|
|
return True
|
|
|
|
|
else:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
@jit.to_static
|
|
|
|
|
# @jit.to_static
|
|
|
|
|
def forward_attention_decoder(
|
|
|
|
|
self,
|
|
|
|
|
hyps: paddle.Tensor,
|
|
|
|
|
hyps_lens: paddle.Tensor,
|
|
|
|
|
encoder_out: paddle.Tensor,
|
|
|
|
|
reverse_weight: float=0, ) -> paddle.Tensor:
|
|
|
|
|
reverse_weight: float=0.0, ) -> paddle.Tensor:
|
|
|
|
|
""" Export interface for c++ call, forward decoder with multiple
|
|
|
|
|
hypothesis from ctc prefix beam search and one encoder output
|
|
|
|
|
Args:
|
|
|
|
@ -783,7 +783,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
|
# >>> tensor([[3, 2, 1],
|
|
|
|
|
# >>> [4, 8, 9],
|
|
|
|
|
# >>> [2, eos, eos]])
|
|
|
|
|
r_hyps = torch.concat([hyps[:, 0:1], r_hyps], axis=1)
|
|
|
|
|
r_hyps = paddle.concat([hyps[:, 0:1], r_hyps], axis=1)
|
|
|
|
|
# >>> r_hyps
|
|
|
|
|
# >>> tensor([[sos, 3, 2, 1],
|
|
|
|
|
# >>> [sos, 4, 8, 9],
|
|
|
|
@ -791,7 +791,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
|
decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps,
|
|
|
|
|
hyps_lens, r_hyps, reverse_weight)
|
|
|
|
|
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
|
|
|
|
|
r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1)
|
|
|
|
|
r_decoder_out = paddle.nn.functional.log_softmax(r_decoder_out, axis=-1)
|
|
|
|
|
return decoder_out, r_decoder_out
|
|
|
|
|
|
|
|
|
|
@paddle.no_grad()
|
|
|
|
|