commit
3b0004345c
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,172 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections import OrderedDict
|
||||
|
||||
import paddle
|
||||
|
||||
from paddlespeech.cli.log import logger
|
||||
from paddlespeech.cli.text.infer import TextExecutor
|
||||
from paddlespeech.server.engine.base_engine import BaseEngine
|
||||
|
||||
|
||||
class PaddleTextConnectionHandler:
|
||||
def __init__(self, text_engine):
|
||||
"""The PaddleSpeech Text Server Connection Handler
|
||||
This connection process every server request
|
||||
Args:
|
||||
text_engine (TextEngine): The Text engine
|
||||
"""
|
||||
super().__init__()
|
||||
logger.info(
|
||||
"Create PaddleTextConnectionHandler to process the text request")
|
||||
self.text_engine = text_engine
|
||||
self.task = self.text_engine.executor.task
|
||||
self.model = self.text_engine.executor.model
|
||||
self.tokenizer = self.text_engine.executor.tokenizer
|
||||
self._punc_list = self.text_engine.executor._punc_list
|
||||
self._inputs = OrderedDict()
|
||||
self._outputs = OrderedDict()
|
||||
|
||||
@paddle.no_grad()
|
||||
def run(self, text):
|
||||
"""The connection process the request text
|
||||
|
||||
Args:
|
||||
text (str): the request text
|
||||
|
||||
Returns:
|
||||
str: the punctuation text
|
||||
"""
|
||||
self.preprocess(text)
|
||||
self.infer()
|
||||
res = self.postprocess()
|
||||
|
||||
return res
|
||||
|
||||
@paddle.no_grad()
|
||||
def preprocess(self, text):
|
||||
"""
|
||||
Input preprocess and return paddle.Tensor stored in self.input.
|
||||
Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet).
|
||||
|
||||
Args:
|
||||
text (str): the request text
|
||||
"""
|
||||
if self.task == 'punc':
|
||||
clean_text = self.text_engine.executor._clean_text(text)
|
||||
assert len(clean_text) > 0, f'Invalid input string: {text}'
|
||||
|
||||
tokenized_input = self.tokenizer(
|
||||
list(clean_text), return_length=True, is_split_into_words=True)
|
||||
|
||||
self._inputs['input_ids'] = tokenized_input['input_ids']
|
||||
self._inputs['seg_ids'] = tokenized_input['token_type_ids']
|
||||
self._inputs['seq_len'] = tokenized_input['seq_len']
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@paddle.no_grad()
|
||||
def infer(self):
|
||||
"""Model inference and result stored in self.output.
|
||||
"""
|
||||
if self.task == 'punc':
|
||||
input_ids = paddle.to_tensor(self._inputs['input_ids']).unsqueeze(0)
|
||||
seg_ids = paddle.to_tensor(self._inputs['seg_ids']).unsqueeze(0)
|
||||
logits, _ = self.model(input_ids, seg_ids)
|
||||
preds = paddle.argmax(logits, axis=-1).squeeze(0)
|
||||
|
||||
self._outputs['preds'] = preds
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def postprocess(self):
|
||||
"""Output postprocess and return human-readable results such as texts and audio files.
|
||||
|
||||
Returns:
|
||||
str: The punctuation text
|
||||
"""
|
||||
if self.task == 'punc':
|
||||
input_ids = self._inputs['input_ids']
|
||||
seq_len = self._inputs['seq_len']
|
||||
preds = self._outputs['preds']
|
||||
|
||||
tokens = self.tokenizer.convert_ids_to_tokens(
|
||||
input_ids[1:seq_len - 1])
|
||||
labels = preds[1:seq_len - 1].tolist()
|
||||
assert len(tokens) == len(labels)
|
||||
|
||||
text = ''
|
||||
for t, l in zip(tokens, labels):
|
||||
text += t
|
||||
if l != 0: # Non punc.
|
||||
text += self._punc_list[l]
|
||||
|
||||
return text
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TextServerExecutor(TextExecutor):
|
||||
def __init__(self):
|
||||
"""The wrapper for TextEcutor
|
||||
"""
|
||||
super().__init__()
|
||||
pass
|
||||
|
||||
|
||||
class TextEngine(BaseEngine):
|
||||
def __init__(self):
|
||||
"""The Text Engine
|
||||
"""
|
||||
super(TextEngine, self).__init__()
|
||||
logger.info("Create the TextEngine Instance")
|
||||
|
||||
def init(self, config: dict):
|
||||
"""Init the Text Engine
|
||||
|
||||
Args:
|
||||
config (dict): The server configuation
|
||||
|
||||
Returns:
|
||||
bool: The engine instance flag
|
||||
"""
|
||||
logger.info("Init the text engine")
|
||||
try:
|
||||
self.config = config
|
||||
if self.config.device:
|
||||
self.device = self.config.device
|
||||
else:
|
||||
self.device = paddle.get_device()
|
||||
|
||||
paddle.set_device(self.device)
|
||||
logger.info(f"Text Engine set the device: {self.device}")
|
||||
except BaseException as e:
|
||||
logger.error(
|
||||
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
|
||||
)
|
||||
logger.error("Initialize Text server engine Failed on device: %s." %
|
||||
(self.device))
|
||||
return False
|
||||
|
||||
self.executor = TextServerExecutor()
|
||||
self.executor._init_from_path(
|
||||
task=config.task,
|
||||
model_type=config.model_type,
|
||||
lang=config.lang,
|
||||
cfg_path=config.cfg_path,
|
||||
ckpt_path=config.ckpt_path,
|
||||
vocab_file=config.vocab_file)
|
||||
|
||||
logger.info("Init the text engine successfully")
|
||||
return True
|
@ -0,0 +1,96 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import traceback
|
||||
from typing import Union
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from paddlespeech.cli.log import logger
|
||||
from paddlespeech.server.engine.engine_pool import get_engine_pool
|
||||
from paddlespeech.server.engine.text.python.text_engine import PaddleTextConnectionHandler
|
||||
from paddlespeech.server.restful.request import TextRequest
|
||||
from paddlespeech.server.restful.response import ErrorResponse
|
||||
from paddlespeech.server.restful.response import TextResponse
|
||||
from paddlespeech.server.utils.errors import ErrorCode
|
||||
from paddlespeech.server.utils.errors import failed_response
|
||||
from paddlespeech.server.utils.exception import ServerBaseException
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get('/paddlespeech/text/help')
|
||||
def help():
|
||||
"""help
|
||||
|
||||
Returns:
|
||||
json: The /paddlespeech/text api response content
|
||||
"""
|
||||
response = {
|
||||
"success": "True",
|
||||
"code": 200,
|
||||
"message": {
|
||||
"global": "success"
|
||||
},
|
||||
"result": {
|
||||
"punc_text": "The punctuation text content"
|
||||
}
|
||||
}
|
||||
return response
|
||||
|
||||
|
||||
@router.post(
|
||||
"/paddlespeech/text", response_model=Union[TextResponse, ErrorResponse])
|
||||
def asr(request_body: TextRequest):
|
||||
"""asr api
|
||||
|
||||
Args:
|
||||
request_body (TextRequest): the punctuation request body
|
||||
|
||||
Returns:
|
||||
json: the punctuation response body
|
||||
"""
|
||||
try:
|
||||
# 1. we get the sentence content from the request
|
||||
text = request_body.text
|
||||
logger.info(f"Text service receive the {text}")
|
||||
|
||||
# 2. get single engine from engine pool
|
||||
# and each request has its own connection to process the text
|
||||
engine_pool = get_engine_pool()
|
||||
text_engine = engine_pool['text']
|
||||
connection_handler = PaddleTextConnectionHandler(text_engine)
|
||||
punc_text = connection_handler.run(text)
|
||||
logger.info(f"Get the Text Connection result {punc_text}")
|
||||
|
||||
# 3. create the response
|
||||
if punc_text is None:
|
||||
punc_text = text
|
||||
response = {
|
||||
"success": True,
|
||||
"code": 200,
|
||||
"message": {
|
||||
"description": "success"
|
||||
},
|
||||
"result": {
|
||||
"punc_text": punc_text
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(f"The Text Service final response: {response}")
|
||||
except ServerBaseException as e:
|
||||
response = failed_response(e.error_code, e.msg)
|
||||
except BaseException:
|
||||
response = failed_response(ErrorCode.SERVER_UNKOWN_ERR)
|
||||
traceback.print_exc()
|
||||
|
||||
return response
|
@ -0,0 +1,75 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from paddlespeech.cli.log import logger
|
||||
|
||||
|
||||
# Request and response
|
||||
def text_client(args):
|
||||
""" Request and response
|
||||
Args:
|
||||
text: A sentence to be processed by PaddleSpeech Text Server
|
||||
outfile: The punctuation text
|
||||
"""
|
||||
url = "http://" + str(args.server) + ":" + str(
|
||||
args.port) + "/paddlespeech/text"
|
||||
request = {
|
||||
"text": args.text,
|
||||
}
|
||||
|
||||
response = requests.post(url, json.dumps(request))
|
||||
response_dict = response.json()
|
||||
punc_text = response_dict["result"]["punc_text"]
|
||||
|
||||
# transform audio
|
||||
outfile = args.output
|
||||
if outfile:
|
||||
with open(outfile, 'w') as w:
|
||||
w.write(punc_text + "\n")
|
||||
|
||||
logger.info(f"The punc text is: {punc_text}")
|
||||
return punc_text
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--text',
|
||||
type=str,
|
||||
default="今天的天气真不错啊你下午有空吗我想约你一起去吃饭",
|
||||
help='A sentence to be synthesized')
|
||||
parser.add_argument(
|
||||
'--output', type=str, default="./punc_text", help='Punc text file')
|
||||
parser.add_argument(
|
||||
"--server", type=str, help="server ip", default="127.0.0.1")
|
||||
parser.add_argument("--port", type=int, help="server port", default=8090)
|
||||
args = parser.parse_args()
|
||||
|
||||
st = time.time()
|
||||
try:
|
||||
punc_text = text_client(args)
|
||||
time_consume = time.time() - st
|
||||
time_per_word = time_consume / len(args.text)
|
||||
print("Text Process successfully.")
|
||||
print("Inference time: %f" % (time_consume))
|
||||
print("The text length: %f" % (len(args.text)))
|
||||
print("The time per work is: %f" % (time_per_word))
|
||||
except BaseException as e:
|
||||
logger.info("Failed to Process text.")
|
||||
logger.info(e)
|
Loading…
Reference in new issue