diff --git a/paddlespeech/cls/exps/panns/deploy/predict.py b/paddlespeech/cls/exps/panns/deploy/predict.py index 2888590c9..56fb0756d 100644 --- a/paddlespeech/cls/exps/panns/deploy/predict.py +++ b/paddlespeech/cls/exps/panns/deploy/predict.py @@ -74,11 +74,10 @@ class Predictor(object): enable_mkldnn=False): self.batch_size = batch_size - model_file = os.path.join(model_dir, "inference.pdmodel") - if not os.path.exists(model_file): + if os.path.exists(os.path.join(model_dir, "inference.json")): model_file = os.path.join(model_dir, "inference.json") - if not os.path.exists(model_file): - raise ValueError("Inference model file not exists!") + else: + model_file = os.path.join(model_dir, "inference.pdmodel") params_file = os.path.join(model_dir, "inference.pdiparams")