support bitransformer decoder

pull/2415/head
tianhao zhang 3 years ago
parent 455379b88e
commit 0a95689461

@ -40,7 +40,7 @@ class U2Infer():
self.preprocess_conf = config.preprocess_config
self.preprocess_args = {"train": False}
self.preprocessing = Transformation(self.preprocess_conf)
self.reverse_weight = getattr(config.model_conf, 'reverse_weight', 0.0)
self.text_feature = TextFeaturizer(
unit_type=config.unit_type,
vocab=config.vocab_filepath,
@ -90,7 +90,7 @@ class U2Infer():
decoding_chunk_size=decode_config.decoding_chunk_size,
num_decoding_left_chunks=decode_config.num_decoding_left_chunks,
simulate_streaming=decode_config.simulate_streaming,
reverse_weight=self.config.model_conf.reverse_weight)
reverse_weight=self.reverse_weight)
rsl = result_transcripts[0][0]
utt = Path(self.audio_file).name
logger.info(f"hyp: {utt} {result_transcripts[0][0]}")

@ -316,7 +316,7 @@ class U2Tester(U2Trainer):
vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix)
self.vocab_list = self.text_feature.vocab_list
self.reverse_weight = getattr(config, 'reverse_weight', '0.0')
self.reverse_weight = getattr(config.model_conf, 'reverse_weight', 0.0)
def id2token(self, texts, texts_len, text_feature):
""" ord() id to chr() chr """

@ -689,24 +689,24 @@ class U2BaseModel(ASRInterface, nn.Layer):
"""
return self.ctc.log_softmax(xs)
@jit.to_static
# @jit.to_static
def is_bidirectional_decoder(self) -> bool:
"""
Returns:
torch.Tensor: decoder output
paddle.Tensor: decoder output
"""
if hasattr(self.decoder, 'right_decoder'):
return True
else:
return False
@jit.to_static
# @jit.to_static
def forward_attention_decoder(
self,
hyps: paddle.Tensor,
hyps_lens: paddle.Tensor,
encoder_out: paddle.Tensor,
reverse_weight: float=0, ) -> paddle.Tensor:
reverse_weight: float=0.0, ) -> paddle.Tensor:
""" Export interface for c++ call, forward decoder with multiple
hypothesis from ctc prefix beam search and one encoder output
Args:
@ -783,7 +783,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, eos, eos]])
r_hyps = torch.concat([hyps[:, 0:1], r_hyps], axis=1)
r_hyps = paddle.concat([hyps[:, 0:1], r_hyps], axis=1)
# >>> r_hyps
# >>> tensor([[sos, 3, 2, 1],
# >>> [sos, 4, 8, 9],
@ -791,7 +791,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps,
hyps_lens, r_hyps, reverse_weight)
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1)
r_decoder_out = paddle.nn.functional.log_softmax(r_decoder_out, axis=-1)
return decoder_out, r_decoder_out
@paddle.no_grad()

@ -363,9 +363,8 @@ class BiTransformerDecoder(BatchScorerInterface, nn.Layer):
memory: encoded memory, float32 (batch, maxlen_in, feat)
memory_mask: encoded memory mask, (batch, 1, maxlen_in)
tgt: input token ids, int64 (batch, maxlen_out)
tgt_mask: input token mask, (batch, maxlen_out)
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
tgt_mask: input token mask, (batch, maxlen_out, maxlen_out)
dtype=paddle.bool
cache: cached output list of (batch, max_time_out-1, size)
Returns:
y, cache: NN output value and cache per `self.decoders`.

Loading…
Cancel
Save