diff --git a/demos/streaming_asr_server/conf/application.yaml b/demos/streaming_asr_server/conf/application.yaml index e9a89c19d..683d86f03 100644 --- a/demos/streaming_asr_server/conf/application.yaml +++ b/demos/streaming_asr_server/conf/application.yaml @@ -31,6 +31,8 @@ asr_online: force_yes: True device: 'cpu' # cpu or gpu:id decode_method: "attention_rescoring" + continuous_decoding: True # enable continue decoding when endpoint detected + am_predictor_conf: device: # set 'gpu:id' or 'cpu' switch_ir_optim: True diff --git a/demos/streaming_asr_server/conf/ws_conformer_application.yaml b/demos/streaming_asr_server/conf/ws_conformer_application.yaml index 6a10741bd..9dbc82b6f 100644 --- a/demos/streaming_asr_server/conf/ws_conformer_application.yaml +++ b/demos/streaming_asr_server/conf/ws_conformer_application.yaml @@ -30,6 +30,9 @@ asr_online: decode_method: force_yes: True device: 'cpu' # cpu or gpu:id + decode_method: "attention_rescoring" + continuous_decoding: True # enable continue decoding when endpoint detected + am_predictor_conf: device: # set 'gpu:id' or 'cpu' switch_ir_optim: True diff --git a/demos/streaming_asr_server/conf/ws_conformer_wenetspeech_application.yaml b/demos/streaming_asr_server/conf/ws_conformer_wenetspeech_application.yaml index e9a89c19d..683d86f03 100644 --- a/demos/streaming_asr_server/conf/ws_conformer_wenetspeech_application.yaml +++ b/demos/streaming_asr_server/conf/ws_conformer_wenetspeech_application.yaml @@ -31,6 +31,8 @@ asr_online: force_yes: True device: 'cpu' # cpu or gpu:id decode_method: "attention_rescoring" + continuous_decoding: True # enable continue decoding when endpoint detected + am_predictor_conf: device: # set 'gpu:id' or 'cpu' switch_ir_optim: True diff --git a/paddlespeech/server/conf/ws_application.yaml b/paddlespeech/server/conf/ws_application.yaml index dee8d78ba..d6f5a227c 100644 --- a/paddlespeech/server/conf/ws_application.yaml +++ b/paddlespeech/server/conf/ws_application.yaml @@ -29,6 +29,7 @@ asr_online: cfg_path: decode_method: force_yes: True + device: # cpu or gpu:id am_predictor_conf: device: # set 'gpu:id' or 'cpu' diff --git a/paddlespeech/server/conf/ws_conformer_application.yaml b/paddlespeech/server/conf/ws_conformer_application.yaml index 9c0425345..dd5e67ca3 100644 --- a/paddlespeech/server/conf/ws_conformer_application.yaml +++ b/paddlespeech/server/conf/ws_conformer_application.yaml @@ -30,6 +30,8 @@ asr_online: decode_method: force_yes: True device: # cpu or gpu:id + continuous_decoding: True # enable continue decoding when endpoint detected + am_predictor_conf: device: # set 'gpu:id' or 'cpu' switch_ir_optim: True diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index 2dce35cb5..8f99077c7 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -55,7 +55,7 @@ class PaddleASRConnectionHanddler: self.config = asr_engine.config # server config self.model_config = asr_engine.executor.config self.asr_engine = asr_engine - + # model_type, sample_rate and text_feature is shared for deepspeech2 and conformer self.model_type = self.asr_engine.executor.model_type self.sample_rate = self.asr_engine.executor.sample_rate @@ -76,11 +76,13 @@ class PaddleASRConnectionHanddler: self.frame_shift_in_ms = int( self.n_shift / self.preprocess_conf.process[0]['fs'] * 1000) + self.continuous_decoding = self.config.get("continuous_decoding", False) self.init_decoder() self.reset() def init_decoder(self): if "deepspeech2" in self.model_type: + assert self.continuous_decoding is False, "ds2 model not support endpoint" self.am_predictor = self.asr_engine.executor.am_predictor self.decoder = CTCDecoder( @@ -104,6 +106,8 @@ class PaddleASRConnectionHanddler: elif "conformer" in self.model_type or "transformer" in self.model_type: # acoustic model self.model = self.asr_engine.executor.model + self.continuous_decoding = self.config.continuous_decoding + logger.info(f"continue decoding: {self.continuous_decoding}") # ctc decoding config self.ctc_decode_config = self.asr_engine.executor.config.decode @@ -120,7 +124,8 @@ class PaddleASRConnectionHanddler: if "deepspeech2" in self.model_type: return - # feature cache + # cache for audio and feat + self.remained_wav = None self.cached_feat = None ## conformer @@ -135,6 +140,19 @@ class PaddleASRConnectionHanddler: ## just for record info self.chunk_num = 0 # global decoding chunk num, not used + def output_reset(self): + ## outputs + # partial/ending decoding results + self.result_transcripts = [''] + # token timestamp result + self.word_time_stamp = [] + + ## just for record + self.hyps = [] + + # one best timestamp viterbi prob is large. + self.time_stamp = [] + def reset_continuous_decoding(self): """ when in continous decoding, reset for next utterance. @@ -143,6 +161,7 @@ class PaddleASRConnectionHanddler: self.model_reset() self.searcher.reset() self.endpointer.reset() + self.output_reset() def reset(self): if "deepspeech2" in self.model_type: @@ -171,24 +190,14 @@ class PaddleASRConnectionHanddler: # frame step of cur utterance self.num_frames = 0 - # cache for audio and feat - self.remained_wav = None - self.cached_feat = None + ## endpoint + self.endpoint_state = False # True for detect endpoint ## conformer self.model_reset() ## outputs - # partial/ending decoding results - self.result_transcripts = [''] - # token timestamp result - self.word_time_stamp = [] - - ## just for record - self.hyps = [] - - # one best timestamp viterbi prob is large. - self.time_stamp = [] + self.output_reset() def extract_feat(self, samples: ByteString): logger.info("Online ASR extract the feat") @@ -388,6 +397,9 @@ class PaddleASRConnectionHanddler: if "deepspeech" in self.model_type: return + # reset endpiont state + self.endpoint_state = False + logger.info( "Conformer/Transformer: start to decode with advanced_decoding method" ) @@ -489,6 +501,16 @@ class PaddleASRConnectionHanddler: # get one best hyps self.hyps = self.searcher.get_one_best_hyps() + # endpoint + if not is_finished: + def contain_nonsilence(): + return len(self.hyps) > 0 and len(self.hyps[0]) > 0 + + decoding_something = contain_nonsilence() + if self.endpointer.endpoint_detected(ctc_probs.numpy(), decoding_something): + self.endpoint_state = True + logger.info(f"Endpoint is detected at {self.num_frames} frame.") + # advance cache of feat assert self.cached_feat.shape[0] == 1 #(B=1,T,D) assert end >= cached_feature_num @@ -847,6 +869,15 @@ class ASREngine(BaseEngine): logger.info("Initialize ASR server engine successfully.") return True + + def new_handler(self): + """New handler from model. + + Returns: + PaddleASRConnectionHanddler: asr handler instance + """ + return PaddleASRConnectionHanddler(self) + def preprocess(self, *args, **kwargs): raise NotImplementedError("Online not using this.") diff --git a/paddlespeech/server/engine/asr/online/ctc_endpoint.py b/paddlespeech/server/engine/asr/online/ctc_endpoint.py index 70146d6dc..9686969aa 100644 --- a/paddlespeech/server/engine/asr/online/ctc_endpoint.py +++ b/paddlespeech/server/engine/asr/online/ctc_endpoint.py @@ -13,6 +13,7 @@ # limitations under the License. from dataclasses import dataclass from typing import List +import numpy as np from paddlespeech.cli.log import logger @@ -76,14 +77,23 @@ class OnlineCTCEndpoint: ) and trailine_silence >= rule.min_trailing_silence and utterance_length >= rule.min_utterance_length if (ans): logger.info( - f"Endpoint Rule: {rule_name} activated: {decoding_something}, {trailine_silence}, {utterance_length}" + f"Endpoint Rule: {rule_name} activated: {rule}" ) return ans - def endpoint_detected(ctc_log_probs: List[List[float]], + def endpoint_detected(self, ctc_log_probs: np.ndarray, decoding_something: bool) -> bool: + """detect endpoint. + + Args: + ctc_log_probs (np.ndarray): (T, D) + decoding_something (bool): contain nonsilince. + + Returns: + bool: whether endpoint detected. + """ for logprob in ctc_log_probs: - blank_prob = exp(logprob[self.opts.blank_id]) + blank_prob = np.exp(logprob[self.opts.blank]) self.num_frames_decoded += 1 if blank_prob > self.opts.blank_threshold: @@ -96,6 +106,7 @@ class OnlineCTCEndpoint: utterance_length = self.num_frames_decoded * self.frame_shift_in_ms trailing_silence = self.trailing_silence_frames * self.frame_shift_in_ms + if self.rule_activated(self.opts.rule1, 'rule1', decoding_something, trailing_silence, utterance_length): return True diff --git a/paddlespeech/server/ws/asr_api.py b/paddlespeech/server/ws/asr_api.py index 0faa131aa..2bd2c4ca3 100644 --- a/paddlespeech/server/ws/asr_api.py +++ b/paddlespeech/server/ws/asr_api.py @@ -38,7 +38,7 @@ async def websocket_endpoint(websocket: WebSocket): #2. if we accept the websocket headers, we will get the online asr engine instance engine_pool = get_engine_pool() - asr_engine = engine_pool['asr'] + asr_model = engine_pool['asr'] #3. each websocket connection, we will create an PaddleASRConnectionHanddler to process such audio # and each connection has its own connection instance to process the request @@ -70,7 +70,8 @@ async def websocket_endpoint(websocket: WebSocket): resp = {"status": "ok", "signal": "server_ready"} # do something at begining here # create the instance to process the audio - connection_handler = PaddleASRConnectionHanddler(asr_engine) + #connection_handler = PaddleASRConnectionHanddler(asr_model) + connection_handler = asr_model.new_handler() await websocket.send_json(resp) elif message['signal'] == 'end': # reset single engine for an new connection @@ -100,11 +101,34 @@ async def websocket_endpoint(websocket: WebSocket): # and decode for the result in this package data connection_handler.extract_feat(message) connection_handler.decode(is_finished=False) + + if connection_handler.endpoint_state: + logger.info("endpoint: detected and rescoring.") + connection_handler.rescoring() + word_time_stamp = connection_handler.get_word_time_stamp() + asr_results = connection_handler.get_result() - # return the current period result - # if the engine create the vad instance, this connection will have many period results + if connection_handler.endpoint_state: + if connection_handler.continuous_decoding: + logger.info("endpoint: continue decoding") + connection_handler.reset_continuous_decoding() + else: + logger.info("endpoint: exit decoding") + # ending by endpoint + resp = { + "status": "ok", + "signal": "finished", + 'result': asr_results, + 'times': word_time_stamp + } + await websocket.send_json(resp) + break + + # return the current partial result + # if the engine create the vad instance, this connection will have many partial results resp = {'result': asr_results} await websocket.send_json(resp) + except WebSocketDisconnect as e: logger.error(e)