diff --git a/paddlespeech/cli/vector/infer.py b/paddlespeech/cli/vector/infer.py index 37e19391..1dff6edb 100644 --- a/paddlespeech/cli/vector/infer.py +++ b/paddlespeech/cli/vector/infer.py @@ -22,8 +22,6 @@ from typing import Union import paddle import soundfile -from paddleaudio.backends import load as load_audio -from paddleaudio.compliance.librosa import melspectrogram from yacs.config import CfgNode from ..executor import BaseExecutor @@ -32,6 +30,8 @@ from ..utils import cli_register from ..utils import stats_wrapper from .pretrained_models import model_alias from .pretrained_models import pretrained_models +from paddleaudio.backends import load as load_audio +from paddleaudio.compliance.librosa import melspectrogram from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.vector.io.batch import feature_normalize from paddlespeech.vector.modules.sid_model import SpeakerIdetification diff --git a/paddlespeech/server/bin/paddlespeech_client.py b/paddlespeech/server/bin/paddlespeech_client.py index f006a089..1cc0a6ab 100644 --- a/paddlespeech/server/bin/paddlespeech_client.py +++ b/paddlespeech/server/bin/paddlespeech_client.py @@ -476,3 +476,72 @@ class CLSClientExecutor(BaseExecutor): res = requests.post(url=url, data=json.dumps(data)) return res + + +@cli_client_register( + name='paddlespeech_client.text', description='visit the text service') +class TextClientExecutor(BaseExecutor): + def __init__(self): + super(TextClientExecutor, self).__init__() + self.parser = argparse.ArgumentParser( + prog='paddlespeech_client.text', add_help=True) + self.parser.add_argument( + '--server_ip', type=str, default='127.0.0.1', help='server ip') + self.parser.add_argument( + '--port', type=int, default=8090, help='server port') + self.parser.add_argument( + '--input', + type=str, + default=None, + help='sentence to be process by text server.', + required=True) + + def execute(self, argv: List[str]) -> bool: + """Execute the request from the argv. + + Args: + argv (List): the request arguments + + Returns: + str: the request flag + """ + args = self.parser.parse_args(argv) + input_ = args.input + server_ip = args.server_ip + port = args.port + output = args.output + + try: + time_start = time.time() + res = self(input=input_, server_ip=server_ip, port=port) + time_end = time.time() + logger.info(f"The punc text: {res}") + logger.info("Response time %f s." % (time_end - time_start)) + return True + except Exception as e: + logger.error("Failed to Text punctuation.") + return False + + @stats_wrapper + def __call__(self, input: str, server_ip: str="127.0.0.1", port: int=8090): + """ + Python API to call text executor. + + Args: + input (str): the request sentence text + server_ip (str, optional): the server ip. Defaults to "127.0.0.1". + port (int, optional): the server port. Defaults to 8090. + + Returns: + str: the punctuation text + """ + + url = 'http://' + server_ip + ":" + str(port) + '/paddlespeech/text' + request = { + "text": input, + } + + res = requests.post(url=url, data=json.dumps(request)) + response_dict = res.json() + punc_text = response_dict["result"]["punc_text"] + return punc_text diff --git a/paddlespeech/server/conf/application.yaml b/paddlespeech/server/conf/application.yaml index 849349c2..c8753059 100644 --- a/paddlespeech/server/conf/application.yaml +++ b/paddlespeech/server/conf/application.yaml @@ -11,7 +11,7 @@ port: 8090 # protocol = ['websocket', 'http'] (only one can be selected). # http only support offline engine type. protocol: 'http' -engine_list: ['asr_python', 'tts_python', 'cls_python'] +engine_list: ['asr_python', 'tts_python', 'cls_python', 'text_python'] ################################################################################# @@ -155,3 +155,15 @@ cls_inference: glog_info: False # True -> print glog summary: True # False -> do not show predictor config + +################################### Text ######################################### +################### text task: punc; engine_type: python ####################### +text_python: + task: punc + model_type: 'ernie_linear_p3_wudao' + lang: 'zh' + sample_rate: 16000 + cfg_path: # [optional] + ckpt_path: # [optional] + vocab_file: # [optional] + device: # set 'gpu:id' or 'cpu' \ No newline at end of file diff --git a/paddlespeech/server/engine/engine_factory.py b/paddlespeech/server/engine/engine_factory.py index 9ebf137d..30e48de7 100644 --- a/paddlespeech/server/engine/engine_factory.py +++ b/paddlespeech/server/engine/engine_factory.py @@ -46,5 +46,8 @@ class EngineFactory(object): elif engine_name == 'cls' and engine_type == 'python': from paddlespeech.server.engine.cls.python.cls_engine import CLSEngine return CLSEngine() + elif engine_name.lower() == 'text' and engine_type.lower() == 'python': + from paddlespeech.server.engine.text.python.text_engine import TextEngine + return TextEngine() else: return None diff --git a/paddlespeech/server/engine/text/__init__.py b/paddlespeech/server/engine/text/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/server/engine/text/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/engine/text/python/__init__.py b/paddlespeech/server/engine/text/python/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/server/engine/text/python/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/engine/text/python/text_engine.py b/paddlespeech/server/engine/text/python/text_engine.py new file mode 100644 index 00000000..73cf8737 --- /dev/null +++ b/paddlespeech/server/engine/text/python/text_engine.py @@ -0,0 +1,172 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import OrderedDict + +import paddle + +from paddlespeech.cli.log import logger +from paddlespeech.cli.text.infer import TextExecutor +from paddlespeech.server.engine.base_engine import BaseEngine + + +class PaddleTextConnectionHandler: + def __init__(self, text_engine): + """The PaddleSpeech Text Server Connection Handler + This connection process every server request + Args: + text_engine (TextEngine): The Text engine + """ + super().__init__() + logger.info( + "Create PaddleTextConnectionHandler to process the text request") + self.text_engine = text_engine + self.task = self.text_engine.executor.task + self.model = self.text_engine.executor.model + self.tokenizer = self.text_engine.executor.tokenizer + self._punc_list = self.text_engine.executor._punc_list + self._inputs = OrderedDict() + self._outputs = OrderedDict() + + @paddle.no_grad() + def run(self, text): + """The connection process the request text + + Args: + text (str): the request text + + Returns: + str: the punctuation text + """ + self.preprocess(text) + self.infer() + res = self.postprocess() + + return res + + @paddle.no_grad() + def preprocess(self, text): + """ + Input preprocess and return paddle.Tensor stored in self.input. + Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet). + + Args: + text (str): the request text + """ + if self.task == 'punc': + clean_text = self.text_engine.executor._clean_text(text) + assert len(clean_text) > 0, f'Invalid input string: {text}' + + tokenized_input = self.tokenizer( + list(clean_text), return_length=True, is_split_into_words=True) + + self._inputs['input_ids'] = tokenized_input['input_ids'] + self._inputs['seg_ids'] = tokenized_input['token_type_ids'] + self._inputs['seq_len'] = tokenized_input['seq_len'] + else: + raise NotImplementedError + + @paddle.no_grad() + def infer(self): + """Model inference and result stored in self.output. + """ + if self.task == 'punc': + input_ids = paddle.to_tensor(self._inputs['input_ids']).unsqueeze(0) + seg_ids = paddle.to_tensor(self._inputs['seg_ids']).unsqueeze(0) + logits, _ = self.model(input_ids, seg_ids) + preds = paddle.argmax(logits, axis=-1).squeeze(0) + + self._outputs['preds'] = preds + else: + raise NotImplementedError + + def postprocess(self): + """Output postprocess and return human-readable results such as texts and audio files. + + Returns: + str: The punctuation text + """ + if self.task == 'punc': + input_ids = self._inputs['input_ids'] + seq_len = self._inputs['seq_len'] + preds = self._outputs['preds'] + + tokens = self.tokenizer.convert_ids_to_tokens( + input_ids[1:seq_len - 1]) + labels = preds[1:seq_len - 1].tolist() + assert len(tokens) == len(labels) + + text = '' + for t, l in zip(tokens, labels): + text += t + if l != 0: # Non punc. + text += self._punc_list[l] + + return text + else: + raise NotImplementedError + + +class TextServerExecutor(TextExecutor): + def __init__(self): + """The wrapper for TextEcutor + """ + super().__init__() + pass + + +class TextEngine(BaseEngine): + def __init__(self): + """The Text Engine + """ + super(TextEngine, self).__init__() + logger.info("Create the TextEngine Instance") + + def init(self, config: dict): + """Init the Text Engine + + Args: + config (dict): The server configuation + + Returns: + bool: The engine instance flag + """ + logger.info("Init the text engine") + try: + self.config = config + if self.config.device: + self.device = self.config.device + else: + self.device = paddle.get_device() + + paddle.set_device(self.device) + logger.info(f"Text Engine set the device: {self.device}") + except BaseException as e: + logger.error( + "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" + ) + logger.error("Initialize Text server engine Failed on device: %s." % + (self.device)) + return False + + self.executor = TextServerExecutor() + self.executor._init_from_path( + task=config.task, + model_type=config.model_type, + lang=config.lang, + cfg_path=config.cfg_path, + ckpt_path=config.ckpt_path, + vocab_file=config.vocab_file) + + logger.info("Init the text engine successfully") + return True diff --git a/paddlespeech/server/restful/api.py b/paddlespeech/server/restful/api.py index 3f91a03b..d5e422e3 100644 --- a/paddlespeech/server/restful/api.py +++ b/paddlespeech/server/restful/api.py @@ -11,12 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys from typing import List from fastapi import APIRouter +from paddlespeech.cli.log import logger from paddlespeech.server.restful.asr_api import router as asr_router from paddlespeech.server.restful.cls_api import router as cls_router +from paddlespeech.server.restful.text_api import router as text_router from paddlespeech.server.restful.tts_api import router as tts_router _router = APIRouter() @@ -38,7 +41,11 @@ def setup_router(api_list: List): _router.include_router(tts_router) elif api_name == 'cls': _router.include_router(cls_router) + elif api_name == 'text': + _router.include_router(text_router) else: - pass + logger.error( + f"PaddleSpeech has not support such service: {api_name}") + sys.exit(-1) return _router diff --git a/paddlespeech/server/restful/request.py b/paddlespeech/server/restful/request.py index dbac9dac..50416627 100644 --- a/paddlespeech/server/restful/request.py +++ b/paddlespeech/server/restful/request.py @@ -78,3 +78,10 @@ class CLSRequest(BaseModel): """ audio: str topk: int = 1 + + +#****************************************************************************************/ +#************************************ Text request **************************************/ +#****************************************************************************************/ +class TextRequest(BaseModel): + text: str diff --git a/paddlespeech/server/restful/response.py b/paddlespeech/server/restful/response.py index a2a207e4..5792959e 100644 --- a/paddlespeech/server/restful/response.py +++ b/paddlespeech/server/restful/response.py @@ -129,6 +129,30 @@ class CLSResponse(BaseModel): result: CLSResult +class TextResult(BaseModel): + punc_text: str + + +class TextResponse(BaseModel): + """ + response example + { + "success": true, + "code": 0, + "message": { + "description": "success" + }, + "result": { + "punc_text": "你好,飞桨" + } + } + """ + success: bool + code: int + message: Message + result: TextResult + + #****************************************************************************************/ #********************************** Error response **************************************/ #****************************************************************************************/ diff --git a/paddlespeech/server/restful/text_api.py b/paddlespeech/server/restful/text_api.py new file mode 100644 index 00000000..696630fb --- /dev/null +++ b/paddlespeech/server/restful/text_api.py @@ -0,0 +1,96 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import traceback +from typing import Union + +from fastapi import APIRouter + +from paddlespeech.cli.log import logger +from paddlespeech.server.engine.engine_pool import get_engine_pool +from paddlespeech.server.engine.text.python.text_engine import PaddleTextConnectionHandler +from paddlespeech.server.restful.request import TextRequest +from paddlespeech.server.restful.response import ErrorResponse +from paddlespeech.server.restful.response import TextResponse +from paddlespeech.server.utils.errors import ErrorCode +from paddlespeech.server.utils.errors import failed_response +from paddlespeech.server.utils.exception import ServerBaseException +router = APIRouter() + + +@router.get('/paddlespeech/text/help') +def help(): + """help + + Returns: + json: The /paddlespeech/text api response content + """ + response = { + "success": "True", + "code": 200, + "message": { + "global": "success" + }, + "result": { + "punc_text": "The punctuation text content" + } + } + return response + + +@router.post( + "/paddlespeech/text", response_model=Union[TextResponse, ErrorResponse]) +def asr(request_body: TextRequest): + """asr api + + Args: + request_body (TextRequest): the punctuation request body + + Returns: + json: the punctuation response body + """ + try: + # 1. we get the sentence content from the request + text = request_body.text + logger.info(f"Text service receive the {text}") + + # 2. get single engine from engine pool + # and each request has its own connection to process the text + engine_pool = get_engine_pool() + text_engine = engine_pool['text'] + connection_handler = PaddleTextConnectionHandler(text_engine) + punc_text = connection_handler.run(text) + logger.info(f"Get the Text Connection result {punc_text}") + + # 3. create the response + if punc_text is None: + punc_text = text + response = { + "success": True, + "code": 200, + "message": { + "description": "success" + }, + "result": { + "punc_text": punc_text + } + } + + logger.info(f"The Text Service final response: {response}") + except ServerBaseException as e: + response = failed_response(e.error_code, e.msg) + except BaseException: + response = failed_response(ErrorCode.SERVER_UNKOWN_ERR) + traceback.print_exc() + + return response diff --git a/paddlespeech/server/tests/text/http_client.py b/paddlespeech/server/tests/text/http_client.py new file mode 100644 index 00000000..c2eb3eb1 --- /dev/null +++ b/paddlespeech/server/tests/text/http_client.py @@ -0,0 +1,75 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import time + +import requests + +from paddlespeech.cli.log import logger + + +# Request and response +def text_client(args): + """ Request and response + Args: + text: A sentence to be processed by PaddleSpeech Text Server + outfile: The punctuation text + """ + url = "http://" + str(args.server) + ":" + str( + args.port) + "/paddlespeech/text" + request = { + "text": args.text, + } + + response = requests.post(url, json.dumps(request)) + response_dict = response.json() + punc_text = response_dict["result"]["punc_text"] + + # transform audio + outfile = args.output + if outfile: + with open(outfile, 'w') as w: + w.write(punc_text + "\n") + + logger.info(f"The punc text is: {punc_text}") + return punc_text + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + '--text', + type=str, + default="今天的天气真不错啊你下午有空吗我想约你一起去吃饭", + help='A sentence to be synthesized') + parser.add_argument( + '--output', type=str, default="./punc_text", help='Punc text file') + parser.add_argument( + "--server", type=str, help="server ip", default="127.0.0.1") + parser.add_argument("--port", type=int, help="server port", default=8090) + args = parser.parse_args() + + st = time.time() + try: + punc_text = text_client(args) + time_consume = time.time() - st + time_per_word = time_consume / len(args.text) + print("Text Process successfully.") + print("Inference time: %f" % (time_consume)) + print("The text length: %f" % (len(args.text))) + print("The time per work is: %f" % (time_per_word)) + except BaseException as e: + logger.info("Failed to Process text.") + logger.info(e) diff --git a/paddlespeech/server/util.py b/paddlespeech/server/util.py index ae3e9c6a..1f1b0be1 100644 --- a/paddlespeech/server/util.py +++ b/paddlespeech/server/util.py @@ -24,11 +24,11 @@ from typing import Any from typing import Dict import paddle -import paddleaudio import requests import yaml from paddle.framework import load +import paddleaudio from . import download from .entry import client_commands from .entry import server_commands