ctc endpoint work

pull/2015/head
Hui Zhang 3 years ago
parent 8f9b7bba48
commit 69a6da4c16

@ -31,6 +31,8 @@ asr_online:
force_yes: True force_yes: True
device: 'cpu' # cpu or gpu:id device: 'cpu' # cpu or gpu:id
decode_method: "attention_rescoring" decode_method: "attention_rescoring"
continuous_decoding: True # enable continue decoding when endpoint detected
am_predictor_conf: am_predictor_conf:
device: # set 'gpu:id' or 'cpu' device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True switch_ir_optim: True

@ -30,6 +30,9 @@ asr_online:
decode_method: decode_method:
force_yes: True force_yes: True
device: 'cpu' # cpu or gpu:id device: 'cpu' # cpu or gpu:id
decode_method: "attention_rescoring"
continuous_decoding: True # enable continue decoding when endpoint detected
am_predictor_conf: am_predictor_conf:
device: # set 'gpu:id' or 'cpu' device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True switch_ir_optim: True

@ -31,6 +31,8 @@ asr_online:
force_yes: True force_yes: True
device: 'cpu' # cpu or gpu:id device: 'cpu' # cpu or gpu:id
decode_method: "attention_rescoring" decode_method: "attention_rescoring"
continuous_decoding: True # enable continue decoding when endpoint detected
am_predictor_conf: am_predictor_conf:
device: # set 'gpu:id' or 'cpu' device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True switch_ir_optim: True

@ -29,6 +29,7 @@ asr_online:
cfg_path: cfg_path:
decode_method: decode_method:
force_yes: True force_yes: True
device: # cpu or gpu:id
am_predictor_conf: am_predictor_conf:
device: # set 'gpu:id' or 'cpu' device: # set 'gpu:id' or 'cpu'

@ -30,6 +30,8 @@ asr_online:
decode_method: decode_method:
force_yes: True force_yes: True
device: # cpu or gpu:id device: # cpu or gpu:id
continuous_decoding: True # enable continue decoding when endpoint detected
am_predictor_conf: am_predictor_conf:
device: # set 'gpu:id' or 'cpu' device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True switch_ir_optim: True

@ -76,11 +76,13 @@ class PaddleASRConnectionHanddler:
self.frame_shift_in_ms = int( self.frame_shift_in_ms = int(
self.n_shift / self.preprocess_conf.process[0]['fs'] * 1000) self.n_shift / self.preprocess_conf.process[0]['fs'] * 1000)
self.continuous_decoding = self.config.get("continuous_decoding", False)
self.init_decoder() self.init_decoder()
self.reset() self.reset()
def init_decoder(self): def init_decoder(self):
if "deepspeech2" in self.model_type: if "deepspeech2" in self.model_type:
assert self.continuous_decoding is False, "ds2 model not support endpoint"
self.am_predictor = self.asr_engine.executor.am_predictor self.am_predictor = self.asr_engine.executor.am_predictor
self.decoder = CTCDecoder( self.decoder = CTCDecoder(
@ -104,6 +106,8 @@ class PaddleASRConnectionHanddler:
elif "conformer" in self.model_type or "transformer" in self.model_type: elif "conformer" in self.model_type or "transformer" in self.model_type:
# acoustic model # acoustic model
self.model = self.asr_engine.executor.model self.model = self.asr_engine.executor.model
self.continuous_decoding = self.config.continuous_decoding
logger.info(f"continue decoding: {self.continuous_decoding}")
# ctc decoding config # ctc decoding config
self.ctc_decode_config = self.asr_engine.executor.config.decode self.ctc_decode_config = self.asr_engine.executor.config.decode
@ -120,7 +124,8 @@ class PaddleASRConnectionHanddler:
if "deepspeech2" in self.model_type: if "deepspeech2" in self.model_type:
return return
# feature cache # cache for audio and feat
self.remained_wav = None
self.cached_feat = None self.cached_feat = None
## conformer ## conformer
@ -135,6 +140,19 @@ class PaddleASRConnectionHanddler:
## just for record info ## just for record info
self.chunk_num = 0 # global decoding chunk num, not used self.chunk_num = 0 # global decoding chunk num, not used
def output_reset(self):
## outputs
# partial/ending decoding results
self.result_transcripts = ['']
# token timestamp result
self.word_time_stamp = []
## just for record
self.hyps = []
# one best timestamp viterbi prob is large.
self.time_stamp = []
def reset_continuous_decoding(self): def reset_continuous_decoding(self):
""" """
when in continous decoding, reset for next utterance. when in continous decoding, reset for next utterance.
@ -143,6 +161,7 @@ class PaddleASRConnectionHanddler:
self.model_reset() self.model_reset()
self.searcher.reset() self.searcher.reset()
self.endpointer.reset() self.endpointer.reset()
self.output_reset()
def reset(self): def reset(self):
if "deepspeech2" in self.model_type: if "deepspeech2" in self.model_type:
@ -171,24 +190,14 @@ class PaddleASRConnectionHanddler:
# frame step of cur utterance # frame step of cur utterance
self.num_frames = 0 self.num_frames = 0
# cache for audio and feat ## endpoint
self.remained_wav = None self.endpoint_state = False # True for detect endpoint
self.cached_feat = None
## conformer ## conformer
self.model_reset() self.model_reset()
## outputs ## outputs
# partial/ending decoding results self.output_reset()
self.result_transcripts = ['']
# token timestamp result
self.word_time_stamp = []
## just for record
self.hyps = []
# one best timestamp viterbi prob is large.
self.time_stamp = []
def extract_feat(self, samples: ByteString): def extract_feat(self, samples: ByteString):
logger.info("Online ASR extract the feat") logger.info("Online ASR extract the feat")
@ -388,6 +397,9 @@ class PaddleASRConnectionHanddler:
if "deepspeech" in self.model_type: if "deepspeech" in self.model_type:
return return
# reset endpiont state
self.endpoint_state = False
logger.info( logger.info(
"Conformer/Transformer: start to decode with advanced_decoding method" "Conformer/Transformer: start to decode with advanced_decoding method"
) )
@ -489,6 +501,16 @@ class PaddleASRConnectionHanddler:
# get one best hyps # get one best hyps
self.hyps = self.searcher.get_one_best_hyps() self.hyps = self.searcher.get_one_best_hyps()
# endpoint
if not is_finished:
def contain_nonsilence():
return len(self.hyps) > 0 and len(self.hyps[0]) > 0
decoding_something = contain_nonsilence()
if self.endpointer.endpoint_detected(ctc_probs.numpy(), decoding_something):
self.endpoint_state = True
logger.info(f"Endpoint is detected at {self.num_frames} frame.")
# advance cache of feat # advance cache of feat
assert self.cached_feat.shape[0] == 1 #(B=1,T,D) assert self.cached_feat.shape[0] == 1 #(B=1,T,D)
assert end >= cached_feature_num assert end >= cached_feature_num
@ -847,6 +869,15 @@ class ASREngine(BaseEngine):
logger.info("Initialize ASR server engine successfully.") logger.info("Initialize ASR server engine successfully.")
return True return True
def new_handler(self):
"""New handler from model.
Returns:
PaddleASRConnectionHanddler: asr handler instance
"""
return PaddleASRConnectionHanddler(self)
def preprocess(self, *args, **kwargs): def preprocess(self, *args, **kwargs):
raise NotImplementedError("Online not using this.") raise NotImplementedError("Online not using this.")

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import List from typing import List
import numpy as np
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
@ -76,14 +77,23 @@ class OnlineCTCEndpoint:
) and trailine_silence >= rule.min_trailing_silence and utterance_length >= rule.min_utterance_length ) and trailine_silence >= rule.min_trailing_silence and utterance_length >= rule.min_utterance_length
if (ans): if (ans):
logger.info( logger.info(
f"Endpoint Rule: {rule_name} activated: {decoding_something}, {trailine_silence}, {utterance_length}" f"Endpoint Rule: {rule_name} activated: {rule}"
) )
return ans return ans
def endpoint_detected(ctc_log_probs: List[List[float]], def endpoint_detected(self, ctc_log_probs: np.ndarray,
decoding_something: bool) -> bool: decoding_something: bool) -> bool:
"""detect endpoint.
Args:
ctc_log_probs (np.ndarray): (T, D)
decoding_something (bool): contain nonsilince.
Returns:
bool: whether endpoint detected.
"""
for logprob in ctc_log_probs: for logprob in ctc_log_probs:
blank_prob = exp(logprob[self.opts.blank_id]) blank_prob = np.exp(logprob[self.opts.blank])
self.num_frames_decoded += 1 self.num_frames_decoded += 1
if blank_prob > self.opts.blank_threshold: if blank_prob > self.opts.blank_threshold:
@ -96,6 +106,7 @@ class OnlineCTCEndpoint:
utterance_length = self.num_frames_decoded * self.frame_shift_in_ms utterance_length = self.num_frames_decoded * self.frame_shift_in_ms
trailing_silence = self.trailing_silence_frames * self.frame_shift_in_ms trailing_silence = self.trailing_silence_frames * self.frame_shift_in_ms
if self.rule_activated(self.opts.rule1, 'rule1', decoding_something, if self.rule_activated(self.opts.rule1, 'rule1', decoding_something,
trailing_silence, utterance_length): trailing_silence, utterance_length):
return True return True

@ -38,7 +38,7 @@ async def websocket_endpoint(websocket: WebSocket):
#2. if we accept the websocket headers, we will get the online asr engine instance #2. if we accept the websocket headers, we will get the online asr engine instance
engine_pool = get_engine_pool() engine_pool = get_engine_pool()
asr_engine = engine_pool['asr'] asr_model = engine_pool['asr']
#3. each websocket connection, we will create an PaddleASRConnectionHanddler to process such audio #3. each websocket connection, we will create an PaddleASRConnectionHanddler to process such audio
# and each connection has its own connection instance to process the request # and each connection has its own connection instance to process the request
@ -70,7 +70,8 @@ async def websocket_endpoint(websocket: WebSocket):
resp = {"status": "ok", "signal": "server_ready"} resp = {"status": "ok", "signal": "server_ready"}
# do something at begining here # do something at begining here
# create the instance to process the audio # create the instance to process the audio
connection_handler = PaddleASRConnectionHanddler(asr_engine) #connection_handler = PaddleASRConnectionHanddler(asr_model)
connection_handler = asr_model.new_handler()
await websocket.send_json(resp) await websocket.send_json(resp)
elif message['signal'] == 'end': elif message['signal'] == 'end':
# reset single engine for an new connection # reset single engine for an new connection
@ -100,11 +101,34 @@ async def websocket_endpoint(websocket: WebSocket):
# and decode for the result in this package data # and decode for the result in this package data
connection_handler.extract_feat(message) connection_handler.extract_feat(message)
connection_handler.decode(is_finished=False) connection_handler.decode(is_finished=False)
if connection_handler.endpoint_state:
logger.info("endpoint: detected and rescoring.")
connection_handler.rescoring()
word_time_stamp = connection_handler.get_word_time_stamp()
asr_results = connection_handler.get_result() asr_results = connection_handler.get_result()
# return the current period result if connection_handler.endpoint_state:
# if the engine create the vad instance, this connection will have many period results if connection_handler.continuous_decoding:
logger.info("endpoint: continue decoding")
connection_handler.reset_continuous_decoding()
else:
logger.info("endpoint: exit decoding")
# ending by endpoint
resp = {
"status": "ok",
"signal": "finished",
'result': asr_results,
'times': word_time_stamp
}
await websocket.send_json(resp)
break
# return the current partial result
# if the engine create the vad instance, this connection will have many partial results
resp = {'result': asr_results} resp = {'result': asr_results}
await websocket.send_json(resp) await websocket.send_json(resp)
except WebSocketDisconnect as e: except WebSocketDisconnect as e:
logger.error(e) logger.error(e)

Loading…
Cancel
Save