Merge pull request #2415 from Zth9730/u2++_decoder

[s2t] support bitransformer decoder
pull/2451/head
Hui Zhang 2 years ago committed by GitHub
commit 1a1ce92cb4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -31,7 +31,6 @@ def has_tensor(val):
return True
elif isinstance(val, dict):
for k, v in val.items():
print(k)
if has_tensor(v):
return True
else:
@ -143,14 +142,15 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
[ 7, 8, 9, 11, -1, -1]])
"""
# TODO(Hui Zhang): using comment code,
#_sos = paddle.to_tensor(
# [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
#_eos = paddle.to_tensor(
# [eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
#ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
#ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]
#ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]
#return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)
# _sos = paddle.to_tensor(
# [sos], dtype=ys_pad.dtype, stop_gradient=True, place=ys_pad.place)
# _eos = paddle.to_tensor(
# [eos], dtype=ys_pad.dtype, stop_gradient=True, place=ys_pad.place)
# ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
# ys_in = [paddle.concat([_sos, y], axis=0) for y in ys]
# ys_out = [paddle.concat([y, _eos], axis=0) for y in ys]
# return pad_sequence(ys_in, padding_value=eos).transpose([1,0]), pad_sequence(ys_out, padding_value=ignore_id).transpose([1,0])
B = ys_pad.shape[0]
_sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos
_eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos
@ -190,3 +190,106 @@ def th_accuracy(pad_outputs: paddle.Tensor,
# denominator = paddle.sum(mask)
denominator = paddle.sum(mask.type_as(pad_targets))
return float(numerator) / float(denominator)
def reverse_pad_list(ys_pad: paddle.Tensor,
ys_lens: paddle.Tensor,
pad_value: float=-1.0) -> paddle.Tensor:
"""Reverse padding for the list of tensors.
Args:
ys_pad (tensor): The padded tensor (B, Tokenmax).
ys_lens (tensor): The lens of token seqs (B)
pad_value (int): Value for padding.
Returns:
Tensor: Padded tensor (B, Tokenmax).
Examples:
>>> x
tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]])
>>> pad_list(x, 0)
tensor([[4, 3, 2, 1],
[7, 6, 5, 0],
[9, 8, 0, 0]])
"""
r_ys_pad = pad_sequence([(paddle.flip(y.int()[:i], [0]))
for y, i in zip(ys_pad, ys_lens)], True, pad_value)
return r_ys_pad
def st_reverse_pad_list(ys_pad: paddle.Tensor,
ys_lens: paddle.Tensor,
sos: float,
eos: float) -> paddle.Tensor:
"""Reverse padding for the list of tensors.
Args:
ys_pad (tensor): The padded tensor (B, Tokenmax).
ys_lens (tensor): The lens of token seqs (B)
Returns:
Tensor: Padded tensor (B, Tokenmax).
Examples:
>>> x
tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]])
>>> pad_list(x, 0)
tensor([[4, 3, 2, 1],
[7, 6, 5, 0],
[9, 8, 0, 0]])
"""
# Equal to:
# >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id))
# >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id)
B = ys_pad.shape[0]
_sos = paddle.full([B, 1], sos, dtype=ys_pad.dtype)
max_len = paddle.max(ys_lens)
index_range = paddle.arange(0, max_len, 1)
seq_len_expand = ys_lens.unsqueeze(1)
seq_mask = seq_len_expand > index_range # (beam, max_len)
index = (seq_len_expand - 1) - index_range # (beam, max_len)
# >>> index
# >>> tensor([[ 2, 1, 0],
# >>> [ 2, 1, 0],
# >>> [ 0, -1, -2]])
index = index * seq_mask
# >>> index
# >>> tensor([[2, 1, 0],
# >>> [2, 1, 0],
# >>> [0, 0, 0]])
def paddle_gather(x, dim, index):
index_shape = index.shape
index_flatten = index.flatten()
if dim < 0:
dim = len(x.shape) + dim
nd_index = []
for k in range(len(x.shape)):
if k == dim:
nd_index.append(index_flatten)
else:
reshape_shape = [1] * len(x.shape)
reshape_shape[k] = x.shape[k]
x_arange = paddle.arange(x.shape[k], dtype=index.dtype)
x_arange = x_arange.reshape(reshape_shape)
dim_index = paddle.expand(x_arange, index_shape).flatten()
nd_index.append(dim_index)
ind2 = paddle.transpose(paddle.stack(nd_index), [1, 0]).astype("int64")
paddle_out = paddle.gather_nd(x, ind2).reshape(index_shape)
return paddle_out
r_hyps = paddle_gather(ys_pad, 1, index)
# >>> r_hyps
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, 2, 2]])
eos = paddle.full([1], eos, dtype=r_hyps.dtype)
r_hyps = paddle.where(seq_mask, r_hyps, eos)
# >>> r_hyps
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, eos, eos]])
r_hyps = paddle.cat([_sos, r_hyps], dim=1)
# r_hyps = paddle.concat([hyps[:, 0:1], r_hyps], axis=1)
# >>> r_hyps
# >>> tensor([[sos, 3, 2, 1],
# >>> [sos, 4, 8, 9],
# >>> [sos, 2, eos, eos]])
return r_hyps

@ -40,7 +40,7 @@ class U2Infer():
self.preprocess_conf = config.preprocess_config
self.preprocess_args = {"train": False}
self.preprocessing = Transformation(self.preprocess_conf)
self.reverse_weight = getattr(config.model_conf, 'reverse_weight', 0.0)
self.text_feature = TextFeaturizer(
unit_type=config.unit_type,
vocab=config.vocab_filepath,
@ -89,7 +89,8 @@ class U2Infer():
ctc_weight=decode_config.ctc_weight,
decoding_chunk_size=decode_config.decoding_chunk_size,
num_decoding_left_chunks=decode_config.num_decoding_left_chunks,
simulate_streaming=decode_config.simulate_streaming)
simulate_streaming=decode_config.simulate_streaming,
reverse_weight=self.reverse_weight)
rsl = result_transcripts[0][0]
utt = Path(self.audio_file).name
logger.info(f"hyp: {utt} {result_transcripts[0][0]}")

@ -253,7 +253,6 @@ class U2Trainer(Trainer):
model_conf.output_dim = self.test_loader.vocab_size
model = U2Model.from_config(model_conf)
if self.parallel:
model = paddle.DataParallel(model)
@ -317,6 +316,7 @@ class U2Tester(U2Trainer):
vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix)
self.vocab_list = self.text_feature.vocab_list
self.reverse_weight = getattr(config.model_conf, 'reverse_weight', 0.0)
def id2token(self, texts, texts_len, text_feature):
""" ord() id to chr() chr """
@ -341,6 +341,7 @@ class U2Tester(U2Trainer):
start_time = time.time()
target_transcripts = self.id2token(texts, texts_len, self.text_feature)
result_transcripts, result_tokenids = self.model.decode(
audio,
audio_len,
@ -350,7 +351,8 @@ class U2Tester(U2Trainer):
ctc_weight=decode_config.ctc_weight,
decoding_chunk_size=decode_config.decoding_chunk_size,
num_decoding_left_chunks=decode_config.num_decoding_left_chunks,
simulate_streaming=decode_config.simulate_streaming)
simulate_streaming=decode_config.simulate_streaming,
reverse_weight=self.reverse_weight)
decode_time = time.time() - start_time
for utt, target, result, rec_tids in zip(

@ -361,7 +361,7 @@ class DataLoaderFactory():
elif mode == 'valid':
config['manifest'] = config.dev_manifest
config['train_mode'] = False
elif model == 'test' or mode == 'align':
elif mode == 'test' or mode == 'align':
config['manifest'] = config.test_manifest
config['train_mode'] = False
config['dither'] = 0.0

@ -31,6 +31,8 @@ from paddle import nn
from paddlespeech.audio.utils.tensor_utils import add_sos_eos
from paddlespeech.audio.utils.tensor_utils import pad_sequence
from paddlespeech.audio.utils.tensor_utils import reverse_pad_list
from paddlespeech.audio.utils.tensor_utils import st_reverse_pad_list
from paddlespeech.audio.utils.tensor_utils import th_accuracy
from paddlespeech.s2t.decoders.scorers.ctc import CTCPrefixScorer
from paddlespeech.s2t.frontend.utility import IGNORE_ID
@ -38,6 +40,7 @@ from paddlespeech.s2t.frontend.utility import load_cmvn
from paddlespeech.s2t.models.asr_interface import ASRInterface
from paddlespeech.s2t.modules.cmvn import GlobalCMVN
from paddlespeech.s2t.modules.ctc import CTCDecoderBase
from paddlespeech.s2t.modules.decoder import BiTransformerDecoder
from paddlespeech.s2t.modules.decoder import TransformerDecoder
from paddlespeech.s2t.modules.encoder import ConformerEncoder
from paddlespeech.s2t.modules.encoder import TransformerEncoder
@ -69,6 +72,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
ctc: CTCDecoderBase,
ctc_weight: float=0.5,
ignore_id: int=IGNORE_ID,
reverse_weight: float=0.0,
lsm_weight: float=0.0,
length_normalized_loss: bool=False,
**kwargs):
@ -82,6 +86,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
self.vocab_size = vocab_size
self.ignore_id = ignore_id
self.ctc_weight = ctc_weight
self.reverse_weight = reverse_weight
self.encoder = encoder
self.decoder = decoder
@ -171,12 +176,21 @@ class U2BaseModel(ASRInterface, nn.Layer):
self.ignore_id)
ys_in_lens = ys_pad_lens + 1
r_ys_pad = reverse_pad_list(ys_pad, ys_pad_lens, float(self.ignore_id))
r_ys_in_pad, r_ys_out_pad = add_sos_eos(r_ys_pad, self.sos, self.eos,
self.ignore_id)
# 1. Forward decoder
decoder_out, _ = self.decoder(encoder_out, encoder_mask, ys_in_pad,
ys_in_lens)
decoder_out, r_decoder_out, _ = self.decoder(
encoder_out, encoder_mask, ys_in_pad, ys_in_lens, r_ys_in_pad,
self.reverse_weight)
# 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_out_pad)
r_loss_att = paddle.to_tensor(0.0)
if self.reverse_weight > 0.0:
r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad)
loss_att = loss_att * (1 - self.reverse_weight
) + r_loss_att * self.reverse_weight
acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size),
ys_out_pad,
@ -359,6 +373,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
# Let's assume B = batch_size
# encoder_out: (B, maxlen, encoder_dim)
# encoder_mask: (B, 1, Tmax)
encoder_out, encoder_mask = self._forward_encoder(
speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks, simulate_streaming)
@ -500,7 +515,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
decoding_chunk_size: int=-1,
num_decoding_left_chunks: int=-1,
ctc_weight: float=0.0,
simulate_streaming: bool=False, ) -> List[int]:
simulate_streaming: bool=False,
reverse_weight: float=0.0, ) -> List[int]:
""" Apply attention rescoring decoding, CTC prefix beam search
is applied first to get nbest, then we resoring the nbest on
attention decoder with corresponding encoder out
@ -520,6 +536,9 @@ class U2BaseModel(ASRInterface, nn.Layer):
"""
assert speech.shape[0] == speech_lengths.shape[0]
assert decoding_chunk_size != 0
if reverse_weight > 0.0:
# decoder should be a bitransformer decoder if reverse_weight > 0.0
assert hasattr(self.decoder, 'right_decoder')
device = speech.place
batch_size = speech.shape[0]
# For attention rescoring we only support batch_size=1
@ -541,22 +560,30 @@ class U2BaseModel(ASRInterface, nn.Layer):
hyp_content, place=device, dtype=paddle.long)
hyp_list.append(hyp_content)
hyps_pad = pad_sequence(hyp_list, True, self.ignore_id)
ori_hyps_pad = hyps_pad
hyps_lens = paddle.to_tensor(
[len(hyp[0]) for hyp in hyps], place=device,
dtype=paddle.long) # (beam_size,)
hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id)
hyps_lens = hyps_lens + 1 # Add <sos> at begining
encoder_out = encoder_out.repeat(beam_size, 1, 1)
encoder_mask = paddle.ones(
(beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool)
decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps_pad,
hyps_lens) # (beam_size, max_hyps_len, vocab_size)
r_hyps_pad = st_reverse_pad_list(ori_hyps_pad, hyps_lens - 1, self.sos,
self.eos)
decoder_out, r_decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad,
reverse_weight) # (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
decoder_out = decoder_out.numpy()
# r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a
# conventional transformer decoder.
r_decoder_out = paddle.nn.functional.log_softmax(r_decoder_out, axis=-1)
r_decoder_out = r_decoder_out.numpy()
# Only use decoder score for rescoring
best_score = -float('inf')
best_index = 0
@ -567,6 +594,12 @@ class U2BaseModel(ASRInterface, nn.Layer):
score += decoder_out[i][j][w]
# last decoder output token is `eos`, for laste decoder input token.
score += decoder_out[i][len(hyp[0])][self.eos]
if reverse_weight > 0:
r_score = 0.0
for j, w in enumerate(hyp[0]):
r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w]
r_score += r_decoder_out[i][len(hyp[0])][self.eos]
score = score * (1 - reverse_weight) + r_score * reverse_weight
# add ctc score (which in ln domain)
score += hyp[1] * ctc_weight
if score > best_score:
@ -653,12 +686,24 @@ class U2BaseModel(ASRInterface, nn.Layer):
"""
return self.ctc.log_softmax(xs)
@jit.to_static
# @jit.to_static
def is_bidirectional_decoder(self) -> bool:
"""
Returns:
paddle.Tensor: decoder output
"""
if hasattr(self.decoder, 'right_decoder'):
return True
else:
return False
# @jit.to_static
def forward_attention_decoder(
self,
hyps: paddle.Tensor,
hyps_lens: paddle.Tensor,
encoder_out: paddle.Tensor, ) -> paddle.Tensor:
encoder_out: paddle.Tensor,
reverse_weight: float=0.0, ) -> paddle.Tensor:
""" Export interface for c++ call, forward decoder with multiple
hypothesis from ctc prefix beam search and one encoder output
Args:
@ -676,11 +721,22 @@ class U2BaseModel(ASRInterface, nn.Layer):
# (B, 1, T)
encoder_mask = paddle.ones(
[num_hyps, 1, encoder_out.shape[1]], dtype=paddle.bool)
# input for right to left decoder
# this hyps_lens has count <sos> token, we need minus it.
r_hyps_lens = hyps_lens - 1
# this hyps has included <sos> token, so it should be
# convert the original hyps.
r_hyps = hyps[:, 1:]
# (num_hyps, max_hyps_len, vocab_size)
decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps,
hyps_lens)
r_hyps = st_reverse_pad_list(r_hyps, r_hyps_lens, self.sos, self.eos)
decoder_out, r_decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps, hyps_lens, r_hyps, reverse_weight)
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
return decoder_out
r_decoder_out = paddle.nn.functional.log_softmax(r_decoder_out, axis=-1)
return decoder_out, r_decoder_out
@paddle.no_grad()
def decode(self,
@ -692,7 +748,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
ctc_weight: float=0.0,
decoding_chunk_size: int=-1,
num_decoding_left_chunks: int=-1,
simulate_streaming: bool=False):
simulate_streaming: bool=False,
reverse_weight: float=0.0):
"""u2 decoding.
Args:
@ -764,7 +821,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
decoding_chunk_size=decoding_chunk_size,
num_decoding_left_chunks=num_decoding_left_chunks,
ctc_weight=ctc_weight,
simulate_streaming=simulate_streaming)
simulate_streaming=simulate_streaming,
reverse_weight=reverse_weight)
hyps = [hyp]
else:
raise ValueError(f"Not support decoding method: {decoding_method}")
@ -801,7 +859,6 @@ class U2Model(U2DecodeModel):
with DefaultInitializerContext(init_type):
vocab_size, encoder, decoder, ctc = U2Model._init_from_config(
configs)
super().__init__(
vocab_size=vocab_size,
encoder=encoder,
@ -851,10 +908,20 @@ class U2Model(U2DecodeModel):
raise ValueError(f"not support encoder type:{encoder_type}")
# decoder
decoder_type = configs.get('decoder', 'transformer')
logger.debug(f"U2 Decoder type: {decoder_type}")
if decoder_type == 'transformer':
decoder = TransformerDecoder(vocab_size,
encoder.output_size(),
**configs['decoder_conf'])
elif decoder_type == 'bitransformer':
assert 0.0 < configs['model_conf']['reverse_weight'] < 1.0
assert configs['decoder_conf']['r_num_blocks'] > 0
decoder = BiTransformerDecoder(vocab_size,
encoder.output_size(),
**configs['decoder_conf'])
else:
raise ValueError(f"not support decoder type:{decoder_type}")
# ctc decoder and ctc loss
model_conf = configs.get('model_conf', dict())
dropout_rate = model_conf.get('ctc_dropout_rate', 0.0)

@ -35,7 +35,6 @@ from paddlespeech.s2t.modules.mask import make_xs_mask
from paddlespeech.s2t.modules.mask import subsequent_mask
from paddlespeech.s2t.modules.positionwise_feed_forward import PositionwiseFeedForward
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ["TransformerDecoder"]
@ -116,13 +115,19 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
memory: paddle.Tensor,
memory_mask: paddle.Tensor,
ys_in_pad: paddle.Tensor,
ys_in_lens: paddle.Tensor, ) -> Tuple[paddle.Tensor, paddle.Tensor]:
ys_in_lens: paddle.Tensor,
r_ys_in_pad: paddle.Tensor=paddle.empty([0]),
reverse_weight: float=0.0) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Forward decoder.
Args:
memory: encoded memory, float32 (batch, maxlen_in, feat)
memory_mask: encoder memory mask, (batch, 1, maxlen_in)
ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
ys_in_lens: input lengths of this batch (batch)
r_ys_in_pad: not used in transformer decoder, in order to unify api
with bidirectional decoder
reverse_weight: not used in transformer decoder, in order to unify
api with bidirectional decode
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out, vocab_size)
@ -151,7 +156,7 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
# TODO(Hui Zhang): reduce_sum not support bool type
# olens = tgt_mask.sum(1)
olens = tgt_mask.astype(paddle.int).sum(1)
return x, olens
return x, paddle.to_tensor(0.0), olens
def forward_one_step(
self,
@ -251,3 +256,119 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
state_list = [[states[i][b] for i in range(n_layers)]
for b in range(n_batch)]
return logp, state_list
class BiTransformerDecoder(BatchScorerInterface, nn.Layer):
"""Base class of Transfomer decoder module.
Args:
vocab_size: output dim
encoder_output_size: dimension of attention
attention_heads: the number of heads of multi head attention
linear_units: the hidden units number of position-wise feedforward
num_blocks: the number of decoder blocks
r_num_blocks: the number of right to left decoder blocks
dropout_rate: dropout rate
self_attention_dropout_rate: dropout rate for attention
input_layer: input layer type
use_output_layer: whether to use output layer
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
normalize_before:
True: use layer_norm before each sub-block of a layer.
False: use layer_norm after each sub-block of a layer.
concat_after: whether to concat attention layer's input and output
True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x)
"""
def __init__(self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int=4,
linear_units: int=2048,
num_blocks: int=6,
r_num_blocks: int=0,
dropout_rate: float=0.1,
positional_dropout_rate: float=0.1,
self_attention_dropout_rate: float=0.0,
src_attention_dropout_rate: float=0.0,
input_layer: str="embed",
use_output_layer: bool=True,
normalize_before: bool=True,
concat_after: bool=False,
max_len: int=5000):
assert check_argument_types()
nn.Layer.__init__(self)
self.left_decoder = TransformerDecoder(
vocab_size, encoder_output_size, attention_heads, linear_units,
num_blocks, dropout_rate, positional_dropout_rate,
self_attention_dropout_rate, src_attention_dropout_rate,
input_layer, use_output_layer, normalize_before, concat_after,
max_len)
self.right_decoder = TransformerDecoder(
vocab_size, encoder_output_size, attention_heads, linear_units,
r_num_blocks, dropout_rate, positional_dropout_rate,
self_attention_dropout_rate, src_attention_dropout_rate,
input_layer, use_output_layer, normalize_before, concat_after,
max_len)
def forward(
self,
memory: paddle.Tensor,
memory_mask: paddle.Tensor,
ys_in_pad: paddle.Tensor,
ys_in_lens: paddle.Tensor,
r_ys_in_pad: paddle.Tensor,
reverse_weight: float=0.0,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Forward decoder.
Args:
memory: encoded memory, float32 (batch, maxlen_in, feat)
memory_mask: encoder memory mask, (batch, 1, maxlen_in)
ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
ys_in_lens: input lengths of this batch (batch)
r_ys_in_pad: padded input token ids, int64 (batch, maxlen_out),
used for right to left decoder
reverse_weight: used for right to left decoder
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out,
vocab_size) if use_output_layer is True,
r_x: x: decoded token score (right to left decoder)
before softmax (batch, maxlen_out, vocab_size)
if use_output_layer is True,
olens: (batch, )
"""
l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad,
ys_in_lens)
r_x = paddle.to_tensor(0.0)
if reverse_weight > 0.0:
r_x, _, olens = self.right_decoder(memory, memory_mask, r_ys_in_pad,
ys_in_lens)
return l_x, r_x, olens
def forward_one_step(
self,
memory: paddle.Tensor,
memory_mask: paddle.Tensor,
tgt: paddle.Tensor,
tgt_mask: paddle.Tensor,
cache: Optional[List[paddle.Tensor]]=None,
) -> Tuple[paddle.Tensor, List[paddle.Tensor]]:
"""Forward one step.
This is only used for decoding.
Args:
memory: encoded memory, float32 (batch, maxlen_in, feat)
memory_mask: encoded memory mask, (batch, 1, maxlen_in)
tgt: input token ids, int64 (batch, maxlen_out)
tgt_mask: input token mask, (batch, maxlen_out, maxlen_out)
dtype=paddle.bool
cache: cached output list of (batch, max_time_out-1, size)
Returns:
y, cache: NN output value and cache per `self.decoders`.
y.shape` is (batch, maxlen_out, token)
"""
return self.left_decoder.forward_one_step(memory, memory_mask, tgt,
tgt_mask, cache)

@ -612,7 +612,8 @@ class PaddleASRConnectionHanddler:
encoder_out = self.encoder_out.repeat(beam_size, 1, 1)
encoder_mask = paddle.ones(
(beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool)
decoder_out, _ = self.model.decoder(
decoder_out, _, _ = self.model.decoder(
encoder_out, encoder_mask, hyps_pad,
hyps_lens) # (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain

Loading…
Cancel
Save