From b9e3e49305983ff1b07d8d649dcadebfb1a71e32 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 15 Jun 2022 07:48:14 +0000 Subject: [PATCH 1/8] refactor stream asr and fix ds2 stream bug --- demos/streaming_asr_server/test.sh | 2 +- .../asr/online/{ => python}/asr_engine.py | 160 ++++++++++-------- paddlespeech/server/engine/engine_factory.py | 5 +- 3 files changed, 96 insertions(+), 71 deletions(-) rename paddlespeech/server/engine/asr/online/{ => python}/asr_engine.py (96%) diff --git a/demos/streaming_asr_server/test.sh b/demos/streaming_asr_server/test.sh index f3075454d..f09068d47 100755 --- a/demos/streaming_asr_server/test.sh +++ b/demos/streaming_asr_server/test.sh @@ -4,7 +4,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav # read the wav and pass it to only streaming asr service # If `127.0.0.1` is not accessible, you need to use the actual service IP address. # python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --wavfile ./zh.wav -paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --input ./zh.wav +paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input ./zh.wav # read the wav and call streaming and punc service # If `127.0.0.1` is not accessible, you need to use the actual service IP address. diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/python/asr_engine.py similarity index 96% rename from paddlespeech/server/engine/asr/online/asr_engine.py rename to paddlespeech/server/engine/asr/online/python/asr_engine.py index f230b8b9d..9801a6fcf 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/python/asr_engine.py @@ -121,13 +121,14 @@ class PaddleASRConnectionHanddler: raise ValueError(f"Not supported: {self.model_type}") def model_reset(self): - if "deepspeech2" in self.model_type: - return - # cache for audio and feat self.remained_wav = None self.cached_feat = None + + if "deepspeech2" in self.model_type: + return + ## conformer # cache for conformer online self.subsampling_cache = None @@ -697,6 +698,67 @@ class ASRServerExecutor(ASRExecutor): self.task_resource = CommonTaskResource( task='asr', model_format='dynamic', inference_mode='online') + def update_config(self)->None: + if "deepspeech2" in self.model_type: + with UpdateConfig(self.config): + # download lm + self.config.decode.lang_model_path = os.path.join( + MODEL_HOME, 'language_model', + self.config.decode.lang_model_path) + + lm_url = self.task_resource.res_dict['lm_url'] + lm_md5 = self.task_resource.res_dict['lm_md5'] + logger.info(f"Start to load language model {lm_url}") + self.download_lm( + lm_url, + os.path.dirname(self.config.decode.lang_model_path), lm_md5) + elif "conformer" in self.model_type or "transformer" in self.model_type: + with UpdateConfig(self.config): + logger.info("start to create the stream conformer asr engine") + # update the decoding method + if self.decode_method: + self.config.decode.decoding_method = self.decode_method + # update num_decoding_left_chunks + if self.num_decoding_left_chunks: + assert self.num_decoding_left_chunks == -1 or self.num_decoding_left_chunks >= 0, f"num_decoding_left_chunks should be -1 or >=0" + self.config.decode.num_decoding_left_chunks = self.num_decoding_left_chunks + # we only support ctc_prefix_beam_search and attention_rescoring dedoding method + # Generally we set the decoding_method to attention_rescoring + if self.config.decode.decoding_method not in [ + "ctc_prefix_beam_search", "attention_rescoring" + ]: + logger.info( + "we set the decoding_method to attention_rescoring") + self.config.decode.decoding_method = "attention_rescoring" + + assert self.config.decode.decoding_method in [ + "ctc_prefix_beam_search", "attention_rescoring" + ], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}" + else: + raise Exception(f"not support: {self.model_type}") + + def init_model(self) -> None: + if "deepspeech2" in self.model_type : + # AM predictor + logger.info("ASR engine start to init the am predictor") + self.am_predictor = init_predictor( + model_file=self.am_model, + params_file=self.am_params, + predictor_conf=self.am_predictor_conf) + elif "conformer" in self.model_type or "transformer" in self.model_type : + # load model + # model_type: {model_name}_{dataset} + model_name = self.model_type[:self.model_type.rindex('_')] + logger.info(f"model name: {model_name}") + model_class = self.task_resource.get_model_class(model_name) + model = model_class.from_config(self.config) + self.model = model + self.model.set_state_dict(paddle.load(self.am_model)) + self.model.eval() + else: + raise Exception(f"not support: {self.model_type}") + + def _init_from_path(self, model_type: str=None, am_model: Optional[os.PathLike]=None, @@ -718,8 +780,13 @@ class ASRServerExecutor(ASRExecutor): self.model_type = model_type self.sample_rate = sample_rate + self.decode_method = decode_method + self.num_decoding_left_chunks = num_decoding_left_chunks + # conf for paddleinference predictor or onnx + self.am_predictor_conf = am_predictor_conf logger.info(f"model_type: {self.model_type}") + sample_rate_str = '16k' if sample_rate == 16000 else '8k' tag = model_type + '-' + lang + '-' + sample_rate_str self.task_resource.set_task_model(model_tag=tag) @@ -763,62 +830,10 @@ class ASRServerExecutor(ASRExecutor): vocab=self.config.vocab_filepath, spm_model_prefix=self.config.spm_model_prefix) - if "deepspeech2" in model_type: - with UpdateConfig(self.config): - # download lm - self.config.decode.lang_model_path = os.path.join( - MODEL_HOME, 'language_model', - self.config.decode.lang_model_path) - - lm_url = self.task_resource.res_dict['lm_url'] - lm_md5 = self.task_resource.res_dict['lm_md5'] - logger.info(f"Start to load language model {lm_url}") - self.download_lm( - lm_url, - os.path.dirname(self.config.decode.lang_model_path), lm_md5) - - # AM predictor - logger.info("ASR engine start to init the am predictor") - self.am_predictor_conf = am_predictor_conf - self.am_predictor = init_predictor( - model_file=self.am_model, - params_file=self.am_params, - predictor_conf=self.am_predictor_conf) - - elif "conformer" in model_type or "transformer" in model_type: - with UpdateConfig(self.config): - logger.info("start to create the stream conformer asr engine") - # update the decoding method - if decode_method: - self.config.decode.decoding_method = decode_method - # update num_decoding_left_chunks - if num_decoding_left_chunks: - assert num_decoding_left_chunks == -1 or num_decoding_left_chunks >= 0, f"num_decoding_left_chunks should be -1 or >=0" - self.config.decode.num_decoding_left_chunks = num_decoding_left_chunks - # we only support ctc_prefix_beam_search and attention_rescoring dedoding method - # Generally we set the decoding_method to attention_rescoring - if self.config.decode.decoding_method not in [ - "ctc_prefix_beam_search", "attention_rescoring" - ]: - logger.info( - "we set the decoding_method to attention_rescoring") - self.config.decode.decoding_method = "attention_rescoring" - - assert self.config.decode.decoding_method in [ - "ctc_prefix_beam_search", "attention_rescoring" - ], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}" - - # load model - model_name = model_type[:model_type.rindex( - '_')] # model_type: {model_name}_{dataset} - logger.info(f"model name: {model_name}") - model_class = self.task_resource.get_model_class(model_name) - model = model_class.from_config(self.config) - self.model = model - self.model.set_state_dict(paddle.load(self.am_model)) - self.model.eval() - else: - raise Exception(f"not support: {model_type}") + self.update_config() + + # AM predictor + self.init_model() logger.info(f"create the {model_type} model success") return True @@ -835,6 +850,22 @@ class ASREngine(BaseEngine): super(ASREngine, self).__init__() logger.info("create the online asr engine resource instance") + + def init_model(self) -> bool: + if not self.executor._init_from_path( + model_type=self.config.model_type, + am_model=self.config.am_model, + am_params=self.config.am_params, + lang=self.config.lang, + sample_rate=self.config.sample_rate, + cfg_path=self.config.cfg_path, + decode_method=self.config.decode_method, + num_decoding_left_chunks=self.config.num_decoding_left_chunks, + am_predictor_conf=self.config.am_predictor_conf): + return False + return True + + def init(self, config: dict) -> bool: """init engine resource @@ -860,16 +891,7 @@ class ASREngine(BaseEngine): logger.info(f"paddlespeech_server set the device: {self.device}") - if not self.executor._init_from_path( - model_type=self.config.model_type, - am_model=self.config.am_model, - am_params=self.config.am_params, - lang=self.config.lang, - sample_rate=self.config.sample_rate, - cfg_path=self.config.cfg_path, - decode_method=self.config.decode_method, - num_decoding_left_chunks=self.config.num_decoding_left_chunks, - am_predictor_conf=self.config.am_predictor_conf): + if not self.init_model(): logger.error( "Init the ASR server occurs error, please check the server configuration yaml" ) diff --git a/paddlespeech/server/engine/engine_factory.py b/paddlespeech/server/engine/engine_factory.py index 5fdaaccea..019e46849 100644 --- a/paddlespeech/server/engine/engine_factory.py +++ b/paddlespeech/server/engine/engine_factory.py @@ -26,7 +26,10 @@ class EngineFactory(object): from paddlespeech.server.engine.asr.python.asr_engine import ASREngine return ASREngine() elif engine_name == 'asr' and engine_type == 'online': - from paddlespeech.server.engine.asr.online.asr_engine import ASREngine + from paddlespeech.server.engine.asr.online.python.asr_engine import ASREngine + return ASREngine() + elif engine_name == 'asr' and engine_type == 'online-onnx': + from paddlespeech.server.engine.asr.online.onnx.asr_engine import ASREngine return ASREngine() elif engine_name == 'tts' and engine_type == 'inference': from paddlespeech.server.engine.tts.paddleinference.tts_engine import TTSEngine From c8574c7e35a85215d88a6461f27f930d50434ab9 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 15 Jun 2022 08:44:36 +0000 Subject: [PATCH 2/8] ds2 inference as sepearte engine for streaming asr --- .../conf/ws_ds2_application.yaml | 4 +- paddlespeech/cli/asr/infer.py | 2 +- paddlespeech/resource/pretrained_models.py | 20 + ...plication.yaml => ws_ds2_application.yaml} | 4 +- .../asr/online/paddleinference/__init__.py | 0 .../asr/online/paddleinference/asr_engine.py | 539 ++++++++++++++++++ .../engine/asr/online/python/asr_engine.py | 17 +- paddlespeech/server/engine/engine_factory.py | 3 + paddlespeech/server/ws/asr_api.py | 2 +- utils/zh_tn.py | 2 +- 10 files changed, 575 insertions(+), 18 deletions(-) rename paddlespeech/server/conf/{ws_application.yaml => ws_ds2_application.yaml} (96%) create mode 100644 paddlespeech/server/engine/asr/online/paddleinference/__init__.py create mode 100644 paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py diff --git a/demos/streaming_asr_server/conf/ws_ds2_application.yaml b/demos/streaming_asr_server/conf/ws_ds2_application.yaml index d19bd26dc..4f75c07bf 100644 --- a/demos/streaming_asr_server/conf/ws_ds2_application.yaml +++ b/demos/streaming_asr_server/conf/ws_ds2_application.yaml @@ -11,7 +11,7 @@ port: 8090 # protocol = ['websocket'] (only one can be selected). # websocket only support online engine type. protocol: 'websocket' -engine_list: ['asr_online'] +engine_list: ['asr_online-inference'] ################################################################################# @@ -20,7 +20,7 @@ engine_list: ['asr_online'] ################################### ASR ######################################### ################### speech task: asr; engine_type: online ####################### -asr_online: +asr_online-inference: model_type: 'deepspeech2online_aishell' am_model: # the pdmodel file of am static model [optional] am_params: # the pdiparams file of am static model [optional] diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 00cad150e..a943ccfa7 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -187,7 +187,7 @@ class ASRExecutor(BaseExecutor): elif "conformer" in model_type or "transformer" in model_type: self.config.decode.decoding_method = decode_method if num_decoding_left_chunks: - assert num_decoding_left_chunks == -1 or num_decoding_left_chunks >= 0, f"num_decoding_left_chunks should be -1 or >=0" + assert num_decoding_left_chunks == -1 or num_decoding_left_chunks >= 0, "num_decoding_left_chunks should be -1 or >=0" self.config.num_decoding_left_chunks = num_decoding_left_chunks else: diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py index f79961d64..eb6ca0cc0 100644 --- a/paddlespeech/resource/pretrained_models.py +++ b/paddlespeech/resource/pretrained_models.py @@ -224,6 +224,26 @@ asr_static_pretrained_models = { '29e02312deb2e59b3c8686c7966d4fe3' } }, + "deepspeech2online_aishell-zh-16k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.1.model.tar.gz', + 'md5': + 'df5ddeac8b679a470176649ac4b78726', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/deepspeech2_online/checkpoints/avg_1', + 'model': + 'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel', + 'params': + 'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' + }, + }, } # --------------------------------- diff --git a/paddlespeech/server/conf/ws_application.yaml b/paddlespeech/server/conf/ws_ds2_application.yaml similarity index 96% rename from paddlespeech/server/conf/ws_application.yaml rename to paddlespeech/server/conf/ws_ds2_application.yaml index 43d83f2d4..fb16e5bdc 100644 --- a/paddlespeech/server/conf/ws_application.yaml +++ b/paddlespeech/server/conf/ws_ds2_application.yaml @@ -11,7 +11,7 @@ port: 8090 # protocol = ['websocket', 'http'] (only one can be selected). # websocket only support online engine type. protocol: 'websocket' -engine_list: ['asr_online'] +engine_list: ['asr_online-inference'] ################################################################################# @@ -20,7 +20,7 @@ engine_list: ['asr_online'] ################################### ASR ######################################### ################### speech task: asr; engine_type: online ####################### -asr_online: +asr_online-inference: model_type: 'deepspeech2online_aishell' am_model: # the pdmodel file of am static model [optional] am_params: # the pdiparams file of am static model [optional] diff --git a/paddlespeech/server/engine/asr/online/paddleinference/__init__.py b/paddlespeech/server/engine/asr/online/paddleinference/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py b/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py new file mode 100644 index 000000000..93edd7011 --- /dev/null +++ b/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py @@ -0,0 +1,539 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +from typing import ByteString +from typing import Optional + +import numpy as np +import paddle +from numpy import float32 +from yacs.config import CfgNode + +from paddlespeech.cli.asr.infer import ASRExecutor +from paddlespeech.cli.log import logger +from paddlespeech.cli.utils import MODEL_HOME +from paddlespeech.resource import CommonTaskResource +from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer +from paddlespeech.s2t.modules.ctc import CTCDecoder +from paddlespeech.s2t.transform.transformation import Transformation +from paddlespeech.s2t.utils.utility import UpdateConfig +from paddlespeech.server.engine.base_engine import BaseEngine +from paddlespeech.server.utils.paddle_predictor import init_predictor + +__all__ = ['PaddleASRConnectionHanddler', 'ASRServerExecutor', 'ASREngine'] + + +# ASR server connection process class +class PaddleASRConnectionHanddler: + def __init__(self, asr_engine): + """Init a Paddle ASR Connection Handler instance + + Args: + asr_engine (ASREngine): the global asr engine + """ + super().__init__() + logger.info( + "create an paddle asr connection handler to process the websocket connection" + ) + self.config = asr_engine.config # server config + self.model_config = asr_engine.executor.config + self.asr_engine = asr_engine + + # model_type, sample_rate and text_feature is shared for deepspeech2 and conformer + self.model_type = self.asr_engine.executor.model_type + self.sample_rate = self.asr_engine.executor.sample_rate + # tokens to text + self.text_feature = self.asr_engine.executor.text_feature + + # extract feat, new only fbank in conformer model + self.preprocess_conf = self.model_config.preprocess_config + self.preprocess_args = {"train": False} + self.preprocessing = Transformation(self.preprocess_conf) + + # frame window and frame shift, in samples unit + self.win_length = self.preprocess_conf.process[0]['win_length'] + self.n_shift = self.preprocess_conf.process[0]['n_shift'] + + assert self.preprocess_conf.process[0]['fs'] == self.sample_rate, ( + self.sample_rate, self.preprocess_conf.process[0]['fs']) + self.frame_shift_in_ms = int( + self.n_shift / self.preprocess_conf.process[0]['fs'] * 1000) + + self.continuous_decoding = self.config.get("continuous_decoding", False) + self.init_decoder() + self.reset() + + def init_decoder(self): + if "deepspeech2" in self.model_type: + assert self.continuous_decoding is False, "ds2 model not support endpoint" + self.am_predictor = self.asr_engine.executor.am_predictor + + self.decoder = CTCDecoder( + odim=self.model_config.output_dim, # is in vocab + enc_n_units=self.model_config.rnn_layer_size * 2, + blank_id=self.model_config.blank_id, + dropout_rate=0.0, + reduction=True, # sum + batch_average=True, # sum / batch_size + grad_norm_type=self.model_config.get('ctc_grad_norm_type', + None)) + + cfg = self.model_config.decode + decode_batch_size = 1 # for online + self.decoder.init_decoder( + decode_batch_size, self.text_feature.vocab_list, + cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, + cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, + cfg.num_proc_bsearch) + else: + raise ValueError(f"Not supported: {self.model_type}") + + def model_reset(self): + # cache for audio and feat + self.remained_wav = None + self.cached_feat = None + + def output_reset(self): + ## outputs + # partial/ending decoding results + self.result_transcripts = [''] + + def reset_continuous_decoding(self): + """ + when in continous decoding, reset for next utterance. + """ + self.global_frame_offset = self.num_frames + self.model_reset() + + def reset(self): + if "deepspeech2" in self.model_type: + # for deepspeech2 + # init state + self.chunk_state_h_box = np.zeros( + (self.model_config.num_rnn_layers, 1, + self.model_config.rnn_layer_size), + dtype=float32) + self.chunk_state_c_box = np.zeros( + (self.model_config.num_rnn_layers, 1, + self.model_config.rnn_layer_size), + dtype=float32) + self.decoder.reset_decoder(batch_size=1) + else: + raise NotImplementedError(f"{self.model_type} not support.") + + self.device = None + + ## common + # global sample and frame step + self.num_samples = 0 + self.global_frame_offset = 0 + # frame step of cur utterance + self.num_frames = 0 + + ## endpoint + self.endpoint_state = False # True for detect endpoint + + ## conformer + self.model_reset() + + ## outputs + self.output_reset() + + def extract_feat(self, samples: ByteString): + logger.info("Online ASR extract the feat") + samples = np.frombuffer(samples, dtype=np.int16) + assert samples.ndim == 1 + + self.num_samples += samples.shape[0] + logger.info( + f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}" + ) + + # self.reamined_wav stores all the samples, + # include the original remained_wav and this package samples + if self.remained_wav is None: + self.remained_wav = samples + else: + assert self.remained_wav.ndim == 1 # (T,) + self.remained_wav = np.concatenate([self.remained_wav, samples]) + logger.info( + f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}" + ) + + if len(self.remained_wav) < self.win_length: + # samples not enough for feature window + return 0 + + # fbank + x_chunk = self.preprocessing(self.remained_wav, **self.preprocess_args) + x_chunk = paddle.to_tensor(x_chunk, dtype="float32").unsqueeze(axis=0) + + # feature cache + if self.cached_feat is None: + self.cached_feat = x_chunk + else: + assert (len(x_chunk.shape) == 3) # (B,T,D) + assert (len(self.cached_feat.shape) == 3) # (B,T,D) + self.cached_feat = paddle.concat( + [self.cached_feat, x_chunk], axis=1) + + # set the feat device + if self.device is None: + self.device = self.cached_feat.place + + # cur frame step + num_frames = x_chunk.shape[1] + + # global frame step + self.num_frames += num_frames + + # update remained wav + self.remained_wav = self.remained_wav[self.n_shift * num_frames:] + + logger.info( + f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}" + ) + logger.info( + f"After extract feat, the cached remain the audio samples: {self.remained_wav.shape}" + ) + logger.info(f"global samples: {self.num_samples}") + logger.info(f"global frames: {self.num_frames}") + + def decode(self, is_finished=False): + """advance decoding + + Args: + is_finished (bool, optional): Is last frame or not. Defaults to False. + + Returns: + None: + """ + if "deepspeech2" in self.model_type: + decoding_chunk_size = 1 # decoding chunk size = 1. int decoding frame unit + + context = 7 # context=7, in audio frame unit + subsampling = 4 # subsampling=4, in audio frame unit + + cached_feature_num = context - subsampling + # decoding window for model, in audio frame unit + decoding_window = (decoding_chunk_size - 1) * subsampling + context + # decoding stride for model, in audio frame unit + stride = subsampling * decoding_chunk_size + + if self.cached_feat is None: + logger.info("no audio feat, please input more pcm data") + return + + num_frames = self.cached_feat.shape[1] + logger.info( + f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" + ) + + # the cached feat must be larger decoding_window + if num_frames < decoding_window and not is_finished: + logger.info( + f"frame feat num is less than {decoding_window}, please input more pcm data" + ) + return None, None + + # if is_finished=True, we need at least context frames + if num_frames < context: + logger.info( + "flast {num_frames} is less than context {context} frames, and we cannot do model forward" + ) + return None, None + + logger.info("start to do model forward") + # num_frames - context + 1 ensure that current frame can get context window + if is_finished: + # if get the finished chunk, we need process the last context + left_frames = context + else: + # we only process decoding_window frames for one chunk + left_frames = decoding_window + + end = None + for cur in range(0, num_frames - left_frames + 1, stride): + end = min(cur + decoding_window, num_frames) + + # extract the audio + x_chunk = self.cached_feat[:, cur:end, :].numpy() + x_chunk_lens = np.array([x_chunk.shape[1]]) + + trans_best = self.decode_one_chunk(x_chunk, x_chunk_lens) + + self.result_transcripts = [trans_best] + + # update feat cache + self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :] + + # return trans_best[0] + else: + raise Exception(f"{self.model_type} not support paddleinference.") + + @paddle.no_grad() + def decode_one_chunk(self, x_chunk, x_chunk_lens): + """forward one chunk frames + + Args: + x_chunk (np.ndarray): (B,T,D), audio frames. + x_chunk_lens ([type]): (B,), audio frame lens + + Returns: + logprob: poster probability. + """ + logger.info("start to decoce one chunk for deepspeech2") + input_names = self.am_predictor.get_input_names() + audio_handle = self.am_predictor.get_input_handle(input_names[0]) + audio_len_handle = self.am_predictor.get_input_handle(input_names[1]) + h_box_handle = self.am_predictor.get_input_handle(input_names[2]) + c_box_handle = self.am_predictor.get_input_handle(input_names[3]) + + audio_handle.reshape(x_chunk.shape) + audio_handle.copy_from_cpu(x_chunk) + + audio_len_handle.reshape(x_chunk_lens.shape) + audio_len_handle.copy_from_cpu(x_chunk_lens) + + h_box_handle.reshape(self.chunk_state_h_box.shape) + h_box_handle.copy_from_cpu(self.chunk_state_h_box) + + c_box_handle.reshape(self.chunk_state_c_box.shape) + c_box_handle.copy_from_cpu(self.chunk_state_c_box) + + output_names = self.am_predictor.get_output_names() + output_handle = self.am_predictor.get_output_handle(output_names[0]) + output_lens_handle = self.am_predictor.get_output_handle( + output_names[1]) + output_state_h_handle = self.am_predictor.get_output_handle( + output_names[2]) + output_state_c_handle = self.am_predictor.get_output_handle( + output_names[3]) + + self.am_predictor.run() + + output_chunk_probs = output_handle.copy_to_cpu() + output_chunk_lens = output_lens_handle.copy_to_cpu() + self.chunk_state_h_box = output_state_h_handle.copy_to_cpu() + self.chunk_state_c_box = output_state_c_handle.copy_to_cpu() + + self.decoder.next(output_chunk_probs, output_chunk_lens) + trans_best, trans_beam = self.decoder.decode() + logger.info(f"decode one best result for deepspeech2: {trans_best[0]}") + return trans_best[0] + + def get_result(self): + """return partial/ending asr result. + + Returns: + str: one best result of partial/ending. + """ + if len(self.result_transcripts) > 0: + return self.result_transcripts[0] + else: + return '' + + +class ASRServerExecutor(ASRExecutor): + def __init__(self): + super().__init__() + self.task_resource = CommonTaskResource( + task='asr', model_format='static', inference_mode='online') + + def update_config(self) -> None: + if "deepspeech2" in self.model_type: + with UpdateConfig(self.config): + # download lm + self.config.decode.lang_model_path = os.path.join( + MODEL_HOME, 'language_model', + self.config.decode.lang_model_path) + + lm_url = self.task_resource.res_dict['lm_url'] + lm_md5 = self.task_resource.res_dict['lm_md5'] + logger.info(f"Start to load language model {lm_url}") + self.download_lm( + lm_url, + os.path.dirname(self.config.decode.lang_model_path), lm_md5) + else: + raise NotImplementedError( + f"{self.model_type} not support paddleinference.") + + def init_model(self) -> None: + + if "deepspeech2" in self.model_type: + # AM predictor + logger.info("ASR engine start to init the am predictor") + self.am_predictor = init_predictor( + model_file=self.am_model, + params_file=self.am_params, + predictor_conf=self.am_predictor_conf) + else: + raise NotImplementedError( + f"{self.model_type} not support paddleinference.") + + def _init_from_path(self, + model_type: str=None, + am_model: Optional[os.PathLike]=None, + am_params: Optional[os.PathLike]=None, + lang: str='zh', + sample_rate: int=16000, + cfg_path: Optional[os.PathLike]=None, + decode_method: str='attention_rescoring', + num_decoding_left_chunks: int=-1, + am_predictor_conf: dict=None): + """ + Init model and other resources from a specific path. + """ + if not model_type or not lang or not sample_rate: + logger.error( + "The model type or lang or sample rate is None, please input an valid server parameter yaml" + ) + return False + + self.model_type = model_type + self.sample_rate = sample_rate + self.decode_method = decode_method + self.num_decoding_left_chunks = num_decoding_left_chunks + # conf for paddleinference predictor or onnx + self.am_predictor_conf = am_predictor_conf + logger.info(f"model_type: {self.model_type}") + + sample_rate_str = '16k' if sample_rate == 16000 else '8k' + tag = model_type + '-' + lang + '-' + sample_rate_str + self.task_resource.set_task_model(model_tag=tag) + + if cfg_path is None or am_model is None or am_params is None: + self.res_path = self.task_resource.res_dir + self.cfg_path = os.path.join( + self.res_path, self.task_resource.res_dict['cfg_path']) + + self.am_model = os.path.join(self.res_path, + self.task_resource.res_dict['model']) + self.am_params = os.path.join(self.res_path, + self.task_resource.res_dict['params']) + else: + self.cfg_path = os.path.abspath(cfg_path) + self.am_model = os.path.abspath(am_model) + self.am_params = os.path.abspath(am_params) + self.res_path = os.path.dirname( + os.path.dirname(os.path.abspath(self.cfg_path))) + + logger.info("Load the pretrained model:") + logger.info(f" tag = {tag}") + logger.info(f" res_path: {self.res_path}") + logger.info(f" cfg path: {self.cfg_path}") + logger.info(f" am_model path: {self.am_model}") + logger.info(f" am_params path: {self.am_params}") + + #Init body. + self.config = CfgNode(new_allowed=True) + self.config.merge_from_file(self.cfg_path) + + if self.config.spm_model_prefix: + self.config.spm_model_prefix = os.path.join( + self.res_path, self.config.spm_model_prefix) + logger.info(f"spm model path: {self.config.spm_model_prefix}") + + self.vocab = self.config.vocab_filepath + + self.text_feature = TextFeaturizer( + unit_type=self.config.unit_type, + vocab=self.config.vocab_filepath, + spm_model_prefix=self.config.spm_model_prefix) + + self.update_config() + + # AM predictor + self.init_model() + + logger.info(f"create the {model_type} model success") + return True + + +class ASREngine(BaseEngine): + """ASR model resource + + Args: + metaclass: Defaults to Singleton. + """ + + def __init__(self): + super(ASREngine, self).__init__() + logger.info("create the online asr engine resource instance") + + def init_model(self) -> bool: + if not self.executor._init_from_path( + model_type=self.config.model_type, + am_model=self.config.am_model, + am_params=self.config.am_params, + lang=self.config.lang, + sample_rate=self.config.sample_rate, + cfg_path=self.config.cfg_path, + decode_method=self.config.decode_method, + num_decoding_left_chunks=self.config.num_decoding_left_chunks, + am_predictor_conf=self.config.am_predictor_conf): + return False + return True + + def init(self, config: dict) -> bool: + """init engine resource + + Args: + config_file (str): config file + + Returns: + bool: init failed or success + """ + self.config = config + self.executor = ASRServerExecutor() + + try: + self.device = self.config.get("device", paddle.get_device()) + paddle.set_device(self.device) + except BaseException as e: + logger.error( + f"Set device failed, please check if device '{self.device}' is already used and the parameter 'device' in the yaml file" + ) + logger.error( + "If all GPU or XPU is used, you can set the server to 'cpu'") + sys.exit(-1) + + logger.info(f"paddlespeech_server set the device: {self.device}") + + if not self.init_model(): + logger.error( + "Init the ASR server occurs error, please check the server configuration yaml" + ) + return False + + logger.info("Initialize ASR server engine successfully.") + return True + + def new_handler(self): + """New handler from model. + + Returns: + PaddleASRConnectionHanddler: asr handler instance + """ + return PaddleASRConnectionHanddler(self) + + def preprocess(self, *args, **kwargs): + raise NotImplementedError("Online not using this.") + + def run(self, *args, **kwargs): + raise NotImplementedError("Online not using this.") + + def postprocess(self): + raise NotImplementedError("Online not using this.") diff --git a/paddlespeech/server/engine/asr/online/python/asr_engine.py b/paddlespeech/server/engine/asr/online/python/asr_engine.py index 9801a6fcf..231137af6 100644 --- a/paddlespeech/server/engine/asr/online/python/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/python/asr_engine.py @@ -125,7 +125,6 @@ class PaddleASRConnectionHanddler: self.remained_wav = None self.cached_feat = None - if "deepspeech2" in self.model_type: return @@ -698,7 +697,7 @@ class ASRServerExecutor(ASRExecutor): self.task_resource = CommonTaskResource( task='asr', model_format='dynamic', inference_mode='online') - def update_config(self)->None: + def update_config(self) -> None: if "deepspeech2" in self.model_type: with UpdateConfig(self.config): # download lm @@ -720,7 +719,7 @@ class ASRServerExecutor(ASRExecutor): self.config.decode.decoding_method = self.decode_method # update num_decoding_left_chunks if self.num_decoding_left_chunks: - assert self.num_decoding_left_chunks == -1 or self.num_decoding_left_chunks >= 0, f"num_decoding_left_chunks should be -1 or >=0" + assert self.num_decoding_left_chunks == -1 or self.num_decoding_left_chunks >= 0, "num_decoding_left_chunks should be -1 or >=0" self.config.decode.num_decoding_left_chunks = self.num_decoding_left_chunks # we only support ctc_prefix_beam_search and attention_rescoring dedoding method # Generally we set the decoding_method to attention_rescoring @@ -738,17 +737,17 @@ class ASRServerExecutor(ASRExecutor): raise Exception(f"not support: {self.model_type}") def init_model(self) -> None: - if "deepspeech2" in self.model_type : + if "deepspeech2" in self.model_type: # AM predictor logger.info("ASR engine start to init the am predictor") self.am_predictor = init_predictor( model_file=self.am_model, params_file=self.am_params, predictor_conf=self.am_predictor_conf) - elif "conformer" in self.model_type or "transformer" in self.model_type : + elif "conformer" in self.model_type or "transformer" in self.model_type: # load model # model_type: {model_name}_{dataset} - model_name = self.model_type[:self.model_type.rindex('_')] + model_name = self.model_type[:self.model_type.rindex('_')] logger.info(f"model name: {model_name}") model_class = self.task_resource.get_model_class(model_name) model = model_class.from_config(self.config) @@ -758,7 +757,6 @@ class ASRServerExecutor(ASRExecutor): else: raise Exception(f"not support: {self.model_type}") - def _init_from_path(self, model_type: str=None, am_model: Optional[os.PathLike]=None, @@ -786,7 +784,6 @@ class ASRServerExecutor(ASRExecutor): self.am_predictor_conf = am_predictor_conf logger.info(f"model_type: {self.model_type}") - sample_rate_str = '16k' if sample_rate == 16000 else '8k' tag = model_type + '-' + lang + '-' + sample_rate_str self.task_resource.set_task_model(model_tag=tag) @@ -831,7 +828,7 @@ class ASRServerExecutor(ASRExecutor): spm_model_prefix=self.config.spm_model_prefix) self.update_config() - + # AM predictor self.init_model() @@ -850,7 +847,6 @@ class ASREngine(BaseEngine): super(ASREngine, self).__init__() logger.info("create the online asr engine resource instance") - def init_model(self) -> bool: if not self.executor._init_from_path( model_type=self.config.model_type, @@ -865,7 +861,6 @@ class ASREngine(BaseEngine): return False return True - def init(self, config: dict) -> bool: """init engine resource diff --git a/paddlespeech/server/engine/engine_factory.py b/paddlespeech/server/engine/engine_factory.py index 019e46849..cfb0deb35 100644 --- a/paddlespeech/server/engine/engine_factory.py +++ b/paddlespeech/server/engine/engine_factory.py @@ -28,6 +28,9 @@ class EngineFactory(object): elif engine_name == 'asr' and engine_type == 'online': from paddlespeech.server.engine.asr.online.python.asr_engine import ASREngine return ASREngine() + elif engine_name == 'asr' and engine_type == 'online-inference': + from paddlespeech.server.engine.asr.online.paddleinference.asr_engine import ASREngine + return ASREngine() elif engine_name == 'asr' and engine_type == 'online-onnx': from paddlespeech.server.engine.asr.online.onnx.asr_engine import ASREngine return ASREngine() diff --git a/paddlespeech/server/ws/asr_api.py b/paddlespeech/server/ws/asr_api.py index 23609b41a..ae1c88310 100644 --- a/paddlespeech/server/ws/asr_api.py +++ b/paddlespeech/server/ws/asr_api.py @@ -92,7 +92,7 @@ async def websocket_endpoint(websocket: WebSocket): else: resp = {"status": "ok", "message": "no valid json data"} await websocket.send_json(resp) - + elif "bytes" in message: # bytes for the pcm data message = message["bytes"] diff --git a/utils/zh_tn.py b/utils/zh_tn.py index 73bb8af22..6fee626bd 100755 --- a/utils/zh_tn.py +++ b/utils/zh_tn.py @@ -747,7 +747,7 @@ def num2chn(number_string, previous_symbol, (CNU, type(None))): if next_symbol.power != 1 and ( (previous_symbol is None) or - (previous_symbol.power != 1)): + (previous_symbol.power != 1)): # noqa: E129 result_symbols[i] = liang # if big is True, 'δΈ€' will not be used and `alt_two` has no impact on output From 3cee7db021b971d5221324cbaa17638f1b4bb9f1 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 15 Jun 2022 10:20:44 +0000 Subject: [PATCH 3/8] onnx ds2 straming asr --- .../conf/ws_ds2_application.yaml | 43 +- paddlespeech/resource/pretrained_models.py | 16 + .../server/conf/ws_ds2_application.yaml | 49 +- .../engine/asr/online/onnx/asr_engine.py | 520 ++++++++++++++++++ .../asr/online/paddleinference/asr_engine.py | 1 - .../engine/asr/online/python/asr_engine.py | 1 - paddlespeech/server/engine/engine_factory.py | 3 +- paddlespeech/server/utils/onnx_infer.py | 26 +- 8 files changed, 638 insertions(+), 21 deletions(-) create mode 100644 paddlespeech/server/engine/asr/online/onnx/asr_engine.py diff --git a/demos/streaming_asr_server/conf/ws_ds2_application.yaml b/demos/streaming_asr_server/conf/ws_ds2_application.yaml index 4f75c07bf..f0a98e72a 100644 --- a/demos/streaming_asr_server/conf/ws_ds2_application.yaml +++ b/demos/streaming_asr_server/conf/ws_ds2_application.yaml @@ -7,11 +7,11 @@ host: 0.0.0.0 port: 8090 # The task format in the engin_list is: _ -# task choices = ['asr_online'] +# task choices = ['asr_online-inference', 'asr_online-onnx'] # protocol = ['websocket'] (only one can be selected). # websocket only support online engine type. protocol: 'websocket' -engine_list: ['asr_online-inference'] +engine_list: ['asr_online-onnx'] ################################################################################# @@ -19,10 +19,10 @@ engine_list: ['asr_online-inference'] ################################################################################# ################################### ASR ######################################### -################### speech task: asr; engine_type: online ####################### +################### speech task: asr; engine_type: online-inference ####################### asr_online-inference: model_type: 'deepspeech2online_aishell' - am_model: # the pdmodel file of am static model [optional] + am_model: # the pdmodel file of am static model [optional] am_params: # the pdiparams file of am static model [optional] lang: 'zh' sample_rate: 16000 @@ -47,3 +47,38 @@ asr_online-inference: shift_n: 4 # frame window_ms: 20 # ms shift_ms: 10 # ms + + + +################################### ASR ######################################### +################### speech task: asr; engine_type: online-onnx ####################### +asr_online-onnx: + model_type: 'deepspeech2online_aishell' + am_model: # the pdmodel file of am static model [optional] + am_params: # the pdiparams file of am static model [optional] + lang: 'zh' + sample_rate: 16000 + cfg_path: + decode_method: + num_decoding_left_chunks: + force_yes: True + device: 'cpu' # cpu or gpu:id + + # https://onnxruntime.ai/docs/api/python/api_summary.html#inferencesession + am_predictor_conf: + device: 'cpu' # set 'gpu:id' or 'cpu' + graph_optimization_level: 0 + intra_op_num_threads: 0 # Sets the number of threads used to parallelize the execution within nodes. + inter_op_num_threads: 0 # Sets the number of threads used to parallelize the execution of the graph (across nodes). + log_severity_level: 2 # Log severity level. Applies to session load, initialization, etc. 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2. + log_verbosity_level: 0 # VLOG level if DEBUG build and session_log_severity_level is 0. Applies to session load, initialization, etc. Default is 0. + + chunk_buffer_conf: + frame_duration_ms: 80 + shift_ms: 40 + sample_rate: 16000 + sample_width: 2 + window_n: 7 # frame + shift_n: 4 # frame + window_ms: 20 # ms + shift_ms: 10 # ms diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py index eb6ca0cc0..ba4a79d9d 100644 --- a/paddlespeech/resource/pretrained_models.py +++ b/paddlespeech/resource/pretrained_models.py @@ -15,6 +15,7 @@ __all__ = [ 'asr_dynamic_pretrained_models', 'asr_static_pretrained_models', + 'asr_onnx_pretrained_models', 'cls_dynamic_pretrained_models', 'cls_static_pretrained_models', 'st_dynamic_pretrained_models', @@ -246,6 +247,21 @@ asr_static_pretrained_models = { }, } + +asr_onnx_pretrained_models = { + "deepspeech2online_wenetspeech-zh-16k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.2.model.tar.gz', + 'md5': 'b0c77e7f8881e0a27b82127d1abb8d5f', + 'cfg_path':'model.yaml', + 'ckpt_path':'exp/deepspeech2_online/checkpoints/avg_10', + 'lm_url': 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': '29e02312deb2e59b3c8686c7966d4fe3' + }, + }, +} + # --------------------------------- # -------------- CLS -------------- # --------------------------------- diff --git a/paddlespeech/server/conf/ws_ds2_application.yaml b/paddlespeech/server/conf/ws_ds2_application.yaml index fb16e5bdc..f0a98e72a 100644 --- a/paddlespeech/server/conf/ws_ds2_application.yaml +++ b/paddlespeech/server/conf/ws_ds2_application.yaml @@ -7,11 +7,11 @@ host: 0.0.0.0 port: 8090 # The task format in the engin_list is: _ -# task choices = ['asr_online', 'tts_online'] -# protocol = ['websocket', 'http'] (only one can be selected). +# task choices = ['asr_online-inference', 'asr_online-onnx'] +# protocol = ['websocket'] (only one can be selected). # websocket only support online engine type. protocol: 'websocket' -engine_list: ['asr_online-inference'] +engine_list: ['asr_online-onnx'] ################################################################################# @@ -19,18 +19,18 @@ engine_list: ['asr_online-inference'] ################################################################################# ################################### ASR ######################################### -################### speech task: asr; engine_type: online ####################### +################### speech task: asr; engine_type: online-inference ####################### asr_online-inference: model_type: 'deepspeech2online_aishell' - am_model: # the pdmodel file of am static model [optional] + am_model: # the pdmodel file of am static model [optional] am_params: # the pdiparams file of am static model [optional] lang: 'zh' sample_rate: 16000 cfg_path: decode_method: - num_decoding_left_chunks: + num_decoding_left_chunks: force_yes: True - device: # cpu or gpu:id + device: 'cpu' # cpu or gpu:id am_predictor_conf: device: # set 'gpu:id' or 'cpu' @@ -47,3 +47,38 @@ asr_online-inference: shift_n: 4 # frame window_ms: 20 # ms shift_ms: 10 # ms + + + +################################### ASR ######################################### +################### speech task: asr; engine_type: online-onnx ####################### +asr_online-onnx: + model_type: 'deepspeech2online_aishell' + am_model: # the pdmodel file of am static model [optional] + am_params: # the pdiparams file of am static model [optional] + lang: 'zh' + sample_rate: 16000 + cfg_path: + decode_method: + num_decoding_left_chunks: + force_yes: True + device: 'cpu' # cpu or gpu:id + + # https://onnxruntime.ai/docs/api/python/api_summary.html#inferencesession + am_predictor_conf: + device: 'cpu' # set 'gpu:id' or 'cpu' + graph_optimization_level: 0 + intra_op_num_threads: 0 # Sets the number of threads used to parallelize the execution within nodes. + inter_op_num_threads: 0 # Sets the number of threads used to parallelize the execution of the graph (across nodes). + log_severity_level: 2 # Log severity level. Applies to session load, initialization, etc. 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2. + log_verbosity_level: 0 # VLOG level if DEBUG build and session_log_severity_level is 0. Applies to session load, initialization, etc. Default is 0. + + chunk_buffer_conf: + frame_duration_ms: 80 + shift_ms: 40 + sample_rate: 16000 + sample_width: 2 + window_n: 7 # frame + shift_n: 4 # frame + window_ms: 20 # ms + shift_ms: 10 # ms diff --git a/paddlespeech/server/engine/asr/online/onnx/asr_engine.py b/paddlespeech/server/engine/asr/online/onnx/asr_engine.py new file mode 100644 index 000000000..0bd2f950f --- /dev/null +++ b/paddlespeech/server/engine/asr/online/onnx/asr_engine.py @@ -0,0 +1,520 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +from typing import ByteString +from typing import Optional + +import numpy as np +import paddle +from numpy import float32 +from yacs.config import CfgNode + +from paddlespeech.cli.asr.infer import ASRExecutor +from paddlespeech.cli.log import logger +from paddlespeech.cli.utils import MODEL_HOME +from paddlespeech.resource import CommonTaskResource +from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer +from paddlespeech.s2t.modules.ctc import CTCDecoder +from paddlespeech.s2t.transform.transformation import Transformation +from paddlespeech.s2t.utils.utility import UpdateConfig +from paddlespeech.server.engine.base_engine import BaseEngine +from paddlespeech.server.utils import onnx_infer + +__all__ = ['PaddleASRConnectionHanddler', 'ASRServerExecutor', 'ASREngine'] + + +# ASR server connection process class +class PaddleASRConnectionHanddler: + def __init__(self, asr_engine): + """Init a Paddle ASR Connection Handler instance + + Args: + asr_engine (ASREngine): the global asr engine + """ + super().__init__() + logger.info( + "create an paddle asr connection handler to process the websocket connection" + ) + self.config = asr_engine.config # server config + self.model_config = asr_engine.executor.config + self.asr_engine = asr_engine + + # model_type, sample_rate and text_feature is shared for deepspeech2 and conformer + self.model_type = self.asr_engine.executor.model_type + self.sample_rate = self.asr_engine.executor.sample_rate + # tokens to text + self.text_feature = self.asr_engine.executor.text_feature + + # extract feat, new only fbank in conformer model + self.preprocess_conf = self.model_config.preprocess_config + self.preprocess_args = {"train": False} + self.preprocessing = Transformation(self.preprocess_conf) + + # frame window and frame shift, in samples unit + self.win_length = self.preprocess_conf.process[0]['win_length'] + self.n_shift = self.preprocess_conf.process[0]['n_shift'] + + assert self.preprocess_conf.process[0]['fs'] == self.sample_rate, ( + self.sample_rate, self.preprocess_conf.process[0]['fs']) + self.frame_shift_in_ms = int( + self.n_shift / self.preprocess_conf.process[0]['fs'] * 1000) + + self.continuous_decoding = self.config.get("continuous_decoding", False) + self.init_decoder() + self.reset() + + def init_decoder(self): + if "deepspeech2" in self.model_type: + assert self.continuous_decoding is False, "ds2 model not support endpoint" + self.am_predictor = self.asr_engine.executor.am_predictor + + self.decoder = CTCDecoder( + odim=self.model_config.output_dim, # is in vocab + enc_n_units=self.model_config.rnn_layer_size * 2, + blank_id=self.model_config.blank_id, + dropout_rate=0.0, + reduction=True, # sum + batch_average=True, # sum / batch_size + grad_norm_type=self.model_config.get('ctc_grad_norm_type', + None)) + + cfg = self.model_config.decode + decode_batch_size = 1 # for online + self.decoder.init_decoder( + decode_batch_size, self.text_feature.vocab_list, + cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, + cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, + cfg.num_proc_bsearch) + else: + raise ValueError(f"Not supported: {self.model_type}") + + def model_reset(self): + # cache for audio and feat + self.remained_wav = None + self.cached_feat = None + + def output_reset(self): + ## outputs + # partial/ending decoding results + self.result_transcripts = [''] + + def reset_continuous_decoding(self): + """ + when in continous decoding, reset for next utterance. + """ + self.global_frame_offset = self.num_frames + self.model_reset() + + def reset(self): + if "deepspeech2" in self.model_type: + # for deepspeech2 + # init state + self.chunk_state_h_box = np.zeros( + (self.model_config.num_rnn_layers, 1, + self.model_config.rnn_layer_size), + dtype=float32) + self.chunk_state_c_box = np.zeros( + (self.model_config.num_rnn_layers, 1, + self.model_config.rnn_layer_size), + dtype=float32) + self.decoder.reset_decoder(batch_size=1) + else: + raise NotImplementedError(f"{self.model_type} not support.") + + self.device = None + + ## common + # global sample and frame step + self.num_samples = 0 + self.global_frame_offset = 0 + # frame step of cur utterance + self.num_frames = 0 + + ## endpoint + self.endpoint_state = False # True for detect endpoint + + ## conformer + self.model_reset() + + ## outputs + self.output_reset() + + def extract_feat(self, samples: ByteString): + logger.info("Online ASR extract the feat") + samples = np.frombuffer(samples, dtype=np.int16) + assert samples.ndim == 1 + + self.num_samples += samples.shape[0] + logger.info( + f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}" + ) + + # self.reamined_wav stores all the samples, + # include the original remained_wav and this package samples + if self.remained_wav is None: + self.remained_wav = samples + else: + assert self.remained_wav.ndim == 1 # (T,) + self.remained_wav = np.concatenate([self.remained_wav, samples]) + logger.info( + f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}" + ) + + if len(self.remained_wav) < self.win_length: + # samples not enough for feature window + return 0 + + # fbank + x_chunk = self.preprocessing(self.remained_wav, **self.preprocess_args) + x_chunk = paddle.to_tensor(x_chunk, dtype="float32").unsqueeze(axis=0) + + # feature cache + if self.cached_feat is None: + self.cached_feat = x_chunk + else: + assert (len(x_chunk.shape) == 3) # (B,T,D) + assert (len(self.cached_feat.shape) == 3) # (B,T,D) + self.cached_feat = paddle.concat( + [self.cached_feat, x_chunk], axis=1) + + # set the feat device + if self.device is None: + self.device = self.cached_feat.place + + # cur frame step + num_frames = x_chunk.shape[1] + + # global frame step + self.num_frames += num_frames + + # update remained wav + self.remained_wav = self.remained_wav[self.n_shift * num_frames:] + + logger.info( + f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}" + ) + logger.info( + f"After extract feat, the cached remain the audio samples: {self.remained_wav.shape}" + ) + logger.info(f"global samples: {self.num_samples}") + logger.info(f"global frames: {self.num_frames}") + + def decode(self, is_finished=False): + """advance decoding + + Args: + is_finished (bool, optional): Is last frame or not. Defaults to False. + + Returns: + None: + """ + if "deepspeech2" in self.model_type: + decoding_chunk_size = 1 # decoding chunk size = 1. int decoding frame unit + + context = 7 # context=7, in audio frame unit + subsampling = 4 # subsampling=4, in audio frame unit + + cached_feature_num = context - subsampling + # decoding window for model, in audio frame unit + decoding_window = (decoding_chunk_size - 1) * subsampling + context + # decoding stride for model, in audio frame unit + stride = subsampling * decoding_chunk_size + + if self.cached_feat is None: + logger.info("no audio feat, please input more pcm data") + return + + num_frames = self.cached_feat.shape[1] + logger.info( + f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" + ) + + # the cached feat must be larger decoding_window + if num_frames < decoding_window and not is_finished: + logger.info( + f"frame feat num is less than {decoding_window}, please input more pcm data" + ) + return None, None + + # if is_finished=True, we need at least context frames + if num_frames < context: + logger.info( + "flast {num_frames} is less than context {context} frames, and we cannot do model forward" + ) + return None, None + + logger.info("start to do model forward") + # num_frames - context + 1 ensure that current frame can get context window + if is_finished: + # if get the finished chunk, we need process the last context + left_frames = context + else: + # we only process decoding_window frames for one chunk + left_frames = decoding_window + + end = None + for cur in range(0, num_frames - left_frames + 1, stride): + end = min(cur + decoding_window, num_frames) + + # extract the audio + x_chunk = self.cached_feat[:, cur:end, :].numpy() + x_chunk_lens = np.array([x_chunk.shape[1]]) + + trans_best = self.decode_one_chunk(x_chunk, x_chunk_lens) + + self.result_transcripts = [trans_best] + + # update feat cache + self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :] + + # return trans_best[0] + else: + raise Exception(f"{self.model_type} not support paddleinference.") + + @paddle.no_grad() + def decode_one_chunk(self, x_chunk, x_chunk_lens): + """forward one chunk frames + + Args: + x_chunk (np.ndarray): (B,T,D), audio frames. + x_chunk_lens ([type]): (B,), audio frame lens + + Returns: + logprob: poster probability. + """ + logger.info("start to decoce one chunk for deepspeech2") + # state_c, state_h, audio_lens, audio + # 'chunk_state_c_box', 'chunk_state_h_box', 'audio_chunk_lens', 'audio_chunk' + input_names = [n.name for n in self.am_predictor.get_inputs()] + logger.info(f"ort inputs: {input_names}") + # 'softmax_0.tmp_0', 'tmp_5', 'concat_0.tmp_0', 'concat_1.tmp_0' + # audio, audio_lens, state_h, state_c + output_names = [n.name for n in self.am_predictor.get_outputs()] + logger.info(f"ort outpus: {output_names}") + assert (len(input_names) == len(output_names)) + assert isinstance(input_names[0], str) + + input_datas = [self.chunk_state_c_box, self.chunk_state_h_box, x_chunk_lens, x_chunk] + feeds = dict(zip(input_names, input_datas)) + + outputs = self.am_predictor.run( + [*output_names], + {**feeds}) + + output_chunk_probs, output_chunk_lens, self.chunk_state_h_box, self.chunk_state_c_box = outputs + self.decoder.next(output_chunk_probs, output_chunk_lens) + trans_best, trans_beam = self.decoder.decode() + logger.info(f"decode one best result for deepspeech2: {trans_best[0]}") + return trans_best[0] + + def get_result(self): + """return partial/ending asr result. + + Returns: + str: one best result of partial/ending. + """ + if len(self.result_transcripts) > 0: + return self.result_transcripts[0] + else: + return '' + + +class ASRServerExecutor(ASRExecutor): + def __init__(self): + super().__init__() + self.task_resource = CommonTaskResource( + task='asr', model_format='static', inference_mode='online') + + def update_config(self) -> None: + if "deepspeech2" in self.model_type: + with UpdateConfig(self.config): + # download lm + self.config.decode.lang_model_path = os.path.join( + MODEL_HOME, 'language_model', + self.config.decode.lang_model_path) + + lm_url = self.task_resource.res_dict['lm_url'] + lm_md5 = self.task_resource.res_dict['lm_md5'] + logger.info(f"Start to load language model {lm_url}") + self.download_lm( + lm_url, + os.path.dirname(self.config.decode.lang_model_path), lm_md5) + else: + raise NotImplementedError( + f"{self.model_type} not support paddleinference.") + + def init_model(self) -> None: + + if "deepspeech2" in self.model_type: + # AM predictor + logger.info("ASR engine start to init the am predictor") + self.am_predictor = onnx_infer.get_sess( + model_path=self.am_model, sess_conf=self.am_predictor_conf) + else: + raise NotImplementedError( + f"{self.model_type} not support paddleinference.") + + def _init_from_path(self, + model_type: str=None, + am_model: Optional[os.PathLike]=None, + am_params: Optional[os.PathLike]=None, + lang: str='zh', + sample_rate: int=16000, + cfg_path: Optional[os.PathLike]=None, + decode_method: str='attention_rescoring', + num_decoding_left_chunks: int=-1, + am_predictor_conf: dict=None): + """ + Init model and other resources from a specific path. + """ + if not model_type or not lang or not sample_rate: + logger.error( + "The model type or lang or sample rate is None, please input an valid server parameter yaml" + ) + return False + assert am_params is None, "am_params not used in onnx engine" + + self.model_type = model_type + self.sample_rate = sample_rate + self.decode_method = decode_method + self.num_decoding_left_chunks = num_decoding_left_chunks + # conf for paddleinference predictor or onnx + self.am_predictor_conf = am_predictor_conf + logger.info(f"model_type: {self.model_type}") + + sample_rate_str = '16k' if sample_rate == 16000 else '8k' + tag = model_type + '-' + lang + '-' + sample_rate_str + self.task_resource.set_task_model(model_tag=tag) + + if cfg_path is None: + self.res_path = self.task_resource.res_dir + self.cfg_path = os.path.join( + self.res_path, self.task_resource.res_dict['cfg_path']) + else: + self.cfg_path = os.path.abspath(cfg_path) + self.res_path = os.path.dirname( + os.path.dirname(os.path.abspath(self.cfg_path))) + + self.am_model = os.path.join(self.res_path, + self.task_resource.res_dict['model']) if am_model is None else os.path.abspath(am_model) + self.am_params = os.path.join(self.res_path, + self.task_resource.res_dict['params']) if am_params is None else os.path.abspath(am_params) + + logger.info("Load the pretrained model:") + logger.info(f" tag = {tag}") + logger.info(f" res_path: {self.res_path}") + logger.info(f" cfg path: {self.cfg_path}") + logger.info(f" am_model path: {self.am_model}") + logger.info(f" am_params path: {self.am_params}") + + #Init body. + self.config = CfgNode(new_allowed=True) + self.config.merge_from_file(self.cfg_path) + + if self.config.spm_model_prefix: + self.config.spm_model_prefix = os.path.join( + self.res_path, self.config.spm_model_prefix) + logger.info(f"spm model path: {self.config.spm_model_prefix}") + + self.vocab = self.config.vocab_filepath + + self.text_feature = TextFeaturizer( + unit_type=self.config.unit_type, + vocab=self.config.vocab_filepath, + spm_model_prefix=self.config.spm_model_prefix) + + self.update_config() + + # AM predictor + self.init_model() + + logger.info(f"create the {model_type} model success") + return True + + +class ASREngine(BaseEngine): + """ASR model resource + + Args: + metaclass: Defaults to Singleton. + """ + + def __init__(self): + super(ASREngine, self).__init__() + + def init_model(self) -> bool: + if not self.executor._init_from_path( + model_type=self.config.model_type, + am_model=self.config.am_model, + am_params=self.config.am_params, + lang=self.config.lang, + sample_rate=self.config.sample_rate, + cfg_path=self.config.cfg_path, + decode_method=self.config.decode_method, + num_decoding_left_chunks=self.config.num_decoding_left_chunks, + am_predictor_conf=self.config.am_predictor_conf): + return False + return True + + def init(self, config: dict) -> bool: + """init engine resource + + Args: + config_file (str): config file + + Returns: + bool: init failed or success + """ + self.config = config + self.executor = ASRServerExecutor() + + try: + self.device = self.config.get("device", paddle.get_device()) + paddle.set_device(self.device) + except BaseException as e: + logger.error( + f"Set device failed, please check if device '{self.device}' is already used and the parameter 'device' in the yaml file" + ) + logger.error( + "If all GPU or XPU is used, you can set the server to 'cpu'") + sys.exit(-1) + + logger.info(f"paddlespeech_server set the device: {self.device}") + + if not self.init_model(): + logger.error( + "Init the ASR server occurs error, please check the server configuration yaml" + ) + return False + + logger.info("Initialize ASR server engine successfully.") + return True + + def new_handler(self): + """New handler from model. + + Returns: + PaddleASRConnectionHanddler: asr handler instance + """ + return PaddleASRConnectionHanddler(self) + + def preprocess(self, *args, **kwargs): + raise NotImplementedError("Online not using this.") + + def run(self, *args, **kwargs): + raise NotImplementedError("Online not using this.") + + def postprocess(self): + raise NotImplementedError("Online not using this.") diff --git a/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py b/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py index 93edd7011..fb24cab98 100644 --- a/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py @@ -471,7 +471,6 @@ class ASREngine(BaseEngine): def __init__(self): super(ASREngine, self).__init__() - logger.info("create the online asr engine resource instance") def init_model(self) -> bool: if not self.executor._init_from_path( diff --git a/paddlespeech/server/engine/asr/online/python/asr_engine.py b/paddlespeech/server/engine/asr/online/python/asr_engine.py index 231137af6..c22cbbe5f 100644 --- a/paddlespeech/server/engine/asr/online/python/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/python/asr_engine.py @@ -845,7 +845,6 @@ class ASREngine(BaseEngine): def __init__(self): super(ASREngine, self).__init__() - logger.info("create the online asr engine resource instance") def init_model(self) -> bool: if not self.executor._init_from_path( diff --git a/paddlespeech/server/engine/engine_factory.py b/paddlespeech/server/engine/engine_factory.py index cfb0deb35..3c1c3d533 100644 --- a/paddlespeech/server/engine/engine_factory.py +++ b/paddlespeech/server/engine/engine_factory.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Text +from ..utils.log import logger __all__ = ['EngineFactory'] - class EngineFactory(object): @staticmethod def get_engine(engine_name: Text, engine_type: Text): + logger.info(f"{engine_name} : {engine_type} engine.") if engine_name == 'asr' and engine_type == 'inference': from paddlespeech.server.engine.asr.paddleinference.asr_engine import ASREngine return ASREngine() diff --git a/paddlespeech/server/utils/onnx_infer.py b/paddlespeech/server/utils/onnx_infer.py index ac11c534b..4287477f2 100644 --- a/paddlespeech/server/utils/onnx_infer.py +++ b/paddlespeech/server/utils/onnx_infer.py @@ -16,21 +16,33 @@ from typing import Optional import onnxruntime as ort +from .log import logger + def get_sess(model_path: Optional[os.PathLike]=None, sess_conf: dict=None): + logger.info(f"ort sessconf: {sess_conf}") sess_options = ort.SessionOptions() sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + if sess_conf.get('graph_optimization_level', 99) == 0: + sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL - if "gpu" in sess_conf["device"]: + # "gpu:0" + providers = ['CPUExecutionProvider'] + if "gpu" in sess_conf.get("device", ""): + providers = ['CUDAExecutionProvider'] # fastspeech2/mb_melgan can't use trt now! - if sess_conf["use_trt"]: + if sess_conf.get("use_trt", 0): providers = ['TensorrtExecutionProvider'] - else: - providers = ['CUDAExecutionProvider'] - elif sess_conf["device"] == "cpu": - providers = ['CPUExecutionProvider'] - sess_options.intra_op_num_threads = sess_conf["cpu_threads"] + logger.info(f"ort providers: {providers}") + + if 'cpu_threads' in sess_conf: + sess_options.intra_op_num_threads = sess_conf.get("cpu_threads", 0) + else: + sess_options.intra_op_num_threads = sess_conf.get("intra_op_num_threads", 0) + + sess_options.inter_op_num_threads = sess_conf.get("inter_op_num_threads", 0) + sess = ort.InferenceSession( model_path, providers=providers, sess_options=sess_options) return sess From 42d28b961ca16adf3f0e7280bd6f16d8fd11c8f3 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 16 Jun 2022 05:00:30 +0000 Subject: [PATCH 4/8] fix pretrian model error --- paddlespeech/resource/pretrained_models.py | 73 ++++++++++++++-------- 1 file changed, 48 insertions(+), 25 deletions(-) diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py index 196edd50e..f13713476 100644 --- a/paddlespeech/resource/pretrained_models.py +++ b/paddlespeech/resource/pretrained_models.py @@ -170,14 +170,22 @@ asr_dynamic_pretrained_models = { '1.0.2': { 'url': 'http://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.2.model.tar.gz', - 'md5': '4dd42cfce9aaa54db0ec698da6c48ec5', - 'cfg_path': 'model.yaml', - 'ckpt_path':'exp/deepspeech2_online/checkpoints/avg_1', - 'model':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel', - 'params':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams', - 'onnx_model': 'onnx/model.onnx' - 'lm_url':'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', - 'lm_md5':'29e02312deb2e59b3c8686c7966d4fe3' + 'md5': + '4dd42cfce9aaa54db0ec698da6c48ec5', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/deepspeech2_online/checkpoints/avg_1', + 'model': + 'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel', + 'params': + 'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams', + 'onnx_model': + 'onnx/model.onnx', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' }, }, "deepspeech2offline_librispeech-en-16k": { @@ -241,32 +249,47 @@ asr_static_pretrained_models = { '1.0.2': { 'url': 'http://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.2.model.tar.gz', - 'md5': '4dd42cfce9aaa54db0ec698da6c48ec5', - 'cfg_path': 'model.yaml', - 'ckpt_path':'exp/deepspeech2_online/checkpoints/avg_1', - 'model':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel', - 'params':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams', - 'onnx_model': 'onnx/model.onnx' - 'lm_url':'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', - 'lm_md5':'29e02312deb2e59b3c8686c7966d4fe3' + 'md5': + '4dd42cfce9aaa54db0ec698da6c48ec5', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/deepspeech2_online/checkpoints/avg_1', + 'model': + 'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel', + 'params': + 'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams', + 'onnx_model': + 'onnx/model.onnx', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' }, }, } - asr_onnx_pretrained_models = { "deepspeech2online_aishell-zh-16k": { '1.0.2': { 'url': 'http://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.2.model.tar.gz', - 'md5': '4dd42cfce9aaa54db0ec698da6c48ec5', - 'cfg_path': 'model.yaml', - 'ckpt_path':'exp/deepspeech2_online/checkpoints/avg_1', - 'model':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel', - 'params':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams', - 'onnx_model': 'onnx/model.onnx' - 'lm_url':'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', - 'lm_md5':'29e02312deb2e59b3c8686c7966d4fe3' + 'md5': + '4dd42cfce9aaa54db0ec698da6c48ec5', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/deepspeech2_online/checkpoints/avg_1', + 'model': + 'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel', + 'params': + 'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams', + 'onnx_model': + 'onnx/model.onnx', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' }, }, } From 9106daa2a3b6cce4017fd4b268461b33d2418b18 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 16 Jun 2022 05:01:08 +0000 Subject: [PATCH 5/8] code format --- .../conf/ws_ds2_application.yaml | 2 +- paddlespeech/resource/resource.py | 4 +++- .../engine/asr/online/onnx/asr_engine.py | 20 ++++++++++--------- paddlespeech/server/engine/engine_factory.py | 3 +++ paddlespeech/server/utils/onnx_infer.py | 9 +++++---- .../examples/ds2_ol/onnx/local/infer_check.py | 4 ++-- 6 files changed, 25 insertions(+), 17 deletions(-) diff --git a/demos/streaming_asr_server/conf/ws_ds2_application.yaml b/demos/streaming_asr_server/conf/ws_ds2_application.yaml index f0a98e72a..f67d3157a 100644 --- a/demos/streaming_asr_server/conf/ws_ds2_application.yaml +++ b/demos/streaming_asr_server/conf/ws_ds2_application.yaml @@ -11,7 +11,7 @@ port: 8090 # protocol = ['websocket'] (only one can be selected). # websocket only support online engine type. protocol: 'websocket' -engine_list: ['asr_online-onnx'] +engine_list: ['asr_online-inference'] ################################################################################# diff --git a/paddlespeech/resource/resource.py b/paddlespeech/resource/resource.py index 369dba900..2e637f0f1 100644 --- a/paddlespeech/resource/resource.py +++ b/paddlespeech/resource/resource.py @@ -164,9 +164,11 @@ class CommonTaskResource: try: import_models = '{}_{}_pretrained_models'.format(self.task, self.model_format) + print(f"from .pretrained_models import {import_models}") exec('from .pretrained_models import {}'.format(import_models)) models = OrderedDict(locals()[import_models]) - except ImportError: + except Exception as e: + print(e) models = OrderedDict({}) # no models. finally: return models diff --git a/paddlespeech/server/engine/asr/online/onnx/asr_engine.py b/paddlespeech/server/engine/asr/online/onnx/asr_engine.py index 0bd2f950f..97addc7a3 100644 --- a/paddlespeech/server/engine/asr/online/onnx/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/onnx/asr_engine.py @@ -306,12 +306,13 @@ class PaddleASRConnectionHanddler: assert (len(input_names) == len(output_names)) assert isinstance(input_names[0], str) - input_datas = [self.chunk_state_c_box, self.chunk_state_h_box, x_chunk_lens, x_chunk] + input_datas = [ + self.chunk_state_c_box, self.chunk_state_h_box, x_chunk_lens, + x_chunk + ] feeds = dict(zip(input_names, input_datas)) - outputs = self.am_predictor.run( - [*output_names], - {**feeds}) + outputs = self.am_predictor.run([*output_names], {**feeds}) output_chunk_probs, output_chunk_lens, self.chunk_state_h_box, self.chunk_state_c_box = outputs self.decoder.next(output_chunk_probs, output_chunk_lens) @@ -335,7 +336,7 @@ class ASRServerExecutor(ASRExecutor): def __init__(self): super().__init__() self.task_resource = CommonTaskResource( - task='asr', model_format='static', inference_mode='online') + task='asr', model_format='onnx', inference_mode='online') def update_config(self) -> None: if "deepspeech2" in self.model_type: @@ -407,10 +408,11 @@ class ASRServerExecutor(ASRExecutor): self.res_path = os.path.dirname( os.path.dirname(os.path.abspath(self.cfg_path))) - self.am_model = os.path.join(self.res_path, - self.task_resource.res_dict['model']) if am_model is None else os.path.abspath(am_model) - self.am_params = os.path.join(self.res_path, - self.task_resource.res_dict['params']) if am_params is None else os.path.abspath(am_params) + self.am_model = os.path.join(self.res_path, self.task_resource.res_dict[ + 'model']) if am_model is None else os.path.abspath(am_model) + self.am_params = os.path.join( + self.res_path, self.task_resource.res_dict[ + 'params']) if am_params is None else os.path.abspath(am_params) logger.info("Load the pretrained model:") logger.info(f" tag = {tag}") diff --git a/paddlespeech/server/engine/engine_factory.py b/paddlespeech/server/engine/engine_factory.py index 3c1c3d533..6a66a002e 100644 --- a/paddlespeech/server/engine/engine_factory.py +++ b/paddlespeech/server/engine/engine_factory.py @@ -12,14 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Text + from ..utils.log import logger __all__ = ['EngineFactory'] + class EngineFactory(object): @staticmethod def get_engine(engine_name: Text, engine_type: Text): logger.info(f"{engine_name} : {engine_type} engine.") + if engine_name == 'asr' and engine_type == 'inference': from paddlespeech.server.engine.asr.paddleinference.asr_engine import ASREngine return ASREngine() diff --git a/paddlespeech/server/utils/onnx_infer.py b/paddlespeech/server/utils/onnx_infer.py index 4287477f2..1c9d878f8 100644 --- a/paddlespeech/server/utils/onnx_infer.py +++ b/paddlespeech/server/utils/onnx_infer.py @@ -35,14 +35,15 @@ def get_sess(model_path: Optional[os.PathLike]=None, sess_conf: dict=None): if sess_conf.get("use_trt", 0): providers = ['TensorrtExecutionProvider'] logger.info(f"ort providers: {providers}") - + if 'cpu_threads' in sess_conf: - sess_options.intra_op_num_threads = sess_conf.get("cpu_threads", 0) + sess_options.intra_op_num_threads = sess_conf.get("cpu_threads", 0) else: - sess_options.intra_op_num_threads = sess_conf.get("intra_op_num_threads", 0) + sess_options.intra_op_num_threads = sess_conf.get( + "intra_op_num_threads", 0) sess_options.inter_op_num_threads = sess_conf.get("inter_op_num_threads", 0) - + sess = ort.InferenceSession( model_path, providers=providers, sess_options=sess_options) return sess diff --git a/speechx/examples/ds2_ol/onnx/local/infer_check.py b/speechx/examples/ds2_ol/onnx/local/infer_check.py index a5ec7ce34..f821baa12 100755 --- a/speechx/examples/ds2_ol/onnx/local/infer_check.py +++ b/speechx/examples/ds2_ol/onnx/local/infer_check.py @@ -27,7 +27,8 @@ def parse_args(): '--input_file', type=str, default="static_ds2online_inputs.pickle", - help="aishell ds2 input data file. For wenetspeech, we only feed for infer model", ) + help="aishell ds2 input data file. For wenetspeech, we only feed for infer model", + ) parser.add_argument( '--model_type', type=str, @@ -57,7 +58,6 @@ if __name__ == '__main__': iodict = pickle.load(f) print(iodict.keys()) - audio_chunk = iodict['audio_chunk'] audio_chunk_lens = iodict['audio_chunk_lens'] chunk_state_h_box = iodict['chunk_state_h_box'] From 5e03d753acb7e63a37dd34e0647a12c782b1cb13 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 16 Jun 2022 06:49:19 +0000 Subject: [PATCH 6/8] add ds2 steaming asr onnx --- .../conf/ws_ds2_application.yaml | 6 +++--- paddlespeech/resource/resource.py | 1 - .../server/conf/ws_ds2_application.yaml | 10 +++++----- .../engine/asr/online/onnx/asr_engine.py | 18 +++++++++++++----- .../asr/online/paddleinference/asr_engine.py | 6 ++++++ 5 files changed, 27 insertions(+), 14 deletions(-) diff --git a/demos/streaming_asr_server/conf/ws_ds2_application.yaml b/demos/streaming_asr_server/conf/ws_ds2_application.yaml index f67d3157a..a4e6e9a1f 100644 --- a/demos/streaming_asr_server/conf/ws_ds2_application.yaml +++ b/demos/streaming_asr_server/conf/ws_ds2_application.yaml @@ -22,8 +22,8 @@ engine_list: ['asr_online-inference'] ################### speech task: asr; engine_type: online-inference ####################### asr_online-inference: model_type: 'deepspeech2online_aishell' - am_model: # the pdmodel file of am static model [optional] - am_params: # the pdiparams file of am static model [optional] + am_model: # the pdmodel file of am static model [optional] + am_params: # the pdiparams file of am static model [optional] lang: 'zh' sample_rate: 16000 cfg_path: @@ -54,7 +54,7 @@ asr_online-inference: ################### speech task: asr; engine_type: online-onnx ####################### asr_online-onnx: model_type: 'deepspeech2online_aishell' - am_model: # the pdmodel file of am static model [optional] + am_model: # the pdmodel file of onnx am static model [optional] am_params: # the pdiparams file of am static model [optional] lang: 'zh' sample_rate: 16000 diff --git a/paddlespeech/resource/resource.py b/paddlespeech/resource/resource.py index 2e637f0f1..15112ba7d 100644 --- a/paddlespeech/resource/resource.py +++ b/paddlespeech/resource/resource.py @@ -168,7 +168,6 @@ class CommonTaskResource: exec('from .pretrained_models import {}'.format(import_models)) models = OrderedDict(locals()[import_models]) except Exception as e: - print(e) models = OrderedDict({}) # no models. finally: return models diff --git a/paddlespeech/server/conf/ws_ds2_application.yaml b/paddlespeech/server/conf/ws_ds2_application.yaml index f0a98e72a..430e6fd12 100644 --- a/paddlespeech/server/conf/ws_ds2_application.yaml +++ b/paddlespeech/server/conf/ws_ds2_application.yaml @@ -11,7 +11,7 @@ port: 8090 # protocol = ['websocket'] (only one can be selected). # websocket only support online engine type. protocol: 'websocket' -engine_list: ['asr_online-onnx'] +engine_list: ['asr_online-inference'] ################################################################################# @@ -22,8 +22,8 @@ engine_list: ['asr_online-onnx'] ################### speech task: asr; engine_type: online-inference ####################### asr_online-inference: model_type: 'deepspeech2online_aishell' - am_model: # the pdmodel file of am static model [optional] - am_params: # the pdiparams file of am static model [optional] + am_model: # the pdmodel file of am static model [optional] + am_params: # the pdiparams file of am static model [optional] lang: 'zh' sample_rate: 16000 cfg_path: @@ -54,7 +54,7 @@ asr_online-inference: ################### speech task: asr; engine_type: online-onnx ####################### asr_online-onnx: model_type: 'deepspeech2online_aishell' - am_model: # the pdmodel file of am static model [optional] + am_model: # the pdmodel file of onnx am static model [optional] am_params: # the pdiparams file of am static model [optional] lang: 'zh' sample_rate: 16000 @@ -81,4 +81,4 @@ asr_online-onnx: window_n: 7 # frame shift_n: 4 # frame window_ms: 20 # ms - shift_ms: 10 # ms + shift_ms: 10 # ms \ No newline at end of file diff --git a/paddlespeech/server/engine/asr/online/onnx/asr_engine.py b/paddlespeech/server/engine/asr/online/onnx/asr_engine.py index 97addc7a3..aab29f78e 100644 --- a/paddlespeech/server/engine/asr/online/onnx/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/onnx/asr_engine.py @@ -331,6 +331,13 @@ class PaddleASRConnectionHanddler: else: return '' + def get_word_time_stamp(self): + return [] + + @paddle.no_grad() + def rescoring(self): + ... + class ASRServerExecutor(ASRExecutor): def __init__(self): @@ -409,17 +416,18 @@ class ASRServerExecutor(ASRExecutor): os.path.dirname(os.path.abspath(self.cfg_path))) self.am_model = os.path.join(self.res_path, self.task_resource.res_dict[ - 'model']) if am_model is None else os.path.abspath(am_model) - self.am_params = os.path.join( - self.res_path, self.task_resource.res_dict[ - 'params']) if am_params is None else os.path.abspath(am_params) + 'onnx_model']) if am_model is None else os.path.abspath(am_model) + + # self.am_params = os.path.join( + # self.res_path, self.task_resource.res_dict[ + # 'params']) if am_params is None else os.path.abspath(am_params) logger.info("Load the pretrained model:") logger.info(f" tag = {tag}") logger.info(f" res_path: {self.res_path}") logger.info(f" cfg path: {self.cfg_path}") logger.info(f" am_model path: {self.am_model}") - logger.info(f" am_params path: {self.am_params}") + # logger.info(f" am_params path: {self.am_params}") #Init body. self.config = CfgNode(new_allowed=True) diff --git a/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py b/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py index fb24cab98..b3b31a5a0 100644 --- a/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py @@ -345,6 +345,12 @@ class PaddleASRConnectionHanddler: else: return '' + def get_word_time_stamp(self): + return [] + + @paddle.no_grad() + def rescoring(self): + ... class ASRServerExecutor(ASRExecutor): def __init__(self): From 27f2833bf7c7f0d6682753c3fc0e2cb5d5ede37f Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 16 Jun 2022 07:10:32 +0000 Subject: [PATCH 7/8] format --- .../server/engine/asr/online/paddleinference/asr_engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py b/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py index b3b31a5a0..a450e430b 100644 --- a/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py @@ -352,6 +352,7 @@ class PaddleASRConnectionHanddler: def rescoring(self): ... + class ASRServerExecutor(ASRExecutor): def __init__(self): super().__init__() From 0f8e9cdd32cef875200a9f0c90cc1dd5630a82aa Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 16 Jun 2022 08:41:10 +0000 Subject: [PATCH 8/8] add init file --- .../server/engine/asr/online/onnx/__init__.py | 13 +++++++++++++ .../engine/asr/online/paddleinference/__init__.py | 13 +++++++++++++ .../server/engine/asr/online/python/__init__.py | 13 +++++++++++++ 3 files changed, 39 insertions(+) create mode 100644 paddlespeech/server/engine/asr/online/onnx/__init__.py create mode 100644 paddlespeech/server/engine/asr/online/python/__init__.py diff --git a/paddlespeech/server/engine/asr/online/onnx/__init__.py b/paddlespeech/server/engine/asr/online/onnx/__init__.py new file mode 100644 index 000000000..c747d3e7a --- /dev/null +++ b/paddlespeech/server/engine/asr/online/onnx/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/paddlespeech/server/engine/asr/online/paddleinference/__init__.py b/paddlespeech/server/engine/asr/online/paddleinference/__init__.py index e69de29bb..c747d3e7a 100644 --- a/paddlespeech/server/engine/asr/online/paddleinference/__init__.py +++ b/paddlespeech/server/engine/asr/online/paddleinference/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/paddlespeech/server/engine/asr/online/python/__init__.py b/paddlespeech/server/engine/asr/online/python/__init__.py new file mode 100644 index 000000000..c747d3e7a --- /dev/null +++ b/paddlespeech/server/engine/asr/online/python/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file