From d847fe29cfc83afbe5e4fc6f3717240516600eb7 Mon Sep 17 00:00:00 2001 From: WilliamZhang06 Date: Wed, 30 Mar 2022 16:44:04 +0800 Subject: [PATCH 1/2] added online asr engine , test=doc --- paddlespeech/s2t/frontend/audio.py | 12 + paddlespeech/s2t/frontend/speech.py | 16 + paddlespeech/server/bin/main.py | 10 +- paddlespeech/server/conf/application.yaml | 27 +- .../server/engine/asr/online/__init__.py | 13 + .../server/engine/asr/online/asr_engine.py | 355 ++++++++++++++++++ paddlespeech/server/engine/engine_factory.py | 3 + .../tests/asr/online/microphone_client.py | 154 ++++++++ .../tests/asr/online/websocket_client.py | 115 ++++++ paddlespeech/server/utils/buffer.py | 59 +++ paddlespeech/server/utils/vad.py | 79 ++++ paddlespeech/server/ws/__init__.py | 13 + paddlespeech/server/ws/api.py | 38 ++ paddlespeech/server/ws/asr_socket.py | 106 ++++++ 14 files changed, 996 insertions(+), 4 deletions(-) create mode 100644 paddlespeech/server/engine/asr/online/__init__.py create mode 100644 paddlespeech/server/engine/asr/online/asr_engine.py create mode 100644 paddlespeech/server/tests/asr/online/microphone_client.py create mode 100644 paddlespeech/server/tests/asr/online/websocket_client.py create mode 100644 paddlespeech/server/utils/buffer.py create mode 100644 paddlespeech/server/utils/vad.py create mode 100644 paddlespeech/server/ws/__init__.py create mode 100644 paddlespeech/server/ws/api.py create mode 100644 paddlespeech/server/ws/asr_socket.py diff --git a/paddlespeech/s2t/frontend/audio.py b/paddlespeech/s2t/frontend/audio.py index d0368cc8..7f71e5dd 100644 --- a/paddlespeech/s2t/frontend/audio.py +++ b/paddlespeech/s2t/frontend/audio.py @@ -208,6 +208,18 @@ class AudioSegment(): io.BytesIO(bytes), dtype='float32') return cls(samples, sample_rate) + @classmethod + def from_pcm(cls, samples, sample_rate): + """Create audio segment from a byte string containing audio samples. + :param samples: Audio samples [num_samples x num_channels]. + :type samples: numpy.ndarray + :param sample_rate: Audio sample rate. + :type sample_rate: int + :return: Audio segment instance. + :rtype: AudioSegment + """ + return cls(samples, sample_rate) + @classmethod def concatenate(cls, *segments): """Concatenate an arbitrary number of audio segments together. diff --git a/paddlespeech/s2t/frontend/speech.py b/paddlespeech/s2t/frontend/speech.py index 8fd661c9..0340831a 100644 --- a/paddlespeech/s2t/frontend/speech.py +++ b/paddlespeech/s2t/frontend/speech.py @@ -107,6 +107,22 @@ class SpeechSegment(AudioSegment): return cls(audio.samples, audio.sample_rate, transcript, tokens, token_ids) + @classmethod + def from_pcm(cls, samples, sample_rate, transcript, tokens=None, token_ids=None): + """Create speech segment from pcm on online mode + Args: + samples (numpy.ndarray): Audio samples [num_samples x num_channels]. + sample_rate (int): Audio sample rate. + transcript (str): Transcript text for the speech. + tokens (List[str], optional): text tokens. Defaults to None. + token_ids (List[int], optional): text token ids. Defaults to None. + Returns: + SpeechSegment: Speech segment instance. + """ + audio = AudioSegment.from_pcm(samples, sample_rate) + return cls(audio.samples, audio.sample_rate, transcript, tokens, + token_ids) + @classmethod def concatenate(cls, *segments): """Concatenate an arbitrary number of speech segments together, both diff --git a/paddlespeech/server/bin/main.py b/paddlespeech/server/bin/main.py index de528299..45ded33d 100644 --- a/paddlespeech/server/bin/main.py +++ b/paddlespeech/server/bin/main.py @@ -17,7 +17,8 @@ import uvicorn from fastapi import FastAPI from paddlespeech.server.engine.engine_pool import init_engine_pool -from paddlespeech.server.restful.api import setup_router +from paddlespeech.server.restful.api import setup_router as setup_http_router +from paddlespeech.server.ws.api import setup_router as setup_ws_router from paddlespeech.server.utils.config import get_config app = FastAPI( @@ -35,7 +36,12 @@ def init(config): """ # init api api_list = list(engine.split("_")[0] for engine in config.engine_list) - api_router = setup_router(api_list) + if config.protocol == "websocket": + api_router = setup_ws_router(api_list) + elif config.protocol == "http": + api_router = setup_http_router(api_list) + else: + raise Exception("unsupported protocol") app.include_router(api_router) if not init_engine_pool(config): diff --git a/paddlespeech/server/conf/application.yaml b/paddlespeech/server/conf/application.yaml index 2b1a0599..40de8e3b 100644 --- a/paddlespeech/server/conf/application.yaml +++ b/paddlespeech/server/conf/application.yaml @@ -3,13 +3,18 @@ ################################################################################# # SERVER SETTING # ################################################################################# -host: 127.0.0.1 +host: 0.0.0.0 port: 8090 # The task format in the engin_list is: _ # task choices = ['asr_python', 'asr_inference', 'tts_python', 'tts_inference'] +# protocol: 'http' +# engine_list: ['asr_python', 'tts_python', 'cls_python'] -engine_list: ['asr_python', 'tts_python', 'cls_python'] + +# websocket, http (only choose one). websocket only support online engine type. +protocol: 'websocket' +engine_list: ['asr_online'] ################################################################################# @@ -48,6 +53,24 @@ asr_inference: summary: True # False -> do not show predictor config +################### speech task: asr; engine_type: online ####################### +asr_online: + model_type: 'deepspeech2online_aishell' + am_model: # the pdmodel file of am static model [optional] + am_params: # the pdiparams file of am static model [optional] + lang: 'zh' + sample_rate: 16000 + cfg_path: + decode_method: + force_yes: True + + am_predictor_conf: + device: # set 'gpu:id' or 'cpu' + switch_ir_optim: True + glog_info: False # True -> print glog + summary: True # False -> do not show predictor config + + ################################### TTS ######################################### ################### speech task: tts; engine_type: python ####################### tts_python: diff --git a/paddlespeech/server/engine/asr/online/__init__.py b/paddlespeech/server/engine/asr/online/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/server/engine/asr/online/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py new file mode 100644 index 00000000..d5c1aa7b --- /dev/null +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -0,0 +1,355 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import io +import os +import time +from typing import Optional +import pickle +import numpy as np +from numpy import float32 +import soundfile + +import paddle +from yacs.config import CfgNode + +from paddlespeech.s2t.frontend.speech import SpeechSegment +from paddlespeech.cli.asr.infer import ASRExecutor +from paddlespeech.cli.log import logger +from paddlespeech.cli.utils import MODEL_HOME +from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer +from paddlespeech.s2t.modules.ctc import CTCDecoder +from paddlespeech.s2t.utils.utility import UpdateConfig +from paddlespeech.server.engine.base_engine import BaseEngine +from paddlespeech.server.utils.config import get_config +from paddlespeech.server.utils.paddle_predictor import init_predictor +from paddlespeech.server.utils.paddle_predictor import run_model + +__all__ = ['ASREngine'] + +pretrained_models = { + "deepspeech2online_aishell-zh-16k": { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.1.1.model.tar.gz', + 'md5': + 'd5e076217cf60486519f72c217d21b9b', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/deepspeech2_online/checkpoints/avg_1', + 'model': + 'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel', + 'params': + 'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' + }, +} + + +class ASRServerExecutor(ASRExecutor): + def __init__(self): + super().__init__() + pass + + def _init_from_path(self, + model_type: str='wenetspeech', + am_model: Optional[os.PathLike]=None, + am_params: Optional[os.PathLike]=None, + lang: str='zh', + sample_rate: int=16000, + cfg_path: Optional[os.PathLike]=None, + decode_method: str='attention_rescoring', + am_predictor_conf: dict=None): + """ + Init model and other resources from a specific path. + """ + + if cfg_path is None or am_model is None or am_params is None: + sample_rate_str = '16k' if sample_rate == 16000 else '8k' + tag = model_type + '-' + lang + '-' + sample_rate_str + res_path = self._get_pretrained_path(tag) # wenetspeech_zh + self.res_path = res_path + self.cfg_path = os.path.join(res_path, + pretrained_models[tag]['cfg_path']) + + self.am_model = os.path.join(res_path, + pretrained_models[tag]['model']) + self.am_params = os.path.join(res_path, + pretrained_models[tag]['params']) + logger.info(res_path) + logger.info(self.cfg_path) + logger.info(self.am_model) + logger.info(self.am_params) + else: + self.cfg_path = os.path.abspath(cfg_path) + self.am_model = os.path.abspath(am_model) + self.am_params = os.path.abspath(am_params) + self.res_path = os.path.dirname( + os.path.dirname(os.path.abspath(self.cfg_path))) + + #Init body. + self.config = CfgNode(new_allowed=True) + self.config.merge_from_file(self.cfg_path) + + with UpdateConfig(self.config): + if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: + from paddlespeech.s2t.io.collator import SpeechCollator + self.vocab = self.config.vocab_filepath + self.config.decode.lang_model_path = os.path.join( + MODEL_HOME, 'language_model', + self.config.decode.lang_model_path) + self.collate_fn_test = SpeechCollator.from_config(self.config) + self.text_feature = TextFeaturizer( + unit_type=self.config.unit_type, vocab=self.vocab) + + lm_url = pretrained_models[tag]['lm_url'] + lm_md5 = pretrained_models[tag]['lm_md5'] + self.download_lm( + lm_url, + os.path.dirname(self.config.decode.lang_model_path), lm_md5) + elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: + raise Exception("wrong type") + else: + raise Exception("wrong type") + + # AM predictor + self.am_predictor_conf = am_predictor_conf + self.am_predictor = init_predictor( + model_file=self.am_model, + params_file=self.am_params, + predictor_conf=self.am_predictor_conf) + + # decoder + self.decoder = CTCDecoder( + odim=self.config.output_dim, # is in vocab + enc_n_units=self.config.rnn_layer_size * 2, + blank_id=self.config.blank_id, + dropout_rate=0.0, + reduction=True, # sum + batch_average=True, # sum / batch_size + grad_norm_type=self.config.get('ctc_grad_norm_type', None)) + + # init decoder + cfg = self.config.decode + decode_batch_size = 1 # for online + self.decoder.init_decoder( + decode_batch_size, self.text_feature.vocab_list, + cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, + cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, + cfg.num_proc_bsearch) + + # init state box + self.chunk_state_h_box = np.zeros( + (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), + dtype=float32) + self.chunk_state_c_box = np.zeros( + (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), + dtype=float32) + + def reset_decoder_and_chunk(self): + """reset decoder and chunk state for an new audio + """ + self.decoder.reset_decoder(batch_size=1) + # init state box, for new audio request + self.chunk_state_h_box = np.zeros( + (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), + dtype=float32) + self.chunk_state_c_box = np.zeros( + (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), + dtype=float32) + + def decode_one_chunk(self, x_chunk, x_chunk_lens, model_type: str): + """decode one chunk + + Args: + x_chunk (numpy.array): shape[B, T, D] + x_chunk_lens (numpy.array): shape[B] + model_type (str): online model type + + Returns: + [type]: [description] + """ + if "deepspeech2online" in model_type : + input_names = self.am_predictor.get_input_names() + audio_handle = self.am_predictor.get_input_handle(input_names[0]) + audio_len_handle = self.am_predictor.get_input_handle(input_names[1]) + h_box_handle = self.am_predictor.get_input_handle(input_names[2]) + c_box_handle = self.am_predictor.get_input_handle(input_names[3]) + + audio_handle.reshape(x_chunk.shape) + audio_handle.copy_from_cpu(x_chunk) + + audio_len_handle.reshape(x_chunk_lens.shape) + audio_len_handle.copy_from_cpu(x_chunk_lens) + + h_box_handle.reshape(self.chunk_state_h_box.shape) + h_box_handle.copy_from_cpu(self.chunk_state_h_box) + + c_box_handle.reshape(self.chunk_state_c_box.shape) + c_box_handle.copy_from_cpu(self.chunk_state_c_box) + + output_names = self.am_predictor.get_output_names() + output_handle = self.am_predictor.get_output_handle(output_names[0]) + output_lens_handle = self.am_predictor.get_output_handle(output_names[1]) + output_state_h_handle = self.am_predictor.get_output_handle( + output_names[2]) + output_state_c_handle = self.am_predictor.get_output_handle( + output_names[3]) + + self.am_predictor.run() + + output_chunk_probs = output_handle.copy_to_cpu() + output_chunk_lens = output_lens_handle.copy_to_cpu() + self.chunk_state_h_box = output_state_h_handle.copy_to_cpu() + self.chunk_state_c_box = output_state_c_handle.copy_to_cpu() + + self.decoder.next(output_chunk_probs, output_chunk_lens) + trans_best, trans_beam = self.decoder.decode() + + return trans_best[0] + + elif "conformer" in model_type or "transformer" in model_type: + raise Exception("invalid model name") + else: + raise Exception("invalid model name") + + def _pcm16to32(self, audio): + """pcm int16 to float32 + + Args: + audio(numpy.array): numpy.int16 + + Returns: + audio(numpy.array): numpy.float32 + """ + if audio.dtype == np.int16: + audio = audio.astype("float32") + bits = np.iinfo(np.int16).bits + audio = audio / (2**(bits - 1)) + return audio + + def extract_feat(self, samples, sample_rate): + """extract feat + + Args: + samples (numpy.array): numpy.float32 + sample_rate (int): sample rate + + Returns: + x_chunk (numpy.array): shape[B, T, D] + x_chunk_lens (numpy.array): shape[B] + """ + # pcm16 -> pcm 32 + samples = self._pcm16to32(samples) + + # read audio + speech_segment = SpeechSegment.from_pcm( + samples, sample_rate, transcript=" ") + # audio augment + self.collate_fn_test.augmentation.transform_audio(speech_segment) + + # extract speech feature + spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize( + speech_segment, self.collate_fn_test.keep_transcription_text) + # CMVN spectrum + if self.collate_fn_test._normalizer: + spectrum = self.collate_fn_test._normalizer.apply(spectrum) + + # spectrum augment + audio = self.collate_fn_test.augmentation.transform_feature(spectrum) + + audio_len = audio.shape[0] + audio = paddle.to_tensor(audio, dtype='float32') + # audio_len = paddle.to_tensor(audio_len) + audio = paddle.unsqueeze(audio, axis=0) + + x_chunk = audio.numpy() + x_chunk_lens = np.array([audio_len]) + + return x_chunk, x_chunk_lens + + +class ASREngine(BaseEngine): + """ASR server engine + + Args: + metaclass: Defaults to Singleton. + """ + + def __init__(self): + super(ASREngine, self).__init__() + + def init(self, config: dict) -> bool: + """init engine resource + + Args: + config_file (str): config file + + Returns: + bool: init failed or success + """ + self.input = None + self.output = "" + self.executor = ASRServerExecutor() + self.config = config + + self.executor._init_from_path( + model_type=self.config.model_type, + am_model=self.config.am_model, + am_params=self.config.am_params, + lang=self.config.lang, + sample_rate=self.config.sample_rate, + cfg_path=self.config.cfg_path, + decode_method=self.config.decode_method, + am_predictor_conf=self.config.am_predictor_conf) + + logger.info("Initialize ASR server engine successfully.") + return True + + def preprocess(self, samples, sample_rate): + """preprocess + + Args: + samples (numpy.array): numpy.float32 + sample_rate (int): sample rate + + Returns: + x_chunk (numpy.array): shape[B, T, D] + x_chunk_lens (numpy.array): shape[B] + """ + x_chunk, x_chunk_lens = self.executor.extract_feat(samples, sample_rate) + return x_chunk, x_chunk_lens + + def run(self, x_chunk, x_chunk_lens, decoder_chunk_size=1): + """run online engine + + Args: + x_chunk (numpy.array): shape[B, T, D] + x_chunk_lens (numpy.array): shape[B] + decoder_chunk_size(int) + """ + self.output = self.executor.decode_one_chunk(x_chunk, x_chunk_lens, self.config.model_type) + + def postprocess(self): + """postprocess + """ + return self.output + + def reset(self): + """reset engine decoder and inference state + """ + self.executor.reset_decoder_and_chunk() + self.output = "" diff --git a/paddlespeech/server/engine/engine_factory.py b/paddlespeech/server/engine/engine_factory.py index c39c44ca..2a39fb79 100644 --- a/paddlespeech/server/engine/engine_factory.py +++ b/paddlespeech/server/engine/engine_factory.py @@ -25,6 +25,9 @@ class EngineFactory(object): elif engine_name == 'asr' and engine_type == 'python': from paddlespeech.server.engine.asr.python.asr_engine import ASREngine return ASREngine() + elif engine_name == 'asr' and engine_type == 'online': + from paddlespeech.server.engine.asr.online.asr_engine import ASREngine + return ASREngine() elif engine_name == 'tts' and engine_type == 'inference': from paddlespeech.server.engine.tts.paddleinference.tts_engine import TTSEngine return TTSEngine() diff --git a/paddlespeech/server/tests/asr/online/microphone_client.py b/paddlespeech/server/tests/asr/online/microphone_client.py new file mode 100644 index 00000000..74d457c5 --- /dev/null +++ b/paddlespeech/server/tests/asr/online/microphone_client.py @@ -0,0 +1,154 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +record wave from the mic +""" + +import threading +import pyaudio +import wave +import logging +import asyncio +import websockets +import json +from signal import SIGINT, SIGTERM + + +class ASRAudioHandler(threading.Thread): + def __init__(self, + url="127.0.0.1", + port=8090): + threading.Thread.__init__(self) + self.url = url + self.port = port + self.url = "ws://" + self.url + ":" + str(self.port) + "/ws/asr" + self.fileName = "./output.wav" + self.chunk = 5120 + self.format = pyaudio.paInt16 + self.channels = 1 + self.rate = 16000 + self._running = True + self._frames = [] + self.data_backup = [] + + def startrecord(self): + """ + start a new thread to record wave + """ + threading._start_new_thread(self.recording, ()) + + def recording(self): + """ + recording wave + """ + self._running = True + self._frames = [] + p = pyaudio.PyAudio() + stream = p.open(format=self.format, + channels=self.channels, + rate=self.rate, + input=True, + frames_per_buffer=self.chunk) + while(self._running): + data = stream.read(self.chunk) + self._frames.append(data) + self.data_backup.append(data) + + stream.stop_stream() + stream.close() + p.terminate() + + def save(self): + """ + save wave data + """ + p = pyaudio.PyAudio() + wf = wave.open(self.fileName, 'wb') + wf.setnchannels(self.channels) + wf.setsampwidth(p.get_sample_size(self.format)) + wf.setframerate(self.rate) + wf.writeframes(b''.join(self.data_backup)) + wf.close() + p.terminate() + + def stoprecord(self): + """ + stop recording + """ + self._running = False + + async def run(self): + aa = input("是否开始录音? (y/n)") + if aa.strip() == "y": + self.startrecord() + logging.info("*" * 10 + "开始录音,请输入语音") + + async with websockets.connect(self.url) as ws: + # 发送开始指令 + audio_info = json.dumps({ + "name": "test.wav", + "signal": "start", + "nbest": 5 + }, sort_keys=True, indent=4, separators=(',', ': ')) + await ws.send(audio_info) + msg = await ws.recv() + logging.info("receive msg={}".format(msg)) + + # send bytes data + logging.info("结束录音请: Ctrl + c。继续请按回车。") + try: + while True: + while len(self._frames) > 0: + await ws.send(self._frames.pop(0)) + msg = await ws.recv() + logging.info("receive msg={}".format(msg)) + except asyncio.CancelledError: + # quit + # send finished + audio_info = json.dumps({ + "name": "test.wav", + "signal": "end", + "nbest": 5 + }, sort_keys=True, indent=4, separators=(',', ': ')) + await ws.send(audio_info) + msg = await ws.recv() + logging.info("receive msg={}".format(msg)) + + self.stoprecord() + logging.info("*" * 10 + "录音结束") + self.save() + elif aa.strip() == "n": + exit() + else: + print("无效输入!") + exit() + + +if __name__ == "__main__": + + logging.basicConfig(level=logging.INFO) + logging.info("asr websocket client start") + + handler = ASRAudioHandler("127.0.0.1", 8090) + loop = asyncio.get_event_loop() + main_task = asyncio.ensure_future(handler.run()) + for signal in [SIGINT, SIGTERM]: + loop.add_signal_handler(signal, main_task.cancel) + try: + loop.run_until_complete(main_task) + finally: + loop.close() + + logging.info("asr websocket client finished") diff --git a/paddlespeech/server/tests/asr/online/websocket_client.py b/paddlespeech/server/tests/asr/online/websocket_client.py new file mode 100644 index 00000000..d849ffea --- /dev/null +++ b/paddlespeech/server/tests/asr/online/websocket_client.py @@ -0,0 +1,115 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/python +# -*- coding: UTF-8 -*- + +import argparse +import logging +import time +import os +import json +import wave +import numpy as np +import asyncio +import websockets +import soundfile + + +class ASRAudioHandler: + def __init__(self, + url="127.0.0.1", + port=8090): + self.url = url + self.port = port + self.url = "ws://" + self.url + ":" + str(self.port) + "/ws/asr" + + def read_wave(self, wavfile_path: str): + samples, sample_rate = soundfile.read(wavfile_path, dtype='int16') + x_len = len(samples) + chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz + chunk_size = 80 * 16 #80ms, sample_rate = 16kHz + + if (x_len - chunk_size) % chunk_stride != 0: + padding_len_x = chunk_stride - (x_len - chunk_size + ) % chunk_stride + else: + padding_len_x = 0 + + padding = np.zeros( + (padding_len_x), dtype=samples.dtype) + padded_x = np.concatenate([samples, padding], axis=0) + + num_chunk = (x_len + padding_len_x - chunk_size) / chunk_stride + 1 + num_chunk = int(num_chunk) + + for i in range(0, num_chunk): + start = i * chunk_stride + end = start + chunk_size + x_chunk = padded_x[start:end] + yield x_chunk + + async def run(self, wavfile_path: str): + logging.info("send a message to the server") + # 读取音频 + # self.read_wave() + # 发送 websocket 的 handshake 协议头 + async with websockets.connect(self.url) as ws: + # server 端已经接收到 handshake 协议头 + # 发送开始指令 + audio_info = json.dumps({ + "name": "test.wav", + "signal": "start", + "nbest": 5 + }, sort_keys=True, indent=4, separators=(',', ': ')) + await ws.send(audio_info) + msg = await ws.recv() + logging.info("receive msg={}".format(msg)) + + # send chunk audio data to engine + for chunk_data in self.read_wave(wavfile_path): + await ws.send(chunk_data.tobytes()) + msg = await ws.recv() + logging.info("receive msg={}".format(msg)) + + # finished + audio_info = json.dumps({ + "name": "test.wav", + "signal": "end", + "nbest": 5 + }, sort_keys=True, indent=4, separators=(',', ': ')) + await ws.send(audio_info) + msg = await ws.recv() + logging.info("receive msg={}".format(msg)) + + +def main(args): + logging.basicConfig(level=logging.INFO) + logging.info("asr websocket client start") + handler = ASRAudioHandler("127.0.0.1", 8090) + loop = asyncio.get_event_loop() + loop.run_until_complete(handler.run(args.wavfile)) + logging.info("asr websocket client finished") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--wavfile", + action="store", + help="wav file path ", + default="./16_audio.wav") + args = parser.parse_args() + + main(args) diff --git a/paddlespeech/server/utils/buffer.py b/paddlespeech/server/utils/buffer.py new file mode 100644 index 00000000..4c1a3958 --- /dev/null +++ b/paddlespeech/server/utils/buffer.py @@ -0,0 +1,59 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class Frame(object): + """Represents a "frame" of audio data.""" + + def __init__(self, bytes, timestamp, duration): + self.bytes = bytes + self.timestamp = timestamp + self.duration = duration + + +class ChunkBuffer(object): + def __init__(self, + frame_duration_ms=80, + shift_ms=40, + sample_rate=16000, + sample_width=2): + self.sample_rate = sample_rate + self.frame_duration_ms = frame_duration_ms + self.shift_ms = shift_ms + self.remained_audio = b'' + self.sample_width = sample_width # int16 = 2; float32 = 4 + + def frame_generator(self, audio): + """Generates audio frames from PCM audio data. + Takes the desired frame duration in milliseconds, the PCM data, and + the sample rate. + Yields Frames of the requested duration. + """ + audio = self.remained_audio + audio + self.remained_audio = b'' + + n = int(self.sample_rate * + (self.frame_duration_ms / 1000.0) * self.sample_width) + shift_n = int(self.sample_rate * + (self.shift_ms / 1000.0) * self.sample_width) + offset = 0 + timestamp = 0.0 + duration = (float(n) / self.sample_rate) / self.sample_width + shift_duration = (float(shift_n) / self.sample_rate) / self.sample_width + while offset + n <= len(audio): + yield Frame(audio[offset:offset + n], timestamp, duration) + timestamp += shift_duration + offset += shift_n + + self.remained_audio += audio[offset:] diff --git a/paddlespeech/server/utils/vad.py b/paddlespeech/server/utils/vad.py new file mode 100644 index 00000000..e9b55717 --- /dev/null +++ b/paddlespeech/server/utils/vad.py @@ -0,0 +1,79 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections +import logging + +import webrtcvad + + +class VADAudio(): + def __init__(self, + aggressiveness, + rate, + frame_duration_ms, + sample_width=2, + padding_ms=200, + padding_ratio=0.9): + """Initializes VAD with given aggressivenes and sets up internal queues""" + self.vad = webrtcvad.Vad(aggressiveness) + self.rate = rate + self.sample_width = sample_width + self.frame_duration_ms = frame_duration_ms + self._frame_length = int(rate * (frame_duration_ms / 1000.0) * + self.sample_width) + self._buffer_queue = collections.deque() + self.ring_buffer = collections.deque(maxlen=padding_ms // + frame_duration_ms) + self._ratio = padding_ratio + self.triggered = False + + def add_audio(self, audio): + """Adds new audio to internal queue""" + for x in audio: + self._buffer_queue.append(x) + + def frame_generator(self): + """Generator that yields audio frames of frame_duration_ms""" + while len(self._buffer_queue) > self._frame_length: + frame = bytearray() + for _ in range(self._frame_length): + frame.append(self._buffer_queue.popleft()) + yield bytes(frame) + + def vad_collector(self): + """Generator that yields series of consecutive audio frames comprising each utterence, separated by yielding a single None. + Determines voice activity by ratio of frames in padding_ms. Uses a buffer to include padding_ms prior to being triggered. + Example: (frame, ..., frame, None, frame, ..., frame, None, ...) + |---utterence---| |---utterence---| + """ + for frame in self.frame_generator(): + is_speech = self.vad.is_speech(frame, self.rate) + if not self.triggered: + self.ring_buffer.append((frame, is_speech)) + num_voiced = len( + [f for f, speech in self.ring_buffer if speech]) + if num_voiced > self._ratio * self.ring_buffer.maxlen: + self.triggered = True + for f, s in self.ring_buffer: + yield f + self.ring_buffer.clear() + else: + yield frame + self.ring_buffer.append((frame, is_speech)) + num_unvoiced = len( + [f for f, speech in self.ring_buffer if not speech]) + if num_unvoiced > self._ratio * self.ring_buffer.maxlen: + self.triggered = False + yield None + self.ring_buffer.clear() diff --git a/paddlespeech/server/ws/__init__.py b/paddlespeech/server/ws/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/server/ws/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/ws/api.py b/paddlespeech/server/ws/api.py new file mode 100644 index 00000000..10664d11 --- /dev/null +++ b/paddlespeech/server/ws/api.py @@ -0,0 +1,38 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List + +from fastapi import APIRouter + +from paddlespeech.server.ws.asr_socket import router as asr_router + +_router = APIRouter() + + +def setup_router(api_list: List): + """setup router for fastapi + Args: + api_list (List): [asr, tts] + Returns: + APIRouter + """ + for api_name in api_list: + if api_name == 'asr': + _router.include_router(asr_router) + elif api_name == 'tts': + pass + else: + pass + + return _router diff --git a/paddlespeech/server/ws/asr_socket.py b/paddlespeech/server/ws/asr_socket.py new file mode 100644 index 00000000..5cc9472c --- /dev/null +++ b/paddlespeech/server/ws/asr_socket.py @@ -0,0 +1,106 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import base64 +import traceback +from typing import Union +import random +import numpy as np +import json + +from fastapi import APIRouter +from fastapi import WebSocket +from fastapi import WebSocketDisconnect +from starlette.websockets import WebSocketState as WebSocketState + +from paddlespeech.server.engine.asr.online.asr_engine import ASREngine +from paddlespeech.server.engine.engine_pool import get_engine_pool +from paddlespeech.server.utils.buffer import ChunkBuffer +from paddlespeech.server.utils.vad import VADAudio + + +router = APIRouter() + +@router.websocket('/ws/asr') +async def websocket_endpoint(websocket: WebSocket): + + await websocket.accept() + + # init buffer + chunk_buffer = ChunkBuffer(sample_width=2) + # init vad + vad = VADAudio(2, 16000, 20) + + try: + while True: + # careful here, changed the source code from starlette.websockets + assert websocket.application_state == WebSocketState.CONNECTED + message = await websocket.receive() + websocket._raise_on_disconnect(message) + if "text" in message: + message = json.loads(message["text"]) + if 'signal' not in message: + resp = { + "status": "ok", + "message": "no valid json data" + } + await websocket.send_json(resp) + + if message['signal'] == 'start': + resp = { + "status": "ok", + "signal": "server_ready" + } + # do something at begining here + await websocket.send_json(resp) + elif message['signal'] == 'end': + engine_pool = get_engine_pool() + asr_engine = engine_pool['asr'] + # reset single engine for an new connection + asr_engine.reset() + resp = { + "status": "ok", + "signal": "finished" + } + await websocket.send_json(resp) + break + else: + resp = { + "status": "ok", + "message": "no valid json data" + } + await websocket.send_json(resp) + elif "bytes" in message: + message = message["bytes"] + + # vad for input bytes audio + vad.add_audio(message) + message = b''.join(f for f in vad.vad_collector() if f is not None) + + engine_pool = get_engine_pool() + asr_engine = engine_pool['asr'] + asr_results = "" + frames = chunk_buffer.frame_generator(message) + for frame in frames: + samples = np.frombuffer(frame.bytes, dtype=np.int16) + sample_rate = asr_engine.config.sample_rate + x_chunk, x_chunk_lens = asr_engine.preprocess(samples, sample_rate) + asr_engine.run(x_chunk, x_chunk_lens) + asr_results = asr_engine.postprocess() + + asr_results = asr_engine.postprocess() + resp = {'asr_results': asr_results} + + await websocket.send_json(resp) + except WebSocketDisconnect: + pass From 2ec8d608bf1ad5b0be9c36dc4339702271c27a6e Mon Sep 17 00:00:00 2001 From: WilliamZhang06 Date: Thu, 31 Mar 2022 16:06:16 +0800 Subject: [PATCH 2/2] fixed comments, test=doc --- paddlespeech/server/conf/application.yaml | 13 ++-- paddlespeech/server/conf/ws_application.yaml | 51 ++++++++++++++++ .../tests/asr/online/microphone_client.py | 61 +++++++++++-------- .../tests/asr/online/websocket_client.py | 56 ++++++++--------- paddlespeech/server/utils/vad.py | 7 +-- paddlespeech/server/ws/asr_socket.py | 48 +++++++-------- 6 files changed, 142 insertions(+), 94 deletions(-) create mode 100644 paddlespeech/server/conf/ws_application.yaml diff --git a/paddlespeech/server/conf/application.yaml b/paddlespeech/server/conf/application.yaml index 40de8e3b..849349c2 100644 --- a/paddlespeech/server/conf/application.yaml +++ b/paddlespeech/server/conf/application.yaml @@ -3,18 +3,15 @@ ################################################################################# # SERVER SETTING # ################################################################################# -host: 0.0.0.0 +host: 127.0.0.1 port: 8090 # The task format in the engin_list is: _ # task choices = ['asr_python', 'asr_inference', 'tts_python', 'tts_inference'] -# protocol: 'http' -# engine_list: ['asr_python', 'tts_python', 'cls_python'] - - -# websocket, http (only choose one). websocket only support online engine type. -protocol: 'websocket' -engine_list: ['asr_online'] +# protocol = ['websocket', 'http'] (only one can be selected). +# http only support offline engine type. +protocol: 'http' +engine_list: ['asr_python', 'tts_python', 'cls_python'] ################################################################################# diff --git a/paddlespeech/server/conf/ws_application.yaml b/paddlespeech/server/conf/ws_application.yaml new file mode 100644 index 00000000..ef23593e --- /dev/null +++ b/paddlespeech/server/conf/ws_application.yaml @@ -0,0 +1,51 @@ +# This is the parameter configuration file for PaddleSpeech Serving. + +################################################################################# +# SERVER SETTING # +################################################################################# +host: 0.0.0.0 +port: 8091 + +# The task format in the engin_list is: _ +# task choices = ['asr_online', 'tts_online'] +# protocol = ['websocket', 'http'] (only one can be selected). +# websocket only support online engine type. +protocol: 'websocket' +engine_list: ['asr_online'] + + +################################################################################# +# ENGINE CONFIG # +################################################################################# + +################################### ASR ######################################### +################### speech task: asr; engine_type: online ####################### +asr_online: + model_type: 'deepspeech2online_aishell' + am_model: # the pdmodel file of am static model [optional] + am_params: # the pdiparams file of am static model [optional] + lang: 'zh' + sample_rate: 16000 + cfg_path: + decode_method: + force_yes: True + + am_predictor_conf: + device: # set 'gpu:id' or 'cpu' + switch_ir_optim: True + glog_info: False # True -> print glog + summary: True # False -> do not show predictor config + + chunk_buffer_conf: + frame_duration_ms: 80 + shift_ms: 40 + sample_rate: 16000 + sample_width: 2 + + vad_conf: + aggressiveness: 2 + sample_rate: 16000 + frame_duration_ms: 20 + sample_width: 2 + padding_ms: 200 + padding_ratio: 0.9 diff --git a/paddlespeech/server/tests/asr/online/microphone_client.py b/paddlespeech/server/tests/asr/online/microphone_client.py index 74d457c5..2ceaf6d0 100644 --- a/paddlespeech/server/tests/asr/online/microphone_client.py +++ b/paddlespeech/server/tests/asr/online/microphone_client.py @@ -11,25 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """ record wave from the mic """ - +import asyncio +import json +import logging import threading -import pyaudio import wave -import logging -import asyncio +from signal import SIGINT +from signal import SIGTERM + +import pyaudio import websockets -import json -from signal import SIGINT, SIGTERM class ASRAudioHandler(threading.Thread): - def __init__(self, - url="127.0.0.1", - port=8090): + def __init__(self, url="127.0.0.1", port=8091): threading.Thread.__init__(self) self.url = url self.port = port @@ -56,12 +54,13 @@ class ASRAudioHandler(threading.Thread): self._running = True self._frames = [] p = pyaudio.PyAudio() - stream = p.open(format=self.format, - channels=self.channels, - rate=self.rate, - input=True, - frames_per_buffer=self.chunk) - while(self._running): + stream = p.open( + format=self.format, + channels=self.channels, + rate=self.rate, + input=True, + frames_per_buffer=self.chunk) + while (self._running): data = stream.read(self.chunk) self._frames.append(data) self.data_backup.append(data) @@ -97,11 +96,15 @@ class ASRAudioHandler(threading.Thread): async with websockets.connect(self.url) as ws: # 发送开始指令 - audio_info = json.dumps({ - "name": "test.wav", - "signal": "start", - "nbest": 5 - }, sort_keys=True, indent=4, separators=(',', ': ')) + audio_info = json.dumps( + { + "name": "test.wav", + "signal": "start", + "nbest": 5 + }, + sort_keys=True, + indent=4, + separators=(',', ': ')) await ws.send(audio_info) msg = await ws.recv() logging.info("receive msg={}".format(msg)) @@ -117,11 +120,15 @@ class ASRAudioHandler(threading.Thread): except asyncio.CancelledError: # quit # send finished - audio_info = json.dumps({ - "name": "test.wav", - "signal": "end", - "nbest": 5 - }, sort_keys=True, indent=4, separators=(',', ': ')) + audio_info = json.dumps( + { + "name": "test.wav", + "signal": "end", + "nbest": 5 + }, + sort_keys=True, + indent=4, + separators=(',', ': ')) await ws.send(audio_info) msg = await ws.recv() logging.info("receive msg={}".format(msg)) @@ -141,7 +148,7 @@ if __name__ == "__main__": logging.basicConfig(level=logging.INFO) logging.info("asr websocket client start") - handler = ASRAudioHandler("127.0.0.1", 8090) + handler = ASRAudioHandler("127.0.0.1", 8091) loop = asyncio.get_event_loop() main_task = asyncio.ensure_future(handler.run()) for signal in [SIGINT, SIGTERM]: diff --git a/paddlespeech/server/tests/asr/online/websocket_client.py b/paddlespeech/server/tests/asr/online/websocket_client.py index d849ffea..58b1a452 100644 --- a/paddlespeech/server/tests/asr/online/websocket_client.py +++ b/paddlespeech/server/tests/asr/online/websocket_client.py @@ -11,26 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - #!/usr/bin/python # -*- coding: UTF-8 -*- - import argparse -import logging -import time -import os +import asyncio import json -import wave +import logging + import numpy as np -import asyncio -import websockets import soundfile +import websockets class ASRAudioHandler: - def __init__(self, - url="127.0.0.1", - port=8090): + def __init__(self, url="127.0.0.1", port=8090): self.url = url self.port = port self.url = "ws://" + self.url + ":" + str(self.port) + "/ws/asr" @@ -38,17 +32,15 @@ class ASRAudioHandler: def read_wave(self, wavfile_path: str): samples, sample_rate = soundfile.read(wavfile_path, dtype='int16') x_len = len(samples) - chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz - chunk_size = 80 * 16 #80ms, sample_rate = 16kHz + chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz + chunk_size = 80 * 16 #80ms, sample_rate = 16kHz if (x_len - chunk_size) % chunk_stride != 0: - padding_len_x = chunk_stride - (x_len - chunk_size - ) % chunk_stride + padding_len_x = chunk_stride - (x_len - chunk_size) % chunk_stride else: padding_len_x = 0 - padding = np.zeros( - (padding_len_x), dtype=samples.dtype) + padding = np.zeros((padding_len_x), dtype=samples.dtype) padded_x = np.concatenate([samples, padding], axis=0) num_chunk = (x_len + padding_len_x - chunk_size) / chunk_stride + 1 @@ -68,11 +60,15 @@ class ASRAudioHandler: async with websockets.connect(self.url) as ws: # server 端已经接收到 handshake 协议头 # 发送开始指令 - audio_info = json.dumps({ - "name": "test.wav", - "signal": "start", - "nbest": 5 - }, sort_keys=True, indent=4, separators=(',', ': ')) + audio_info = json.dumps( + { + "name": "test.wav", + "signal": "start", + "nbest": 5 + }, + sort_keys=True, + indent=4, + separators=(',', ': ')) await ws.send(audio_info) msg = await ws.recv() logging.info("receive msg={}".format(msg)) @@ -84,11 +80,15 @@ class ASRAudioHandler: logging.info("receive msg={}".format(msg)) # finished - audio_info = json.dumps({ - "name": "test.wav", - "signal": "end", - "nbest": 5 - }, sort_keys=True, indent=4, separators=(',', ': ')) + audio_info = json.dumps( + { + "name": "test.wav", + "signal": "end", + "nbest": 5 + }, + sort_keys=True, + indent=4, + separators=(',', ': ')) await ws.send(audio_info) msg = await ws.recv() logging.info("receive msg={}".format(msg)) @@ -97,7 +97,7 @@ class ASRAudioHandler: def main(args): logging.basicConfig(level=logging.INFO) logging.info("asr websocket client start") - handler = ASRAudioHandler("127.0.0.1", 8090) + handler = ASRAudioHandler("127.0.0.1", 8091) loop = asyncio.get_event_loop() loop.run_until_complete(handler.run(args.wavfile)) logging.info("asr websocket client finished") diff --git a/paddlespeech/server/utils/vad.py b/paddlespeech/server/utils/vad.py index e9b55717..a2dcf68b 100644 --- a/paddlespeech/server/utils/vad.py +++ b/paddlespeech/server/utils/vad.py @@ -12,16 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. import collections -import logging import webrtcvad class VADAudio(): def __init__(self, - aggressiveness, - rate, - frame_duration_ms, + aggressiveness=2, + rate=16000, + frame_duration_ms=20, sample_width=2, padding_ms=200, padding_ratio=0.9): diff --git a/paddlespeech/server/ws/asr_socket.py b/paddlespeech/server/ws/asr_socket.py index 5cc9472c..ea19816b 100644 --- a/paddlespeech/server/ws/asr_socket.py +++ b/paddlespeech/server/ws/asr_socket.py @@ -11,35 +11,39 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import base64 -import traceback -from typing import Union -import random -import numpy as np import json +import numpy as np from fastapi import APIRouter from fastapi import WebSocket from fastapi import WebSocketDisconnect from starlette.websockets import WebSocketState as WebSocketState -from paddlespeech.server.engine.asr.online.asr_engine import ASREngine from paddlespeech.server.engine.engine_pool import get_engine_pool from paddlespeech.server.utils.buffer import ChunkBuffer from paddlespeech.server.utils.vad import VADAudio - router = APIRouter() + @router.websocket('/ws/asr') async def websocket_endpoint(websocket: WebSocket): await websocket.accept() + engine_pool = get_engine_pool() + asr_engine = engine_pool['asr'] # init buffer - chunk_buffer = ChunkBuffer(sample_width=2) + chunk_buffer_conf = asr_engine.config.chunk_buffer_conf + chunk_buffer = ChunkBuffer( + sample_rate=chunk_buffer_conf['sample_rate'], + sample_width=chunk_buffer_conf['sample_width']) # init vad - vad = VADAudio(2, 16000, 20) + vad_conf = asr_engine.config.vad_conf + vad = VADAudio( + aggressiveness=vad_conf['aggressiveness'], + rate=vad_conf['sample_rate'], + frame_duration_ms=vad_conf['frame_duration_ms']) try: while True: @@ -50,17 +54,11 @@ async def websocket_endpoint(websocket: WebSocket): if "text" in message: message = json.loads(message["text"]) if 'signal' not in message: - resp = { - "status": "ok", - "message": "no valid json data" - } + resp = {"status": "ok", "message": "no valid json data"} await websocket.send_json(resp) if message['signal'] == 'start': - resp = { - "status": "ok", - "signal": "server_ready" - } + resp = {"status": "ok", "signal": "server_ready"} # do something at begining here await websocket.send_json(resp) elif message['signal'] == 'end': @@ -68,24 +66,19 @@ async def websocket_endpoint(websocket: WebSocket): asr_engine = engine_pool['asr'] # reset single engine for an new connection asr_engine.reset() - resp = { - "status": "ok", - "signal": "finished" - } + resp = {"status": "ok", "signal": "finished"} await websocket.send_json(resp) break else: - resp = { - "status": "ok", - "message": "no valid json data" - } + resp = {"status": "ok", "message": "no valid json data"} await websocket.send_json(resp) elif "bytes" in message: message = message["bytes"] # vad for input bytes audio vad.add_audio(message) - message = b''.join(f for f in vad.vad_collector() if f is not None) + message = b''.join(f for f in vad.vad_collector() + if f is not None) engine_pool = get_engine_pool() asr_engine = engine_pool['asr'] @@ -94,7 +87,8 @@ async def websocket_endpoint(websocket: WebSocket): for frame in frames: samples = np.frombuffer(frame.bytes, dtype=np.int16) sample_rate = asr_engine.config.sample_rate - x_chunk, x_chunk_lens = asr_engine.preprocess(samples, sample_rate) + x_chunk, x_chunk_lens = asr_engine.preprocess(samples, + sample_rate) asr_engine.run(x_chunk, x_chunk_lens) asr_results = asr_engine.postprocess()