add server cls, test=doc

pull/1554/head
lym0302 3 years ago
parent fd20056718
commit 99fa7a8205

@ -193,7 +193,8 @@ class CLSExecutor(BaseExecutor):
sr=feat_conf['sample_rate'],
mono=True,
dtype='float32')
logger.info("Preprocessing audio_file:" + audio_file)
if isinstance(audio_file, (str, os.PathLike)):
logger.info("Preprocessing audio_file:" + audio_file)
# Feature extraction
feature_extractor = LogMelSpectrogram(

@ -18,6 +18,7 @@ from .base_commands import ClientHelpCommand
from .base_commands import ServerBaseCommand
from .base_commands import ServerHelpCommand
from .bin.paddlespeech_client import ASRClientExecutor
from .bin.paddlespeech_client import CLSClientExecutor
from .bin.paddlespeech_client import TTSClientExecutor
from .bin.paddlespeech_server import ServerExecutor

@ -31,7 +31,7 @@ from paddlespeech.cli.log import logger
from paddlespeech.server.utils.audio_process import wav2pcm
from paddlespeech.server.utils.util import wav2base64
__all__ = ['TTSClientExecutor', 'ASRClientExecutor']
__all__ = ['TTSClientExecutor', 'ASRClientExecutor', 'CLSClientExecutor']
@cli_client_register(
@ -243,3 +243,71 @@ class ASRClientExecutor(BaseExecutor):
print("time cost %f s." % (time_end - time_start))
except BaseException:
print("Failed to speech recognition.")
@cli_client_register(
name='paddlespeech_client.cls', description='visit cls service')
class CLSClientExecutor(BaseExecutor):
def __init__(self):
super(CLSClientExecutor, self).__init__()
self.parser = argparse.ArgumentParser(
prog='paddlespeech_client.cls', add_help=True)
self.parser.add_argument(
'--server_ip', type=str, default='127.0.0.1', help='server ip')
self.parser.add_argument(
'--port', type=int, default=8090, help='server port')
self.parser.add_argument(
'--input',
type=str,
default=None,
help='Audio file to classify.',
required=True)
self.parser.add_argument(
'--topk',
type=int,
default=1,
help='Return topk scores of classification result.')
def execute(self, argv: List[str]) -> bool:
args = self.parser.parse_args(argv)
url = 'http://' + args.server_ip + ":" + str(
args.port) + '/paddlespeech/cls'
audio = wav2base64(args.input)
data = {
"audio": audio,
"topk": args.topk,
}
time_start = time.time()
try:
r = requests.post(url=url, data=json.dumps(data))
# ending Timestamp
time_end = time.time()
logger.info(r.json())
logger.info("Response time %f s." % (time_end - time_start))
return True
except BaseException:
logger.error("Failed to speech classification.")
return False
@stats_wrapper
def __call__(self,
input: str,
server_ip: str="127.0.0.1",
port: int=8090,
topk: int=1):
"""
Python API to call an executor.
"""
url = 'http://' + server_ip + ":" + str(port) + '/paddlespeech/cls'
audio = wav2base64(input)
data = {"audio": audio, "topk": topk}
time_start = time.time()
try:
r = requests.post(url=url, data=json.dumps(data))
# ending Timestamp
time_end = time.time()
print(r.json())
print("Response time %f s." % (time_end - time_start))
except BaseException:
print("Failed to speech classification.")

@ -9,12 +9,16 @@ port: 8090
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_python', 'asr_inference', 'tts_python', 'tts_inference']
engine_list: ['asr_python', 'tts_python']
#engine_list: ['asr_python', 'tts_python', 'cls_python']
engine_list: ['cls_inference']
#engine_list: ['asr_python', 'cls_python']
#################################################################################
# ENGINE CONFIG #
#################################################################################
################################### ASR #########################################
################### speech task: asr; engine_type: python #######################
asr_python:
model: 'conformer_wenetspeech'
@ -46,6 +50,7 @@ asr_inference:
summary: True # False -> do not show predictor config
################################### TTS #########################################
################### speech task: tts; engine_type: python #######################
tts_python:
# am (acoustic model) choices=['speedyspeech_csmsc', 'fastspeech2_csmsc',
@ -105,3 +110,30 @@ tts_inference:
# others
lang: 'zh'
################################### CLS #########################################
################### speech task: cls; engine_type: python #######################
cls_python:
# model choices=['panns_cnn14', 'panns_cnn10', 'panns_cnn6']
model: 'panns_cnn14'
cfg_path: # [optional] Config of cls task.
ckpt_path: # [optional] Checkpoint file of model.
label_file: # [optional] Label file of cls task.
device: # set 'gpu:id' or 'cpu'
################### speech task: cls; engine_type: inference #######################
cls_inference:
# model_type choices=['panns_cnn14', 'panns_cnn10', 'panns_cnn6']
model_type: 'panns_cnn14'
cfg_path:
model_path: # the pdmodel file of am static model [optional]
params_path: # the pdiparams file of am static model [optional]
label_file: # [optional] Label file of cls task.
predictor_conf:
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True
glog_info: False # True -> print glog
summary: True # False -> do not show predictor config

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,225 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
import time
from typing import Optional
import numpy as np
import paddle
import yaml
from paddlespeech.cli.cls.infer import CLSExecutor
from paddlespeech.cli.log import logger
from paddlespeech.cli.utils import download_and_decompress
from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.paddle_predictor import init_predictor
from paddlespeech.server.utils.paddle_predictor import run_model
__all__ = ['CLSEngine']
pretrained_models = {
"panns_cnn6-32k": {
'url':
'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn6_static.tar.gz',
'md5':
'da087c31046d23281d8ec5188c1967da',
'cfg_path':
'panns.yaml',
'model_path':
'inference.pdmodel',
'params_path':
'inference.pdiparams',
'label_file':
'audioset_labels.txt',
},
"panns_cnn10-32k": {
'url':
'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn10_static.tar.gz',
'md5':
'5460cc6eafbfaf0f261cc75b90284ae1',
'cfg_path':
'panns.yaml',
'model_path':
'inference.pdmodel',
'params_path':
'inference.pdiparams',
'label_file':
'audioset_labels.txt',
},
"panns_cnn14-32k": {
'url':
'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn14_static.tar.gz',
'md5':
'ccc80b194821274da79466862b2ab00f',
'cfg_path':
'panns.yaml',
'model_path':
'inference.pdmodel',
'params_path':
'inference.pdiparams',
'label_file':
'audioset_labels.txt',
},
}
class CLSServerExecutor(CLSExecutor):
def __init__(self):
super().__init__()
pass
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
Download and returns pretrained resources path of current task.
"""
support_models = list(pretrained_models.keys())
assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format(
tag, '\n\t\t'.join(support_models))
res_path = os.path.join(MODEL_HOME, tag)
decompressed_path = download_and_decompress(pretrained_models[tag],
res_path)
print("aaaaaaaaaaaaa: ", decompressed_path)
decompressed_path = os.path.abspath(decompressed_path)
logger.info(
'Use pretrained model stored in: {}'.format(decompressed_path))
return decompressed_path
def _init_from_path(
self,
model_type: str='panns_cnn14',
cfg_path: Optional[os.PathLike]=None,
model_path: Optional[os.PathLike]=None,
params_path: Optional[os.PathLike]=None,
label_file: Optional[os.PathLike]=None,
predictor_conf: dict=None, ):
"""
Init model and other resources from a specific path.
"""
if cfg_path is None or model_path is None or params_path is None or label_file is None:
tag = model_type + '-' + '32k'
self.res_path = self._get_pretrained_path(tag)
self.cfg_path = os.path.join(self.res_path,
pretrained_models[tag]['cfg_path'])
self.model_path = os.path.join(self.res_path,
pretrained_models[tag]['model_path'])
self.params_path = os.path.join(
self.res_path, pretrained_models[tag]['params_path'])
self.label_file = os.path.join(self.res_path,
pretrained_models[tag]['label_file'])
else:
self.cfg_path = os.path.abspath(cfg_path)
self.model_path = os.path.abspath(model_path)
self.params_path = os.path.abspath(params_path)
self.label_file = os.path.abspath(label_file)
logger.info(self.cfg_path)
logger.info(self.model_path)
logger.info(self.params_path)
logger.info(self.label_file)
# config
with open(self.cfg_path, 'r') as f:
self._conf = yaml.safe_load(f)
logger.info("Read cfg file successfully.")
# labels
self._label_list = []
with open(self.label_file, 'r') as f:
for line in f:
self._label_list.append(line.strip())
logger.info("Read label file successfully.")
# Create predictor
self.predictor_conf = predictor_conf
self.predictor = init_predictor(
model_file=self.model_path,
params_file=self.params_path,
predictor_conf=self.predictor_conf)
logger.info("Create predictor successfully.")
@paddle.no_grad()
def infer(self):
"""
Model inference and result stored in self.output.
"""
output = run_model(self.predictor, [self._inputs['feats'].numpy()])
self._outputs['logits'] = output[0]
class CLSEngine(BaseEngine):
"""CLS server engine
Args:
metaclass: Defaults to Singleton.
"""
def __init__(self):
super(CLSEngine, self).__init__()
def init(self, config: dict) -> bool:
"""init engine resource
Args:
config_file (str): config file
Returns:
bool: init failed or success
"""
self.executor = CLSServerExecutor()
self.config = config
self.executor._init_from_path(
self.config.model_type, self.config.cfg_path,
self.config.model_path, self.config.params_path,
self.config.label_file, self.config.predictor_conf)
logger.info("Initialize CLS server engine successfully.")
return True
def run(self, audio_data):
"""engine run
Args:
audio_data (bytes): base64.b64decode
"""
self.executor.preprocess(io.BytesIO(audio_data))
st = time.time()
self.executor.infer()
infer_time = time.time() - st
logger.info("inference time: {}".format(infer_time))
logger.info("cls engine type: inference")
def postprocess(self, topk: int):
"""postprocess
"""
assert topk <= len(self.executor._label_list
), 'Value of topk is larger than number of labels.'
result = np.squeeze(self.executor._outputs['logits'], axis=0)
topk_idx = (-result).argsort()[:topk]
topk_results = []
for idx in topk_idx:
res = {}
label, score = self.executor._label_list[idx], result[idx]
res['class_name'] = label
res['prob'] = score
topk_results.append(res)
return topk_results

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,124 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import time
from typing import List
import paddle
from paddlespeech.cli.cls.infer import CLSExecutor
from paddlespeech.cli.log import logger
from paddlespeech.server.engine.base_engine import BaseEngine
__all__ = ['CLSEngine']
class CLSServerExecutor(CLSExecutor):
def __init__(self):
super().__init__()
pass
def get_topk_results(self, topk: int) -> List:
assert topk <= len(
self._label_list), 'Value of topk is larger than number of labels.'
result = self._outputs['logits'].squeeze(0).numpy()
topk_idx = (-result).argsort()[:topk]
res = {}
topk_results = []
for idx in topk_idx:
label, score = self._label_list[idx], result[idx]
res['class'] = label
res['prob'] = score
topk_results.append(res)
return topk_results
class CLSEngine(BaseEngine):
"""CLS server engine
Args:
metaclass: Defaults to Singleton.
"""
def __init__(self):
super(CLSEngine, self).__init__()
def init(self, config: dict) -> bool:
"""init engine resource
Args:
config_file (str): config file
Returns:
bool: init failed or success
"""
self.input = None
self.output = None
self.executor = CLSServerExecutor()
self.config = config
try:
if self.config.device:
self.device = self.config.device
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except BaseException:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
try:
self.executor._init_from_path(
self.config.model, self.config.cfg_path, self.config.ckpt_path,
self.config.label_file)
except BaseException:
logger.error("Initialize CLS server engine Failed.")
return False
logger.info("Initialize CLS server engine successfully on device: %s." %
(self.device))
return True
def run(self, audio_data):
"""engine run
Args:
audio_data (bytes): base64.b64decode
"""
self.executor.preprocess(io.BytesIO(audio_data))
st = time.time()
self.executor.infer()
infer_time = time.time() - st
logger.info("inference time: {}".format(infer_time))
logger.info("cls engine type: python")
def postprocess(self, topk: int):
"""postprocess
"""
assert topk <= len(self.executor._label_list
), 'Value of topk is larger than number of labels.'
result = self.executor._outputs['logits'].squeeze(0).numpy()
topk_idx = (-result).argsort()[:topk]
topk_results = []
for idx in topk_idx:
res = {}
label, score = self.executor._label_list[idx], result[idx]
res['class_name'] = label
res['prob'] = score
topk_results.append(res)
return topk_results

@ -31,5 +31,11 @@ class EngineFactory(object):
elif engine_name == 'tts' and engine_type == 'python':
from paddlespeech.server.engine.tts.python.tts_engine import TTSEngine
return TTSEngine()
elif engine_name == 'cls' and engine_type == 'inference':
from paddlespeech.server.engine.cls.paddleinference.cls_engine import CLSEngine
return CLSEngine()
elif engine_name == 'cls' and engine_type == 'python':
from paddlespeech.server.engine.cls.python.cls_engine import CLSEngine
return CLSEngine()
else:
return None

@ -250,27 +250,21 @@ class TTSServerExecutor(TTSExecutor):
self.frontend = English(phone_vocab_path=self.phones_dict)
logger.info("frontend done!")
try:
# am predictor
self.am_predictor_conf = am_predictor_conf
self.am_predictor = init_predictor(
model_file=self.am_model,
params_file=self.am_params,
predictor_conf=self.am_predictor_conf)
logger.info("Create AM predictor successfully.")
except BaseException:
logger.error("Failed to create AM predictor.")
try:
# voc predictor
self.voc_predictor_conf = voc_predictor_conf
self.voc_predictor = init_predictor(
model_file=self.voc_model,
params_file=self.voc_params,
predictor_conf=self.voc_predictor_conf)
logger.info("Create Vocoder predictor successfully.")
except BaseException:
logger.error("Failed to create Vocoder predictor.")
# Create am predictor
self.am_predictor_conf = am_predictor_conf
self.am_predictor = init_predictor(
model_file=self.am_model,
params_file=self.am_params,
predictor_conf=self.am_predictor_conf)
logger.info("Create AM predictor successfully.")
# Create voc predictor
self.voc_predictor_conf = voc_predictor_conf
self.voc_predictor = init_predictor(
model_file=self.voc_model,
params_file=self.voc_params,
predictor_conf=self.voc_predictor_conf)
logger.info("Create Vocoder predictor successfully.")
@paddle.no_grad()
def infer(self,
@ -359,27 +353,22 @@ class TTSEngine(BaseEngine):
def init(self, config: dict) -> bool:
self.executor = TTSServerExecutor()
try:
self.config = config
self.executor._init_from_path(
am=self.config.am,
am_model=self.config.am_model,
am_params=self.config.am_params,
am_sample_rate=self.config.am_sample_rate,
phones_dict=self.config.phones_dict,
tones_dict=self.config.tones_dict,
speaker_dict=self.config.speaker_dict,
voc=self.config.voc,
voc_model=self.config.voc_model,
voc_params=self.config.voc_params,
voc_sample_rate=self.config.voc_sample_rate,
lang=self.config.lang,
am_predictor_conf=self.config.am_predictor_conf,
voc_predictor_conf=self.config.voc_predictor_conf, )
except BaseException:
logger.error("Initialize TTS server engine Failed.")
return False
self.config = config
self.executor._init_from_path(
am=self.config.am,
am_model=self.config.am_model,
am_params=self.config.am_params,
am_sample_rate=self.config.am_sample_rate,
phones_dict=self.config.phones_dict,
tones_dict=self.config.tones_dict,
speaker_dict=self.config.speaker_dict,
voc=self.config.voc,
voc_model=self.config.voc_model,
voc_params=self.config.voc_params,
voc_sample_rate=self.config.voc_sample_rate,
lang=self.config.lang,
am_predictor_conf=self.config.am_predictor_conf,
voc_predictor_conf=self.config.voc_predictor_conf, )
logger.info("Initialize TTS server engine successfully.")
return True

@ -16,6 +16,7 @@ from typing import List
from fastapi import APIRouter
from paddlespeech.server.restful.asr_api import router as asr_router
from paddlespeech.server.restful.cls_api import router as cls_router
from paddlespeech.server.restful.tts_api import router as tts_router
_router = APIRouter()
@ -25,7 +26,7 @@ def setup_router(api_list: List):
"""setup router for fastapi
Args:
api_list (List): [asr, tts]
api_list (List): [asr, tts, cls]
Returns:
APIRouter
@ -35,6 +36,8 @@ def setup_router(api_list: List):
_router.include_router(asr_router)
elif api_name == 'tts':
_router.include_router(tts_router)
elif api_name == 'cls':
_router.include_router(cls_router)
else:
pass

@ -0,0 +1,92 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import traceback
from typing import Union
from fastapi import APIRouter
from paddlespeech.server.engine.engine_pool import get_engine_pool
from paddlespeech.server.restful.request import CLSRequest
from paddlespeech.server.restful.response import CLSResponse
from paddlespeech.server.restful.response import ErrorResponse
from paddlespeech.server.utils.errors import ErrorCode
from paddlespeech.server.utils.errors import failed_response
from paddlespeech.server.utils.exception import ServerBaseException
router = APIRouter()
@router.get('/paddlespeech/cls/help')
def help():
"""help
Returns:
json: [description]
"""
response = {
"success": "True",
"code": 200,
"message": {
"global": "success"
},
"result": {
"description": "cls server",
"input": "base64 string of wavfile",
"output": "classification result"
}
}
return response
@router.post(
"/paddlespeech/cls", response_model=Union[CLSResponse, ErrorResponse])
def cls(request_body: CLSRequest):
"""cls api
Args:
request_body (CLSRequest): [description]
Returns:
json: [description]
"""
try:
audio_data = base64.b64decode(request_body.audio)
# get single engine from engine pool
engine_pool = get_engine_pool()
cls_engine = engine_pool['cls']
cls_engine.run(audio_data)
cls_results = cls_engine.postprocess(request_body.topk)
response = {
"success": True,
"code": 200,
"message": {
"description": "success"
},
"result": {
"topk": request_body.topk,
"results": cls_results
}
}
except ServerBaseException as e:
response = failed_response(e.error_code, e.msg)
except BaseException:
response = failed_response(ErrorCode.SERVER_UNKOWN_ERR)
traceback.print_exc()
return response

@ -15,7 +15,7 @@ from typing import Optional
from pydantic import BaseModel
__all__ = ['ASRRequest', 'TTSRequest']
__all__ = ['ASRRequest', 'TTSRequest', 'CLSRequest']
#****************************************************************************************/
@ -63,3 +63,18 @@ class TTSRequest(BaseModel):
volume: float = 1.0
sample_rate: int = 0
save_path: str = None
#****************************************************************************************/
#************************************ CLS request ***************************************/
#****************************************************************************************/
class CLSRequest(BaseModel):
"""
request body example
{
"audio": "exSI6ICJlbiIsCgkgICAgInBvc2l0aW9uIjogImZhbHNlIgoJf...",
"topk": 1
}
"""
audio: str
topk: int = 1

@ -11,9 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from pydantic import BaseModel
__all__ = ['ASRResponse', 'TTSResponse']
__all__ = ['ASRResponse', 'TTSResponse', 'CLSResponse']
class Message(BaseModel):
@ -85,6 +87,45 @@ class TTSResponse(BaseModel):
result: TTSResult
#****************************************************************************************/
#************************************ CLS response **************************************/
#****************************************************************************************/
class CLSResults(BaseModel):
class_name: str
prob: float
class CLSResult(BaseModel):
topk: int
results: List[CLSResults]
class CLSResponse(BaseModel):
"""
response example
{
"success": true,
"code": 0,
"message": {
"description": "success"
},
"result": {
topk: 1
results: [
{
"class":"Speech",
"prob": 0.9027184844017029
}
]
}
}
"""
success: bool
code: int
message: Message
result: CLSResult
#****************************************************************************************/
#********************************** Error response **************************************/
#****************************************************************************************/

@ -35,10 +35,12 @@ def init_predictor(model_dir: Optional[os.PathLike]=None,
Returns:
predictor (PaddleInferPredictor): created predictor
"""
if model_dir is not None:
assert os.path.isdir(model_dir), 'Please check model dir.'
config = Config(args.model_dir)
else:
assert os.path.isfile(model_file) and os.path.isfile(
params_file), 'Please check model and parameter files.'
config = Config(model_file, params_file)
# set device
@ -66,7 +68,6 @@ def init_predictor(model_dir: Optional[os.PathLike]=None,
config.enable_memory_optim()
predictor = create_predictor(config)
return predictor
@ -84,10 +85,8 @@ def run_model(predictor, input: List) -> List:
for i, name in enumerate(input_names):
input_handle = predictor.get_input_handle(name)
input_handle.copy_from_cpu(input[i])
# do the inference
predictor.run()
results = []
# get out data from output tensor
output_names = predictor.get_output_names()

Loading…
Cancel
Save