diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index ea1828b6..00216356 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -119,8 +119,7 @@ class ASRExecutor(BaseExecutor): lang: str='zh', model_sample_rate: int=16000, cfg_path: Optional[os.PathLike]=None, - ckpt_path: Optional[os.PathLike]=None, - device: str='cpu'): + ckpt_path: Optional[os.PathLike]=None): """ Init model and other resources from a specific path. """ @@ -142,7 +141,6 @@ class ASRExecutor(BaseExecutor): os.path.dirname(os.path.abspath(self.cfg_path))) #Init body. - paddle.set_device(device) self.config = CfgNode(new_allowed=True) self.config.merge_from_file(self.cfg_path) self.config.decoding.decoding_method = "attention_rescoring" @@ -403,8 +401,9 @@ class ASRExecutor(BaseExecutor): """ audio_file = os.path.abspath(audio_file) self._check(audio_file, model_sample_rate) - self._init_from_path(model, lang, model_sample_rate, config, ckpt_path, - device) + + paddle.set_device(device) + self._init_from_path(model, lang, model_sample_rate, config, ckpt_path) self.preprocess(model, audio_file) self.infer(model) res = self.postprocess() # Retrieve result of asr.