[Fix] panns predict.py

pull/3914/head
megemini 10 months ago
parent afa9466c89
commit 1c540d8771

@ -15,6 +15,7 @@ import argparse
import os import os
import numpy as np import numpy as np
import paddle
from paddle import inference from paddle import inference
from paddle.audio.datasets import ESC50 from paddle.audio.datasets import ESC50
from paddle.audio.features import LogMelSpectrogram from paddle.audio.features import LogMelSpectrogram
@ -74,6 +75,11 @@ class Predictor(object):
self.batch_size = batch_size self.batch_size = batch_size
model_file = os.path.join(model_dir, "inference.pdmodel") model_file = os.path.join(model_dir, "inference.pdmodel")
if not os.path.exists(model_file):
model_file = os.path.join(model_dir, "inference.json")
if not os.path.exists(model_file):
raise ValueError("Inference model file not exists!")
params_file = os.path.join(model_dir, "inference.pdiparams") params_file = os.path.join(model_dir, "inference.pdiparams")
assert os.path.isfile(model_file) and os.path.isfile( assert os.path.isfile(model_file) and os.path.isfile(

Loading…
Cancel
Save