From d21ccd02875fea5d8c90483a31cd8b6f4a148d2e Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Fri, 15 Apr 2022 18:42:46 +0800 Subject: [PATCH] add conformer online server, test=doc --- paddlespeech/cli/asr/infer.py | 56 +++-- paddlespeech/s2t/models/u2/u2.py | 8 +- paddlespeech/s2t/modules/ctc.py | 3 +- paddlespeech/s2t/modules/encoder.py | 2 + paddlespeech/server/conf/ws_application.yaml | 54 +++- .../server/engine/asr/online/asr_engine.py | 231 ++++++++++++------ .../tests/asr/online/websocket_client.py | 2 +- paddlespeech/server/ws/asr_socket.py | 29 ++- 8 files changed, 272 insertions(+), 113 deletions(-) diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index b12b9f6f..53f71a70 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -91,6 +91,20 @@ pretrained_models = { 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', 'lm_md5': '29e02312deb2e59b3c8686c7966d4fe3' + }, + "conformer2online_aishell-zh-16k": { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.1.2.model.tar.gz', + 'md5': + '4814e52e0fc2fd48899373f95c84b0c9', + 'cfg_path': + 'config.yaml', + 'ckpt_path': + 'exp/deepspeech2_online/checkpoints/avg_30', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' }, "deepspeech2offline_librispeech-en-16k": { 'url': @@ -115,6 +129,8 @@ model_alias = { "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline", "conformer": "paddlespeech.s2t.models.u2:U2Model", + "conformer2online": + "paddlespeech.s2t.models.u2:U2Model", "transformer": "paddlespeech.s2t.models.u2:U2Model", "wenetspeech": @@ -219,6 +235,7 @@ class ASRExecutor(BaseExecutor): """ Init model and other resources from a specific path. """ + logger.info("start to init the model") if hasattr(self, 'model'): logger.info('Model had been initialized.') return @@ -233,14 +250,15 @@ class ASRExecutor(BaseExecutor): self.ckpt_path = os.path.join( res_path, pretrained_models[tag]['ckpt_path'] + ".pdparams") logger.info(res_path) - logger.info(self.cfg_path) - logger.info(self.ckpt_path) + else: self.cfg_path = os.path.abspath(cfg_path) self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams") self.res_path = os.path.dirname( os.path.dirname(os.path.abspath(self.cfg_path))) - + logger.info(self.cfg_path) + logger.info(self.ckpt_path) + #Init body. self.config = CfgNode(new_allowed=True) self.config.merge_from_file(self.cfg_path) @@ -269,7 +287,6 @@ class ASRExecutor(BaseExecutor): vocab=self.config.vocab_filepath, spm_model_prefix=self.config.spm_model_prefix) self.config.decode.decoding_method = decode_method - else: raise Exception("wrong type") model_name = model_type[:model_type.rindex( @@ -347,12 +364,14 @@ class ASRExecutor(BaseExecutor): else: raise Exception("wrong type") + logger.info("audio feat process success") + @paddle.no_grad() def infer(self, model_type: str): """ Model inference and result stored in self.output. """ - + logger.info("start to infer the model to get the output") cfg = self.config.decode audio = self._inputs["audio"] audio_len = self._inputs["audio_len"] @@ -369,17 +388,22 @@ class ASRExecutor(BaseExecutor): self._outputs["result"] = result_transcripts[0] elif "conformer" in model_type or "transformer" in model_type: - result_transcripts = self.model.decode( - audio, - audio_len, - text_feature=self.text_feature, - decoding_method=cfg.decoding_method, - beam_size=cfg.beam_size, - ctc_weight=cfg.ctc_weight, - decoding_chunk_size=cfg.decoding_chunk_size, - num_decoding_left_chunks=cfg.num_decoding_left_chunks, - simulate_streaming=cfg.simulate_streaming) - self._outputs["result"] = result_transcripts[0][0] + logger.info(f"we will use the transformer like model : {model_type}") + try: + result_transcripts = self.model.decode( + audio, + audio_len, + text_feature=self.text_feature, + decoding_method=cfg.decoding_method, + beam_size=cfg.beam_size, + ctc_weight=cfg.ctc_weight, + decoding_chunk_size=cfg.decoding_chunk_size, + num_decoding_left_chunks=cfg.num_decoding_left_chunks, + simulate_streaming=cfg.simulate_streaming) + self._outputs["result"] = result_transcripts[0][0] + except Exception as e: + logger.exception(e) + else: raise Exception("invalid model name") diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index 6a98607b..f0d2711d 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -213,12 +213,14 @@ class U2BaseModel(ASRInterface, nn.Layer): num_decoding_left_chunks=num_decoding_left_chunks ) # (B, maxlen, encoder_dim) else: + print("offline decode from the asr") encoder_out, encoder_mask = self.encoder( speech, speech_lengths, decoding_chunk_size=decoding_chunk_size, num_decoding_left_chunks=num_decoding_left_chunks ) # (B, maxlen, encoder_dim) + print("offline decode success") return encoder_out, encoder_mask def recognize( @@ -706,13 +708,15 @@ class U2BaseModel(ASRInterface, nn.Layer): List[List[int]]: transcripts. """ batch_size = feats.shape[0] + print("start to decode the audio feat") if decoding_method in ['ctc_prefix_beam_search', 'attention_rescoring'] and batch_size > 1: - logger.fatal( + logger.error( f'decoding mode {decoding_method} must be running with batch_size == 1' ) + logger.error(f"current batch_size is {batch_size}") sys.exit(1) - + print(f"use the {decoding_method} to decode the audio feat") if decoding_method == 'attention': hyps = self.recognize( feats, diff --git a/paddlespeech/s2t/modules/ctc.py b/paddlespeech/s2t/modules/ctc.py index 33ad472d..bd1219b1 100644 --- a/paddlespeech/s2t/modules/ctc.py +++ b/paddlespeech/s2t/modules/ctc.py @@ -180,7 +180,8 @@ class CTCDecoder(CTCDecoderBase): # init once if self._ext_scorer is not None: return - + + from paddlespeech.s2t.decoders.ctcdecoder import Scorer # noqa: F401 if language_model_path != '': logger.info("begin to initialize the external scorer " "for decoding") diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index c843c0e2..347035cd 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -317,6 +317,8 @@ class BaseEncoder(nn.Layer): outputs = [] offset = 0 # Feed forward overlap input step by step + print(f"context: {context}") + print(f"stride: {stride}") for cur in range(0, num_frames - context + 1, stride): end = min(cur + decoding_window, num_frames) chunk_xs = xs[:, cur:end, :] diff --git a/paddlespeech/server/conf/ws_application.yaml b/paddlespeech/server/conf/ws_application.yaml index ef23593e..6b82edcb 100644 --- a/paddlespeech/server/conf/ws_application.yaml +++ b/paddlespeech/server/conf/ws_application.yaml @@ -4,7 +4,7 @@ # SERVER SETTING # ################################################################################# host: 0.0.0.0 -port: 8091 +port: 8096 # The task format in the engin_list is: _ # task choices = ['asr_online', 'tts_online'] @@ -18,10 +18,44 @@ engine_list: ['asr_online'] # ENGINE CONFIG # ################################################################################# +# ################################### ASR ######################################### +# ################### speech task: asr; engine_type: online ####################### +# asr_online: +# model_type: 'deepspeech2online_aishell' +# am_model: # the pdmodel file of am static model [optional] +# am_params: # the pdiparams file of am static model [optional] +# lang: 'zh' +# sample_rate: 16000 +# cfg_path: +# decode_method: +# force_yes: True + +# am_predictor_conf: +# device: # set 'gpu:id' or 'cpu' +# switch_ir_optim: True +# glog_info: False # True -> print glog +# summary: True # False -> do not show predictor config + +# chunk_buffer_conf: +# frame_duration_ms: 80 +# shift_ms: 40 +# sample_rate: 16000 +# sample_width: 2 + +# vad_conf: +# aggressiveness: 2 +# sample_rate: 16000 +# frame_duration_ms: 20 +# sample_width: 2 +# padding_ms: 200 +# padding_ratio: 0.9 + + + ################################### ASR ######################################### ################### speech task: asr; engine_type: online ####################### asr_online: - model_type: 'deepspeech2online_aishell' + model_type: 'conformer2online_aishell' am_model: # the pdmodel file of am static model [optional] am_params: # the pdiparams file of am static model [optional] lang: 'zh' @@ -37,15 +71,15 @@ asr_online: summary: True # False -> do not show predictor config chunk_buffer_conf: - frame_duration_ms: 80 + frame_duration_ms: 85 shift_ms: 40 sample_rate: 16000 sample_width: 2 - vad_conf: - aggressiveness: 2 - sample_rate: 16000 - frame_duration_ms: 20 - sample_width: 2 - padding_ms: 200 - padding_ratio: 0.9 + # vad_conf: + # aggressiveness: 2 + # sample_rate: 16000 + # frame_duration_ms: 20 + # sample_width: 2 + # padding_ms: 200 + # padding_ratio: 0.9 \ No newline at end of file diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index cd5300fc..a5b9ab48 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -20,11 +20,15 @@ from numpy import float32 from yacs.config import CfgNode from paddlespeech.cli.asr.infer import ASRExecutor +from paddlespeech.cli.asr.infer import model_alias +from paddlespeech.cli.asr.infer import pretrained_models from paddlespeech.cli.log import logger from paddlespeech.cli.utils import MODEL_HOME from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.speech import SpeechSegment from paddlespeech.s2t.modules.ctc import CTCDecoder +from paddlespeech.s2t.transform.transformation import Transformation +from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.audio_process import pcm2float @@ -51,6 +55,24 @@ pretrained_models = { 'lm_md5': '29e02312deb2e59b3c8686c7966d4fe3' }, + "conformer2online_aishell-zh-16k": { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr1_chunk_conformer_aishell_ckpt_0.1.2.model.tar.gz', + 'md5': + '4814e52e0fc2fd48899373f95c84b0c9', + 'cfg_path': + 'exp/chunk_conformer//conf/config.yaml', + 'ckpt_path': + 'exp/chunk_conformer/checkpoints/avg_30/', + 'model': + 'exp/chunk_conformer/checkpoints/avg_30.pdparams', + 'params': + 'exp/chunk_conformer/checkpoints/avg_30.pdparams', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' + }, } @@ -71,15 +93,17 @@ class ASRServerExecutor(ASRExecutor): """ Init model and other resources from a specific path. """ - + self.model_type = model_type + self.sample_rate = sample_rate if cfg_path is None or am_model is None or am_params is None: sample_rate_str = '16k' if sample_rate == 16000 else '8k' tag = model_type + '-' + lang + '-' + sample_rate_str logger.info(f"Load the pretrained model, tag = {tag}") res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.res_path = res_path - self.cfg_path = os.path.join(res_path, - pretrained_models[tag]['cfg_path']) + self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/paddlespeech/server/tests/asr/online/conf/config.yaml" + # self.cfg_path = os.path.join(res_path, + # pretrained_models[tag]['cfg_path']) self.am_model = os.path.join(res_path, pretrained_models[tag]['model']) @@ -119,49 +143,67 @@ class ASRServerExecutor(ASRExecutor): lm_url, os.path.dirname(self.config.decode.lang_model_path), lm_md5) elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: - # 开发 conformer 的流式模型 logger.info("start to create the stream conformer asr engine") - # 复用cli里面的代码 - + if self.config.spm_model_prefix: + self.config.spm_model_prefix = os.path.join( + self.res_path, self.config.spm_model_prefix) + self.config.vocab_filepath = os.path.join( + self.res_path, self.config.vocab_filepath) + self.text_feature = TextFeaturizer( + unit_type=self.config.unit_type, + vocab=self.config.vocab_filepath, + spm_model_prefix=self.config.spm_model_prefix) + # update the decoding method + if decode_method: + self.config.decode.decoding_method = decode_method else: raise Exception("wrong type") - - # AM predictor - logger.info("ASR engine start to init the am predictor") - self.am_predictor_conf = am_predictor_conf - self.am_predictor = init_predictor( - model_file=self.am_model, - params_file=self.am_params, - predictor_conf=self.am_predictor_conf) - - # decoder - logger.info("ASR engine start to create the ctc decoder instance") - self.decoder = CTCDecoder( - odim=self.config.output_dim, # is in vocab - enc_n_units=self.config.rnn_layer_size * 2, - blank_id=self.config.blank_id, - dropout_rate=0.0, - reduction=True, # sum - batch_average=True, # sum / batch_size - grad_norm_type=self.config.get('ctc_grad_norm_type', None)) - - # init decoder - logger.info("ASR engine start to init the ctc decoder") - cfg = self.config.decode - decode_batch_size = 1 # for online - self.decoder.init_decoder( - decode_batch_size, self.text_feature.vocab_list, - cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, - cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, - cfg.num_proc_bsearch) - - # init state box - self.chunk_state_h_box = np.zeros( - (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), - dtype=float32) - self.chunk_state_c_box = np.zeros( - (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), - dtype=float32) + if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: + # AM predictor + logger.info("ASR engine start to init the am predictor") + self.am_predictor_conf = am_predictor_conf + self.am_predictor = init_predictor( + model_file=self.am_model, + params_file=self.am_params, + predictor_conf=self.am_predictor_conf) + + # decoder + logger.info("ASR engine start to create the ctc decoder instance") + self.decoder = CTCDecoder( + odim=self.config.output_dim, # is in vocab + enc_n_units=self.config.rnn_layer_size * 2, + blank_id=self.config.blank_id, + dropout_rate=0.0, + reduction=True, # sum + batch_average=True, # sum / batch_size + grad_norm_type=self.config.get('ctc_grad_norm_type', None)) + + # init decoder + logger.info("ASR engine start to init the ctc decoder") + cfg = self.config.decode + decode_batch_size = 1 # for online + self.decoder.init_decoder( + decode_batch_size, self.text_feature.vocab_list, + cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, + cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, + cfg.num_proc_bsearch) + + # init state box + self.chunk_state_h_box = np.zeros( + (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), + dtype=float32) + self.chunk_state_c_box = np.zeros( + (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), + dtype=float32) + elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: + model_name = model_type[:model_type.rindex( + '_')] # model_type: {model_name}_{dataset} + logger.info(f"model name: {model_name}") + model_class = dynamic_import(model_name, model_alias) + model_conf = self.config + model = model_class.from_config(model_conf) + self.model = model + logger.info("create the transformer like model success") def reset_decoder_and_chunk(self): """reset decoder and chunk state for an new audio @@ -186,6 +228,7 @@ class ASRServerExecutor(ASRExecutor): Returns: [type]: [description] """ + logger.info("start to decoce chunk by chunk") if "deepspeech2online" in model_type: input_names = self.am_predictor.get_input_names() audio_handle = self.am_predictor.get_input_handle(input_names[0]) @@ -224,10 +267,29 @@ class ASRServerExecutor(ASRExecutor): self.decoder.next(output_chunk_probs, output_chunk_lens) trans_best, trans_beam = self.decoder.decode() + logger.info(f"decode one one best result: {trans_best[0]}") return trans_best[0] elif "conformer" in model_type or "transformer" in model_type: - raise Exception("invalid model name") + try: + logger.info( + f"we will use the transformer like model : {self.model_type}" + ) + cfg = self.config.decode + result_transcripts = self.model.decode( + x_chunk, + x_chunk_lens, + text_feature=self.text_feature, + decoding_method=cfg.decoding_method, + beam_size=cfg.beam_size, + ctc_weight=cfg.ctc_weight, + decoding_chunk_size=cfg.decoding_chunk_size, + num_decoding_left_chunks=cfg.num_decoding_left_chunks, + simulate_streaming=cfg.simulate_streaming) + + return result_transcripts[0][0] + except Exception as e: + logger.exception(e) else: raise Exception("invalid model name") @@ -244,32 +306,55 @@ class ASRServerExecutor(ASRExecutor): """ # pcm16 -> pcm 32 samples = pcm2float(samples) - - # read audio - speech_segment = SpeechSegment.from_pcm( - samples, sample_rate, transcript=" ") - # audio augment - self.collate_fn_test.augmentation.transform_audio(speech_segment) - - # extract speech feature - spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize( - speech_segment, self.collate_fn_test.keep_transcription_text) - # CMVN spectrum - if self.collate_fn_test._normalizer: - spectrum = self.collate_fn_test._normalizer.apply(spectrum) - - # spectrum augment - audio = self.collate_fn_test.augmentation.transform_feature(spectrum) - - audio_len = audio.shape[0] - audio = paddle.to_tensor(audio, dtype='float32') - # audio_len = paddle.to_tensor(audio_len) - audio = paddle.unsqueeze(audio, axis=0) - - x_chunk = audio.numpy() - x_chunk_lens = np.array([audio_len]) - - return x_chunk, x_chunk_lens + if "deepspeech2online" in self.model_type: + # read audio + speech_segment = SpeechSegment.from_pcm( + samples, sample_rate, transcript=" ") + # audio augment + self.collate_fn_test.augmentation.transform_audio(speech_segment) + + # extract speech feature + spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize( + speech_segment, self.collate_fn_test.keep_transcription_text) + # CMVN spectrum + if self.collate_fn_test._normalizer: + spectrum = self.collate_fn_test._normalizer.apply(spectrum) + + # spectrum augment + audio = self.collate_fn_test.augmentation.transform_feature( + spectrum) + + audio_len = audio.shape[0] + audio = paddle.to_tensor(audio, dtype='float32') + # audio_len = paddle.to_tensor(audio_len) + audio = paddle.unsqueeze(audio, axis=0) + + x_chunk = audio.numpy() + x_chunk_lens = np.array([audio_len]) + + return x_chunk, x_chunk_lens + elif "conformer2online" in self.model_type: + + if sample_rate != self.sample_rate: + logger.info(f"audio sample rate {sample_rate} is not match," \ + "the model sample_rate is {self.sample_rate}") + logger.info(f"ASR Engine use the {self.model_type} to process") + logger.info("Create the preprocess instance") + preprocess_conf = self.config.preprocess_config + preprocess_args = {"train": False} + preprocessing = Transformation(preprocess_conf) + + logger.info("Read the audio file") + logger.info(f"audio shape: {samples.shape}") + # fbank + x_chunk = preprocessing(samples, **preprocess_args) + x_chunk_lens = paddle.to_tensor(x_chunk.shape[0]) + x_chunk = paddle.to_tensor( + x_chunk, dtype="float32").unsqueeze(axis=0) + logger.info( + f"process the audio feature success, feat shape: {x_chunk.shape}" + ) + return x_chunk, x_chunk_lens class ASREngine(BaseEngine): @@ -310,7 +395,10 @@ class ASREngine(BaseEngine): logger.info("Initialize ASR server engine successfully.") return True - def preprocess(self, samples, sample_rate): + def preprocess(self, + samples, + sample_rate, + model_type="deepspeech2online_aishell-zh-16k"): """preprocess Args: @@ -321,6 +409,7 @@ class ASREngine(BaseEngine): x_chunk (numpy.array): shape[B, T, D] x_chunk_lens (numpy.array): shape[B] """ + # if "deepspeech" in model_type: x_chunk, x_chunk_lens = self.executor.extract_feat(samples, sample_rate) return x_chunk, x_chunk_lens diff --git a/paddlespeech/server/tests/asr/online/websocket_client.py b/paddlespeech/server/tests/asr/online/websocket_client.py index 049d707e..a26838f8 100644 --- a/paddlespeech/server/tests/asr/online/websocket_client.py +++ b/paddlespeech/server/tests/asr/online/websocket_client.py @@ -103,7 +103,7 @@ class ASRAudioHandler: def main(args): logging.basicConfig(level=logging.INFO) logging.info("asr websocket client start") - handler = ASRAudioHandler("127.0.0.1", 8090) + handler = ASRAudioHandler("127.0.0.1", 8096) loop = asyncio.get_event_loop() # support to process single audio file diff --git a/paddlespeech/server/ws/asr_socket.py b/paddlespeech/server/ws/asr_socket.py index ea19816b..442f26cb 100644 --- a/paddlespeech/server/ws/asr_socket.py +++ b/paddlespeech/server/ws/asr_socket.py @@ -14,6 +14,7 @@ import json import numpy as np +import json from fastapi import APIRouter from fastapi import WebSocket from fastapi import WebSocketDisconnect @@ -28,7 +29,7 @@ router = APIRouter() @router.websocket('/ws/asr') async def websocket_endpoint(websocket: WebSocket): - + print("websocket protocal receive the dataset") await websocket.accept() engine_pool = get_engine_pool() @@ -36,14 +37,18 @@ async def websocket_endpoint(websocket: WebSocket): # init buffer chunk_buffer_conf = asr_engine.config.chunk_buffer_conf chunk_buffer = ChunkBuffer( + frame_duration_ms=chunk_buffer_conf['frame_duration_ms'], sample_rate=chunk_buffer_conf['sample_rate'], sample_width=chunk_buffer_conf['sample_width']) # init vad - vad_conf = asr_engine.config.vad_conf - vad = VADAudio( - aggressiveness=vad_conf['aggressiveness'], - rate=vad_conf['sample_rate'], - frame_duration_ms=vad_conf['frame_duration_ms']) + # print(asr_engine.config) + # print(type(asr_engine.config)) + vad_conf = asr_engine.config.get('vad_conf', None) + if vad_conf: + vad = VADAudio( + aggressiveness=vad_conf['aggressiveness'], + rate=vad_conf['sample_rate'], + frame_duration_ms=vad_conf['frame_duration_ms']) try: while True: @@ -65,7 +70,7 @@ async def websocket_endpoint(websocket: WebSocket): engine_pool = get_engine_pool() asr_engine = engine_pool['asr'] # reset single engine for an new connection - asr_engine.reset() + # asr_engine.reset() resp = {"status": "ok", "signal": "finished"} await websocket.send_json(resp) break @@ -75,16 +80,16 @@ async def websocket_endpoint(websocket: WebSocket): elif "bytes" in message: message = message["bytes"] - # vad for input bytes audio - vad.add_audio(message) - message = b''.join(f for f in vad.vad_collector() - if f is not None) - + # # vad for input bytes audio + # vad.add_audio(message) + # message = b''.join(f for f in vad.vad_collector() + # if f is not None) engine_pool = get_engine_pool() asr_engine = engine_pool['asr'] asr_results = "" frames = chunk_buffer.frame_generator(message) for frame in frames: + # get the pcm data from the bytes samples = np.frombuffer(frame.bytes, dtype=np.int16) sample_rate = asr_engine.config.sample_rate x_chunk, x_chunk_lens = asr_engine.preprocess(samples,