# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict import paddle from paddlespeech.cli.log import logger from paddlespeech.cli.text.infer import TextExecutor from paddlespeech.server.engine.base_engine import BaseEngine class PaddleTextConnectionHandler: def __init__(self, text_engine): """The PaddleSpeech Text Server Connection Handler This connection process every server request Args: text_engine (TextEngine): The Text engine """ super().__init__() logger.info( "Create PaddleTextConnectionHandler to process the text request") self.text_engine = text_engine self.task = self.text_engine.executor.task self.model = self.text_engine.executor.model self.tokenizer = self.text_engine.executor.tokenizer self._punc_list = self.text_engine.executor._punc_list self._inputs = OrderedDict() self._outputs = OrderedDict() @paddle.no_grad() def run(self, text): """The connection process the request text Args: text (str): the request text Returns: str: the punctuation text """ self.preprocess(text) self.infer() res = self.postprocess() return res @paddle.no_grad() def preprocess(self, text): """ Input preprocess and return paddle.Tensor stored in self.input. Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet). Args: text (str): the request text """ if self.task == 'punc': clean_text = self.text_engine.executor._clean_text(text) assert len(clean_text) > 0, f'Invalid input string: {text}' tokenized_input = self.tokenizer( list(clean_text), return_length=True, is_split_into_words=True) self._inputs['input_ids'] = tokenized_input['input_ids'] self._inputs['seg_ids'] = tokenized_input['token_type_ids'] self._inputs['seq_len'] = tokenized_input['seq_len'] else: raise NotImplementedError @paddle.no_grad() def infer(self): """Model inference and result stored in self.output. """ if self.task == 'punc': input_ids = paddle.to_tensor(self._inputs['input_ids']).unsqueeze(0) seg_ids = paddle.to_tensor(self._inputs['seg_ids']).unsqueeze(0) logits, _ = self.model(input_ids, seg_ids) preds = paddle.argmax(logits, axis=-1).squeeze(0) self._outputs['preds'] = preds else: raise NotImplementedError def postprocess(self): """Output postprocess and return human-readable results such as texts and audio files. Returns: str: The punctuation text """ if self.task == 'punc': input_ids = self._inputs['input_ids'] seq_len = self._inputs['seq_len'] preds = self._outputs['preds'] tokens = self.tokenizer.convert_ids_to_tokens( input_ids[1:seq_len - 1]) labels = preds[1:seq_len - 1].tolist() assert len(tokens) == len(labels) text = '' for t, l in zip(tokens, labels): text += t if l != 0: # Non punc. text += self._punc_list[l] return text else: raise NotImplementedError class TextServerExecutor(TextExecutor): def __init__(self): """The wrapper for TextEcutor """ super().__init__() pass class TextEngine(BaseEngine): def __init__(self): """The Text Engine """ super(TextEngine, self).__init__() logger.info("Create the TextEngine Instance") def init(self, config: dict): """Init the Text Engine Args: config (dict): The server configuation Returns: bool: The engine instance flag """ logger.info("Init the text engine") try: self.config = config if self.config.device: self.device = self.config.device else: self.device = paddle.get_device() paddle.set_device(self.device) logger.info(f"Text Engine set the device: {self.device}") except BaseException as e: logger.error( "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" ) logger.error("Initialize Text server engine Failed on device: %s." % (self.device)) return False self.executor = TextServerExecutor() self.executor._init_from_path( task=config.task, model_type=config.model_type, lang=config.lang, cfg_path=config.cfg_path, ckpt_path=config.ckpt_path, vocab_file=config.vocab_file) logger.info("Init the text engine successfully") return True