|
|
|
@ -49,7 +49,7 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
asr_engine (ASREngine): the global asr engine
|
|
|
|
|
"""
|
|
|
|
|
super().__init__()
|
|
|
|
|
logger.info(
|
|
|
|
|
logger.debug(
|
|
|
|
|
"create an paddle asr connection handler to process the websocket connection"
|
|
|
|
|
)
|
|
|
|
|
self.config = asr_engine.config # server config
|
|
|
|
@ -107,7 +107,7 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
# acoustic model
|
|
|
|
|
self.model = self.asr_engine.executor.model
|
|
|
|
|
self.continuous_decoding = self.config.continuous_decoding
|
|
|
|
|
logger.info(f"continue decoding: {self.continuous_decoding}")
|
|
|
|
|
logger.debug(f"continue decoding: {self.continuous_decoding}")
|
|
|
|
|
|
|
|
|
|
# ctc decoding config
|
|
|
|
|
self.ctc_decode_config = self.asr_engine.executor.config.decode
|
|
|
|
@ -207,7 +207,7 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
assert samples.ndim == 1
|
|
|
|
|
|
|
|
|
|
self.num_samples += samples.shape[0]
|
|
|
|
|
logger.info(
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -218,7 +218,7 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
else:
|
|
|
|
|
assert self.remained_wav.ndim == 1 # (T,)
|
|
|
|
|
self.remained_wav = np.concatenate([self.remained_wav, samples])
|
|
|
|
|
logger.info(
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -252,14 +252,14 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
# update remained wav
|
|
|
|
|
self.remained_wav = self.remained_wav[self.n_shift * num_frames:]
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}"
|
|
|
|
|
)
|
|
|
|
|
logger.info(
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"After extract feat, the cached remain the audio samples: {self.remained_wav.shape}"
|
|
|
|
|
)
|
|
|
|
|
logger.info(f"global samples: {self.num_samples}")
|
|
|
|
|
logger.info(f"global frames: {self.num_frames}")
|
|
|
|
|
logger.debug(f"global samples: {self.num_samples}")
|
|
|
|
|
logger.debug(f"global frames: {self.num_frames}")
|
|
|
|
|
|
|
|
|
|
def decode(self, is_finished=False):
|
|
|
|
|
"""advance decoding
|
|
|
|
@ -283,24 +283,24 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
stride = subsampling * decoding_chunk_size
|
|
|
|
|
|
|
|
|
|
if self.cached_feat is None:
|
|
|
|
|
logger.info("no audio feat, please input more pcm data")
|
|
|
|
|
logger.debug("no audio feat, please input more pcm data")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
num_frames = self.cached_feat.shape[1]
|
|
|
|
|
logger.info(
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# the cached feat must be larger decoding_window
|
|
|
|
|
if num_frames < decoding_window and not is_finished:
|
|
|
|
|
logger.info(
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"frame feat num is less than {decoding_window}, please input more pcm data"
|
|
|
|
|
)
|
|
|
|
|
return None, None
|
|
|
|
|
|
|
|
|
|
# if is_finished=True, we need at least context frames
|
|
|
|
|
if num_frames < context:
|
|
|
|
|
logger.info(
|
|
|
|
|
logger.debug(
|
|
|
|
|
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
|
|
|
|
|
)
|
|
|
|
|
return None, None
|
|
|
|
@ -354,7 +354,7 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
Returns:
|
|
|
|
|
logprob: poster probability.
|
|
|
|
|
"""
|
|
|
|
|
logger.info("start to decoce one chunk for deepspeech2")
|
|
|
|
|
logger.debug("start to decoce one chunk for deepspeech2")
|
|
|
|
|
input_names = self.am_predictor.get_input_names()
|
|
|
|
|
audio_handle = self.am_predictor.get_input_handle(input_names[0])
|
|
|
|
|
audio_len_handle = self.am_predictor.get_input_handle(input_names[1])
|
|
|
|
@ -391,7 +391,7 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
|
|
|
|
|
self.decoder.next(output_chunk_probs, output_chunk_lens)
|
|
|
|
|
trans_best, trans_beam = self.decoder.decode()
|
|
|
|
|
logger.info(f"decode one best result for deepspeech2: {trans_best[0]}")
|
|
|
|
|
logger.debug(f"decode one best result for deepspeech2: {trans_best[0]}")
|
|
|
|
|
return trans_best[0]
|
|
|
|
|
|
|
|
|
|
@paddle.no_grad()
|
|
|
|
@ -402,7 +402,7 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
# reset endpiont state
|
|
|
|
|
self.endpoint_state = False
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
logger.debug(
|
|
|
|
|
"Conformer/Transformer: start to decode with advanced_decoding method"
|
|
|
|
|
)
|
|
|
|
|
cfg = self.ctc_decode_config
|
|
|
|
@ -427,25 +427,25 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
stride = subsampling * decoding_chunk_size
|
|
|
|
|
|
|
|
|
|
if self.cached_feat is None:
|
|
|
|
|
logger.info("no audio feat, please input more pcm data")
|
|
|
|
|
logger.debug("no audio feat, please input more pcm data")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# (B=1,T,D)
|
|
|
|
|
num_frames = self.cached_feat.shape[1]
|
|
|
|
|
logger.info(
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# the cached feat must be larger decoding_window
|
|
|
|
|
if num_frames < decoding_window and not is_finished:
|
|
|
|
|
logger.info(
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"frame feat num is less than {decoding_window}, please input more pcm data"
|
|
|
|
|
)
|
|
|
|
|
return None, None
|
|
|
|
|
|
|
|
|
|
# if is_finished=True, we need at least context frames
|
|
|
|
|
if num_frames < context:
|
|
|
|
|
logger.info(
|
|
|
|
|
logger.debug(
|
|
|
|
|
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
|
|
|
|
|
)
|
|
|
|
|
return None, None
|
|
|
|
@ -489,7 +489,7 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
self.encoder_out = ys
|
|
|
|
|
else:
|
|
|
|
|
self.encoder_out = paddle.concat([self.encoder_out, ys], axis=1)
|
|
|
|
|
logger.info(
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"This connection handler encoder out shape: {self.encoder_out.shape}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -513,7 +513,8 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
if self.endpointer.endpoint_detected(ctc_probs.numpy(),
|
|
|
|
|
decoding_something):
|
|
|
|
|
self.endpoint_state = True
|
|
|
|
|
logger.info(f"Endpoint is detected at {self.num_frames} frame.")
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"Endpoint is detected at {self.num_frames} frame.")
|
|
|
|
|
|
|
|
|
|
# advance cache of feat
|
|
|
|
|
assert self.cached_feat.shape[0] == 1 #(B=1,T,D)
|
|
|
|
@ -526,7 +527,7 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
def update_result(self):
|
|
|
|
|
"""Conformer/Transformer hyps to result.
|
|
|
|
|
"""
|
|
|
|
|
logger.info("update the final result")
|
|
|
|
|
logger.debug("update the final result")
|
|
|
|
|
hyps = self.hyps
|
|
|
|
|
|
|
|
|
|
# output results and tokenids
|
|
|
|
@ -560,16 +561,16 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
only for conformer and transformer model.
|
|
|
|
|
"""
|
|
|
|
|
if "deepspeech2" in self.model_type:
|
|
|
|
|
logger.info("deepspeech2 not support rescoring decoding.")
|
|
|
|
|
logger.debug("deepspeech2 not support rescoring decoding.")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if "attention_rescoring" != self.ctc_decode_config.decoding_method:
|
|
|
|
|
logger.info(
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"decoding method not match: {self.ctc_decode_config.decoding_method}, need attention_rescoring"
|
|
|
|
|
)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
logger.info("rescoring the final result")
|
|
|
|
|
logger.debug("rescoring the final result")
|
|
|
|
|
|
|
|
|
|
# last decoding for last audio
|
|
|
|
|
self.searcher.finalize_search()
|
|
|
|
@ -685,7 +686,6 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
"bg": global_offset_in_sec + start,
|
|
|
|
|
"ed": global_offset_in_sec + end
|
|
|
|
|
})
|
|
|
|
|
# logger.info(f"{word_time_stamp[-1]}")
|
|
|
|
|
|
|
|
|
|
self.word_time_stamp = word_time_stamp
|
|
|
|
|
logger.info(f"word time stamp: {self.word_time_stamp}")
|
|
|
|
@ -707,13 +707,13 @@ class ASRServerExecutor(ASRExecutor):
|
|
|
|
|
|
|
|
|
|
lm_url = self.task_resource.res_dict['lm_url']
|
|
|
|
|
lm_md5 = self.task_resource.res_dict['lm_md5']
|
|
|
|
|
logger.info(f"Start to load language model {lm_url}")
|
|
|
|
|
logger.debug(f"Start to load language model {lm_url}")
|
|
|
|
|
self.download_lm(
|
|
|
|
|
lm_url,
|
|
|
|
|
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
|
|
|
|
|
elif "conformer" in self.model_type or "transformer" in self.model_type:
|
|
|
|
|
with UpdateConfig(self.config):
|
|
|
|
|
logger.info("start to create the stream conformer asr engine")
|
|
|
|
|
logger.debug("start to create the stream conformer asr engine")
|
|
|
|
|
# update the decoding method
|
|
|
|
|
if self.decode_method:
|
|
|
|
|
self.config.decode.decoding_method = self.decode_method
|
|
|
|
@ -726,7 +726,7 @@ class ASRServerExecutor(ASRExecutor):
|
|
|
|
|
if self.config.decode.decoding_method not in [
|
|
|
|
|
"ctc_prefix_beam_search", "attention_rescoring"
|
|
|
|
|
]:
|
|
|
|
|
logger.info(
|
|
|
|
|
logger.debug(
|
|
|
|
|
"we set the decoding_method to attention_rescoring")
|
|
|
|
|
self.config.decode.decoding_method = "attention_rescoring"
|
|
|
|
|
|
|
|
|
@ -739,7 +739,7 @@ class ASRServerExecutor(ASRExecutor):
|
|
|
|
|
def init_model(self) -> None:
|
|
|
|
|
if "deepspeech2" in self.model_type:
|
|
|
|
|
# AM predictor
|
|
|
|
|
logger.info("ASR engine start to init the am predictor")
|
|
|
|
|
logger.debug("ASR engine start to init the am predictor")
|
|
|
|
|
self.am_predictor = init_predictor(
|
|
|
|
|
model_file=self.am_model,
|
|
|
|
|
params_file=self.am_params,
|
|
|
|
@ -748,7 +748,7 @@ class ASRServerExecutor(ASRExecutor):
|
|
|
|
|
# load model
|
|
|
|
|
# model_type: {model_name}_{dataset}
|
|
|
|
|
model_name = self.model_type[:self.model_type.rindex('_')]
|
|
|
|
|
logger.info(f"model name: {model_name}")
|
|
|
|
|
logger.debug(f"model name: {model_name}")
|
|
|
|
|
model_class = self.task_resource.get_model_class(model_name)
|
|
|
|
|
model = model_class.from_config(self.config)
|
|
|
|
|
self.model = model
|
|
|
|
@ -782,7 +782,7 @@ class ASRServerExecutor(ASRExecutor):
|
|
|
|
|
self.num_decoding_left_chunks = num_decoding_left_chunks
|
|
|
|
|
# conf for paddleinference predictor or onnx
|
|
|
|
|
self.am_predictor_conf = am_predictor_conf
|
|
|
|
|
logger.info(f"model_type: {self.model_type}")
|
|
|
|
|
logger.debug(f"model_type: {self.model_type}")
|
|
|
|
|
|
|
|
|
|
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
|
|
|
|
|
tag = model_type + '-' + lang + '-' + sample_rate_str
|
|
|
|
@ -804,12 +804,12 @@ class ASRServerExecutor(ASRExecutor):
|
|
|
|
|
self.res_path = os.path.dirname(
|
|
|
|
|
os.path.dirname(os.path.abspath(self.cfg_path)))
|
|
|
|
|
|
|
|
|
|
logger.info("Load the pretrained model:")
|
|
|
|
|
logger.info(f" tag = {tag}")
|
|
|
|
|
logger.info(f" res_path: {self.res_path}")
|
|
|
|
|
logger.info(f" cfg path: {self.cfg_path}")
|
|
|
|
|
logger.info(f" am_model path: {self.am_model}")
|
|
|
|
|
logger.info(f" am_params path: {self.am_params}")
|
|
|
|
|
logger.debug("Load the pretrained model:")
|
|
|
|
|
logger.debug(f" tag = {tag}")
|
|
|
|
|
logger.debug(f" res_path: {self.res_path}")
|
|
|
|
|
logger.debug(f" cfg path: {self.cfg_path}")
|
|
|
|
|
logger.debug(f" am_model path: {self.am_model}")
|
|
|
|
|
logger.debug(f" am_params path: {self.am_params}")
|
|
|
|
|
|
|
|
|
|
#Init body.
|
|
|
|
|
self.config = CfgNode(new_allowed=True)
|
|
|
|
@ -818,7 +818,7 @@ class ASRServerExecutor(ASRExecutor):
|
|
|
|
|
if self.config.spm_model_prefix:
|
|
|
|
|
self.config.spm_model_prefix = os.path.join(
|
|
|
|
|
self.res_path, self.config.spm_model_prefix)
|
|
|
|
|
logger.info(f"spm model path: {self.config.spm_model_prefix}")
|
|
|
|
|
logger.debug(f"spm model path: {self.config.spm_model_prefix}")
|
|
|
|
|
|
|
|
|
|
self.vocab = self.config.vocab_filepath
|
|
|
|
|
|
|
|
|
@ -832,7 +832,7 @@ class ASRServerExecutor(ASRExecutor):
|
|
|
|
|
# AM predictor
|
|
|
|
|
self.init_model()
|
|
|
|
|
|
|
|
|
|
logger.info(f"create the {model_type} model success")
|
|
|
|
|
logger.debug(f"create the {model_type} model success")
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -883,7 +883,7 @@ class ASREngine(BaseEngine):
|
|
|
|
|
"If all GPU or XPU is used, you can set the server to 'cpu'")
|
|
|
|
|
sys.exit(-1)
|
|
|
|
|
|
|
|
|
|
logger.info(f"paddlespeech_server set the device: {self.device}")
|
|
|
|
|
logger.debug(f"paddlespeech_server set the device: {self.device}")
|
|
|
|
|
|
|
|
|
|
if not self.init_model():
|
|
|
|
|
logger.error(
|
|
|
|
@ -891,7 +891,9 @@ class ASREngine(BaseEngine):
|
|
|
|
|
)
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
logger.info("Initialize ASR server engine successfully.")
|
|
|
|
|
logger.info("Initialize ASR server engine successfully on device: %s." %
|
|
|
|
|
(self.device))
|
|
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
def new_handler(self):
|
|
|
|
|