support bitransformer decoder, test=asr

pull/2415/head
tianhao zhang 2 years ago
parent 0a95689461
commit 027535dec1

@ -213,3 +213,82 @@ def reverse_pad_list(ys_pad: paddle.Tensor,
r_ys_pad = pad_sequence([(paddle.flip(y.int()[:i], [0]))
for y, i in zip(ys_pad, ys_lens)], True, pad_value)
return r_ys_pad
def st_reverse_pad_list(ys_pad: paddle.Tensor,
ys_lens: paddle.Tensor,
sos: float,
eos: float) -> paddle.Tensor:
"""Reverse padding for the list of tensors.
Args:
ys_pad (tensor): The padded tensor (B, Tokenmax).
ys_lens (tensor): The lens of token seqs (B)
Returns:
Tensor: Padded tensor (B, Tokenmax).
Examples:
>>> x
tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]])
>>> pad_list(x, 0)
tensor([[4, 3, 2, 1],
[7, 6, 5, 0],
[9, 8, 0, 0]])
"""
# Equal to:
# >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id))
# >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id)
B = ys_pad.shape[0]
_sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos
max_len = paddle.max(ys_lens)
index_range = paddle.arange(0, max_len, 1)
seq_len_expand = ys_lens.unsqueeze(1)
seq_mask = seq_len_expand > index_range # (beam, max_len)
index = (seq_len_expand - 1) - index_range # (beam, max_len)
# >>> index
# >>> tensor([[ 2, 1, 0],
# >>> [ 2, 1, 0],
# >>> [ 0, -1, -2]])
index = index * seq_mask
# >>> index
# >>> tensor([[2, 1, 0],
# >>> [2, 1, 0],
# >>> [0, 0, 0]])
def paddle_gather(x, dim, index):
index_shape = index.shape
index_flatten = index.flatten()
if dim < 0:
dim = len(x.shape) + dim
nd_index = []
for k in range(len(x.shape)):
if k == dim:
nd_index.append(index_flatten)
else:
reshape_shape = [1] * len(x.shape)
reshape_shape[k] = x.shape[k]
x_arange = paddle.arange(x.shape[k], dtype=index.dtype)
x_arange = x_arange.reshape(reshape_shape)
dim_index = paddle.expand(x_arange, index_shape).flatten()
nd_index.append(dim_index)
ind2 = paddle.transpose(paddle.stack(nd_index), [1, 0]).astype("int64")
paddle_out = paddle.gather_nd(x, ind2).reshape(index_shape)
return paddle_out
r_hyps = paddle_gather(ys_pad, 1, index)
# >>> r_hyps
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, 2, 2]])
r_hyps = paddle.where(seq_mask, r_hyps, eos)
# >>> r_hyps
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, eos, eos]])
r_hyps = paddle.cat([_sos, r_hyps], dim=1)
# r_hyps = paddle.concat([hyps[:, 0:1], r_hyps], axis=1)
# >>> r_hyps
# >>> tensor([[sos, 3, 2, 1],
# >>> [sos, 4, 8, 9],
# >>> [sos, 2, eos, eos]])
return r_hyps

@ -341,6 +341,7 @@ class U2Tester(U2Trainer):
start_time = time.time()
target_transcripts = self.id2token(texts, texts_len, self.text_feature)
result_transcripts, result_tokenids = self.model.decode(
audio,
audio_len,

@ -32,6 +32,7 @@ from paddle import nn
from paddlespeech.audio.utils.tensor_utils import add_sos_eos
from paddlespeech.audio.utils.tensor_utils import pad_sequence
from paddlespeech.audio.utils.tensor_utils import reverse_pad_list
from paddlespeech.audio.utils.tensor_utils import st_reverse_pad_list
from paddlespeech.audio.utils.tensor_utils import th_accuracy
from paddlespeech.s2t.decoders.scorers.ctc import CTCPrefixScorer
from paddlespeech.s2t.frontend.utility import IGNORE_ID
@ -565,16 +566,12 @@ class U2BaseModel(ASRInterface, nn.Layer):
dtype=paddle.long) # (beam_size,)
hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id)
hyps_lens = hyps_lens + 1 # Add <sos> at begining
encoder_out = encoder_out.repeat(beam_size, 1, 1)
encoder_mask = paddle.ones(
(beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool)
# used for right to left decoder
r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens - 1,
self.ignore_id)
r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos,
self.ignore_id)
r_hyps_pad = st_reverse_pad_list(ori_hyps_pad, hyps_lens - 1, self.sos,
self.eos)
decoder_out, r_decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad,
reverse_weight) # (beam_size, max_hyps_len, vocab_size)
@ -733,63 +730,10 @@ class U2BaseModel(ASRInterface, nn.Layer):
r_hyps = hyps[:, 1:]
# (num_hyps, max_hyps_len, vocab_size)
# Equal to:
# >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id))
# >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id)
max_len = paddle.max(r_hyps_lens)
index_range = paddle.arange(0, max_len, 1)
seq_len_expand = r_hyps_lens.unsqueeze(1)
seq_mask = seq_len_expand > index_range # (beam, max_len)
index = (seq_len_expand - 1) - index_range # (beam, max_len)
# >>> index
# >>> tensor([[ 2, 1, 0],
# >>> [ 2, 1, 0],
# >>> [ 0, -1, -2]])
index = index * seq_mask
# >>> index
# >>> tensor([[2, 1, 0],
# >>> [2, 1, 0],
# >>> [0, 0, 0]])
def paddle_gather(x, dim, index):
index_shape = index.shape
index_flatten = index.flatten()
if dim < 0:
dim = len(x.shape) + dim
nd_index = []
for k in range(len(x.shape)):
if k == dim:
nd_index.append(index_flatten)
else:
reshape_shape = [1] * len(x.shape)
reshape_shape[k] = x.shape[k]
x_arange = paddle.arange(x.shape[k], dtype=index.dtype)
x_arange = x_arange.reshape(reshape_shape)
dim_index = paddle.expand(x_arange, index_shape).flatten()
nd_index.append(dim_index)
ind2 = paddle.transpose(paddle.stack(nd_index),
[1, 0]).astype("int64")
paddle_out = paddle.gather_nd(x, ind2).reshape(index_shape)
return paddle_out
r_hyps = paddle_gather(r_hyps, 1, index)
# >>> r_hyps
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, 2, 2]])
r_hyps = paddle.where(seq_mask, r_hyps, self.eos)
# >>> r_hyps
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, eos, eos]])
r_hyps = paddle.concat([hyps[:, 0:1], r_hyps], axis=1)
# >>> r_hyps
# >>> tensor([[sos, 3, 2, 1],
# >>> [sos, 4, 8, 9],
# >>> [sos, 2, eos, eos]])
decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps,
hyps_lens, r_hyps, reverse_weight)
r_hyps = st_reverse_pad_list(r_hyps, r_hyps_lens, self.sos, self.eos)
decoder_out, r_decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps, hyps_lens, r_hyps, reverse_weight)
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
r_decoder_out = paddle.nn.functional.log_softmax(r_decoder_out, axis=-1)
return decoder_out, r_decoder_out
@ -877,7 +821,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
decoding_chunk_size=decoding_chunk_size,
num_decoding_left_chunks=num_decoding_left_chunks,
ctc_weight=ctc_weight,
simulate_streaming=simulate_streaming)
simulate_streaming=simulate_streaming,
reverse_weight=reverse_weight)
hyps = [hyp]
else:
raise ValueError(f"Not support decoding method: {decoding_method}")

Loading…
Cancel
Save