diff --git a/speechx/examples/ds2_ol/onnx/local/infer_check.py b/speechx/examples/ds2_ol/onnx/local/infer_check.py new file mode 100755 index 00000000..4debf4d3 --- /dev/null +++ b/speechx/examples/ds2_ol/onnx/local/infer_check.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import numpy as np +import onnxruntime +import paddle +import os +import pickle + +def parse_args(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + '--input_file', + type=str, + default="static_ds2online_inputs.pickle", + help="ds2 input pickle file.", + ) + parser.add_argument( + '--model_dir', + type=str, + default=".", + help="paddle model dir." + ) + parser.add_argument( + '--onnx_model', + type=str, + default='./model.old.onnx', + help="onnx model." + ) + + return parser.parse_args() + + +if __name__ == '__main__': + FLAGS = parse_args() + + # input and output + with open(FLAGS.input_file, 'rb') as f: + iodict = pickle.load(f) + print(iodict.keys()) + + audio_chunk = iodict['audio_chunk'] + audio_chunk_lens = iodict['audio_chunk_lens'] + chunk_state_h_box = iodict['chunk_state_h_box'] + chunk_state_c_box = iodict['chunk_state_c_bos'] + + # paddle + model = paddle.jit.load(os.path.join(FLAGS.model_dir, "avg_1.jit")) + res_chunk, res_lens, chunk_state_h, chunk_state_c = model( + paddle.to_tensor(audio_chunk), + paddle.to_tensor(audio_chunk_lens), + paddle.to_tensor(chunk_state_h_box), + paddle.to_tensor(chunk_state_c_box), + ) + + # onnxruntime + options = onnxruntime.SessionOptions() + options.enable_profiling=True + sess = onnxruntime.InferenceSession(FLAGS.onnx_model, sess_options=options) + ort_res_chunk, ort_res_lens, ort_chunk_state_h, ort_chunk_state_c = sess.run( + ['softmax_0.tmp_0', 'tmp_5', 'concat_0.tmp_0', 'concat_1.tmp_0'], + {"audio_chunk": audio_chunk, + "audio_chunk_lens":audio_chunk_lens, + "chunk_state_h_box": chunk_state_h_box, + "chunk_state_c_box":chunk_state_c_box}) + + print(sess.end_profiling()) + + # assert paddle equal ort + print(np.allclose(ort_res_chunk, res_chunk, atol=1e-6)) + print(np.allclose(ort_res_lens, res_lens, atol=1e-6)) + print(np.allclose(ort_chunk_state_h, chunk_state_h, atol=1e-6)) + print(np.allclose(ort_chunk_state_c, chunk_state_c, atol=1e-6)) \ No newline at end of file diff --git a/speechx/examples/ds2_ol/onnx/run.sh b/speechx/examples/ds2_ol/onnx/run.sh index 07706749..b7c7e2fb 100755 --- a/speechx/examples/ds2_ol/onnx/run.sh +++ b/speechx/examples/ds2_ol/onnx/run.sh @@ -22,10 +22,12 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then popd fi + dir=$data/exp/deepspeech2_online/checkpoints model=avg_1.jit.pdmodel param=avg_1.jit.pdiparams + output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ];then # prune model by outputs @@ -47,7 +49,10 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ];then --input_shape_dict=${input_shape_dict} fi + if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ];then # to onnx ./local/tonnx.sh $dir $model $param $exp/model.onnx -fi \ No newline at end of file + ./local/infer_check.py --input_file 'static_ds2online_inputs.pickle' --model_dir $dir --onnx_model $exp/model.onnx +fi +