update the asr server api, test=doc

pull/1784/head
xiongxinlei 3 years ago
parent 833900a8b4
commit 7007b0ecac

@ -20,19 +20,23 @@ import logging
import os
from paddlespeech.cli.log import logger
from paddlespeech.server.utils.audio_handler import ASRAudioHandler
from paddlespeech.server.utils.audio_handler import ASRWsAudioHandler
def main(args):
logger.info("asr websocket client start")
handler = ASRAudioHandler("127.0.0.1", 8090)
handler = ASRWsAudioHandler(
args.server_ip,
args.port,
punc_server_ip=args.punc_server_ip,
punc_server_port=args.punc_server_port)
loop = asyncio.get_event_loop()
# support to process single audio file
if args.wavfile and os.path.exists(args.wavfile):
logger.info(f"start to process the wavscp: {args.wavfile}")
result = loop.run_until_complete(handler.run(args.wavfile))
result = result["asr_results"]
result = result["final_result"]
logger.info(f"asr websocket client finished : {result}")
# support to process batch audios from wav.scp
@ -43,13 +47,29 @@ def main(args):
for line in f:
utt_name, utt_path = line.strip().split()
result = loop.run_until_complete(handler.run(utt_path))
result = result["asr_results"]
result = result["final_result"]
w.write(f"{utt_name} {result}\n")
if __name__ == "__main__":
logger.info("Start to do streaming asr client")
parser = argparse.ArgumentParser()
parser.add_argument(
'--server_ip', type=str, default='127.0.0.1', help='server ip')
parser.add_argument('--port', type=int, default=8090, help='server port')
parser.add_argument(
'--punc.server_ip',
type=str,
default=None,
dest="punc_server_ip",
help='Punctuation server ip')
parser.add_argument(
'--punc.port',
type=int,
default=8091,
dest="punc_server_port",
help='Punctuation server port')
parser.add_argument(
"--wavfile",
action="store",

@ -21,8 +21,6 @@ from typing import Union
import numpy as np
import paddle
import yaml
from paddleaudio import load
from paddleaudio.features import LogMelSpectrogram
from ..executor import BaseExecutor
from ..log import logger
@ -30,6 +28,8 @@ from ..utils import cli_register
from ..utils import stats_wrapper
from .pretrained_models import model_alias
from .pretrained_models import pretrained_models
from paddleaudio import load
from paddleaudio.features import LogMelSpectrogram
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
__all__ = ['CLSExecutor']

@ -29,7 +29,7 @@ from ..executor import BaseExecutor
from ..util import cli_client_register
from ..util import stats_wrapper
from paddlespeech.cli.log import logger
from paddlespeech.server.utils.audio_handler import ASRAudioHandler
from paddlespeech.server.utils.audio_handler import ASRWsAudioHandler
from paddlespeech.server.utils.audio_process import wav2pcm
from paddlespeech.server.utils.util import wav2base64
@ -369,7 +369,7 @@ class ASRClientExecutor(BaseExecutor):
Returns:
str: The ASR results
"""
# 1. Firstly, we use the asr server to recognize the audio text content
# we use the asr server to recognize the audio text content
if protocol.lower() == "http":
from paddlespeech.server.utils.audio_handler import ASRHttpHandler
logger.info("asr http client start")
@ -380,22 +380,20 @@ class ASRClientExecutor(BaseExecutor):
elif protocol.lower() == "websocket":
logger.info("asr websocket client start")
handler = ASRAudioHandler(
handler = ASRWsAudioHandler(
server_ip,
port,
punc_server_ip=punc_server_ip,
punc_server_port=punc_server_port)
loop = asyncio.get_event_loop()
res = loop.run_until_complete(handler.run(input))
res = res['asr_results']
res = res['final_result']
logger.info("asr websocket client finished")
else:
logger.error(f"Sorry, we have not support protocol: {protocol},"
"please use http or websocket protocol")
sys.exit(-1)
# 2. Secondly, we use the punctuation server to do post process for text
return res

@ -26,7 +26,7 @@ import pyaudio
import websockets
class ASRAudioHandler(threading.Thread):
class ASRWsAudioHandler(threading.Thread):
def __init__(self, url="127.0.0.1", port=8091):
threading.Thread.__init__(self)
self.url = url
@ -148,7 +148,7 @@ if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
logging.info("asr websocket client start")
handler = ASRAudioHandler("127.0.0.1", 8091)
handler = ASRWsAudioHandler("127.0.0.1", 8091)
loop = asyncio.get_event_loop()
main_task = asyncio.ensure_future(handler.run())
for signal in [SIGINT, SIGTERM]:

@ -29,13 +29,30 @@ from paddlespeech.server.utils.util import wav2base64
class TextHttpHandler:
def __init__(self, server_ip="127.0.0.1", port=8090):
"""Text http client request
Args:
server_ip (str, optional): the text server ip. Defaults to "127.0.0.1".
port (int, optional): the text server port. Defaults to 8090.
"""
super().__init__()
self.server_ip = server_ip
self.port = port
if server_ip is None or port is None:
self.url = None
else:
self.url = 'http://' + self.server_ip + ":" + str(
self.port) + '/paddlespeech/text'
def run(self, text):
"""Call the text server to process the specific text
Args:
text (str): the text to be processed
Returns:
str: punctuation text
"""
if self.server_ip is None or self.port is None:
logger.warning(
"No punctuation server, please input valid ip and port")
@ -55,24 +72,29 @@ class TextHttpHandler:
return punc_text
class ASRAudioHandler:
class ASRWsAudioHandler:
def __init__(self,
url="127.0.0.1",
port=8090,
punc_server_ip="127.0.0.1",
punc_server_port="8091"):
url=None,
port=None,
endpoint="/paddlespeech/asr/streaming",
punc_server_ip=None,
punc_server_port=None):
"""PaddleSpeech Online ASR Server Client audio handler
Online asr server use the websocket protocal
Args:
url (str, optional): the server ip. Defaults to "127.0.0.1".
port (int, optional): the server port. Defaults to 8090.
url (str, optional): the server ip. Defaults to None.
port (int, optional): the server port. Defaults to None.
endpoint(str, optional): to compatiable with python server and c++ server.
punc_server_ip(str, optional): the punctuation server ip. Defaults to None.
punc_server_port(int, optional): the punctuation port. Defaults to None
"""
self.url = url
self.port = port
self.url = "ws://" + self.url + ":" + str(self.port) + "/ws/asr"
if url is None or port is None or endpoint is None:
self.url = None
else:
self.url = "ws://" + self.url + ":" + str(
self.port) + endpoint
self.punc_server = TextHttpHandler(punc_server_ip, punc_server_port)
def read_wave(self, wavfile_path: str):
@ -117,6 +139,11 @@ class ASRAudioHandler:
"""
logging.info("send a message to the server")
if self.url is None:
logger.error(
"No punctuation server, please input valid ip and port")
return ""
# 1. send websocket handshake protocal
async with websockets.connect(self.url) as ws:
# 2. server has already received handshake protocal
@ -125,7 +152,7 @@ class ASRAudioHandler:
{
"name": "test.wav",
"signal": "start",
"nbest": 5
"nbest": 1
},
sort_keys=True,
indent=4,
@ -139,7 +166,9 @@ class ASRAudioHandler:
await ws.send(chunk_data.tobytes())
msg = await ws.recv()
msg = json.loads(msg)
msg["asr_results"] = self.punc_server.run(msg["asr_results"])
if self.punc_server and len(msg["partial_result"]) > 0:
msg["partial_result"] = self.punc_server.run(
msg["partial_result"])
logger.info("receive msg={}".format(msg))
# 4. we must send finished signal to the server
@ -157,7 +186,8 @@ class ASRAudioHandler:
# 5. decode the bytes to str
msg = json.loads(msg)
msg["asr_results"] = self.punc_server.run(msg["asr_results"])
if self.punc_server:
msg["final_result"] = self.punc_server.run(msg["final_result"])
logger.info("final receive msg={}".format(msg))
result = msg
@ -165,14 +195,39 @@ class ASRAudioHandler:
class ASRHttpHandler:
def __init__(self, server_ip="127.0.0.1", port=8090):
def __init__(self, server_ip=None, port=None):
"""The ASR client http request
Args:
server_ip (str, optional): the http asr server ip. Defaults to "127.0.0.1".
port (int, optional): the http asr server port. Defaults to 8090.
"""
super().__init__()
self.server_ip = server_ip
self.port = port
if server_ip is None or port is None:
self.url = None
else:
self.url = 'http://' + self.server_ip + ":" + str(
self.port) + '/paddlespeech/asr'
def run(self, input, audio_format, sample_rate, lang):
"""Call the http asr to process the audio
Args:
input (str): the audio file path
audio_format (str): the audio format
sample_rate (str): the audio sample rate
lang (str): the audio language type
Returns:
str: the final asr result
"""
if self.url is None:
logger.error(
"No punctuation server, please input valid ip and port")
return ""
audio = wav2base64(input)
data = {
"audio": audio,

@ -24,7 +24,7 @@ from paddlespeech.server.engine.engine_pool import get_engine_pool
router = APIRouter()
@router.websocket('/ws/asr')
@router.websocket('/paddlespeech/asr/streaming')
async def websocket_endpoint(websocket: WebSocket):
"""PaddleSpeech Online ASR Server api
@ -83,7 +83,7 @@ async def websocket_endpoint(websocket: WebSocket):
resp = {
"status": "ok",
"signal": "finished",
'asr_results': asr_results
'final_result': asr_results
}
await websocket.send_json(resp)
break
@ -102,7 +102,7 @@ async def websocket_endpoint(websocket: WebSocket):
# return the current period result
# if the engine create the vad instance, this connection will have many period results
resp = {'asr_results': asr_results}
resp = {'partial_result': asr_results}
await websocket.send_json(resp)
except WebSocketDisconnect:
pass

Loading…
Cancel
Save