add server cls, test=doc

pull/1554/head
lym0302 2 years ago
parent fd20056718
commit 99fa7a8205

@ -193,7 +193,8 @@ class CLSExecutor(BaseExecutor):
sr=feat_conf['sample_rate'], sr=feat_conf['sample_rate'],
mono=True, mono=True,
dtype='float32') dtype='float32')
logger.info("Preprocessing audio_file:" + audio_file) if isinstance(audio_file, (str, os.PathLike)):
logger.info("Preprocessing audio_file:" + audio_file)
# Feature extraction # Feature extraction
feature_extractor = LogMelSpectrogram( feature_extractor = LogMelSpectrogram(

@ -18,6 +18,7 @@ from .base_commands import ClientHelpCommand
from .base_commands import ServerBaseCommand from .base_commands import ServerBaseCommand
from .base_commands import ServerHelpCommand from .base_commands import ServerHelpCommand
from .bin.paddlespeech_client import ASRClientExecutor from .bin.paddlespeech_client import ASRClientExecutor
from .bin.paddlespeech_client import CLSClientExecutor
from .bin.paddlespeech_client import TTSClientExecutor from .bin.paddlespeech_client import TTSClientExecutor
from .bin.paddlespeech_server import ServerExecutor from .bin.paddlespeech_server import ServerExecutor

@ -31,7 +31,7 @@ from paddlespeech.cli.log import logger
from paddlespeech.server.utils.audio_process import wav2pcm from paddlespeech.server.utils.audio_process import wav2pcm
from paddlespeech.server.utils.util import wav2base64 from paddlespeech.server.utils.util import wav2base64
__all__ = ['TTSClientExecutor', 'ASRClientExecutor'] __all__ = ['TTSClientExecutor', 'ASRClientExecutor', 'CLSClientExecutor']
@cli_client_register( @cli_client_register(
@ -243,3 +243,71 @@ class ASRClientExecutor(BaseExecutor):
print("time cost %f s." % (time_end - time_start)) print("time cost %f s." % (time_end - time_start))
except BaseException: except BaseException:
print("Failed to speech recognition.") print("Failed to speech recognition.")
@cli_client_register(
name='paddlespeech_client.cls', description='visit cls service')
class CLSClientExecutor(BaseExecutor):
def __init__(self):
super(CLSClientExecutor, self).__init__()
self.parser = argparse.ArgumentParser(
prog='paddlespeech_client.cls', add_help=True)
self.parser.add_argument(
'--server_ip', type=str, default='127.0.0.1', help='server ip')
self.parser.add_argument(
'--port', type=int, default=8090, help='server port')
self.parser.add_argument(
'--input',
type=str,
default=None,
help='Audio file to classify.',
required=True)
self.parser.add_argument(
'--topk',
type=int,
default=1,
help='Return topk scores of classification result.')
def execute(self, argv: List[str]) -> bool:
args = self.parser.parse_args(argv)
url = 'http://' + args.server_ip + ":" + str(
args.port) + '/paddlespeech/cls'
audio = wav2base64(args.input)
data = {
"audio": audio,
"topk": args.topk,
}
time_start = time.time()
try:
r = requests.post(url=url, data=json.dumps(data))
# ending Timestamp
time_end = time.time()
logger.info(r.json())
logger.info("Response time %f s." % (time_end - time_start))
return True
except BaseException:
logger.error("Failed to speech classification.")
return False
@stats_wrapper
def __call__(self,
input: str,
server_ip: str="127.0.0.1",
port: int=8090,
topk: int=1):
"""
Python API to call an executor.
"""
url = 'http://' + server_ip + ":" + str(port) + '/paddlespeech/cls'
audio = wav2base64(input)
data = {"audio": audio, "topk": topk}
time_start = time.time()
try:
r = requests.post(url=url, data=json.dumps(data))
# ending Timestamp
time_end = time.time()
print(r.json())
print("Response time %f s." % (time_end - time_start))
except BaseException:
print("Failed to speech classification.")

@ -9,12 +9,16 @@ port: 8090
# The task format in the engin_list is: <speech task>_<engine type> # The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_python', 'asr_inference', 'tts_python', 'tts_inference'] # task choices = ['asr_python', 'asr_inference', 'tts_python', 'tts_inference']
engine_list: ['asr_python', 'tts_python'] #engine_list: ['asr_python', 'tts_python', 'cls_python']
engine_list: ['cls_inference']
#engine_list: ['asr_python', 'cls_python']
################################################################################# #################################################################################
# ENGINE CONFIG # # ENGINE CONFIG #
################################################################################# #################################################################################
################################### ASR #########################################
################### speech task: asr; engine_type: python ####################### ################### speech task: asr; engine_type: python #######################
asr_python: asr_python:
model: 'conformer_wenetspeech' model: 'conformer_wenetspeech'
@ -46,6 +50,7 @@ asr_inference:
summary: True # False -> do not show predictor config summary: True # False -> do not show predictor config
################################### TTS #########################################
################### speech task: tts; engine_type: python ####################### ################### speech task: tts; engine_type: python #######################
tts_python: tts_python:
# am (acoustic model) choices=['speedyspeech_csmsc', 'fastspeech2_csmsc', # am (acoustic model) choices=['speedyspeech_csmsc', 'fastspeech2_csmsc',
@ -105,3 +110,30 @@ tts_inference:
# others # others
lang: 'zh' lang: 'zh'
################################### CLS #########################################
################### speech task: cls; engine_type: python #######################
cls_python:
# model choices=['panns_cnn14', 'panns_cnn10', 'panns_cnn6']
model: 'panns_cnn14'
cfg_path: # [optional] Config of cls task.
ckpt_path: # [optional] Checkpoint file of model.
label_file: # [optional] Label file of cls task.
device: # set 'gpu:id' or 'cpu'
################### speech task: cls; engine_type: inference #######################
cls_inference:
# model_type choices=['panns_cnn14', 'panns_cnn10', 'panns_cnn6']
model_type: 'panns_cnn14'
cfg_path:
model_path: # the pdmodel file of am static model [optional]
params_path: # the pdiparams file of am static model [optional]
label_file: # [optional] Label file of cls task.
predictor_conf:
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True
glog_info: False # True -> print glog
summary: True # False -> do not show predictor config

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,225 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
import time
from typing import Optional
import numpy as np
import paddle
import yaml
from paddlespeech.cli.cls.infer import CLSExecutor
from paddlespeech.cli.log import logger
from paddlespeech.cli.utils import download_and_decompress
from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.paddle_predictor import init_predictor
from paddlespeech.server.utils.paddle_predictor import run_model
__all__ = ['CLSEngine']
pretrained_models = {
"panns_cnn6-32k": {
'url':
'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn6_static.tar.gz',
'md5':
'da087c31046d23281d8ec5188c1967da',
'cfg_path':
'panns.yaml',
'model_path':
'inference.pdmodel',
'params_path':
'inference.pdiparams',
'label_file':
'audioset_labels.txt',
},
"panns_cnn10-32k": {
'url':
'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn10_static.tar.gz',
'md5':
'5460cc6eafbfaf0f261cc75b90284ae1',
'cfg_path':
'panns.yaml',
'model_path':
'inference.pdmodel',
'params_path':
'inference.pdiparams',
'label_file':
'audioset_labels.txt',
},
"panns_cnn14-32k": {
'url':
'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn14_static.tar.gz',
'md5':
'ccc80b194821274da79466862b2ab00f',
'cfg_path':
'panns.yaml',
'model_path':
'inference.pdmodel',
'params_path':
'inference.pdiparams',
'label_file':
'audioset_labels.txt',
},
}
class CLSServerExecutor(CLSExecutor):
def __init__(self):
super().__init__()
pass
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
Download and returns pretrained resources path of current task.
"""
support_models = list(pretrained_models.keys())
assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format(
tag, '\n\t\t'.join(support_models))
res_path = os.path.join(MODEL_HOME, tag)
decompressed_path = download_and_decompress(pretrained_models[tag],
res_path)
print("aaaaaaaaaaaaa: ", decompressed_path)
decompressed_path = os.path.abspath(decompressed_path)
logger.info(
'Use pretrained model stored in: {}'.format(decompressed_path))
return decompressed_path
def _init_from_path(
self,
model_type: str='panns_cnn14',
cfg_path: Optional[os.PathLike]=None,
model_path: Optional[os.PathLike]=None,
params_path: Optional[os.PathLike]=None,
label_file: Optional[os.PathLike]=None,
predictor_conf: dict=None, ):
"""
Init model and other resources from a specific path.
"""
if cfg_path is None or model_path is None or params_path is None or label_file is None:
tag = model_type + '-' + '32k'
self.res_path = self._get_pretrained_path(tag)
self.cfg_path = os.path.join(self.res_path,
pretrained_models[tag]['cfg_path'])
self.model_path = os.path.join(self.res_path,
pretrained_models[tag]['model_path'])
self.params_path = os.path.join(
self.res_path, pretrained_models[tag]['params_path'])
self.label_file = os.path.join(self.res_path,
pretrained_models[tag]['label_file'])
else:
self.cfg_path = os.path.abspath(cfg_path)
self.model_path = os.path.abspath(model_path)
self.params_path = os.path.abspath(params_path)
self.label_file = os.path.abspath(label_file)
logger.info(self.cfg_path)
logger.info(self.model_path)
logger.info(self.params_path)
logger.info(self.label_file)
# config
with open(self.cfg_path, 'r') as f:
self._conf = yaml.safe_load(f)
logger.info("Read cfg file successfully.")
# labels
self._label_list = []
with open(self.label_file, 'r') as f:
for line in f:
self._label_list.append(line.strip())
logger.info("Read label file successfully.")
# Create predictor
self.predictor_conf = predictor_conf
self.predictor = init_predictor(
model_file=self.model_path,
params_file=self.params_path,
predictor_conf=self.predictor_conf)
logger.info("Create predictor successfully.")
@paddle.no_grad()
def infer(self):
"""
Model inference and result stored in self.output.
"""
output = run_model(self.predictor, [self._inputs['feats'].numpy()])
self._outputs['logits'] = output[0]
class CLSEngine(BaseEngine):
"""CLS server engine
Args:
metaclass: Defaults to Singleton.
"""
def __init__(self):
super(CLSEngine, self).__init__()
def init(self, config: dict) -> bool:
"""init engine resource
Args:
config_file (str): config file
Returns:
bool: init failed or success
"""
self.executor = CLSServerExecutor()
self.config = config
self.executor._init_from_path(
self.config.model_type, self.config.cfg_path,
self.config.model_path, self.config.params_path,
self.config.label_file, self.config.predictor_conf)
logger.info("Initialize CLS server engine successfully.")
return True
def run(self, audio_data):
"""engine run
Args:
audio_data (bytes): base64.b64decode
"""
self.executor.preprocess(io.BytesIO(audio_data))
st = time.time()
self.executor.infer()
infer_time = time.time() - st
logger.info("inference time: {}".format(infer_time))
logger.info("cls engine type: inference")
def postprocess(self, topk: int):
"""postprocess
"""
assert topk <= len(self.executor._label_list
), 'Value of topk is larger than number of labels.'
result = np.squeeze(self.executor._outputs['logits'], axis=0)
topk_idx = (-result).argsort()[:topk]
topk_results = []
for idx in topk_idx:
res = {}
label, score = self.executor._label_list[idx], result[idx]
res['class_name'] = label
res['prob'] = score
topk_results.append(res)
return topk_results

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,124 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import time
from typing import List
import paddle
from paddlespeech.cli.cls.infer import CLSExecutor
from paddlespeech.cli.log import logger
from paddlespeech.server.engine.base_engine import BaseEngine
__all__ = ['CLSEngine']
class CLSServerExecutor(CLSExecutor):
def __init__(self):
super().__init__()
pass
def get_topk_results(self, topk: int) -> List:
assert topk <= len(
self._label_list), 'Value of topk is larger than number of labels.'
result = self._outputs['logits'].squeeze(0).numpy()
topk_idx = (-result).argsort()[:topk]
res = {}
topk_results = []
for idx in topk_idx:
label, score = self._label_list[idx], result[idx]
res['class'] = label
res['prob'] = score
topk_results.append(res)
return topk_results
class CLSEngine(BaseEngine):
"""CLS server engine
Args:
metaclass: Defaults to Singleton.
"""
def __init__(self):
super(CLSEngine, self).__init__()
def init(self, config: dict) -> bool:
"""init engine resource
Args:
config_file (str): config file
Returns:
bool: init failed or success
"""
self.input = None
self.output = None
self.executor = CLSServerExecutor()
self.config = config
try:
if self.config.device:
self.device = self.config.device
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except BaseException:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
try:
self.executor._init_from_path(
self.config.model, self.config.cfg_path, self.config.ckpt_path,
self.config.label_file)
except BaseException:
logger.error("Initialize CLS server engine Failed.")
return False
logger.info("Initialize CLS server engine successfully on device: %s." %
(self.device))
return True
def run(self, audio_data):
"""engine run
Args:
audio_data (bytes): base64.b64decode
"""
self.executor.preprocess(io.BytesIO(audio_data))
st = time.time()
self.executor.infer()
infer_time = time.time() - st
logger.info("inference time: {}".format(infer_time))
logger.info("cls engine type: python")
def postprocess(self, topk: int):
"""postprocess
"""
assert topk <= len(self.executor._label_list
), 'Value of topk is larger than number of labels.'
result = self.executor._outputs['logits'].squeeze(0).numpy()
topk_idx = (-result).argsort()[:topk]
topk_results = []
for idx in topk_idx:
res = {}
label, score = self.executor._label_list[idx], result[idx]
res['class_name'] = label
res['prob'] = score
topk_results.append(res)
return topk_results

@ -31,5 +31,11 @@ class EngineFactory(object):
elif engine_name == 'tts' and engine_type == 'python': elif engine_name == 'tts' and engine_type == 'python':
from paddlespeech.server.engine.tts.python.tts_engine import TTSEngine from paddlespeech.server.engine.tts.python.tts_engine import TTSEngine
return TTSEngine() return TTSEngine()
elif engine_name == 'cls' and engine_type == 'inference':
from paddlespeech.server.engine.cls.paddleinference.cls_engine import CLSEngine
return CLSEngine()
elif engine_name == 'cls' and engine_type == 'python':
from paddlespeech.server.engine.cls.python.cls_engine import CLSEngine
return CLSEngine()
else: else:
return None return None

@ -250,27 +250,21 @@ class TTSServerExecutor(TTSExecutor):
self.frontend = English(phone_vocab_path=self.phones_dict) self.frontend = English(phone_vocab_path=self.phones_dict)
logger.info("frontend done!") logger.info("frontend done!")
try: # Create am predictor
# am predictor self.am_predictor_conf = am_predictor_conf
self.am_predictor_conf = am_predictor_conf self.am_predictor = init_predictor(
self.am_predictor = init_predictor( model_file=self.am_model,
model_file=self.am_model, params_file=self.am_params,
params_file=self.am_params, predictor_conf=self.am_predictor_conf)
predictor_conf=self.am_predictor_conf) logger.info("Create AM predictor successfully.")
logger.info("Create AM predictor successfully.")
except BaseException: # Create voc predictor
logger.error("Failed to create AM predictor.") self.voc_predictor_conf = voc_predictor_conf
self.voc_predictor = init_predictor(
try: model_file=self.voc_model,
# voc predictor params_file=self.voc_params,
self.voc_predictor_conf = voc_predictor_conf predictor_conf=self.voc_predictor_conf)
self.voc_predictor = init_predictor( logger.info("Create Vocoder predictor successfully.")
model_file=self.voc_model,
params_file=self.voc_params,
predictor_conf=self.voc_predictor_conf)
logger.info("Create Vocoder predictor successfully.")
except BaseException:
logger.error("Failed to create Vocoder predictor.")
@paddle.no_grad() @paddle.no_grad()
def infer(self, def infer(self,
@ -359,27 +353,22 @@ class TTSEngine(BaseEngine):
def init(self, config: dict) -> bool: def init(self, config: dict) -> bool:
self.executor = TTSServerExecutor() self.executor = TTSServerExecutor()
try: self.config = config
self.config = config self.executor._init_from_path(
self.executor._init_from_path( am=self.config.am,
am=self.config.am, am_model=self.config.am_model,
am_model=self.config.am_model, am_params=self.config.am_params,
am_params=self.config.am_params, am_sample_rate=self.config.am_sample_rate,
am_sample_rate=self.config.am_sample_rate, phones_dict=self.config.phones_dict,
phones_dict=self.config.phones_dict, tones_dict=self.config.tones_dict,
tones_dict=self.config.tones_dict, speaker_dict=self.config.speaker_dict,
speaker_dict=self.config.speaker_dict, voc=self.config.voc,
voc=self.config.voc, voc_model=self.config.voc_model,
voc_model=self.config.voc_model, voc_params=self.config.voc_params,
voc_params=self.config.voc_params, voc_sample_rate=self.config.voc_sample_rate,
voc_sample_rate=self.config.voc_sample_rate, lang=self.config.lang,
lang=self.config.lang, am_predictor_conf=self.config.am_predictor_conf,
am_predictor_conf=self.config.am_predictor_conf, voc_predictor_conf=self.config.voc_predictor_conf, )
voc_predictor_conf=self.config.voc_predictor_conf, )
except BaseException:
logger.error("Initialize TTS server engine Failed.")
return False
logger.info("Initialize TTS server engine successfully.") logger.info("Initialize TTS server engine successfully.")
return True return True

@ -16,6 +16,7 @@ from typing import List
from fastapi import APIRouter from fastapi import APIRouter
from paddlespeech.server.restful.asr_api import router as asr_router from paddlespeech.server.restful.asr_api import router as asr_router
from paddlespeech.server.restful.cls_api import router as cls_router
from paddlespeech.server.restful.tts_api import router as tts_router from paddlespeech.server.restful.tts_api import router as tts_router
_router = APIRouter() _router = APIRouter()
@ -25,7 +26,7 @@ def setup_router(api_list: List):
"""setup router for fastapi """setup router for fastapi
Args: Args:
api_list (List): [asr, tts] api_list (List): [asr, tts, cls]
Returns: Returns:
APIRouter APIRouter
@ -35,6 +36,8 @@ def setup_router(api_list: List):
_router.include_router(asr_router) _router.include_router(asr_router)
elif api_name == 'tts': elif api_name == 'tts':
_router.include_router(tts_router) _router.include_router(tts_router)
elif api_name == 'cls':
_router.include_router(cls_router)
else: else:
pass pass

@ -0,0 +1,92 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import traceback
from typing import Union
from fastapi import APIRouter
from paddlespeech.server.engine.engine_pool import get_engine_pool
from paddlespeech.server.restful.request import CLSRequest
from paddlespeech.server.restful.response import CLSResponse
from paddlespeech.server.restful.response import ErrorResponse
from paddlespeech.server.utils.errors import ErrorCode
from paddlespeech.server.utils.errors import failed_response
from paddlespeech.server.utils.exception import ServerBaseException
router = APIRouter()
@router.get('/paddlespeech/cls/help')
def help():
"""help
Returns:
json: [description]
"""
response = {
"success": "True",
"code": 200,
"message": {
"global": "success"
},
"result": {
"description": "cls server",
"input": "base64 string of wavfile",
"output": "classification result"
}
}
return response
@router.post(
"/paddlespeech/cls", response_model=Union[CLSResponse, ErrorResponse])
def cls(request_body: CLSRequest):
"""cls api
Args:
request_body (CLSRequest): [description]
Returns:
json: [description]
"""
try:
audio_data = base64.b64decode(request_body.audio)
# get single engine from engine pool
engine_pool = get_engine_pool()
cls_engine = engine_pool['cls']
cls_engine.run(audio_data)
cls_results = cls_engine.postprocess(request_body.topk)
response = {
"success": True,
"code": 200,
"message": {
"description": "success"
},
"result": {
"topk": request_body.topk,
"results": cls_results
}
}
except ServerBaseException as e:
response = failed_response(e.error_code, e.msg)
except BaseException:
response = failed_response(ErrorCode.SERVER_UNKOWN_ERR)
traceback.print_exc()
return response

@ -15,7 +15,7 @@ from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
__all__ = ['ASRRequest', 'TTSRequest'] __all__ = ['ASRRequest', 'TTSRequest', 'CLSRequest']
#****************************************************************************************/ #****************************************************************************************/
@ -63,3 +63,18 @@ class TTSRequest(BaseModel):
volume: float = 1.0 volume: float = 1.0
sample_rate: int = 0 sample_rate: int = 0
save_path: str = None save_path: str = None
#****************************************************************************************/
#************************************ CLS request ***************************************/
#****************************************************************************************/
class CLSRequest(BaseModel):
"""
request body example
{
"audio": "exSI6ICJlbiIsCgkgICAgInBvc2l0aW9uIjogImZhbHNlIgoJf...",
"topk": 1
}
"""
audio: str
topk: int = 1

@ -11,9 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List
from pydantic import BaseModel from pydantic import BaseModel
__all__ = ['ASRResponse', 'TTSResponse'] __all__ = ['ASRResponse', 'TTSResponse', 'CLSResponse']
class Message(BaseModel): class Message(BaseModel):
@ -85,6 +87,45 @@ class TTSResponse(BaseModel):
result: TTSResult result: TTSResult
#****************************************************************************************/
#************************************ CLS response **************************************/
#****************************************************************************************/
class CLSResults(BaseModel):
class_name: str
prob: float
class CLSResult(BaseModel):
topk: int
results: List[CLSResults]
class CLSResponse(BaseModel):
"""
response example
{
"success": true,
"code": 0,
"message": {
"description": "success"
},
"result": {
topk: 1
results: [
{
"class":"Speech",
"prob": 0.9027184844017029
}
]
}
}
"""
success: bool
code: int
message: Message
result: CLSResult
#****************************************************************************************/ #****************************************************************************************/
#********************************** Error response **************************************/ #********************************** Error response **************************************/
#****************************************************************************************/ #****************************************************************************************/

@ -35,10 +35,12 @@ def init_predictor(model_dir: Optional[os.PathLike]=None,
Returns: Returns:
predictor (PaddleInferPredictor): created predictor predictor (PaddleInferPredictor): created predictor
""" """
if model_dir is not None: if model_dir is not None:
assert os.path.isdir(model_dir), 'Please check model dir.'
config = Config(args.model_dir) config = Config(args.model_dir)
else: else:
assert os.path.isfile(model_file) and os.path.isfile(
params_file), 'Please check model and parameter files.'
config = Config(model_file, params_file) config = Config(model_file, params_file)
# set device # set device
@ -66,7 +68,6 @@ def init_predictor(model_dir: Optional[os.PathLike]=None,
config.enable_memory_optim() config.enable_memory_optim()
predictor = create_predictor(config) predictor = create_predictor(config)
return predictor return predictor
@ -84,10 +85,8 @@ def run_model(predictor, input: List) -> List:
for i, name in enumerate(input_names): for i, name in enumerate(input_names):
input_handle = predictor.get_input_handle(name) input_handle = predictor.get_input_handle(name)
input_handle.copy_from_cpu(input[i]) input_handle.copy_from_cpu(input[i])
# do the inference # do the inference
predictor.run() predictor.run()
results = [] results = []
# get out data from output tensor # get out data from output tensor
output_names = predictor.get_output_names() output_names = predictor.get_output_names()

Loading…
Cancel
Save