From 5acb0b5252e77018fdca05435c97638ac48f5d6a Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Mon, 18 Apr 2022 21:46:57 +0800 Subject: [PATCH] fix the websocket chunk edge bug, test=doc --- .../server/engine/asr/online/asr_engine.py | 121 ++++++++++-------- paddlespeech/server/ws/asr_socket.py | 1 - 2 files changed, 66 insertions(+), 56 deletions(-) diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index 696d223a..a8e25f4b 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -60,9 +60,9 @@ pretrained_models = { }, "conformer2online_aishell-zh-16k": { 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.1.model.tar.gz', + 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.3.model.tar.gz', 'md5': - 'b450d5dfaea0ac227c595ce58d18b637', + '0ac93d390552336f2a906aec9e33c5fa', 'cfg_path': 'model.yaml', 'ckpt_path': @@ -78,12 +78,19 @@ pretrained_models = { }, } -# ASR server connection process class +# ASR server connection process class class PaddleASRConnectionHanddler: def __init__(self, asr_engine): + """Init a Paddle ASR Connection Handler instance + + Args: + asr_engine (ASREngine): the global asr engine + """ super().__init__() - logger.info("create an paddle asr connection handler to process the websocket connection") + logger.info( + "create an paddle asr connection handler to process the websocket connection" + ) self.config = asr_engine.config self.model_config = asr_engine.executor.config self.model = asr_engine.executor.model @@ -98,24 +105,26 @@ class PaddleASRConnectionHanddler: pass elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type: self.sample_rate = self.asr_engine.executor.sample_rate - + # acoustic model self.model = self.asr_engine.executor.model - + # tokens to text self.text_feature = self.asr_engine.executor.text_feature - - # ctc decoding + + # ctc decoding config self.ctc_decode_config = self.asr_engine.executor.config.decode self.searcher = CTCPrefixBeamSearch(self.ctc_decode_config) - # extract fbank + # extract feat, new only fbank in conformer model self.preprocess_conf = self.model_config.preprocess_config self.preprocess_args = {"train": False} self.preprocessing = Transformation(self.preprocess_conf) + + # frame window samples length and frame shift samples length self.win_length = self.preprocess_conf.process[0]['win_length'] self.n_shift = self.preprocess_conf.process[0]['n_shift'] - + def extract_feat(self, samples): if "deepspeech2online" in self.model_type: pass @@ -123,10 +132,10 @@ class PaddleASRConnectionHanddler: logger.info("Online ASR extract the feat") samples = np.frombuffer(samples, dtype=np.int16) assert samples.ndim == 1 - + logger.info(f"This package receive {samples.shape[0]} pcm data") self.num_samples += samples.shape[0] - + # self.reamined_wav stores all the samples, # include the original remained_wav and this package samples if self.remained_wav is None: @@ -141,19 +150,21 @@ class PaddleASRConnectionHanddler: return 0 # fbank - x_chunk = self.preprocessing(self.remained_wav, **self.preprocess_args) + x_chunk = self.preprocessing(self.remained_wav, + **self.preprocess_args) x_chunk = paddle.to_tensor( x_chunk, dtype="float32").unsqueeze(axis=0) if self.cached_feat is None: self.cached_feat = x_chunk else: - assert(len(x_chunk.shape) == 3) - assert(len(self.cached_feat.shape) == 3) - self.cached_feat = paddle.concat([self.cached_feat, x_chunk], axis=1) - + assert (len(x_chunk.shape) == 3) + assert (len(self.cached_feat.shape) == 3) + self.cached_feat = paddle.concat( + [self.cached_feat, x_chunk], axis=1) + # set the feat device if self.device is None: - self.device = self.cached_feat.place + self.device = self.cached_feat.place num_frames = x_chunk.shape[1] self.num_frames += num_frames @@ -161,7 +172,7 @@ class PaddleASRConnectionHanddler: logger.info( f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}" - ) + ) logger.info( f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" ) @@ -209,24 +220,30 @@ class PaddleASRConnectionHanddler: subsampling = self.model.encoder.embed.subsampling_rate context = self.model.encoder.embed.right_context + 1 stride = subsampling * decoding_chunk_size - cached_feature_num = context - subsampling # processed chunk feature cached for next chunk + cached_feature_num = context - subsampling # processed chunk feature cached for next chunk # decoding window for model decoding_window = (decoding_chunk_size - 1) * subsampling + context if self.cached_feat is None: logger.info("no audio feat, please input more pcm data") return - + num_frames = self.cached_feat.shape[1] - logger.info(f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames") - + logger.info( + f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" + ) + # the cached feat must be larger decoding_window if num_frames < decoding_window and not is_finished: - logger.info(f"frame feat num is less than {decoding_window}, please input more pcm data") + logger.info( + f"frame feat num is less than {decoding_window}, please input more pcm data" + ) return None, None if num_frames < context: - logger.info("flast {num_frames} is less than context {context} frames, and we cannot do model forward") + logger.info( + "flast {num_frames} is less than context {context} frames, and we cannot do model forward" + ) return None, None logger.info("start to do model forward") @@ -235,17 +252,17 @@ class PaddleASRConnectionHanddler: # num_frames - context + 1 ensure that current frame can get context window if is_finished: - # if get the finished chunk, we need process the last context + # if get the finished chunk, we need process the last context left_frames = context else: # we only process decoding_window frames for one chunk - left_frames = decoding_window - + left_frames = decoding_window + # record the end for removing the processed feat end = None for cur in range(0, num_frames - left_frames + 1, stride): end = min(cur + decoding_window, num_frames) - + self.chunk_num += 1 chunk_xs = self.cached_feat[:, cur:end, :] (y, self.subsampling_cache, self.elayers_output_cache, @@ -257,35 +274,31 @@ class PaddleASRConnectionHanddler: # update the offset self.offset += y.shape[1] - - logger.info(f"output size: {len(outputs)}") + ys = paddle.cat(outputs, 1) if self.encoder_out is None: - self.encoder_out = ys + self.encoder_out = ys else: self.encoder_out = paddle.concat([self.encoder_out, ys], axis=1) - # masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) - # masks = masks.unsqueeze(1) - + # get the ctc probs ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size) ctc_probs = ctc_probs.squeeze(0) self.searcher.search(None, ctc_probs, self.cached_feat.place) - + self.hyps = self.searcher.get_one_best_hyps() + assert self.cached_feat.shape[0] == 1 + assert end >= cached_feature_num + self.cached_feat = self.cached_feat[0, end - + cached_feature_num:, :].unsqueeze(0) + assert len( + self.cached_feat.shape + ) == 3, f"current cache feat shape is: {self.cached_feat.shape}" - # remove the processed feat - if end == num_frames: - self.cached_feat = None - else: - assert self.cached_feat.shape[0] == 1 - assert end >= cached_feature_num - self.cached_feat = self.cached_feat[0,end - cached_feature_num:,:].unsqueeze(0) - assert len(self.cached_feat.shape) == 3, f"current cache feat shape is: {self.cached_feat.shape}" - - # ys for rescoring - # return ys, masks + logger.info( + f"This connection handler encoder out shape: {self.encoder_out.shape}" + ) def update_result(self): logger.info("update the final result") @@ -304,8 +317,8 @@ class PaddleASRConnectionHanddler: def rescoring(self): logger.info("rescoring the final result") if "attention_rescoring" != self.ctc_decode_config.decoding_method: - return - + return + self.searcher.finalize_search() self.update_result() @@ -363,8 +376,6 @@ class PaddleASRConnectionHanddler: logger.info(f"best index: {best_index}") self.hyps = [hyps[best_index][0]] self.update_result() - # return hyps[best_index][0] - class ASRServerExecutor(ASRExecutor): @@ -409,9 +420,9 @@ class ASRServerExecutor(ASRExecutor): logger.info(f"Load the pretrained model, tag = {tag}") res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.res_path = res_path - self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/examples/aishell/asr1/model.yaml" - # self.cfg_path = os.path.join(res_path, - # pretrained_models[tag]['cfg_path']) + + self.cfg_path = os.path.join(res_path, + pretrained_models[tag]['cfg_path']) self.am_model = os.path.join(res_path, pretrained_models[tag]['model']) @@ -639,7 +650,7 @@ class ASRServerExecutor(ASRExecutor): subsampling = self.model.encoder.embed.subsampling_rate context = self.model.encoder.embed.right_context + 1 stride = subsampling * decoding_chunk_size - + # decoding window for model decoding_window = (decoding_chunk_size - 1) * subsampling + context num_frames = xs.shape[1] diff --git a/paddlespeech/server/ws/asr_socket.py b/paddlespeech/server/ws/asr_socket.py index ae7c5eb4..82b05bc5 100644 --- a/paddlespeech/server/ws/asr_socket.py +++ b/paddlespeech/server/ws/asr_socket.py @@ -96,7 +96,6 @@ async def websocket_endpoint(websocket: WebSocket): asr_results = connection_handler.get_result() resp = {'asr_results': asr_results} - print("\n") await websocket.send_json(resp) except WebSocketDisconnect: pass