added engine framework, test=asr

pull/1383/head
WilliamZhang06 4 years ago
parent 2a530d49ff
commit 63a1799fa9

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -14,7 +14,6 @@
import argparse import argparse
def init(args): def init(args):
""" 系统初始化 """ 系统初始化
""" """
@ -27,13 +26,18 @@ def main(args):
app.run(host='0.0.0.0', port=conf.port) app.run(host='0.0.0.0', port=conf.port)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--config_file", action="store", parser.add_argument(
help="yaml file of the app", default="./conf/application.yaml") "--config_file",
parser.add_argument("--log_file", action="store", action="store",
help="log file", default="./log/paddlespeech.log") help="yaml file of the app",
default="./conf/application.yaml")
parser.add_argument(
"--log_file",
action="store",
help="log file",
default="./log/paddlespeech.log")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

@ -14,7 +14,6 @@
import argparse import argparse
def init(args): def init(args):
""" 系统初始化 """ 系统初始化
""" """
@ -27,13 +26,18 @@ def main(args):
app.run(host='0.0.0.0', port=conf.port) app.run(host='0.0.0.0', port=conf.port)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--config_file", action="store", parser.add_argument(
help="yaml file of the app", default="./conf/application.yaml") "--config_file",
parser.add_argument("--log_file", action="store", action="store",
help="log file", default="./log/paddlespeech.log") help="yaml file of the app",
default="./conf/application.yaml")
parser.add_argument(
"--log_file",
action="store",
help="log file",
default="./log/paddlespeech.log")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

@ -14,7 +14,6 @@
import argparse import argparse
def init(args): def init(args):
""" 系统初始化 """ 系统初始化
""" """
@ -27,13 +26,18 @@ def main(args):
app.run(host='0.0.0.0', port=conf.port) app.run(host='0.0.0.0', port=conf.port)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--config_file", action="store", parser.add_argument(
help="yaml file of the app", default="./conf/application.yaml") "--config_file",
parser.add_argument("--log_file", action="store", action="store",
help="log file", default="./log/paddlespeech.log") help="yaml file of the app",
default="./conf/application.yaml")
parser.add_argument(
"--log_file",
action="store",
help="log file",
default="./log/paddlespeech.log")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

@ -0,0 +1,8 @@
# This is the parameter configuration file for PaddleSpeech Serving.
##################################################################
# SERVER SETTING #
##################################################################
host: '0.0.0.0'
port: 8090

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -11,15 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from engine import BaseEngine from engine.base_engine import BaseEngine
from utils.log import logger
__all__ = ['ASREngine'] __all__ = ['ASREngine']
class ASREngine(BaseEngine):
def __init__(self, name): class ASREngine(BaseEngine):
def __init__(self, name=None):
super(ASREngine, self).__init__() super(ASREngine, self).__init__()
self.executor = name
self.input = None
self.output = None
def init(self): def init(self):
pass pass
@ -28,8 +32,8 @@ class ASREngine(BaseEngine):
pass pass
def run(self): def run(self):
pass logger.info("start run asr engine")
return "hello world"
if __name__ == "__main__": if __name__ == "__main__":
@ -39,4 +43,3 @@ if __name__ == "__main__":
print(class1 is class2) print(class1 is class2)
print(id(class1)) print(id(class1))
print(id(class2)) print(id(class2))

@ -18,6 +18,7 @@ from typing import Union
from pattern_singleton import Singleton from pattern_singleton import Singleton
class BaseEngine(metaclass=Singleton): class BaseEngine(metaclass=Singleton):
""" """
An base engine class An base engine class

@ -13,31 +13,55 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import asr_api as api_run import uvicorn
import tts_api as api_run import yaml
from engine.asr.python.asr_engine import ASREngine
from fastapi import FastAPI
from restful.api import router as api_router
from utils.log import logger
app = FastAPI(
title="PaddleSpeech Serving API", description="Api", version="0.0.1")
def init(args): def init(args):
""" 系统初始化 """ 系统初始化
""" """
app.include_router(api_router)
# engine single
ASR_ENGINE = ASREngine("asr")
# todo others
return True
def main(args): def main(args):
"""主程序入口""" """主程序入口"""
if init(args): #TODO configuration
api_run.run() from yacs.config import CfgNode
app.run(host='0.0.0.0', port=conf.port) with open(args.config_file, 'rt') as f:
config = CfgNode(yaml.safe_load(f))
if init(args):
uvicorn.run(app, host=config.host, port=config.port, debug=True)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--config_file", action="store", parser.add_argument(
help="yaml file of the app", default="./conf/application.yaml") "--config_file",
parser.add_argument("--log_file", action="store", action="store",
help="log file", default="./log/paddlespeech.log") help="yaml file of the app",
default="./conf/application.yaml")
parser.add_argument(
"--log_file",
action="store",
help="log file",
default="./log/paddlespeech.log")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -13,19 +13,9 @@
# limitations under the License. # limitations under the License.
from fastapi import APIRouter from fastapi import APIRouter
router = APIRouter() from .asr_api import router as asr_router
from .tts_api import router as tts_router
router.include_router(auth_router)
router.include_router(user_router)
router.include_router(profile_router)
router.include_router(comment_router)
router.include_router(article_router)
router.include_router(tag_router)
router = APIRouter()
router.include_router(asr_router)
def init_app(app): router.include_router(tts_router)
app.include_router(router)

