support bitransformer decoder, test=asr

pull/2415/head
tianhao zhang 3 years ago
parent 1a56a6e42b
commit ecbf324286

@ -613,7 +613,8 @@ class PaddleASRConnectionHanddler:
encoder_out = self.encoder_out.repeat(beam_size, 1, 1)
encoder_mask = paddle.ones(
(beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool)
decoder_out, _ = self.model.decoder(
decoder_out, _, _ = self.model.decoder(
encoder_out, encoder_mask, hyps_pad,
hyps_lens) # (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain

Loading…
Cancel
Save