diff --git a/paddlespeech/resource/resource.py b/paddlespeech/resource/resource.py index f00b1b3b..fff5f745 100644 --- a/paddlespeech/resource/resource.py +++ b/paddlespeech/resource/resource.py @@ -82,6 +82,7 @@ class CommonTaskResource: self.model_tag = model_tag self.version = version self.res_dict = self.pretrained_models[model_tag][version] + self.format_path(self.res_dict) self.res_dir = self._fetch(self.res_dict, self._get_model_dir(model_type)) else: @@ -89,9 +90,19 @@ class CommonTaskResource: self.voc_model_tag = model_tag self.voc_version = version self.voc_res_dict = self.pretrained_models[model_tag][version] + self.format_path(self.voc_res_dict) self.voc_res_dir = self._fetch(self.voc_res_dict, self._get_model_dir(model_type)) + @staticmethod + def format_path(res_dict: Dict[str, str]): + for k, v in res_dict.items(): + if '/' in v: + if v.startswith('https://') or v.startswith('http://'): + continue + else: + res_dict[k] = os.path.join(*(v.split('/'))) + @staticmethod def get_model_class(model_name) -> List[object]: """Dynamic import model class.