update the asr server api, test=doc

pull/1784/head
xiongxinlei 3 years ago
parent 833900a8b4
commit 7007b0ecac

@ -20,19 +20,23 @@ import logging
import os import os
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.server.utils.audio_handler import ASRAudioHandler from paddlespeech.server.utils.audio_handler import ASRWsAudioHandler
def main(args): def main(args):
logger.info("asr websocket client start") logger.info("asr websocket client start")
handler = ASRAudioHandler("127.0.0.1", 8090) handler = ASRWsAudioHandler(
args.server_ip,
args.port,
punc_server_ip=args.punc_server_ip,
punc_server_port=args.punc_server_port)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
# support to process single audio file # support to process single audio file
if args.wavfile and os.path.exists(args.wavfile): if args.wavfile and os.path.exists(args.wavfile):
logger.info(f"start to process the wavscp: {args.wavfile}") logger.info(f"start to process the wavscp: {args.wavfile}")
result = loop.run_until_complete(handler.run(args.wavfile)) result = loop.run_until_complete(handler.run(args.wavfile))
result = result["asr_results"] result = result["final_result"]
logger.info(f"asr websocket client finished : {result}") logger.info(f"asr websocket client finished : {result}")
# support to process batch audios from wav.scp # support to process batch audios from wav.scp
@ -43,13 +47,29 @@ def main(args):
for line in f: for line in f:
utt_name, utt_path = line.strip().split() utt_name, utt_path = line.strip().split()
result = loop.run_until_complete(handler.run(utt_path)) result = loop.run_until_complete(handler.run(utt_path))
result = result["asr_results"] result = result["final_result"]
w.write(f"{utt_name} {result}\n") w.write(f"{utt_name} {result}\n")
if __name__ == "__main__": if __name__ == "__main__":
logger.info("Start to do streaming asr client") logger.info("Start to do streaming asr client")
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument(
'--server_ip', type=str, default='127.0.0.1', help='server ip')
parser.add_argument('--port', type=int, default=8090, help='server port')
parser.add_argument(
'--punc.server_ip',
type=str,
default=None,
dest="punc_server_ip",
help='Punctuation server ip')
parser.add_argument(
'--punc.port',
type=int,
default=8091,
dest="punc_server_port",
help='Punctuation server port')
parser.add_argument( parser.add_argument(
"--wavfile", "--wavfile",
action="store", action="store",

@ -21,8 +21,6 @@ from typing import Union
import numpy as np import numpy as np
import paddle import paddle
import yaml import yaml
from paddleaudio import load
from paddleaudio.features import LogMelSpectrogram
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..log import logger from ..log import logger
@ -30,6 +28,8 @@ from ..utils import cli_register
from ..utils import stats_wrapper from ..utils import stats_wrapper
from .pretrained_models import model_alias from .pretrained_models import model_alias
from .pretrained_models import pretrained_models from .pretrained_models import pretrained_models
from paddleaudio import load
from paddleaudio.features import LogMelSpectrogram
from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.dynamic_import import dynamic_import
__all__ = ['CLSExecutor'] __all__ = ['CLSExecutor']

@ -29,7 +29,7 @@ from ..executor import BaseExecutor
from ..util import cli_client_register from ..util import cli_client_register
from ..util import stats_wrapper from ..util import stats_wrapper
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.server.utils.audio_handler import ASRAudioHandler from paddlespeech.server.utils.audio_handler import ASRWsAudioHandler
from paddlespeech.server.utils.audio_process import wav2pcm from paddlespeech.server.utils.audio_process import wav2pcm
from paddlespeech.server.utils.util import wav2base64 from paddlespeech.server.utils.util import wav2base64
@ -369,7 +369,7 @@ class ASRClientExecutor(BaseExecutor):
Returns: Returns:
str: The ASR results str: The ASR results
""" """
# 1. Firstly, we use the asr server to recognize the audio text content # we use the asr server to recognize the audio text content
if protocol.lower() == "http": if protocol.lower() == "http":
from paddlespeech.server.utils.audio_handler import ASRHttpHandler from paddlespeech.server.utils.audio_handler import ASRHttpHandler
logger.info("asr http client start") logger.info("asr http client start")
@ -380,22 +380,20 @@ class ASRClientExecutor(BaseExecutor):
elif protocol.lower() == "websocket": elif protocol.lower() == "websocket":
logger.info("asr websocket client start") logger.info("asr websocket client start")
handler = ASRAudioHandler( handler = ASRWsAudioHandler(
server_ip, server_ip,
port, port,
punc_server_ip=punc_server_ip, punc_server_ip=punc_server_ip,
punc_server_port=punc_server_port) punc_server_port=punc_server_port)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
res = loop.run_until_complete(handler.run(input)) res = loop.run_until_complete(handler.run(input))
res = res['asr_results'] res = res['final_result']
logger.info("asr websocket client finished") logger.info("asr websocket client finished")
else: else:
logger.error(f"Sorry, we have not support protocol: {protocol}," logger.error(f"Sorry, we have not support protocol: {protocol},"
"please use http or websocket protocol") "please use http or websocket protocol")
sys.exit(-1) sys.exit(-1)
# 2. Secondly, we use the punctuation server to do post process for text
return res return res

@ -26,7 +26,7 @@ import pyaudio
import websockets import websockets
class ASRAudioHandler(threading.Thread): class ASRWsAudioHandler(threading.Thread):
def __init__(self, url="127.0.0.1", port=8091): def __init__(self, url="127.0.0.1", port=8091):
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.url = url self.url = url
@ -148,7 +148,7 @@ if __name__ == "__main__":
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logging.info("asr websocket client start") logging.info("asr websocket client start")
handler = ASRAudioHandler("127.0.0.1", 8091) handler = ASRWsAudioHandler("127.0.0.1", 8091)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
main_task = asyncio.ensure_future(handler.run()) main_task = asyncio.ensure_future(handler.run())
for signal in [SIGINT, SIGTERM]: for signal in [SIGINT, SIGTERM]:

@ -29,13 +29,30 @@ from paddlespeech.server.utils.util import wav2base64
class TextHttpHandler: class TextHttpHandler:
def __init__(self, server_ip="127.0.0.1", port=8090): def __init__(self, server_ip="127.0.0.1", port=8090):
"""Text http client request
Args:
server_ip (str, optional): the text server ip. Defaults to "127.0.0.1".
port (int, optional): the text server port. Defaults to 8090.
"""
super().__init__() super().__init__()
self.server_ip = server_ip self.server_ip = server_ip
self.port = port self.port = port
self.url = 'http://' + self.server_ip + ":" + str( if server_ip is None or port is None:
self.port) + '/paddlespeech/text' self.url = None
else:
self.url = 'http://' + self.server_ip + ":" + str(
self.port) + '/paddlespeech/text'
def run(self, text): def run(self, text):
"""Call the text server to process the specific text
Args:
text (str): the text to be processed
Returns:
str: punctuation text
"""
if self.server_ip is None or self.port is None: if self.server_ip is None or self.port is None:
logger.warning( logger.warning(
"No punctuation server, please input valid ip and port") "No punctuation server, please input valid ip and port")
@ -55,24 +72,29 @@ class TextHttpHandler:
return punc_text return punc_text
class ASRAudioHandler: class ASRWsAudioHandler:
def __init__(self, def __init__(self,
url="127.0.0.1", url=None,
port=8090, port=None,
punc_server_ip="127.0.0.1", endpoint="/paddlespeech/asr/streaming",
punc_server_port="8091"): punc_server_ip=None,
punc_server_port=None):
"""PaddleSpeech Online ASR Server Client audio handler """PaddleSpeech Online ASR Server Client audio handler
Online asr server use the websocket protocal Online asr server use the websocket protocal
Args: Args:
url (str, optional): the server ip. Defaults to "127.0.0.1". url (str, optional): the server ip. Defaults to None.
port (int, optional): the server port. Defaults to 8090. port (int, optional): the server port. Defaults to None.
endpoint(str, optional): to compatiable with python server and c++ server.
punc_server_ip(str, optional): the punctuation server ip. Defaults to None. punc_server_ip(str, optional): the punctuation server ip. Defaults to None.
punc_server_port(int, optional): the punctuation port. Defaults to None punc_server_port(int, optional): the punctuation port. Defaults to None
""" """
self.url = url self.url = url
self.port = port self.port = port
self.url = "ws://" + self.url + ":" + str(self.port) + "/ws/asr" if url is None or port is None or endpoint is None:
self.url = None
else:
self.url = "ws://" + self.url + ":" + str(
self.port) + endpoint
self.punc_server = TextHttpHandler(punc_server_ip, punc_server_port) self.punc_server = TextHttpHandler(punc_server_ip, punc_server_port)
def read_wave(self, wavfile_path: str): def read_wave(self, wavfile_path: str):
@ -117,6 +139,11 @@ class ASRAudioHandler:
""" """
logging.info("send a message to the server") logging.info("send a message to the server")
if self.url is None:
logger.error(
"No punctuation server, please input valid ip and port")
return ""
# 1. send websocket handshake protocal # 1. send websocket handshake protocal
async with websockets.connect(self.url) as ws: async with websockets.connect(self.url) as ws:
# 2. server has already received handshake protocal # 2. server has already received handshake protocal
@ -125,7 +152,7 @@ class ASRAudioHandler:
{ {
"name": "test.wav", "name": "test.wav",
"signal": "start", "signal": "start",
"nbest": 5 "nbest": 1
}, },
sort_keys=True, sort_keys=True,
indent=4, indent=4,
@ -139,7 +166,9 @@ class ASRAudioHandler:
await ws.send(chunk_data.tobytes()) await ws.send(chunk_data.tobytes())
msg = await ws.recv() msg = await ws.recv()
msg = json.loads(msg) msg = json.loads(msg)
msg["asr_results"] = self.punc_server.run(msg["asr_results"]) if self.punc_server and len(msg["partial_result"]) > 0:
msg["partial_result"] = self.punc_server.run(
msg["partial_result"])
logger.info("receive msg={}".format(msg)) logger.info("receive msg={}".format(msg))
# 4. we must send finished signal to the server # 4. we must send finished signal to the server
@ -157,7 +186,8 @@ class ASRAudioHandler:
# 5. decode the bytes to str # 5. decode the bytes to str
msg = json.loads(msg) msg = json.loads(msg)
msg["asr_results"] = self.punc_server.run(msg["asr_results"]) if self.punc_server:
msg["final_result"] = self.punc_server.run(msg["final_result"])
logger.info("final receive msg={}".format(msg)) logger.info("final receive msg={}".format(msg))
result = msg result = msg
@ -165,14 +195,39 @@ class ASRAudioHandler:
class ASRHttpHandler: class ASRHttpHandler:
def __init__(self, server_ip="127.0.0.1", port=8090): def __init__(self, server_ip=None, port=None):
"""The ASR client http request
Args:
server_ip (str, optional): the http asr server ip. Defaults to "127.0.0.1".
port (int, optional): the http asr server port. Defaults to 8090.
"""
super().__init__() super().__init__()
self.server_ip = server_ip self.server_ip = server_ip
self.port = port self.port = port
self.url = 'http://' + self.server_ip + ":" + str( if server_ip is None or port is None:
self.port) + '/paddlespeech/asr' self.url = None
else:
self.url = 'http://' + self.server_ip + ":" + str(
self.port) + '/paddlespeech/asr'
def run(self, input, audio_format, sample_rate, lang): def run(self, input, audio_format, sample_rate, lang):
"""Call the http asr to process the audio
Args:
input (str): the audio file path
audio_format (str): the audio format
sample_rate (str): the audio sample rate
lang (str): the audio language type
Returns:
str: the final asr result
"""
if self.url is None:
logger.error(
"No punctuation server, please input valid ip and port")
return ""
audio = wav2base64(input) audio = wav2base64(input)
data = { data = {
"audio": audio, "audio": audio,

@ -24,7 +24,7 @@ from paddlespeech.server.engine.engine_pool import get_engine_pool
router = APIRouter() router = APIRouter()
@router.websocket('/ws/asr') @router.websocket('/paddlespeech/asr/streaming')
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(websocket: WebSocket):
"""PaddleSpeech Online ASR Server api """PaddleSpeech Online ASR Server api
@ -83,7 +83,7 @@ async def websocket_endpoint(websocket: WebSocket):
resp = { resp = {
"status": "ok", "status": "ok",
"signal": "finished", "signal": "finished",
'asr_results': asr_results 'final_result': asr_results
} }
await websocket.send_json(resp) await websocket.send_json(resp)
break break
@ -102,7 +102,7 @@ async def websocket_endpoint(websocket: WebSocket):
# return the current period result # return the current period result
# if the engine create the vad instance, this connection will have many period results # if the engine create the vad instance, this connection will have many period results
resp = {'asr_results': asr_results} resp = {'partial_result': asr_results}
await websocket.send_json(resp) await websocket.send_json(resp)
except WebSocketDisconnect: except WebSocketDisconnect:
pass pass

Loading…
Cancel
Save