support onnx quantize

pull/2050/head
Hui Zhang 3 years ago
parent 6dfe7273e6
commit 3cf1f1f0b5

@ -18,12 +18,13 @@ engine_list: ['asr_online-onnx']
# ENGINE CONFIG # # ENGINE CONFIG #
################################################################################# #################################################################################
################################### ASR ######################################### ################################### ASR #########################################
################### speech task: asr; engine_type: online-inference ####################### ################### speech task: asr; engine_type: online-onnx #######################
asr_online-inference: asr_online-onnx:
model_type: 'deepspeech2online_wenetspeech' model_type: 'deepspeech2online_wenetspeech'
am_model: # the pdmodel file of am static model [optional] am_model: # the pdmodel file of onnx am static model [optional]
am_params: # the pdiparams file of am static model [optional] am_params: # the pdiparams file of am static model [optional]
lang: 'zh' lang: 'zh'
sample_rate: 16000 sample_rate: 16000
cfg_path: cfg_path:
@ -32,11 +33,14 @@ asr_online-inference:
force_yes: True force_yes: True
device: 'cpu' # cpu or gpu:id device: 'cpu' # cpu or gpu:id
# https://onnxruntime.ai/docs/api/python/api_summary.html#inferencesession
am_predictor_conf: am_predictor_conf:
device: # set 'gpu:id' or 'cpu' device: 'cpu' # set 'gpu:id' or 'cpu'
switch_ir_optim: True graph_optimization_level: 0
glog_info: False # True -> print glog intra_op_num_threads: 0 # Sets the number of threads used to parallelize the execution within nodes.
summary: True # False -> do not show predictor config inter_op_num_threads: 0 # Sets the number of threads used to parallelize the execution of the graph (across nodes).
log_severity_level: 2 # Log severity level. Applies to session load, initialization, etc. 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
log_verbosity_level: 0 # VLOG level if DEBUG build and session_log_severity_level is 0. Applies to session load, initialization, etc. Default is 0.
chunk_buffer_conf: chunk_buffer_conf:
frame_duration_ms: 85 frame_duration_ms: 85
@ -49,13 +53,12 @@ asr_online-inference:
shift_ms: 10 # ms shift_ms: 10 # ms
################################### ASR ######################################### ################################### ASR #########################################
################### speech task: asr; engine_type: online-onnx ####################### ################### speech task: asr; engine_type: online-inference #######################
asr_online-onnx: asr_online-inference:
model_type: 'deepspeech2online_wenetspeech' model_type: 'deepspeech2online_wenetspeech'
am_model: # the pdmodel file of onnx am static model [optional] am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional] am_params: # the pdiparams file of am static model [optional]
lang: 'zh' lang: 'zh'
sample_rate: 16000 sample_rate: 16000
cfg_path: cfg_path:
@ -64,21 +67,18 @@ asr_online-onnx:
force_yes: True force_yes: True
device: 'cpu' # cpu or gpu:id device: 'cpu' # cpu or gpu:id
# https://onnxruntime.ai/docs/api/python/api_summary.html#inferencesession
am_predictor_conf: am_predictor_conf:
device: 'cpu' # set 'gpu:id' or 'cpu' device: # set 'gpu:id' or 'cpu'
graph_optimization_level: 0 switch_ir_optim: True
intra_op_num_threads: 0 # Sets the number of threads used to parallelize the execution within nodes. glog_info: False # True -> print glog
inter_op_num_threads: 0 # Sets the number of threads used to parallelize the execution of the graph (across nodes). summary: True # False -> do not show predictor config
log_severity_level: 2 # Log severity level. Applies to session load, initialization, etc. 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
log_verbosity_level: 0 # VLOG level if DEBUG build and session_log_severity_level is 0. Applies to session load, initialization, etc. Default is 0.
chunk_buffer_conf: chunk_buffer_conf:
frame_duration_ms: 80 frame_duration_ms: 85
shift_ms: 40 shift_ms: 40
sample_rate: 16000 sample_rate: 16000
sample_width: 2 sample_width: 2
window_n: 7 # frame window_n: 7 # frame
shift_n: 4 # frame shift_n: 4 # frame
window_ms: 25 # ms window_ms: 25 # ms
shift_ms: 10 # ms shift_ms: 10 # ms

