From 59a78f2a4648430227def9e872bbf612817c6db9 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Fri, 17 Jun 2022 04:40:11 +0000 Subject: [PATCH] ds2 wenetspeech to onnx and support streaming asr server --- demos/streaming_asr_server/.gitignore | 2 + .../conf/ws_ds2_application.yaml | 6 +- .../local/rtf_from_log.py | 40 ++++++ demos/streaming_asr_server/local/test.sh | 21 +++ .../{ => local}/websocket_client.py | 5 +- paddlespeech/cli/utils.py | 5 +- paddlespeech/resource/pretrained_models.py | 128 ++++++++++++++---- .../server/conf/ws_ds2_application.yaml | 8 +- .../server/engine/asr/online/onnx/__init__.py | 2 +- .../asr/online/paddleinference/__init__.py | 2 +- .../engine/asr/online/python/__init__.py | 2 +- speechx/examples/ds2_ol/onnx/README.md | 2 +- .../ds2_ol/onnx/local/onnx_infer_shape.py | 2 + speechx/examples/ds2_ol/onnx/local/ort_opt.py | 45 ++++++ speechx/examples/ds2_ol/onnx/run.sh | 21 ++- 15 files changed, 242 insertions(+), 49 deletions(-) create mode 100644 demos/streaming_asr_server/.gitignore create mode 100755 demos/streaming_asr_server/local/rtf_from_log.py create mode 100755 demos/streaming_asr_server/local/test.sh rename demos/streaming_asr_server/{ => local}/websocket_client.py (94%) create mode 100755 speechx/examples/ds2_ol/onnx/local/ort_opt.py diff --git a/demos/streaming_asr_server/.gitignore b/demos/streaming_asr_server/.gitignore new file mode 100644 index 00000000..0f09019d --- /dev/null +++ b/demos/streaming_asr_server/.gitignore @@ -0,0 +1,2 @@ +exp + diff --git a/demos/streaming_asr_server/conf/ws_ds2_application.yaml b/demos/streaming_asr_server/conf/ws_ds2_application.yaml index a4e6e9a1..e7ce59c2 100644 --- a/demos/streaming_asr_server/conf/ws_ds2_application.yaml +++ b/demos/streaming_asr_server/conf/ws_ds2_application.yaml @@ -11,7 +11,7 @@ port: 8090 # protocol = ['websocket'] (only one can be selected). # websocket only support online engine type. protocol: 'websocket' -engine_list: ['asr_online-inference'] +engine_list: ['asr_online-onnx'] ################################################################################# @@ -21,7 +21,7 @@ engine_list: ['asr_online-inference'] ################################### ASR ######################################### ################### speech task: asr; engine_type: online-inference ####################### asr_online-inference: - model_type: 'deepspeech2online_aishell' + model_type: 'deepspeech2online_wenetspeech' am_model: # the pdmodel file of am static model [optional] am_params: # the pdiparams file of am static model [optional] lang: 'zh' @@ -53,7 +53,7 @@ asr_online-inference: ################################### ASR ######################################### ################### speech task: asr; engine_type: online-onnx ####################### asr_online-onnx: - model_type: 'deepspeech2online_aishell' + model_type: 'deepspeech2online_wenetspeech' am_model: # the pdmodel file of onnx am static model [optional] am_params: # the pdiparams file of am static model [optional] lang: 'zh' diff --git a/demos/streaming_asr_server/local/rtf_from_log.py b/demos/streaming_asr_server/local/rtf_from_log.py new file mode 100755 index 00000000..a5634388 --- /dev/null +++ b/demos/streaming_asr_server/local/rtf_from_log.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +import argparse + +if __name__ == '__main__': + parser = argparse.ArgumentParser(prog=__doc__) + parser.add_argument( + '--logfile', type=str, required=True, help='ws client log file') + + args = parser.parse_args() + + rtfs = [] + with open(args.logfile, 'r') as f: + for line in f: + if 'RTF=' in line: + # udio duration: 6.126, elapsed time: 3.471978187561035, RTF=0.5667610492264177 + line = line.strip() + beg = line.index("audio") + line = line[beg:] + + items = line.split(',') + vals = [] + for elem in items: + if "RTF=" in elem: + continue + _, val = elem.split(":") + vals.append(eval(val)) + keys = ['T', 'P'] + meta = dict(zip(keys, vals)) + + rtfs.append(meta) + + T = 0.0 + P = 0.0 + n = 0 + for m in rtfs: + n += 1 + T += m['T'] + P += m['P'] + + print(f"RTF: {P/T}, utts: {n}") diff --git a/demos/streaming_asr_server/local/test.sh b/demos/streaming_asr_server/local/test.sh new file mode 100755 index 00000000..d70dd336 --- /dev/null +++ b/demos/streaming_asr_server/local/test.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +if [ $# != 1 ];then + echo "usage: $0 wav_scp" + exit -1 +fi + +scp=$1 + +# calc RTF +# wav_scp can generate from `speechx/examples/ds2_ol/aishell` + +exp=exp +mkdir -p $exp + +python3 local/websocket_client.py --server_ip 127.0.0.1 --port 8090 --wavscp $scp &> $exp/log.rsl + +python3 local/rtf_from_log.py --logfile $exp/log.rsl + + + \ No newline at end of file diff --git a/demos/streaming_asr_server/websocket_client.py b/demos/streaming_asr_server/local/websocket_client.py similarity index 94% rename from demos/streaming_asr_server/websocket_client.py rename to demos/streaming_asr_server/local/websocket_client.py index 8e1f19a5..03712402 100644 --- a/demos/streaming_asr_server/websocket_client.py +++ b/demos/streaming_asr_server/local/websocket_client.py @@ -1,3 +1,4 @@ +#!/usr/bin/python # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -11,9 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -#!/usr/bin/python -# -*- coding: UTF-8 -*- -# script for calc RTF: grep -rn RTF log.txt | awk '{print $NF}' | awk -F "=" '{sum += $NF} END {print "all time",sum, "audio num", NR, "RTF", sum/NR}' +# calc avg RTF(NOT Accurate): grep -rn RTF log.txt | awk '{print $NF}' | awk -F "=" '{sum += $NF} END {print "all time",sum, "audio num", NR, "RTF", sum/NR}' import argparse import asyncio import codecs diff --git a/paddlespeech/cli/utils.py b/paddlespeech/cli/utils.py index 21c887e9..0161629e 100644 --- a/paddlespeech/cli/utils.py +++ b/paddlespeech/cli/utils.py @@ -25,10 +25,10 @@ from typing import Dict import paddle import requests +import soundfile as sf import yaml from paddle.framework import load -import paddlespeech.audio from . import download from .entry import commands try: @@ -282,7 +282,8 @@ def _note_one_stat(cls_name, params={}): if 'audio_file' in params: try: - _, sr = paddlespeech.audio.load(params['audio_file']) + # recursive import cased by: utils.DATA_HOME + _, sr = sf.read(params['audio_file']) except Exception: sr = -1 diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py index c3cef499..37303331 100644 --- a/paddlespeech/resource/pretrained_models.py +++ b/paddlespeech/resource/pretrained_models.py @@ -135,15 +135,21 @@ asr_dynamic_pretrained_models = { }, }, "deepspeech2online_wenetspeech-zh-16k": { - '1.0': { + '1.0.3': { 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.2.model.tar.gz', + 'http://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.3.model.tar.gz', 'md5': - 'b0c77e7f8881e0a27b82127d1abb8d5f', + 'cfe273793e68f790f742b411c98bc75e', 'cfg_path': 'model.yaml', 'ckpt_path': 'exp/deepspeech2_online/checkpoints/avg_10', + 'model': + 'exp/deepspeech2_online/checkpoints/avg_10.jit.pdmodel', + 'params': + 'exp/deepspeech2_online/checkpoints/avg_10.jit.pdiparams', + 'onnx_model': + 'onnx/model.onnx', 'lm_url': 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', 'lm_md5': @@ -170,14 +176,22 @@ asr_dynamic_pretrained_models = { '1.0.2': { 'url': 'http://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.2.model.tar.gz', - 'md5': '4dd42cfce9aaa54db0ec698da6c48ec5', - 'cfg_path': 'model.yaml', - 'ckpt_path':'exp/deepspeech2_online/checkpoints/avg_1', - 'model':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel', - 'params':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams', - 'onnx_model': 'onnx/model.onnx', - 'lm_url':'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', - 'lm_md5':'29e02312deb2e59b3c8686c7966d4fe3' + 'md5': + '4dd42cfce9aaa54db0ec698da6c48ec5', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/deepspeech2_online/checkpoints/avg_1', + 'model': + 'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel', + 'params': + 'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams', + 'onnx_model': + 'onnx/model.onnx', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' }, }, "deepspeech2offline_librispeech-en-16k": { @@ -241,14 +255,44 @@ asr_static_pretrained_models = { '1.0.2': { 'url': 'http://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.2.model.tar.gz', - 'md5': '4dd42cfce9aaa54db0ec698da6c48ec5', - 'cfg_path': 'model.yaml', - 'ckpt_path':'exp/deepspeech2_online/checkpoints/avg_1', - 'model':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel', - 'params':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams', - 'onnx_model': 'onnx/model.onnx', - 'lm_url':'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', - 'lm_md5':'29e02312deb2e59b3c8686c7966d4fe3' + 'md5': + '4dd42cfce9aaa54db0ec698da6c48ec5', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/deepspeech2_online/checkpoints/avg_1', + 'model': + 'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel', + 'params': + 'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams', + 'onnx_model': + 'onnx/model.onnx', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' + }, + }, + "deepspeech2online_wenetspeech-zh-16k": { + '1.0.3': { + 'url': + 'http://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.3.model.tar.gz', + 'md5': + 'cfe273793e68f790f742b411c98bc75e', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/deepspeech2_online/checkpoints/avg_10', + 'model': + 'exp/deepspeech2_online/checkpoints/avg_10.jit.pdmodel', + 'params': + 'exp/deepspeech2_online/checkpoints/avg_10.jit.pdiparams', + 'onnx_model': + 'onnx/model.onnx', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' }, }, } @@ -258,14 +302,44 @@ asr_onnx_pretrained_models = { '1.0.2': { 'url': 'http://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.2.model.tar.gz', - 'md5': '4dd42cfce9aaa54db0ec698da6c48ec5', - 'cfg_path': 'model.yaml', - 'ckpt_path':'exp/deepspeech2_online/checkpoints/avg_1', - 'model':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel', - 'params':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams', - 'onnx_model': 'onnx/model.onnx', - 'lm_url':'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', - 'lm_md5':'29e02312deb2e59b3c8686c7966d4fe3' + 'md5': + '4dd42cfce9aaa54db0ec698da6c48ec5', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/deepspeech2_online/checkpoints/avg_1', + 'model': + 'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel', + 'params': + 'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams', + 'onnx_model': + 'onnx/model.onnx', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' + }, + }, + "deepspeech2online_wenetspeech-zh-16k": { + '1.0.3': { + 'url': + 'http://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.3.model.tar.gz', + 'md5': + 'cfe273793e68f790f742b411c98bc75e', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/deepspeech2_online/checkpoints/avg_10', + 'model': + 'exp/deepspeech2_online/checkpoints/avg_10.jit.pdmodel', + 'params': + 'exp/deepspeech2_online/checkpoints/avg_10.jit.pdiparams', + 'onnx_model': + 'onnx/model.onnx', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' }, }, } diff --git a/paddlespeech/server/conf/ws_ds2_application.yaml b/paddlespeech/server/conf/ws_ds2_application.yaml index 430e6fd1..e7ce59c2 100644 --- a/paddlespeech/server/conf/ws_ds2_application.yaml +++ b/paddlespeech/server/conf/ws_ds2_application.yaml @@ -11,7 +11,7 @@ port: 8090 # protocol = ['websocket'] (only one can be selected). # websocket only support online engine type. protocol: 'websocket' -engine_list: ['asr_online-inference'] +engine_list: ['asr_online-onnx'] ################################################################################# @@ -21,7 +21,7 @@ engine_list: ['asr_online-inference'] ################################### ASR ######################################### ################### speech task: asr; engine_type: online-inference ####################### asr_online-inference: - model_type: 'deepspeech2online_aishell' + model_type: 'deepspeech2online_wenetspeech' am_model: # the pdmodel file of am static model [optional] am_params: # the pdiparams file of am static model [optional] lang: 'zh' @@ -53,7 +53,7 @@ asr_online-inference: ################################### ASR ######################################### ################### speech task: asr; engine_type: online-onnx ####################### asr_online-onnx: - model_type: 'deepspeech2online_aishell' + model_type: 'deepspeech2online_wenetspeech' am_model: # the pdmodel file of onnx am static model [optional] am_params: # the pdiparams file of am static model [optional] lang: 'zh' @@ -81,4 +81,4 @@ asr_online-onnx: window_n: 7 # frame shift_n: 4 # frame window_ms: 20 # ms - shift_ms: 10 # ms \ No newline at end of file + shift_ms: 10 # ms diff --git a/paddlespeech/server/engine/asr/online/onnx/__init__.py b/paddlespeech/server/engine/asr/online/onnx/__init__.py index c747d3e7..97043fd7 100644 --- a/paddlespeech/server/engine/asr/online/onnx/__init__.py +++ b/paddlespeech/server/engine/asr/online/onnx/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/paddlespeech/server/engine/asr/online/paddleinference/__init__.py b/paddlespeech/server/engine/asr/online/paddleinference/__init__.py index c747d3e7..97043fd7 100644 --- a/paddlespeech/server/engine/asr/online/paddleinference/__init__.py +++ b/paddlespeech/server/engine/asr/online/paddleinference/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/paddlespeech/server/engine/asr/online/python/__init__.py b/paddlespeech/server/engine/asr/online/python/__init__.py index c747d3e7..97043fd7 100644 --- a/paddlespeech/server/engine/asr/online/python/__init__.py +++ b/paddlespeech/server/engine/asr/online/python/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/speechx/examples/ds2_ol/onnx/README.md b/speechx/examples/ds2_ol/onnx/README.md index 566a4597..eaea8b6e 100644 --- a/speechx/examples/ds2_ol/onnx/README.md +++ b/speechx/examples/ds2_ol/onnx/README.md @@ -9,7 +9,7 @@ Please make sure [Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX) and The example test with these packages installed: ``` -paddle2onnx 0.9.8rc0 # develop af4354b4e9a61a93be6490640059a02a4499bc7a +paddle2onnx 0.9.8 # develop 62c5424e22cd93968dc831216fc9e0f0fce3d819 paddleaudio 0.2.1 paddlefsl 1.1.0 paddlenlp 2.2.6 diff --git a/speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py b/speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py index c41e66b7..2d364c25 100755 --- a/speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py +++ b/speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py @@ -492,6 +492,8 @@ class SymbolicShapeInference: skip_infer = node.op_type in [ 'If', 'Loop', 'Scan', 'SplitToSequence', 'ZipMap', \ # contrib ops + + 'Attention', 'BiasGelu', \ 'EmbedLayerNormalization', \ 'FastGelu', 'Gelu', 'LayerNormalization', \ diff --git a/speechx/examples/ds2_ol/onnx/local/ort_opt.py b/speechx/examples/ds2_ol/onnx/local/ort_opt.py new file mode 100755 index 00000000..8e995bcf --- /dev/null +++ b/speechx/examples/ds2_ol/onnx/local/ort_opt.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +import argparse + +import onnxruntime as ort + +# onnxruntime optimizer. +# https://onnxruntime.ai/docs/performance/graph-optimizations.html +# https://onnxruntime.ai/docs/api/python/api_summary.html#api + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--model_in', required=True, type=str, help='Path to onnx model.') + parser.add_argument( + '--opt_level', + required=True, + type=int, + default=0, + choices=[0, 1, 2], + help='Path to onnx model.') + parser.add_argument( + '--model_out', required=True, help='path to save the optimized model.') + parser.add_argument('--debug', default=False, help='output debug info.') + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_arguments() + + sess_options = ort.SessionOptions() + + # Set graph optimization level + print(f"opt level: {args.opt_level}") + if args.opt_level == 0: + sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC + elif args.opt_level == 1: + sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED + else: + sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + + # To enable model serialization after graph optimization set this + sess_options.optimized_model_filepath = args.model_out + + session = ort.InferenceSession(args.model_in, sess_options) diff --git a/speechx/examples/ds2_ol/onnx/run.sh b/speechx/examples/ds2_ol/onnx/run.sh index 57cd9416..583abda4 100755 --- a/speechx/examples/ds2_ol/onnx/run.sh +++ b/speechx/examples/ds2_ol/onnx/run.sh @@ -5,10 +5,11 @@ set -e . path.sh stage=0 -stop_stage=100 -#tarfile=asr0_deepspeech2_online_wenetspeech_ckpt_1.0.2.model.tar.gz -tarfile=asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.1.model.tar.gz -model_prefix=avg_1.jit +stop_stage=50 +tarfile=asr0_deepspeech2_online_wenetspeech_ckpt_1.0.2.model.tar.gz +#tarfile=asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.1.model.tar.gz +model_prefix=avg_10.jit +#model_prefix=avg_1.jit model=${model_prefix}.pdmodel param=${model_prefix}.pdiparams @@ -80,6 +81,14 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ];then fi +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ] ;then + # ort graph optmize + ./local/ort_opt.py --model_in $exp/model.onnx --opt_level 0 --model_out $exp/model.ort.opt.onnx + + ./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.ort.opt.onnx +fi + + # aishell rnn hidden is 1024 # wenetspeech rnn hiddn is 2048 if [ $model_type == 'aishell' ];then @@ -90,9 +99,9 @@ else echo "not support: $model_type" exit -1 fi - -if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ] ;then + +if [ ${stage} -le 51 ] && [ ${stop_stage} -ge 51 ] ;then # wenetspeech ds2 model execed 2GB limit, will error. # simplifying onnx model ./local/onnx_opt.sh $exp/model.onnx $exp/model.opt.onnx "$input_shape"