add text punc server, test=doc

pull/1772/head
xiongxinlei 3 years ago
parent 87ef68f127
commit ba62b85e9b

@ -22,8 +22,6 @@ from typing import Union
import paddle import paddle
import soundfile import soundfile
from paddleaudio.backends import load as load_audio
from paddleaudio.compliance.librosa import melspectrogram
from yacs.config import CfgNode from yacs.config import CfgNode
from ..executor import BaseExecutor from ..executor import BaseExecutor
@ -32,6 +30,8 @@ from ..utils import cli_register
from ..utils import stats_wrapper from ..utils import stats_wrapper
from .pretrained_models import model_alias from .pretrained_models import model_alias
from .pretrained_models import pretrained_models from .pretrained_models import pretrained_models
from paddleaudio.backends import load as load_audio
from paddleaudio.compliance.librosa import melspectrogram
from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.vector.io.batch import feature_normalize from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.modules.sid_model import SpeakerIdetification from paddlespeech.vector.modules.sid_model import SpeakerIdetification

@ -366,3 +366,77 @@ class CLSClientExecutor(BaseExecutor):
res = requests.post(url=url, data=json.dumps(data)) res = requests.post(url=url, data=json.dumps(data))
return res return res
@cli_client_register(
name='paddlespeech_client.text', description='visit the text service')
class TextClientExecutor(BaseExecutor):
def __init__(self):
super(TextClientExecutor, self).__init__()
self.parser = argparse.ArgumentParser(
prog='paddlespeech_client.text', add_help=True)
self.parser.add_argument(
'--server_ip', type=str, default='127.0.0.1', help='server ip')
self.parser.add_argument(
'--port', type=int, default=8090, help='server port')
self.parser.add_argument(
'--input',
type=str,
default=None,
help='sentence to be process by text server.',
required=True)
self.parser.add_argument(
'--output',
type=str,
default=None,
help='Return punctuation sentence.')
def execute(self, argv: List[str]) -> bool:
"""Execute the request from the argv.
Args:
argv (List): the request arguments
Returns:
str: the request flag
"""
args = self.parser.parse_args(argv)
input_ = args.input
server_ip = args.server_ip
port = args.port
output = args.output
try:
time_start = time.time()
res = self(input=input_, server_ip=server_ip, port=port)
time_end = time.time()
logger.info(f"The punc text: {res}")
logger.info("Response time %f s." % (time_end - time_start))
return True
except Exception as e:
logger.error("Failed to Text punctuation.")
return False
@stats_wrapper
def __call__(self, input: str, server_ip: str="127.0.0.1", port: int=8090):
"""
Python API to call text executor.
Args:
input (str): the request sentence text
server_ip (str, optional): the server ip. Defaults to "127.0.0.1".
port (int, optional): the server port. Defaults to 8090.
Returns:
str: the punctuation text
"""
url = 'http://' + server_ip + ":" + str(port) + '/paddlespeech/text'
request = {
"text": input,
}
res = requests.post(url=url, data=json.dumps(request))
response_dict = res.json()
punc_text = response_dict["result"]["punc_text"]
return punc_text

@ -0,0 +1,35 @@
# This is the parameter configuration file for PaddleSpeech Serving.
#################################################################################
# SERVER SETTING #
#################################################################################
host: 0.0.0.0
port: 8190
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_python', 'asr_inference', 'tts_python', 'tts_inference']
# protocol = ['websocket', 'http'] (only one can be selected).
# http only support offline engine type.
protocol: 'http'
engine_list: ['text_python']
#################################################################################
# ENGINE CONFIG #
#################################################################################
################################### ASR #########################################
################### speech task: asr; engine_type: python #######################
text_python:
task: punc
model_type: 'ernie_linear_p7_wudao'
lang: 'zh'
sample_rate: 16000
cfg_path: # [optional]
ckpt_path: # [optional]
vocab_file: # [optional]
device: # set 'gpu:id' or 'cpu'

