|
|
|
@ -76,11 +76,13 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
self.frame_shift_in_ms = int(
|
|
|
|
|
self.n_shift / self.preprocess_conf.process[0]['fs'] * 1000)
|
|
|
|
|
|
|
|
|
|
self.continuous_decoding = self.config.get("continuous_decoding", False)
|
|
|
|
|
self.init_decoder()
|
|
|
|
|
self.reset()
|
|
|
|
|
|
|
|
|
|
def init_decoder(self):
|
|
|
|
|
if "deepspeech2" in self.model_type:
|
|
|
|
|
assert self.continuous_decoding is False, "ds2 model not support endpoint"
|
|
|
|
|
self.am_predictor = self.asr_engine.executor.am_predictor
|
|
|
|
|
|
|
|
|
|
self.decoder = CTCDecoder(
|
|
|
|
@ -104,6 +106,8 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
elif "conformer" in self.model_type or "transformer" in self.model_type:
|
|
|
|
|
# acoustic model
|
|
|
|
|
self.model = self.asr_engine.executor.model
|
|
|
|
|
self.continuous_decoding = self.config.continuous_decoding
|
|
|
|
|
logger.info(f"continue decoding: {self.continuous_decoding}")
|
|
|
|
|
|
|
|
|
|
# ctc decoding config
|
|
|
|
|
self.ctc_decode_config = self.asr_engine.executor.config.decode
|
|
|
|
@ -120,7 +124,8 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
if "deepspeech2" in self.model_type:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# feature cache
|
|
|
|
|
# cache for audio and feat
|
|
|
|
|
self.remained_wav = None
|
|
|
|
|
self.cached_feat = None
|
|
|
|
|
|
|
|
|
|
## conformer
|
|
|
|
@ -135,6 +140,19 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
## just for record info
|
|
|
|
|
self.chunk_num = 0 # global decoding chunk num, not used
|
|
|
|
|
|
|
|
|
|
def output_reset(self):
|
|
|
|
|
## outputs
|
|
|
|
|
# partial/ending decoding results
|
|
|
|
|
self.result_transcripts = ['']
|
|
|
|
|
# token timestamp result
|
|
|
|
|
self.word_time_stamp = []
|
|
|
|
|
|
|
|
|
|
## just for record
|
|
|
|
|
self.hyps = []
|
|
|
|
|
|
|
|
|
|
# one best timestamp viterbi prob is large.
|
|
|
|
|
self.time_stamp = []
|
|
|
|
|
|
|
|
|
|
def reset_continuous_decoding(self):
|
|
|
|
|
"""
|
|
|
|
|
when in continous decoding, reset for next utterance.
|
|
|
|
@ -143,6 +161,7 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
self.model_reset()
|
|
|
|
|
self.searcher.reset()
|
|
|
|
|
self.endpointer.reset()
|
|
|
|
|
self.output_reset()
|
|
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
|
if "deepspeech2" in self.model_type:
|
|
|
|
@ -171,24 +190,14 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
# frame step of cur utterance
|
|
|
|
|
self.num_frames = 0
|
|
|
|
|
|
|
|
|
|
# cache for audio and feat
|
|
|
|
|
self.remained_wav = None
|
|
|
|
|
self.cached_feat = None
|
|
|
|
|
## endpoint
|
|
|
|
|
self.endpoint_state = False # True for detect endpoint
|
|
|
|
|
|
|
|
|
|
## conformer
|
|
|
|
|
self.model_reset()
|
|
|
|
|
|
|
|
|
|
## outputs
|
|
|
|
|
# partial/ending decoding results
|
|
|
|
|
self.result_transcripts = ['']
|
|
|
|
|
# token timestamp result
|
|
|
|
|
self.word_time_stamp = []
|
|
|
|
|
|
|
|
|
|
## just for record
|
|
|
|
|
self.hyps = []
|
|
|
|
|
|
|
|
|
|
# one best timestamp viterbi prob is large.
|
|
|
|
|
self.time_stamp = []
|
|
|
|
|
self.output_reset()
|
|
|
|
|
|
|
|
|
|
def extract_feat(self, samples: ByteString):
|
|
|
|
|
logger.info("Online ASR extract the feat")
|
|
|
|
@ -388,6 +397,9 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
if "deepspeech" in self.model_type:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# reset endpiont state
|
|
|
|
|
self.endpoint_state = False
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
"Conformer/Transformer: start to decode with advanced_decoding method"
|
|
|
|
|
)
|
|
|
|
@ -489,6 +501,16 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
# get one best hyps
|
|
|
|
|
self.hyps = self.searcher.get_one_best_hyps()
|
|
|
|
|
|
|
|
|
|
# endpoint
|
|
|
|
|
if not is_finished:
|
|
|
|
|
def contain_nonsilence():
|
|
|
|
|
return len(self.hyps) > 0 and len(self.hyps[0]) > 0
|
|
|
|
|
|
|
|
|
|
decoding_something = contain_nonsilence()
|
|
|
|
|
if self.endpointer.endpoint_detected(ctc_probs.numpy(), decoding_something):
|
|
|
|
|
self.endpoint_state = True
|
|
|
|
|
logger.info(f"Endpoint is detected at {self.num_frames} frame.")
|
|
|
|
|
|
|
|
|
|
# advance cache of feat
|
|
|
|
|
assert self.cached_feat.shape[0] == 1 #(B=1,T,D)
|
|
|
|
|
assert end >= cached_feature_num
|
|
|
|
@ -847,6 +869,15 @@ class ASREngine(BaseEngine):
|
|
|
|
|
logger.info("Initialize ASR server engine successfully.")
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def new_handler(self):
|
|
|
|
|
"""New handler from model.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
PaddleASRConnectionHanddler: asr handler instance
|
|
|
|
|
"""
|
|
|
|
|
return PaddleASRConnectionHanddler(self)
|
|
|
|
|
|
|
|
|
|
def preprocess(self, *args, **kwargs):
|
|
|
|
|
raise NotImplementedError("Online not using this.")
|
|
|
|
|
|
|
|
|
|