fix the websocket chunk edge bug, test=doc

pull/1704/head
xiongxinlei 3 years ago
parent 05a8a4b5fc
commit 5acb0b5252

@ -60,9 +60,9 @@ pretrained_models = {
},
"conformer2online_aishell-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.1.model.tar.gz',
'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.3.model.tar.gz',
'md5':
'b450d5dfaea0ac227c595ce58d18b637',
'0ac93d390552336f2a906aec9e33c5fa',
'cfg_path':
'model.yaml',
'ckpt_path':
@ -78,12 +78,19 @@ pretrained_models = {
},
}
# ASR server connection process class
# ASR server connection process class
class PaddleASRConnectionHanddler:
def __init__(self, asr_engine):
"""Init a Paddle ASR Connection Handler instance
Args:
asr_engine (ASREngine): the global asr engine
"""
super().__init__()
logger.info("create an paddle asr connection handler to process the websocket connection")
logger.info(
"create an paddle asr connection handler to process the websocket connection"
)
self.config = asr_engine.config
self.model_config = asr_engine.executor.config
self.model = asr_engine.executor.model
@ -105,14 +112,16 @@ class PaddleASRConnectionHanddler:
# tokens to text
self.text_feature = self.asr_engine.executor.text_feature
# ctc decoding
# ctc decoding config
self.ctc_decode_config = self.asr_engine.executor.config.decode
self.searcher = CTCPrefixBeamSearch(self.ctc_decode_config)
# extract fbank
# extract feat, new only fbank in conformer model
self.preprocess_conf = self.model_config.preprocess_config
self.preprocess_args = {"train": False}
self.preprocessing = Transformation(self.preprocess_conf)
# frame window samples length and frame shift samples length
self.win_length = self.preprocess_conf.process[0]['win_length']
self.n_shift = self.preprocess_conf.process[0]['n_shift']
@ -141,7 +150,8 @@ class PaddleASRConnectionHanddler:
return 0
# fbank
x_chunk = self.preprocessing(self.remained_wav, **self.preprocess_args)
x_chunk = self.preprocessing(self.remained_wav,
**self.preprocess_args)
x_chunk = paddle.to_tensor(
x_chunk, dtype="float32").unsqueeze(axis=0)
if self.cached_feat is None:
@ -149,7 +159,8 @@ class PaddleASRConnectionHanddler:
else:
assert (len(x_chunk.shape) == 3)
assert (len(self.cached_feat.shape) == 3)
self.cached_feat = paddle.concat([self.cached_feat, x_chunk], axis=1)
self.cached_feat = paddle.concat(
[self.cached_feat, x_chunk], axis=1)
# set the feat device
if self.device is None:
@ -218,15 +229,21 @@ class PaddleASRConnectionHanddler:
return
num_frames = self.cached_feat.shape[1]
logger.info(f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames")
logger.info(
f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames"
)
# the cached feat must be larger decoding_window
if num_frames < decoding_window and not is_finished:
logger.info(f"frame feat num is less than {decoding_window}, please input more pcm data")
logger.info(
f"frame feat num is less than {decoding_window}, please input more pcm data"
)
return None, None
if num_frames < context:
logger.info("flast {num_frames} is less than context {context} frames, and we cannot do model forward")
logger.info(
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
)
return None, None
logger.info("start to do model forward")
@ -258,14 +275,11 @@ class PaddleASRConnectionHanddler:
# update the offset
self.offset += y.shape[1]
logger.info(f"output size: {len(outputs)}")
ys = paddle.cat(outputs, 1)
if self.encoder_out is None:
self.encoder_out = ys
else:
self.encoder_out = paddle.concat([self.encoder_out, ys], axis=1)
# masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
# masks = masks.unsqueeze(1)
# get the ctc probs
ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size)
@ -274,18 +288,17 @@ class PaddleASRConnectionHanddler:
self.searcher.search(None, ctc_probs, self.cached_feat.place)
self.hyps = self.searcher.get_one_best_hyps()
# remove the processed feat
if end == num_frames:
self.cached_feat = None
else:
assert self.cached_feat.shape[0] == 1
assert end >= cached_feature_num
self.cached_feat = self.cached_feat[0,end - cached_feature_num:,:].unsqueeze(0)
assert len(self.cached_feat.shape) == 3, f"current cache feat shape is: {self.cached_feat.shape}"
self.cached_feat = self.cached_feat[0, end -
cached_feature_num:, :].unsqueeze(0)
assert len(
self.cached_feat.shape
) == 3, f"current cache feat shape is: {self.cached_feat.shape}"
# ys for rescoring
# return ys, masks
logger.info(
f"This connection handler encoder out shape: {self.encoder_out.shape}"
)
def update_result(self):
logger.info("update the final result")
@ -363,8 +376,6 @@ class PaddleASRConnectionHanddler:
logger.info(f"best index: {best_index}")
self.hyps = [hyps[best_index][0]]
self.update_result()
# return hyps[best_index][0]
class ASRServerExecutor(ASRExecutor):
@ -409,9 +420,9 @@ class ASRServerExecutor(ASRExecutor):
logger.info(f"Load the pretrained model, tag = {tag}")
res_path = self._get_pretrained_path(tag) # wenetspeech_zh
self.res_path = res_path
self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/examples/aishell/asr1/model.yaml"
# self.cfg_path = os.path.join(res_path,
# pretrained_models[tag]['cfg_path'])
self.cfg_path = os.path.join(res_path,
pretrained_models[tag]['cfg_path'])
self.am_model = os.path.join(res_path,
pretrained_models[tag]['model'])

@ -96,7 +96,6 @@ async def websocket_endpoint(websocket: WebSocket):
asr_results = connection_handler.get_result()
resp = {'asr_results': asr_results}
print("\n")
await websocket.send_json(resp)
except WebSocketDisconnect:
pass

Loading…
Cancel
Save