Merge pull request #2045 from zh794390558/wenetspeech_onnx

[server] ds2 wenetspeech to onnx and support streaming asr server
pull/2050/head
Jackwaterveg 3 years ago committed by GitHub
commit 6dfe7273e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -11,7 +11,7 @@ port: 8090
# protocol = ['websocket'] (only one can be selected).
# websocket only support online engine type.
protocol: 'websocket'
engine_list: ['asr_online-inference']
engine_list: ['asr_online-onnx']
#################################################################################
@ -21,7 +21,7 @@ engine_list: ['asr_online-inference']
################################### ASR #########################################
################### speech task: asr; engine_type: online-inference #######################
asr_online-inference:
model_type: 'deepspeech2online_aishell'
model_type: 'deepspeech2online_wenetspeech'
am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
@ -53,7 +53,7 @@ asr_online-inference:
################################### ASR #########################################
################### speech task: asr; engine_type: online-onnx #######################
asr_online-onnx:
model_type: 'deepspeech2online_aishell'
model_type: 'deepspeech2online_wenetspeech'
am_model: # the pdmodel file of onnx am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
@ -80,5 +80,5 @@ asr_online-onnx:
sample_width: 2
window_n: 7 # frame
shift_n: 4 # frame
window_ms: 20 # ms
window_ms: 25 # ms
shift_ms: 10 # ms

@ -0,0 +1,40 @@
#!/usr/bin/env python3
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser(prog=__doc__)
parser.add_argument(
'--logfile', type=str, required=True, help='ws client log file')
args = parser.parse_args()
rtfs = []
with open(args.logfile, 'r') as f:
for line in f:
if 'RTF=' in line:
# udio duration: 6.126, elapsed time: 3.471978187561035, RTF=0.5667610492264177
line = line.strip()
beg = line.index("audio")
line = line[beg:]
items = line.split(',')
vals = []
for elem in items:
if "RTF=" in elem:
continue
_, val = elem.split(":")
vals.append(eval(val))
keys = ['T', 'P']
meta = dict(zip(keys, vals))
rtfs.append(meta)
T = 0.0
P = 0.0
n = 0
for m in rtfs:
n += 1
T += m['T']
P += m['P']
print(f"RTF: {P/T}, utts: {n}")

