[Update] satisty version

pull/3914/head
megemini 10 months ago
parent 2d9650662f
commit 42b0572362

@ -22,6 +22,8 @@ from paddle.audio.features import LogMelSpectrogram
from paddleaudio.backends import soundfile_load as load_audio from paddleaudio.backends import soundfile_load as load_audio
from scipy.special import softmax from scipy.special import softmax
import paddlespeech.utils
# yapf: disable # yapf: disable
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_dir", type=str, required=True, default="./export", help="The directory to static model.") parser.add_argument("--model_dir", type=str, required=True, default="./export", help="The directory to static model.")
@ -74,17 +76,17 @@ class Predictor(object):
enable_mkldnn=False): enable_mkldnn=False):
self.batch_size = batch_size self.batch_size = batch_size
if os.path.exists(os.path.join(model_dir, "inference.json")): if paddlespeech.utils.satisfy_paddle_version('2.6.0'):
model_file = os.path.join(model_dir, "inference.json") config = inference.Config(model_dir, 'inference')
else: else:
model_file = os.path.join(model_dir, "inference.pdmodel") model_file = os.path.join(model_dir, 'inference.pdmodel')
params_file = os.path.join(model_dir, "inference.pdiparams") params_file = os.path.join(model_dir, "inference.pdiparams")
assert os.path.isfile(model_file) and os.path.isfile( assert os.path.isfile(model_file) and os.path.isfile(
params_file), 'Please check model and parameter files.' params_file), 'Please check model and parameter files.'
config = inference.Config(model_file, params_file) config = inference.Config(model_file, params_file)
config.disable_mkldnn() config.disable_mkldnn()
if device == "gpu": if device == "gpu":
# set GPU configs accordingly # set GPU configs accordingly

@ -11,3 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from packaging.version import Version
def satisfy_version(source: str, target: str, dev_allowed: bool=True) -> bool:
if dev_allowed and source.startswith('0.0.0'):
target_version = Version('0.0.0')
else:
target_version = Version(target)
source_version = Version(source)
return source_version >= target_version
def satisfy_paddle_version(target: str, dev_allowed: bool=True) -> bool:
import paddle
return satisfy_version(paddle.__version__, target, dev_allowed)

Loading…
Cancel
Save