t2s infernece compatible with PIR api

pull/3923/head
Wang Xin 10 months ago
parent 67ae7c8dd2
commit 63d2cc7ae9

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import os
from pathlib import Path from pathlib import Path
import paddle import paddle
@ -120,6 +121,15 @@ def parse_args():
return args return args
def get_model_suffix(inference_dir, model_name):
if os.path.exists(os.path.join(inference_dir, model_name + ".pdmodel")):
return ".pdmodel"
elif os.path.exists(os.path.join(inference_dir, model_name + ".json")):
return ".json"
else:
raise ValueError("model file not found!")
# only inference for models trained with csmsc now # only inference for models trained with csmsc now
def main(): def main():
args = parse_args() args = parse_args()
@ -133,9 +143,10 @@ def main():
tones_dict=args.tones_dict) tones_dict=args.tones_dict)
# am_predictor # am_predictor
suffix = get_model_suffix(args.inference_dir, args.am)
am_predictor = get_predictor( am_predictor = get_predictor(
model_dir=args.inference_dir, model_dir=args.inference_dir,
model_file=args.am + ".pdmodel", model_file=args.am + suffix,
params_file=args.am + ".pdiparams", params_file=args.am + ".pdiparams",
device=args.device, device=args.device,
use_trt=args.use_trt, use_trt=args.use_trt,
@ -146,9 +157,10 @@ def main():
am_dataset = args.am[args.am.rindex('_') + 1:] am_dataset = args.am[args.am.rindex('_') + 1:]
# voc_predictor # voc_predictor
suffix = get_model_suffix(args.inference_dir, args.voc)
voc_predictor = get_predictor( voc_predictor = get_predictor(
model_dir=args.inference_dir, model_dir=args.inference_dir,
model_file=args.voc + ".pdmodel", model_file=args.voc + suffix,
params_file=args.voc + ".pdiparams", params_file=args.voc + ".pdiparams",
device=args.device, device=args.device,
use_trt=args.use_trt, use_trt=args.use_trt,

Loading…
Cancel
Save