add check asr server model type, test=doc

pull/1905/head
xiongxinlei 3 years ago
parent 15271445fd
commit 67939d0d66

@ -29,7 +29,8 @@ asr_online:
cfg_path:
decode_method:
force_yes: True
device: cpu # cpu or gpu:id
device: 'cpu' # cpu or gpu:id
decode_method: "attention_rescoring"
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True

@ -13,6 +13,7 @@
# limitations under the License.
import copy
import os
import sys
from typing import Optional
import numpy as np
@ -588,7 +589,7 @@ class ASRServerExecutor(ASRExecutor):
self.pretrained_models = pretrained_models
def _init_from_path(self,
model_type: str='deepspeech2online_aishell',
model_type: str=None,
am_model: Optional[os.PathLike]=None,
am_params: Optional[os.PathLike]=None,
lang: str='zh',
@ -599,6 +600,12 @@ class ASRServerExecutor(ASRExecutor):
"""
Init model and other resources from a specific path.
"""
if not model_type or not lang or not sample_rate:
logger.error(
"The model type or lang or sample rate is None, please input an valid server parameter yaml"
)
return False
self.model_type = model_type
self.sample_rate = sample_rate
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
@ -1028,12 +1035,15 @@ class ASREngine(BaseEngine):
self.device = paddle.get_device()
logger.info(f"paddlespeech_server set the device: {self.device}")
paddle.set_device(self.device)
except BaseException:
except BaseException as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
f"Set device failed, please check if device '{self.device}' is already used and the parameter 'device' in the yaml file"
)
logger.error(
"If all GPU or XPU is used, you can set the server to 'cpu'")
sys.exit(-1)
self.executor._init_from_path(
if not self.executor._init_from_path(
model_type=self.config.model_type,
am_model=self.config.am_model,
am_params=self.config.am_params,
@ -1041,7 +1051,11 @@ class ASREngine(BaseEngine):
sample_rate=self.config.sample_rate,
cfg_path=self.config.cfg_path,
decode_method=self.config.decode_method,
am_predictor_conf=self.config.am_predictor_conf)
am_predictor_conf=self.config.am_predictor_conf):
logger.error(
"Init the ASR server occurs error, please check the server configuration yaml"
)
return False
logger.info("Initialize ASR server engine successfully.")
return True

Loading…
Cancel
Save