Merge pull request #2415 from Zth9730/u2++_decoder

[s2t] support bitransformer decoder
pull/2451/head
Hui Zhang 2 years ago committed by GitHub
commit 1a1ce92cb4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -31,7 +31,6 @@ def has_tensor(val):
return True return True
elif isinstance(val, dict): elif isinstance(val, dict):
for k, v in val.items(): for k, v in val.items():
print(k)
if has_tensor(v): if has_tensor(v):
return True return True
else: else:
@ -143,14 +142,15 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
[ 7, 8, 9, 11, -1, -1]]) [ 7, 8, 9, 11, -1, -1]])
""" """
# TODO(Hui Zhang): using comment code, # TODO(Hui Zhang): using comment code,
#_sos = paddle.to_tensor( # _sos = paddle.to_tensor(
# [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place) # [sos], dtype=ys_pad.dtype, stop_gradient=True, place=ys_pad.place)
#_eos = paddle.to_tensor( # _eos = paddle.to_tensor(
# [eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place) # [eos], dtype=ys_pad.dtype, stop_gradient=True, place=ys_pad.place)
#ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys # ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
#ys_in = [paddle.cat([_sos, y], dim=0) for y in ys] # ys_in = [paddle.concat([_sos, y], axis=0) for y in ys]
#ys_out = [paddle.cat([y, _eos], dim=0) for y in ys] # ys_out = [paddle.concat([y, _eos], axis=0) for y in ys]
#return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id) # return pad_sequence(ys_in, padding_value=eos).transpose([1,0]), pad_sequence(ys_out, padding_value=ignore_id).transpose([1,0])
B = ys_pad.shape[0] B = ys_pad.shape[0]
_sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos _sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos
_eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos _eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos
@ -190,3 +190,106 @@ def th_accuracy(pad_outputs: paddle.Tensor,
# denominator = paddle.sum(mask) # denominator = paddle.sum(mask)
denominator = paddle.sum(mask.type_as(pad_targets)) denominator = paddle.sum(mask.type_as(pad_targets))
return float(numerator) / float(denominator) return float(numerator) / float(denominator)
def reverse_pad_list(ys_pad: paddle.Tensor,
ys_lens: paddle.Tensor,
pad_value: float=-1.0) -> paddle.Tensor:
"""Reverse padding for the list of tensors.
Args:
ys_pad (tensor): The padded tensor (B, Tokenmax).
ys_lens (tensor): The lens of token seqs (B)
pad_value (int): Value for padding.
Returns:
Tensor: Padded tensor (B, Tokenmax).
Examples:
>>> x
tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]])
>>> pad_list(x, 0)
tensor([[4, 3, 2, 1],
[7, 6, 5, 0],
[9, 8, 0, 0]])
"""
r_ys_pad = pad_sequence([(paddle.flip(y.int()[:i], [0]))
for y, i in zip(ys_pad, ys_lens)], True, pad_value)
return r_ys_pad
def st_reverse_pad_list(ys_pad: paddle.Tensor,
ys_lens: paddle.Tensor,
sos: float,
eos: float) -> paddle.Tensor:
"""Reverse padding for the list of tensors.
Args:
ys_pad (tensor): The padded tensor (B, Tokenmax).
ys_lens (tensor): The lens of token seqs (B)
Returns:
Tensor: Padded tensor (B, Tokenmax).
Examples:
>>> x
tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]])
>>> pad_list(x, 0)
tensor([[4, 3, 2, 1],
[7, 6, 5, 0],
[9, 8, 0, 0]])
"""
# Equal to:
# >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id))
# >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id)
B = ys_pad.shape[0]
_sos = paddle.full([B, 1], sos, dtype=ys_pad.dtype)
max_len = paddle.max(ys_lens)
index_range = paddle.arange(0, max_len, 1)
seq_len_expand = ys_lens.unsqueeze(1)
seq_mask = seq_len_expand > index_range # (beam, max_len)
index = (seq_len_expand - 1) - index_range # (beam, max_len)
# >>> index
# >>> tensor([[ 2, 1, 0],
# >>> [ 2, 1, 0],
# >>> [ 0, -1, -2]])
index = index * seq_mask
# >>> index
# >>> tensor([[2, 1, 0],
# >>> [2, 1, 0],
# >>> [0, 0, 0]])
def paddle_gather(x, dim, index):
index_shape = index.shape
index_flatten = index.flatten()
if dim < 0:
dim = len(x.shape) + dim
nd_index = []
for k in range(len(x.shape)):
if k == dim:
nd_index.append(index_flatten)
else:
reshape_shape = [1] * len(x.shape)
reshape_shape[k] = x.shape[k]
x_arange = paddle.arange(x.shape[k], dtype=index.dtype)
x_arange = x_arange.reshape(reshape_shape)
dim_index = paddle.expand(x_arange, index_shape).flatten()
nd_index.append(dim_index)
ind2 = paddle.transpose(paddle.stack(nd_index), [1, 0]).astype("int64")
paddle_out = paddle.gather_nd(x, ind2).reshape(index_shape)
return paddle_out
r_hyps = paddle_gather(ys_pad, 1, index)
# >>> r_hyps
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, 2, 2]])
eos = paddle.full([1], eos, dtype=r_hyps.dtype)
r_hyps = paddle.where(seq_mask, r_hyps, eos)
# >>> r_hyps
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, eos, eos]])
r_hyps = paddle.cat([_sos, r_hyps], dim=1)
# r_hyps = paddle.concat([hyps[:, 0:1], r_hyps], axis=1)
# >>> r_hyps
# >>> tensor([[sos, 3, 2, 1],
# >>> [sos, 4, 8, 9],
# >>> [sos, 2, eos, eos]])
return r_hyps

