add check asr server model type, test=doc

pull/1905/head
xiongxinlei 3 years ago
parent 15271445fd
commit 67939d0d66

@ -29,7 +29,8 @@ asr_online:
cfg_path: cfg_path:
decode_method: decode_method:
force_yes: True force_yes: True
device: cpu # cpu or gpu:id device: 'cpu' # cpu or gpu:id
decode_method: "attention_rescoring"
am_predictor_conf: am_predictor_conf:
device: # set 'gpu:id' or 'cpu' device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True switch_ir_optim: True

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import copy import copy
import os import os
import sys
from typing import Optional from typing import Optional
import numpy as np import numpy as np
@ -588,7 +589,7 @@ class ASRServerExecutor(ASRExecutor):
self.pretrained_models = pretrained_models self.pretrained_models = pretrained_models
def _init_from_path(self, def _init_from_path(self,
model_type: str='deepspeech2online_aishell', model_type: str=None,
am_model: Optional[os.PathLike]=None, am_model: Optional[os.PathLike]=None,
am_params: Optional[os.PathLike]=None, am_params: Optional[os.PathLike]=None,
lang: str='zh', lang: str='zh',
@ -599,6 +600,12 @@ class ASRServerExecutor(ASRExecutor):
""" """
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
if not model_type or not lang or not sample_rate:
logger.error(
"The model type or lang or sample rate is None, please input an valid server parameter yaml"
)
return False
self.model_type = model_type self.model_type = model_type
self.sample_rate = sample_rate self.sample_rate = sample_rate
sample_rate_str = '16k' if sample_rate == 16000 else '8k' sample_rate_str = '16k' if sample_rate == 16000 else '8k'
@ -1028,12 +1035,15 @@ class ASREngine(BaseEngine):
self.device = paddle.get_device() self.device = paddle.get_device()
logger.info(f"paddlespeech_server set the device: {self.device}") logger.info(f"paddlespeech_server set the device: {self.device}")
paddle.set_device(self.device) paddle.set_device(self.device)
except BaseException: except BaseException as e:
logger.error( logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file" f"Set device failed, please check if device '{self.device}' is already used and the parameter 'device' in the yaml file"
) )
logger.error(
"If all GPU or XPU is used, you can set the server to 'cpu'")
sys.exit(-1)
self.executor._init_from_path( if not self.executor._init_from_path(
model_type=self.config.model_type, model_type=self.config.model_type,
am_model=self.config.am_model, am_model=self.config.am_model,
am_params=self.config.am_params, am_params=self.config.am_params,
@ -1041,7 +1051,11 @@ class ASREngine(BaseEngine):
sample_rate=self.config.sample_rate, sample_rate=self.config.sample_rate,
cfg_path=self.config.cfg_path, cfg_path=self.config.cfg_path,
decode_method=self.config.decode_method, decode_method=self.config.decode_method,
am_predictor_conf=self.config.am_predictor_conf) am_predictor_conf=self.config.am_predictor_conf):
logger.error(
"Init the ASR server occurs error, please check the server configuration yaml"
)
return False
logger.info("Initialize ASR server engine successfully.") logger.info("Initialize ASR server engine successfully.")
return True return True

Loading…
Cancel
Save