opt onnx script

pull/2037/head
Hui Zhang 3 years ago
parent b4c6a52beb
commit ff3b1ff817

@ -15,6 +15,7 @@
__all__ = [ __all__ = [
'asr_dynamic_pretrained_models', 'asr_dynamic_pretrained_models',
'asr_static_pretrained_models', 'asr_static_pretrained_models',
'asr_onnx_pretrained_models',
'cls_dynamic_pretrained_models', 'cls_dynamic_pretrained_models',
'cls_static_pretrained_models', 'cls_static_pretrained_models',
'st_dynamic_pretrained_models', 'st_dynamic_pretrained_models',
@ -224,6 +225,43 @@ asr_static_pretrained_models = {
'29e02312deb2e59b3c8686c7966d4fe3' '29e02312deb2e59b3c8686c7966d4fe3'
} }
}, },
"deepspeech2online_aishell-zh-16k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.1.model.tar.gz',
'md5':
'df5ddeac8b679a470176649ac4b78726',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_1',
'model':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel',
'params':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
},
}
asr_onnx_pretrained_models = {
"deepspeech2online_aishell-zh-16k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.1.model.tar.gz',
'md5': 'df5ddeac8b679a470176649ac4b78726',
'cfg_path': model.yaml',
'ckpt_path':'exp/deepspeech2_online/checkpoints/avg_1',
'model':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel',
'params':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams',
'lm_url':'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':'29e02312deb2e59b3c8686c7966d4fe3'
},
},
} }
# --------------------------------- # ---------------------------------

