added engine factory and config, test=doc

pull/1399/head
WilliamZhang06 3 years ago
commit 36ea686b7e

@ -11,33 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
def init(args):
""" 系统初始化
"""
def main(args):
"""主程序入口"""
if init(args):
app.run(host='0.0.0.0', port=conf.port)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_file",
action="store",
help="yaml file of the app",
default="./conf/application.yaml")
parser.add_argument(
"--log_file",
action="store",
help="log file",
default="./log/paddlespeech.log")
args = parser.parse_args()
main(args)

@ -11,33 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
def init(args):
""" 系统初始化
"""
def main(args):
"""主程序入口"""
if init(args):
app.run(host='0.0.0.0', port=conf.port)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_file",
action="store",
help="yaml file of the app",
default="./conf/application.yaml")
parser.add_argument(
"--log_file",
action="store",
help="log file",
default="./log/paddlespeech.log")
args = parser.parse_args()
main(args)

@ -11,33 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
def init(args):
""" 系统初始化
"""
def main(args):
"""主程序入口"""
if init(args):
app.run(host='0.0.0.0', port=conf.port)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_file",
action="store",
help="yaml file of the app",
default="./conf/application.yaml")
parser.add_argument(
"--log_file",
action="store",
help="log file",
default="./log/paddlespeech.log")
args = parser.parse_args()
main(args)

@ -6,3 +6,9 @@
host: '0.0.0.0'
port: 8090
##################################################################
# CONFIG FILE #
##################################################################
# add engine type (Options: asr, tts) and config file here.
engine_backend:
asr: 'conf/asr/asr.yaml'

@ -0,0 +1,5 @@
model: 'conformer_wenetspeech'
lang: 'conformer_wenetspeech'
lang: 'zh'
sample_rate: 16000
decode_method: 'attention_rescoring'

@ -14,18 +14,21 @@
from engine.base_engine import BaseEngine
from utils.log import logger
from utils.config import get_config
__all__ = ['ASREngine']
class ASREngine(BaseEngine):
def __init__(self, name=None):
def __init__(self):
super(ASREngine, self).__init__()
self.executor = name
def init(self, config_file: str):
self.config_file = config_file
self.executor = None
self.input = None
self.output = None
def init(self):
config = get_config(self.config_file)
pass
def postprocess(self):
@ -34,12 +37,3 @@ class ASREngine(BaseEngine):
def run(self):
logger.info("start run asr engine")
return "hello world"
if __name__ == "__main__":
# test Singleton
class1 = ASREngine("ASREngine")
class2 = ASREngine()
print(class1 is class2)
print(id(class1))
print(id(class2))

@ -0,0 +1,26 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from engine.asr.python.asr_engine import ASREngine
from engine.tts.python.tts_engine import TTSEngine
class EngineFactory(object):
@staticmethod
def get_engine(engine_name):
if engine_name == 'asr':
return ASREngine()
elif engine_name == 'tts':
return TTSEngine()
else:
return None

@ -12,43 +12,42 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import uvicorn
import yaml
from engine.tts.python.tts_engine import TTSEngine
from fastapi import FastAPI
from restful.api import router as api_router
from paddlespeech.cli.log import logger
from restful.api import setup_router
from utils.log import logger
from utils.config import get_config
from engine.engine_factory import EngineFactory
app = FastAPI(
title="PaddleSpeech Serving API", description="Api", version="0.0.1")
def init(args):
""" 系统初始化
def init(config):
""" system initialization
"""
# init api
api_list = list(config.engine_backend)
api_router = setup_router(api_list)
app.include_router(api_router)
# engine single
TTS_ENGINE = TTSEngine()
# todo others
# init engine
engine_list = []
for engine in config.engine_backend:
engine_list.append(EngineFactory.get_engine(engine_name=engine))
engine_list[-1].init(config_file=config.engine_backend[engine])
return True
def main(args):
"""主程序入口"""
"""main function"""
#TODO configuration
from yacs.config import CfgNode
with open(args.config_file, 'rt') as f:
config = CfgNode(yaml.safe_load(f))
config = get_config(args.config_file)
if init(args):
if init(config):
uvicorn.run(app, host=config.host, port=config.port, debug=True)
@ -58,7 +57,7 @@ if __name__ == "__main__":
"--config_file",
action="store",
help="yaml file of the app",
default="./conf/tts/tts.yaml")
default="./conf/application.yaml")
parser.add_argument(
"--log_file",

@ -14,8 +14,19 @@
from fastapi import APIRouter
from .tts_api import router as tts_router
#from .asr_api import router as asr_router
from .asr_api import router as asr_router
_router = APIRouter()
def setup_router(api_list: list):
for api_name in api_list:
if api_name == 'asr':
_router.include_router(asr_router)
elif api_name == 'tts':
_router.include_router(tts_router)
else:
pass
return _router
router = APIRouter()
#router.include_router(asr_router)
router.include_router(tts_router)

@ -14,13 +14,12 @@
from fastapi import APIRouter
import base64
from engine.asr.python.asr_engine import ASREngine
from .response import ASRResponse
from .request import ASRRequest
router = APIRouter()
router = APIRouter()
@router.get('/paddlespeech/asr/help')
def help():
@ -44,8 +43,8 @@ def asr(request_body: ASRRequest):
"""
# single
asr_engine = ASREngine()
print("asr_engine id :" ,id(asr_engine))
asr_engine.init()
asr_results = asr_engine.run()
asr_engine.postprocess()

@ -0,0 +1,30 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import yaml
from yacs.config import CfgNode
def get_config(config_file):
"""[summary]
Args:
config_file (str): config_file
Returns:
CfgNode:
"""
with open(config_file, 'rt') as f:
config = CfgNode(yaml.safe_load(f))
return config

@ -0,0 +1,35 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the
import base64
def readwav2base64(wav_file):
"""
read wave file and covert to base64 string
"""
with open(wav_file, 'rb') as f:
base64_bytes = base64.b64encode(f.read())
base64_string = base64_bytes.decode('utf-8')
return base64_string
def readbase64towav(base64_string):
pass
def self_check():
""" self check resource
"""
return True
Loading…
Cancel
Save