|
|
@ -121,13 +121,14 @@ class PaddleASRConnectionHanddler:
|
|
|
|
raise ValueError(f"Not supported: {self.model_type}")
|
|
|
|
raise ValueError(f"Not supported: {self.model_type}")
|
|
|
|
|
|
|
|
|
|
|
|
def model_reset(self):
|
|
|
|
def model_reset(self):
|
|
|
|
if "deepspeech2" in self.model_type:
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# cache for audio and feat
|
|
|
|
# cache for audio and feat
|
|
|
|
self.remained_wav = None
|
|
|
|
self.remained_wav = None
|
|
|
|
self.cached_feat = None
|
|
|
|
self.cached_feat = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "deepspeech2" in self.model_type:
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
## conformer
|
|
|
|
## conformer
|
|
|
|
# cache for conformer online
|
|
|
|
# cache for conformer online
|
|
|
|
self.subsampling_cache = None
|
|
|
|
self.subsampling_cache = None
|
|
|
@ -697,6 +698,67 @@ class ASRServerExecutor(ASRExecutor):
|
|
|
|
self.task_resource = CommonTaskResource(
|
|
|
|
self.task_resource = CommonTaskResource(
|
|
|
|
task='asr', model_format='dynamic', inference_mode='online')
|
|
|
|
task='asr', model_format='dynamic', inference_mode='online')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_config(self)->None:
|
|
|
|
|
|
|
|
if "deepspeech2" in self.model_type:
|
|
|
|
|
|
|
|
with UpdateConfig(self.config):
|
|
|
|
|
|
|
|
# download lm
|
|
|
|
|
|
|
|
self.config.decode.lang_model_path = os.path.join(
|
|
|
|
|
|
|
|
MODEL_HOME, 'language_model',
|
|
|
|
|
|
|
|
self.config.decode.lang_model_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm_url = self.task_resource.res_dict['lm_url']
|
|
|
|
|
|
|
|
lm_md5 = self.task_resource.res_dict['lm_md5']
|
|
|
|
|
|
|
|
logger.info(f"Start to load language model {lm_url}")
|
|
|
|
|
|
|
|
self.download_lm(
|
|
|
|
|
|
|
|
lm_url,
|
|
|
|
|
|
|
|
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
|
|
|
|
|
|
|
|
elif "conformer" in self.model_type or "transformer" in self.model_type:
|
|
|
|
|
|
|
|
with UpdateConfig(self.config):
|
|
|
|
|
|
|
|
logger.info("start to create the stream conformer asr engine")
|
|
|
|
|
|
|
|
# update the decoding method
|
|
|
|
|
|
|
|
if self.decode_method:
|
|
|
|
|
|
|
|
self.config.decode.decoding_method = self.decode_method
|
|
|
|
|
|
|
|
# update num_decoding_left_chunks
|
|
|
|
|
|
|
|
if self.num_decoding_left_chunks:
|
|
|
|
|
|
|
|
assert self.num_decoding_left_chunks == -1 or self.num_decoding_left_chunks >= 0, f"num_decoding_left_chunks should be -1 or >=0"
|
|
|
|
|
|
|
|
self.config.decode.num_decoding_left_chunks = self.num_decoding_left_chunks
|
|
|
|
|
|
|
|
# we only support ctc_prefix_beam_search and attention_rescoring dedoding method
|
|
|
|
|
|
|
|
# Generally we set the decoding_method to attention_rescoring
|
|
|
|
|
|
|
|
if self.config.decode.decoding_method not in [
|
|
|
|
|
|
|
|
"ctc_prefix_beam_search", "attention_rescoring"
|
|
|
|
|
|
|
|
]:
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
|
|
|
"we set the decoding_method to attention_rescoring")
|
|
|
|
|
|
|
|
self.config.decode.decoding_method = "attention_rescoring"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert self.config.decode.decoding_method in [
|
|
|
|
|
|
|
|
"ctc_prefix_beam_search", "attention_rescoring"
|
|
|
|
|
|
|
|
], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}"
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
raise Exception(f"not support: {self.model_type}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_model(self) -> None:
|
|
|
|
|
|
|
|
if "deepspeech2" in self.model_type :
|
|
|
|
|
|
|
|
# AM predictor
|
|
|
|
|
|
|
|
logger.info("ASR engine start to init the am predictor")
|
|
|
|
|
|
|
|
self.am_predictor = init_predictor(
|
|
|
|
|
|
|
|
model_file=self.am_model,
|
|
|
|
|
|
|
|
params_file=self.am_params,
|
|
|
|
|
|
|
|
predictor_conf=self.am_predictor_conf)
|
|
|
|
|
|
|
|
elif "conformer" in self.model_type or "transformer" in self.model_type :
|
|
|
|
|
|
|
|
# load model
|
|
|
|
|
|
|
|
# model_type: {model_name}_{dataset}
|
|
|
|
|
|
|
|
model_name = self.model_type[:self.model_type.rindex('_')]
|
|
|
|
|
|
|
|
logger.info(f"model name: {model_name}")
|
|
|
|
|
|
|
|
model_class = self.task_resource.get_model_class(model_name)
|
|
|
|
|
|
|
|
model = model_class.from_config(self.config)
|
|
|
|
|
|
|
|
self.model = model
|
|
|
|
|
|
|
|
self.model.set_state_dict(paddle.load(self.am_model))
|
|
|
|
|
|
|
|
self.model.eval()
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
raise Exception(f"not support: {self.model_type}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _init_from_path(self,
|
|
|
|
def _init_from_path(self,
|
|
|
|
model_type: str=None,
|
|
|
|
model_type: str=None,
|
|
|
|
am_model: Optional[os.PathLike]=None,
|
|
|
|
am_model: Optional[os.PathLike]=None,
|
|
|
@ -718,8 +780,13 @@ class ASRServerExecutor(ASRExecutor):
|
|
|
|
|
|
|
|
|
|
|
|
self.model_type = model_type
|
|
|
|
self.model_type = model_type
|
|
|
|
self.sample_rate = sample_rate
|
|
|
|
self.sample_rate = sample_rate
|
|
|
|
|
|
|
|
self.decode_method = decode_method
|
|
|
|
|
|
|
|
self.num_decoding_left_chunks = num_decoding_left_chunks
|
|
|
|
|
|
|
|
# conf for paddleinference predictor or onnx
|
|
|
|
|
|
|
|
self.am_predictor_conf = am_predictor_conf
|
|
|
|
logger.info(f"model_type: {self.model_type}")
|
|
|
|
logger.info(f"model_type: {self.model_type}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
|
|
|
|
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
|
|
|
|
tag = model_type + '-' + lang + '-' + sample_rate_str
|
|
|
|
tag = model_type + '-' + lang + '-' + sample_rate_str
|
|
|
|
self.task_resource.set_task_model(model_tag=tag)
|
|
|
|
self.task_resource.set_task_model(model_tag=tag)
|
|
|
@ -763,62 +830,10 @@ class ASRServerExecutor(ASRExecutor):
|
|
|
|
vocab=self.config.vocab_filepath,
|
|
|
|
vocab=self.config.vocab_filepath,
|
|
|
|
spm_model_prefix=self.config.spm_model_prefix)
|
|
|
|
spm_model_prefix=self.config.spm_model_prefix)
|
|
|
|
|
|
|
|
|
|
|
|
if "deepspeech2" in model_type:
|
|
|
|
self.update_config()
|
|
|
|
with UpdateConfig(self.config):
|
|
|
|
|
|
|
|
# download lm
|
|
|
|
# AM predictor
|
|
|
|
self.config.decode.lang_model_path = os.path.join(
|
|
|
|
self.init_model()
|
|
|
|
MODEL_HOME, 'language_model',
|
|
|
|
|
|
|
|
self.config.decode.lang_model_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm_url = self.task_resource.res_dict['lm_url']
|
|
|
|
|
|
|
|
lm_md5 = self.task_resource.res_dict['lm_md5']
|
|
|
|
|
|
|
|
logger.info(f"Start to load language model {lm_url}")
|
|
|
|
|
|
|
|
self.download_lm(
|
|
|
|
|
|
|
|
lm_url,
|
|
|
|
|
|
|
|
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# AM predictor
|
|
|
|
|
|
|
|
logger.info("ASR engine start to init the am predictor")
|
|
|
|
|
|
|
|
self.am_predictor_conf = am_predictor_conf
|
|
|
|
|
|
|
|
self.am_predictor = init_predictor(
|
|
|
|
|
|
|
|
model_file=self.am_model,
|
|
|
|
|
|
|
|
params_file=self.am_params,
|
|
|
|
|
|
|
|
predictor_conf=self.am_predictor_conf)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif "conformer" in model_type or "transformer" in model_type:
|
|
|
|
|
|
|
|
with UpdateConfig(self.config):
|
|
|
|
|
|
|
|
logger.info("start to create the stream conformer asr engine")
|
|
|
|
|
|
|
|
# update the decoding method
|
|
|
|
|
|
|
|
if decode_method:
|
|
|
|
|
|
|
|
self.config.decode.decoding_method = decode_method
|
|
|
|
|
|
|
|
# update num_decoding_left_chunks
|
|
|
|
|
|
|
|
if num_decoding_left_chunks:
|
|
|
|
|
|
|
|
assert num_decoding_left_chunks == -1 or num_decoding_left_chunks >= 0, f"num_decoding_left_chunks should be -1 or >=0"
|
|
|
|
|
|
|
|
self.config.decode.num_decoding_left_chunks = num_decoding_left_chunks
|
|
|
|
|
|
|
|
# we only support ctc_prefix_beam_search and attention_rescoring dedoding method
|
|
|
|
|
|
|
|
# Generally we set the decoding_method to attention_rescoring
|
|
|
|
|
|
|
|
if self.config.decode.decoding_method not in [
|
|
|
|
|
|
|
|
"ctc_prefix_beam_search", "attention_rescoring"
|
|
|
|
|
|
|
|
]:
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
|
|
|
"we set the decoding_method to attention_rescoring")
|
|
|
|
|
|
|
|
self.config.decode.decoding_method = "attention_rescoring"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert self.config.decode.decoding_method in [
|
|
|
|
|
|
|
|
"ctc_prefix_beam_search", "attention_rescoring"
|
|
|
|
|
|
|
|
], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# load model
|
|
|
|
|
|
|
|
model_name = model_type[:model_type.rindex(
|
|
|
|
|
|
|
|
'_')] # model_type: {model_name}_{dataset}
|
|
|
|
|
|
|
|
logger.info(f"model name: {model_name}")
|
|
|
|
|
|
|
|
model_class = self.task_resource.get_model_class(model_name)
|
|
|
|
|
|
|
|
model = model_class.from_config(self.config)
|
|
|
|
|
|
|
|
self.model = model
|
|
|
|
|
|
|
|
self.model.set_state_dict(paddle.load(self.am_model))
|
|
|
|
|
|
|
|
self.model.eval()
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
raise Exception(f"not support: {model_type}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"create the {model_type} model success")
|
|
|
|
logger.info(f"create the {model_type} model success")
|
|
|
|
return True
|
|
|
|
return True
|
|
|
@ -835,6 +850,22 @@ class ASREngine(BaseEngine):
|
|
|
|
super(ASREngine, self).__init__()
|
|
|
|
super(ASREngine, self).__init__()
|
|
|
|
logger.info("create the online asr engine resource instance")
|
|
|
|
logger.info("create the online asr engine resource instance")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_model(self) -> bool:
|
|
|
|
|
|
|
|
if not self.executor._init_from_path(
|
|
|
|
|
|
|
|
model_type=self.config.model_type,
|
|
|
|
|
|
|
|
am_model=self.config.am_model,
|
|
|
|
|
|
|
|
am_params=self.config.am_params,
|
|
|
|
|
|
|
|
lang=self.config.lang,
|
|
|
|
|
|
|
|
sample_rate=self.config.sample_rate,
|
|
|
|
|
|
|
|
cfg_path=self.config.cfg_path,
|
|
|
|
|
|
|
|
decode_method=self.config.decode_method,
|
|
|
|
|
|
|
|
num_decoding_left_chunks=self.config.num_decoding_left_chunks,
|
|
|
|
|
|
|
|
am_predictor_conf=self.config.am_predictor_conf):
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init(self, config: dict) -> bool:
|
|
|
|
def init(self, config: dict) -> bool:
|
|
|
|
"""init engine resource
|
|
|
|
"""init engine resource
|
|
|
|
|
|
|
|
|
|
|
@ -860,16 +891,7 @@ class ASREngine(BaseEngine):
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"paddlespeech_server set the device: {self.device}")
|
|
|
|
logger.info(f"paddlespeech_server set the device: {self.device}")
|
|
|
|
|
|
|
|
|
|
|
|
if not self.executor._init_from_path(
|
|
|
|
if not self.init_model():
|
|
|
|
model_type=self.config.model_type,
|
|
|
|
|
|
|
|
am_model=self.config.am_model,
|
|
|
|
|
|
|
|
am_params=self.config.am_params,
|
|
|
|
|
|
|
|
lang=self.config.lang,
|
|
|
|
|
|
|
|
sample_rate=self.config.sample_rate,
|
|
|
|
|
|
|
|
cfg_path=self.config.cfg_path,
|
|
|
|
|
|
|
|
decode_method=self.config.decode_method,
|
|
|
|
|
|
|
|
num_decoding_left_chunks=self.config.num_decoding_left_chunks,
|
|
|
|
|
|
|
|
am_predictor_conf=self.config.am_predictor_conf):
|
|
|
|
|
|
|
|
logger.error(
|
|
|
|
logger.error(
|
|
|
|
"Init the ASR server occurs error, please check the server configuration yaml"
|
|
|
|
"Init the ASR server occurs error, please check the server configuration yaml"
|
|
|
|
)
|
|
|
|
)
|