asr client add punctuatjion server, test=doc

pull/1784/head
xiongxinlei 3 years ago
parent 119143d0f1
commit 833900a8b4

@ -16,7 +16,6 @@ import asyncio
import base64 import base64
import io import io
import json import json
import logging
import os import os
import random import random
import time import time
@ -36,7 +35,7 @@ from paddlespeech.server.utils.util import wav2base64
__all__ = [ __all__ = [
'TTSClientExecutor', 'TTSOnlineClientExecutor', 'ASRClientExecutor', 'TTSClientExecutor', 'TTSOnlineClientExecutor', 'ASRClientExecutor',
'ASROnlineClientExecutor', 'CLSClientExecutor' 'CLSClientExecutor'
] ]
@ -288,6 +287,12 @@ class ASRClientExecutor(BaseExecutor):
default=None, default=None,
help='Audio file to be recognized', help='Audio file to be recognized',
required=True) required=True)
self.parser.add_argument(
'--protocol',
type=str,
default="http",
choices=["http", "websocket"],
help='server protocol')
self.parser.add_argument( self.parser.add_argument(
'--sample_rate', type=int, default=16000, help='audio sample rate') '--sample_rate', type=int, default=16000, help='audio sample rate')
self.parser.add_argument( self.parser.add_argument(
@ -295,81 +300,18 @@ class ASRClientExecutor(BaseExecutor):
self.parser.add_argument( self.parser.add_argument(
'--audio_format', type=str, default="wav", help='audio format') '--audio_format', type=str, default="wav", help='audio format')
def execute(self, argv: List[str]) -> bool:
args = self.parser.parse_args(argv)
input_ = args.input
server_ip = args.server_ip
port = args.port
sample_rate = args.sample_rate
lang = args.lang
audio_format = args.audio_format
try:
time_start = time.time()
res = self(
input=input_,
server_ip=server_ip,
port=port,
sample_rate=sample_rate,
lang=lang,
audio_format=audio_format)
time_end = time.time()
logger.info(res.json())
logger.info("Response time %f s." % (time_end - time_start))
return True
except Exception as e:
logger.error("Failed to speech recognition.")
return False
@stats_wrapper
def __call__(self,
input: str,
server_ip: str="127.0.0.1",
port: int=8090,
sample_rate: int=16000,
lang: str="zh_cn",
audio_format: str="wav"):
"""
Python API to call an executor.
"""
url = 'http://' + server_ip + ":" + str(port) + '/paddlespeech/asr'
audio = wav2base64(input)
data = {
"audio": audio,
"audio_format": audio_format,
"sample_rate": sample_rate,
"lang": lang,
}
res = requests.post(url=url, data=json.dumps(data))
return res
@cli_client_register(
name='paddlespeech_client.asr_online',
description='visit asr online service')
class ASROnlineClientExecutor(BaseExecutor):
def __init__(self):
super(ASROnlineClientExecutor, self).__init__()
self.parser = argparse.ArgumentParser(
prog='paddlespeech_client.asr_online', add_help=True)
self.parser.add_argument(
'--server_ip', type=str, default='127.0.0.1', help='server ip')
self.parser.add_argument(
'--port', type=int, default=8091, help='server port')
self.parser.add_argument( self.parser.add_argument(
'--input', '--punc.server_ip',
type=str, type=str,
default=None, default=None,
help='Audio file to be recognized', dest="punc_server_ip",
required=True) help='Punctuation server ip')
self.parser.add_argument( self.parser.add_argument(
'--sample_rate', type=int, default=16000, help='audio sample rate') '--punc.port',
self.parser.add_argument( type=int,
'--lang', type=str, default="zh_cn", help='language') default=8091,
self.parser.add_argument( dest="punc_server_port",
'--audio_format', type=str, default="wav", help='audio format') help='Punctuation server port')
def execute(self, argv: List[str]) -> bool: def execute(self, argv: List[str]) -> bool:
args = self.parser.parse_args(argv) args = self.parser.parse_args(argv)
@ -379,6 +321,7 @@ class ASROnlineClientExecutor(BaseExecutor):
sample_rate = args.sample_rate sample_rate = args.sample_rate
lang = args.lang lang = args.lang
audio_format = args.audio_format audio_format = args.audio_format
protocol = args.protocol
try: try:
time_start = time.time() time_start = time.time()
@ -388,9 +331,12 @@ class ASROnlineClientExecutor(BaseExecutor):
port=port, port=port,
sample_rate=sample_rate, sample_rate=sample_rate,
lang=lang, lang=lang,
audio_format=audio_format) audio_format=audio_format,
protocol=protocol,
punc_server_ip=args.punc_server_ip,
punc_server_port=args.punc_server_port)
time_end = time.time() time_end = time.time()
logger.info(res) logger.info(f"ASR result: {res}")
logger.info("Response time %f s." % (time_end - time_start)) logger.info("Response time %f s." % (time_end - time_start))
return True return True
except Exception as e: except Exception as e:
@ -402,21 +348,55 @@ class ASROnlineClientExecutor(BaseExecutor):
def __call__(self, def __call__(self,
input: str, input: str,
server_ip: str="127.0.0.1", server_ip: str="127.0.0.1",
port: int=8091, port: int=8090,
sample_rate: int=16000, sample_rate: int=16000,
lang: str="zh_cn", lang: str="zh_cn",
audio_format: str="wav"): audio_format: str="wav",
""" protocol: str="http",
Python API to call an executor. punc_server_ip: str="127.0.0.1",
punc_server_port: int=8091):
"""Python API to call an executor.
Args:
input (str): The input audio file path
server_ip (str, optional): The ASR server ip. Defaults to "127.0.0.1".
port (int, optional): The ASR server port. Defaults to 8090.
sample_rate (int, optional): The audio sample rate. Defaults to 16000.
lang (str, optional): The audio language type. Defaults to "zh_cn".
audio_format (str, optional): The audio format information. Defaults to "wav".
protocol (str, optional): The ASR server. Defaults to "http".
Returns:
str: The ASR results
""" """
logging.basicConfig(level=logging.INFO) # 1. Firstly, we use the asr server to recognize the audio text content
logging.info("asr websocket client start") if protocol.lower() == "http":
handler = ASRAudioHandler(server_ip, port) from paddlespeech.server.utils.audio_handler import ASRHttpHandler
logger.info("asr http client start")
handler = ASRHttpHandler(server_ip=server_ip, port=port)
res = handler.run(input, audio_format, sample_rate, lang)
res = res['result']['transcription']
logger.info("asr http client finished")
elif protocol.lower() == "websocket":
logger.info("asr websocket client start")
handler = ASRAudioHandler(
server_ip,
port,
punc_server_ip=punc_server_ip,
punc_server_port=punc_server_port)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
res = loop.run_until_complete(handler.run(input)) res = loop.run_until_complete(handler.run(input))
logging.info("asr websocket client finished") res = res['asr_results']
logger.info("asr websocket client finished")
else:
logger.error(f"Sorry, we have not support protocol: {protocol},"
"please use http or websocket protocol")
sys.exit(-1)
return res['asr_results'] # 2. Secondly, we use the punctuation server to do post process for text
return res
@cli_client_register( @cli_client_register(

@ -24,20 +24,57 @@ import websockets
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.server.utils.audio_process import save_audio from paddlespeech.server.utils.audio_process import save_audio
from paddlespeech.server.utils.util import wav2base64
class TextHttpHandler:
def __init__(self, server_ip="127.0.0.1", port=8090):
super().__init__()
self.server_ip = server_ip
self.port = port
self.url = 'http://' + self.server_ip + ":" + str(
self.port) + '/paddlespeech/text'
def run(self, text):
if self.server_ip is None or self.port is None:
logger.warning(
"No punctuation server, please input valid ip and port")
return text
request = {
"text": text,
}
try:
res = requests.post(url=self.url, data=json.dumps(request))
response_dict = res.json()
punc_text = response_dict["result"]["punc_text"]
except Exception as e:
logger.error(f"Call punctuation {self.url} occurs")
logger.error(e)
punc_text = text
return punc_text
class ASRAudioHandler: class ASRAudioHandler:
def __init__(self, url="127.0.0.1", port=8090): def __init__(self,
url="127.0.0.1",
port=8090,
punc_server_ip="127.0.0.1",
punc_server_port="8091"):
"""PaddleSpeech Online ASR Server Client audio handler """PaddleSpeech Online ASR Server Client audio handler
Online asr server use the websocket protocal Online asr server use the websocket protocal
Args: Args:
url (str, optional): the server ip. Defaults to "127.0.0.1". url (str, optional): the server ip. Defaults to "127.0.0.1".
port (int, optional): the server port. Defaults to 8090. port (int, optional): the server port. Defaults to 8090.
punc_server_ip(str, optional): the punctuation server ip. Defaults to None.
punc_server_port(int, optional): the punctuation port. Defaults to None
""" """
self.url = url self.url = url
self.port = port self.port = port
self.url = "ws://" + self.url + ":" + str(self.port) + "/ws/asr" self.url = "ws://" + self.url + ":" + str(self.port) + "/ws/asr"
self.punc_server = TextHttpHandler(punc_server_ip, punc_server_port)
def read_wave(self, wavfile_path: str): def read_wave(self, wavfile_path: str):
"""read the audio file from specific wavfile path """read the audio file from specific wavfile path
@ -102,6 +139,7 @@ class ASRAudioHandler:
await ws.send(chunk_data.tobytes()) await ws.send(chunk_data.tobytes())
msg = await ws.recv() msg = await ws.recv()
msg = json.loads(msg) msg = json.loads(msg)
msg["asr_results"] = self.punc_server.run(msg["asr_results"])
logger.info("receive msg={}".format(msg)) logger.info("receive msg={}".format(msg))
# 4. we must send finished signal to the server # 4. we must send finished signal to the server
@ -119,11 +157,35 @@ class ASRAudioHandler:
# 5. decode the bytes to str # 5. decode the bytes to str
msg = json.loads(msg) msg = json.loads(msg)
msg["asr_results"] = self.punc_server.run(msg["asr_results"])
logger.info("final receive msg={}".format(msg)) logger.info("final receive msg={}".format(msg))
result = msg result = msg
return result return result
class ASRHttpHandler:
def __init__(self, server_ip="127.0.0.1", port=8090):
super().__init__()
self.server_ip = server_ip
self.port = port
self.url = 'http://' + self.server_ip + ":" + str(
self.port) + '/paddlespeech/asr'
def run(self, input, audio_format, sample_rate, lang):
audio = wav2base64(input)
data = {
"audio": audio,
"audio_format": audio_format,
"sample_rate": sample_rate,
"lang": lang,
}
res = requests.post(url=self.url, data=json.dumps(data))
return res.json()
class TTSWsHandler: class TTSWsHandler:
def __init__(self, server="127.0.0.1", port=8092, play: bool=False): def __init__(self, server="127.0.0.1", port=8092, play: bool=False):
"""PaddleSpeech Online TTS Server Client audio handler """PaddleSpeech Online TTS Server Client audio handler

Loading…
Cancel
Save