@ -40,7 +40,7 @@ class U2Infer():
self.preprocess_conf = config.preprocess_config self.preprocess_conf = config.preprocess_config
self.preprocess_args = {"train": False} self.preprocess_args = {"train": False}
self.preprocessing = Transformation(self.preprocess_conf) self.preprocessing = Transformation(self.preprocess_conf)
self.reverse_weight = getattr(config.model_conf, 'reverse_weight', 0.0)
self.text_feature = TextFeaturizer( self.text_feature = TextFeaturizer(
unit_type=config.unit_type, unit_type=config.unit_type,
vocab=config.vocab_filepath, vocab=config.vocab_filepath,
@ -89,7 +89,8 @@ class U2Infer():
ctc_weight=decode_config.ctc_weight, ctc_weight=decode_config.ctc_weight,
decoding_chunk_size=decode_config.decoding_chunk_size, decoding_chunk_size=decode_config.decoding_chunk_size,
num_decoding_left_chunks=decode_config.num_decoding_left_chunks, num_decoding_left_chunks=decode_config.num_decoding_left_chunks,
simulate_streaming=decode_config.simulate_streaming) simulate_streaming=decode_config.simulate_streaming,
reverse_weight=self.reverse_weight)
rsl = result_transcripts[0][0] rsl = result_transcripts[0][0]
utt = Path(self.audio_file).name utt = Path(self.audio_file).name
logger.info(f"hyp: {utt} {result_transcripts[0][0]}") logger.info(f"hyp: {utt} {result_transcripts[0][0]}")

