|
|
|
@ -13,6 +13,7 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
import copy
|
|
|
|
|
import os
|
|
|
|
|
import sys
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
@ -588,7 +589,7 @@ class ASRServerExecutor(ASRExecutor):
|
|
|
|
|
self.pretrained_models = pretrained_models
|
|
|
|
|
|
|
|
|
|
def _init_from_path(self,
|
|
|
|
|
model_type: str='deepspeech2online_aishell',
|
|
|
|
|
model_type: str=None,
|
|
|
|
|
am_model: Optional[os.PathLike]=None,
|
|
|
|
|
am_params: Optional[os.PathLike]=None,
|
|
|
|
|
lang: str='zh',
|
|
|
|
@ -599,6 +600,12 @@ class ASRServerExecutor(ASRExecutor):
|
|
|
|
|
"""
|
|
|
|
|
Init model and other resources from a specific path.
|
|
|
|
|
"""
|
|
|
|
|
if not model_type or not lang or not sample_rate:
|
|
|
|
|
logger.error(
|
|
|
|
|
"The model type or lang or sample rate is None, please input an valid server parameter yaml"
|
|
|
|
|
)
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
self.model_type = model_type
|
|
|
|
|
self.sample_rate = sample_rate
|
|
|
|
|
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
|
|
|
|
@ -730,6 +737,8 @@ class ASRServerExecutor(ASRExecutor):
|
|
|
|
|
# update the ctc decoding
|
|
|
|
|
self.searcher = CTCPrefixBeamSearch(self.config.decode)
|
|
|
|
|
self.transformer_decode_reset()
|
|
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
def reset_decoder_and_chunk(self):
|
|
|
|
|
"""reset decoder and chunk state for an new audio
|
|
|
|
@ -1028,20 +1037,27 @@ class ASREngine(BaseEngine):
|
|
|
|
|
self.device = paddle.get_device()
|
|
|
|
|
logger.info(f"paddlespeech_server set the device: {self.device}")
|
|
|
|
|
paddle.set_device(self.device)
|
|
|
|
|
except BaseException:
|
|
|
|
|
except BaseException as e:
|
|
|
|
|
logger.error(
|
|
|
|
|
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
|
|
|
|
|
f"Set device failed, please check if device '{self.device}' is already used and the parameter 'device' in the yaml file"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.executor._init_from_path(
|
|
|
|
|
model_type=self.config.model_type,
|
|
|
|
|
am_model=self.config.am_model,
|
|
|
|
|
am_params=self.config.am_params,
|
|
|
|
|
lang=self.config.lang,
|
|
|
|
|
sample_rate=self.config.sample_rate,
|
|
|
|
|
cfg_path=self.config.cfg_path,
|
|
|
|
|
decode_method=self.config.decode_method,
|
|
|
|
|
am_predictor_conf=self.config.am_predictor_conf)
|
|
|
|
|
logger.error(
|
|
|
|
|
"If all GPU or XPU is used, you can set the server to 'cpu'")
|
|
|
|
|
sys.exit(-1)
|
|
|
|
|
|
|
|
|
|
if not self.executor._init_from_path(
|
|
|
|
|
model_type=self.config.model_type,
|
|
|
|
|
am_model=self.config.am_model,
|
|
|
|
|
am_params=self.config.am_params,
|
|
|
|
|
lang=self.config.lang,
|
|
|
|
|
sample_rate=self.config.sample_rate,
|
|
|
|
|
cfg_path=self.config.cfg_path,
|
|
|
|
|
decode_method=self.config.decode_method,
|
|
|
|
|
am_predictor_conf=self.config.am_predictor_conf):
|
|
|
|
|
logger.error(
|
|
|
|
|
"Init the ASR server occurs error, please check the server configuration yaml"
|
|
|
|
|
)
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
logger.info("Initialize ASR server engine successfully.")
|
|
|
|
|
return True
|
|
|
|
|