|
|
|
@ -492,10 +492,9 @@ class U2Tester(U2Trainer):
|
|
|
|
|
]
|
|
|
|
|
infer_model.forward_encoder_chunk = paddle.jit.to_static(
|
|
|
|
|
infer_model.forward_encoder_chunk, input_spec=input_spec)
|
|
|
|
|
# paddle.jit.save(static_model, self.args.export_path, combine_params=True)
|
|
|
|
|
|
|
|
|
|
######################### infer_model.forward_attention_decoder ########################
|
|
|
|
|
# TODO: 512(encoder_output) be configable. 1 for B
|
|
|
|
|
# TODO: 512(encoder_output) be configable. 1 for BatchSize
|
|
|
|
|
input_spec = [
|
|
|
|
|
paddle.static.InputSpec(shape=[None, None], dtype='int64'),
|
|
|
|
|
paddle.static.InputSpec(shape=[None], dtype='int64'),
|
|
|
|
@ -503,7 +502,6 @@ class U2Tester(U2Trainer):
|
|
|
|
|
]
|
|
|
|
|
infer_model.forward_attention_decoder = paddle.jit.to_static(
|
|
|
|
|
infer_model.forward_attention_decoder, input_spec=input_spec)
|
|
|
|
|
# paddle.jit.save(static_model, self.args.export_path, combine_params=True)
|
|
|
|
|
|
|
|
|
|
######################### infer_model.ctc_activation ########################
|
|
|
|
|
# TODO: 512(encoder_output) be configable
|
|
|
|
@ -513,8 +511,10 @@ class U2Tester(U2Trainer):
|
|
|
|
|
infer_model.ctc_activation = paddle.jit.to_static(
|
|
|
|
|
infer_model.ctc_activation, input_spec=input_spec)
|
|
|
|
|
|
|
|
|
|
paddle.jit.save(infer_model, './export.jit', combine_params=True, skip_forward=True)
|
|
|
|
|
# jit save
|
|
|
|
|
paddle.jit.save(infer_model, self.args.export_path, combine_params=True, skip_forward=True)
|
|
|
|
|
|
|
|
|
|
# test dy2static
|
|
|
|
|
def flatten(out):
|
|
|
|
|
if isinstance(out, paddle.Tensor):
|
|
|
|
|
return [out]
|
|
|
|
@ -541,7 +541,7 @@ class U2Tester(U2Trainer):
|
|
|
|
|
|
|
|
|
|
from paddle.jit.layer import Layer
|
|
|
|
|
layer = Layer()
|
|
|
|
|
layer.load('./export.jit', paddle.CPUPlace())
|
|
|
|
|
layer.load(self.args.export_path, paddle.CPUPlace())
|
|
|
|
|
|
|
|
|
|
xs1 = paddle.full([1, 7, 80], 0.1, dtype='float32')
|
|
|
|
|
offset = paddle.to_tensor([0], dtype='int32')
|
|
|
|
|