@ -46,5 +46,8 @@ class EngineFactory(object):
elif engine_name == 'cls' and engine_type == 'python': elif engine_name == 'cls' and engine_type == 'python':
from paddlespeech.server.engine.cls.python.cls_engine import CLSEngine from paddlespeech.server.engine.cls.python.cls_engine import CLSEngine
return CLSEngine() return CLSEngine()
elif engine_name.lower() == 'text' and engine_type.lower() == 'python':
from paddlespeech.server.engine.text.python.text_engine import TextEngine
return TextEngine()
else: else:
return None return None

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,172 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
import paddle
from paddlespeech.cli.log import logger
from paddlespeech.cli.text.infer import TextExecutor
from paddlespeech.server.engine.base_engine import BaseEngine
class PaddleTextConnectionHandler:
def __init__(self, text_engine):
"""The PaddleSpeech Text Server Connection Handler
This connection process every server request
Args:
text_engine (TextEngine): The Text engine
"""
super().__init__()
logger.info(
"Create PaddleTextConnectionHandler to process the text request")
self.text_engine = text_engine
self.task = self.text_engine.executor.task
self.model = self.text_engine.executor.model
self.tokenizer = self.text_engine.executor.tokenizer
self._punc_list = self.text_engine.executor._punc_list
self._inputs = OrderedDict()
self._outputs = OrderedDict()
@paddle.no_grad()
def run(self, text):
"""The connection process the request text
Args:
text (str): the request text
Returns:
str: the punctuation text
"""
self.preprocess(text)
self.infer()
res = self.postprocess()
return res
@paddle.no_grad()
def preprocess(self, text):
"""
Input preprocess and return paddle.Tensor stored in self.input.
Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet).
Args:
text (str): the request text
"""
if self.task == 'punc':
clean_text = self.text_engine.executor._clean_text(text)
assert len(clean_text) > 0, f'Invalid input string: {text}'
tokenized_input = self.tokenizer(
list(clean_text), return_length=True, is_split_into_words=True)
self._inputs['input_ids'] = tokenized_input['input_ids']
self._inputs['seg_ids'] = tokenized_input['token_type_ids']
self._inputs['seq_len'] = tokenized_input['seq_len']
else:
raise NotImplementedError
@paddle.no_grad()
def infer(self):
"""Model inference and result stored in self.output.
"""
if self.task == 'punc':
input_ids = paddle.to_tensor(self._inputs['input_ids']).unsqueeze(0)
seg_ids = paddle.to_tensor(self._inputs['seg_ids']).unsqueeze(0)
logits, _ = self.model(input_ids, seg_ids)
preds = paddle.argmax(logits, axis=-1).squeeze(0)
self._outputs['preds'] = preds
else:
raise NotImplementedError
def postprocess(self):
"""Output postprocess and return human-readable results such as texts and audio files.
Returns:
str: The punctuation text
"""
if self.task == 'punc':
input_ids = self._inputs['input_ids']
seq_len = self._inputs['seq_len']
preds = self._outputs['preds']
tokens = self.tokenizer.convert_ids_to_tokens(
input_ids[1:seq_len - 1])
labels = preds[1:seq_len - 1].tolist()
assert len(tokens) == len(labels)
text = ''
for t, l in zip(tokens, labels):
text += t
if l != 0: # Non punc.
text += self._punc_list[l]
return text
else:
raise NotImplementedError
class TextServerExecutor(TextExecutor):
def __init__(self):
"""The wrapper for TextEcutor
"""
super().__init__()
pass
class TextEngine(BaseEngine):
def __init__(self):
"""The Text Engine
"""
super(TextEngine, self).__init__()
logger.info("Create the TextEngine Instance")
def init(self, config: dict):
"""Init the Text Engine
Args:
config (dict): The server configuation
Returns:
bool: The engine instance flag
"""
logger.info("Init the text engine")
try:
self.config = config
if self.config.device:
self.device = self.config.device
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
logger.info(f"Text Engine set the device: {self.device}")
except BaseException as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error("Initialize Text server engine Failed on device: %s." %
(self.device))
return False
self.executor = TextServerExecutor()
self.executor._init_from_path(
task=config.task,
model_type=config.model_type,
lang=config.lang,
cfg_path=config.cfg_path,
ckpt_path=config.ckpt_path,
vocab_file=config.vocab_file)
logger.info("Init the text engine successfully")
return True

@ -11,12 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys
from typing import List from typing import List
from fastapi import APIRouter from fastapi import APIRouter
from paddlespeech.cli.log import logger
from paddlespeech.server.restful.asr_api import router as asr_router from paddlespeech.server.restful.asr_api import router as asr_router
from paddlespeech.server.restful.cls_api import router as cls_router from paddlespeech.server.restful.cls_api import router as cls_router
from paddlespeech.server.restful.text_api import router as text_router
from paddlespeech.server.restful.tts_api import router as tts_router from paddlespeech.server.restful.tts_api import router as tts_router
_router = APIRouter() _router = APIRouter()
@ -38,7 +41,11 @@ def setup_router(api_list: List):
_router.include_router(tts_router) _router.include_router(tts_router)
elif api_name == 'cls': elif api_name == 'cls':
_router.include_router(cls_router) _router.include_router(cls_router)
elif api_name == 'text':
_router.include_router(text_router)
else: else:
pass logger.error(
f"PaddleSpeech has not support such service: {api_name}")
sys.exit(-1)
return _router return _router