@ -0,0 +1,63 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from fastapi import APIRouter
import base64
from engine.asr.python.asr_engine import ASREngine
from .response import ASRResponse
from .request import ASRRequest
router = APIRouter()
@router.get('/paddlespeech/asr/help')
def help():
"""help
Returns:
json: [description]
"""
return {'hello': 'world'}
@router.post("/paddlespeech/asr", response_model=ASRResponse)
def asr(request_body: ASRRequest):
"""asr api
Args:
request_body (ASRRequest): [description]
Returns:
json: [description]
"""
# single
asr_engine = ASREngine()
asr_engine.init()
asr_results = asr_engine.run()
asr_engine.postprocess()
json_body = {
"success": True,
"code": 0,
"message": {
"description": "success"
},
"result": {
"transcription": asr_results
}
}
return json_body

@ -11,14 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional
from typing import List from typing import List
from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
__all__ = ['ASRRequest, TTSRequest'] __all__ = ['ASRRequest, TTSRequest']
#****************************************************************************************/ #****************************************************************************************/
#************************************ ASR request ***************************************/ #************************************ ASR request ***************************************/
#****************************************************************************************/ #****************************************************************************************/
@ -29,8 +29,8 @@ class ASRRequest(BaseModel):
"audio": "exSI6ICJlbiIsCgkgICAgInBvc2l0aW9uIjogImZhbHNlIgoJf...", "audio": "exSI6ICJlbiIsCgkgICAgInBvc2l0aW9uIjogImZhbHNlIgoJf...",
"audio_format": "wav", "audio_format": "wav",
"sample_rate": 16000, "sample_rate": 16000,
"lang ": "zh_cn", "lang": "zh_cn",
"ptt ":false "ptt":false
} }
""" """
audio: str audio: str
@ -53,4 +53,4 @@ class TTSRequest(BaseModel):
"lang ": "zh_cn", "lang ": "zh_cn",
"ptt ":false "ptt ":false
} }
""" """

@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional
from typing import List from typing import List
from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
@ -22,12 +22,14 @@ __all__ = ['ASRResponse']
class Message(BaseModel): class Message(BaseModel):
description: str description: str
#****************************************************************************************/ #****************************************************************************************/
#************************************ ASR response **************************************/ #************************************ ASR response **************************************/
#****************************************************************************************/ #****************************************************************************************/
class AsrResult(BaseModel): class AsrResult(BaseModel):
transcription: str transcription: str
class ASRResponse(BaseModel): class ASRResponse(BaseModel):
""" """
response example response example
@ -36,7 +38,7 @@ class ASRResponse(BaseModel):
"code": 0, "code": 0,
"message": { "message": {
"description": "success" "description": "success"
} },
"result": { "result": {
"transcription": "你好,飞桨" "transcription": "你好,飞桨"
} }
@ -47,6 +49,7 @@ class ASRResponse(BaseModel):
message: Message message: Message
result: AsrResult result: AsrResult
#****************************************************************************************/ #****************************************************************************************/
#************************************ TTS response **************************************/ #************************************ TTS response **************************************/
#****************************************************************************************/ #****************************************************************************************/

@ -13,38 +13,17 @@
# limitations under the License. # limitations under the License.
from fastapi import APIRouter from fastapi import APIRouter
router = APIRouter()
router.include_router(auth_router)
router.include_router(user_router)
router.include_router(profile_router)
router.include_router(comment_router)
router.include_router(article_router)
router.include_router(tag_router)
def init_app(app):
asr,tts
router = APIRouter()
if asr
backend
dyload(asr)
asr.register_router(router) @router.get('/paddlespeech/tts/help')
if tts def help():
backend """help
dyload(asr)
Returns:
json: [description]
"""
return {'hello': 'world'}
asr.register_router(router)
app.include_router(router)

@ -0,0 +1,59 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import logging
__all__ = [
'logger',
]
class Logger(object):
def __init__(self, name: str=None):
name = 'PaddleSpeech' if not name else name
self.logger = logging.getLogger(name)
log_config = {
'DEBUG': 10,
'INFO': 20,
'TRAIN': 21,
'EVAL': 22,
'WARNING': 30,
'ERROR': 40,
'CRITICAL': 50,
'EXCEPTION': 100,
}
for key, level in log_config.items():
logging.addLevelName(level, key)
if key == 'EXCEPTION':
self.__dict__[key.lower()] = self.logger.exception
else:
self.__dict__[key.lower()] = functools.partial(self.__call__,
level)
self.format = logging.Formatter(
fmt='[%(asctime)-15s] [%(levelname)8s] - %(message)s')
self.handler = logging.StreamHandler()
self.handler.setFormatter(self.format)
self.logger.addHandler(self.handler)
self.logger.setLevel(logging.DEBUG)
self.logger.propagate = False
def __call__(self, log_level: str, msg: str):
self.logger.log(log_level, msg)
logger = Logger()

@ -0,0 +1,66 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the
import requests
import json
import time
import base64
import argparse
def readwav2base64(wav_file):
"""
read wave file and covert to base64 string
"""
with open(wav_file, 'rb') as f:
base64_bytes = base64.b64encode(f.read())
base64_string = base64_bytes.decode('utf-8')
return base64_string
def main(args):
"""
main func
"""
url = "http://127.0.0.1:8090/paddlespeech/asr"
# start Timestamp
time_start=time.time()
# test_audio_dir = "test_data/16_audio.wav"
# audio = readwav2base64(test_audio_dir)
data = {
"audio": "exSI6ICJlbiIsCgkgICAgInBvc2l0aW9uIjogImZhbHNlIgoJf",
"audio_format": "wav",
"sample_rate": 16000,
"lang": "zh_cn",
}
r = requests.post(url=url, data=json.dumps(data))
# ending Timestamp
time_end=time.time()
print('time cost',time_end - time_start, 's')
print(r.json())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_type", action="store",
help="model type: u2, dp2", default="dp2")
args = parser.parse_args()
main(args)
Loading…
Cancel
Save