|
|
@ -478,7 +478,8 @@ class U2Tester(U2Trainer):
|
|
|
|
del input_spec
|
|
|
|
del input_spec
|
|
|
|
infer_model.eval()
|
|
|
|
infer_model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
######################### infer_model.forward_encoder_chunk zero Tensor online ########################
|
|
|
|
######################### infer_model.forward_encoder_chunk zero Tensor online ############
|
|
|
|
|
|
|
|
# TODO: 80(feature dim) be configable
|
|
|
|
input_spec = [
|
|
|
|
input_spec = [
|
|
|
|
paddle.static.InputSpec(shape=[1, None, 80], dtype='float32'),
|
|
|
|
paddle.static.InputSpec(shape=[1, None, 80], dtype='float32'),
|
|
|
|
paddle.static.InputSpec(shape=[1], dtype='int32'), -1,
|
|
|
|
paddle.static.InputSpec(shape=[1], dtype='int32'), -1,
|
|
|
@ -492,6 +493,7 @@ class U2Tester(U2Trainer):
|
|
|
|
# paddle.jit.save(static_model, self.args.export_path, combine_params=True)
|
|
|
|
# paddle.jit.save(static_model, self.args.export_path, combine_params=True)
|
|
|
|
|
|
|
|
|
|
|
|
######################### infer_model.forward_attention_decoder ########################
|
|
|
|
######################### infer_model.forward_attention_decoder ########################
|
|
|
|
|
|
|
|
# TODO: 512(encoder_output) be configable. 1 for B
|
|
|
|
input_spec = [
|
|
|
|
input_spec = [
|
|
|
|
paddle.static.InputSpec(shape=[None, None], dtype='int64'),
|
|
|
|
paddle.static.InputSpec(shape=[None, None], dtype='int64'),
|
|
|
|
paddle.static.InputSpec(shape=[None], dtype='int64'),
|
|
|
|
paddle.static.InputSpec(shape=[None], dtype='int64'),
|
|
|
@ -501,4 +503,12 @@ class U2Tester(U2Trainer):
|
|
|
|
infer_model.forward_attention_decoder, input_spec=input_spec)
|
|
|
|
infer_model.forward_attention_decoder, input_spec=input_spec)
|
|
|
|
# paddle.jit.save(static_model, self.args.export_path, combine_params=True)
|
|
|
|
# paddle.jit.save(static_model, self.args.export_path, combine_params=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
######################### infer_model.ctc_activation ########################
|
|
|
|
|
|
|
|
# TODO: 512(encoder_output) be configable
|
|
|
|
|
|
|
|
input_spec = [
|
|
|
|
|
|
|
|
paddle.static.InputSpec(shape=[1, None, 512], dtype='float32')
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
infer_model.ctc_activation = paddle.jit.to_static(
|
|
|
|
|
|
|
|
infer_model.ctc_activation, input_spec=input_spec)
|
|
|
|
|
|
|
|
|
|
|
|
paddle.jit.save(infer_model, './export.jit', combine_params=True)
|
|
|
|
paddle.jit.save(infer_model, './export.jit', combine_params=True)
|
|
|
|