[Fix] panns predict.py

pull/3914/head
megemini 10 months ago
parent afa9466c89
commit 1c540d8771

@ -15,6 +15,7 @@ import argparse
import os
import numpy as np
import paddle
from paddle import inference
from paddle.audio.datasets import ESC50
from paddle.audio.features import LogMelSpectrogram
@ -74,6 +75,11 @@ class Predictor(object):
self.batch_size = batch_size
model_file = os.path.join(model_dir, "inference.pdmodel")
if not os.path.exists(model_file):
model_file = os.path.join(model_dir, "inference.json")
if not os.path.exists(model_file):
raise ValueError("Inference model file not exists!")
params_file = os.path.join(model_dir, "inference.pdiparams")
assert os.path.isfile(model_file) and os.path.isfile(

Loading…
Cancel
Save