diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index 141e83bce..fdccdf159 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -478,7 +478,8 @@ class U2Tester(U2Trainer): del input_spec infer_model.eval() - ######################### infer_model.forward_encoder_chunk zero Tensor online ######################## + ######################### infer_model.forward_encoder_chunk zero Tensor online ############ + # TODO: 80(feature dim) be configable input_spec = [ paddle.static.InputSpec(shape=[1, None, 80], dtype='float32'), paddle.static.InputSpec(shape=[1], dtype='int32'), -1, @@ -492,6 +493,7 @@ class U2Tester(U2Trainer): # paddle.jit.save(static_model, self.args.export_path, combine_params=True) ######################### infer_model.forward_attention_decoder ######################## + # TODO: 512(encoder_output) be configable. 1 for B input_spec = [ paddle.static.InputSpec(shape=[None, None], dtype='int64'), paddle.static.InputSpec(shape=[None], dtype='int64'), @@ -501,4 +503,12 @@ class U2Tester(U2Trainer): infer_model.forward_attention_decoder, input_spec=input_spec) # paddle.jit.save(static_model, self.args.export_path, combine_params=True) + ######################### infer_model.ctc_activation ######################## + # TODO: 512(encoder_output) be configable + input_spec = [ + paddle.static.InputSpec(shape=[1, None, 512], dtype='float32') + ] + infer_model.ctc_activation = paddle.jit.to_static( + infer_model.ctc_activation, input_spec=input_spec) + paddle.jit.save(infer_model, './export.jit', combine_params=True)