@ -78,3 +78,10 @@ class CLSRequest(BaseModel):
""" """
audio: str audio: str
topk: int = 1 topk: int = 1
#****************************************************************************************/
#************************************ Text request **************************************/
#****************************************************************************************/
class TextRequest(BaseModel):
text: str

@ -129,6 +129,30 @@ class CLSResponse(BaseModel):
result: CLSResult result: CLSResult
class TextResult(BaseModel):
punc_text: str
class TextResponse(BaseModel):
"""
response example
{
"success": true,
"code": 0,
"message": {
"description": "success"
},
"result": {
"punc_text": "你好,飞桨"
}
}
"""
success: bool
code: int
message: Message
result: TextResult
#****************************************************************************************/ #****************************************************************************************/
#********************************** Error response **************************************/ #********************************** Error response **************************************/
#****************************************************************************************/ #****************************************************************************************/

@ -0,0 +1,98 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import traceback
from typing import Union
from fastapi import APIRouter
from paddlespeech.cli.log import logger
from paddlespeech.server.engine.engine_pool import get_engine_pool
from paddlespeech.server.engine.text.python.text_engine import PaddleTextConnectionHandler
from paddlespeech.server.restful.request import TextRequest
from paddlespeech.server.restful.response import ErrorResponse
from paddlespeech.server.restful.response import TextResponse
from paddlespeech.server.utils.errors import ErrorCode
from paddlespeech.server.utils.errors import failed_response
from paddlespeech.server.utils.exception import ServerBaseException
router = APIRouter()
@router.get('/paddlespeech/text/help')
def help():
"""help
Returns:
json: [description]
"""
response = {
"success": "True",
"code": 200,
"message": {
"global": "success"
},
"result": {
"description": "text server",
"input": "text string",
"output": "punctuation text"
}
}
return response
@router.post(
"/paddlespeech/text", response_model=Union[TextResponse, ErrorResponse])
def asr(request_body: TextRequest):
"""asr api
Args:
request_body (TextRequest): the punctuation request body
Returns:
json: the punctuation response body
"""
try:
# 1. we get the sentence content from the request
text = request_body.text
logger.info(f"Text service receive the {text}")
# 2. get single engine from engine pool
# and each request has its own connection to process the text
engine_pool = get_engine_pool()
text_engine = engine_pool['text']
connection_handler = PaddleTextConnectionHandler(text_engine)
punc_text = connection_handler.run(text)
logger.info(f"Get the Text Connection result {punc_text}")
# 3. create the response
if punc_text is None:
punc_text = text
response = {
"success": True,
"code": 200,
"message": {
"description": "success"
},
"result": {
"punc_text": punc_text
}
}
logger.info(f"The Text Service final response: {response}")
except ServerBaseException as e:
response = failed_response(e.error_code, e.msg)
except BaseException:
response = failed_response(ErrorCode.SERVER_UNKOWN_ERR)
traceback.print_exc()
return response

@ -0,0 +1,75 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import time
import requests
from paddlespeech.cli.log import logger
# Request and response
def text_client(args):
""" Request and response
Args:
text: A sentence to be processed by PaddleSpeech Text Server
outfile: The punctuation text
"""
url = "http://" + str(args.server) + ":" + str(
args.port) + "/paddlespeech/text"
request = {
"text": args.text,
}
response = requests.post(url, json.dumps(request))
response_dict = response.json()
punc_text = response_dict["result"]["punc_text"]
# transform audio
outfile = args.output
if outfile:
with open(outfile, 'w') as w:
w.write(punc_text + "\n")
logger.info(f"The punc text is: {punc_text}")
return punc_text
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--text',
type=str,
default="今天的天气真不错啊你下午有空吗我想约你一起去吃饭",
help='A sentence to be synthesized')
parser.add_argument(
'--output', type=str, default="./punc_text", help='Punc text file')
parser.add_argument(
"--server", type=str, help="server ip", default="127.0.0.1")
parser.add_argument("--port", type=int, help="server port", default=8090)
args = parser.parse_args()
st = time.time()
try:
punc_text = text_client(args)
time_consume = time.time() - st
time_per_word = time_consume / len(args.text)
print("Text Process successfully.")
print("Inference time: %f" % (time_consume))
print("The text length: %f" % (len(args.text)))
print("The time per work is: %f" % (time_per_word))
except BaseException as e:
logger.info("Failed to Process text.")
logger.info(e)

@ -24,11 +24,11 @@ from typing import Any
from typing import Dict from typing import Dict
import paddle import paddle
import paddleaudio
import requests import requests
import yaml import yaml
from paddle.framework import load from paddle.framework import load
import paddleaudio
from . import download from . import download
from .entry import client_commands from .entry import client_commands
from .entry import server_commands from .entry import server_commands

Loading…
Cancel
Save