commit
dd9cc3272c
@ -0,0 +1,47 @@
|
||||
# This is the parameter configuration file for TTS server.
|
||||
# These are the static models that support paddle inference.
|
||||
|
||||
##################################################################
|
||||
# TTS SERVER SETTING #
|
||||
##################################################################
|
||||
host: '0.0.0.0'
|
||||
port: 8692
|
||||
|
||||
##################################################################
|
||||
# ACOUSTIC MODEL SETTING #
|
||||
# am choices=['speedyspeech_csmsc', 'fastspeech2_csmsc']
|
||||
##################################################################
|
||||
am: 'fastspeech2_csmsc'
|
||||
am_model: # the pdmodel file of am static model
|
||||
am_params: # the pdiparams file of am static model
|
||||
am_sample_rate: 24000
|
||||
phones_dict:
|
||||
tones_dict:
|
||||
speaker_dict:
|
||||
spk_id: 0
|
||||
|
||||
am_predictor_conf:
|
||||
use_gpu: True
|
||||
enable_mkldnn: True
|
||||
switch_ir_optim: True
|
||||
|
||||
|
||||
##################################################################
|
||||
# VOCODER SETTING #
|
||||
# voc choices=['pwgan_csmsc', 'mb_melgan_csmsc','hifigan_csmsc']
|
||||
##################################################################
|
||||
voc: 'pwgan_csmsc'
|
||||
voc_model: # the pdmodel file of vocoder static model
|
||||
voc_params: # the pdiparams file of vocoder static model
|
||||
voc_sample_rate: 24000
|
||||
|
||||
voc_predictor_conf:
|
||||
use_gpu: True
|
||||
enable_mkldnn: True
|
||||
switch_ir_optim: True
|
||||
|
||||
##################################################################
|
||||
# OTHERS #
|
||||
##################################################################
|
||||
lang: 'zh'
|
||||
device: paddle.get_device()
|
@ -0,0 +1,474 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import paddle
|
||||
import soundfile as sf
|
||||
import yaml
|
||||
from engine.base_engine import BaseEngine
|
||||
from scipy.io import wavfile
|
||||
|
||||
from paddlespeech.cli.log import logger
|
||||
from paddlespeech.cli.tts.infer import TTSExecutor
|
||||
from paddlespeech.cli.utils import download_and_decompress
|
||||
from paddlespeech.cli.utils import MODEL_HOME
|
||||
from paddlespeech.t2s.frontend import English
|
||||
from paddlespeech.t2s.frontend.zh_frontend import Frontend
|
||||
from utils.audio_process import change_speed
|
||||
from utils.errors import ErrorCode
|
||||
from utils.exception import ServerBaseException
|
||||
from utils.paddle_predictor import init_predictor
|
||||
from utils.paddle_predictor import run_model
|
||||
|
||||
__all__ = ['TTSEngine']
|
||||
|
||||
# Static model applied on paddle inference
|
||||
pretrained_models = {
|
||||
# speedyspeech
|
||||
"speedyspeech_csmsc-zh": {
|
||||
'url':
|
||||
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_static_0.5.zip',
|
||||
'md5':
|
||||
'f10cbdedf47dc7a9668d2264494e1823',
|
||||
'model':
|
||||
'speedyspeech_csmsc.pdmodel',
|
||||
'params':
|
||||
'speedyspeech_csmsc.pdiparams',
|
||||
'phones_dict':
|
||||
'phone_id_map.txt',
|
||||
'tones_dict':
|
||||
'tone_id_map.txt',
|
||||
'sample_rate':
|
||||
24000,
|
||||
},
|
||||
# fastspeech2
|
||||
"fastspeech2_csmsc-zh": {
|
||||
'url':
|
||||
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_static_0.4.zip',
|
||||
'md5':
|
||||
'9788cd9745e14c7a5d12d32670b2a5a7',
|
||||
'model':
|
||||
'fastspeech2_csmsc.pdmodel',
|
||||
'params':
|
||||
'fastspeech2_csmsc.pdiparams',
|
||||
'phones_dict':
|
||||
'phone_id_map.txt',
|
||||
'sample_rate':
|
||||
24000,
|
||||
},
|
||||
# pwgan
|
||||
"pwgan_csmsc-zh": {
|
||||
'url':
|
||||
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_static_0.4.zip',
|
||||
'md5':
|
||||
'e3504aed9c5a290be12d1347836d2742',
|
||||
'model':
|
||||
'pwgan_csmsc.pdmodel',
|
||||
'params':
|
||||
'pwgan_csmsc.pdiparams',
|
||||
'sample_rate':
|
||||
24000,
|
||||
},
|
||||
# mb_melgan
|
||||
"mb_melgan_csmsc-zh": {
|
||||
'url':
|
||||
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_static_0.1.1.zip',
|
||||
'md5':
|
||||
'ac6eee94ba483421d750433f4c3b8d36',
|
||||
'model':
|
||||
'mb_melgan_csmsc.pdmodel',
|
||||
'params':
|
||||
'mb_melgan_csmsc.pdiparams',
|
||||
'sample_rate':
|
||||
24000,
|
||||
},
|
||||
# hifigan
|
||||
"hifigan_csmsc-zh": {
|
||||
'url':
|
||||
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_static_0.1.1.zip',
|
||||
'md5':
|
||||
'7edd8c436b3a5546b3a7cb8cff9d5a0c',
|
||||
'model':
|
||||
'hifigan_csmsc.pdmodel',
|
||||
'params':
|
||||
'hifigan_csmsc.pdiparams',
|
||||
'sample_rate':
|
||||
24000,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TTSServerExecutor(TTSExecutor):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.parser = argparse.ArgumentParser(
|
||||
prog='paddlespeech.tts', add_help=True)
|
||||
self.parser.add_argument(
|
||||
'--conf',
|
||||
type=str,
|
||||
default='./conf/tts/tts_pd.yaml',
|
||||
help='Configuration parameters.')
|
||||
|
||||
def _get_pretrained_path(self, tag: str) -> os.PathLike:
|
||||
"""
|
||||
Download and returns pretrained resources path of current task.
|
||||
"""
|
||||
assert tag in pretrained_models, 'Can not find pretrained resources of {}.'.format(
|
||||
tag)
|
||||
|
||||
res_path = os.path.join(MODEL_HOME, tag)
|
||||
decompressed_path = download_and_decompress(pretrained_models[tag],
|
||||
res_path)
|
||||
decompressed_path = os.path.abspath(decompressed_path)
|
||||
logger.info(
|
||||
'Use pretrained model stored in: {}'.format(decompressed_path))
|
||||
return decompressed_path
|
||||
|
||||
def _init_from_path(
|
||||
self,
|
||||
am: str='fastspeech2_csmsc',
|
||||
am_model: Optional[os.PathLike]=None,
|
||||
am_params: Optional[os.PathLike]=None,
|
||||
am_sample_rate: int=24000,
|
||||
phones_dict: Optional[os.PathLike]=None,
|
||||
tones_dict: Optional[os.PathLike]=None,
|
||||
speaker_dict: Optional[os.PathLike]=None,
|
||||
voc: str='pwgan_csmsc',
|
||||
voc_model: Optional[os.PathLike]=None,
|
||||
voc_params: Optional[os.PathLike]=None,
|
||||
voc_sample_rate: int=24000,
|
||||
lang: str='zh',
|
||||
am_predictor_conf: dict=None,
|
||||
voc_predictor_conf: dict=None, ):
|
||||
"""
|
||||
Init model and other resources from a specific path.
|
||||
"""
|
||||
if hasattr(self, 'am') and hasattr(self, 'voc'):
|
||||
logger.info('Models had been initialized.')
|
||||
return
|
||||
# am
|
||||
am_tag = am + '-' + lang
|
||||
if am_model is None or am_params is None or phones_dict is None:
|
||||
am_res_path = self._get_pretrained_path(am_tag)
|
||||
self.am_res_path = am_res_path
|
||||
self.am_model = os.path.join(am_res_path,
|
||||
pretrained_models[am_tag]['model'])
|
||||
self.am_params = os.path.join(am_res_path,
|
||||
pretrained_models[am_tag]['params'])
|
||||
# must have phones_dict in acoustic
|
||||
self.phones_dict = os.path.join(
|
||||
am_res_path, pretrained_models[am_tag]['phones_dict'])
|
||||
self.am_sample_rate = pretrained_models[am_tag]['sample_rate']
|
||||
|
||||
logger.info(am_res_path)
|
||||
logger.info(self.am_model)
|
||||
logger.info(self.am_params)
|
||||
else:
|
||||
self.am_model = os.path.abspath(am_model)
|
||||
self.am_params = os.path.abspath(am_params)
|
||||
self.phones_dict = os.path.abspath(phones_dict)
|
||||
self.am_sample_rate = am_sample_rate
|
||||
self.am_res_path = os.path.dirname(os.path.abspath(self.am_model))
|
||||
print("self.phones_dict:", self.phones_dict)
|
||||
|
||||
# for speedyspeech
|
||||
self.tones_dict = None
|
||||
if 'tones_dict' in pretrained_models[am_tag]:
|
||||
self.tones_dict = os.path.join(
|
||||
am_res_path, pretrained_models[am_tag]['tones_dict'])
|
||||
if tones_dict:
|
||||
self.tones_dict = tones_dict
|
||||
|
||||
# for multi speaker fastspeech2
|
||||
self.speaker_dict = None
|
||||
if 'speaker_dict' in pretrained_models[am_tag]:
|
||||
self.speaker_dict = os.path.join(
|
||||
am_res_path, pretrained_models[am_tag]['speaker_dict'])
|
||||
if speaker_dict:
|
||||
self.speaker_dict = speaker_dict
|
||||
|
||||
# voc
|
||||
voc_tag = voc + '-' + lang
|
||||
if voc_model is None or voc_params is None:
|
||||
voc_res_path = self._get_pretrained_path(voc_tag)
|
||||
self.voc_res_path = voc_res_path
|
||||
self.voc_model = os.path.join(voc_res_path,
|
||||
pretrained_models[voc_tag]['model'])
|
||||
self.voc_params = os.path.join(voc_res_path,
|
||||
pretrained_models[voc_tag]['params'])
|
||||
self.voc_sample_rate = pretrained_models[voc_tag]['sample_rate']
|
||||
logger.info(voc_res_path)
|
||||
logger.info(self.voc_model)
|
||||
logger.info(self.voc_params)
|
||||
else:
|
||||
self.voc_model = os.path.abspath(voc_model)
|
||||
self.voc_params = os.path.abspath(voc_params)
|
||||
self.voc_sample_rate = voc_sample_rate
|
||||
self.voc_res_path = os.path.dirname(os.path.abspath(self.voc_model))
|
||||
|
||||
assert (self.voc_sample_rate == self.am_sample_rate)
|
||||
# Init body.
|
||||
with open(self.phones_dict, "r") as f:
|
||||
phn_id = [line.strip().split() for line in f.readlines()]
|
||||
vocab_size = len(phn_id)
|
||||
print("vocab_size:", vocab_size)
|
||||
|
||||
tone_size = None
|
||||
if self.tones_dict:
|
||||
with open(self.tones_dict, "r") as f:
|
||||
tone_id = [line.strip().split() for line in f.readlines()]
|
||||
tone_size = len(tone_id)
|
||||
print("tone_size:", tone_size)
|
||||
|
||||
spk_num = None
|
||||
if self.speaker_dict:
|
||||
with open(self.speaker_dict, 'rt') as f:
|
||||
spk_id = [line.strip().split() for line in f.readlines()]
|
||||
spk_num = len(spk_id)
|
||||
print("spk_num:", spk_num)
|
||||
|
||||
# frontend
|
||||
if lang == 'zh':
|
||||
self.frontend = Frontend(
|
||||
phone_vocab_path=self.phones_dict,
|
||||
tone_vocab_path=self.tones_dict)
|
||||
|
||||
elif lang == 'en':
|
||||
self.frontend = English(phone_vocab_path=self.phones_dict)
|
||||
print("frontend done!")
|
||||
|
||||
# am predictor
|
||||
self.am_predictor_conf = am_predictor_conf
|
||||
self.am_predictor = init_predictor(
|
||||
model_file=self.am_model,
|
||||
params_file=self.am_params,
|
||||
predictor_conf=self.am_predictor_conf)
|
||||
|
||||
# voc predictor
|
||||
self.voc_predictor_conf = voc_predictor_conf
|
||||
self.voc_predictor = init_predictor(
|
||||
model_file=self.voc_model,
|
||||
params_file=self.voc_params,
|
||||
predictor_conf=self.voc_predictor_conf)
|
||||
|
||||
@paddle.no_grad()
|
||||
def infer(self,
|
||||
text: str,
|
||||
lang: str='zh',
|
||||
am: str='fastspeech2_csmsc',
|
||||
spk_id: int=0):
|
||||
"""
|
||||
Model inference and result stored in self.output.
|
||||
"""
|
||||
am_name = am[:am.rindex('_')]
|
||||
am_dataset = am[am.rindex('_') + 1:]
|
||||
get_tone_ids = False
|
||||
merge_sentences = False
|
||||
if am_name == 'speedyspeech':
|
||||
get_tone_ids = True
|
||||
if lang == 'zh':
|
||||
input_ids = self.frontend.get_input_ids(
|
||||
text,
|
||||
merge_sentences=merge_sentences,
|
||||
get_tone_ids=get_tone_ids)
|
||||
phone_ids = input_ids["phone_ids"]
|
||||
if get_tone_ids:
|
||||
tone_ids = input_ids["tone_ids"]
|
||||
elif lang == 'en':
|
||||
input_ids = self.frontend.get_input_ids(
|
||||
text, merge_sentences=merge_sentences)
|
||||
phone_ids = input_ids["phone_ids"]
|
||||
else:
|
||||
print("lang should in {'zh', 'en'}!")
|
||||
|
||||
flags = 0
|
||||
for i in range(len(phone_ids)):
|
||||
part_phone_ids = phone_ids[i]
|
||||
# am
|
||||
if am_name == 'speedyspeech':
|
||||
part_tone_ids = tone_ids[i]
|
||||
am_result = run_model(
|
||||
self.am_predictor,
|
||||
[part_phone_ids.numpy(), part_tone_ids.numpy()])
|
||||
mel = am_result[0]
|
||||
|
||||
# fastspeech2
|
||||
else:
|
||||
# multi speaker do not have static model
|
||||
if am_dataset in {"aishell3", "vctk"}:
|
||||
pass
|
||||
else:
|
||||
am_result = run_model(self.am_predictor,
|
||||
[part_phone_ids.numpy()])
|
||||
mel = am_result[0]
|
||||
# voc
|
||||
voc_result = run_model(self.voc_predictor, [mel])
|
||||
wav = voc_result[0]
|
||||
wav = paddle.to_tensor(wav)
|
||||
|
||||
if flags == 0:
|
||||
wav_all = wav
|
||||
flags = 1
|
||||
else:
|
||||
wav_all = paddle.concat([wav_all, wav])
|
||||
self._outputs['wav'] = wav_all
|
||||
|
||||
|
||||
class TTSEngine(BaseEngine):
|
||||
"""TTS server engine
|
||||
|
||||
Args:
|
||||
metaclass: Defaults to Singleton.
|
||||
"""
|
||||
|
||||
def __init__(self, name=None):
|
||||
"""Initialize TTS server engine
|
||||
"""
|
||||
super(TTSEngine, self).__init__()
|
||||
self.executor = TTSServerExecutor()
|
||||
|
||||
config_path = self.executor.parser.parse_args().conf
|
||||
with open(config_path, 'rt') as f:
|
||||
self.conf_dict = yaml.safe_load(f)
|
||||
|
||||
self.executor._init_from_path(
|
||||
am=self.conf_dict["am"],
|
||||
am_model=self.conf_dict["am_model"],
|
||||
am_params=self.conf_dict["am_params"],
|
||||
am_sample_rate=self.conf_dict["am_sample_rate"],
|
||||
phones_dict=self.conf_dict["phones_dict"],
|
||||
tones_dict=self.conf_dict["tones_dict"],
|
||||
speaker_dict=self.conf_dict["speaker_dict"],
|
||||
voc=self.conf_dict["voc"],
|
||||
voc_model=self.conf_dict["voc_model"],
|
||||
voc_params=self.conf_dict["voc_params"],
|
||||
voc_sample_rate=self.conf_dict["voc_sample_rate"],
|
||||
lang=self.conf_dict["lang"],
|
||||
am_predictor_conf=self.conf_dict["am_predictor_conf"],
|
||||
voc_predictor_conf=self.conf_dict["voc_predictor_conf"], )
|
||||
|
||||
logger.info("Initialize TTS server engine successfully.")
|
||||
|
||||
def postprocess(self,
|
||||
wav,
|
||||
original_fs: int,
|
||||
target_fs: int=16000,
|
||||
volume: float=1.0,
|
||||
speed: float=1.0,
|
||||
audio_path: str=None):
|
||||
"""Post-processing operations, including speech, volume, sample rate, save audio file
|
||||
|
||||
Args:
|
||||
wav (numpy(float)): Synthesized audio sample points
|
||||
original_fs (int): original audio sample rate
|
||||
target_fs (int): target audio sample rate
|
||||
volume (float): target volume
|
||||
speed (float): target speed
|
||||
"""
|
||||
|
||||
# transform sample_rate
|
||||
if target_fs == 0 or target_fs > original_fs:
|
||||
target_fs = original_fs
|
||||
wav_tar_fs = wav
|
||||
else:
|
||||
wav_tar_fs = librosa.resample(
|
||||
np.squeeze(wav), original_fs, target_fs)
|
||||
|
||||
# transform volume
|
||||
wav_vol = wav_tar_fs * volume
|
||||
|
||||
# transform speed
|
||||
try: # windows not support soxbindings
|
||||
wav_speed = change_speed(wav_vol, speed, target_fs)
|
||||
except:
|
||||
raise ServerBaseException(
|
||||
ErrorCode.SERVER_INTERNAL_ERR,
|
||||
"Can not install soxbindings on your system.")
|
||||
|
||||
# wav to base64
|
||||
buf = io.BytesIO()
|
||||
wavfile.write(buf, target_fs, wav_speed)
|
||||
base64_bytes = base64.b64encode(buf.read())
|
||||
wav_base64 = base64_bytes.decode('utf-8')
|
||||
|
||||
# save audio
|
||||
if audio_path is not None and audio_path.endswith(".wav"):
|
||||
sf.write(audio_path, wav_speed, target_fs)
|
||||
elif audio_path is not None and audio_path.endswith(".pcm"):
|
||||
wav_norm = wav_speed * (32767 / max(0.001,
|
||||
np.max(np.abs(wav_speed))))
|
||||
with open(audio_path, "wb") as f:
|
||||
f.write(wav_norm.astype(np.int16))
|
||||
|
||||
return target_fs, wav_base64
|
||||
|
||||
def run(self,
|
||||
sentence: str,
|
||||
spk_id: int=0,
|
||||
speed: float=1.0,
|
||||
volume: float=1.0,
|
||||
sample_rate: int=0,
|
||||
save_path: str=None):
|
||||
"""get the result of the server response
|
||||
|
||||
Args:
|
||||
sentence (str): sentence to be synthesized
|
||||
spk_id (int, optional): speaker id. Defaults to 0.
|
||||
speed (float, optional): audio speed, 0 < speed <=3.0. Defaults to 1.0.
|
||||
volume (float, optional): The volume relative to the audio synthesized by the model,
|
||||
0 < volume <=3.0. Defaults to 1.0.
|
||||
sample_rate (int, optional): Set the sample rate of the synthesized audio.
|
||||
0 represents the sample rate for model synthesis. Defaults to 0.
|
||||
save_path (str, optional): The save path of the synthesized audio. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ServerBaseException: Exception
|
||||
ServerBaseException: Exception
|
||||
|
||||
Returns:
|
||||
lang, target_sample_rate, wav_base64
|
||||
"""
|
||||
|
||||
lang = self.conf_dict["lang"]
|
||||
|
||||
try:
|
||||
self.executor.infer(
|
||||
text=sentence,
|
||||
lang=lang,
|
||||
am=self.conf_dict["am"],
|
||||
spk_id=spk_id)
|
||||
except:
|
||||
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
|
||||
"tts infer failed.")
|
||||
|
||||
try:
|
||||
target_sample_rate, wav_base64 = self.postprocess(
|
||||
wav=self.executor._outputs['wav'].numpy(),
|
||||
original_fs=self.executor.am_sample_rate,
|
||||
target_fs=sample_rate,
|
||||
volume=volume,
|
||||
speed=speed,
|
||||
audio_path=save_path)
|
||||
except:
|
||||
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
|
||||
"tts postprocess failed.")
|
||||
|
||||
return lang, target_sample_rate, wav_base64
|
@ -0,0 +1,82 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from paddle.inference import Config
|
||||
from paddle.inference import create_predictor
|
||||
|
||||
|
||||
def init_predictor(model_dir: Optional[os.PathLike]=None,
|
||||
model_file: Optional[os.PathLike]=None,
|
||||
params_file: Optional[os.PathLike]=None,
|
||||
predictor_conf: dict=None):
|
||||
"""Create predictor with Paddle inference
|
||||
|
||||
Args:
|
||||
model_dir (Optional[os.PathLike], optional): The path of the static model saved in the model layer. Defaults to None.
|
||||
model_file (Optional[os.PathLike], optional): *.pdmodel file path. Defaults to None.
|
||||
params_file (Optional[os.PathLike], optional): *.pdiparams file path.. Defaults to None.
|
||||
predictor_conf (dict, optional): The configuration parameters of predictor. Defaults to None.
|
||||
|
||||
Returns:
|
||||
predictor (PaddleInferPredictor): created predictor
|
||||
"""
|
||||
|
||||
if model_dir is not None:
|
||||
config = Config(args.model_dir)
|
||||
else:
|
||||
config = Config(model_file, params_file)
|
||||
|
||||
config.enable_memory_optim()
|
||||
if predictor_conf["use_gpu"]:
|
||||
config.enable_use_gpu(1000, 0)
|
||||
if predictor_conf["enable_mkldnn"]:
|
||||
config.enable_mkldnn()
|
||||
if predictor_conf["switch_ir_optim"]:
|
||||
config.switch_ir_optim()
|
||||
|
||||
predictor = create_predictor(config)
|
||||
|
||||
return predictor
|
||||
|
||||
|
||||
def run_model(predictor, input: List) -> List:
|
||||
""" run predictor
|
||||
|
||||
Args:
|
||||
predictor: paddle inference predictor
|
||||
input (list): The input of predictor
|
||||
|
||||
Returns:
|
||||
list: result list
|
||||
"""
|
||||
input_names = predictor.get_input_names()
|
||||
for i, name in enumerate(input_names):
|
||||
input_handle = predictor.get_input_handle(name)
|
||||
input_handle.copy_from_cpu(input[i])
|
||||
|
||||
# do the inference
|
||||
predictor.run()
|
||||
|
||||
results = []
|
||||
# get out data from output tensor
|
||||
output_names = predictor.get_output_names()
|
||||
for i, name in enumerate(output_names):
|
||||
output_handle = predictor.get_output_handle(name)
|
||||
output_data = output_handle.copy_to_cpu()
|
||||
results.append(output_data)
|
||||
|
||||
return results
|
Loading…
Reference in new issue