From b1dddddbe08282468d3ae289772f1c7fc6e16ad2 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Sat, 30 Apr 2022 21:47:11 +0800 Subject: [PATCH 01/11] add vector server, test=doc --- demos/speaker_verification/README.md | 2 +- demos/speaker_verification/README_cn.md | 6 +- .../streaming_asr_server/websocket_client.py | 9 +- paddlespeech/cli/vector/infer.py | 3 +- paddlespeech/server/README_cn.md | 20 ++ .../server/bin/paddlespeech_client.py | 99 ++++++++- paddlespeech/server/conf/application.yaml | 13 +- .../server/conf/vector_application.yaml | 32 +++ .../server/engine/asr/online/asr_engine.py | 8 + paddlespeech/server/engine/engine_factory.py | 3 + paddlespeech/server/engine/vector/__init__.py | 0 .../server/engine/vector/python/__init__.py | 0 .../engine/vector/python/vector_engine.py | 200 ++++++++++++++++++ paddlespeech/server/restful/api.py | 4 +- paddlespeech/server/restful/request.py | 39 +++- paddlespeech/server/restful/response.py | 56 +++++ paddlespeech/server/restful/vector_api.py | 151 +++++++++++++ paddlespeech/server/utils/audio_handler.py | 101 +++++++++ 18 files changed, 735 insertions(+), 11 deletions(-) create mode 100644 paddlespeech/server/conf/vector_application.yaml create mode 100644 paddlespeech/server/engine/vector/__init__.py create mode 100644 paddlespeech/server/engine/vector/python/__init__.py create mode 100644 paddlespeech/server/engine/vector/python/vector_engine.py create mode 100644 paddlespeech/server/restful/vector_api.py diff --git a/demos/speaker_verification/README.md b/demos/speaker_verification/README.md index b79f3f7a..b6a1d9bc 100644 --- a/demos/speaker_verification/README.md +++ b/demos/speaker_verification/README.md @@ -14,7 +14,7 @@ see [installation](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/doc You can choose one way from easy, meduim and hard to install paddlespeech. ### 2. Prepare Input File -The input of this demo should be a WAV file(`.wav`), and the sample rate must be the same as the model. +The input of this cli demo should be a WAV file(`.wav`), and the sample rate must be the same as the model. Here are sample files for this demo that can be downloaded: ```bash diff --git a/demos/speaker_verification/README_cn.md b/demos/speaker_verification/README_cn.md index db382f29..90bba38a 100644 --- a/demos/speaker_verification/README_cn.md +++ b/demos/speaker_verification/README_cn.md @@ -4,16 +4,16 @@ ## 介绍 声纹识别是一项用计算机程序自动提取说话人特征的技术。 -这个 demo 是一个从给定音频文件提取说话人特征,它可以通过使用 `PaddleSpeech` 的单个命令或 python 中的几行代码来实现。 +这个 demo 是从一个给定音频文件中提取说话人特征,它可以通过使用 `PaddleSpeech` 的单个命令或 python 中的几行代码来实现。 ## 使用方法 ### 1. 安装 请看[安装文档](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install_cn.md)。 -你可以从 easy,medium,hard 三中方式中选择一种方式安装。 +你可以从easy medium,hard 三种方式中选择一种方式安装。 ### 2. 准备输入 -这个 demo 的输入应该是一个 WAV 文件(`.wav`),并且采样率必须与模型的采样率相同。 +声纹cli demo 的输入应该是一个 WAV 文件(`.wav`),并且采样率必须与模型的采样率相同。 可以下载此 demo 的示例音频: ```bash diff --git a/demos/streaming_asr_server/websocket_client.py b/demos/streaming_asr_server/websocket_client.py index 523ef482..3cadd72a 100644 --- a/demos/streaming_asr_server/websocket_client.py +++ b/demos/streaming_asr_server/websocket_client.py @@ -28,6 +28,7 @@ def main(args): handler = ASRWsAudioHandler( args.server_ip, args.port, + endpoint=args.endpoint, punc_server_ip=args.punc_server_ip, punc_server_port=args.punc_server_port) loop = asyncio.get_event_loop() @@ -36,7 +37,7 @@ def main(args): if args.wavfile and os.path.exists(args.wavfile): logger.info(f"start to process the wavscp: {args.wavfile}") result = loop.run_until_complete(handler.run(args.wavfile)) - result = result["result"] + # result = result["result"] logger.info(f"asr websocket client finished : {result}") # support to process batch audios from wav.scp @@ -69,7 +70,11 @@ if __name__ == "__main__": default=8091, dest="punc_server_port", help='Punctuation server port') - + parser.add_argument( + "--endpoint", + type=str, + default="/paddlespeech/asr/streaming", + help="ASR websocket endpoint") parser.add_argument( "--wavfile", action="store", diff --git a/paddlespeech/cli/vector/infer.py b/paddlespeech/cli/vector/infer.py index 37e19391..8afb0f5c 100644 --- a/paddlespeech/cli/vector/infer.py +++ b/paddlespeech/cli/vector/infer.py @@ -272,7 +272,8 @@ class VectorExecutor(BaseExecutor): model_type: str='ecapatdnn_voxceleb12', sample_rate: int=16000, cfg_path: Optional[os.PathLike]=None, - ckpt_path: Optional[os.PathLike]=None): + ckpt_path: Optional[os.PathLike]=None, + task=None): """Init the neural network from the model path Args: diff --git a/paddlespeech/server/README_cn.md b/paddlespeech/server/README_cn.md index e799bca8..010d3d51 100644 --- a/paddlespeech/server/README_cn.md +++ b/paddlespeech/server/README_cn.md @@ -63,3 +63,23 @@ paddlespeech_server start --config_file conf/tts_online_application.yaml ``` paddlespeech_client tts_online --server_ip 127.0.0.1 --port 8092 --input "您好,欢迎使用百度飞桨深度学习框架!" --output output.wav ``` + +## 声纹识别 + +### 启动声纹识别服务 + +``` +paddlespeech_server start --config_file conf/vector_application.yaml +``` + +### 获取说话人音频声纹 + +``` +paddlespeech_client vector --task spk --server_ip 127.0.0.1 --port 8090 --input 85236145389.wav +``` + +### 两个说话人音频声纹打分 + +``` +paddlespeech_client vector --task score --server_ip 127.0.0.1 --port 8090 --enroll 123456789.wav --test 85236145389.wav +``` \ No newline at end of file diff --git a/paddlespeech/server/bin/paddlespeech_client.py b/paddlespeech/server/bin/paddlespeech_client.py index 2f1ce385..f1f02d16 100644 --- a/paddlespeech/server/bin/paddlespeech_client.py +++ b/paddlespeech/server/bin/paddlespeech_client.py @@ -35,7 +35,7 @@ from paddlespeech.server.utils.util import wav2base64 __all__ = [ 'TTSClientExecutor', 'TTSOnlineClientExecutor', 'ASRClientExecutor', - 'ASROnlineClientExecutor', 'CLSClientExecutor' + 'ASROnlineClientExecutor', 'CLSClientExecutor', 'VectorClientExecutor' ] @@ -583,3 +583,100 @@ class TextClientExecutor(BaseExecutor): response_dict = res.json() punc_text = response_dict["result"]["punc_text"] return punc_text + + +@cli_client_register( + name='paddlespeech_client.vector', description='visit the vector service') +class VectorClientExecutor(BaseExecutor): + def __init__(self): + super(VectorClientExecutor, self).__init__() + self.parser = argparse.ArgumentParser( + prog='paddlespeech_client.vector', add_help=True) + self.parser.add_argument( + '--server_ip', type=str, default='127.0.0.1', help='server ip') + self.parser.add_argument( + '--port', type=int, default=8090, help='server port') + self.parser.add_argument( + '--input', + type=str, + default=None, + help='sentence to be process by text server.') + self.parser.add_argument( + '--task', type=str, default="spk", help="The vector service task") + self.parser.add_argument( + "--enroll", type=str, default=None, help="The enroll audio") + self.parser.add_argument( + "--test", type=str, default=None, help="The test audio") + + def execute(self, argv: List[str]) -> bool: + """Execute the request from the argv. + + Args: + argv (List): the request arguments + + Returns: + str: the request flag + """ + args = self.parser.parse_args(argv) + input_ = args.input + server_ip = args.server_ip + port = args.port + task = args.task + + try: + time_start = time.time() + res = self( + input=input_, + server_ip=server_ip, + port=port, + enroll_audio=args.enroll, + test_audio=args.test, + task=task) + time_end = time.time() + logger.info(f"The vector: {res}") + logger.info("Response time %f s." % (time_end - time_start)) + return True + except Exception as e: + logger.error("Failed to extract vector.") + logger.error(e) + return False + + @stats_wrapper + def __call__(self, + input: str, + server_ip: str="127.0.0.1", + port: int=8090, + audio_format: str="wav", + sample_rate: int=16000, + enroll_audio: str=None, + test_audio: str=None, + task="spk"): + """ + Python API to call text executor. + + Args: + input (str): the request sentence text + server_ip (str, optional): the server ip. Defaults to "127.0.0.1". + port (int, optional): the server port. Defaults to 8090. + + Returns: + str: the punctuation text + """ + if task == "spk": + from paddlespeech.server.utils.audio_handler import VectorHttpHandler + logger.info("vector http client start") + logger.info(f"the input audio: {input}") + handler = VectorHttpHandler(server_ip=server_ip, port=port) + res = handler.run(input, audio_format, sample_rate) + return res + elif task == "score": + from paddlespeech.server.utils.audio_handler import VectorScoreHttpHandler + logger.info("vector score http client start") + logger.info( + f"enroll audio: {enroll_audio}, test audio: {test_audio}") + handler = VectorScoreHttpHandler(server_ip=server_ip, port=port) + res = handler.run(enroll_audio, test_audio, audio_format, + sample_rate) + logger.info(f"The vector score is: {res}") + else: + logger.error(f"Sorry, we have not support such task {task}") diff --git a/paddlespeech/server/conf/application.yaml b/paddlespeech/server/conf/application.yaml index c8753059..b6a9942e 100644 --- a/paddlespeech/server/conf/application.yaml +++ b/paddlespeech/server/conf/application.yaml @@ -11,7 +11,7 @@ port: 8090 # protocol = ['websocket', 'http'] (only one can be selected). # http only support offline engine type. protocol: 'http' -engine_list: ['asr_python', 'tts_python', 'cls_python', 'text_python'] +engine_list: ['asr_python', 'tts_python', 'cls_python', 'text_python', 'vector_python'] ################################################################################# @@ -166,4 +166,15 @@ text_python: cfg_path: # [optional] ckpt_path: # [optional] vocab_file: # [optional] + device: # set 'gpu:id' or 'cpu' + + +################################### Vector ###################################### +################### Vector task: spk; engine_type: python ####################### +vector_python: + task: spk + model_type: 'ecapatdnn_voxceleb12' + sample_rate: 16000 + cfg_path: # [optional] + ckpt_path: # [optional] device: # set 'gpu:id' or 'cpu' \ No newline at end of file diff --git a/paddlespeech/server/conf/vector_application.yaml b/paddlespeech/server/conf/vector_application.yaml new file mode 100644 index 00000000..c78659e3 --- /dev/null +++ b/paddlespeech/server/conf/vector_application.yaml @@ -0,0 +1,32 @@ +# This is the parameter configuration file for PaddleSpeech Serving. + +################################################################################# +# SERVER SETTING # +################################################################################# +host: 0.0.0.0 +port: 8090 + +# The task format in the engin_list is: _ +# protocol = ['http'] (only one can be selected). +# http only support offline engine type. +protocol: 'http' +engine_list: ['vector_python'] + + +################################################################################# +# ENGINE CONFIG # +################################################################################# + +################################### Vector ###################################### +################### Vector task: spk; engine_type: python ####################### +vector_python: + task: spk + model_type: 'ecapatdnn_voxceleb12' + sample_rate: 16000 + cfg_path: # [optional] + ckpt_path: # [optional] + device: # set 'gpu:id' or 'cpu' + + + + diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index 990590b4..2e61bb4e 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -13,6 +13,7 @@ # limitations under the License. import copy import os +import time from typing import Optional import numpy as np @@ -153,6 +154,12 @@ class PaddleASRConnectionHanddler: self.n_shift = self.preprocess_conf.process[0]['n_shift'] def extract_feat(self, samples): + + # we compute the elapsed time of first char occuring + # and we record the start time at the first pcm sample arraving + # if self.first_char_occur_elapsed is not None: + # self.first_char_occur_elapsed = time.time() + if "deepspeech2online" in self.model_type: # self.reamined_wav stores all the samples, # include the original remained_wav and this package samples @@ -290,6 +297,7 @@ class PaddleASRConnectionHanddler: self.chunk_num = 0 self.global_frame_offset = 0 self.result_transcripts = [''] + self.first_char_occur_elapsed = None def decode(self, is_finished=False): if "deepspeech2online" in self.model_type: diff --git a/paddlespeech/server/engine/engine_factory.py b/paddlespeech/server/engine/engine_factory.py index 30e48de7..6cf95d75 100644 --- a/paddlespeech/server/engine/engine_factory.py +++ b/paddlespeech/server/engine/engine_factory.py @@ -49,5 +49,8 @@ class EngineFactory(object): elif engine_name.lower() == 'text' and engine_type.lower() == 'python': from paddlespeech.server.engine.text.python.text_engine import TextEngine return TextEngine() + elif engine_name.lower() == 'vector' and engine_type.lower() == 'python': + from paddlespeech.server.engine.vector.python.vector_engine import VectorEngine + return VectorEngine() else: return None diff --git a/paddlespeech/server/engine/vector/__init__.py b/paddlespeech/server/engine/vector/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/paddlespeech/server/engine/vector/python/__init__.py b/paddlespeech/server/engine/vector/python/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/paddlespeech/server/engine/vector/python/vector_engine.py b/paddlespeech/server/engine/vector/python/vector_engine.py new file mode 100644 index 00000000..866c2229 --- /dev/null +++ b/paddlespeech/server/engine/vector/python/vector_engine.py @@ -0,0 +1,200 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import io +from collections import OrderedDict + +import numpy as np +import paddle + +from paddleaudio.backends import load as load_audio +from paddleaudio.compliance.librosa import melspectrogram +from paddlespeech.cli.log import logger +from paddlespeech.cli.vector.infer import VectorExecutor +from paddlespeech.server.engine.base_engine import BaseEngine +from paddlespeech.vector.io.batch import feature_normalize + + +class PaddleVectorConnectionHandler: + def __init__(self, vector_engine): + """The PaddleSpeech Vector Server Connection Handler + This connection process every server request + Args: + vector_engine (VectorEngine): The Vector engine + """ + super().__init__() + logger.info( + "Create PaddleVectorConnectionHandler to process the vector request") + self.vector_engine = vector_engine + self.executor = self.vector_engine.executor + self.task = self.vector_engine.executor.task + self.model = self.vector_engine.executor.model + self.config = self.vector_engine.executor.config + + self._inputs = OrderedDict() + self._outputs = OrderedDict() + + @paddle.no_grad() + def run(self, audio_data, task="spk"): + """The connection process the http request audio + + Args: + audio_data (bytes): base64.b64decode + + Returns: + str: the punctuation text + """ + logger.info( + f"start to extract the do vector {self.task} from the http request") + if self.task == "spk" and task == "spk": + embedding = self.extract_audio_embedding(audio_data) + return embedding + else: + logger.error( + "The request task is not matched with server model task") + logger.error( + f"The server model task is: {self.task}, but the request task is: {task}" + ) + + return np.array([ + 0.0, + ]) + + @paddle.no_grad() + def get_enroll_test_score(self, enroll_audio, test_audio): + """Get the enroll and test audio score + + Args: + enroll_audio (str): the base64 format enroll audio + test_audio (str): the base64 format test audio + + Returns: + float: the score between enroll and test audio + """ + logger.info("start to extract the enroll audio embedding") + enroll_emb = self.extract_audio_embedding(enroll_audio) + + logger.info("start to extract the test audio embedding") + test_emb = self.extract_audio_embedding(test_audio) + + logger.info( + "start to get the score between the enroll and test embedding") + score = self.executor.get_embeddings_score(enroll_emb, test_emb) + + logger.info(f"get the enroll vs test score: {score}") + return score + + @paddle.no_grad() + def extract_audio_embedding(self, audio: str, sample_rate: int=16000): + """extract the audio embedding + + Args: + audio (_type_): _description_ + sample_rate (int, optional): _description_. Defaults to 16000. + """ + # we can not reuse the cache io.BytesIO(audio) data, + # because the soundfile will change the io.BytesIO(audio) to the end + # thus we should convert the base64 string to io.BytesIO when we need the audio data + if not self.executor._check(io.BytesIO(audio), sample_rate): + logger.info("check the audio sample rate occurs error") + return np.array([0.0]) + + waveform, sr = load_audio(io.BytesIO(audio)) + logger.info(f"load the audio sample points, shape is: {waveform.shape}") + + # stage 2: get the audio feat + # Note: Now we only support fbank feature + try: + feats = melspectrogram( + x=waveform, + sr=self.config.sr, + n_mels=self.config.n_mels, + window_size=self.config.window_size, + hop_length=self.config.hop_size) + logger.info(f"extract the audio feats, shape is: {feats.shape}") + except Exception as e: + logger.info(f"feats occurs exception {e}") + sys.exit(-1) + + feats = paddle.to_tensor(feats).unsqueeze(0) + # in inference period, the lengths is all one without padding + lengths = paddle.ones([1]) + + # stage 3: we do feature normalize, + # Now we assume that the feats must do normalize + feats = feature_normalize(feats, mean_norm=True, std_norm=False) + + # stage 4: store the feats and length in the _inputs, + # which will be used in other function + logger.info(f"feats shape: {feats.shape}") + logger.info("audio extract the feats success") + + logger.info("start to extract the audio embedding") + embedding = self.model.backbone(feats, lengths).squeeze().numpy() + logger.info(f"embedding size: {embedding.shape}") + + return embedding + + +class VectorServerExecutor(VectorExecutor): + def __init__(self): + """The wrapper for TextEcutor + """ + super().__init__() + pass + + +class VectorEngine(BaseEngine): + def __init__(self): + """The Vector Engine + """ + super(VectorEngine, self).__init__() + logger.info("Create the VectorEngine Instance") + + def init(self, config: dict): + """Init the Vector Engine + + Args: + config (dict): The server configuation + + Returns: + bool: The engine instance flag + """ + logger.info("Init the vector engine") + try: + self.config = config + if self.config.device: + self.device = self.config.device + else: + self.device = paddle.get_device() + + paddle.set_device(self.device) + logger.info(f"Vector Engine set the device: {self.device}") + except BaseException as e: + logger.error( + "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" + ) + logger.error("Initialize Vector server engine Failed on device: %s." + % (self.device)) + return False + + self.executor = VectorServerExecutor() + + self.executor._init_from_path( + model_type=config.model_type, + cfg_path=config.cfg_path, + ckpt_path=config.ckpt_path, + task=config.task) + + logger.info("Init the Vector engine successfully") + return True diff --git a/paddlespeech/server/restful/api.py b/paddlespeech/server/restful/api.py index d5e422e3..f1e4ffc8 100644 --- a/paddlespeech/server/restful/api.py +++ b/paddlespeech/server/restful/api.py @@ -21,7 +21,7 @@ from paddlespeech.server.restful.asr_api import router as asr_router from paddlespeech.server.restful.cls_api import router as cls_router from paddlespeech.server.restful.text_api import router as text_router from paddlespeech.server.restful.tts_api import router as tts_router - +from paddlespeech.server.restful.vector_api import router as vec_router _router = APIRouter() @@ -43,6 +43,8 @@ def setup_router(api_list: List): _router.include_router(cls_router) elif api_name == 'text': _router.include_router(text_router) + elif api_name.lower() == 'vector': + _router.include_router(vec_router) else: logger.error( f"PaddleSpeech has not support such service: {api_name}") diff --git a/paddlespeech/server/restful/request.py b/paddlespeech/server/restful/request.py index 50416627..b23ed76d 100644 --- a/paddlespeech/server/restful/request.py +++ b/paddlespeech/server/restful/request.py @@ -15,7 +15,7 @@ from typing import Optional from pydantic import BaseModel -__all__ = ['ASRRequest', 'TTSRequest', 'CLSRequest'] +__all__ = ['ASRRequest', 'TTSRequest', 'CLSRequest', 'VectorRequest'] #****************************************************************************************/ @@ -85,3 +85,40 @@ class CLSRequest(BaseModel): #****************************************************************************************/ class TextRequest(BaseModel): text: str + + +#****************************************************************************************/ +#************************************ Vecotr request ************************************/ +#****************************************************************************************/ +class VectorRequest(BaseModel): + """ + request body example + { + "audio": "exSI6ICJlbiIsCgkgICAgInBvc2l0aW9uIjogImZhbHNlIgoJf...", + "task": "spk", + "audio_format": "wav", + "sample_rate": 16000, + } + """ + audio: str + task: str + audio_format: str + sample_rate: int + + +class VectorScoreRequest(BaseModel): + """ + request body example + { + "enroll_audio": "exSI6ICJlbiIsCgkgICAgInBvc2l0aW9uIjogImZhbHNlIgoJf...", + "test_audio": "exSI6ICJlbiIsCgkgICAgInBvc2l0aW9uIjogImZhbHNlIgoJf...", + "task": "spk", + "audio_format": "wav", + "sample_rate": 16000, + } + """ + enroll_audio: str + test_audio: str + task: str + audio_format: str + sample_rate: int diff --git a/paddlespeech/server/restful/response.py b/paddlespeech/server/restful/response.py index 5792959e..f8cdb3cf 100644 --- a/paddlespeech/server/restful/response.py +++ b/paddlespeech/server/restful/response.py @@ -129,6 +129,11 @@ class CLSResponse(BaseModel): result: CLSResult +#****************************************************************************************/ +#************************************ Text response **************************************/ +#****************************************************************************************/ + + class TextResult(BaseModel): punc_text: str @@ -153,6 +158,57 @@ class TextResponse(BaseModel): result: TextResult +#****************************************************************************************/ +#************************************ Vector response **************************************/ +#****************************************************************************************/ + + +class VectorResult(BaseModel): + vec: list + + +class VectorResponse(BaseModel): + """ + response example + { + "success": true, + "code": 0, + "message": { + "description": "success" + }, + "result": { + "vec": [1.0, 1.0] + } + } + """ + success: bool + code: int + message: Message + result: VectorResult + + +class VectorScoreResult(BaseModel): + score: float + +class VectorScoreResponse(BaseModel): + """ + response example + { + "success": true, + "code": 0, + "message": { + "description": "success" + }, + "result": { + "score": 1.0 + } + } + """ + success: bool + code: int + message: Message + result: VectorScoreResult + #****************************************************************************************/ #********************************** Error response **************************************/ #****************************************************************************************/ diff --git a/paddlespeech/server/restful/vector_api.py b/paddlespeech/server/restful/vector_api.py new file mode 100644 index 00000000..6e04f48e --- /dev/null +++ b/paddlespeech/server/restful/vector_api.py @@ -0,0 +1,151 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import base64 +import traceback +from typing import Union + +import numpy as np +from fastapi import APIRouter + +from paddlespeech.cli.log import logger +from paddlespeech.server.engine.engine_pool import get_engine_pool +from paddlespeech.server.engine.vector.python.vector_engine import PaddleVectorConnectionHandler +from paddlespeech.server.restful.request import VectorRequest +from paddlespeech.server.restful.request import VectorScoreRequest +from paddlespeech.server.restful.response import ErrorResponse +from paddlespeech.server.restful.response import VectorResponse +from paddlespeech.server.restful.response import VectorScoreResponse +from paddlespeech.server.utils.errors import ErrorCode +from paddlespeech.server.utils.errors import failed_response +from paddlespeech.server.utils.exception import ServerBaseException +router = APIRouter() + + +@router.get('/paddlespeech/vector/help') +def help(): + """help + + Returns: + json: The /paddlespeech/vector api response content + """ + response = { + "success": "True", + "code": 200, + "message": { + "global": "success" + }, + "vector": [2.3, 3.5, 5.5, 6.2, 2.8, 1.2, 0.3, 3.6] + } + return response + + +@router.post( + "/paddlespeech/vector", response_model=Union[VectorResponse, ErrorResponse]) +def vector(request_body: VectorRequest): + """vector api + + Args: + request_body (VectorRequest): the vector request body + + Returns: + json: the vector response body + """ + try: + # 1. get the audio data + # the audio must be base64 format + audio_data = base64.b64decode(request_body.audio) + + # 2. get single engine from engine pool + # and we use the vector_engine to create an connection handler to process the request + engine_pool = get_engine_pool() + vector_engine = engine_pool['vector'] + connection_handler = PaddleVectorConnectionHandler(vector_engine) + + # 3. we use the connection handler to process the audio + audio_vec = connection_handler.run(audio_data, request_body.task) + + # 4. we need the result of the vector instance be numpy.ndarray + if not isinstance(audio_vec, np.ndarray): + logger.error( + f"the vector type is not numpy.array, that is: {type(audio_vec)}" + ) + error_reponse = ErrorResponse() + error_reponse.message.description = f"the vector type is not numpy.array, that is: {type(audio_vec)}" + return error_reponse + + response = { + "success": True, + "code": 200, + "message": { + "description": "success" + }, + "result": { + "vec": audio_vec.tolist() + } + } + + except ServerBaseException as e: + response = failed_response(e.error_code, e.msg) + except BaseException: + response = failed_response(ErrorCode.SERVER_UNKOWN_ERR) + traceback.print_exc() + + return response + + +@router.post( + "/paddlespeech/vector/score", + response_model=Union[VectorScoreResponse, ErrorResponse]) +def score(request_body: VectorScoreRequest): + """vector api + + Args: + request_body (VectorScoreRequest): the punctuation request body + + Returns: + json: the punctuation response body + """ + try: + # 1. get the audio data + # the audio must be base64 format + enroll_data = base64.b64decode(request_body.enroll_audio) + test_data = base64.b64decode(request_body.test_audio) + + # 2. get single engine from engine pool + # and we use the vector_engine to create an connection handler to process the request + engine_pool = get_engine_pool() + vector_engine = engine_pool['vector'] + connection_handler = PaddleVectorConnectionHandler(vector_engine) + + # 3. we use the connection handler to process the audio + score = connection_handler.get_enroll_test_score(enroll_data, test_data) + + response = { + "success": True, + "code": 200, + "message": { + "description": "success" + }, + "result": { + "score": score + } + } + + except ServerBaseException as e: + response = failed_response(e.error_code, e.msg) + except BaseException: + response = failed_response(ErrorCode.SERVER_UNKOWN_ERR) + traceback.print_exc() + + return response diff --git a/paddlespeech/server/utils/audio_handler.py b/paddlespeech/server/utils/audio_handler.py index f0ec0eaa..a088929f 100644 --- a/paddlespeech/server/utils/audio_handler.py +++ b/paddlespeech/server/utils/audio_handler.py @@ -142,6 +142,7 @@ class ASRWsAudioHandler: return "" # 1. send websocket handshake protocal + start_time = time.time() async with websockets.connect(self.url) as ws: # 2. server has already received handshake protocal # client start to send the command @@ -187,7 +188,14 @@ class ASRWsAudioHandler: if self.punc_server: msg["result"] = self.punc_server.run(msg["result"]) + # 6. logging the final result and comptute the statstics + elapsed_time = time.time() - start_time + audio_info = soundfile.info(wavfile_path) logger.info("client final receive msg={}".format(msg)) + logger.info( + f"audio duration: {audio_info.duration}, elapsed time: {elapsed_time}, RTF={elapsed_time/audio_info.duration}" + ) + result = msg return result @@ -456,3 +464,96 @@ class TTSHttpHandler: self.stream.stop_stream() self.stream.close() self.p.terminate() + + +class VectorHttpHandler: + def __init__(self, server_ip=None, port=None): + """The Vector client http request + + Args: + server_ip (str, optional): the http vector server ip. Defaults to "127.0.0.1". + port (int, optional): the http vector server port. Defaults to 8090. + """ + super().__init__() + self.server_ip = server_ip + self.port = port + if server_ip is None or port is None: + self.url = None + else: + self.url = 'http://' + self.server_ip + ":" + str( + self.port) + '/paddlespeech/vector' + + def run(self, input, audio_format, sample_rate, task="spk"): + """Call the http asr to process the audio + + Args: + input (str): the audio file path + audio_format (str): the audio format + sample_rate (str): the audio sample rate + + Returns: + list: the audio vector + """ + if self.url is None: + logger.error("No vector server, please input valid ip and port") + return "" + + audio = wav2base64(input) + data = { + "audio": audio, + "task": task, + "audio_format": audio_format, + "sample_rate": sample_rate, + } + + logger.info(self.url) + res = requests.post(url=self.url, data=json.dumps(data)) + + return res.json() + + +class VectorScoreHttpHandler: + def __init__(self, server_ip=None, port=None): + """The Vector score client http request + + Args: + server_ip (str, optional): the http vector server ip. Defaults to "127.0.0.1". + port (int, optional): the http vector server port. Defaults to 8090. + """ + super().__init__() + self.server_ip = server_ip + self.port = port + if server_ip is None or port is None: + self.url = None + else: + self.url = 'http://' + self.server_ip + ":" + str( + self.port) + '/paddlespeech/vector/score' + + def run(self, enroll_audio, test_audio, audio_format, sample_rate): + """Call the http asr to process the audio + + Args: + input (str): the audio file path + audio_format (str): the audio format + sample_rate (str): the audio sample rate + + Returns: + list: the audio vector + """ + if self.url is None: + logger.error("No vector server, please input valid ip and port") + return "" + + enroll_audio = wav2base64(enroll_audio) + test_audio = wav2base64(test_audio) + data = { + "enroll_audio": enroll_audio, + "test_audio": test_audio, + "task": "score", + "audio_format": audio_format, + "sample_rate": sample_rate, + } + + res = requests.post(url=self.url, data=json.dumps(data)) + + return res.json() From 3950557e043e239526162cec4b42d334457d2a41 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Sun, 1 May 2022 23:50:08 +0800 Subject: [PATCH 02/11] update the vector server note, test=doc --- paddlespeech/server/bin/paddlespeech_client.py | 10 +++++++--- paddlespeech/server/restful/request.py | 5 ++++- paddlespeech/server/restful/response.py | 7 ++++++- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/paddlespeech/server/bin/paddlespeech_client.py b/paddlespeech/server/bin/paddlespeech_client.py index f1f02d16..32f78942 100644 --- a/paddlespeech/server/bin/paddlespeech_client.py +++ b/paddlespeech/server/bin/paddlespeech_client.py @@ -655,12 +655,16 @@ class VectorClientExecutor(BaseExecutor): Python API to call text executor. Args: - input (str): the request sentence text + input (str): the request audio data server_ip (str, optional): the server ip. Defaults to "127.0.0.1". port (int, optional): the server port. Defaults to 8090. - + audio_format (str, optional): audio format. Defaults to "wav". + sample_rate (str, optional): audio sample rate. Defaults to 16000. + enroll_audio (str, optional): enroll audio data. Defaults to None. + test_audio (str, optional): test audio data. Defaults to None. + task (str, optional): the task type, "spk" or "socre". Defaults to "spk" Returns: - str: the punctuation text + str: the audio embedding or score between enroll and test audio """ if task == "spk": from paddlespeech.server.utils.audio_handler import VectorHttpHandler diff --git a/paddlespeech/server/restful/request.py b/paddlespeech/server/restful/request.py index b23ed76d..4e88280a 100644 --- a/paddlespeech/server/restful/request.py +++ b/paddlespeech/server/restful/request.py @@ -15,7 +15,10 @@ from typing import Optional from pydantic import BaseModel -__all__ = ['ASRRequest', 'TTSRequest', 'CLSRequest', 'VectorRequest'] +__all__ = [ + 'ASRRequest', 'TTSRequest', 'CLSRequest', 'VectorRequest', + 'VectorScoreRequest' +] #****************************************************************************************/ diff --git a/paddlespeech/server/restful/response.py b/paddlespeech/server/restful/response.py index f8cdb3cf..c91b3899 100644 --- a/paddlespeech/server/restful/response.py +++ b/paddlespeech/server/restful/response.py @@ -15,7 +15,10 @@ from typing import List from pydantic import BaseModel -__all__ = ['ASRResponse', 'TTSResponse', 'CLSResponse'] +__all__ = [ + 'ASRResponse', 'TTSResponse', 'CLSResponse', 'TextResponse', + 'VectorResponse', 'VectorScoreResponse' +] class Message(BaseModel): @@ -190,6 +193,7 @@ class VectorResponse(BaseModel): class VectorScoreResult(BaseModel): score: float + class VectorScoreResponse(BaseModel): """ response example @@ -209,6 +213,7 @@ class VectorScoreResponse(BaseModel): message: Message result: VectorScoreResult + #****************************************************************************************/ #********************************** Error response **************************************/ #****************************************************************************************/ From c78653850b020ef54590a744eebe80b6a096af56 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Mon, 2 May 2022 20:11:34 +0800 Subject: [PATCH 03/11] join streaming asr and punc server, test=doc --- demos/streaming_asr_server/README.md | 272 +++++++++++++++++- demos/streaming_asr_server/README_cn.md | 272 ++++++++++++++++++ .../conf/punc_application.yaml | 35 +++ .../conf/ws_conformer_application.yaml | 4 +- demos/streaming_asr_server/punc_server.py | 38 +++ demos/streaming_asr_server/server.sh | 5 + .../streaming_asr_server.py | 38 +++ demos/streaming_asr_server/test.sh | 7 +- .../server/bin/paddlespeech_client.py | 42 ++- 9 files changed, 703 insertions(+), 10 deletions(-) create mode 100644 demos/streaming_asr_server/conf/punc_application.yaml create mode 100644 demos/streaming_asr_server/punc_server.py create mode 100755 demos/streaming_asr_server/server.sh create mode 100644 demos/streaming_asr_server/streaming_asr_server.py diff --git a/demos/streaming_asr_server/README.md b/demos/streaming_asr_server/README.md index 3de2f386..48cfbaf3 100644 --- a/demos/streaming_asr_server/README.md +++ b/demos/streaming_asr_server/README.md @@ -355,4 +355,274 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav [2022-04-21 15:59:08,016] [ INFO] - receive msg={'asr_results': '我认为跑步最重要的就是给我带来了身体健康'} [2022-04-21 15:59:08,024] [ INFO] - receive msg={'asr_results': '我认为跑步最重要的就是给我带来了身体健康'} [2022-04-21 15:59:12,883] [ INFO] - final receive msg={'status': 'ok', 'signal': 'finished', 'asr_results': '我认为跑步最重要的就是给我带来了身体健康'} - ``` \ No newline at end of file + ``` + + +## Punctuation service + +### 1. Server usage + +- Command Line + ``` bash + In PaddleSpeech/demos/streaming_asr_server directory to lanuch punctuation service + paddlespeech_server start --config_file conf/punc_application.yaml + ``` + + + Usage: + ```bash + paddlespeech_server start --help + ``` + + Arguments: + - `config_file`: configuration file. + - `log_file`: log file. + + + Output: + ``` bash + [2022-05-02 17:59:26,285] [ INFO] - Create the TextEngine Instance + [2022-05-02 17:59:26,285] [ INFO] - Init the text engine + [2022-05-02 17:59:26,285] [ INFO] - Text Engine set the device: gpu:0 + [2022-05-02 17:59:26,286] [ INFO] - File /home/users/xiongxinlei/.paddlespeech/models/ernie_linear_p3_wudao-punc-zh/ernie_linear_p3_wudao-punc-zh.tar.gz md5 checking... + [2022-05-02 17:59:30,810] [ INFO] - Use pretrained model stored in: /home/users/xiongxinlei/.paddlespeech/models/ernie_linear_p3_wudao-punc-zh/ernie_linear_p3_wudao-punc-zh.tar + W0502 17:59:31.486552 9595 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 6.1, Driver API Version: 10.2, Runtime API Version: 10.2 + W0502 17:59:31.491360 9595 device_context.cc:465] device: 0, cuDNN Version: 7.6. + [2022-05-02 17:59:34,688] [ INFO] - Already cached /home/users/xiongxinlei/.paddlenlp/models/ernie-1.0/vocab.txt + [2022-05-02 17:59:34,701] [ INFO] - Init the text engine successfully + INFO: Started server process [9595] + [2022-05-02 17:59:34] [INFO] [server.py:75] Started server process [9595] + INFO: Waiting for application startup. + [2022-05-02 17:59:34] [INFO] [on.py:45] Waiting for application startup. + INFO: Application startup complete. + [2022-05-02 17:59:34] [INFO] [on.py:59] Application startup complete. + INFO: Uvicorn running on http://0.0.0.0:8190 (Press CTRL+C to quit) + [2022-05-02 17:59:34] [INFO] [server.py:206] Uvicorn running on http://0.0.0.0:8190 (Press CTRL+C to quit) + ``` + +- Python API + + ```python + # 在 PaddleSpeech/demos/streaming_asr_server 目录 + from paddlespeech.server.bin.paddlespeech_server import ServerExecutor + + server_executor = ServerExecutor() + server_executor( + config_file="./conf/punc_application.yaml", + log_file="./log/paddlespeech.log") + ``` + + Output: + ``` + [2022-05-02 18:09:02,542] [ INFO] - Create the TextEngine Instance + [2022-05-02 18:09:02,543] [ INFO] - Init the text engine + [2022-05-02 18:09:02,543] [ INFO] - Text Engine set the device: gpu:0 + [2022-05-02 18:09:02,545] [ INFO] - File /home/users/xiongxinlei/.paddlespeech/models/ernie_linear_p3_wudao-punc-zh/ernie_linear_p3_wudao-punc-zh.tar.gz md5 checking... + [2022-05-02 18:09:06,919] [ INFO] - Use pretrained model stored in: /home/users/xiongxinlei/.paddlespeech/models/ernie_linear_p3_wudao-punc-zh/ernie_linear_p3_wudao-punc-zh.tar + W0502 18:09:07.523002 22615 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 6.1, Driver API Version: 10.2, Runtime API Version: 10.2 + W0502 18:09:07.527882 22615 device_context.cc:465] device: 0, cuDNN Version: 7.6. + [2022-05-02 18:09:10,900] [ INFO] - Already cached /home/users/xiongxinlei/.paddlenlp/models/ernie-1.0/vocab.txt + [2022-05-02 18:09:10,913] [ INFO] - Init the text engine successfully + INFO: Started server process [22615] + [2022-05-02 18:09:10] [INFO] [server.py:75] Started server process [22615] + INFO: Waiting for application startup. + [2022-05-02 18:09:10] [INFO] [on.py:45] Waiting for application startup. + INFO: Application startup complete. + [2022-05-02 18:09:10] [INFO] [on.py:59] Application startup complete. + INFO: Uvicorn running on http://0.0.0.0:8190 (Press CTRL+C to quit) + [2022-05-02 18:09:10] [INFO] [server.py:206] Uvicorn running on http://0.0.0.0:8190 (Press CTRL+C to quit) + ``` + +### 2. Client usage +**Note** The response time will be slightly longer when using the client for the first time + +- Command line + ``` + paddlespeech_client text --server_ip 127.0.0.1 --port 8190 --input "我认为跑步最重要的就是给我带来了身体健康" + ``` + + Output + ``` + [2022-05-02 18:12:29,767] [ INFO] - The punc text: 我认为跑步最重要的就是给我带来了身体健康。 + [2022-05-02 18:12:29,767] [ INFO] - Response time 0.096548 s. + ``` + +- Python3 API + + ```python + from paddlespeech.server.bin.paddlespeech_client import TextClientExecutor + + textclient_executor = TextClientExecutor() + res = textclient_executor( + input="我认为跑步最重要的就是给我带来了身体健康", + server_ip="127.0.0.1", + port=8190,) + print(res) + ``` + + Output: + ``` bash + 我认为跑步最重要的就是给我带来了身体健康。 + ``` + + +## Join streaming asr and punctuation server +We use `streaming_ asr_server.py` and `punc_server.py` two services to lanuch streaming speech recognition and punctuation prediction services respectively. And the `websocket_client.py` script can be used to call streaming speech recognition and punctuation prediction services at the same time. + +### 1. Start two server + +``` bash +Note: streaming speech recognition and punctuation prediction are configured on different graphics cards through configuration files +bash server.sh +``` + +### 2. Call client +- Command line + ``` + paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --input ./zh.wav + ``` + Output: + ``` + [2022-05-02 18:57:46,961] [ INFO] - asr websocket client start + [2022-05-02 18:57:46,961] [ INFO] - endpoint: ws://127.0.0.1:8290/paddlespeech/asr/streaming + [2022-05-02 18:57:46,982] [ INFO] - client receive msg={"status": "ok", "signal": "server_ready"} + [2022-05-02 18:57:46,999] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:57:47,011] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:57:47,023] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:57:47,035] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:57:47,046] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:57:47,057] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:57:47,068] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:57:47,079] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:57:47,222] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:57:47,230] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:57:47,239] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:57:47,247] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:57:47,255] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:57:47,263] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:57:47,271] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:57:47,462] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:57:47,525] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:57:47,589] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:57:47,649] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:57:47,708] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:57:47,766] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:57:47,824] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:57:47,881] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:57:48,130] [ INFO] - client receive msg={'result': '我认为,跑步最重要的。'} + [2022-05-02 18:57:48,200] [ INFO] - client receive msg={'result': '我认为,跑步最重要的。'} + [2022-05-02 18:57:48,265] [ INFO] - client receive msg={'result': '我认为,跑步最重要的。'} + [2022-05-02 18:57:48,327] [ INFO] - client receive msg={'result': '我认为,跑步最重要的。'} + [2022-05-02 18:57:48,389] [ INFO] - client receive msg={'result': '我认为,跑步最重要的。'} + [2022-05-02 18:57:48,448] [ INFO] - client receive msg={'result': '我认为,跑步最重要的。'} + [2022-05-02 18:57:48,505] [ INFO] - client receive msg={'result': '我认为,跑步最重要的。'} + [2022-05-02 18:57:48,754] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:57:48,821] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:57:48,881] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:57:48,939] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:57:49,011] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:57:49,080] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:57:49,146] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:57:49,210] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:57:49,452] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是给。'} + [2022-05-02 18:57:49,516] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是给。'} + [2022-05-02 18:57:49,581] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是给。'} + [2022-05-02 18:57:49,645] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是给。'} + [2022-05-02 18:57:49,706] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是给。'} + [2022-05-02 18:57:49,763] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是给。'} + [2022-05-02 18:57:49,818] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是给。'} + [2022-05-02 18:57:50,064] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:57:50,125] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:57:50,186] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:57:50,245] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:57:50,301] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:57:50,358] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:57:50,414] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:57:50,469] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:57:50,712] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了身体健康。'} + [2022-05-02 18:57:50,776] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了身体健康。'} + [2022-05-02 18:57:50,837] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了身体健康。'} + [2022-05-02 18:57:50,897] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了身体健康。'} + [2022-05-02 18:57:50,956] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了身体健康。'} + [2022-05-02 18:57:51,012] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了身体健康。'} + [2022-05-02 18:57:51,276] [ INFO] - client final receive msg={'status': 'ok', 'signal': 'finished', 'result': '我认为跑步最重要的就是给我带来了身体健康。'} + [2022-05-02 18:57:51,277] [ INFO] - asr websocket client finished + [2022-05-02 18:57:51,277] [ INFO] - 我认为跑步最重要的就是给我带来了身体健康。 + [2022-05-02 18:57:51,277] [ INFO] - Response time 4.316903 s. + ``` + +- Use script + ``` + python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --wavfile ./zh.wav + ``` + Output: + ``` + [2022-05-02 18:29:22,039] [ INFO] - Start to do streaming asr client + [2022-05-02 18:29:22,040] [ INFO] - asr websocket client start + [2022-05-02 18:29:22,040] [ INFO] - endpoint: ws://127.0.0.1:8290/paddlespeech/asr/streaming + [2022-05-02 18:29:22,041] [ INFO] - start to process the wavscp: ./zh.wav + [2022-05-02 18:29:22,122] [ INFO] - client receive msg={"status": "ok", "signal": "server_ready"} + [2022-05-02 18:29:22,351] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:29:22,360] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:29:22,368] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:29:22,376] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:29:22,384] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:29:22,392] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:29:22,400] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:29:22,408] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:29:22,549] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:29:22,558] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:29:22,567] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:29:22,575] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:29:22,583] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:29:22,591] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:29:22,599] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:29:22,822] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:29:22,879] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:29:22,937] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:29:22,995] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:29:23,052] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:29:23,107] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:29:23,161] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:29:23,213] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:29:23,454] [ INFO] - client receive msg={'result': '我认为,跑步最重要的。'} + [2022-05-02 18:29:23,515] [ INFO] - client receive msg={'result': '我认为,跑步最重要的。'} + [2022-05-02 18:29:23,575] [ INFO] - client receive msg={'result': '我认为,跑步最重要的。'} + [2022-05-02 18:29:23,630] [ INFO] - client receive msg={'result': '我认为,跑步最重要的。'} + [2022-05-02 18:29:23,684] [ INFO] - client receive msg={'result': '我认为,跑步最重要的。'} + [2022-05-02 18:29:23,736] [ INFO] - client receive msg={'result': '我认为,跑步最重要的。'} + [2022-05-02 18:29:23,789] [ INFO] - client receive msg={'result': '我认为,跑步最重要的。'} + [2022-05-02 18:29:24,030] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:29:24,095] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:29:24,156] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:29:24,213] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:29:24,268] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:29:24,323] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:29:24,377] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:29:24,429] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:29:24,671] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是给。'} + [2022-05-02 18:29:24,736] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是给。'} + [2022-05-02 18:29:24,797] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是给。'} + [2022-05-02 18:29:24,857] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是给。'} + [2022-05-02 18:29:24,918] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是给。'} + [2022-05-02 18:29:24,975] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是给。'} + [2022-05-02 18:29:25,029] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是给。'} + [2022-05-02 18:29:25,271] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:29:25,336] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:29:25,398] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:29:25,458] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:29:25,521] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:29:25,579] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:29:25,652] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:29:25,722] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:29:25,969] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了身体健康。'} + [2022-05-02 18:29:26,034] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了身体健康。'} + [2022-05-02 18:29:26,095] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了身体健康。'} + [2022-05-02 18:29:26,163] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了身体健康。'} + [2022-05-02 18:29:26,229] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了身体健康。'} + [2022-05-02 18:29:26,294] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了身体健康。'} + [2022-05-02 18:29:26,565] [ INFO] - client final receive msg={'status': 'ok', 'signal': 'finished', 'result': '我认为跑步最重要的就是给我带来了身体健康。'} + [2022-05-02 18:29:26,566] [ INFO] - asr websocket client finished : 我认为跑步最重要的就是给我带来了身体健康。 + ``` + + \ No newline at end of file diff --git a/demos/streaming_asr_server/README_cn.md b/demos/streaming_asr_server/README_cn.md index bb1d3772..67f62860 100644 --- a/demos/streaming_asr_server/README_cn.md +++ b/demos/streaming_asr_server/README_cn.md @@ -363,3 +363,275 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav [2022-04-21 15:59:08,024] [ INFO] - receive msg={'asr_results': '我认为跑步最重要的就是给我带来了身体健康'} [2022-04-21 15:59:12,883] [ INFO] - final receive msg={'status': 'ok', 'signal': 'finished', 'asr_results': '我认为跑步最重要的就是给我带来了身体健康'} ``` + + + +## 标点预测 + +### 1. 服务端使用方法 + +- 命令行 + ``` bash + 在 PaddleSpeech/demos/streaming_asr_server 目录下启动标点预测服务 + paddlespeech_server start --config_file conf/punc_application.yaml + ``` + + + 使用方法: + + ```bash + paddlespeech_server start --help + ``` + + 参数: + - `config_file`: 服务的配置文件。 + - `log_file`: log 文件。 + + + 输出: + ``` bash + [2022-05-02 17:59:26,285] [ INFO] - Create the TextEngine Instance + [2022-05-02 17:59:26,285] [ INFO] - Init the text engine + [2022-05-02 17:59:26,285] [ INFO] - Text Engine set the device: gpu:0 + [2022-05-02 17:59:26,286] [ INFO] - File /home/users/xiongxinlei/.paddlespeech/models/ernie_linear_p3_wudao-punc-zh/ernie_linear_p3_wudao-punc-zh.tar.gz md5 checking... + [2022-05-02 17:59:30,810] [ INFO] - Use pretrained model stored in: /home/users/xiongxinlei/.paddlespeech/models/ernie_linear_p3_wudao-punc-zh/ernie_linear_p3_wudao-punc-zh.tar + W0502 17:59:31.486552 9595 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 6.1, Driver API Version: 10.2, Runtime API Version: 10.2 + W0502 17:59:31.491360 9595 device_context.cc:465] device: 0, cuDNN Version: 7.6. + [2022-05-02 17:59:34,688] [ INFO] - Already cached /home/users/xiongxinlei/.paddlenlp/models/ernie-1.0/vocab.txt + [2022-05-02 17:59:34,701] [ INFO] - Init the text engine successfully + INFO: Started server process [9595] + [2022-05-02 17:59:34] [INFO] [server.py:75] Started server process [9595] + INFO: Waiting for application startup. + [2022-05-02 17:59:34] [INFO] [on.py:45] Waiting for application startup. + INFO: Application startup complete. + [2022-05-02 17:59:34] [INFO] [on.py:59] Application startup complete. + INFO: Uvicorn running on http://0.0.0.0:8190 (Press CTRL+C to quit) + [2022-05-02 17:59:34] [INFO] [server.py:206] Uvicorn running on http://0.0.0.0:8190 (Press CTRL+C to quit) + ``` + +- Python API + + ```python + # 在 PaddleSpeech/demos/streaming_asr_server 目录 + from paddlespeech.server.bin.paddlespeech_server import ServerExecutor + + server_executor = ServerExecutor() + server_executor( + config_file="./conf/punc_application.yaml", + log_file="./log/paddlespeech.log") + ``` + + 输出 + ``` + [2022-05-02 18:09:02,542] [ INFO] - Create the TextEngine Instance + [2022-05-02 18:09:02,543] [ INFO] - Init the text engine + [2022-05-02 18:09:02,543] [ INFO] - Text Engine set the device: gpu:0 + [2022-05-02 18:09:02,545] [ INFO] - File /home/users/xiongxinlei/.paddlespeech/models/ernie_linear_p3_wudao-punc-zh/ernie_linear_p3_wudao-punc-zh.tar.gz md5 checking... + [2022-05-02 18:09:06,919] [ INFO] - Use pretrained model stored in: /home/users/xiongxinlei/.paddlespeech/models/ernie_linear_p3_wudao-punc-zh/ernie_linear_p3_wudao-punc-zh.tar + W0502 18:09:07.523002 22615 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 6.1, Driver API Version: 10.2, Runtime API Version: 10.2 + W0502 18:09:07.527882 22615 device_context.cc:465] device: 0, cuDNN Version: 7.6. + [2022-05-02 18:09:10,900] [ INFO] - Already cached /home/users/xiongxinlei/.paddlenlp/models/ernie-1.0/vocab.txt + [2022-05-02 18:09:10,913] [ INFO] - Init the text engine successfully + INFO: Started server process [22615] + [2022-05-02 18:09:10] [INFO] [server.py:75] Started server process [22615] + INFO: Waiting for application startup. + [2022-05-02 18:09:10] [INFO] [on.py:45] Waiting for application startup. + INFO: Application startup complete. + [2022-05-02 18:09:10] [INFO] [on.py:59] Application startup complete. + INFO: Uvicorn running on http://0.0.0.0:8190 (Press CTRL+C to quit) + [2022-05-02 18:09:10] [INFO] [server.py:206] Uvicorn running on http://0.0.0.0:8190 (Press CTRL+C to quit) + ``` + +### 2. 标点预测客户端使用方法 +**注意:** 初次使用客户端时响应时间会略长 + +- 命令行 (推荐使用) + ``` + paddlespeech_client text --server_ip 127.0.0.1 --port 8190 --input "我认为跑步最重要的就是给我带来了身体健康" + ``` + + 输出 + ``` + [2022-05-02 18:12:29,767] [ INFO] - The punc text: 我认为跑步最重要的就是给我带来了身体健康。 + [2022-05-02 18:12:29,767] [ INFO] - Response time 0.096548 s. + ``` + +- Python3 API + + ```python + from paddlespeech.server.bin.paddlespeech_client import TextClientExecutor + + textclient_executor = TextClientExecutor() + res = textclient_executor( + input="我认为跑步最重要的就是给我带来了身体健康", + server_ip="127.0.0.1", + port=8190,) + print(res) + ``` + + 输出: + ``` bash + 我认为跑步最重要的就是给我带来了身体健康。 + ``` + + +## 联合流式语音识别和标点预测 +使用 `streaming_asr_server.py` 和 `punc_server.py` 两个服务,分别启动流式语音识别和标点预测服务。调用 `websocket_client.py` 脚本可以同时调用流式语音识别和标点预测服务。 + +### 1. 启动服务 + +``` bash +注意:流式语音识别和标点预测通过配置文件配置到不同的显卡上 +bash server.sh +``` + +### 2. 调用服务 +- 使用命令行: + ``` + paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --input ./zh.wav + ``` + 输出: + ``` + [2022-05-02 18:57:46,961] [ INFO] - asr websocket client start + [2022-05-02 18:57:46,961] [ INFO] - endpoint: ws://127.0.0.1:8290/paddlespeech/asr/streaming + [2022-05-02 18:57:46,982] [ INFO] - client receive msg={"status": "ok", "signal": "server_ready"} + [2022-05-02 18:57:46,999] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:57:47,011] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:57:47,023] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:57:47,035] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:57:47,046] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:57:47,057] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:57:47,068] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:57:47,079] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:57:47,222] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:57:47,230] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:57:47,239] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:57:47,247] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:57:47,255] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:57:47,263] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:57:47,271] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:57:47,462] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:57:47,525] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:57:47,589] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:57:47,649] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:57:47,708] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:57:47,766] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:57:47,824] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:57:47,881] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:57:48,130] [ INFO] - client receive msg={'result': '我认为,跑步最重要的。'} + [2022-05-02 18:57:48,200] [ INFO] - client receive msg={'result': '我认为,跑步最重要的。'} + [2022-05-02 18:57:48,265] [ INFO] - client receive msg={'result': '我认为,跑步最重要的。'} + [2022-05-02 18:57:48,327] [ INFO] - client receive msg={'result': '我认为,跑步最重要的。'} + [2022-05-02 18:57:48,389] [ INFO] - client receive msg={'result': '我认为,跑步最重要的。'} + [2022-05-02 18:57:48,448] [ INFO] - client receive msg={'result': '我认为,跑步最重要的。'} + [2022-05-02 18:57:48,505] [ INFO] - client receive msg={'result': '我认为,跑步最重要的。'} + [2022-05-02 18:57:48,754] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:57:48,821] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:57:48,881] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:57:48,939] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:57:49,011] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:57:49,080] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:57:49,146] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:57:49,210] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:57:49,452] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是给。'} + [2022-05-02 18:57:49,516] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是给。'} + [2022-05-02 18:57:49,581] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是给。'} + [2022-05-02 18:57:49,645] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是给。'} + [2022-05-02 18:57:49,706] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是给。'} + [2022-05-02 18:57:49,763] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是给。'} + [2022-05-02 18:57:49,818] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是给。'} + [2022-05-02 18:57:50,064] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:57:50,125] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:57:50,186] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:57:50,245] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:57:50,301] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:57:50,358] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:57:50,414] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:57:50,469] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:57:50,712] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了身体健康。'} + [2022-05-02 18:57:50,776] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了身体健康。'} + [2022-05-02 18:57:50,837] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了身体健康。'} + [2022-05-02 18:57:50,897] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了身体健康。'} + [2022-05-02 18:57:50,956] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了身体健康。'} + [2022-05-02 18:57:51,012] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了身体健康。'} + [2022-05-02 18:57:51,276] [ INFO] - client final receive msg={'status': 'ok', 'signal': 'finished', 'result': '我认为跑步最重要的就是给我带来了身体健康。'} + [2022-05-02 18:57:51,277] [ INFO] - asr websocket client finished + [2022-05-02 18:57:51,277] [ INFO] - 我认为跑步最重要的就是给我带来了身体健康。 + [2022-05-02 18:57:51,277] [ INFO] - Response time 4.316903 s. + ``` + +- 使用脚本调用 + ``` + python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --wavfile ./zh.wav + ``` + 输出: + ``` + [2022-05-02 18:29:22,039] [ INFO] - Start to do streaming asr client + [2022-05-02 18:29:22,040] [ INFO] - asr websocket client start + [2022-05-02 18:29:22,040] [ INFO] - endpoint: ws://127.0.0.1:8290/paddlespeech/asr/streaming + [2022-05-02 18:29:22,041] [ INFO] - start to process the wavscp: ./zh.wav + [2022-05-02 18:29:22,122] [ INFO] - client receive msg={"status": "ok", "signal": "server_ready"} + [2022-05-02 18:29:22,351] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:29:22,360] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:29:22,368] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:29:22,376] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:29:22,384] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:29:22,392] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:29:22,400] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:29:22,408] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:29:22,549] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:29:22,558] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:29:22,567] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:29:22,575] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:29:22,583] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:29:22,591] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:29:22,599] [ INFO] - client receive msg={'result': ''} + [2022-05-02 18:29:22,822] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:29:22,879] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:29:22,937] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:29:22,995] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:29:23,052] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:29:23,107] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:29:23,161] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:29:23,213] [ INFO] - client receive msg={'result': '我认为,跑'} + [2022-05-02 18:29:23,454] [ INFO] - client receive msg={'result': '我认为,跑步最重要的。'} + [2022-05-02 18:29:23,515] [ INFO] - client receive msg={'result': '我认为,跑步最重要的。'} + [2022-05-02 18:29:23,575] [ INFO] - client receive msg={'result': '我认为,跑步最重要的。'} + [2022-05-02 18:29:23,630] [ INFO] - client receive msg={'result': '我认为,跑步最重要的。'} + [2022-05-02 18:29:23,684] [ INFO] - client receive msg={'result': '我认为,跑步最重要的。'} + [2022-05-02 18:29:23,736] [ INFO] - client receive msg={'result': '我认为,跑步最重要的。'} + [2022-05-02 18:29:23,789] [ INFO] - client receive msg={'result': '我认为,跑步最重要的。'} + [2022-05-02 18:29:24,030] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:29:24,095] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:29:24,156] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:29:24,213] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:29:24,268] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:29:24,323] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:29:24,377] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:29:24,429] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是。'} + [2022-05-02 18:29:24,671] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是给。'} + [2022-05-02 18:29:24,736] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是给。'} + [2022-05-02 18:29:24,797] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是给。'} + [2022-05-02 18:29:24,857] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是给。'} + [2022-05-02 18:29:24,918] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是给。'} + [2022-05-02 18:29:24,975] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是给。'} + [2022-05-02 18:29:25,029] [ INFO] - client receive msg={'result': '我认为,跑步最重要的就是给。'} + [2022-05-02 18:29:25,271] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:29:25,336] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:29:25,398] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:29:25,458] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:29:25,521] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:29:25,579] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:29:25,652] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:29:25,722] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了。'} + [2022-05-02 18:29:25,969] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了身体健康。'} + [2022-05-02 18:29:26,034] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了身体健康。'} + [2022-05-02 18:29:26,095] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了身体健康。'} + [2022-05-02 18:29:26,163] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了身体健康。'} + [2022-05-02 18:29:26,229] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了身体健康。'} + [2022-05-02 18:29:26,294] [ INFO] - client receive msg={'result': '我认为跑步最重要的就是给我带来了身体健康。'} + [2022-05-02 18:29:26,565] [ INFO] - client final receive msg={'status': 'ok', 'signal': 'finished', 'result': '我认为跑步最重要的就是给我带来了身体健康。'} + [2022-05-02 18:29:26,566] [ INFO] - asr websocket client finished : 我认为跑步最重要的就是给我带来了身体健康。 + ``` + + \ No newline at end of file diff --git a/demos/streaming_asr_server/conf/punc_application.yaml b/demos/streaming_asr_server/conf/punc_application.yaml new file mode 100644 index 00000000..e0d06871 --- /dev/null +++ b/demos/streaming_asr_server/conf/punc_application.yaml @@ -0,0 +1,35 @@ +# This is the parameter configuration file for PaddleSpeech Serving. + +################################################################################# +# SERVER SETTING # +################################################################################# +host: 0.0.0.0 +port: 8190 + +# The task format in the engin_list is: _ +# task choices = ['asr_python'] +# protocol = ['http'] (only one can be selected). +# http only support offline engine type. +protocol: 'http' +engine_list: ['text_python'] + + +################################################################################# +# ENGINE CONFIG # +################################################################################# + +################################### Text ######################################### +################### text task: punc; engine_type: python ####################### +text_python: + task: punc + model_type: 'ernie_linear_p3_wudao' + lang: 'zh' + sample_rate: 16000 + cfg_path: # [optional] + ckpt_path: # [optional] + vocab_file: # [optional] + device: gpu:0 # set 'gpu:id' or 'cpu' + + + + diff --git a/demos/streaming_asr_server/conf/ws_conformer_application.yaml b/demos/streaming_asr_server/conf/ws_conformer_application.yaml index 50c7a727..42473555 100644 --- a/demos/streaming_asr_server/conf/ws_conformer_application.yaml +++ b/demos/streaming_asr_server/conf/ws_conformer_application.yaml @@ -4,7 +4,7 @@ # SERVER SETTING # ################################################################################# host: 0.0.0.0 -port: 8090 +port: 8290 # The task format in the engin_list is: _ # task choices = ['asr_online'] @@ -29,7 +29,7 @@ asr_online: cfg_path: decode_method: force_yes: True - device: # cpu or gpu:id + device: gpu:3 # cpu or gpu:id am_predictor_conf: device: # set 'gpu:id' or 'cpu' switch_ir_optim: True diff --git a/demos/streaming_asr_server/punc_server.py b/demos/streaming_asr_server/punc_server.py new file mode 100644 index 00000000..eefa0fb4 --- /dev/null +++ b/demos/streaming_asr_server/punc_server.py @@ -0,0 +1,38 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +from paddlespeech.cli.log import logger +from paddlespeech.server.bin.paddlespeech_server import ServerExecutor +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog='paddlespeech_server.start', add_help=True) + parser.add_argument( + "--config_file", + action="store", + help="yaml file of the app", + default=None, + required=True) + + parser.add_argument( + "--log_file", + action="store", + help="log file", + default="./log/paddlespeech.log") + logger.info("start to parse the args") + args = parser.parse_args() + + logger.info("start to launch the punctuation server") + punc_server = ServerExecutor() + punc_server(config_file=args.config_file, log_file=args.log_file) diff --git a/demos/streaming_asr_server/server.sh b/demos/streaming_asr_server/server.sh new file mode 100755 index 00000000..04858321 --- /dev/null +++ b/demos/streaming_asr_server/server.sh @@ -0,0 +1,5 @@ +export CUDA_VISIBLE_DEVICE=0,1,2,3 + +nohup python3 punc_server.py --config_file conf/punc_application.yaml > punc.log 2>&1 & + +nohup python3 streaming_asr_server.py --config_file conf/ws_conformer_application.yaml > streaming_asr.log 2>&1 & diff --git a/demos/streaming_asr_server/streaming_asr_server.py b/demos/streaming_asr_server/streaming_asr_server.py new file mode 100644 index 00000000..011b009a --- /dev/null +++ b/demos/streaming_asr_server/streaming_asr_server.py @@ -0,0 +1,38 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +from paddlespeech.cli.log import logger +from paddlespeech.server.bin.paddlespeech_server import ServerExecutor +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog='paddlespeech_server.start', add_help=True) + parser.add_argument( + "--config_file", + action="store", + help="yaml file of the app", + default=None, + required=True) + + parser.add_argument( + "--log_file", + action="store", + help="log file", + default="./log/paddlespeech.log") + logger.info("start to parse the args") + args = parser.parse_args() + + logger.info("start to launch the streaming asr server") + streaming_asr_server = ServerExecutor() + streaming_asr_server(config_file=args.config_file, log_file=args.log_file) diff --git a/demos/streaming_asr_server/test.sh b/demos/streaming_asr_server/test.sh index fe8155cf..912d67a2 100644 --- a/demos/streaming_asr_server/test.sh +++ b/demos/streaming_asr_server/test.sh @@ -1,5 +1,8 @@ # download the test wav wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav -# read the wav and pass it to service -python3 websocket_client.py --wavfile ./zh.wav +# read the wav and pass it to only streaming asr service +python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --wavfile ./zh.wav + +# read the wav and call streaming and punc service +python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --wavfile ./zh.wav diff --git a/paddlespeech/server/bin/paddlespeech_client.py b/paddlespeech/server/bin/paddlespeech_client.py index 2f1ce385..9d5c1b21 100644 --- a/paddlespeech/server/bin/paddlespeech_client.py +++ b/paddlespeech/server/bin/paddlespeech_client.py @@ -411,6 +411,18 @@ class ASROnlineClientExecutor(BaseExecutor): '--lang', type=str, default="zh_cn", help='language') self.parser.add_argument( '--audio_format', type=str, default="wav", help='audio format') + self.parser.add_argument( + '--punc.server_ip', + type=str, + default=None, + dest="punc_server_ip", + help='Punctuation server ip') + self.parser.add_argument( + '--punc.port', + type=int, + default=8190, + dest="punc_server_port", + help='Punctuation server port') def execute(self, argv: List[str]) -> bool: args = self.parser.parse_args(argv) @@ -428,7 +440,9 @@ class ASROnlineClientExecutor(BaseExecutor): port=port, sample_rate=sample_rate, lang=lang, - audio_format=audio_format) + audio_format=audio_format, + punc_server_ip=args.punc_server_ip, + punc_server_port=args.punc_server_port) time_end = time.time() logger.info(res) logger.info("Response time %f s." % (time_end - time_start)) @@ -445,12 +459,30 @@ class ASROnlineClientExecutor(BaseExecutor): port: int=8091, sample_rate: int=16000, lang: str="zh_cn", - audio_format: str="wav"): - """ - Python API to call an executor. + audio_format: str="wav", + punc_server_ip: str=None, + punc_server_port: str=None): + """Python API to call asr online executor. + + Args: + input (str): the audio file to be send to streaming asr service. + server_ip (str, optional): streaming asr server ip. Defaults to "127.0.0.1". + port (int, optional): streaming asr server port. Defaults to 8091. + sample_rate (int, optional): audio sample rate. Defaults to 16000. + lang (str, optional): audio language type. Defaults to "zh_cn". + audio_format (str, optional): audio format. Defaults to "wav". + punc_server_ip (str, optional): punctuation server ip. Defaults to None. + punc_server_port (str, optional): punctuation server port. Defaults to None. + + Returns: + str: the audio text """ logger.info("asr websocket client start") - handler = ASRWsAudioHandler(server_ip, port) + handler = ASRWsAudioHandler( + server_ip, + port, + punc_server_ip=punc_server_ip, + punc_server_port=punc_server_port) loop = asyncio.get_event_loop() res = loop.run_until_complete(handler.run(input)) logger.info("asr websocket client finished") From 2ab96187aaad5f7e05788fe61b3baa2c1fc5d103 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Wed, 4 May 2022 16:42:12 +0800 Subject: [PATCH 04/11] streaming asr server add time stamp, test=doc --- .../server/engine/asr/online/asr_engine.py | 34 ++++++++ .../server/engine/asr/online/ctc_search.py | 87 ++++++++++++++++--- paddlespeech/server/ws/asr_socket.py | 4 +- 3 files changed, 110 insertions(+), 15 deletions(-) diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index 990590b4..a98268f0 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -290,6 +290,7 @@ class PaddleASRConnectionHanddler: self.chunk_num = 0 self.global_frame_offset = 0 self.result_transcripts = [''] + self.word_time_stamp = None def decode(self, is_finished=False): if "deepspeech2online" in self.model_type: @@ -505,6 +506,12 @@ class PaddleASRConnectionHanddler: else: return '' + def get_word_time_stamp(self): + if self.word_time_stamp is None: + return [] + else: + return self.word_time_stamp + @paddle.no_grad() def rescoring(self): if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: @@ -569,8 +576,35 @@ class PaddleASRConnectionHanddler: # update the one best result logger.info(f"best index: {best_index}") self.hyps = [hyps[best_index][0]] + + # update the hyps time stamp + self.time_stamp = hyps[best_index][5] if hyps[best_index][2] > hyps[ + best_index][3] else hyps[best_index][6] + logger.info(f"time stamp: {self.time_stamp}") + self.update_result() + # update each word start and end time stamp + frame_shift_in_ms = self.model.encoder.embed.subsampling_rate * self.n_shift / self.sample_rate + logger.info(f"frame shift ms: {frame_shift_in_ms}") + word_time_stamp = [] + for idx, _ in enumerate(self.time_stamp): + start = (self.time_stamp[idx - 1] + self.time_stamp[idx] + ) / 2.0 if idx > 0 else 0 + start = start * frame_shift_in_ms + + end = (self.time_stamp[idx] + self.time_stamp[idx + 1] + ) / 2.0 if idx < len(self.time_stamp) - 1 else self.offset + end = end * frame_shift_in_ms + word_time_stamp.append({ + "w": self.result_transcripts[0][idx], + "bg": start, + "ed": end + }) + # logger.info(f"{self.result_transcripts[0][idx]}, start: {start}, end: {end}") + self.word_time_stamp = word_time_stamp + logger.info(f"word time stamp: {self.word_time_stamp}") + class ASRServerExecutor(ASRExecutor): def __init__(self): diff --git a/paddlespeech/server/engine/asr/online/ctc_search.py b/paddlespeech/server/engine/asr/online/ctc_search.py index be5fb15b..3a808587 100644 --- a/paddlespeech/server/engine/asr/online/ctc_search.py +++ b/paddlespeech/server/engine/asr/online/ctc_search.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy from collections import defaultdict import paddle @@ -54,14 +55,24 @@ class CTCPrefixBeamSearch: assert len(ctc_probs.shape) == 2 # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score)) - # blank_ending_score and none_blank_ending_score in ln domain + # 0. blank_ending_score, + # 1. none_blank_ending_score, + # 2. viterbi_blank ending, + # 3. viterbi_non_blank, + # 4. current_token_prob, + # 5. times_viterbi_blank, + # 6. times_titerbi_non_blank if self.cur_hyps is None: - self.cur_hyps = [(tuple(), (0.0, -float('inf')))] + self.cur_hyps = [(tuple(), (0.0, -float('inf'), 0.0, 0.0, + -float('inf'), [], []))] + # self.cur_hyps = [(tuple(), (0.0, -float('inf')))] # 2. CTC beam search step by step for t in range(0, maxlen): logp = ctc_probs[t] # (vocab_size,) # key: prefix, value (pb, pnb), default value(-inf, -inf) - next_hyps = defaultdict(lambda: (-float('inf'), -float('inf'))) + # next_hyps = defaultdict(lambda: (-float('inf'), -float('inf'))) + next_hyps = defaultdict( + lambda: (-float('inf'), -float('inf'), -float('inf'), -float('inf'), -float('inf'), [], [])) # 2.1 First beam prune: select topk best # do token passing process @@ -69,36 +80,83 @@ class CTCPrefixBeamSearch: for s in top_k_index: s = s.item() ps = logp[s].item() - for prefix, (pb, pnb) in self.cur_hyps: + for prefix, (pb, pnb, v_s, v_ns, cur_token_prob, times_s, + times_ns) in self.cur_hyps: last = prefix[-1] if len(prefix) > 0 else None if s == blank_id: # blank - n_pb, n_pnb = next_hyps[prefix] + n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, n_times_ns = next_hyps[ + prefix] n_pb = log_add([n_pb, pb + ps, pnb + ps]) - next_hyps[prefix] = (n_pb, n_pnb) + + pre_times = times_s if v_s > v_ns else times_ns + n_times_s = copy.deepcopy(pre_times) + viterbi_score = v_s if v_s > v_ns else v_ns + n_v_s = viterbi_score + ps + next_hyps[prefix] = (n_pb, n_pnb, n_v_s, n_v_ns, + n_cur_token_prob, n_times_s, + n_times_ns) elif s == last: # Update *ss -> *s; - n_pb, n_pnb = next_hyps[prefix] + # case1: *a + a => *a + n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, n_times_ns = next_hyps[ + prefix] n_pnb = log_add([n_pnb, pnb + ps]) - next_hyps[prefix] = (n_pb, n_pnb) + if n_v_ns < v_ns + ps: + n_v_ns = v_ns + ps + if n_cur_token_prob < ps: + n_cur_token_prob = ps + n_times_ns = copy.deepcopy(times_ns) + n_times_ns[ + -1] = self.abs_time_step # 注意,这里要重新使用绝对时间 + next_hyps[prefix] = (n_pb, n_pnb, n_v_s, n_v_ns, + n_cur_token_prob, n_times_s, + n_times_ns) + # Update *s-s -> *ss, - is for blank + # Case 2: *aε + a => *aa n_prefix = prefix + (s, ) - n_pb, n_pnb = next_hyps[n_prefix] + n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, n_times_ns = next_hyps[ + n_prefix] + if n_v_ns < v_s + ps: + n_v_ns = v_s + ps + n_cur_token_prob = ps + n_times_ns = copy.deepcopy(times_s) + n_times_ns.append(self.abs_time_step) n_pnb = log_add([n_pnb, pb + ps]) - next_hyps[n_prefix] = (n_pb, n_pnb) + next_hyps[n_prefix] = (n_pb, n_pnb, n_v_s, n_v_ns, + n_cur_token_prob, n_times_s, + n_times_ns) else: + # Case 3: *a + b => *ab, *aε + b => *ab n_prefix = prefix + (s, ) - n_pb, n_pnb = next_hyps[n_prefix] + n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, n_times_n = next_hyps[ + n_prefix] + viterbi_score = v_s if v_s > v_ns else v_ns + pre_times = times_s if v_s > v_ns else times_ns + if n_v_ns < viterbi_score + ps: + n_v_ns = viterbi_score + ps + n_cur_token_prob = ps + n_times_ns = copy.deepcopy(pre_times) + n_times_ns.append(self.abs_time_step) + n_pnb = log_add([n_pnb, pb + ps, pnb + ps]) - next_hyps[n_prefix] = (n_pb, n_pnb) + next_hyps[n_prefix] = (n_pb, n_pnb, n_v_s, n_v_ns, + n_cur_token_prob, n_times_s, + n_times_ns) # 2.2 Second beam prune next_hyps = sorted( next_hyps.items(), - key=lambda x: log_add(list(x[1])), + key=lambda x: log_add([x[1][0], x[1][1]]), reverse=True) self.cur_hyps = next_hyps[:beam_size] - self.hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in self.cur_hyps] + # 2.3 update the absolute time step + self.abs_time_step += 1 + # self.hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in self.cur_hyps] + self.hyps = [(y[0], log_add([y[1][0], y[1][1]]), y[1][2], y[1][3], + y[1][4], y[1][5], y[1][6]) for y in self.cur_hyps] + logger.info("ctc prefix search success") return self.hyps @@ -123,6 +181,7 @@ class CTCPrefixBeamSearch: """ self.cur_hyps = None self.hyps = None + self.abs_time_step = 0 def finalize_search(self): """do nothing in ctc_prefix_beam_search diff --git a/paddlespeech/server/ws/asr_socket.py b/paddlespeech/server/ws/asr_socket.py index 68686d3d..0f7dcddd 100644 --- a/paddlespeech/server/ws/asr_socket.py +++ b/paddlespeech/server/ws/asr_socket.py @@ -78,12 +78,14 @@ async def websocket_endpoint(websocket: WebSocket): connection_handler.decode(is_finished=True) connection_handler.rescoring() asr_results = connection_handler.get_result() + word_time_stamp = connection_handler.get_word_time_stamp() connection_handler.reset() resp = { "status": "ok", "signal": "finished", - 'result': asr_results + 'result': asr_results, + 'times': word_time_stamp } await websocket.send_json(resp) break From 10da21a77b64e39626ea4f9481e8a0c483e0ef74 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Mon, 2 May 2022 15:37:41 +0800 Subject: [PATCH 05/11] update the vector cli for server, test=doc --- demos/streaming_asr_server/websocket_client.py | 2 +- paddlespeech/cli/vector/infer.py | 3 +++ paddlespeech/server/bin/paddlespeech_client.py | 6 +++++- paddlespeech/server/engine/vector/python/vector_engine.py | 4 ++-- paddlespeech/server/restful/request.py | 2 +- 5 files changed, 12 insertions(+), 5 deletions(-) diff --git a/demos/streaming_asr_server/websocket_client.py b/demos/streaming_asr_server/websocket_client.py index 3cadd72a..3451b8d0 100644 --- a/demos/streaming_asr_server/websocket_client.py +++ b/demos/streaming_asr_server/websocket_client.py @@ -37,7 +37,7 @@ def main(args): if args.wavfile and os.path.exists(args.wavfile): logger.info(f"start to process the wavscp: {args.wavfile}") result = loop.run_until_complete(handler.run(args.wavfile)) - # result = result["result"] + result = result["result"] logger.info(f"asr websocket client finished : {result}") # support to process batch audios from wav.scp diff --git a/paddlespeech/cli/vector/infer.py b/paddlespeech/cli/vector/infer.py index 8afb0f5c..3111badf 100644 --- a/paddlespeech/cli/vector/infer.py +++ b/paddlespeech/cli/vector/infer.py @@ -285,8 +285,10 @@ class VectorExecutor(BaseExecutor): Defaults to None. ckpt_path (Optional[os.PathLike], optional): the pretrained model path, which is stored in the disk. Defaults to None. + task (str, optional): the model task type """ # stage 0: avoid to init the mode again + self.task = task if hasattr(self, "model"): logger.info("Model has been initialized") return @@ -435,6 +437,7 @@ class VectorExecutor(BaseExecutor): if self.sample_rate != 16000 and self.sample_rate != 8000: logger.error( "invalid sample rate, please input --sr 8000 or --sr 16000") + logger.error(f"The model sample rate: {self.sample_rate}, the external sample rate is: {sample_rate}") return False if isinstance(audio_file, (str, os.PathLike)): diff --git a/paddlespeech/server/bin/paddlespeech_client.py b/paddlespeech/server/bin/paddlespeech_client.py index 32f78942..cd1cd51a 100644 --- a/paddlespeech/server/bin/paddlespeech_client.py +++ b/paddlespeech/server/bin/paddlespeech_client.py @@ -602,7 +602,11 @@ class VectorClientExecutor(BaseExecutor): default=None, help='sentence to be process by text server.') self.parser.add_argument( - '--task', type=str, default="spk", help="The vector service task") + '--task', + type=str, + default="spk", + choices=["spk", "score"], + help="The vector service task") self.parser.add_argument( "--enroll", type=str, default=None, help="The enroll audio") self.parser.add_argument( diff --git a/paddlespeech/server/engine/vector/python/vector_engine.py b/paddlespeech/server/engine/vector/python/vector_engine.py index 866c2229..2fd8dec6 100644 --- a/paddlespeech/server/engine/vector/python/vector_engine.py +++ b/paddlespeech/server/engine/vector/python/vector_engine.py @@ -99,8 +99,8 @@ class PaddleVectorConnectionHandler: """extract the audio embedding Args: - audio (_type_): _description_ - sample_rate (int, optional): _description_. Defaults to 16000. + audio (str): the audio data + sample_rate (int, optional): the audio sample rate. Defaults to 16000. """ # we can not reuse the cache io.BytesIO(audio) data, # because the soundfile will change the io.BytesIO(audio) to the end diff --git a/paddlespeech/server/restful/request.py b/paddlespeech/server/restful/request.py index 4e88280a..b7a32481 100644 --- a/paddlespeech/server/restful/request.py +++ b/paddlespeech/server/restful/request.py @@ -115,7 +115,7 @@ class VectorScoreRequest(BaseModel): { "enroll_audio": "exSI6ICJlbiIsCgkgICAgInBvc2l0aW9uIjogImZhbHNlIgoJf...", "test_audio": "exSI6ICJlbiIsCgkgICAgInBvc2l0aW9uIjogImZhbHNlIgoJf...", - "task": "spk", + "task": "score", "audio_format": "wav", "sample_rate": 16000, } From 624ab2c57afafef0260230e707ebe849308bf82f Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 5 May 2022 11:08:00 +0000 Subject: [PATCH 06/11] update asr1 config --- examples/aishell/asr1/conf/chunk_conformer.yaml | 6 +++--- examples/aishell/asr1/conf/conformer.yaml | 2 +- examples/aishell/asr1/conf/transformer.yaml | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/aishell/asr1/conf/chunk_conformer.yaml b/examples/aishell/asr1/conf/chunk_conformer.yaml index 3cfe9b1b..b389e367 100644 --- a/examples/aishell/asr1/conf/chunk_conformer.yaml +++ b/examples/aishell/asr1/conf/chunk_conformer.yaml @@ -10,7 +10,7 @@ encoder_conf: attention_heads: 4 linear_units: 2048 # the number of units of position-wise feed forward num_blocks: 12 # the number of encoder blocks - dropout_rate: 0.1 + dropout_rate: 0.1 # sublayer output dropout positional_dropout_rate: 0.1 attention_dropout_rate: 0.0 input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 @@ -30,7 +30,7 @@ decoder_conf: attention_heads: 4 linear_units: 2048 num_blocks: 6 - dropout_rate: 0.1 + dropout_rate: 0.1 # sublayer output dropout positional_dropout_rate: 0.1 self_attention_dropout_rate: 0.0 src_attention_dropout_rate: 0.0 @@ -39,7 +39,7 @@ model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option length_normalized_loss: false - init_type: 'kaiming_uniform' + init_type: 'kaiming_uniform' # !Warning: need to convergence ########################################### # Data # diff --git a/examples/aishell/asr1/conf/conformer.yaml b/examples/aishell/asr1/conf/conformer.yaml index a150a04d..2419d07a 100644 --- a/examples/aishell/asr1/conf/conformer.yaml +++ b/examples/aishell/asr1/conf/conformer.yaml @@ -37,7 +37,7 @@ model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option length_normalized_loss: false - init_type: 'kaiming_uniform' + init_type: 'kaiming_uniform' # !Warning: need to convergence ########################################### # Data # diff --git a/examples/aishell/asr1/conf/transformer.yaml b/examples/aishell/asr1/conf/transformer.yaml index 9e08ea0e..4e068420 100644 --- a/examples/aishell/asr1/conf/transformer.yaml +++ b/examples/aishell/asr1/conf/transformer.yaml @@ -10,7 +10,7 @@ encoder_conf: attention_heads: 4 linear_units: 2048 # the number of units of position-wise feed forward num_blocks: 12 # the number of encoder blocks - dropout_rate: 0.1 + dropout_rate: 0.1 # sublayer output dropout positional_dropout_rate: 0.1 attention_dropout_rate: 0.0 input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 @@ -21,7 +21,7 @@ decoder_conf: attention_heads: 4 linear_units: 2048 num_blocks: 6 - dropout_rate: 0.1 + dropout_rate: 0.1 # sublayer output dropout positional_dropout_rate: 0.1 self_attention_dropout_rate: 0.0 src_attention_dropout_rate: 0.0 From d1eb6269ff42fee7d3d8cbfbe0234e222cd40c54 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Thu, 5 May 2022 19:06:53 +0800 Subject: [PATCH 07/11] update the streaming asr and punc server to cpu device, test=doc --- demos/streaming_asr_server/README.md | 11 ++++++++--- demos/streaming_asr_server/README_cn.md | 8 ++++++-- demos/streaming_asr_server/conf/punc_application.yaml | 2 +- .../conf/ws_conformer_application.yaml | 4 ++-- demos/streaming_asr_server/test.sh | 0 5 files changed, 17 insertions(+), 8 deletions(-) mode change 100644 => 100755 demos/streaming_asr_server/test.sh diff --git a/demos/streaming_asr_server/README.md b/demos/streaming_asr_server/README.md index 48cfbaf3..d693dc41 100644 --- a/demos/streaming_asr_server/README.md +++ b/demos/streaming_asr_server/README.md @@ -29,7 +29,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav ### 3. Server Usage - Command Line (Recommended) - + **Note:** The default deployment of the server is on the 'CPU' device, which can be deployed on the 'GPU' by modifying the 'device' parameter in the service configuration file. ```bash # in PaddleSpeech/demos/streaming_asr_server start the service paddlespeech_server start --config_file ./conf/ws_conformer_application.yaml @@ -110,6 +110,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav ``` - Python API + **Note:** The default deployment of the server is on the 'CPU' device, which can be deployed on the 'GPU' by modifying the 'device' parameter in the service configuration file. ```python # in PaddleSpeech/demos/streaming_asr_server directory from paddlespeech.server.bin.paddlespeech_server import ServerExecutor @@ -361,8 +362,9 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav ## Punctuation service ### 1. Server usage - + - Command Line + **Note:** The default deployment of the server is on the 'CPU' device, which can be deployed on the 'GPU' by modifying the 'device' parameter in the service configuration file. ``` bash In PaddleSpeech/demos/streaming_asr_server directory to lanuch punctuation service paddlespeech_server start --config_file conf/punc_application.yaml @@ -401,7 +403,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav ``` - Python API - + **Note:** The default deployment of the server is on the 'CPU' device, which can be deployed on the 'GPU' by modifying the 'device' parameter in the service configuration file. ```python # 在 PaddleSpeech/demos/streaming_asr_server 目录 from paddlespeech.server.bin.paddlespeech_server import ServerExecutor @@ -467,6 +469,9 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav ## Join streaming asr and punctuation server + +By default, each server is deployed on the 'CPU' device and speech recognition and punctuation prediction can be deployed on different 'GPU' by modifying the' device 'parameter in the service configuration file respectively. + We use `streaming_ asr_server.py` and `punc_server.py` two services to lanuch streaming speech recognition and punctuation prediction services respectively. And the `websocket_client.py` script can be used to call streaming speech recognition and punctuation prediction services at the same time. ### 1. Start two server diff --git a/demos/streaming_asr_server/README_cn.md b/demos/streaming_asr_server/README_cn.md index 67f62860..db9cbb5e 100644 --- a/demos/streaming_asr_server/README_cn.md +++ b/demos/streaming_asr_server/README_cn.md @@ -36,7 +36,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav ### 3. 服务端使用方法 - 命令行 (推荐使用) - + **注意:** 默认部署在 `cpu` 设备上,可以通过修改服务配置文件中 `device` 参数部署在 `gpu` 上。 ```bash # 在 PaddleSpeech/demos/streaming_asr_server 目录启动服务 paddlespeech_server start --config_file ./conf/ws_conformer_application.yaml @@ -117,6 +117,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav ``` - Python API + **注意:** 默认部署在 `cpu` 设备上,可以通过修改服务配置文件中 `device` 参数部署在 `gpu` 上。 ```python # 在 PaddleSpeech/demos/streaming_asr_server 目录 from paddlespeech.server.bin.paddlespeech_server import ServerExecutor @@ -371,6 +372,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav ### 1. 服务端使用方法 - 命令行 + **注意:** 默认部署在 `cpu` 设备上,可以通过修改服务配置文件中 `device` 参数部署在 `gpu` 上。 ``` bash 在 PaddleSpeech/demos/streaming_asr_server 目录下启动标点预测服务 paddlespeech_server start --config_file conf/punc_application.yaml @@ -410,7 +412,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav ``` - Python API - + **注意:** 默认部署在 `cpu` 设备上,可以通过修改服务配置文件中 `device` 参数部署在 `gpu` 上。 ```python # 在 PaddleSpeech/demos/streaming_asr_server 目录 from paddlespeech.server.bin.paddlespeech_server import ServerExecutor @@ -476,6 +478,8 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav ## 联合流式语音识别和标点预测 +**注意:** By default, each server is deployed on the 'CPU' device. Voice recognition and punctuation prediction can be deployed on different 'GPUs' by modifying the' device 'parameter in the service configuration file. + 使用 `streaming_asr_server.py` 和 `punc_server.py` 两个服务,分别启动流式语音识别和标点预测服务。调用 `websocket_client.py` 脚本可以同时调用流式语音识别和标点预测服务。 ### 1. 启动服务 diff --git a/demos/streaming_asr_server/conf/punc_application.yaml b/demos/streaming_asr_server/conf/punc_application.yaml index e0d06871..f947525e 100644 --- a/demos/streaming_asr_server/conf/punc_application.yaml +++ b/demos/streaming_asr_server/conf/punc_application.yaml @@ -28,7 +28,7 @@ text_python: cfg_path: # [optional] ckpt_path: # [optional] vocab_file: # [optional] - device: gpu:0 # set 'gpu:id' or 'cpu' + device: 'cpu' # set 'gpu:id' or 'cpu' diff --git a/demos/streaming_asr_server/conf/ws_conformer_application.yaml b/demos/streaming_asr_server/conf/ws_conformer_application.yaml index 42473555..20a50008 100644 --- a/demos/streaming_asr_server/conf/ws_conformer_application.yaml +++ b/demos/streaming_asr_server/conf/ws_conformer_application.yaml @@ -29,7 +29,7 @@ asr_online: cfg_path: decode_method: force_yes: True - device: gpu:3 # cpu or gpu:id + device: 'cpu' # cpu or gpu:id am_predictor_conf: device: # set 'gpu:id' or 'cpu' switch_ir_optim: True @@ -42,4 +42,4 @@ asr_online: window_ms: 25 # ms shift_ms: 10 # ms sample_rate: 16000 - sample_width: 2 \ No newline at end of file + sample_width: 2 diff --git a/demos/streaming_asr_server/test.sh b/demos/streaming_asr_server/test.sh old mode 100644 new mode 100755 From 0e2372edd20fde7331b7618302babf358e7c6c5b Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Thu, 5 May 2022 19:43:12 +0800 Subject: [PATCH 08/11] update readme_cn.md, test=doc --- demos/streaming_asr_server/README_cn.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/demos/streaming_asr_server/README_cn.md b/demos/streaming_asr_server/README_cn.md index db9cbb5e..b768c435 100644 --- a/demos/streaming_asr_server/README_cn.md +++ b/demos/streaming_asr_server/README_cn.md @@ -478,7 +478,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav ## 联合流式语音识别和标点预测 -**注意:** By default, each server is deployed on the 'CPU' device. Voice recognition and punctuation prediction can be deployed on different 'GPUs' by modifying the' device 'parameter in the service configuration file. +**注意:** 默认部署在 `cpu` 设备上,可以通过修改服务配置文件中 `device` 参数将语音识别和标点预测部署在不同的 `gpu` 上。 使用 `streaming_asr_server.py` 和 `punc_server.py` 两个服务,分别启动流式语音识别和标点预测服务。调用 `websocket_client.py` 脚本可以同时调用流式语音识别和标点预测服务。 From b7a77eebcaad28ef9036b25dd0f9f6fbf030e0f5 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Thu, 5 May 2022 23:22:24 +0800 Subject: [PATCH 09/11] update the time stamp type, test=doc --- .../server/engine/asr/online/asr_engine.py | 20 ++++++++++----- .../server/engine/asr/online/ctc_search.py | 25 +++++++++---------- 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index a2559816..99d34a30 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -13,7 +13,6 @@ # limitations under the License. import copy import os -import time from typing import Optional import numpy as np @@ -297,7 +296,8 @@ class PaddleASRConnectionHanddler: self.chunk_num = 0 self.global_frame_offset = 0 self.result_transcripts = [''] - self.word_time_stamp = None + self.word_time_stamp = [] + self.time_stamp = [] self.first_char_occur_elapsed = None def decode(self, is_finished=False): @@ -515,10 +515,7 @@ class PaddleASRConnectionHanddler: return '' def get_word_time_stamp(self): - if self.word_time_stamp is None: - return [] - else: - return self.word_time_stamp + return self.word_time_stamp @paddle.no_grad() def rescoring(self): @@ -582,7 +579,18 @@ class PaddleASRConnectionHanddler: best_index = i # update the one best result + # hyps stored the beam results and each fields is: + logger.info(f"best index: {best_index}") + # logger.info(f'best result: {hyps[best_index]}') + # the field of the hyps is: + # hyps[0][0]: the sentence word-id in the vocab with a tuple + # hyps[0][1]: the sentence decoding probability with all paths + # hyps[0][2]: viterbi_blank ending probability + # hyps[0][3]: viterbi_non_blank probability + # hyps[0][4]: current_token_prob, + # hyps[0][5]: times_viterbi_blank, + # hyps[0][6]: times_titerbi_non_blank self.hyps = [hyps[best_index][0]] # update the hyps time stamp diff --git a/paddlespeech/server/engine/asr/online/ctc_search.py b/paddlespeech/server/engine/asr/online/ctc_search.py index 3a808587..4c9ac3ac 100644 --- a/paddlespeech/server/engine/asr/online/ctc_search.py +++ b/paddlespeech/server/engine/asr/online/ctc_search.py @@ -27,7 +27,7 @@ class CTCPrefixBeamSearch: """Implement the ctc prefix beam search Args: - config (yacs.config.CfgNode): _description_ + config (yacs.config.CfgNode): the ctc prefix beam search configuration """ self.config = config self.reset() @@ -69,7 +69,6 @@ class CTCPrefixBeamSearch: # 2. CTC beam search step by step for t in range(0, maxlen): logp = ctc_probs[t] # (vocab_size,) - # key: prefix, value (pb, pnb), default value(-inf, -inf) # next_hyps = defaultdict(lambda: (-float('inf'), -float('inf'))) next_hyps = defaultdict( lambda: (-float('inf'), -float('inf'), -float('inf'), -float('inf'), -float('inf'), [], [])) @@ -80,7 +79,7 @@ class CTCPrefixBeamSearch: for s in top_k_index: s = s.item() ps = logp[s].item() - for prefix, (pb, pnb, v_s, v_ns, cur_token_prob, times_s, + for prefix, (pb, pnb, v_b_s, v_nb_s, cur_token_prob, times_s, times_ns) in self.cur_hyps: last = prefix[-1] if len(prefix) > 0 else None if s == blank_id: # blank @@ -88,9 +87,9 @@ class CTCPrefixBeamSearch: prefix] n_pb = log_add([n_pb, pb + ps, pnb + ps]) - pre_times = times_s if v_s > v_ns else times_ns + pre_times = times_s if v_b_s > v_nb_s else times_ns n_times_s = copy.deepcopy(pre_times) - viterbi_score = v_s if v_s > v_ns else v_ns + viterbi_score = v_b_s if v_b_s > v_nb_s else v_nb_s n_v_s = viterbi_score + ps next_hyps[prefix] = (n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, @@ -101,8 +100,8 @@ class CTCPrefixBeamSearch: n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, n_times_ns = next_hyps[ prefix] n_pnb = log_add([n_pnb, pnb + ps]) - if n_v_ns < v_ns + ps: - n_v_ns = v_ns + ps + if n_v_ns < v_nb_s + ps: + n_v_ns = v_nb_s + ps if n_cur_token_prob < ps: n_cur_token_prob = ps n_times_ns = copy.deepcopy(times_ns) @@ -117,8 +116,8 @@ class CTCPrefixBeamSearch: n_prefix = prefix + (s, ) n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, n_times_ns = next_hyps[ n_prefix] - if n_v_ns < v_s + ps: - n_v_ns = v_s + ps + if n_v_ns < v_b_s + ps: + n_v_ns = v_b_s + ps n_cur_token_prob = ps n_times_ns = copy.deepcopy(times_s) n_times_ns.append(self.abs_time_step) @@ -129,10 +128,10 @@ class CTCPrefixBeamSearch: else: # Case 3: *a + b => *ab, *aε + b => *ab n_prefix = prefix + (s, ) - n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, n_times_n = next_hyps[ + n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, n_times_ns = next_hyps[ n_prefix] - viterbi_score = v_s if v_s > v_ns else v_ns - pre_times = times_s if v_s > v_ns else times_ns + viterbi_score = v_b_s if v_b_s > v_nb_s else v_nb_s + pre_times = times_s if v_b_s > v_nb_s else times_ns if n_v_ns < viterbi_score + ps: n_v_ns = viterbi_score + ps n_cur_token_prob = ps @@ -153,7 +152,7 @@ class CTCPrefixBeamSearch: # 2.3 update the absolute time step self.abs_time_step += 1 - # self.hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in self.cur_hyps] + self.hyps = [(y[0], log_add([y[1][0], y[1][1]]), y[1][2], y[1][3], y[1][4], y[1][5], y[1][6]) for y in self.cur_hyps] From 5d5266abff63a32c8f1c97351a299371b4b40abc Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Fri, 6 May 2022 02:36:37 +0000 Subject: [PATCH 10/11] rm to_float32 flags, default is fbank --- speechx/examples/ds2_ol/aishell/run.sh | 1 - .../ds2_ol/decoder/recognizer_test_main.cc | 4 ++- .../ds2_ol/feat/compute_fbank_main.cc | 1 + .../feat/linear-spectrogram-wo-db-norm-ol.cc | 3 +- .../ds2_ol/websocket/websocket_server.sh | 1 - speechx/speechx/decoder/param.h | 31 ++++++++++++------- .../speechx/frontend/audio/feature_pipeline.h | 6 ++-- 7 files changed, 28 insertions(+), 19 deletions(-) diff --git a/speechx/examples/ds2_ol/aishell/run.sh b/speechx/examples/ds2_ol/aishell/run.sh index b44200b0..650cb140 100755 --- a/speechx/examples/ds2_ol/aishell/run.sh +++ b/speechx/examples/ds2_ol/aishell/run.sh @@ -155,7 +155,6 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \ --cmvn_file=$cmvn \ --model_path=$model_dir/avg_1.jit.pdmodel \ - --to_float32=true \ --streaming_chunk=30 \ --param_path=$model_dir/avg_1.jit.pdiparams \ --word_symbol_table=$wfst/words.txt \ diff --git a/speechx/examples/ds2_ol/decoder/recognizer_test_main.cc b/speechx/examples/ds2_ol/decoder/recognizer_test_main.cc index 00764f53..476fac05 100644 --- a/speechx/examples/ds2_ol/decoder/recognizer_test_main.cc +++ b/speechx/examples/ds2_ol/decoder/recognizer_test_main.cc @@ -19,6 +19,7 @@ DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); DEFINE_string(result_wspecifier, "", "test result wspecifier"); +DEFINE_int32(sample_rate, 16000, "sample rate"); int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, false); @@ -30,7 +31,8 @@ int main(int argc, char* argv[]) { kaldi::SequentialTableReader wav_reader( FLAGS_wav_rspecifier); kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); - int sample_rate = 16000; + + int sample_rate = FLAGS_sample_rate; float streaming_chunk = FLAGS_streaming_chunk; int chunk_sample_size = streaming_chunk * sample_rate; LOG(INFO) << "sr: " << sample_rate; diff --git a/speechx/examples/ds2_ol/feat/compute_fbank_main.cc b/speechx/examples/ds2_ol/feat/compute_fbank_main.cc index 7beaa587..67683eeb 100644 --- a/speechx/examples/ds2_ol/feat/compute_fbank_main.cc +++ b/speechx/examples/ds2_ol/feat/compute_fbank_main.cc @@ -69,6 +69,7 @@ int main(int argc, char* argv[]) { feat_cache_opts.frame_chunk_stride = 1; feat_cache_opts.frame_chunk_size = 1; ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn)); + LOG(INFO) << "fbank: " << true; LOG(INFO) << "feat dim: " << feature_cache.Dim(); int sample_rate = 16000; diff --git a/speechx/examples/ds2_ol/feat/linear-spectrogram-wo-db-norm-ol.cc b/speechx/examples/ds2_ol/feat/linear-spectrogram-wo-db-norm-ol.cc index c3652ad4..bbf0e690 100644 --- a/speechx/examples/ds2_ol/feat/linear-spectrogram-wo-db-norm-ol.cc +++ b/speechx/examples/ds2_ol/feat/linear-spectrogram-wo-db-norm-ol.cc @@ -56,6 +56,7 @@ int main(int argc, char* argv[]) { opt.frame_opts.remove_dc_offset = false; opt.frame_opts.window_type = "hanning"; opt.frame_opts.preemph_coeff = 0.0; + LOG(INFO) << "linear feature: " << true; LOG(INFO) << "frame length (ms): " << opt.frame_opts.frame_length_ms; LOG(INFO) << "frame shift (ms): " << opt.frame_opts.frame_shift_ms; @@ -77,7 +78,7 @@ int main(int argc, char* argv[]) { int sample_rate = 16000; float streaming_chunk = FLAGS_streaming_chunk; int chunk_sample_size = streaming_chunk * sample_rate; - LOG(INFO) << "sr: " << sample_rate; + LOG(INFO) << "sample rate: " << sample_rate; LOG(INFO) << "chunk size (s): " << streaming_chunk; LOG(INFO) << "chunk size (sample): " << chunk_sample_size; diff --git a/speechx/examples/ds2_ol/websocket/websocket_server.sh b/speechx/examples/ds2_ol/websocket/websocket_server.sh index 0e389f89..fc57e326 100755 --- a/speechx/examples/ds2_ol/websocket/websocket_server.sh +++ b/speechx/examples/ds2_ol/websocket/websocket_server.sh @@ -63,7 +63,6 @@ websocket_server_main \ --cmvn_file=$cmvn \ --model_path=$model_dir/avg_1.jit.pdmodel \ --streaming_chunk=0.1 \ - --to_float32=true \ --param_path=$model_dir/avg_1.jit.pdiparams \ --word_symbol_table=$wfst/words.txt \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ diff --git a/speechx/speechx/decoder/param.h b/speechx/speechx/decoder/param.h index 85de08ca..9905bc6e 100644 --- a/speechx/speechx/decoder/param.h +++ b/speechx/speechx/decoder/param.h @@ -19,23 +19,23 @@ #include "decoder/ctc_tlg_decoder.h" #include "frontend/audio/feature_pipeline.h" +// feature +DEFINE_bool(use_fbank, false, "False for fbank; or linear feature"); +// DEFINE_bool(to_float32, true, "audio convert to pcm32. True for linear feature, or fbank"); +DEFINE_int32(num_bins, 161, "num bins of mel"); DEFINE_string(cmvn_file, "", "read cmvn"); DEFINE_double(streaming_chunk, 0.1, "streaming feature chunk size"); -DEFINE_bool(to_float32, true, "audio convert to pcm32"); -DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); -DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); -DEFINE_string(word_symbol_table, "words.txt", "word symbol table"); -DEFINE_string(graph_path, "TLG", "decoder graph"); -DEFINE_double(acoustic_scale, 1.0, "acoustic scale"); -DEFINE_int32(max_active, 7500, "max active"); -DEFINE_double(beam, 15.0, "decoder beam"); -DEFINE_double(lattice_beam, 7.5, "decoder beam"); +// feature sliding window DEFINE_int32(receptive_field_length, 7, "receptive field of two CNN(kernel=5) downsampling module."); DEFINE_int32(downsampling_rate, 4, "two CNN(kernel=5) module downsampling rate."); + +// nnet +DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); +DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); DEFINE_string( model_input_names, "audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box", @@ -47,8 +47,14 @@ DEFINE_string(model_cache_names, "chunk_state_h_box,chunk_state_c_box", "model cache names"); DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes"); -DEFINE_bool(use_fbank, false, "use fbank or linear feature"); -DEFINE_int32(num_bins, 161, "num bins of mel"); + +// decoder +DEFINE_string(word_symbol_table, "words.txt", "word symbol table"); +DEFINE_string(graph_path, "TLG", "decoder graph"); +DEFINE_double(acoustic_scale, 1.0, "acoustic scale"); +DEFINE_int32(max_active, 7500, "max active"); +DEFINE_double(beam, 15.0, "decoder beam"); +DEFINE_double(lattice_beam, 7.5, "decoder beam"); namespace ppspeech { // todo refactor later @@ -56,17 +62,18 @@ FeaturePipelineOptions InitFeaturePipelineOptions() { FeaturePipelineOptions opts; opts.cmvn_file = FLAGS_cmvn_file; opts.linear_spectrogram_opts.streaming_chunk = FLAGS_streaming_chunk; - opts.to_float32 = FLAGS_to_float32; kaldi::FrameExtractionOptions frame_opts; frame_opts.dither = 0.0; frame_opts.frame_shift_ms = 10; opts.use_fbank = FLAGS_use_fbank; if (opts.use_fbank) { + opts.to_float32 = false; frame_opts.window_type = "povey"; frame_opts.frame_length_ms = 25; opts.fbank_opts.fbank_opts.mel_opts.num_bins = FLAGS_num_bins; opts.fbank_opts.fbank_opts.frame_opts = frame_opts; } else { + opts.to_float32 = true; frame_opts.remove_dc_offset = false; frame_opts.frame_length_ms = 20; frame_opts.window_type = "hanning"; diff --git a/speechx/speechx/frontend/audio/feature_pipeline.h b/speechx/speechx/frontend/audio/feature_pipeline.h index 4868d37e..1acf62a9 100644 --- a/speechx/speechx/frontend/audio/feature_pipeline.h +++ b/speechx/speechx/frontend/audio/feature_pipeline.h @@ -28,15 +28,15 @@ namespace ppspeech { struct FeaturePipelineOptions { std::string cmvn_file; - bool to_float32; + bool to_float32; // true, only for linear feature bool use_fbank; LinearSpectrogramOptions linear_spectrogram_opts; FbankOptions fbank_opts; FeatureCacheOptions feature_cache_opts; FeaturePipelineOptions() : cmvn_file(""), - to_float32(false), - use_fbank(false), + to_float32(false), // true, only for linear feature + use_fbank(true), linear_spectrogram_opts(), fbank_opts(), feature_cache_opts() {} From 8522b8299971e1d86ae6e474f656ea69c25f0060 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Fri, 6 May 2022 02:40:21 +0000 Subject: [PATCH 11/11] format --- demos/streaming_asr_server/README.md | 2 +- demos/streaming_asr_server/README_cn.md | 2 +- paddlespeech/cli/vector/infer.py | 4 ++- paddlespeech/server/README_cn.md | 2 +- paddlespeech/server/engine/vector/__init__.py | 13 ++++++++++ .../server/engine/vector/python/__init__.py | 13 ++++++++++ .../engine/vector/python/vector_engine.py | 2 +- .../ds2_ol/decoder/recognizer_test_main.cc | 2 +- speechx/speechx/decoder/param.h | 25 ++++++++++--------- speechx/speechx/frontend/audio/fbank.cc | 11 +++++--- .../frontend/audio/feature_pipeline.cc | 10 ++++---- .../speechx/frontend/audio/feature_pipeline.h | 6 ++--- 12 files changed, 62 insertions(+), 30 deletions(-) diff --git a/demos/streaming_asr_server/README.md b/demos/streaming_asr_server/README.md index d693dc41..6808de5e 100644 --- a/demos/streaming_asr_server/README.md +++ b/demos/streaming_asr_server/README.md @@ -630,4 +630,4 @@ bash server.sh [2022-05-02 18:29:26,566] [ INFO] - asr websocket client finished : 我认为跑步最重要的就是给我带来了身体健康。 ``` - \ No newline at end of file + diff --git a/demos/streaming_asr_server/README_cn.md b/demos/streaming_asr_server/README_cn.md index b768c435..5fa81d4b 100644 --- a/demos/streaming_asr_server/README_cn.md +++ b/demos/streaming_asr_server/README_cn.md @@ -638,4 +638,4 @@ bash server.sh [2022-05-02 18:29:26,566] [ INFO] - asr websocket client finished : 我认为跑步最重要的就是给我带来了身体健康。 ``` - \ No newline at end of file + diff --git a/paddlespeech/cli/vector/infer.py b/paddlespeech/cli/vector/infer.py index 3111badf..0a169f8b 100644 --- a/paddlespeech/cli/vector/infer.py +++ b/paddlespeech/cli/vector/infer.py @@ -437,7 +437,9 @@ class VectorExecutor(BaseExecutor): if self.sample_rate != 16000 and self.sample_rate != 8000: logger.error( "invalid sample rate, please input --sr 8000 or --sr 16000") - logger.error(f"The model sample rate: {self.sample_rate}, the external sample rate is: {sample_rate}") + logger.error( + f"The model sample rate: {self.sample_rate}, the external sample rate is: {sample_rate}" + ) return False if isinstance(audio_file, (str, os.PathLike)): diff --git a/paddlespeech/server/README_cn.md b/paddlespeech/server/README_cn.md index 010d3d51..a974d40f 100644 --- a/paddlespeech/server/README_cn.md +++ b/paddlespeech/server/README_cn.md @@ -82,4 +82,4 @@ paddlespeech_client vector --task spk --server_ip 127.0.0.1 --port 8090 --input ``` paddlespeech_client vector --task score --server_ip 127.0.0.1 --port 8090 --enroll 123456789.wav --test 85236145389.wav -``` \ No newline at end of file +``` diff --git a/paddlespeech/server/engine/vector/__init__.py b/paddlespeech/server/engine/vector/__init__.py index e69de29b..97043fd7 100644 --- a/paddlespeech/server/engine/vector/__init__.py +++ b/paddlespeech/server/engine/vector/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/engine/vector/python/__init__.py b/paddlespeech/server/engine/vector/python/__init__.py index e69de29b..97043fd7 100644 --- a/paddlespeech/server/engine/vector/python/__init__.py +++ b/paddlespeech/server/engine/vector/python/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/engine/vector/python/vector_engine.py b/paddlespeech/server/engine/vector/python/vector_engine.py index 2fd8dec6..85430370 100644 --- a/paddlespeech/server/engine/vector/python/vector_engine.py +++ b/paddlespeech/server/engine/vector/python/vector_engine.py @@ -16,9 +16,9 @@ from collections import OrderedDict import numpy as np import paddle - from paddleaudio.backends import load as load_audio from paddleaudio.compliance.librosa import melspectrogram + from paddlespeech.cli.log import logger from paddlespeech.cli.vector.infer import VectorExecutor from paddlespeech.server.engine.base_engine import BaseEngine diff --git a/speechx/examples/ds2_ol/decoder/recognizer_test_main.cc b/speechx/examples/ds2_ol/decoder/recognizer_test_main.cc index 476fac05..7aef73f7 100644 --- a/speechx/examples/ds2_ol/decoder/recognizer_test_main.cc +++ b/speechx/examples/ds2_ol/decoder/recognizer_test_main.cc @@ -31,7 +31,7 @@ int main(int argc, char* argv[]) { kaldi::SequentialTableReader wav_reader( FLAGS_wav_rspecifier); kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); - + int sample_rate = FLAGS_sample_rate; float streaming_chunk = FLAGS_streaming_chunk; int chunk_sample_size = streaming_chunk * sample_rate; diff --git a/speechx/speechx/decoder/param.h b/speechx/speechx/decoder/param.h index 9905bc6e..b2bf1890 100644 --- a/speechx/speechx/decoder/param.h +++ b/speechx/speechx/decoder/param.h @@ -21,7 +21,8 @@ // feature DEFINE_bool(use_fbank, false, "False for fbank; or linear feature"); -// DEFINE_bool(to_float32, true, "audio convert to pcm32. True for linear feature, or fbank"); +// DEFINE_bool(to_float32, true, "audio convert to pcm32. True for linear +// feature, or fbank"); DEFINE_int32(num_bins, 161, "num bins of mel"); DEFINE_string(cmvn_file, "", "read cmvn"); DEFINE_double(streaming_chunk, 0.1, "streaming feature chunk size"); @@ -67,18 +68,18 @@ FeaturePipelineOptions InitFeaturePipelineOptions() { frame_opts.frame_shift_ms = 10; opts.use_fbank = FLAGS_use_fbank; if (opts.use_fbank) { - opts.to_float32 = false; - frame_opts.window_type = "povey"; - frame_opts.frame_length_ms = 25; - opts.fbank_opts.fbank_opts.mel_opts.num_bins = FLAGS_num_bins; - opts.fbank_opts.fbank_opts.frame_opts = frame_opts; + opts.to_float32 = false; + frame_opts.window_type = "povey"; + frame_opts.frame_length_ms = 25; + opts.fbank_opts.fbank_opts.mel_opts.num_bins = FLAGS_num_bins; + opts.fbank_opts.fbank_opts.frame_opts = frame_opts; } else { - opts.to_float32 = true; - frame_opts.remove_dc_offset = false; - frame_opts.frame_length_ms = 20; - frame_opts.window_type = "hanning"; - frame_opts.preemph_coeff = 0.0; - opts.linear_spectrogram_opts.frame_opts = frame_opts; + opts.to_float32 = true; + frame_opts.remove_dc_offset = false; + frame_opts.frame_length_ms = 20; + frame_opts.window_type = "hanning"; + frame_opts.preemph_coeff = 0.0; + opts.linear_spectrogram_opts.frame_opts = frame_opts; } opts.feature_cache_opts.frame_chunk_size = FLAGS_receptive_field_length; opts.feature_cache_opts.frame_chunk_stride = FLAGS_downsampling_rate; diff --git a/speechx/speechx/frontend/audio/fbank.cc b/speechx/speechx/frontend/audio/fbank.cc index a865db59..fea9032a 100644 --- a/speechx/speechx/frontend/audio/fbank.cc +++ b/speechx/speechx/frontend/audio/fbank.cc @@ -102,13 +102,16 @@ bool Fbank::Compute(const Vector& waves, Vector* feats) { // note: this online feature-extraction code does not support VTLN. RealFft(&window, true); kaldi::ComputePowerSpectrum(&window); - const kaldi::MelBanks &mel_bank = *(computer_.GetMelBanks(1.0)); - SubVector power_spectrum(window, 0, window.Dim() / 2 + 1); + const kaldi::MelBanks& mel_bank = *(computer_.GetMelBanks(1.0)); + SubVector power_spectrum(window, 0, window.Dim() / 2 + 1); if (!opts_.fbank_opts.use_power) { power_spectrum.ApplyPow(0.5); } - int32 mel_offset = ((opts_.fbank_opts.use_energy && !opts_.fbank_opts.htk_compat) ? 1 : 0); - SubVector mel_energies(this_feature, mel_offset, opts_.fbank_opts.mel_opts.num_bins); + int32 mel_offset = + ((opts_.fbank_opts.use_energy && !opts_.fbank_opts.htk_compat) ? 1 + : 0); + SubVector mel_energies( + this_feature, mel_offset, opts_.fbank_opts.mel_opts.num_bins); mel_bank.Compute(power_spectrum, &mel_energies); mel_energies.ApplyFloor(1e-07); mel_energies.ApplyLog(); diff --git a/speechx/speechx/frontend/audio/feature_pipeline.cc b/speechx/speechx/frontend/audio/feature_pipeline.cc index 40891871..087de0f0 100644 --- a/speechx/speechx/frontend/audio/feature_pipeline.cc +++ b/speechx/speechx/frontend/audio/feature_pipeline.cc @@ -23,13 +23,13 @@ FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) { new ppspeech::AudioCache(1000 * kint16max, opts.to_float32)); unique_ptr base_feature; - + if (opts.use_fbank) { - base_feature.reset(new ppspeech::Fbank(opts.fbank_opts, - std::move(data_source))); + base_feature.reset( + new ppspeech::Fbank(opts.fbank_opts, std::move(data_source))); } else { - base_feature.reset(new ppspeech::LinearSpectrogram(opts.linear_spectrogram_opts, - std::move(data_source))); + base_feature.reset(new ppspeech::LinearSpectrogram( + opts.linear_spectrogram_opts, std::move(data_source))); } unique_ptr cmvn( diff --git a/speechx/speechx/frontend/audio/feature_pipeline.h b/speechx/speechx/frontend/audio/feature_pipeline.h index 1acf62a9..6b9b4795 100644 --- a/speechx/speechx/frontend/audio/feature_pipeline.h +++ b/speechx/speechx/frontend/audio/feature_pipeline.h @@ -18,24 +18,24 @@ #include "frontend/audio/audio_cache.h" #include "frontend/audio/data_cache.h" +#include "frontend/audio/fbank.h" #include "frontend/audio/feature_cache.h" #include "frontend/audio/frontend_itf.h" #include "frontend/audio/linear_spectrogram.h" -#include "frontend/audio/fbank.h" #include "frontend/audio/normalizer.h" namespace ppspeech { struct FeaturePipelineOptions { std::string cmvn_file; - bool to_float32; // true, only for linear feature + bool to_float32; // true, only for linear feature bool use_fbank; LinearSpectrogramOptions linear_spectrogram_opts; FbankOptions fbank_opts; FeatureCacheOptions feature_cache_opts; FeaturePipelineOptions() : cmvn_file(""), - to_float32(false), // true, only for linear feature + to_float32(false), // true, only for linear feature use_fbank(true), linear_spectrogram_opts(), fbank_opts(),