diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index bdda0fd8c..09d0982ec 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -519,6 +519,7 @@ class U2Tester(U2Trainer): infer_model.ctc_activation, input_spec=input_spec) ######################### infer_model.forward_attention_decoder ######################## + reverse_weight = 0.3 input_spec = [ # hyps, (B, U) paddle.static.InputSpec(shape=[None, None], dtype='int64'), @@ -526,7 +527,8 @@ class U2Tester(U2Trainer): paddle.static.InputSpec(shape=[None], dtype='int64'), # encoder_out, (B,T,D) paddle.static.InputSpec( - shape=[batch_size, None, model_size], dtype='float32') + shape=[batch_size, None, model_size], dtype='float32'), + reverse_weight ] infer_model.forward_attention_decoder = paddle.jit.to_static( infer_model.forward_attention_decoder, input_spec=input_spec) diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index 632e0a615..e1e9b05d7 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -716,7 +716,7 @@ class U2BaseModel(ASRInterface, nn.Layer): hyps: paddle.Tensor, hyps_lens: paddle.Tensor, encoder_out: paddle.Tensor, - reverse_weight: float=0.0, ) -> paddle.Tensor: + reverse_weight: float=0.0) -> paddle.Tensor: """ Export interface for c++ call, forward decoder with multiple hypothesis from ctc prefix beam search and one encoder output Args: