ctc endpoint work

pull/2015/head
Hui Zhang 3 years ago
parent 8f9b7bba48
commit 69a6da4c16

@ -31,6 +31,8 @@ asr_online:
force_yes: True
device: 'cpu' # cpu or gpu:id
decode_method: "attention_rescoring"
continuous_decoding: True # enable continue decoding when endpoint detected
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True

@ -30,6 +30,9 @@ asr_online:
decode_method:
force_yes: True
device: 'cpu' # cpu or gpu:id
decode_method: "attention_rescoring"
continuous_decoding: True # enable continue decoding when endpoint detected
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True

@ -31,6 +31,8 @@ asr_online:
force_yes: True
device: 'cpu' # cpu or gpu:id
decode_method: "attention_rescoring"
continuous_decoding: True # enable continue decoding when endpoint detected
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True

@ -29,6 +29,7 @@ asr_online:
cfg_path:
decode_method:
force_yes: True
device: # cpu or gpu:id
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'

@ -30,6 +30,8 @@ asr_online:
decode_method:
force_yes: True
device: # cpu or gpu:id
continuous_decoding: True # enable continue decoding when endpoint detected
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True

@ -76,11 +76,13 @@ class PaddleASRConnectionHanddler:
self.frame_shift_in_ms = int(
self.n_shift / self.preprocess_conf.process[0]['fs'] * 1000)
self.continuous_decoding = self.config.get("continuous_decoding", False)
self.init_decoder()
self.reset()
def init_decoder(self):
if "deepspeech2" in self.model_type:
assert self.continuous_decoding is False, "ds2 model not support endpoint"
self.am_predictor = self.asr_engine.executor.am_predictor
self.decoder = CTCDecoder(
@ -104,6 +106,8 @@ class PaddleASRConnectionHanddler:
elif "conformer" in self.model_type or "transformer" in self.model_type:
# acoustic model
self.model = self.asr_engine.executor.model
self.continuous_decoding = self.config.continuous_decoding
logger.info(f"continue decoding: {self.continuous_decoding}")
# ctc decoding config
self.ctc_decode_config = self.asr_engine.executor.config.decode
@ -120,7 +124,8 @@ class PaddleASRConnectionHanddler:
if "deepspeech2" in self.model_type:
return
# feature cache
# cache for audio and feat
self.remained_wav = None
self.cached_feat = None
## conformer
@ -135,6 +140,19 @@ class PaddleASRConnectionHanddler:
## just for record info
self.chunk_num = 0 # global decoding chunk num, not used
def output_reset(self):
## outputs
# partial/ending decoding results
self.result_transcripts = ['']
# token timestamp result
self.word_time_stamp = []
## just for record
self.hyps = []
# one best timestamp viterbi prob is large.
self.time_stamp = []
def reset_continuous_decoding(self):
"""
when in continous decoding, reset for next utterance.
@ -143,6 +161,7 @@ class PaddleASRConnectionHanddler:
self.model_reset()
self.searcher.reset()
self.endpointer.reset()
self.output_reset()
def reset(self):
if "deepspeech2" in self.model_type:
@ -171,24 +190,14 @@ class PaddleASRConnectionHanddler:
# frame step of cur utterance
self.num_frames = 0
# cache for audio and feat
self.remained_wav = None
self.cached_feat = None
## endpoint
self.endpoint_state = False # True for detect endpoint
## conformer
self.model_reset()
## outputs
# partial/ending decoding results
self.result_transcripts = ['']
# token timestamp result
self.word_time_stamp = []
## just for record
self.hyps = []
# one best timestamp viterbi prob is large.
self.time_stamp = []
self.output_reset()
def extract_feat(self, samples: ByteString):
logger.info("Online ASR extract the feat")
@ -388,6 +397,9 @@ class PaddleASRConnectionHanddler:
if "deepspeech" in self.model_type:
return
# reset endpiont state
self.endpoint_state = False
logger.info(
"Conformer/Transformer: start to decode with advanced_decoding method"
)
@ -489,6 +501,16 @@ class PaddleASRConnectionHanddler:
# get one best hyps
self.hyps = self.searcher.get_one_best_hyps()
# endpoint
if not is_finished:
def contain_nonsilence():
return len(self.hyps) > 0 and len(self.hyps[0]) > 0
decoding_something = contain_nonsilence()
if self.endpointer.endpoint_detected(ctc_probs.numpy(), decoding_something):
self.endpoint_state = True
logger.info(f"Endpoint is detected at {self.num_frames} frame.")
# advance cache of feat
assert self.cached_feat.shape[0] == 1 #(B=1,T,D)
assert end >= cached_feature_num
@ -847,6 +869,15 @@ class ASREngine(BaseEngine):
logger.info("Initialize ASR server engine successfully.")
return True
def new_handler(self):
"""New handler from model.
Returns:
PaddleASRConnectionHanddler: asr handler instance
"""
return PaddleASRConnectionHanddler(self)
def preprocess(self, *args, **kwargs):
raise NotImplementedError("Online not using this.")

@ -13,6 +13,7 @@
# limitations under the License.
from dataclasses import dataclass
from typing import List
import numpy as np
from paddlespeech.cli.log import logger
@ -76,14 +77,23 @@ class OnlineCTCEndpoint:
) and trailine_silence >= rule.min_trailing_silence and utterance_length >= rule.min_utterance_length
if (ans):
logger.info(
f"Endpoint Rule: {rule_name} activated: {decoding_something}, {trailine_silence}, {utterance_length}"
f"Endpoint Rule: {rule_name} activated: {rule}"
)
return ans
def endpoint_detected(ctc_log_probs: List[List[float]],
def endpoint_detected(self, ctc_log_probs: np.ndarray,
decoding_something: bool) -> bool:
"""detect endpoint.
Args:
ctc_log_probs (np.ndarray): (T, D)
decoding_something (bool): contain nonsilince.
Returns:
bool: whether endpoint detected.
"""
for logprob in ctc_log_probs:
blank_prob = exp(logprob[self.opts.blank_id])
blank_prob = np.exp(logprob[self.opts.blank])
self.num_frames_decoded += 1
if blank_prob > self.opts.blank_threshold:
@ -96,6 +106,7 @@ class OnlineCTCEndpoint:
utterance_length = self.num_frames_decoded * self.frame_shift_in_ms
trailing_silence = self.trailing_silence_frames * self.frame_shift_in_ms
if self.rule_activated(self.opts.rule1, 'rule1', decoding_something,
trailing_silence, utterance_length):
return True

@ -38,7 +38,7 @@ async def websocket_endpoint(websocket: WebSocket):
#2. if we accept the websocket headers, we will get the online asr engine instance
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
asr_model = engine_pool['asr']
#3. each websocket connection, we will create an PaddleASRConnectionHanddler to process such audio
# and each connection has its own connection instance to process the request
@ -70,7 +70,8 @@ async def websocket_endpoint(websocket: WebSocket):
resp = {"status": "ok", "signal": "server_ready"}
# do something at begining here
# create the instance to process the audio
connection_handler = PaddleASRConnectionHanddler(asr_engine)
#connection_handler = PaddleASRConnectionHanddler(asr_model)
connection_handler = asr_model.new_handler()
await websocket.send_json(resp)
elif message['signal'] == 'end':
# reset single engine for an new connection
@ -100,11 +101,34 @@ async def websocket_endpoint(websocket: WebSocket):
# and decode for the result in this package data
connection_handler.extract_feat(message)
connection_handler.decode(is_finished=False)
if connection_handler.endpoint_state:
logger.info("endpoint: detected and rescoring.")
connection_handler.rescoring()
word_time_stamp = connection_handler.get_word_time_stamp()
asr_results = connection_handler.get_result()
# return the current period result
# if the engine create the vad instance, this connection will have many period results
if connection_handler.endpoint_state:
if connection_handler.continuous_decoding:
logger.info("endpoint: continue decoding")
connection_handler.reset_continuous_decoding()
else:
logger.info("endpoint: exit decoding")
# ending by endpoint
resp = {
"status": "ok",
"signal": "finished",
'result': asr_results,
'times': word_time_stamp
}
await websocket.send_json(resp)
break
# return the current partial result
# if the engine create the vad instance, this connection will have many partial results
resp = {'result': asr_results}
await websocket.send_json(resp)
except WebSocketDisconnect as e:
logger.error(e)

Loading…
Cancel
Save