|
|
|
@ -298,8 +298,8 @@ class U2BaseModel(nn.Layer):
|
|
|
|
|
speech, speech_lengths, decoding_chunk_size,
|
|
|
|
|
num_decoding_left_chunks,
|
|
|
|
|
simulate_streaming) # (B, maxlen, encoder_dim)
|
|
|
|
|
maxlen = encoder_out.size(1)
|
|
|
|
|
encoder_dim = encoder_out.size(2)
|
|
|
|
|
maxlen = encoder_out.shape[1]
|
|
|
|
|
encoder_dim = encoder_out.shape[2]
|
|
|
|
|
running_size = batch_size * beam_size
|
|
|
|
|
encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view(
|
|
|
|
|
running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim)
|
|
|
|
@ -404,7 +404,7 @@ class U2BaseModel(nn.Layer):
|
|
|
|
|
encoder_out, encoder_mask = self._forward_encoder(
|
|
|
|
|
speech, speech_lengths, decoding_chunk_size,
|
|
|
|
|
num_decoding_left_chunks, simulate_streaming)
|
|
|
|
|
maxlen = encoder_out.size(1)
|
|
|
|
|
maxlen = encoder_out.shape[1]
|
|
|
|
|
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
|
|
|
|
|
ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size)
|
|
|
|
|
|
|
|
|
@ -455,7 +455,7 @@ class U2BaseModel(nn.Layer):
|
|
|
|
|
speech, speech_lengths, decoding_chunk_size,
|
|
|
|
|
num_decoding_left_chunks,
|
|
|
|
|
simulate_streaming) # (B, maxlen, encoder_dim)
|
|
|
|
|
maxlen = encoder_out.size(1)
|
|
|
|
|
maxlen = encoder_out.shape[1]
|
|
|
|
|
ctc_probs = self.ctc.log_softmax(encoder_out) # (1, maxlen, vocab_size)
|
|
|
|
|
ctc_probs = ctc_probs.squeeze(0)
|
|
|
|
|
|
|
|
|
@ -583,7 +583,7 @@ class U2BaseModel(nn.Layer):
|
|
|
|
|
|
|
|
|
|
encoder_out = encoder_out.repeat(beam_size, 1, 1)
|
|
|
|
|
encoder_mask = paddle.ones(
|
|
|
|
|
(beam_size, 1, encoder_out.size(1)), dtype=paddle.bool)
|
|
|
|
|
(beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool)
|
|
|
|
|
decoder_out, _ = self.decoder(
|
|
|
|
|
encoder_out, encoder_mask, hyps_pad,
|
|
|
|
|
hyps_lens) # (beam_size, max_hyps_len, vocab_size)
|
|
|
|
@ -690,13 +690,13 @@ class U2BaseModel(nn.Layer):
|
|
|
|
|
Returns:
|
|
|
|
|
paddle.Tensor: decoder output, (B, L)
|
|
|
|
|
"""
|
|
|
|
|
assert encoder_out.size(0) == 1
|
|
|
|
|
num_hyps = hyps.size(0)
|
|
|
|
|
assert hyps_lens.size(0) == num_hyps
|
|
|
|
|
assert encoder_out.shape[0] == 1
|
|
|
|
|
num_hyps = hyps.shape[0]
|
|
|
|
|
assert hyps_lens.shape[0] == num_hyps
|
|
|
|
|
encoder_out = encoder_out.repeat(num_hyps, 1, 1)
|
|
|
|
|
# (B, 1, T)
|
|
|
|
|
encoder_mask = paddle.ones(
|
|
|
|
|
[num_hyps, 1, encoder_out.size(1)], dtype=paddle.bool)
|
|
|
|
|
[num_hyps, 1, encoder_out.shape[1]], dtype=paddle.bool)
|
|
|
|
|
# (num_hyps, max_hyps_len, vocab_size)
|
|
|
|
|
decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps,
|
|
|
|
|
hyps_lens)
|
|
|
|
@ -751,7 +751,7 @@ class U2BaseModel(nn.Layer):
|
|
|
|
|
Returns:
|
|
|
|
|
List[List[int]]: transcripts.
|
|
|
|
|
"""
|
|
|
|
|
batch_size = feats.size(0)
|
|
|
|
|
batch_size = feats.shape[0]
|
|
|
|
|
if decoding_method in ['ctc_prefix_beam_search',
|
|
|
|
|
'attention_rescoring'] and batch_size > 1:
|
|
|
|
|
logger.fatal(
|
|
|
|
@ -779,7 +779,7 @@ class U2BaseModel(nn.Layer):
|
|
|
|
|
# result in List[int], change it to List[List[int]] for compatible
|
|
|
|
|
# with other batch decoding mode
|
|
|
|
|
elif decoding_method == 'ctc_prefix_beam_search':
|
|
|
|
|
assert feats.size(0) == 1
|
|
|
|
|
assert feats.shape[0] == 1
|
|
|
|
|
hyp = self.ctc_prefix_beam_search(
|
|
|
|
|
feats,
|
|
|
|
|
feats_lengths,
|
|
|
|
@ -789,7 +789,7 @@ class U2BaseModel(nn.Layer):
|
|
|
|
|
simulate_streaming=simulate_streaming)
|
|
|
|
|
hyps = [hyp]
|
|
|
|
|
elif decoding_method == 'attention_rescoring':
|
|
|
|
|
assert feats.size(0) == 1
|
|
|
|
|
assert feats.shape[0] == 1
|
|
|
|
|
hyp = self.attention_rescoring(
|
|
|
|
|
feats,
|
|
|
|
|
feats_lengths,
|
|
|
|
|