export ctc_activation

pull/2212/head
Hui Zhang 3 years ago
parent 812d80ab1c
commit 6149daa221

@ -478,7 +478,8 @@ class U2Tester(U2Trainer):
del input_spec del input_spec
infer_model.eval() infer_model.eval()
######################### infer_model.forward_encoder_chunk zero Tensor online ######################## ######################### infer_model.forward_encoder_chunk zero Tensor online ############
# TODO: 80(feature dim) be configable
input_spec = [ input_spec = [
paddle.static.InputSpec(shape=[1, None, 80], dtype='float32'), paddle.static.InputSpec(shape=[1, None, 80], dtype='float32'),
paddle.static.InputSpec(shape=[1], dtype='int32'), -1, paddle.static.InputSpec(shape=[1], dtype='int32'), -1,
@ -492,6 +493,7 @@ class U2Tester(U2Trainer):
# paddle.jit.save(static_model, self.args.export_path, combine_params=True) # paddle.jit.save(static_model, self.args.export_path, combine_params=True)
######################### infer_model.forward_attention_decoder ######################## ######################### infer_model.forward_attention_decoder ########################
# TODO: 512(encoder_output) be configable. 1 for B
input_spec = [ input_spec = [
paddle.static.InputSpec(shape=[None, None], dtype='int64'), paddle.static.InputSpec(shape=[None, None], dtype='int64'),
paddle.static.InputSpec(shape=[None], dtype='int64'), paddle.static.InputSpec(shape=[None], dtype='int64'),
@ -501,4 +503,12 @@ class U2Tester(U2Trainer):
infer_model.forward_attention_decoder, input_spec=input_spec) infer_model.forward_attention_decoder, input_spec=input_spec)
# paddle.jit.save(static_model, self.args.export_path, combine_params=True) # paddle.jit.save(static_model, self.args.export_path, combine_params=True)
######################### infer_model.ctc_activation ########################
# TODO: 512(encoder_output) be configable
input_spec = [
paddle.static.InputSpec(shape=[1, None, 512], dtype='float32')
]
infer_model.ctc_activation = paddle.jit.to_static(
infer_model.ctc_activation, input_spec=input_spec)
paddle.jit.save(infer_model, './export.jit', combine_params=True) paddle.jit.save(infer_model, './export.jit', combine_params=True)

Loading…
Cancel
Save