@ -253,7 +253,6 @@ class U2Trainer(Trainer):
model_conf.output_dim = self.test_loader.vocab_size model_conf.output_dim = self.test_loader.vocab_size
model = U2Model.from_config(model_conf) model = U2Model.from_config(model_conf)
if self.parallel: if self.parallel:
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
@ -317,6 +316,7 @@ class U2Tester(U2Trainer):
vocab=self.config.vocab_filepath, vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix) spm_model_prefix=self.config.spm_model_prefix)
self.vocab_list = self.text_feature.vocab_list self.vocab_list = self.text_feature.vocab_list
self.reverse_weight = getattr(config.model_conf, 'reverse_weight', 0.0)
def id2token(self, texts, texts_len, text_feature): def id2token(self, texts, texts_len, text_feature):
""" ord() id to chr() chr """ """ ord() id to chr() chr """
@ -341,6 +341,7 @@ class U2Tester(U2Trainer):
start_time = time.time() start_time = time.time()
target_transcripts = self.id2token(texts, texts_len, self.text_feature) target_transcripts = self.id2token(texts, texts_len, self.text_feature)
result_transcripts, result_tokenids = self.model.decode( result_transcripts, result_tokenids = self.model.decode(
audio, audio,
audio_len, audio_len,
@ -350,7 +351,8 @@ class U2Tester(U2Trainer):
ctc_weight=decode_config.ctc_weight, ctc_weight=decode_config.ctc_weight,
decoding_chunk_size=decode_config.decoding_chunk_size, decoding_chunk_size=decode_config.decoding_chunk_size,
num_decoding_left_chunks=decode_config.num_decoding_left_chunks, num_decoding_left_chunks=decode_config.num_decoding_left_chunks,
simulate_streaming=decode_config.simulate_streaming) simulate_streaming=decode_config.simulate_streaming,
reverse_weight=self.reverse_weight)
decode_time = time.time() - start_time decode_time = time.time() - start_time
for utt, target, result, rec_tids in zip( for utt, target, result, rec_tids in zip(

@ -361,7 +361,7 @@ class DataLoaderFactory():
elif mode == 'valid': elif mode == 'valid':
config['manifest'] = config.dev_manifest config['manifest'] = config.dev_manifest
config['train_mode'] = False config['train_mode'] = False
elif model == 'test' or mode == 'align': elif mode == 'test' or mode == 'align':
config['manifest'] = config.test_manifest config['manifest'] = config.test_manifest
config['train_mode'] = False config['train_mode'] = False
config['dither'] = 0.0 config['dither'] = 0.0

@ -31,6 +31,8 @@ from paddle import nn
from paddlespeech.audio.utils.tensor_utils import add_sos_eos from paddlespeech.audio.utils.tensor_utils import add_sos_eos
from paddlespeech.audio.utils.tensor_utils import pad_sequence from paddlespeech.audio.utils.tensor_utils import pad_sequence
from paddlespeech.audio.utils.tensor_utils import reverse_pad_list
from paddlespeech.audio.utils.tensor_utils import st_reverse_pad_list
from paddlespeech.audio.utils.tensor_utils import th_accuracy from paddlespeech.audio.utils.tensor_utils import th_accuracy
from paddlespeech.s2t.decoders.scorers.ctc import CTCPrefixScorer from paddlespeech.s2t.decoders.scorers.ctc import CTCPrefixScorer
from paddlespeech.s2t.frontend.utility import IGNORE_ID from paddlespeech.s2t.frontend.utility import IGNORE_ID
@ -38,6 +40,7 @@ from paddlespeech.s2t.frontend.utility import load_cmvn
from paddlespeech.s2t.models.asr_interface import ASRInterface from paddlespeech.s2t.models.asr_interface import ASRInterface
from paddlespeech.s2t.modules.cmvn import GlobalCMVN from paddlespeech.s2t.modules.cmvn import GlobalCMVN
from paddlespeech.s2t.modules.ctc import CTCDecoderBase from paddlespeech.s2t.modules.ctc import CTCDecoderBase
from paddlespeech.s2t.modules.decoder import BiTransformerDecoder
from paddlespeech.s2t.modules.decoder import TransformerDecoder from paddlespeech.s2t.modules.decoder import TransformerDecoder
from paddlespeech.s2t.modules.encoder import ConformerEncoder from paddlespeech.s2t.modules.encoder import ConformerEncoder
from paddlespeech.s2t.modules.encoder import TransformerEncoder from paddlespeech.s2t.modules.encoder import TransformerEncoder
@ -69,6 +72,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
ctc: CTCDecoderBase, ctc: CTCDecoderBase,
ctc_weight: float=0.5, ctc_weight: float=0.5,
ignore_id: int=IGNORE_ID, ignore_id: int=IGNORE_ID,
reverse_weight: float=0.0,
lsm_weight: float=0.0, lsm_weight: float=0.0,
length_normalized_loss: bool=False, length_normalized_loss: bool=False,
**kwargs): **kwargs):
@ -82,6 +86,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.ignore_id = ignore_id self.ignore_id = ignore_id
self.ctc_weight = ctc_weight self.ctc_weight = ctc_weight
self.reverse_weight = reverse_weight
self.encoder = encoder self.encoder = encoder
self.decoder = decoder self.decoder = decoder
@ -171,12 +176,21 @@ class U2BaseModel(ASRInterface, nn.Layer):
self.ignore_id) self.ignore_id)
ys_in_lens = ys_pad_lens + 1 ys_in_lens = ys_pad_lens + 1
r_ys_pad = reverse_pad_list(ys_pad, ys_pad_lens, float(self.ignore_id))
r_ys_in_pad, r_ys_out_pad = add_sos_eos(r_ys_pad, self.sos, self.eos,
self.ignore_id)
# 1. Forward decoder # 1. Forward decoder
decoder_out, _ = self.decoder(encoder_out, encoder_mask, ys_in_pad, decoder_out, r_decoder_out, _ = self.decoder(
ys_in_lens) encoder_out, encoder_mask, ys_in_pad, ys_in_lens, r_ys_in_pad,
self.reverse_weight)
# 2. Compute attention loss # 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_out_pad) loss_att = self.criterion_att(decoder_out, ys_out_pad)
r_loss_att = paddle.to_tensor(0.0)
if self.reverse_weight > 0.0:
r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad)
loss_att = loss_att * (1 - self.reverse_weight
) + r_loss_att * self.reverse_weight
acc_att = th_accuracy( acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size), decoder_out.view(-1, self.vocab_size),
ys_out_pad, ys_out_pad,
@ -359,6 +373,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
# Let's assume B = batch_size # Let's assume B = batch_size
# encoder_out: (B, maxlen, encoder_dim) # encoder_out: (B, maxlen, encoder_dim)
# encoder_mask: (B, 1, Tmax) # encoder_mask: (B, 1, Tmax)
encoder_out, encoder_mask = self._forward_encoder( encoder_out, encoder_mask = self._forward_encoder(
speech, speech_lengths, decoding_chunk_size, speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks, simulate_streaming) num_decoding_left_chunks, simulate_streaming)
@ -500,7 +515,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
decoding_chunk_size: int=-1, decoding_chunk_size: int=-1,
num_decoding_left_chunks: int=-1, num_decoding_left_chunks: int=-1,
ctc_weight: float=0.0, ctc_weight: float=0.0,
simulate_streaming: bool=False, ) -> List[int]: simulate_streaming: bool=False,
reverse_weight: float=0.0, ) -> List[int]:
""" Apply attention rescoring decoding, CTC prefix beam search """ Apply attention rescoring decoding, CTC prefix beam search
is applied first to get nbest, then we resoring the nbest on is applied first to get nbest, then we resoring the nbest on
attention decoder with corresponding encoder out attention decoder with corresponding encoder out
@ -520,6 +536,9 @@ class U2BaseModel(ASRInterface, nn.Layer):
""" """
assert speech.shape[0] == speech_lengths.shape[0] assert speech.shape[0] == speech_lengths.shape[0]
assert decoding_chunk_size != 0 assert decoding_chunk_size != 0
if reverse_weight > 0.0:
# decoder should be a bitransformer decoder if reverse_weight > 0.0
assert hasattr(self.decoder, 'right_decoder')
device = speech.place device = speech.place
batch_size = speech.shape[0] batch_size = speech.shape[0]
# For attention rescoring we only support batch_size=1 # For attention rescoring we only support batch_size=1
@ -541,22 +560,30 @@ class U2BaseModel(ASRInterface, nn.Layer):
hyp_content, place=device, dtype=paddle.long) hyp_content, place=device, dtype=paddle.long)
hyp_list.append(hyp_content) hyp_list.append(hyp_content)
hyps_pad = pad_sequence(hyp_list, True, self.ignore_id) hyps_pad = pad_sequence(hyp_list, True, self.ignore_id)
ori_hyps_pad = hyps_pad
hyps_lens = paddle.to_tensor( hyps_lens = paddle.to_tensor(
[len(hyp[0]) for hyp in hyps], place=device, [len(hyp[0]) for hyp in hyps], place=device,
dtype=paddle.long) # (beam_size,) dtype=paddle.long) # (beam_size,)
hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id)
hyps_lens = hyps_lens + 1 # Add <sos> at begining hyps_lens = hyps_lens + 1 # Add <sos> at begining
encoder_out = encoder_out.repeat(beam_size, 1, 1) encoder_out = encoder_out.repeat(beam_size, 1, 1)
encoder_mask = paddle.ones( encoder_mask = paddle.ones(
(beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool) (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool)
decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps_pad, r_hyps_pad = st_reverse_pad_list(ori_hyps_pad, hyps_lens - 1, self.sos,
hyps_lens) # (beam_size, max_hyps_len, vocab_size) self.eos)
decoder_out, r_decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad,
reverse_weight) # (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain # ctc score in ln domain
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
decoder_out = decoder_out.numpy() decoder_out = decoder_out.numpy()
# r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a
# conventional transformer decoder.
r_decoder_out = paddle.nn.functional.log_softmax(r_decoder_out, axis=-1)
r_decoder_out = r_decoder_out.numpy()
# Only use decoder score for rescoring # Only use decoder score for rescoring
best_score = -float('inf') best_score = -float('inf')
best_index = 0 best_index = 0
@ -567,6 +594,12 @@ class U2BaseModel(ASRInterface, nn.Layer):
score += decoder_out[i][j][w] score += decoder_out[i][j][w]
# last decoder output token is `eos`, for laste decoder input token. # last decoder output token is `eos`, for laste decoder input token.
score += decoder_out[i][len(hyp[0])][self.eos] score += decoder_out[i][len(hyp[0])][self.eos]
if reverse_weight > 0:
r_score = 0.0
for j, w in enumerate(hyp[0]):
r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w]
r_score += r_decoder_out[i][len(hyp[0])][self.eos]
score = score * (1 - reverse_weight) + r_score * reverse_weight
# add ctc score (which in ln domain) # add ctc score (which in ln domain)
score += hyp[1] * ctc_weight score += hyp[1] * ctc_weight
if score > best_score: if score > best_score:
@ -653,12 +686,24 @@ class U2BaseModel(ASRInterface, nn.Layer):
""" """
return self.ctc.log_softmax(xs) return self.ctc.log_softmax(xs)
@jit.to_static # @jit.to_static
def is_bidirectional_decoder(self) -> bool:
"""
Returns:
paddle.Tensor: decoder output
"""
if hasattr(self.decoder, 'right_decoder'):
return True
else:
return False
# @jit.to_static
def forward_attention_decoder( def forward_attention_decoder(
self, self,
hyps: paddle.Tensor, hyps: paddle.Tensor,
hyps_lens: paddle.Tensor, hyps_lens: paddle.Tensor,
encoder_out: paddle.Tensor, ) -> paddle.Tensor: encoder_out: paddle.Tensor,
reverse_weight: float=0.0, ) -> paddle.Tensor:
""" Export interface for c++ call, forward decoder with multiple """ Export interface for c++ call, forward decoder with multiple
hypothesis from ctc prefix beam search and one encoder output hypothesis from ctc prefix beam search and one encoder output
Args: Args:
@ -676,11 +721,22 @@ class U2BaseModel(ASRInterface, nn.Layer):
# (B, 1, T) # (B, 1, T)
encoder_mask = paddle.ones( encoder_mask = paddle.ones(
[num_hyps, 1, encoder_out.shape[1]], dtype=paddle.bool) [num_hyps, 1, encoder_out.shape[1]], dtype=paddle.bool)
# input for right to left decoder
# this hyps_lens has count <sos> token, we need minus it.
r_hyps_lens = hyps_lens - 1
# this hyps has included <sos> token, so it should be
# convert the original hyps.
r_hyps = hyps[:, 1:]
# (num_hyps, max_hyps_len, vocab_size) # (num_hyps, max_hyps_len, vocab_size)
decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps,
hyps_lens) r_hyps = st_reverse_pad_list(r_hyps, r_hyps_lens, self.sos, self.eos)
decoder_out, r_decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps, hyps_lens, r_hyps, reverse_weight)
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
return decoder_out r_decoder_out = paddle.nn.functional.log_softmax(r_decoder_out, axis=-1)
return decoder_out, r_decoder_out
@paddle.no_grad() @paddle.no_grad()
def decode(self, def decode(self,
@ -692,7 +748,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
ctc_weight: float=0.0, ctc_weight: float=0.0,
decoding_chunk_size: int=-1, decoding_chunk_size: int=-1,
num_decoding_left_chunks: int=-1, num_decoding_left_chunks: int=-1,
simulate_streaming: bool=False): simulate_streaming: bool=False,
reverse_weight: float=0.0):
"""u2 decoding. """u2 decoding.
Args: Args:
@ -764,7 +821,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
decoding_chunk_size=decoding_chunk_size, decoding_chunk_size=decoding_chunk_size,
num_decoding_left_chunks=num_decoding_left_chunks, num_decoding_left_chunks=num_decoding_left_chunks,
ctc_weight=ctc_weight, ctc_weight=ctc_weight,
simulate_streaming=simulate_streaming) simulate_streaming=simulate_streaming,
reverse_weight=reverse_weight)
hyps = [hyp] hyps = [hyp]
else: else:
raise ValueError(f"Not support decoding method: {decoding_method}") raise ValueError(f"Not support decoding method: {decoding_method}")
@ -801,7 +859,6 @@ class U2Model(U2DecodeModel):
with DefaultInitializerContext(init_type): with DefaultInitializerContext(init_type):
vocab_size, encoder, decoder, ctc = U2Model._init_from_config( vocab_size, encoder, decoder, ctc = U2Model._init_from_config(
configs) configs)
super().__init__( super().__init__(
vocab_size=vocab_size, vocab_size=vocab_size,
encoder=encoder, encoder=encoder,
@ -851,10 +908,20 @@ class U2Model(U2DecodeModel):
raise ValueError(f"not support encoder type:{encoder_type}") raise ValueError(f"not support encoder type:{encoder_type}")
# decoder # decoder
decoder = TransformerDecoder(vocab_size, decoder_type = configs.get('decoder', 'transformer')
encoder.output_size(), logger.debug(f"U2 Decoder type: {decoder_type}")
**configs['decoder_conf']) if decoder_type == 'transformer':
decoder = TransformerDecoder(vocab_size,
encoder.output_size(),
**configs['decoder_conf'])
elif decoder_type == 'bitransformer':
assert 0.0 < configs['model_conf']['reverse_weight'] < 1.0
assert configs['decoder_conf']['r_num_blocks'] > 0
decoder = BiTransformerDecoder(vocab_size,
encoder.output_size(),
**configs['decoder_conf'])
else:
raise ValueError(f"not support decoder type:{decoder_type}")
# ctc decoder and ctc loss # ctc decoder and ctc loss
model_conf = configs.get('model_conf', dict()) model_conf = configs.get('model_conf', dict())
dropout_rate = model_conf.get('ctc_dropout_rate', 0.0) dropout_rate = model_conf.get('ctc_dropout_rate', 0.0)

@ -35,7 +35,6 @@ from paddlespeech.s2t.modules.mask import make_xs_mask
from paddlespeech.s2t.modules.mask import subsequent_mask from paddlespeech.s2t.modules.mask import subsequent_mask
from paddlespeech.s2t.modules.positionwise_feed_forward import PositionwiseFeedForward from paddlespeech.s2t.modules.positionwise_feed_forward import PositionwiseFeedForward
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
__all__ = ["TransformerDecoder"] __all__ = ["TransformerDecoder"]
@ -116,13 +115,19 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
memory: paddle.Tensor, memory: paddle.Tensor,
memory_mask: paddle.Tensor, memory_mask: paddle.Tensor,
ys_in_pad: paddle.Tensor, ys_in_pad: paddle.Tensor,
ys_in_lens: paddle.Tensor, ) -> Tuple[paddle.Tensor, paddle.Tensor]: ys_in_lens: paddle.Tensor,
r_ys_in_pad: paddle.Tensor=paddle.empty([0]),
reverse_weight: float=0.0) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Forward decoder. """Forward decoder.
Args: Args:
memory: encoded memory, float32 (batch, maxlen_in, feat) memory: encoded memory, float32 (batch, maxlen_in, feat)
memory_mask: encoder memory mask, (batch, 1, maxlen_in) memory_mask: encoder memory mask, (batch, 1, maxlen_in)
ys_in_pad: padded input token ids, int64 (batch, maxlen_out) ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
ys_in_lens: input lengths of this batch (batch) ys_in_lens: input lengths of this batch (batch)
r_ys_in_pad: not used in transformer decoder, in order to unify api
with bidirectional decoder
reverse_weight: not used in transformer decoder, in order to unify
api with bidirectional decode
Returns: Returns:
(tuple): tuple containing: (tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out, vocab_size) x: decoded token score before softmax (batch, maxlen_out, vocab_size)
@ -151,7 +156,7 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
# TODO(Hui Zhang): reduce_sum not support bool type # TODO(Hui Zhang): reduce_sum not support bool type
# olens = tgt_mask.sum(1) # olens = tgt_mask.sum(1)
olens = tgt_mask.astype(paddle.int).sum(1) olens = tgt_mask.astype(paddle.int).sum(1)
return x, olens return x, paddle.to_tensor(0.0), olens
def forward_one_step( def forward_one_step(
self, self,
@ -251,3 +256,119 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
state_list = [[states[i][b] for i in range(n_layers)] state_list = [[states[i][b] for i in range(n_layers)]
for b in range(n_batch)] for b in range(n_batch)]
return logp, state_list return logp, state_list
class BiTransformerDecoder(BatchScorerInterface, nn.Layer):
"""Base class of Transfomer decoder module.
Args:
vocab_size: output dim
encoder_output_size: dimension of attention
attention_heads: the number of heads of multi head attention
linear_units: the hidden units number of position-wise feedforward
num_blocks: the number of decoder blocks
r_num_blocks: the number of right to left decoder blocks
dropout_rate: dropout rate
self_attention_dropout_rate: dropout rate for attention
input_layer: input layer type
use_output_layer: whether to use output layer
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
normalize_before:
True: use layer_norm before each sub-block of a layer.
False: use layer_norm after each sub-block of a layer.
concat_after: whether to concat attention layer's input and output
True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x)
"""
def __init__(self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int=4,
linear_units: int=2048,
num_blocks: int=6,
r_num_blocks: int=0,
dropout_rate: float=0.1,
positional_dropout_rate: float=0.1,
self_attention_dropout_rate: float=0.0,
src_attention_dropout_rate: float=0.0,
input_layer: str="embed",
use_output_layer: bool=True,
normalize_before: bool=True,
concat_after: bool=False,
max_len: int=5000):
assert check_argument_types()
nn.Layer.__init__(self)
self.left_decoder = TransformerDecoder(
vocab_size, encoder_output_size, attention_heads, linear_units,
num_blocks, dropout_rate, positional_dropout_rate,
self_attention_dropout_rate, src_attention_dropout_rate,
input_layer, use_output_layer, normalize_before, concat_after,
max_len)
self.right_decoder = TransformerDecoder(
vocab_size, encoder_output_size, attention_heads, linear_units,
r_num_blocks, dropout_rate, positional_dropout_rate,
self_attention_dropout_rate, src_attention_dropout_rate,
input_layer, use_output_layer, normalize_before, concat_after,
max_len)
def forward(
self,
memory: paddle.Tensor,
memory_mask: paddle.Tensor,
ys_in_pad: paddle.Tensor,
ys_in_lens: paddle.Tensor,
r_ys_in_pad: paddle.Tensor,
reverse_weight: float=0.0,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Forward decoder.
Args:
memory: encoded memory, float32 (batch, maxlen_in, feat)
memory_mask: encoder memory mask, (batch, 1, maxlen_in)
ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
ys_in_lens: input lengths of this batch (batch)
r_ys_in_pad: padded input token ids, int64 (batch, maxlen_out),
used for right to left decoder
reverse_weight: used for right to left decoder
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out,
vocab_size) if use_output_layer is True,
r_x: x: decoded token score (right to left decoder)
before softmax (batch, maxlen_out, vocab_size)
if use_output_layer is True,
olens: (batch, )
"""
l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad,
ys_in_lens)
r_x = paddle.to_tensor(0.0)
if reverse_weight > 0.0:
r_x, _, olens = self.right_decoder(memory, memory_mask, r_ys_in_pad,
ys_in_lens)
return l_x, r_x, olens
def forward_one_step(
self,
memory: paddle.Tensor,
memory_mask: paddle.Tensor,
tgt: paddle.Tensor,
tgt_mask: paddle.Tensor,
cache: Optional[List[paddle.Tensor]]=None,
) -> Tuple[paddle.Tensor, List[paddle.Tensor]]:
"""Forward one step.
This is only used for decoding.
Args:
memory: encoded memory, float32 (batch, maxlen_in, feat)
memory_mask: encoded memory mask, (batch, 1, maxlen_in)
tgt: input token ids, int64 (batch, maxlen_out)
tgt_mask: input token mask, (batch, maxlen_out, maxlen_out)
dtype=paddle.bool
cache: cached output list of (batch, max_time_out-1, size)
Returns:
y, cache: NN output value and cache per `self.decoders`.
y.shape` is (batch, maxlen_out, token)
"""
return self.left_decoder.forward_one_step(memory, memory_mask, tgt,
tgt_mask, cache)

@ -612,7 +612,8 @@ class PaddleASRConnectionHanddler:
encoder_out = self.encoder_out.repeat(beam_size, 1, 1) encoder_out = self.encoder_out.repeat(beam_size, 1, 1)
encoder_mask = paddle.ones( encoder_mask = paddle.ones(
(beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool) (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool)
decoder_out, _ = self.model.decoder(
decoder_out, _, _ = self.model.decoder(
encoder_out, encoder_mask, hyps_pad, encoder_out, encoder_mask, hyps_pad,
hyps_lens) # (beam_size, max_hyps_len, vocab_size) hyps_lens) # (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain # ctc score in ln domain

Loading…
Cancel
Save