@ -27,7 +27,12 @@ def parse_args():
'--input_file', '--input_file',
type=str, type=str,
default="static_ds2online_inputs.pickle", default="static_ds2online_inputs.pickle",
help="ds2 input pickle file.", ) help="aishell ds2 input data file. For wenetspeech, we only feed for infer model", )
parser.add_argument(
'--model_type',
type=str,
default="aishell",
help="aishell(1024) or wenetspeech(2048)", )
parser.add_argument( parser.add_argument(
'--model_dir', type=str, default=".", help="paddle model dir.") '--model_dir', type=str, default=".", help="paddle model dir.")
parser.add_argument( parser.add_argument(
@ -52,10 +57,17 @@ if __name__ == '__main__':
iodict = pickle.load(f) iodict = pickle.load(f)
print(iodict.keys()) print(iodict.keys())
audio_chunk = iodict['audio_chunk'] audio_chunk = iodict['audio_chunk']
audio_chunk_lens = iodict['audio_chunk_lens'] audio_chunk_lens = iodict['audio_chunk_lens']
chunk_state_h_box = iodict['chunk_state_h_box'] chunk_state_h_box = iodict['chunk_state_h_box']
chunk_state_c_box = iodict['chunk_state_c_bos'] chunk_state_c_box = iodict['chunk_state_c_bos']
print("raw state shape: ", chunk_state_c_box.shape)
if FLAGS.model_type == 'wenetspeech':
chunk_state_h_box = np.repeat(chunk_state_h_box, 2, axis=-1)
chunk_state_c_box = np.repeat(chunk_state_c_box, 2, axis=-1)
print("state shape: ", chunk_state_c_box.shape)
# paddle # paddle
model = paddle.jit.load(os.path.join(FLAGS.model_dir, FLAGS.model_prefix)) model = paddle.jit.load(os.path.join(FLAGS.model_dir, FLAGS.model_prefix))
@ -82,5 +94,7 @@ if __name__ == '__main__':
# assert paddle equal ort # assert paddle equal ort
print(np.allclose(ort_res_chunk, res_chunk, atol=1e-6)) print(np.allclose(ort_res_chunk, res_chunk, atol=1e-6))
print(np.allclose(ort_res_lens, res_lens, atol=1e-6)) print(np.allclose(ort_res_lens, res_lens, atol=1e-6))
print(np.allclose(ort_chunk_state_h, chunk_state_h, atol=1e-6))
print(np.allclose(ort_chunk_state_c, chunk_state_c, atol=1e-6)) if FLAGS.model_type == 'aishell':
print(np.allclose(ort_chunk_state_h, chunk_state_h, atol=1e-6))
print(np.allclose(ort_chunk_state_c, chunk_state_c, atol=1e-6))

@ -6,6 +6,11 @@ set -e
stage=0 stage=0
stop_stage=100 stop_stage=100
#tarfile=asr0_deepspeech2_online_wenetspeech_ckpt_1.0.2.model.tar.gz
tarfile=asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.1.model.tar.gz
model_prefix=avg_1.jit
model=${model_prefix}.pdmodel
param=${model_prefix}.pdiparams
. utils/parse_options.sh . utils/parse_options.sh
@ -14,27 +19,25 @@ exp=exp
mkdir -p $data $exp mkdir -p $data $exp
dir=$data/exp/deepspeech2_online/checkpoints
# wenetspeech or aishell
model_type=$(echo $tarfile | cut -d '_' -f 4)
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then
test -f $data/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz || wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz -P $data test -f $data/$tarfile || wget -P $data -c https://paddlespeech.bj.bcebos.com/s2t/$model_type/asr0/$tarfile
# wenetspeech ds2 model # wenetspeech ds2 model
pushd $data pushd $data
tar zxvf asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz tar zxvf $tarfile
popd popd
# ds2 model demo inputs # ds2 model demo inputs
pushd $exp pushd $exp
wget -c http://paddlespeech.bj.bcebos.com/speechx/examples/ds2_ol/onnx/static_ds2online_inputs.pickle wget -c http://paddlespeech.bj.bcebos.com/speechx/examples/ds2_ol/onnx/static_ds2online_inputs.pickle
popd popd
fi fi
dir=$data/exp/deepspeech2_online/checkpoints
model=avg_1.jit.pdmodel
param=avg_1.jit.pdiparams
output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ];then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ];then
# prune model by outputs # prune model by outputs
@ -44,10 +47,20 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ];then
./local/prune.sh $dir $model $param $output_names $exp/prune ./local/prune.sh $dir $model $param $output_names $exp/prune
fi fi
input_shape_dict="{'audio_chunk':[1,-1,161], 'audio_chunk_lens':[1], 'chunk_state_c_box':[5, 1, 1024], 'chunk_state_h_box':[5,1,1024]}" # aishell rnn hidden is 1024
# wenetspeech rnn hiddn is 2048
if [ $model_type == 'aishell' ];then
input_shape_dict="{'audio_chunk':[1,-1,161], 'audio_chunk_lens':[1], 'chunk_state_c_box':[5, 1, 1024], 'chunk_state_h_box':[5,1,1024]}"
elif [ $model_type == 'wenetspeech' ];then
input_shape_dict="{'audio_chunk':[1,-1,161], 'audio_chunk_lens':[1], 'chunk_state_c_box':[5, 1, 2048], 'chunk_state_h_box':[5,1,2048]}"
else
echo "not support: $model_type"
exit -1
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ];then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ];then
# infer shape by new shape # infer shape by new shape
mkdir -p $exp/shape mkdir -p $exp/shape
echo $input_shape_dict
python3 local/pd_infer_shape.py \ python3 local/pd_infer_shape.py \
--model_dir $dir \ --model_dir $dir \
--model_filename $model \ --model_filename $model \
@ -63,14 +76,26 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ];then
# to onnx # to onnx
./local/tonnx.sh $dir $model $param $exp/model.onnx ./local/tonnx.sh $dir $model $param $exp/model.onnx
./local/infer_check.py --input_file $input_file --model_dir $dir --onnx_model $exp/model.onnx ./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.onnx
fi fi
# aishell rnn hidden is 1024
# wenetspeech rnn hiddn is 2048
if [ $model_type == 'aishell' ];then
input_shape="audio_chunk:1,-1,161 audio_chunk_lens:1 chunk_state_c_box:5,1,1024 chunk_state_h_box:5,1,1024"
elif [ $model_type == 'wenetspeech' ];then
input_shape="audio_chunk:1,-1,161 audio_chunk_lens:1 chunk_state_c_box:5,1,2048 chunk_state_h_box:5,1,2048"
else
echo "not support: $model_type"
exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ] ;then if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ] ;then
input_shape="audio_chunk:1,-1,161 audio_chunk_lens:1 chunk_state_c_box:5,1,1024 chunk_state_h_box:5,1,1024" # wenetspeech ds2 model execed 2GB limit, will error.
# simplifying onnx model # simplifying onnx model
./local/onnx_opt.sh $exp/model.onnx $exp/model.opt.onnx "$input_shape" ./local/onnx_opt.sh $exp/model.onnx $exp/model.opt.onnx "$input_shape"
./local/infer_check.py --input_file $input_file --model_dir $dir --onnx_model $exp/model.opt.onnx ./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.opt.onnx
fi fi

Loading…
Cancel
Save