@ -0,0 +1,21 @@
#!/bin/bash
if [ $# != 1 ];then
echo "usage: $0 wav_scp"
exit -1
fi
scp=$1
# calc RTF
# wav_scp can generate from `speechx/examples/ds2_ol/aishell`
exp=exp
mkdir -p $exp
python3 local/websocket_client.py --server_ip 127.0.0.1 --port 8090 --wavscp $scp &> $exp/log.rsl
python3 local/rtf_from_log.py --logfile $exp/log.rsl

@ -1,3 +1,4 @@
#!/usr/bin/python
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -11,9 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/usr/bin/python
# -*- coding: UTF-8 -*-
# script for calc RTF: grep -rn RTF log.txt | awk '{print $NF}' | awk -F "=" '{sum += $NF} END {print "all time",sum, "audio num", NR, "RTF", sum/NR}'
# calc avg RTF(NOT Accurate): grep -rn RTF log.txt | awk '{print $NF}' | awk -F "=" '{sum += $NF} END {print "all time",sum, "audio num", NR, "RTF", sum/NR}'
# python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --wavfile ./zh.wav
# python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --wavfile ./zh.wav
import argparse
import asyncio
import codecs

@ -3,11 +3,9 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
# read the wav and pass it to only streaming asr service
# If `127.0.0.1` is not accessible, you need to use the actual service IP address.
# python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --wavfile ./zh.wav
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input ./zh.wav
# read the wav and call streaming and punc service
# If `127.0.0.1` is not accessible, you need to use the actual service IP address.
# python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --wavfile ./zh.wav
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --input ./zh.wav

@ -25,10 +25,10 @@ from typing import Dict
import paddle
import requests
import soundfile as sf
import yaml
from paddle.framework import load
import paddlespeech.audio
from . import download
from .entry import commands
try:
@ -282,7 +282,8 @@ def _note_one_stat(cls_name, params={}):
if 'audio_file' in params:
try:
_, sr = paddlespeech.audio.load(params['audio_file'])
# recursive import cased by: utils.DATA_HOME
_, sr = sf.read(params['audio_file'])
except Exception:
sr = -1

@ -135,15 +135,21 @@ asr_dynamic_pretrained_models = {
},
},
"deepspeech2online_wenetspeech-zh-16k": {
'1.0': {
'1.0.3': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.2.model.tar.gz',
'http://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.3.model.tar.gz',
'md5':
'b0c77e7f8881e0a27b82127d1abb8d5f',
'cfe273793e68f790f742b411c98bc75e',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_10',
'model':
'exp/deepspeech2_online/checkpoints/avg_10.jit.pdmodel',
'params':
'exp/deepspeech2_online/checkpoints/avg_10.jit.pdiparams',
'onnx_model':
'onnx/model.onnx',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
@ -170,14 +176,22 @@ asr_dynamic_pretrained_models = {
'1.0.2': {
'url':
'http://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.2.model.tar.gz',
'md5': '4dd42cfce9aaa54db0ec698da6c48ec5',
'cfg_path': 'model.yaml',
'ckpt_path':'exp/deepspeech2_online/checkpoints/avg_1',
'model':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel',
'params':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams',
'onnx_model': 'onnx/model.onnx',
'lm_url':'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':'29e02312deb2e59b3c8686c7966d4fe3'
'md5':
'4dd42cfce9aaa54db0ec698da6c48ec5',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_1',
'model':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel',
'params':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams',
'onnx_model':
'onnx/model.onnx',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
},
"deepspeech2offline_librispeech-en-16k": {
@ -241,14 +255,44 @@ asr_static_pretrained_models = {
'1.0.2': {
'url':
'http://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.2.model.tar.gz',
'md5': '4dd42cfce9aaa54db0ec698da6c48ec5',
'cfg_path': 'model.yaml',
'ckpt_path':'exp/deepspeech2_online/checkpoints/avg_1',
'model':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel',
'params':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams',
'onnx_model': 'onnx/model.onnx',
'lm_url':'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':'29e02312deb2e59b3c8686c7966d4fe3'
'md5':
'4dd42cfce9aaa54db0ec698da6c48ec5',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_1',
'model':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel',
'params':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams',
'onnx_model':
'onnx/model.onnx',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
},
"deepspeech2online_wenetspeech-zh-16k": {
'1.0.3': {
'url':
'http://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.3.model.tar.gz',
'md5':
'cfe273793e68f790f742b411c98bc75e',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_10',
'model':
'exp/deepspeech2_online/checkpoints/avg_10.jit.pdmodel',
'params':
'exp/deepspeech2_online/checkpoints/avg_10.jit.pdiparams',
'onnx_model':
'onnx/model.onnx',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
},
}
@ -258,14 +302,44 @@ asr_onnx_pretrained_models = {
'1.0.2': {
'url':
'http://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.2.model.tar.gz',
'md5': '4dd42cfce9aaa54db0ec698da6c48ec5',
'cfg_path': 'model.yaml',
'ckpt_path':'exp/deepspeech2_online/checkpoints/avg_1',
'model':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel',
'params':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams',
'onnx_model': 'onnx/model.onnx',
'lm_url':'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':'29e02312deb2e59b3c8686c7966d4fe3'
'md5':
'4dd42cfce9aaa54db0ec698da6c48ec5',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_1',
'model':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel',
'params':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams',
'onnx_model':
'onnx/model.onnx',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
},
"deepspeech2online_wenetspeech-zh-16k": {
'1.0.3': {
'url':
'http://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.3.model.tar.gz',
'md5':
'cfe273793e68f790f742b411c98bc75e',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_10',
'model':
'exp/deepspeech2_online/checkpoints/avg_10.jit.pdmodel',
'params':
'exp/deepspeech2_online/checkpoints/avg_10.jit.pdiparams',
'onnx_model':
'onnx/model.onnx',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
},
}

@ -11,7 +11,7 @@ port: 8090
# protocol = ['websocket'] (only one can be selected).
# websocket only support online engine type.
protocol: 'websocket'
engine_list: ['asr_online-inference']
engine_list: ['asr_online-onnx']
#################################################################################
@ -21,7 +21,7 @@ engine_list: ['asr_online-inference']
################################### ASR #########################################
################### speech task: asr; engine_type: online-inference #######################
asr_online-inference:
model_type: 'deepspeech2online_aishell'
model_type: 'deepspeech2online_wenetspeech'
am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
@ -45,7 +45,7 @@ asr_online-inference:
sample_width: 2
window_n: 7 # frame
shift_n: 4 # frame
window_ms: 20 # ms
window_ms: 25 # ms
shift_ms: 10 # ms
@ -53,7 +53,7 @@ asr_online-inference:
################################### ASR #########################################
################### speech task: asr; engine_type: online-onnx #######################
asr_online-onnx:
model_type: 'deepspeech2online_aishell'
model_type: 'deepspeech2online_wenetspeech'
am_model: # the pdmodel file of onnx am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'

@ -10,4 +10,4 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.

@ -10,4 +10,4 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.

@ -10,4 +10,4 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.

@ -9,7 +9,7 @@ Please make sure [Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX) and
The example test with these packages installed:
```
paddle2onnx 0.9.8rc0 # develop af4354b4e9a61a93be6490640059a02a4499bc7a
paddle2onnx 0.9.8 # develop 62c5424e22cd93968dc831216fc9e0f0fce3d819
paddleaudio 0.2.1
paddlefsl 1.1.0
paddlenlp 2.2.6

@ -492,6 +492,8 @@ class SymbolicShapeInference:
skip_infer = node.op_type in [
'If', 'Loop', 'Scan', 'SplitToSequence', 'ZipMap', \
# contrib ops
'Attention', 'BiasGelu', \
'EmbedLayerNormalization', \
'FastGelu', 'Gelu', 'LayerNormalization', \

@ -0,0 +1,45 @@
#!/usr/bin/env python3
import argparse
import onnxruntime as ort
# onnxruntime optimizer.
# https://onnxruntime.ai/docs/performance/graph-optimizations.html
# https://onnxruntime.ai/docs/api/python/api_summary.html#api
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
'--model_in', required=True, type=str, help='Path to onnx model.')
parser.add_argument(
'--opt_level',
required=True,
type=int,
default=0,
choices=[0, 1, 2],
help='Path to onnx model.')
parser.add_argument(
'--model_out', required=True, help='path to save the optimized model.')
parser.add_argument('--debug', default=False, help='output debug info.')
return parser.parse_args()
if __name__ == '__main__':
args = parse_arguments()
sess_options = ort.SessionOptions()
# Set graph optimization level
print(f"opt level: {args.opt_level}")
if args.opt_level == 0:
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
elif args.opt_level == 1:
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
else:
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
# To enable model serialization after graph optimization set this
sess_options.optimized_model_filepath = args.model_out
session = ort.InferenceSession(args.model_in, sess_options)

@ -5,10 +5,11 @@ set -e
. path.sh
stage=0
stop_stage=100
#tarfile=asr0_deepspeech2_online_wenetspeech_ckpt_1.0.2.model.tar.gz
tarfile=asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.1.model.tar.gz
model_prefix=avg_1.jit
stop_stage=50
tarfile=asr0_deepspeech2_online_wenetspeech_ckpt_1.0.2.model.tar.gz
#tarfile=asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.1.model.tar.gz
model_prefix=avg_10.jit
#model_prefix=avg_1.jit
model=${model_prefix}.pdmodel
param=${model_prefix}.pdiparams
@ -80,6 +81,14 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ];then
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ] ;then
# ort graph optmize
./local/ort_opt.py --model_in $exp/model.onnx --opt_level 0 --model_out $exp/model.ort.opt.onnx
./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.ort.opt.onnx
fi
# aishell rnn hidden is 1024
# wenetspeech rnn hiddn is 2048
if [ $model_type == 'aishell' ];then
@ -90,9 +99,9 @@ else
echo "not support: $model_type"
exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ] ;then
if [ ${stage} -le 51 ] && [ ${stop_stage} -ge 51 ] ;then
# wenetspeech ds2 model execed 2GB limit, will error.
# simplifying onnx model
./local/onnx_opt.sh $exp/model.onnx $exp/model.opt.onnx "$input_shape"

Loading…
Cancel
Save