refactor stream asr and fix ds2 stream bug

pull/2036/head
Hui Zhang 3 years ago
parent bca014fd92
commit b9e3e49305

@ -4,7 +4,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
# read the wav and pass it to only streaming asr service # read the wav and pass it to only streaming asr service
# If `127.0.0.1` is not accessible, you need to use the actual service IP address. # If `127.0.0.1` is not accessible, you need to use the actual service IP address.
# python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --wavfile ./zh.wav # python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --wavfile ./zh.wav
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --input ./zh.wav paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input ./zh.wav
# read the wav and call streaming and punc service # read the wav and call streaming and punc service
# If `127.0.0.1` is not accessible, you need to use the actual service IP address. # If `127.0.0.1` is not accessible, you need to use the actual service IP address.

@ -121,13 +121,14 @@ class PaddleASRConnectionHanddler:
raise ValueError(f"Not supported: {self.model_type}") raise ValueError(f"Not supported: {self.model_type}")
def model_reset(self): def model_reset(self):
if "deepspeech2" in self.model_type:
return
# cache for audio and feat # cache for audio and feat
self.remained_wav = None self.remained_wav = None
self.cached_feat = None self.cached_feat = None
if "deepspeech2" in self.model_type:
return
## conformer ## conformer
# cache for conformer online # cache for conformer online
self.subsampling_cache = None self.subsampling_cache = None
@ -697,6 +698,67 @@ class ASRServerExecutor(ASRExecutor):
self.task_resource = CommonTaskResource( self.task_resource = CommonTaskResource(
task='asr', model_format='dynamic', inference_mode='online') task='asr', model_format='dynamic', inference_mode='online')
def update_config(self)->None:
if "deepspeech2" in self.model_type:
with UpdateConfig(self.config):
# download lm
self.config.decode.lang_model_path = os.path.join(
MODEL_HOME, 'language_model',
self.config.decode.lang_model_path)
lm_url = self.task_resource.res_dict['lm_url']
lm_md5 = self.task_resource.res_dict['lm_md5']
logger.info(f"Start to load language model {lm_url}")
self.download_lm(
lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
elif "conformer" in self.model_type or "transformer" in self.model_type:
with UpdateConfig(self.config):
logger.info("start to create the stream conformer asr engine")
# update the decoding method
if self.decode_method:
self.config.decode.decoding_method = self.decode_method
# update num_decoding_left_chunks
if self.num_decoding_left_chunks:
assert self.num_decoding_left_chunks == -1 or self.num_decoding_left_chunks >= 0, f"num_decoding_left_chunks should be -1 or >=0"
self.config.decode.num_decoding_left_chunks = self.num_decoding_left_chunks
# we only support ctc_prefix_beam_search and attention_rescoring dedoding method
# Generally we set the decoding_method to attention_rescoring
if self.config.decode.decoding_method not in [
"ctc_prefix_beam_search", "attention_rescoring"
]:
logger.info(
"we set the decoding_method to attention_rescoring")
self.config.decode.decoding_method = "attention_rescoring"
assert self.config.decode.decoding_method in [
"ctc_prefix_beam_search", "attention_rescoring"
], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}"
else:
raise Exception(f"not support: {self.model_type}")
def init_model(self) -> None:
if "deepspeech2" in self.model_type :
# AM predictor
logger.info("ASR engine start to init the am predictor")
self.am_predictor = init_predictor(
model_file=self.am_model,
params_file=self.am_params,
predictor_conf=self.am_predictor_conf)
elif "conformer" in self.model_type or "transformer" in self.model_type :
# load model
# model_type: {model_name}_{dataset}
model_name = self.model_type[:self.model_type.rindex('_')]
logger.info(f"model name: {model_name}")
model_class = self.task_resource.get_model_class(model_name)
model = model_class.from_config(self.config)
self.model = model
self.model.set_state_dict(paddle.load(self.am_model))
self.model.eval()
else:
raise Exception(f"not support: {self.model_type}")
def _init_from_path(self, def _init_from_path(self,
model_type: str=None, model_type: str=None,
am_model: Optional[os.PathLike]=None, am_model: Optional[os.PathLike]=None,
@ -718,8 +780,13 @@ class ASRServerExecutor(ASRExecutor):
self.model_type = model_type self.model_type = model_type
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.decode_method = decode_method
self.num_decoding_left_chunks = num_decoding_left_chunks
# conf for paddleinference predictor or onnx
self.am_predictor_conf = am_predictor_conf
logger.info(f"model_type: {self.model_type}") logger.info(f"model_type: {self.model_type}")
sample_rate_str = '16k' if sample_rate == 16000 else '8k' sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str tag = model_type + '-' + lang + '-' + sample_rate_str
self.task_resource.set_task_model(model_tag=tag) self.task_resource.set_task_model(model_tag=tag)
@ -763,62 +830,10 @@ class ASRServerExecutor(ASRExecutor):
vocab=self.config.vocab_filepath, vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix) spm_model_prefix=self.config.spm_model_prefix)
if "deepspeech2" in model_type: self.update_config()
with UpdateConfig(self.config):
# download lm
self.config.decode.lang_model_path = os.path.join(
MODEL_HOME, 'language_model',
self.config.decode.lang_model_path)
lm_url = self.task_resource.res_dict['lm_url']
lm_md5 = self.task_resource.res_dict['lm_md5']
logger.info(f"Start to load language model {lm_url}")
self.download_lm(
lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
# AM predictor # AM predictor
logger.info("ASR engine start to init the am predictor") self.init_model()
self.am_predictor_conf = am_predictor_conf
self.am_predictor = init_predictor(
model_file=self.am_model,
params_file=self.am_params,
predictor_conf=self.am_predictor_conf)
elif "conformer" in model_type or "transformer" in model_type:
with UpdateConfig(self.config):
logger.info("start to create the stream conformer asr engine")
# update the decoding method
if decode_method:
self.config.decode.decoding_method = decode_method
# update num_decoding_left_chunks
if num_decoding_left_chunks:
assert num_decoding_left_chunks == -1 or num_decoding_left_chunks >= 0, f"num_decoding_left_chunks should be -1 or >=0"
self.config.decode.num_decoding_left_chunks = num_decoding_left_chunks
# we only support ctc_prefix_beam_search and attention_rescoring dedoding method
# Generally we set the decoding_method to attention_rescoring
if self.config.decode.decoding_method not in [
"ctc_prefix_beam_search", "attention_rescoring"
]:
logger.info(
"we set the decoding_method to attention_rescoring")
self.config.decode.decoding_method = "attention_rescoring"
assert self.config.decode.decoding_method in [
"ctc_prefix_beam_search", "attention_rescoring"
], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}"
# load model
model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset}
logger.info(f"model name: {model_name}")
model_class = self.task_resource.get_model_class(model_name)
model = model_class.from_config(self.config)
self.model = model
self.model.set_state_dict(paddle.load(self.am_model))
self.model.eval()
else:
raise Exception(f"not support: {model_type}")
logger.info(f"create the {model_type} model success") logger.info(f"create the {model_type} model success")
return True return True
@ -835,6 +850,22 @@ class ASREngine(BaseEngine):
super(ASREngine, self).__init__() super(ASREngine, self).__init__()
logger.info("create the online asr engine resource instance") logger.info("create the online asr engine resource instance")
def init_model(self) -> bool:
if not self.executor._init_from_path(
model_type=self.config.model_type,
am_model=self.config.am_model,
am_params=self.config.am_params,
lang=self.config.lang,
sample_rate=self.config.sample_rate,
cfg_path=self.config.cfg_path,
decode_method=self.config.decode_method,
num_decoding_left_chunks=self.config.num_decoding_left_chunks,
am_predictor_conf=self.config.am_predictor_conf):
return False
return True
def init(self, config: dict) -> bool: def init(self, config: dict) -> bool:
"""init engine resource """init engine resource
@ -860,16 +891,7 @@ class ASREngine(BaseEngine):
logger.info(f"paddlespeech_server set the device: {self.device}") logger.info(f"paddlespeech_server set the device: {self.device}")
if not self.executor._init_from_path( if not self.init_model():
model_type=self.config.model_type,
am_model=self.config.am_model,
am_params=self.config.am_params,
lang=self.config.lang,
sample_rate=self.config.sample_rate,
cfg_path=self.config.cfg_path,
decode_method=self.config.decode_method,
num_decoding_left_chunks=self.config.num_decoding_left_chunks,
am_predictor_conf=self.config.am_predictor_conf):
logger.error( logger.error(
"Init the ASR server occurs error, please check the server configuration yaml" "Init the ASR server occurs error, please check the server configuration yaml"
) )

@ -26,7 +26,10 @@ class EngineFactory(object):
from paddlespeech.server.engine.asr.python.asr_engine import ASREngine from paddlespeech.server.engine.asr.python.asr_engine import ASREngine
return ASREngine() return ASREngine()
elif engine_name == 'asr' and engine_type == 'online': elif engine_name == 'asr' and engine_type == 'online':
from paddlespeech.server.engine.asr.online.asr_engine import ASREngine from paddlespeech.server.engine.asr.online.python.asr_engine import ASREngine
return ASREngine()
elif engine_name == 'asr' and engine_type == 'online-onnx':
from paddlespeech.server.engine.asr.online.onnx.asr_engine import ASREngine
return ASREngine() return ASREngine()
elif engine_name == 'tts' and engine_type == 'inference': elif engine_name == 'tts' and engine_type == 'inference':
from paddlespeech.server.engine.tts.paddleinference.tts_engine import TTSEngine from paddlespeech.server.engine.tts.paddleinference.tts_engine import TTSEngine

Loading…
Cancel
Save