[Hackathon 7th] 修复 `panns` 中 `predict.py` 对于 pir 的 json 模型路径 (#3914)

* [Fix] panns predict.py

* [Update] path exists

* [Fix] disable mkldnn and transpose dimension

* [Update] model_file check json first

* [Update] satisty version

* [Update] satisty version

* [Update] satisty version

* [Update] config disable_mkldnn

* [Update] unsqueeze
pull/3932/head
megemini 9 months ago committed by GitHub
parent 67ae7c8dd2
commit f582cb6299
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -15,12 +15,15 @@ import argparse
import os import os
import numpy as np import numpy as np
import paddle
from paddle import inference from paddle import inference
from paddle.audio.datasets import ESC50 from paddle.audio.datasets import ESC50
from paddle.audio.features import LogMelSpectrogram from paddle.audio.features import LogMelSpectrogram
from paddleaudio.backends import soundfile_load as load_audio from paddleaudio.backends import soundfile_load as load_audio
from scipy.special import softmax from scipy.special import softmax
import paddlespeech.utils
# yapf: disable # yapf: disable
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_dir", type=str, required=True, default="./export", help="The directory to static model.") parser.add_argument("--model_dir", type=str, required=True, default="./export", help="The directory to static model.")
@ -56,7 +59,6 @@ def extract_features(files: str, **kwargs):
feature_extractor = LogMelSpectrogram(sr, **kwargs) feature_extractor = LogMelSpectrogram(sr, **kwargs)
feat = feature_extractor(paddle.to_tensor(waveforms[i])) feat = feature_extractor(paddle.to_tensor(waveforms[i]))
feat = paddle.transpose(feat, perm=[1, 0]).unsqueeze(0) feat = paddle.transpose(feat, perm=[1, 0]).unsqueeze(0)
feats.append(feat) feats.append(feat)
return np.stack(feats, axis=0) return np.stack(feats, axis=0)
@ -73,13 +75,18 @@ class Predictor(object):
enable_mkldnn=False): enable_mkldnn=False):
self.batch_size = batch_size self.batch_size = batch_size
model_file = os.path.join(model_dir, "inference.pdmodel") if paddlespeech.utils.satisfy_paddle_version('3.0.0-beta'):
params_file = os.path.join(model_dir, "inference.pdiparams") config = inference.Config(model_dir, 'inference')
config.disable_mkldnn()
else:
model_file = os.path.join(model_dir, 'inference.pdmodel')
params_file = os.path.join(model_dir, "inference.pdiparams")
assert os.path.isfile(model_file) and os.path.isfile(
params_file), 'Please check model and parameter files.'
assert os.path.isfile(model_file) and os.path.isfile( config = inference.Config(model_file, params_file)
params_file), 'Please check model and parameter files.'
config = inference.Config(model_file, params_file)
if device == "gpu": if device == "gpu":
# set GPU configs accordingly # set GPU configs accordingly
# such as intialize the gpu memory, enable tensorrt # such as intialize the gpu memory, enable tensorrt

@ -39,7 +39,8 @@ if __name__ == '__main__':
input_spec=[ input_spec=[
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, None, 64], dtype=paddle.float32) shape=[None, None, 64], dtype=paddle.float32)
]) ],
full_graph=True)
# Save in static graph model. # Save in static graph model.
paddle.jit.save(model, os.path.join(args.output_dir, "inference")) paddle.jit.save(model, os.path.join(args.output_dir, "inference"))

@ -11,3 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from packaging.version import Version
def satisfy_version(source: str, target: str, dev_allowed: bool=True) -> bool:
if dev_allowed and source.startswith('0.0.0'):
target_version = Version('0.0.0')
else:
target_version = Version(target)
source_version = Version(source)
return source_version >= target_version
def satisfy_paddle_version(target: str, dev_allowed: bool=True) -> bool:
import paddle
return satisfy_version(paddle.__version__, target, dev_allowed)

Loading…
Cancel
Save