Merge pull request #2050 from zh794390558/onnx_quant

[speechx] ds2 wenetspeech onnx quant
pull/2061/head
Hui Zhang 3 years ago committed by GitHub
commit e04cd18846
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -18,12 +18,13 @@ engine_list: ['asr_online-onnx']
# ENGINE CONFIG #
#################################################################################
################################### ASR #########################################
################### speech task: asr; engine_type: online-inference #######################
asr_online-inference:
################### speech task: asr; engine_type: online-onnx #######################
asr_online-onnx:
model_type: 'deepspeech2online_wenetspeech'
am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional]
am_model: # the pdmodel file of onnx am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
sample_rate: 16000
cfg_path:
@ -32,11 +33,14 @@ asr_online-inference:
force_yes: True
device: 'cpu' # cpu or gpu:id
# https://onnxruntime.ai/docs/api/python/api_summary.html#inferencesession
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True
glog_info: False # True -> print glog
summary: True # False -> do not show predictor config
device: 'cpu' # set 'gpu:id' or 'cpu'
graph_optimization_level: 0
intra_op_num_threads: 0 # Sets the number of threads used to parallelize the execution within nodes.
inter_op_num_threads: 0 # Sets the number of threads used to parallelize the execution of the graph (across nodes).
log_severity_level: 2 # Log severity level. Applies to session load, initialization, etc. 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
log_verbosity_level: 0 # VLOG level if DEBUG build and session_log_severity_level is 0. Applies to session load, initialization, etc. Default is 0.
chunk_buffer_conf:
frame_duration_ms: 85
@ -49,13 +53,12 @@ asr_online-inference:
shift_ms: 10 # ms
################################### ASR #########################################
################### speech task: asr; engine_type: online-onnx #######################
asr_online-onnx:
################### speech task: asr; engine_type: online-inference #######################
asr_online-inference:
model_type: 'deepspeech2online_wenetspeech'
am_model: # the pdmodel file of onnx am static model [optional]
am_params: # the pdiparams file of am static model [optional]
am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
sample_rate: 16000
cfg_path:
@ -64,17 +67,14 @@ asr_online-onnx:
force_yes: True
device: 'cpu' # cpu or gpu:id
# https://onnxruntime.ai/docs/api/python/api_summary.html#inferencesession
am_predictor_conf:
device: 'cpu' # set 'gpu:id' or 'cpu'
graph_optimization_level: 0
intra_op_num_threads: 0 # Sets the number of threads used to parallelize the execution within nodes.
inter_op_num_threads: 0 # Sets the number of threads used to parallelize the execution of the graph (across nodes).
log_severity_level: 2 # Log severity level. Applies to session load, initialization, etc. 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
log_verbosity_level: 0 # VLOG level if DEBUG build and session_log_severity_level is 0. Applies to session load, initialization, etc. Default is 0.
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True
glog_info: False # True -> print glog
summary: True # False -> do not show predictor config
chunk_buffer_conf:
frame_duration_ms: 80
frame_duration_ms: 85
shift_ms: 40
sample_rate: 16000
sample_width: 2

@ -33,8 +33,9 @@ if __name__ == '__main__':
P = 0.0
n = 0
for m in rtfs:
# not accurate, may have duplicate log
n += 1
T += m['T']
P += m['P']
print(f"RTF: {P/T}, utts: {n}")
print(f"RTF: {P/T}")

@ -18,7 +18,6 @@
import argparse
import asyncio
import codecs
import logging
import os
from paddlespeech.cli.log import logger
@ -44,7 +43,7 @@ def main(args):
# support to process batch audios from wav.scp
if args.wavscp and os.path.exists(args.wavscp):
logging.info(f"start to process the wavscp: {args.wavscp}")
logger.info(f"start to process the wavscp: {args.wavscp}")
with codecs.open(args.wavscp, 'r', encoding='utf-8') as f,\
codecs.open("result.txt", 'w', encoding='utf-8') as w:
for line in f:

