From 67939d0d6691f7be48e496ddfb92c19bffd8c39a Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Sat, 14 May 2022 12:52:35 +0800 Subject: [PATCH] add check asr server model type, test=doc --- .../conf/application.yaml | 5 ++- .../server/engine/asr/online/asr_engine.py | 40 +++++++++++++------ 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/demos/streaming_asr_server/conf/application.yaml b/demos/streaming_asr_server/conf/application.yaml index f576d704..e9a89c19 100644 --- a/demos/streaming_asr_server/conf/application.yaml +++ b/demos/streaming_asr_server/conf/application.yaml @@ -29,7 +29,8 @@ asr_online: cfg_path: decode_method: force_yes: True - device: cpu # cpu or gpu:id + device: 'cpu' # cpu or gpu:id + decode_method: "attention_rescoring" am_predictor_conf: device: # set 'gpu:id' or 'cpu' switch_ir_optim: True @@ -42,4 +43,4 @@ asr_online: window_ms: 25 # ms shift_ms: 10 # ms sample_rate: 16000 - sample_width: 2 \ No newline at end of file + sample_width: 2 diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index 79b0ddb7..6280093f 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -13,6 +13,7 @@ # limitations under the License. import copy import os +import sys from typing import Optional import numpy as np @@ -588,7 +589,7 @@ class ASRServerExecutor(ASRExecutor): self.pretrained_models = pretrained_models def _init_from_path(self, - model_type: str='deepspeech2online_aishell', + model_type: str=None, am_model: Optional[os.PathLike]=None, am_params: Optional[os.PathLike]=None, lang: str='zh', @@ -599,6 +600,12 @@ class ASRServerExecutor(ASRExecutor): """ Init model and other resources from a specific path. """ + if not model_type or not lang or not sample_rate: + logger.error( + "The model type or lang or sample rate is None, please input an valid server parameter yaml" + ) + return False + self.model_type = model_type self.sample_rate = sample_rate sample_rate_str = '16k' if sample_rate == 16000 else '8k' @@ -1028,20 +1035,27 @@ class ASREngine(BaseEngine): self.device = paddle.get_device() logger.info(f"paddlespeech_server set the device: {self.device}") paddle.set_device(self.device) - except BaseException: + except BaseException as e: logger.error( - "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" + f"Set device failed, please check if device '{self.device}' is already used and the parameter 'device' in the yaml file" ) - - self.executor._init_from_path( - model_type=self.config.model_type, - am_model=self.config.am_model, - am_params=self.config.am_params, - lang=self.config.lang, - sample_rate=self.config.sample_rate, - cfg_path=self.config.cfg_path, - decode_method=self.config.decode_method, - am_predictor_conf=self.config.am_predictor_conf) + logger.error( + "If all GPU or XPU is used, you can set the server to 'cpu'") + sys.exit(-1) + + if not self.executor._init_from_path( + model_type=self.config.model_type, + am_model=self.config.am_model, + am_params=self.config.am_params, + lang=self.config.lang, + sample_rate=self.config.sample_rate, + cfg_path=self.config.cfg_path, + decode_method=self.config.decode_method, + am_predictor_conf=self.config.am_predictor_conf): + logger.error( + "Init the ASR server occurs error, please check the server configuration yaml" + ) + return False logger.info("Initialize ASR server engine successfully.") return True