|
|
@ -32,6 +32,7 @@ from paddle import nn
|
|
|
|
from paddlespeech.audio.utils.tensor_utils import add_sos_eos
|
|
|
|
from paddlespeech.audio.utils.tensor_utils import add_sos_eos
|
|
|
|
from paddlespeech.audio.utils.tensor_utils import pad_sequence
|
|
|
|
from paddlespeech.audio.utils.tensor_utils import pad_sequence
|
|
|
|
from paddlespeech.audio.utils.tensor_utils import reverse_pad_list
|
|
|
|
from paddlespeech.audio.utils.tensor_utils import reverse_pad_list
|
|
|
|
|
|
|
|
from paddlespeech.audio.utils.tensor_utils import st_reverse_pad_list
|
|
|
|
from paddlespeech.audio.utils.tensor_utils import th_accuracy
|
|
|
|
from paddlespeech.audio.utils.tensor_utils import th_accuracy
|
|
|
|
from paddlespeech.s2t.decoders.scorers.ctc import CTCPrefixScorer
|
|
|
|
from paddlespeech.s2t.decoders.scorers.ctc import CTCPrefixScorer
|
|
|
|
from paddlespeech.s2t.frontend.utility import IGNORE_ID
|
|
|
|
from paddlespeech.s2t.frontend.utility import IGNORE_ID
|
|
|
@ -565,16 +566,12 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
dtype=paddle.long) # (beam_size,)
|
|
|
|
dtype=paddle.long) # (beam_size,)
|
|
|
|
hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id)
|
|
|
|
hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id)
|
|
|
|
hyps_lens = hyps_lens + 1 # Add <sos> at begining
|
|
|
|
hyps_lens = hyps_lens + 1 # Add <sos> at begining
|
|
|
|
|
|
|
|
|
|
|
|
encoder_out = encoder_out.repeat(beam_size, 1, 1)
|
|
|
|
encoder_out = encoder_out.repeat(beam_size, 1, 1)
|
|
|
|
encoder_mask = paddle.ones(
|
|
|
|
encoder_mask = paddle.ones(
|
|
|
|
(beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool)
|
|
|
|
(beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool)
|
|
|
|
|
|
|
|
|
|
|
|
# used for right to left decoder
|
|
|
|
r_hyps_pad = st_reverse_pad_list(ori_hyps_pad, hyps_lens - 1, self.sos,
|
|
|
|
r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens - 1,
|
|
|
|
self.eos)
|
|
|
|
self.ignore_id)
|
|
|
|
|
|
|
|
r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos,
|
|
|
|
|
|
|
|
self.ignore_id)
|
|
|
|
|
|
|
|
decoder_out, r_decoder_out, _ = self.decoder(
|
|
|
|
decoder_out, r_decoder_out, _ = self.decoder(
|
|
|
|
encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad,
|
|
|
|
encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad,
|
|
|
|
reverse_weight) # (beam_size, max_hyps_len, vocab_size)
|
|
|
|
reverse_weight) # (beam_size, max_hyps_len, vocab_size)
|
|
|
@ -733,63 +730,10 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
r_hyps = hyps[:, 1:]
|
|
|
|
r_hyps = hyps[:, 1:]
|
|
|
|
# (num_hyps, max_hyps_len, vocab_size)
|
|
|
|
# (num_hyps, max_hyps_len, vocab_size)
|
|
|
|
|
|
|
|
|
|
|
|
# Equal to:
|
|
|
|
r_hyps = st_reverse_pad_list(r_hyps, r_hyps_lens, self.sos, self.eos)
|
|
|
|
# >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id))
|
|
|
|
|
|
|
|
# >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id)
|
|
|
|
decoder_out, r_decoder_out, _ = self.decoder(
|
|
|
|
max_len = paddle.max(r_hyps_lens)
|
|
|
|
encoder_out, encoder_mask, hyps, hyps_lens, r_hyps, reverse_weight)
|
|
|
|
index_range = paddle.arange(0, max_len, 1)
|
|
|
|
|
|
|
|
seq_len_expand = r_hyps_lens.unsqueeze(1)
|
|
|
|
|
|
|
|
seq_mask = seq_len_expand > index_range # (beam, max_len)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
index = (seq_len_expand - 1) - index_range # (beam, max_len)
|
|
|
|
|
|
|
|
# >>> index
|
|
|
|
|
|
|
|
# >>> tensor([[ 2, 1, 0],
|
|
|
|
|
|
|
|
# >>> [ 2, 1, 0],
|
|
|
|
|
|
|
|
# >>> [ 0, -1, -2]])
|
|
|
|
|
|
|
|
index = index * seq_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# >>> index
|
|
|
|
|
|
|
|
# >>> tensor([[2, 1, 0],
|
|
|
|
|
|
|
|
# >>> [2, 1, 0],
|
|
|
|
|
|
|
|
# >>> [0, 0, 0]])
|
|
|
|
|
|
|
|
def paddle_gather(x, dim, index):
|
|
|
|
|
|
|
|
index_shape = index.shape
|
|
|
|
|
|
|
|
index_flatten = index.flatten()
|
|
|
|
|
|
|
|
if dim < 0:
|
|
|
|
|
|
|
|
dim = len(x.shape) + dim
|
|
|
|
|
|
|
|
nd_index = []
|
|
|
|
|
|
|
|
for k in range(len(x.shape)):
|
|
|
|
|
|
|
|
if k == dim:
|
|
|
|
|
|
|
|
nd_index.append(index_flatten)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
reshape_shape = [1] * len(x.shape)
|
|
|
|
|
|
|
|
reshape_shape[k] = x.shape[k]
|
|
|
|
|
|
|
|
x_arange = paddle.arange(x.shape[k], dtype=index.dtype)
|
|
|
|
|
|
|
|
x_arange = x_arange.reshape(reshape_shape)
|
|
|
|
|
|
|
|
dim_index = paddle.expand(x_arange, index_shape).flatten()
|
|
|
|
|
|
|
|
nd_index.append(dim_index)
|
|
|
|
|
|
|
|
ind2 = paddle.transpose(paddle.stack(nd_index),
|
|
|
|
|
|
|
|
[1, 0]).astype("int64")
|
|
|
|
|
|
|
|
paddle_out = paddle.gather_nd(x, ind2).reshape(index_shape)
|
|
|
|
|
|
|
|
return paddle_out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r_hyps = paddle_gather(r_hyps, 1, index)
|
|
|
|
|
|
|
|
# >>> r_hyps
|
|
|
|
|
|
|
|
# >>> tensor([[3, 2, 1],
|
|
|
|
|
|
|
|
# >>> [4, 8, 9],
|
|
|
|
|
|
|
|
# >>> [2, 2, 2]])
|
|
|
|
|
|
|
|
r_hyps = paddle.where(seq_mask, r_hyps, self.eos)
|
|
|
|
|
|
|
|
# >>> r_hyps
|
|
|
|
|
|
|
|
# >>> tensor([[3, 2, 1],
|
|
|
|
|
|
|
|
# >>> [4, 8, 9],
|
|
|
|
|
|
|
|
# >>> [2, eos, eos]])
|
|
|
|
|
|
|
|
r_hyps = paddle.concat([hyps[:, 0:1], r_hyps], axis=1)
|
|
|
|
|
|
|
|
# >>> r_hyps
|
|
|
|
|
|
|
|
# >>> tensor([[sos, 3, 2, 1],
|
|
|
|
|
|
|
|
# >>> [sos, 4, 8, 9],
|
|
|
|
|
|
|
|
# >>> [sos, 2, eos, eos]])
|
|
|
|
|
|
|
|
decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps,
|
|
|
|
|
|
|
|
hyps_lens, r_hyps, reverse_weight)
|
|
|
|
|
|
|
|
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
|
|
|
|
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
|
|
|
|
r_decoder_out = paddle.nn.functional.log_softmax(r_decoder_out, axis=-1)
|
|
|
|
r_decoder_out = paddle.nn.functional.log_softmax(r_decoder_out, axis=-1)
|
|
|
|
return decoder_out, r_decoder_out
|
|
|
|
return decoder_out, r_decoder_out
|
|
|
@ -877,7 +821,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
decoding_chunk_size=decoding_chunk_size,
|
|
|
|
decoding_chunk_size=decoding_chunk_size,
|
|
|
|
num_decoding_left_chunks=num_decoding_left_chunks,
|
|
|
|
num_decoding_left_chunks=num_decoding_left_chunks,
|
|
|
|
ctc_weight=ctc_weight,
|
|
|
|
ctc_weight=ctc_weight,
|
|
|
|
simulate_streaming=simulate_streaming)
|
|
|
|
simulate_streaming=simulate_streaming,
|
|
|
|
|
|
|
|
reverse_weight=reverse_weight)
|
|
|
|
hyps = [hyp]
|
|
|
|
hyps = [hyp]
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Not support decoding method: {decoding_method}")
|
|
|
|
raise ValueError(f"Not support decoding method: {decoding_method}")
|
|
|
|