diff --git a/paddlespeech/server/engine/text/python/text_engine.py b/paddlespeech/server/engine/text/python/text_engine.py index 6167e7784..cc72c0543 100644 --- a/paddlespeech/server/engine/text/python/text_engine.py +++ b/paddlespeech/server/engine/text/python/text_engine.py @@ -107,11 +107,14 @@ class PaddleTextConnectionHandler: assert len(tokens) == len(labels) text = '' + is_fast_model = 'fast' in self.text_engine.config.model_type for t, l in zip(tokens, labels): text += t if l != 0: # Non punc. - text += self._punc_list[l] - + if is_fast_model: + text += self._punc_list[l - 1] + else: + text += self._punc_list[l] return text else: raise NotImplementedError @@ -160,14 +163,23 @@ class TextEngine(BaseEngine): return False self.executor = TextServerExecutor() - self.executor._init_from_path( - task=config.task, - model_type=config.model_type, - lang=config.lang, - cfg_path=config.cfg_path, - ckpt_path=config.ckpt_path, - vocab_file=config.vocab_file) - + if 'fast' in config.model_type: + self.executor._init_from_path_new( + task=config.task, + model_type=config.model_type, + lang=config.lang, + cfg_path=config.cfg_path, + ckpt_path=config.ckpt_path, + vocab_file=config.vocab_file) + else: + self.executor._init_from_path( + task=config.task, + model_type=config.model_type, + lang=config.lang, + cfg_path=config.cfg_path, + ckpt_path=config.ckpt_path, + vocab_file=config.vocab_file) + logger.info("Using model: %s." % (config.model_type)) logger.info("Initialize Text server engine successfully on device: %s." % (self.device)) return True