diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py index f79961d6..f0a6ef31 100644 --- a/paddlespeech/resource/pretrained_models.py +++ b/paddlespeech/resource/pretrained_models.py @@ -15,6 +15,7 @@ __all__ = [ 'asr_dynamic_pretrained_models', 'asr_static_pretrained_models', + 'asr_onnx_pretrained_models', 'cls_dynamic_pretrained_models', 'cls_static_pretrained_models', 'st_dynamic_pretrained_models', @@ -166,23 +167,17 @@ asr_dynamic_pretrained_models = { }, }, "deepspeech2online_aishell-zh-16k": { - '1.0': { + '1.0.2': { 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.1.model.tar.gz', - 'md5': - 'df5ddeac8b679a470176649ac4b78726', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/deepspeech2_online/checkpoints/avg_1', - 'model': - 'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel', - 'params': - 'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams', - 'lm_url': - 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', - 'lm_md5': - '29e02312deb2e59b3c8686c7966d4fe3' + 'http://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.2.model.tar.gz', + 'md5': '4dd42cfce9aaa54db0ec698da6c48ec5', + 'cfg_path': 'model.yaml', + 'ckpt_path':'exp/deepspeech2_online/checkpoints/avg_1', + 'model':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel', + 'params':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams', + 'onnx_model': 'onnx/model.onnx', + 'lm_url':'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5':'29e02312deb2e59b3c8686c7966d4fe3' }, }, "deepspeech2offline_librispeech-en-16k": { @@ -224,6 +219,38 @@ asr_static_pretrained_models = { '29e02312deb2e59b3c8686c7966d4fe3' } }, + "deepspeech2online_aishell-zh-16k": { + '1.0.2': { + 'url': + 'http://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.2.model.tar.gz', + 'md5': '4dd42cfce9aaa54db0ec698da6c48ec5', + 'cfg_path': 'model.yaml', + 'ckpt_path':'exp/deepspeech2_online/checkpoints/avg_1', + 'model':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel', + 'params':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams', + 'onnx_model': 'onnx/model.onnx', + 'lm_url':'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5':'29e02312deb2e59b3c8686c7966d4fe3' + }, + }, +} + + +asr_onnx_pretrained_models = { + "deepspeech2online_aishell-zh-16k": { + '1.0.2': { + 'url': + 'http://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.2.model.tar.gz', + 'md5': '4dd42cfce9aaa54db0ec698da6c48ec5', + 'cfg_path': 'model.yaml', + 'ckpt_path':'exp/deepspeech2_online/checkpoints/avg_1', + 'model':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel', + 'params':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams', + 'onnx_model': 'onnx/model.onnx', + 'lm_url':'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5':'29e02312deb2e59b3c8686c7966d4fe3' + }, + }, } # --------------------------------- diff --git a/speechx/examples/ds2_ol/onnx/local/infer_check.py b/speechx/examples/ds2_ol/onnx/local/infer_check.py index 307a764c..f821baa1 100755 --- a/speechx/examples/ds2_ol/onnx/local/infer_check.py +++ b/speechx/examples/ds2_ol/onnx/local/infer_check.py @@ -27,7 +27,13 @@ def parse_args(): '--input_file', type=str, default="static_ds2online_inputs.pickle", - help="ds2 input pickle file.", ) + help="aishell ds2 input data file. For wenetspeech, we only feed for infer model", + ) + parser.add_argument( + '--model_type', + type=str, + default="aishell", + help="aishell(1024) or wenetspeech(2048)", ) parser.add_argument( '--model_dir', type=str, default=".", help="paddle model dir.") parser.add_argument( @@ -56,6 +62,12 @@ if __name__ == '__main__': audio_chunk_lens = iodict['audio_chunk_lens'] chunk_state_h_box = iodict['chunk_state_h_box'] chunk_state_c_box = iodict['chunk_state_c_bos'] + print("raw state shape: ", chunk_state_c_box.shape) + + if FLAGS.model_type == 'wenetspeech': + chunk_state_h_box = np.repeat(chunk_state_h_box, 2, axis=-1) + chunk_state_c_box = np.repeat(chunk_state_c_box, 2, axis=-1) + print("state shape: ", chunk_state_c_box.shape) # paddle model = paddle.jit.load(os.path.join(FLAGS.model_dir, FLAGS.model_prefix)) @@ -82,5 +94,7 @@ if __name__ == '__main__': # assert paddle equal ort print(np.allclose(ort_res_chunk, res_chunk, atol=1e-6)) print(np.allclose(ort_res_lens, res_lens, atol=1e-6)) - print(np.allclose(ort_chunk_state_h, chunk_state_h, atol=1e-6)) - print(np.allclose(ort_chunk_state_c, chunk_state_c, atol=1e-6)) + + if FLAGS.model_type == 'aishell': + print(np.allclose(ort_chunk_state_h, chunk_state_h, atol=1e-6)) + print(np.allclose(ort_chunk_state_c, chunk_state_c, atol=1e-6)) diff --git a/speechx/examples/ds2_ol/onnx/run.sh b/speechx/examples/ds2_ol/onnx/run.sh index dda5a57a..57cd9416 100755 --- a/speechx/examples/ds2_ol/onnx/run.sh +++ b/speechx/examples/ds2_ol/onnx/run.sh @@ -6,6 +6,11 @@ set -e stage=0 stop_stage=100 +#tarfile=asr0_deepspeech2_online_wenetspeech_ckpt_1.0.2.model.tar.gz +tarfile=asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.1.model.tar.gz +model_prefix=avg_1.jit +model=${model_prefix}.pdmodel +param=${model_prefix}.pdiparams . utils/parse_options.sh @@ -14,27 +19,25 @@ exp=exp mkdir -p $data $exp +dir=$data/exp/deepspeech2_online/checkpoints + +# wenetspeech or aishell +model_type=$(echo $tarfile | cut -d '_' -f 4) + if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then - test -f $data/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz || wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz -P $data + test -f $data/$tarfile || wget -P $data -c https://paddlespeech.bj.bcebos.com/s2t/$model_type/asr0/$tarfile # wenetspeech ds2 model pushd $data - tar zxvf asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz + tar zxvf $tarfile popd # ds2 model demo inputs pushd $exp wget -c http://paddlespeech.bj.bcebos.com/speechx/examples/ds2_ol/onnx/static_ds2online_inputs.pickle popd - - fi -dir=$data/exp/deepspeech2_online/checkpoints -model=avg_1.jit.pdmodel -param=avg_1.jit.pdiparams - - output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ];then # prune model by outputs @@ -44,10 +47,20 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ];then ./local/prune.sh $dir $model $param $output_names $exp/prune fi -input_shape_dict="{'audio_chunk':[1,-1,161], 'audio_chunk_lens':[1], 'chunk_state_c_box':[5, 1, 1024], 'chunk_state_h_box':[5,1,1024]}" +# aishell rnn hidden is 1024 +# wenetspeech rnn hiddn is 2048 +if [ $model_type == 'aishell' ];then + input_shape_dict="{'audio_chunk':[1,-1,161], 'audio_chunk_lens':[1], 'chunk_state_c_box':[5, 1, 1024], 'chunk_state_h_box':[5,1,1024]}" +elif [ $model_type == 'wenetspeech' ];then + input_shape_dict="{'audio_chunk':[1,-1,161], 'audio_chunk_lens':[1], 'chunk_state_c_box':[5, 1, 2048], 'chunk_state_h_box':[5,1,2048]}" +else + echo "not support: $model_type" + exit -1 +fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ];then # infer shape by new shape mkdir -p $exp/shape + echo $input_shape_dict python3 local/pd_infer_shape.py \ --model_dir $dir \ --model_filename $model \ @@ -63,14 +76,26 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ];then # to onnx ./local/tonnx.sh $dir $model $param $exp/model.onnx - ./local/infer_check.py --input_file $input_file --model_dir $dir --onnx_model $exp/model.onnx + ./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.onnx fi +# aishell rnn hidden is 1024 +# wenetspeech rnn hiddn is 2048 +if [ $model_type == 'aishell' ];then + input_shape="audio_chunk:1,-1,161 audio_chunk_lens:1 chunk_state_c_box:5,1,1024 chunk_state_h_box:5,1,1024" +elif [ $model_type == 'wenetspeech' ];then + input_shape="audio_chunk:1,-1,161 audio_chunk_lens:1 chunk_state_c_box:5,1,2048 chunk_state_h_box:5,1,2048" +else + echo "not support: $model_type" + exit -1 +fi + + if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ] ;then - input_shape="audio_chunk:1,-1,161 audio_chunk_lens:1 chunk_state_c_box:5,1,1024 chunk_state_h_box:5,1,1024" + # wenetspeech ds2 model execed 2GB limit, will error. # simplifying onnx model ./local/onnx_opt.sh $exp/model.onnx $exp/model.opt.onnx "$input_shape" - ./local/infer_check.py --input_file $input_file --model_dir $dir --onnx_model $exp/model.opt.onnx -fi \ No newline at end of file + ./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.opt.onnx +fi