diff --git a/paddlespeech/__init__.py b/paddlespeech/__init__.py index 185a92b8..42537b15 100644 --- a/paddlespeech/__init__.py +++ b/paddlespeech/__init__.py @@ -11,3 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + + + + + + + + + + diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 7a7a6a5d..7f648b4c 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -291,7 +291,8 @@ class ASRExecutor(BaseExecutor): """ audio_file = input - logger.info("Preprocess audio_file:" + audio_file) + if isinstance(audio_file, (str, os.PathLike)): + logger.info("Preprocess audio_file:" + audio_file) # Get the object for feature extraction if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: @@ -412,13 +413,13 @@ class ASRExecutor(BaseExecutor): def _check(self, audio_file: str, sample_rate: int, force_yes: bool): self.sample_rate = sample_rate if self.sample_rate != 16000 and self.sample_rate != 8000: - logger.error("please input --sr 8000 or --sr 16000") - raise Exception("invalid sample rate") - sys.exit(-1) + logger.error("invalid sample rate, please input --sr 8000 or --sr 16000") + return False - if not os.path.isfile(audio_file): - logger.error("Please input the right audio file path") - sys.exit(-1) + if isinstance(audio_file, (str, os.PathLike)): + if not os.path.isfile(audio_file): + logger.error("Please input the right audio file path") + return False logger.info("checking the audio file format......") try: @@ -435,7 +436,7 @@ class ASRExecutor(BaseExecutor): sample rate: 8k \n \ sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \ ") - sys.exit(-1) + return False logger.info("The sample rate is %d" % audio_sample_rate) if audio_sample_rate != self.sample_rate: logger.warning("The sample rate of the input file is not {}.\n \ @@ -469,6 +470,8 @@ class ASRExecutor(BaseExecutor): logger.info("The audio file format is right") self.change_format = False + return True + def execute(self, argv: List[str]) -> bool: """ Command line entry. @@ -523,7 +526,8 @@ class ASRExecutor(BaseExecutor): Python API to call an executor. """ audio_file = os.path.abspath(audio_file) - self._check(audio_file, sample_rate, force_yes) + if not self._check(audio_file, sample_rate, force_yes): + sys.exit(-1) paddle.set_device(device) self._init_from_path(model, lang, sample_rate, config, decode_method, ckpt_path) diff --git a/paddlespeech/s2t/io/utility.py b/paddlespeech/s2t/io/utility.py index 1a90e3d0..ce5e7723 100644 --- a/paddlespeech/s2t/io/utility.py +++ b/paddlespeech/s2t/io/utility.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import List +from io import BytesIO import numpy as np @@ -88,6 +89,10 @@ def pad_sequence(sequences: List[np.ndarray], def feat_type(filepath): + # deal with Byteio type for paddlespeech server + if isinstance(filepath, BytesIO): + return 'sound' + suffix = filepath.split(":")[0].split('.')[-1].lower() if suffix == 'ark': return 'mat' diff --git a/paddlespeech/server/bin/main.py b/paddlespeech/server/bin/main.py index af51f3f2..dda0bbd7 100644 --- a/paddlespeech/server/bin/main.py +++ b/paddlespeech/server/bin/main.py @@ -16,10 +16,9 @@ import uvicorn import yaml from fastapi import FastAPI -from paddlespeech.server.engine.engine_factory import EngineFactory +from paddlespeech.server.engine.engine_pool import init_engine_pool from paddlespeech.server.restful.api import setup_router from paddlespeech.server.utils.config import get_config -from paddlespeech.server.utils.log import logger app = FastAPI( title="PaddleSpeech Serving API", description="Api", version="0.0.1") @@ -39,12 +38,8 @@ def init(config): api_router = setup_router(api_list) app.include_router(api_router) - # init engine - engine_pool = [] - for engine in config.engine_backend: - engine_pool.append(EngineFactory.get_engine(engine_name=engine)) - if not engine_pool[-1].init(config_file=config.engine_backend[engine]): - return False + if not init_engine_pool(config): + return False return True diff --git a/paddlespeech/server/bin/paddlespeech_client.py b/paddlespeech/server/bin/paddlespeech_client.py index 1b06fe9d..3730d607 100644 --- a/paddlespeech/server/bin/paddlespeech_client.py +++ b/paddlespeech/server/bin/paddlespeech_client.py @@ -241,4 +241,4 @@ class ASRClientExecutor(BaseExecutor): print(r.json()) print("time cost %f s." % (time_end - time_start)) except: - print("Failed to speech recognition.") + print("Failed to speech recognition.") \ No newline at end of file diff --git a/paddlespeech/server/bin/paddlespeech_server.py b/paddlespeech/server/bin/paddlespeech_server.py index 80e65cb4..7c88d8a0 100644 --- a/paddlespeech/server/bin/paddlespeech_server.py +++ b/paddlespeech/server/bin/paddlespeech_server.py @@ -51,10 +51,8 @@ class ServerExecutor(BaseExecutor): def init(self, config) -> bool: """system initialization - Args: config (CfgNode): config object - Returns: bool: """ diff --git a/paddlespeech/server/conf/application.yaml b/paddlespeech/server/conf/application.yaml index 67cc3b34..154ef9af 100644 --- a/paddlespeech/server/conf/application.yaml +++ b/paddlespeech/server/conf/application.yaml @@ -9,9 +9,12 @@ port: 8090 ################################################################## # CONFIG FILE # ################################################################## -# add engine type (Options: asr, tts) and config file here. +# add engine type (Options: python, inference) +engine_type: + asr: 'inference' + # tts: 'inference' +# add engine backend type (Options: asr, tts) and config file here. engine_backend: - asr: 'conf/asr/asr.yaml' - tts: 'conf/tts/tts_pd.yaml' - + asr: 'conf/asr/asr_pd.yaml' + #tts: 'conf/tts/tts_pd.yaml' diff --git a/paddlespeech/server/conf/asr/asr.yaml b/paddlespeech/server/conf/asr/asr.yaml index 4c3b0a67..50e55a3c 100644 --- a/paddlespeech/server/conf/asr/asr.yaml +++ b/paddlespeech/server/conf/asr/asr.yaml @@ -1,7 +1,7 @@ model: 'conformer_wenetspeech' lang: 'zh' sample_rate: 16000 -cfg_path: -ckpt_path: +cfg_path: # [optional] +ckpt_path: # [optional] decode_method: 'attention_rescoring' -force_yes: False +force_yes: True diff --git a/paddlespeech/server/conf/asr/asr_pd.yaml b/paddlespeech/server/conf/asr/asr_pd.yaml new file mode 100644 index 00000000..43a63f1b --- /dev/null +++ b/paddlespeech/server/conf/asr/asr_pd.yaml @@ -0,0 +1,25 @@ +# This is the parameter configuration file for ASR server. +# These are the static models that support paddle inference. + +################################################################## +# ACOUSTIC MODEL SETTING # +# am choices=['deepspeech2offline_aishell'] TODO +################################################################## +model_type: 'deepspeech2offline_aishell' +am_model: # the pdmodel file of am static model [optional] +am_params: # the pdiparams file of am static model [optional] +lang: 'zh' +sample_rate: 16000 +cfg_path: +decode_method: +force_yes: True + +am_predictor_conf: + use_gpu: True + enable_mkldnn: True + switch_ir_optim: True + + +################################################################## +# OTHERS # +################################################################## diff --git a/paddlespeech/server/engine/asr/paddleinference/__init__.py b/paddlespeech/server/engine/asr/paddleinference/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/server/engine/asr/paddleinference/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/engine/asr/paddleinference/asr_engine.py b/paddlespeech/server/engine/asr/paddleinference/asr_engine.py new file mode 100644 index 00000000..6d072322 --- /dev/null +++ b/paddlespeech/server/engine/asr/paddleinference/asr_engine.py @@ -0,0 +1,244 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import io +import os +from typing import List +from typing import Optional +from typing import Union + +import librosa +import paddle +import soundfile +from yacs.config import CfgNode + +from paddlespeech.cli.utils import MODEL_HOME +from paddlespeech.s2t.modules.ctc import CTCDecoder +from paddlespeech.cli.asr.infer import ASRExecutor +from paddlespeech.cli.log import logger +from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer +from paddlespeech.s2t.transform.transformation import Transformation +from paddlespeech.s2t.utils.dynamic_import import dynamic_import +from paddlespeech.s2t.utils.utility import UpdateConfig +from paddlespeech.server.utils.config import get_config +from paddlespeech.server.utils.paddle_predictor import init_predictor +from paddlespeech.server.utils.paddle_predictor import run_model +from paddlespeech.server.engine.base_engine import BaseEngine + +__all__ = ['ASREngine'] + + +pretrained_models = { + "deepspeech2offline_aishell-zh-16k": { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz', + 'md5': + '932c3593d62fe5c741b59b31318aa314', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/deepspeech2/checkpoints/avg_1', + 'model': + 'exp/deepspeech2/checkpoints/avg_1.jit.pdmodel', + 'params': + 'exp/deepspeech2/checkpoints/avg_1.jit.pdiparams', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' + }, +} + + +class ASRServerExecutor(ASRExecutor): + def __init__(self): + super().__init__() + pass + + def _init_from_path(self, + model_type: str='wenetspeech', + am_model: Optional[os.PathLike]=None, + am_params: Optional[os.PathLike]=None, + lang: str='zh', + sample_rate: int=16000, + cfg_path: Optional[os.PathLike]=None, + decode_method: str='attention_rescoring', + am_predictor_conf: dict=None): + """ + Init model and other resources from a specific path. + """ + + if cfg_path is None or am_model is None or am_params is None: + sample_rate_str = '16k' if sample_rate == 16000 else '8k' + tag = model_type + '-' + lang + '-' + sample_rate_str + res_path = self._get_pretrained_path(tag) # wenetspeech_zh + self.res_path = res_path + self.cfg_path = os.path.join(res_path, + pretrained_models[tag]['cfg_path']) + + self.am_model = os.path.join(res_path, + pretrained_models[tag]['model']) + self.am_params = os.path.join(res_path, + pretrained_models[tag]['params']) + logger.info(res_path) + logger.info(self.cfg_path) + logger.info(self.am_model) + logger.info(self.am_params) + else: + self.cfg_path = os.path.abspath(cfg_path) + self.am_model = os.path.abspath(am_model) + self.am_params = os.path.abspath(am_params) + self.res_path = os.path.dirname( + os.path.dirname(os.path.abspath(self.cfg_path))) + + #Init body. + self.config = CfgNode(new_allowed=True) + self.config.merge_from_file(self.cfg_path) + + with UpdateConfig(self.config): + if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: + from paddlespeech.s2t.io.collator import SpeechCollator + self.vocab = self.config.vocab_filepath + self.config.decode.lang_model_path = os.path.join( + MODEL_HOME, 'language_model', + self.config.decode.lang_model_path) + self.collate_fn_test = SpeechCollator.from_config(self.config) + self.text_feature = TextFeaturizer( + unit_type=self.config.unit_type, vocab=self.vocab) + + lm_url = pretrained_models[tag]['lm_url'] + lm_md5 = pretrained_models[tag]['lm_md5'] + self.download_lm( + lm_url, + os.path.dirname(self.config.decode.lang_model_path), lm_md5) + elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: + raise Exception("wrong type") + else: + raise Exception("wrong type") + + # AM predictor + self.am_predictor_conf = am_predictor_conf + self.am_predictor = init_predictor( + model_file=self.am_model, + params_file=self.am_params, + predictor_conf=self.am_predictor_conf) + + # decoder + self.decoder = CTCDecoder( + odim=self.config.output_dim, # is in vocab + enc_n_units=self.config.rnn_layer_size * 2, + blank_id=self.config.blank_id, + dropout_rate=0.0, + reduction=True, # sum + batch_average=True, # sum / batch_size + grad_norm_type=self.config.get('ctc_grad_norm_type', None)) + + + @paddle.no_grad() + def infer(self, model_type: str): + """ + Model inference and result stored in self.output. + """ + cfg = self.config.decode + audio = self._inputs["audio"] + audio_len = self._inputs["audio_len"] + if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: + decode_batch_size = audio.shape[0] + # init once + self.decoder.init_decoder( + decode_batch_size, self.text_feature.vocab_list, + cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, + cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, + cfg.num_proc_bsearch) + + output_data = run_model( + self.am_predictor, + [audio.numpy(), audio_len.numpy()]) + + probs = output_data[0] + eouts_len = output_data[1] + + batch_size = probs.shape[0] + self.decoder.reset_decoder(batch_size=batch_size) + self.decoder.next(probs, eouts_len) + trans_best, trans_beam = self.decoder.decode() + + # self.model.decoder.del_decoder() + self._outputs["result"] = trans_best[0] + + elif "conformer" in model_type or "transformer" in model_type: + raise Exception("invalid model name") + else: + raise Exception("invalid model name") + + +class ASREngine(BaseEngine): + """ASR server engine + + Args: + metaclass: Defaults to Singleton. + """ + + def __init__(self): + super(ASREngine, self).__init__() + + def init(self, config_file: str) -> bool: + """init engine resource + + Args: + config_file (str): config file + + Returns: + bool: init failed or success + """ + self.input = None + self.output = None + self.executor = ASRServerExecutor() + self.config = get_config(config_file) + + paddle.set_device(paddle.get_device()) + self.executor._init_from_path( + model_type=self.config.model_type, + am_model=self.config.am_model, + am_params=self.config.am_params, + lang=self.config.lang, + sample_rate=self.config.sample_rate, + cfg_path=self.config.cfg_path, + decode_method=self.config.decode_method, + am_predictor_conf=self.config.am_predictor_conf) + + logger.info("Initialize ASR server engine successfully.") + return True + + def run(self, audio_data): + """engine run + + Args: + audio_data (bytes): base64.b64decode + """ + if self.executor._check( + io.BytesIO(audio_data), self.config.sample_rate, + self.config.force_yes): + logger.info("start running asr engine") + self.executor.preprocess(self.config.model_type, io.BytesIO(audio_data)) + self.executor.infer(self.config.model_type) + self.output = self.executor.postprocess() # Retrieve result of asr. + logger.info("end inferring asr engine") + else: + logger.info("file check failed!") + self.output = None + + def postprocess(self): + """postprocess + """ + return self.output diff --git a/paddlespeech/server/engine/asr/python/asr_engine.py b/paddlespeech/server/engine/asr/python/asr_engine.py index b1154983..fd67b029 100644 --- a/paddlespeech/server/engine/asr/python/asr_engine.py +++ b/paddlespeech/server/engine/asr/python/asr_engine.py @@ -38,101 +38,6 @@ class ASRServerExecutor(ASRExecutor): super().__init__() pass - def _check(self, audio_file: str, sample_rate: int, force_yes: bool): - self.sample_rate = sample_rate - if self.sample_rate != 16000 and self.sample_rate != 8000: - logger.error("please input --sr 8000 or --sr 16000") - return False - - logger.info("checking the audio file format......") - try: - audio, audio_sample_rate = soundfile.read( - audio_file, dtype="int16", always_2d=True) - except Exception as e: - logger.exception(e) - logger.error( - "can not open the audio file, please check the audio file format is 'wav'. \n \ - you can try to use sox to change the file format.\n \ - For example: \n \ - sample rate: 16k \n \ - sox input_audio.xx --rate 16k --bits 16 --channels 1 output_audio.wav \n \ - sample rate: 8k \n \ - sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \ - ") - - logger.info("The sample rate is %d" % audio_sample_rate) - if audio_sample_rate != self.sample_rate: - logger.warning("The sample rate of the input file is not {}.\n \ - The program will resample the wav file to {}.\n \ - If the result does not meet your expectations,\n \ - Please input the 16k 16 bit 1 channel wav file. \ - ".format(self.sample_rate, self.sample_rate)) - self.change_format = True - else: - logger.info("The audio file format is right") - self.change_format = False - - return True - - def preprocess(self, model_type: str, input: Union[str, os.PathLike]): - """ - Input preprocess and return paddle.Tensor stored in self.input. - Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet). - """ - - audio_file = input - - # Get the object for feature extraction - if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: - audio, _ = self.collate_fn_test.process_utterance( - audio_file=audio_file, transcript=" ") - audio_len = audio.shape[0] - audio = paddle.to_tensor(audio, dtype='float32') - audio_len = paddle.to_tensor(audio_len) - audio = paddle.unsqueeze(audio, axis=0) - # vocab_list = collate_fn_test.vocab_list - self._inputs["audio"] = audio - self._inputs["audio_len"] = audio_len - logger.info(f"audio feat shape: {audio.shape}") - - elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: - logger.info("get the preprocess conf") - preprocess_conf = self.config.preprocess_config - preprocess_args = {"train": False} - preprocessing = Transformation(preprocess_conf) - logger.info("read the audio file") - audio, audio_sample_rate = soundfile.read( - audio_file, dtype="int16", always_2d=True) - - if self.change_format: - if audio.shape[1] >= 2: - audio = audio.mean(axis=1, dtype=np.int16) - else: - audio = audio[:, 0] - # pcm16 -> pcm 32 - audio = self._pcm16to32(audio) - audio = librosa.resample(audio, audio_sample_rate, - self.sample_rate) - audio_sample_rate = self.sample_rate - # pcm32 -> pcm 16 - audio = self._pcm32to16(audio) - else: - audio = audio[:, 0] - - logger.info(f"audio shape: {audio.shape}") - # fbank - audio = preprocessing(audio, **preprocess_args) - - audio_len = paddle.to_tensor(audio.shape[0]) - audio = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0) - - self._inputs["audio"] = audio - self._inputs["audio_len"] = audio_len - logger.info(f"audio feat shape: {audio.shape}") - - else: - raise Exception("wrong type") - class ASREngine(BaseEngine): """ASR server engine @@ -157,16 +62,12 @@ class ASREngine(BaseEngine): self.output = None self.executor = ASRServerExecutor() - try: - self.config = get_config(config_file) - paddle.set_device(paddle.get_device()) - self.executor._init_from_path( - self.config.model, self.config.lang, self.config.sample_rate, - self.config.cfg_path, self.config.decode_method, - self.config.ckpt_path) - except: - logger.info("Initialize ASR server engine Failed.") - return False + self.config = get_config(config_file) + paddle.set_device(paddle.get_device()) + self.executor._init_from_path( + self.config.model, self.config.lang, self.config.sample_rate, + self.config.cfg_path, self.config.decode_method, + self.config.ckpt_path) logger.info("Initialize ASR server engine successfully.") return True diff --git a/paddlespeech/server/engine/engine_factory.py b/paddlespeech/server/engine/engine_factory.py index 79319fd9..05f13568 100644 --- a/paddlespeech/server/engine/engine_factory.py +++ b/paddlespeech/server/engine/engine_factory.py @@ -13,20 +13,24 @@ # limitations under the License. from typing import Text -from paddlespeech.server.engine.asr.python.asr_engine import ASREngine -#from paddlespeech.server.engine.tts.python.tts_engine import TTSEngine -from paddlespeech.server.engine.tts.paddleinference.tts_engine import TTSEngine - __all__ = ['EngineFactory'] class EngineFactory(object): @staticmethod - def get_engine(engine_name: Text): - if engine_name == 'asr': + def get_engine(engine_name: Text, engine_type: Text): + if engine_name == 'asr' and engine_type == 'inference': + from paddlespeech.server.engine.asr.paddleinference.asr_engine import ASREngine + return ASREngine() + elif engine_name == 'asr' and engine_type == 'python': + from paddlespeech.server.engine.asr.python.asr_engine import ASREngine return ASREngine() - elif engine_name == 'tts': + elif engine_name == 'tts' and engine_type == 'inference': + from paddlespeech.server.engine.tts.paddleinference.tts_engine import TTSEngine + return TTSEngine() + elif engine_name == 'tts' and engine_type == 'python': + from paddlespeech.server.engine.tts.python.tts_engine import TTSEngine return TTSEngine() else: return None diff --git a/paddlespeech/server/engine/engine_pool.py b/paddlespeech/server/engine/engine_pool.py new file mode 100644 index 00000000..0198bd80 --- /dev/null +++ b/paddlespeech/server/engine/engine_pool.py @@ -0,0 +1,36 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from paddlespeech.server.engine.engine_factory import EngineFactory + +# global value +ENGINE_POOL = {} + + +def get_engine_pool() -> dict: + """ Get engine pool + """ + global ENGINE_POOL + return ENGINE_POOL + + +def init_engine_pool(config) -> bool: + """ Init engine pool + """ + global ENGINE_POOL + for engine in config.engine_backend: + ENGINE_POOL[engine] = EngineFactory.get_engine(engine_name=engine, engine_type=config.engine_type[engine]) + if not ENGINE_POOL[engine].init(config_file=config.engine_backend[engine]): + return False + + return True diff --git a/paddlespeech/server/restful/api.py b/paddlespeech/server/restful/api.py index 9ce5cad0..2d69dee8 100644 --- a/paddlespeech/server/restful/api.py +++ b/paddlespeech/server/restful/api.py @@ -22,7 +22,14 @@ _router = APIRouter() def setup_router(api_list: List): + """setup router for fastapi + Args: + api_list (List): [asr, tts] + + Returns: + APIRouter + """ for api_name in api_list: if api_name == 'asr': _router.include_router(asr_router) diff --git a/paddlespeech/server/restful/asr_api.py b/paddlespeech/server/restful/asr_api.py index fcdb2f41..4806c042 100644 --- a/paddlespeech/server/restful/asr_api.py +++ b/paddlespeech/server/restful/asr_api.py @@ -16,7 +16,7 @@ import traceback from typing import Union from fastapi import APIRouter -from paddlespeech.server.engine.asr.python.asr_engine import ASREngine +from paddlespeech.server.engine.engine_pool import get_engine_pool from paddlespeech.server.restful.request import ASRRequest from paddlespeech.server.restful.response import ASRResponse from paddlespeech.server.restful.response import ErrorResponse @@ -61,9 +61,12 @@ def asr(request_body: ASRRequest): json: [description] """ try: - # single audio_data = base64.b64decode(request_body.audio) - asr_engine = ASREngine() + + # get single engine from engine pool + engine_pool = get_engine_pool() + asr_engine = engine_pool['asr'] + asr_engine.run(audio_data) asr_results = asr_engine.postprocess() diff --git a/paddlespeech/server/tests/16_audio.wav b/paddlespeech/server/tests/16_audio.wav deleted file mode 100644 index 3cfa5074..00000000 Binary files a/paddlespeech/server/tests/16_audio.wav and /dev/null differ diff --git a/paddlespeech/server/tests/http_client.py b/paddlespeech/server/tests/asr/http_client.py similarity index 100% rename from paddlespeech/server/tests/http_client.py rename to paddlespeech/server/tests/asr/http_client.py