@ -18,12 +18,13 @@ engine_list: ['asr_online-onnx']
# ENGINE CONFIG # # ENGINE CONFIG #
################################################################################# #################################################################################
################################### ASR ######################################### ################################### ASR #########################################
################### speech task: asr; engine_type: online-inference ####################### ################### speech task: asr; engine_type: online-onnx #######################
asr_online-inference: asr_online-onnx:
model_type: 'deepspeech2online_wenetspeech' model_type: 'deepspeech2online_wenetspeech'
am_model: # the pdmodel file of am static model [optional] am_model: # the pdmodel file of onnx am static model [optional]
am_params: # the pdiparams file of am static model [optional] am_params: # the pdiparams file of am static model [optional]
lang: 'zh' lang: 'zh'
sample_rate: 16000 sample_rate: 16000
cfg_path: cfg_path:
@ -32,14 +33,17 @@ asr_online-inference:
force_yes: True force_yes: True
device: 'cpu' # cpu or gpu:id device: 'cpu' # cpu or gpu:id
# https://onnxruntime.ai/docs/api/python/api_summary.html#inferencesession
am_predictor_conf: am_predictor_conf:
device: # set 'gpu:id' or 'cpu' device: 'cpu' # set 'gpu:id' or 'cpu'
switch_ir_optim: True graph_optimization_level: 0
glog_info: False # True -> print glog intra_op_num_threads: 0 # Sets the number of threads used to parallelize the execution within nodes.
summary: True # False -> do not show predictor config inter_op_num_threads: 0 # Sets the number of threads used to parallelize the execution of the graph (across nodes).
log_severity_level: 2 # Log severity level. Applies to session load, initialization, etc. 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
log_verbosity_level: 0 # VLOG level if DEBUG build and session_log_severity_level is 0. Applies to session load, initialization, etc. Default is 0.
chunk_buffer_conf: chunk_buffer_conf:
frame_duration_ms: 80 frame_duration_ms: 85
shift_ms: 40 shift_ms: 40
sample_rate: 16000 sample_rate: 16000
sample_width: 2 sample_width: 2
@ -49,13 +53,12 @@ asr_online-inference:
shift_ms: 10 # ms shift_ms: 10 # ms
################################### ASR ######################################### ################################### ASR #########################################
################### speech task: asr; engine_type: online-onnx ####################### ################### speech task: asr; engine_type: online-inference #######################
asr_online-onnx: asr_online-inference:
model_type: 'deepspeech2online_wenetspeech' model_type: 'deepspeech2online_wenetspeech'
am_model: # the pdmodel file of onnx am static model [optional] am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional] am_params: # the pdiparams file of am static model [optional]
lang: 'zh' lang: 'zh'
sample_rate: 16000 sample_rate: 16000
cfg_path: cfg_path:
@ -64,14 +67,11 @@ asr_online-onnx:
force_yes: True force_yes: True
device: 'cpu' # cpu or gpu:id device: 'cpu' # cpu or gpu:id
# https://onnxruntime.ai/docs/api/python/api_summary.html#inferencesession
am_predictor_conf: am_predictor_conf:
device: 'cpu' # set 'gpu:id' or 'cpu' device: # set 'gpu:id' or 'cpu'
graph_optimization_level: 0 switch_ir_optim: True
intra_op_num_threads: 0 # Sets the number of threads used to parallelize the execution within nodes. glog_info: False # True -> print glog
inter_op_num_threads: 0 # Sets the number of threads used to parallelize the execution of the graph (across nodes). summary: True # False -> do not show predictor config
log_severity_level: 2 # Log severity level. Applies to session load, initialization, etc. 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
log_verbosity_level: 0 # VLOG level if DEBUG build and session_log_severity_level is 0. Applies to session load, initialization, etc. Default is 0.
chunk_buffer_conf: chunk_buffer_conf:
frame_duration_ms: 85 frame_duration_ms: 85
@ -81,4 +81,4 @@ asr_online-onnx:
window_n: 7 # frame window_n: 7 # frame
shift_n: 4 # frame shift_n: 4 # frame
window_ms: 25 # ms window_ms: 25 # ms
shift_ms: 10 # ms shift_ms: 10 # ms

