From cc4904b67affeca660a1e6897d696c8943dc975e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=98=A5=E4=B9=94?= <83450930+Liyulingyue@users.noreply.github.com> Date: Wed, 6 Nov 2024 19:20:24 +0800 Subject: [PATCH] Update infer.py --- paddlespeech/cli/ssl/infer.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/paddlespeech/cli/ssl/infer.py b/paddlespeech/cli/ssl/infer.py index 0bdd818dc..a7ae8c4ac 100644 --- a/paddlespeech/cli/ssl/infer.py +++ b/paddlespeech/cli/ssl/infer.py @@ -119,6 +119,7 @@ class SSLExecutor(BaseExecutor): '--verbose', action='store_true', help='Increase logger verbosity of current task.') + self.last_call_params = None def _init_from_path(self, model_type: str=None, @@ -287,8 +288,8 @@ class SSLExecutor(BaseExecutor): f"we will use the {model_type} like model to extract audio feature." ) try: - out_feature = self.model.extract_features(audio[:, :, 0]) - self._outputs["result"] = out_feature + out_feature = self.model(audio[:, :, 0]) + self._outputs["result"] = out_feature[0] except Exception as e: logger.exception(e) @@ -453,6 +454,22 @@ class SSLExecutor(BaseExecutor): Python API to call an executor. """ + current_call_params = { + "model": model, + "task": task, + "lang": lang, + "sample_rate": sample_rate, + "config": config, + "ckpt_path": ckpt_path, + "decode_method": decode_method, + "force_yes": force_yes, + "rtf": rtf, + "device": device + } + if self.last_call_params is not None and self.last_call_params != current_call_params and hasattr(self, 'model'): + del self.model + self.last_call_params = current_call_params + audio_file = os.path.abspath(audio_file) paddle.set_device(device) self._init_from_path(model, task, lang, sample_rate, config,