support bitransformer decoder

pull/2415/head
tianhao zhang 3 years ago
parent 455379b88e
commit 0a95689461

@ -40,7 +40,7 @@ class U2Infer():
self.preprocess_conf = config.preprocess_config self.preprocess_conf = config.preprocess_config
self.preprocess_args = {"train": False} self.preprocess_args = {"train": False}
self.preprocessing = Transformation(self.preprocess_conf) self.preprocessing = Transformation(self.preprocess_conf)
self.reverse_weight = getattr(config.model_conf, 'reverse_weight', 0.0)
self.text_feature = TextFeaturizer( self.text_feature = TextFeaturizer(
unit_type=config.unit_type, unit_type=config.unit_type,
vocab=config.vocab_filepath, vocab=config.vocab_filepath,
@ -90,7 +90,7 @@ class U2Infer():
decoding_chunk_size=decode_config.decoding_chunk_size, decoding_chunk_size=decode_config.decoding_chunk_size,
num_decoding_left_chunks=decode_config.num_decoding_left_chunks, num_decoding_left_chunks=decode_config.num_decoding_left_chunks,
simulate_streaming=decode_config.simulate_streaming, simulate_streaming=decode_config.simulate_streaming,
reverse_weight=self.config.model_conf.reverse_weight) reverse_weight=self.reverse_weight)
rsl = result_transcripts[0][0] rsl = result_transcripts[0][0]
utt = Path(self.audio_file).name utt = Path(self.audio_file).name
logger.info(f"hyp: {utt} {result_transcripts[0][0]}") logger.info(f"hyp: {utt} {result_transcripts[0][0]}")

@ -316,7 +316,7 @@ class U2Tester(U2Trainer):
vocab=self.config.vocab_filepath, vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix) spm_model_prefix=self.config.spm_model_prefix)
self.vocab_list = self.text_feature.vocab_list self.vocab_list = self.text_feature.vocab_list
self.reverse_weight = getattr(config, 'reverse_weight', '0.0') self.reverse_weight = getattr(config.model_conf, 'reverse_weight', 0.0)
def id2token(self, texts, texts_len, text_feature): def id2token(self, texts, texts_len, text_feature):
""" ord() id to chr() chr """ """ ord() id to chr() chr """

@ -689,24 +689,24 @@ class U2BaseModel(ASRInterface, nn.Layer):
""" """
return self.ctc.log_softmax(xs) return self.ctc.log_softmax(xs)
@jit.to_static # @jit.to_static
def is_bidirectional_decoder(self) -> bool: def is_bidirectional_decoder(self) -> bool:
""" """
Returns: Returns:
torch.Tensor: decoder output paddle.Tensor: decoder output
""" """
if hasattr(self.decoder, 'right_decoder'): if hasattr(self.decoder, 'right_decoder'):
return True return True
else: else:
return False return False
@jit.to_static # @jit.to_static
def forward_attention_decoder( def forward_attention_decoder(
self, self,
hyps: paddle.Tensor, hyps: paddle.Tensor,
hyps_lens: paddle.Tensor, hyps_lens: paddle.Tensor,
encoder_out: paddle.Tensor, encoder_out: paddle.Tensor,
reverse_weight: float=0, ) -> paddle.Tensor: reverse_weight: float=0.0, ) -> paddle.Tensor:
""" Export interface for c++ call, forward decoder with multiple """ Export interface for c++ call, forward decoder with multiple
hypothesis from ctc prefix beam search and one encoder output hypothesis from ctc prefix beam search and one encoder output
Args: Args:
@ -783,7 +783,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
# >>> tensor([[3, 2, 1], # >>> tensor([[3, 2, 1],
# >>> [4, 8, 9], # >>> [4, 8, 9],
# >>> [2, eos, eos]]) # >>> [2, eos, eos]])
r_hyps = torch.concat([hyps[:, 0:1], r_hyps], axis=1) r_hyps = paddle.concat([hyps[:, 0:1], r_hyps], axis=1)
# >>> r_hyps # >>> r_hyps
# >>> tensor([[sos, 3, 2, 1], # >>> tensor([[sos, 3, 2, 1],
# >>> [sos, 4, 8, 9], # >>> [sos, 4, 8, 9],
@ -791,7 +791,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps, decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps,
hyps_lens, r_hyps, reverse_weight) hyps_lens, r_hyps, reverse_weight)
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) r_decoder_out = paddle.nn.functional.log_softmax(r_decoder_out, axis=-1)
return decoder_out, r_decoder_out return decoder_out, r_decoder_out
@paddle.no_grad() @paddle.no_grad()

@ -363,9 +363,8 @@ class BiTransformerDecoder(BatchScorerInterface, nn.Layer):
memory: encoded memory, float32 (batch, maxlen_in, feat) memory: encoded memory, float32 (batch, maxlen_in, feat)
memory_mask: encoded memory mask, (batch, 1, maxlen_in) memory_mask: encoded memory mask, (batch, 1, maxlen_in)
tgt: input token ids, int64 (batch, maxlen_out) tgt: input token ids, int64 (batch, maxlen_out)
tgt_mask: input token mask, (batch, maxlen_out) tgt_mask: input token mask, (batch, maxlen_out, maxlen_out)
dtype=torch.uint8 in PyTorch 1.2- dtype=paddle.bool
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
cache: cached output list of (batch, max_time_out-1, size) cache: cached output list of (batch, max_time_out-1, size)
Returns: Returns:
y, cache: NN output value and cache per `self.decoders`. y, cache: NN output value and cache per `self.decoders`.

Loading…
Cancel
Save