@ -4,19 +4,19 @@
## Speech-to-Text Models
### Speech Recognition Model
Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER | Hours of speech | Example Link
:-------------:| :------------:| :-----: | -----: | :-----: |:-----:| :-----: | :-----: | :-----:
[Ds2 Online Wenetspeech ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.2.model.tar.gz) | Wenetspeech Dataset | Char-based | 1.2 GB | 2 Conv + 5 LSTM layers | 0.152 (test\_net, w/o LM) <br> 0.2417 (test\_meeting, w/o LM) <br> 0.053 (aishell, w/ LM) |-| 10000 h |-
[Ds2 Online Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_0.2.1.model.tar.gz) | Aishell Dataset | Char-based | 491 MB | 2 Conv + 5 LSTM layers | 0.0666 |-| 151 h | [D2 Online Aishell ASR0](../../examples/aishell/asr0)
[Ds2 Offline Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_offline_aishell_ckpt_1.0.1.model.tar.gz)| Aishell Dataset | Char-based | 1.4 GB | 2 Conv + 5 bidirectional LSTM layers| 0.0554 |-| 151 h | [Ds2 Offline Aishell ASR0](../../examples/aishell/asr0)
[Conformer Online Wenetspeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz) | WenetSpeech Dataset | Char-based | 457 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.11 (test\_net) 0.1879 (test\_meeting) |-| 10000 h |-
[Conformer Online Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.2.0.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.0544 |-| 151 h | [Conformer Online Aishell ASR1](../../examples/aishell/asr1)
[Conformer Offline Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_0.1.2.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0464 |-| 151 h | [Conformer Offline Aishell ASR1](../../examples/aishell/asr1)
[Transformer Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_transformer_aishell_ckpt_0.1.1.model.tar.gz) | Aishell Dataset | Char-based | 128 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0523 || 151 h | [Transformer Aishell ASR1](../../examples/aishell/asr1)
[Ds2 Offline Librispeech ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_offline_librispeech_ckpt_1.0.1.model.tar.gz)| Librispeech Dataset | Char-based | 1.3 GB | 2 Conv + 5 bidirectional LSTM layers| - |0.0467| 960 h | [Ds2 Offline Librispeech ASR0](../../examples/librispeech/asr0)
[Conformer Librispeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_conformer_librispeech_ckpt_0.1.1.model.tar.gz) | Librispeech Dataset | subword-based | 191 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring |-| 0.0338 | 960 h | [Conformer Librispeech ASR1](../../examples/librispeech/asr1)
[Transformer Librispeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_transformer_librispeech_ckpt_0.1.1.model.tar.gz) | Librispeech Dataset | subword-based | 131 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: Attention rescoring |-| 0.0381 | 960 h | [Transformer Librispeech ASR1](../../examples/librispeech/asr1)
[Transformer Librispeech ASR2 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr2/asr2_transformer_librispeech_ckpt_0.1.1.model.tar.gz) | Librispeech Dataset | subword-based | 131 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: JoinCTC w/ LM |-| 0.0240 | 960 h | [Transformer Librispeech ASR2](../../examples/librispeech/asr2)
Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER | Hours of speech | Example Link | Inference Type |
:-------------:| :------------:| :-----: | -----: | :-----: |:-----:| :-----: | :-----: | :-----: | :-----: |
[Ds2 Online Wenetspeech ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.4.model.tar.gz) | Wenetspeech Dataset | Char-based | 1.2 GB | 2 Conv + 5 LSTM layers | 0.152 (test\_net, w/o LM) <br> 0.2417 (test\_meeting, w/o LM) <br> 0.053 (aishell, w/ LM) |-| 10000 h | - | onnx/inference/python |
[Ds2 Online Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_0.2.1.model.tar.gz) | Aishell Dataset | Char-based | 491 MB | 2 Conv + 5 LSTM layers | 0.0666 |-| 151 h | [D2 Online Aishell ASR0](../../examples/aishell/asr0) | onnx/inference/python |
[Ds2 Offline Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_offline_aishell_ckpt_1.0.1.model.tar.gz)| Aishell Dataset | Char-based | 1.4 GB | 2 Conv + 5 bidirectional LSTM layers| 0.0554 |-| 151 h | [Ds2 Offline Aishell ASR0](../../examples/aishell/asr0) | inference/python |
[Conformer Online Wenetspeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz) | WenetSpeech Dataset | Char-based | 457 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.11 (test\_net) 0.1879 (test\_meeting) |-| 10000 h |- | python |
[Conformer Online Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.2.0.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.0544 |-| 151 h | [Conformer Online Aishell ASR1](../../examples/aishell/asr1) | python |
[Conformer Offline Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_0.1.2.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0464 |-| 151 h | [Conformer Offline Aishell ASR1](../../examples/aishell/asr1) | python |
[Transformer Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_transformer_aishell_ckpt_0.1.1.model.tar.gz) | Aishell Dataset | Char-based | 128 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0523 || 151 h | [Transformer Aishell ASR1](../../examples/aishell/asr1) | python |
[Ds2 Offline Librispeech ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_offline_librispeech_ckpt_1.0.1.model.tar.gz)| Librispeech Dataset | Char-based | 1.3 GB | 2 Conv + 5 bidirectional LSTM layers| - |0.0467| 960 h | [Ds2 Offline Librispeech ASR0](../../examples/librispeech/asr0) | inference/python |
[Conformer Librispeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_conformer_librispeech_ckpt_0.1.1.model.tar.gz) | Librispeech Dataset | subword-based | 191 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring |-| 0.0338 | 960 h | [Conformer Librispeech ASR1](../../examples/librispeech/asr1) | python |
[Transformer Librispeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_transformer_librispeech_ckpt_0.1.1.model.tar.gz) | Librispeech Dataset | subword-based | 131 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: Attention rescoring |-| 0.0381 | 960 h | [Transformer Librispeech ASR1](../../examples/librispeech/asr1) | python |
[Transformer Librispeech ASR2 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr2/asr2_transformer_librispeech_ckpt_0.1.1.model.tar.gz) | Librispeech Dataset | subword-based | 131 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: JoinCTC w/ LM |-| 0.0240 | 960 h | [Transformer Librispeech ASR2](../../examples/librispeech/asr2) | python |
### Language Model based on NGram
Language Model | Training Data | Token-based | Size | Descriptions

@ -155,6 +155,26 @@ asr_dynamic_pretrained_models = {
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
'1.0.4': {
'url':
'http://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.4.model.tar.gz',
'md5':
'c595cb76902b5a5d01409171375989f4',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_10',
'model':
'exp/deepspeech2_online/checkpoints/avg_10.jit.pdmodel',
'params':
'exp/deepspeech2_online/checkpoints/avg_10.jit.pdiparams',
'onnx_model':
'onnx/model.onnx',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
},
"deepspeech2offline_aishell-zh-16k": {
'1.0': {
@ -294,6 +314,26 @@ asr_static_pretrained_models = {
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
'1.0.4': {
'url':
'http://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.4.model.tar.gz',
'md5':
'c595cb76902b5a5d01409171375989f4',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_10',
'model':
'exp/deepspeech2_online/checkpoints/avg_10.jit.pdmodel',
'params':
'exp/deepspeech2_online/checkpoints/avg_10.jit.pdiparams',
'onnx_model':
'onnx/model.onnx',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
},
}
@ -341,6 +381,26 @@ asr_onnx_pretrained_models = {
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
'1.0.4': {
'url':
'http://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.4.model.tar.gz',
'md5':
'c595cb76902b5a5d01409171375989f4',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_10',
'model':
'exp/deepspeech2_online/checkpoints/avg_10.jit.pdmodel',
'params':
'exp/deepspeech2_online/checkpoints/avg_10.jit.pdiparams',
'onnx_model':
'onnx/model.onnx',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
},
}

@ -18,12 +18,13 @@ engine_list: ['asr_online-onnx']
# ENGINE CONFIG #
#################################################################################
################################### ASR #########################################
################### speech task: asr; engine_type: online-inference #######################
asr_online-inference:
################### speech task: asr; engine_type: online-onnx #######################
asr_online-onnx:
model_type: 'deepspeech2online_wenetspeech'
am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional]
am_model: # the pdmodel file of onnx am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
sample_rate: 16000
cfg_path:
@ -32,14 +33,17 @@ asr_online-inference:
force_yes: True
device: 'cpu' # cpu or gpu:id
# https://onnxruntime.ai/docs/api/python/api_summary.html#inferencesession
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True
glog_info: False # True -> print glog
summary: True # False -> do not show predictor config
device: 'cpu' # set 'gpu:id' or 'cpu'
graph_optimization_level: 0
intra_op_num_threads: 0 # Sets the number of threads used to parallelize the execution within nodes.
inter_op_num_threads: 0 # Sets the number of threads used to parallelize the execution of the graph (across nodes).
log_severity_level: 2 # Log severity level. Applies to session load, initialization, etc. 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
log_verbosity_level: 0 # VLOG level if DEBUG build and session_log_severity_level is 0. Applies to session load, initialization, etc. Default is 0.
chunk_buffer_conf:
frame_duration_ms: 80
frame_duration_ms: 85
shift_ms: 40
sample_rate: 16000
sample_width: 2
@ -49,13 +53,12 @@ asr_online-inference:
shift_ms: 10 # ms
################################### ASR #########################################
################### speech task: asr; engine_type: online-onnx #######################
asr_online-onnx:
################### speech task: asr; engine_type: online-inference #######################
asr_online-inference:
model_type: 'deepspeech2online_wenetspeech'
am_model: # the pdmodel file of onnx am static model [optional]
am_params: # the pdiparams file of am static model [optional]
am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
sample_rate: 16000
cfg_path:
@ -64,14 +67,11 @@ asr_online-onnx:
force_yes: True
device: 'cpu' # cpu or gpu:id
# https://onnxruntime.ai/docs/api/python/api_summary.html#inferencesession
am_predictor_conf:
device: 'cpu' # set 'gpu:id' or 'cpu'
graph_optimization_level: 0
intra_op_num_threads: 0 # Sets the number of threads used to parallelize the execution within nodes.
inter_op_num_threads: 0 # Sets the number of threads used to parallelize the execution of the graph (across nodes).
log_severity_level: 2 # Log severity level. Applies to session load, initialization, etc. 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
log_verbosity_level: 0 # VLOG level if DEBUG build and session_log_severity_level is 0. Applies to session load, initialization, etc. Default is 0.
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True
glog_info: False # True -> print glog
summary: True # False -> do not show predictor config
chunk_buffer_conf:
frame_duration_ms: 85

@ -1,9 +1,11 @@
# DeepSpeech2 ONNX model
# DeepSpeech2 to ONNX model
1. convert deepspeech2 model to ONNX, using Paddle2ONNX.
2. check paddleinference and onnxruntime output equal.
3. optimize onnx model
4. check paddleinference and optimized onnxruntime output equal.
5. quantize onnx model
4. check paddleinference and optimized onnxruntime output equal.
Please make sure [Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX) and [onnx-simplifier](https://github.com/zh794390558/onnx-simplifier/tree/dyn_time_shape) version is correct.
@ -26,12 +28,27 @@ onnxruntime 1.11.0
## Using
```
bash run.sh
bash run.sh --stage 0 --stop_stage 5
```
For more details please see `run.sh`.
## Outputs
The optimized onnx model is `exp/model.opt.onnx`.
The optimized onnx model is `exp/model.opt.onnx`, quanted model is `$exp/model.optset11.quant.onnx`.
To show the graph, please using `local/netron.sh`.
## [Results](https://github.com/PaddlePaddle/PaddleSpeech/wiki/ASR-Benchmark#streaming-asr)
机器硬件:`CPUIntel(R) Xeon(R) Gold 6271C CPU @ 2.60GHz`
测试脚本:`Streaming Server`
Acoustic Model | Model Size | enigne | dedoding_method | ctc_weight | decoding_chunk_size | num_decoding_left_chunk | RTF |
|:-------------:| :-----: | :-----: | :------------:| :-----: | :-----: | :-----: |:-----:|
| deepspeech2online_wenetspeech | 659MB | infernece | ctc_prefix_beam_search | - | 1 | - | 1.9108175171428279(utts=80) |
| deepspeech2online_wenetspeech | 659MB | onnx | ctc_prefix_beam_search | - | 1 | - | 0.5617182449999291 (utts=80) |
| deepspeech2online_wenetspeech | 166MB | onnx quant | ctc_prefix_beam_search | - | 1 | - | 0.44507715475808385 (utts=80) |
> quant 和机器有关不是所有机器都支持。ONNX quant测试机器指令集支持:
> Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology eagerfpu pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 arat umip pku ospke avx512_vnni spec_ctrl

@ -0,0 +1,37 @@
#!/usr/bin/env python3
import argparse
import onnx
from onnx import version_converter
if __name__ == '__main__':
parser = argparse.ArgumentParser(prog=__doc__)
parser.add_argument(
"--model-file", type=str, required=True, help='path/to/the/model.onnx.')
parser.add_argument(
"--save-model",
type=str,
required=True,
help='path/to/saved/model.onnx.')
# Models must be opset10 or higher to be quantized.
parser.add_argument(
"--target-opset", type=int, default=11, help='path/to/the/model.onnx.')
args = parser.parse_args()
print(f"to opset: {args.target_opset}")
# Preprocessing: load the model to be converted.
model_path = args.model_file
original_model = onnx.load(model_path)
# print('The model before conversion:\n{}'.format(original_model))
# A full list of supported adapters can be found here:
# https://github.com/onnx/onnx/blob/main/onnx/version_converter.py#L21
# Apply the version conversion on the original model
converted_model = version_converter.convert_version(original_model,
args.target_opset)
# print('The model after conversion:\n{}'.format(converted_model))
onnx.save(converted_model, args.save_model)

@ -494,6 +494,8 @@ class SymbolicShapeInference:
# contrib ops
'Attention', 'BiasGelu', \
'EmbedLayerNormalization', \
'FastGelu', 'Gelu', 'LayerNormalization', \

@ -0,0 +1,48 @@
#!/usr/bin/env python3
import argparse
from onnxruntime.quantization import quantize_dynamic
from onnxruntime.quantization import QuantType
def quantize_onnx_model(onnx_model_path,
quantized_model_path,
nodes_to_exclude=[]):
print("Starting quantization...")
quantize_dynamic(
onnx_model_path,
quantized_model_path,
weight_type=QuantType.QInt8,
nodes_to_exclude=nodes_to_exclude)
print(f"Quantized model saved to: {quantized_model_path}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-in",
type=str,
required=True,
help="ONNX model", )
parser.add_argument(
"--model-out",
type=str,
required=True,
default='model.quant.onnx',
help="ONNX model", )
parser.add_argument(
"--nodes-to-exclude",
type=str,
required=True,
help="nodes to exclude. e.g. conv,linear.", )
args = parser.parse_args()
nodes_to_exclude = args.nodes_to_exclude.split(',')
quantize_onnx_model(args.model_in, args.model_out, nodes_to_exclude)
if __name__ == "__main__":
main()

@ -15,11 +15,12 @@ pip install paddle2onnx
pip install onnx
# https://github.com/PaddlePaddle/Paddle2ONNX#%E5%91%BD%E4%BB%A4%E8%A1%8C%E8%BD%AC%E6%8D%A2
# opset10 support quantize
paddle2onnx --model_dir $dir \
--model_filename $model \
--params_filename $param \
--save_file $output \
--enable_dev_version True \
--opset_version 9 \
--opset_version 11 \
--enable_onnx_checker True

@ -89,6 +89,18 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ] ;then
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ];then
# convert opset_num to 11
./local/onnx_convert_opset.py --target-opset 11 --model-file $exp/model.ort.opt.onnx --save-model $exp/model.optset11.onnx
# quant model
nodes_to_exclude='p2o.Conv.0,p2o.Conv.2'
./local/ort_dyanmic_quant.py --model-in $exp/model.optset11.onnx --model-out $exp/model.optset11.quant.onnx --nodes-to-exclude "${nodes_to_exclude}"
./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.optset11.quant.onnx
fi
# aishell rnn hidden is 1024
# wenetspeech rnn hiddn is 2048
if [ $model_type == 'aishell' ];then

Loading…
Cancel
Save