Merge pull request #1784 from Honei/v0.3

[asr][server]asr client add punctuatjion server
pull/1788/head
Hui Zhang 3 years ago committed by GitHub
commit b6d0db0ddd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -93,7 +93,7 @@
function parseResult(data) { function parseResult(data) {
var data = JSON.parse(data) var data = JSON.parse(data)
var result = data.asr_results var result = data.result
console.log(result) console.log(result)
$("#resultPanel").html(result) $("#resultPanel").html(result)
} }

@ -20,19 +20,23 @@ import logging
import os import os
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.server.utils.audio_handler import ASRAudioHandler from paddlespeech.server.utils.audio_handler import ASRWsAudioHandler
def main(args): def main(args):
logger.info("asr websocket client start") logger.info("asr websocket client start")
handler = ASRAudioHandler("127.0.0.1", 8090) handler = ASRWsAudioHandler(
args.server_ip,
args.port,
punc_server_ip=args.punc_server_ip,
punc_server_port=args.punc_server_port)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
# support to process single audio file # support to process single audio file
if args.wavfile and os.path.exists(args.wavfile): if args.wavfile and os.path.exists(args.wavfile):
logger.info(f"start to process the wavscp: {args.wavfile}") logger.info(f"start to process the wavscp: {args.wavfile}")
result = loop.run_until_complete(handler.run(args.wavfile)) result = loop.run_until_complete(handler.run(args.wavfile))
result = result["asr_results"] result = result["final_result"]
logger.info(f"asr websocket client finished : {result}") logger.info(f"asr websocket client finished : {result}")
# support to process batch audios from wav.scp # support to process batch audios from wav.scp
@ -43,13 +47,29 @@ def main(args):
for line in f: for line in f:
utt_name, utt_path = line.strip().split() utt_name, utt_path = line.strip().split()
result = loop.run_until_complete(handler.run(utt_path)) result = loop.run_until_complete(handler.run(utt_path))
result = result["asr_results"] result = result["final_result"]
w.write(f"{utt_name} {result}\n") w.write(f"{utt_name} {result}\n")
if __name__ == "__main__": if __name__ == "__main__":
logger.info("Start to do streaming asr client") logger.info("Start to do streaming asr client")
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument(
'--server_ip', type=str, default='127.0.0.1', help='server ip')
parser.add_argument('--port', type=int, default=8090, help='server port')
parser.add_argument(
'--punc.server_ip',
type=str,
default=None,
dest="punc_server_ip",
help='Punctuation server ip')
parser.add_argument(
'--punc.port',
type=int,
default=8091,
dest="punc_server_port",
help='Punctuation server port')
parser.add_argument( parser.add_argument(
"--wavfile", "--wavfile",
action="store", action="store",

@ -21,8 +21,6 @@ from typing import Union
import numpy as np import numpy as np
import paddle import paddle
import yaml import yaml
from paddleaudio import load
from paddleaudio.features import LogMelSpectrogram
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..log import logger from ..log import logger
@ -30,6 +28,8 @@ from ..utils import cli_register
from ..utils import stats_wrapper from ..utils import stats_wrapper
from .pretrained_models import model_alias from .pretrained_models import model_alias
from .pretrained_models import pretrained_models from .pretrained_models import pretrained_models
from paddleaudio import load
from paddleaudio.features import LogMelSpectrogram
from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.dynamic_import import dynamic_import
__all__ = ['CLSExecutor'] __all__ = ['CLSExecutor']

@ -16,7 +16,6 @@ import asyncio
import base64 import base64
import io import io
import json import json
import logging
import os import os
import random import random
import time import time
@ -30,13 +29,13 @@ from ..executor import BaseExecutor
from ..util import cli_client_register from ..util import cli_client_register
from ..util import stats_wrapper from ..util import stats_wrapper
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.server.utils.audio_handler import ASRAudioHandler from paddlespeech.server.utils.audio_handler import ASRWsAudioHandler
from paddlespeech.server.utils.audio_process import wav2pcm from paddlespeech.server.utils.audio_process import wav2pcm
from paddlespeech.server.utils.util import wav2base64 from paddlespeech.server.utils.util import wav2base64
__all__ = [ __all__ = [
'TTSClientExecutor', 'TTSOnlineClientExecutor', 'ASRClientExecutor', 'TTSClientExecutor', 'TTSOnlineClientExecutor', 'ASRClientExecutor',
'ASROnlineClientExecutor', 'CLSClientExecutor' 'CLSClientExecutor'
] ]
@ -288,6 +287,12 @@ class ASRClientExecutor(BaseExecutor):
default=None, default=None,
help='Audio file to be recognized', help='Audio file to be recognized',
required=True) required=True)
self.parser.add_argument(
'--protocol',
type=str,
default="http",
choices=["http", "websocket"],
help='server protocol')
self.parser.add_argument( self.parser.add_argument(
'--sample_rate', type=int, default=16000, help='audio sample rate') '--sample_rate', type=int, default=16000, help='audio sample rate')
self.parser.add_argument( self.parser.add_argument(
@ -295,81 +300,18 @@ class ASRClientExecutor(BaseExecutor):
self.parser.add_argument( self.parser.add_argument(
'--audio_format', type=str, default="wav", help='audio format') '--audio_format', type=str, default="wav", help='audio format')
def execute(self, argv: List[str]) -> bool:
args = self.parser.parse_args(argv)
input_ = args.input
server_ip = args.server_ip
port = args.port
sample_rate = args.sample_rate
lang = args.lang
audio_format = args.audio_format
try:
time_start = time.time()
res = self(
input=input_,
server_ip=server_ip,
port=port,
sample_rate=sample_rate,
lang=lang,
audio_format=audio_format)
time_end = time.time()
logger.info(res.json())
logger.info("Response time %f s." % (time_end - time_start))
return True
except Exception as e:
logger.error("Failed to speech recognition.")
return False
@stats_wrapper
def __call__(self,
input: str,
server_ip: str="127.0.0.1",
port: int=8090,
sample_rate: int=16000,
lang: str="zh_cn",
audio_format: str="wav"):
"""
Python API to call an executor.
"""
url = 'http://' + server_ip + ":" + str(port) + '/paddlespeech/asr'
audio = wav2base64(input)
data = {
"audio": audio,
"audio_format": audio_format,
"sample_rate": sample_rate,
"lang": lang,
}
res = requests.post(url=url, data=json.dumps(data))
return res
@cli_client_register(
name='paddlespeech_client.asr_online',
description='visit asr online service')
class ASROnlineClientExecutor(BaseExecutor):
def __init__(self):
super(ASROnlineClientExecutor, self).__init__()
self.parser = argparse.ArgumentParser(
prog='paddlespeech_client.asr_online', add_help=True)
self.parser.add_argument(
'--server_ip', type=str, default='127.0.0.1', help='server ip')
self.parser.add_argument( self.parser.add_argument(
'--port', type=int, default=8091, help='server port') '--punc.server_ip',
self.parser.add_argument(
'--input',
type=str, type=str,
default=None, default=None,
help='Audio file to be recognized', dest="punc_server_ip",
required=True) help='Punctuation server ip')
self.parser.add_argument( self.parser.add_argument(
'--sample_rate', type=int, default=16000, help='audio sample rate') '--punc.port',
self.parser.add_argument( type=int,
'--lang', type=str, default="zh_cn", help='language') default=8091,
self.parser.add_argument( dest="punc_server_port",
'--audio_format', type=str, default="wav", help='audio format') help='Punctuation server port')
def execute(self, argv: List[str]) -> bool: def execute(self, argv: List[str]) -> bool:
args = self.parser.parse_args(argv) args = self.parser.parse_args(argv)
@ -379,6 +321,7 @@ class ASROnlineClientExecutor(BaseExecutor):
sample_rate = args.sample_rate sample_rate = args.sample_rate
lang = args.lang lang = args.lang
audio_format = args.audio_format audio_format = args.audio_format
protocol = args.protocol
try: try:
time_start = time.time() time_start = time.time()
@ -388,9 +331,12 @@ class ASROnlineClientExecutor(BaseExecutor):
port=port, port=port,
sample_rate=sample_rate, sample_rate=sample_rate,
lang=lang, lang=lang,
audio_format=audio_format) audio_format=audio_format,
protocol=protocol,
punc_server_ip=args.punc_server_ip,
punc_server_port=args.punc_server_port)
time_end = time.time() time_end = time.time()
logger.info(res) logger.info(f"ASR result: {res}")
logger.info("Response time %f s." % (time_end - time_start)) logger.info("Response time %f s." % (time_end - time_start))
return True return True
except Exception as e: except Exception as e:
@ -402,21 +348,53 @@ class ASROnlineClientExecutor(BaseExecutor):
def __call__(self, def __call__(self,
input: str, input: str,
server_ip: str="127.0.0.1", server_ip: str="127.0.0.1",
port: int=8091, port: int=8090,
sample_rate: int=16000, sample_rate: int=16000,
lang: str="zh_cn", lang: str="zh_cn",
audio_format: str="wav"): audio_format: str="wav",
""" protocol: str="http",
Python API to call an executor. punc_server_ip: str="127.0.0.1",
punc_server_port: int=8091):
"""Python API to call an executor.
Args:
input (str): The input audio file path
server_ip (str, optional): The ASR server ip. Defaults to "127.0.0.1".
port (int, optional): The ASR server port. Defaults to 8090.
sample_rate (int, optional): The audio sample rate. Defaults to 16000.
lang (str, optional): The audio language type. Defaults to "zh_cn".
audio_format (str, optional): The audio format information. Defaults to "wav".
protocol (str, optional): The ASR server. Defaults to "http".
Returns:
str: The ASR results
""" """
logging.basicConfig(level=logging.INFO) # we use the asr server to recognize the audio text content
logging.info("asr websocket client start") if protocol.lower() == "http":
handler = ASRAudioHandler(server_ip, port) from paddlespeech.server.utils.audio_handler import ASRHttpHandler
loop = asyncio.get_event_loop() logger.info("asr http client start")
res = loop.run_until_complete(handler.run(input)) handler = ASRHttpHandler(server_ip=server_ip, port=port)
logging.info("asr websocket client finished") res = handler.run(input, audio_format, sample_rate, lang)
res = res['result']['transcription']
return res['asr_results'] logger.info("asr http client finished")
elif protocol.lower() == "websocket":
logger.info("asr websocket client start")
handler = ASRWsAudioHandler(
server_ip,
port,
punc_server_ip=punc_server_ip,
punc_server_port=punc_server_port)
loop = asyncio.get_event_loop()
res = loop.run_until_complete(handler.run(input))
res = res['result']
logger.info("asr websocket client finished")
else:
logger.error(f"Sorry, we have not support protocol: {protocol},"
"please use http or websocket protocol")
sys.exit(-1)
return res
@cli_client_register( @cli_client_register(

@ -26,7 +26,7 @@ import pyaudio
import websockets import websockets
class ASRAudioHandler(threading.Thread): class ASRWsAudioHandler(threading.Thread):
def __init__(self, url="127.0.0.1", port=8091): def __init__(self, url="127.0.0.1", port=8091):
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.url = url self.url = url
@ -148,7 +148,7 @@ if __name__ == "__main__":
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logging.info("asr websocket client start") logging.info("asr websocket client start")
handler = ASRAudioHandler("127.0.0.1", 8091) handler = ASRWsAudioHandler("127.0.0.1", 8091)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
main_task = asyncio.ensure_future(handler.run()) main_task = asyncio.ensure_future(handler.run())
for signal in [SIGINT, SIGTERM]: for signal in [SIGINT, SIGTERM]:

@ -24,22 +24,76 @@ import websockets
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.server.utils.audio_process import save_audio from paddlespeech.server.utils.audio_process import save_audio
from paddlespeech.server.utils.util import wav2base64
class ASRAudioHandler: class TextHttpHandler:
def __init__(self, server_ip="127.0.0.1", port=8090):
"""Text http client request
Args:
server_ip (str, optional): the text server ip. Defaults to "127.0.0.1".
port (int, optional): the text server port. Defaults to 8090.
"""
super().__init__()
self.server_ip = server_ip
self.port = port
if server_ip is None or port is None:
self.url = None
else:
self.url = 'http://' + self.server_ip + ":" + str(
self.port) + '/paddlespeech/text'
def run(self, text):
"""Call the text server to process the specific text
Args:
text (str): the text to be processed
Returns:
str: punctuation text
"""
if self.server_ip is None or self.port is None:
return text
request = {
"text": text,
}
try:
res = requests.post(url=self.url, data=json.dumps(request))
response_dict = res.json()
punc_text = response_dict["result"]["punc_text"]
except Exception as e:
logger.error(f"Call punctuation {self.url} occurs")
logger.error(e)
punc_text = text
return punc_text
class ASRWsAudioHandler:
def __init__(self, def __init__(self,
url="127.0.0.1", url=None,
port=8090, port=None,
endopoint='/paddlespeech/asr/streaming'): endpoint="/paddlespeech/asr/streaming",
punc_server_ip=None,
punc_server_port=None):
"""PaddleSpeech Online ASR Server Client audio handler """PaddleSpeech Online ASR Server Client audio handler
Online asr server use the websocket protocal Online asr server use the websocket protocal
Args: Args:
url (str, optional): the server ip. Defaults to "127.0.0.1". url (str, optional): the server ip. Defaults to None.
port (int, optional): the server port. Defaults to 8090. port (int, optional): the server port. Defaults to None.
endpoint(str, optional): to compatiable with python server and c++ server.
punc_server_ip(str, optional): the punctuation server ip. Defaults to None.
punc_server_port(int, optional): the punctuation port. Defaults to None
""" """
self.url = url self.url = url
self.port = port self.port = port
self.url = "ws://" + self.url + ":" + str(self.port) + endopoint if url is None or port is None or endpoint is None:
self.url = None
else:
self.url = "ws://" + self.url + ":" + str(
self.port) + endpoint
self.punc_server = TextHttpHandler(punc_server_ip, punc_server_port)
logger.info(f"endpoint: {self.url}") logger.info(f"endpoint: {self.url}")
def read_wave(self, wavfile_path: str): def read_wave(self, wavfile_path: str):
@ -84,6 +138,11 @@ class ASRAudioHandler:
""" """
logging.info("send a message to the server") logging.info("send a message to the server")
if self.url is None:
logger.error(
"No asr server, please input valid ip and port")
return ""
# 1. send websocket handshake protocal # 1. send websocket handshake protocal
async with websockets.connect(self.url) as ws: async with websockets.connect(self.url) as ws:
# 2. server has already received handshake protocal # 2. server has already received handshake protocal
@ -92,7 +151,7 @@ class ASRAudioHandler:
{ {
"name": "test.wav", "name": "test.wav",
"signal": "start", "signal": "start",
"nbest": 5 "nbest": 1
}, },
sort_keys=True, sort_keys=True,
indent=4, indent=4,
@ -106,6 +165,10 @@ class ASRAudioHandler:
await ws.send(chunk_data.tobytes()) await ws.send(chunk_data.tobytes())
msg = await ws.recv() msg = await ws.recv()
msg = json.loads(msg) msg = json.loads(msg)
if self.punc_server and len(msg["result"]) > 0:
msg["result"] = self.punc_server.run(
msg["result"])
logger.info("client receive msg={}".format(msg)) logger.info("client receive msg={}".format(msg))
# 4. we must send finished signal to the server # 4. we must send finished signal to the server
@ -123,11 +186,63 @@ class ASRAudioHandler:
# 5. decode the bytes to str # 5. decode the bytes to str
msg = json.loads(msg) msg = json.loads(msg)
if self.punc_server:
msg["result"] = self.punc_server.run(msg["result"])
logger.info("client final receive msg={}".format(msg)) logger.info("client final receive msg={}".format(msg))
result = msg result = msg
return result return result
class ASRHttpHandler:
def __init__(self, server_ip=None, port=None):
"""The ASR client http request
Args:
server_ip (str, optional): the http asr server ip. Defaults to "127.0.0.1".
port (int, optional): the http asr server port. Defaults to 8090.
"""
super().__init__()
self.server_ip = server_ip
self.port = port
if server_ip is None or port is None:
self.url = None
else:
self.url = 'http://' + self.server_ip + ":" + str(
self.port) + '/paddlespeech/asr'
def run(self, input, audio_format, sample_rate, lang):
"""Call the http asr to process the audio
Args:
input (str): the audio file path
audio_format (str): the audio format
sample_rate (str): the audio sample rate
lang (str): the audio language type
Returns:
str: the final asr result
"""
if self.url is None:
logger.error(
"No punctuation server, please input valid ip and port")
return ""
audio = wav2base64(input)
data = {
"audio": audio,
"audio_format": audio_format,
"sample_rate": sample_rate,
"lang": lang,
}
res = requests.post(url=self.url, data=json.dumps(data))
return res.json()
class TTSWsHandler: class TTSWsHandler:
def __init__(self, server="127.0.0.1", port=8092, play: bool=False): def __init__(self, server="127.0.0.1", port=8092, play: bool=False):
"""PaddleSpeech Online TTS Server Client audio handler """PaddleSpeech Online TTS Server Client audio handler

@ -24,7 +24,7 @@ from paddlespeech.server.engine.engine_pool import get_engine_pool
router = APIRouter() router = APIRouter()
@router.websocket('/ws/asr') @router.websocket('/paddlespeech/asr/streaming')
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(websocket: WebSocket):
"""PaddleSpeech Online ASR Server api """PaddleSpeech Online ASR Server api
@ -83,7 +83,7 @@ async def websocket_endpoint(websocket: WebSocket):
resp = { resp = {
"status": "ok", "status": "ok",
"signal": "finished", "signal": "finished",
'asr_results': asr_results 'result': asr_results
} }
await websocket.send_json(resp) await websocket.send_json(resp)
break break
@ -102,7 +102,7 @@ async def websocket_endpoint(websocket: WebSocket):
# return the current period result # return the current period result
# if the engine create the vad instance, this connection will have many period results # if the engine create the vad instance, this connection will have many period results
resp = {'asr_results': asr_results} resp = {'result': asr_results}
await websocket.send_json(resp) await websocket.send_json(resp)
except WebSocketDisconnect: except WebSocketDisconnect:
pass pass

Loading…
Cancel
Save