@ -1,9 +1,11 @@
# DeepSpeech2 ONNX model # DeepSpeech2 to ONNX model
1. convert deepspeech2 model to ONNX, using Paddle2ONNX. 1. convert deepspeech2 model to ONNX, using Paddle2ONNX.
2. check paddleinference and onnxruntime output equal. 2. check paddleinference and onnxruntime output equal.
3. optimize onnx model 3. optimize onnx model
4. check paddleinference and optimized onnxruntime output equal. 4. check paddleinference and optimized onnxruntime output equal.
5. quantize onnx model
4. check paddleinference and optimized onnxruntime output equal.
Please make sure [Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX) and [onnx-simplifier](https://github.com/zh794390558/onnx-simplifier/tree/dyn_time_shape) version is correct. Please make sure [Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX) and [onnx-simplifier](https://github.com/zh794390558/onnx-simplifier/tree/dyn_time_shape) version is correct.
@ -26,12 +28,27 @@ onnxruntime 1.11.0
## Using ## Using
``` ```
bash run.sh bash run.sh --stage 0 --stop_stage 5
``` ```
For more details please see `run.sh`. For more details please see `run.sh`.
## Outputs ## Outputs
The optimized onnx model is `exp/model.opt.onnx`. The optimized onnx model is `exp/model.opt.onnx`, quanted model is `$exp/model.optset11.quant.onnx`.
To show the graph, please using `local/netron.sh`. To show the graph, please using `local/netron.sh`.
## [Results](https://github.com/PaddlePaddle/PaddleSpeech/wiki/ASR-Benchmark#streaming-asr)
机器硬件:`CPUIntel(R) Xeon(R) Gold 6271C CPU @ 2.60GHz`
测试脚本:`Streaming Server`
Acoustic Model | Model Size | enigne | dedoding_method | ctc_weight | decoding_chunk_size | num_decoding_left_chunk | RTF |
|:-------------:| :-----: | :-----: | :------------:| :-----: | :-----: | :-----: |:-----:|
| deepspeech2online_wenetspeech | 659MB | infernece | ctc_prefix_beam_search | - | 1 | - | 1.9108175171428279(utts=80) |
| deepspeech2online_wenetspeech | 659MB | onnx | ctc_prefix_beam_search | - | 1 | - | 0.5617182449999291 (utts=80) |
| deepspeech2online_wenetspeech | 166MB | onnx quant | ctc_prefix_beam_search | - | 1 | - | 0.44507715475808385 (utts=80) |
> quant 和机器有关不是所有机器都支持。ONNX quant测试机器指令集支持:
> Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology eagerfpu pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 arat umip pku ospke avx512_vnni spec_ctrl

@ -0,0 +1,30 @@
#!/usr/bin/env python3
import argparse
import onnx
from onnx import version_converter, helper
if __name__ == '__main__':
parser = argparse.ArgumentParser(prog=__doc__)
parser.add_argument("--model-file", type=str, required=True, help='path/to/the/model.onnx.')
parser.add_argument("--save-model", type=str, required=True, help='path/to/saved/model.onnx.')
# Models must be opset10 or higher to be quantized.
parser.add_argument("--target-opset", type=int, default=11, help='path/to/the/model.onnx.')
args = parser.parse_args()
print(f"to opset: {args.target_opset}")
# Preprocessing: load the model to be converted.
model_path = args.model_file
original_model = onnx.load(model_path)
# print('The model before conversion:\n{}'.format(original_model))
# A full list of supported adapters can be found here:
# https://github.com/onnx/onnx/blob/main/onnx/version_converter.py#L21
# Apply the version conversion on the original model
converted_model = version_converter.convert_version(original_model, args.target_opset)
# print('The model after conversion:\n{}'.format(converted_model))
onnx.save(converted_model, args.save_model)

@ -0,0 +1,43 @@
#!/usr/bin/env python3
import argparse
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType
def quantize_onnx_model(onnx_model_path, quantized_model_path, nodes_to_exclude=[]):
print("Starting quantization...")
from onnxruntime.quantization import QuantType, quantize_dynamic
quantize_dynamic(onnx_model_path, quantized_model_path, weight_type=QuantType.QInt8, nodes_to_exclude=nodes_to_exclude)
print(f"Quantized model saved to: {quantized_model_path}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-in",
type=str,
required=True,
help="ONNX model",
)
parser.add_argument(
"--model-out",
type=str,
required=True,
default='model.quant.onnx',
help="ONNX model",
)
parser.add_argument(
"--nodes-to-exclude",
type=str,
required=True,
help="nodes to exclude. e.g. conv,linear.",
)
args = parser.parse_args()
nodes_to_exclude = args.nodes_to_exclude.split(',')
quantize_onnx_model(args.model_in, args.model_out, nodes_to_exclude)
if __name__ == "__main__":
main()

@ -15,11 +15,12 @@ pip install paddle2onnx
pip install onnx pip install onnx
# https://github.com/PaddlePaddle/Paddle2ONNX#%E5%91%BD%E4%BB%A4%E8%A1%8C%E8%BD%AC%E6%8D%A2 # https://github.com/PaddlePaddle/Paddle2ONNX#%E5%91%BD%E4%BB%A4%E8%A1%8C%E8%BD%AC%E6%8D%A2
# opset10 support quantize
paddle2onnx --model_dir $dir \ paddle2onnx --model_dir $dir \
--model_filename $model \ --model_filename $model \
--params_filename $param \ --params_filename $param \
--save_file $output \ --save_file $output \
--enable_dev_version True \ --enable_dev_version True \
--opset_version 9 \ --opset_version 11 \
--enable_onnx_checker True --enable_onnx_checker True

@ -89,6 +89,18 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ] ;then
fi fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ];then
# convert opset_num to 11
./local/onnx_convert_opset.py --target-opset 11 --model-file $exp/model.ort.opt.onnx --save-model $exp/model.optset11.onnx
# quant model
nodes_to_exclude='p2o.Conv.0,p2o.Conv.2'
./local/ort_dyanmic_quant.py --model-in $exp/model.optset11.onnx --model-out $exp/model.optset11.quant.onnx --nodes-to-exclude "${nodes_to_exclude}"
./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.optset11.quant.onnx
fi
# aishell rnn hidden is 1024 # aishell rnn hidden is 1024
# wenetspeech rnn hiddn is 2048 # wenetspeech rnn hiddn is 2048
if [ $model_type == 'aishell' ];then if [ $model_type == 'aishell' ];then

Loading…
Cancel
Save