|
|
|
@ -16,7 +16,6 @@ import asyncio
|
|
|
|
|
import base64
|
|
|
|
|
import io
|
|
|
|
|
import json
|
|
|
|
|
import logging
|
|
|
|
|
import os
|
|
|
|
|
import random
|
|
|
|
|
import time
|
|
|
|
@ -30,13 +29,13 @@ from ..executor import BaseExecutor
|
|
|
|
|
from ..util import cli_client_register
|
|
|
|
|
from ..util import stats_wrapper
|
|
|
|
|
from paddlespeech.cli.log import logger
|
|
|
|
|
from paddlespeech.server.utils.audio_handler import ASRAudioHandler
|
|
|
|
|
from paddlespeech.server.utils.audio_handler import ASRWsAudioHandler
|
|
|
|
|
from paddlespeech.server.utils.audio_process import wav2pcm
|
|
|
|
|
from paddlespeech.server.utils.util import wav2base64
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
|
'TTSClientExecutor', 'TTSOnlineClientExecutor', 'ASRClientExecutor',
|
|
|
|
|
'ASROnlineClientExecutor', 'CLSClientExecutor'
|
|
|
|
|
'CLSClientExecutor'
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -288,6 +287,12 @@ class ASRClientExecutor(BaseExecutor):
|
|
|
|
|
default=None,
|
|
|
|
|
help='Audio file to be recognized',
|
|
|
|
|
required=True)
|
|
|
|
|
self.parser.add_argument(
|
|
|
|
|
'--protocol',
|
|
|
|
|
type=str,
|
|
|
|
|
default="http",
|
|
|
|
|
choices=["http", "websocket"],
|
|
|
|
|
help='server protocol')
|
|
|
|
|
self.parser.add_argument(
|
|
|
|
|
'--sample_rate', type=int, default=16000, help='audio sample rate')
|
|
|
|
|
self.parser.add_argument(
|
|
|
|
@ -295,81 +300,18 @@ class ASRClientExecutor(BaseExecutor):
|
|
|
|
|
self.parser.add_argument(
|
|
|
|
|
'--audio_format', type=str, default="wav", help='audio format')
|
|
|
|
|
|
|
|
|
|
def execute(self, argv: List[str]) -> bool:
|
|
|
|
|
args = self.parser.parse_args(argv)
|
|
|
|
|
input_ = args.input
|
|
|
|
|
server_ip = args.server_ip
|
|
|
|
|
port = args.port
|
|
|
|
|
sample_rate = args.sample_rate
|
|
|
|
|
lang = args.lang
|
|
|
|
|
audio_format = args.audio_format
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
time_start = time.time()
|
|
|
|
|
res = self(
|
|
|
|
|
input=input_,
|
|
|
|
|
server_ip=server_ip,
|
|
|
|
|
port=port,
|
|
|
|
|
sample_rate=sample_rate,
|
|
|
|
|
lang=lang,
|
|
|
|
|
audio_format=audio_format)
|
|
|
|
|
time_end = time.time()
|
|
|
|
|
logger.info(res.json())
|
|
|
|
|
logger.info("Response time %f s." % (time_end - time_start))
|
|
|
|
|
return True
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error("Failed to speech recognition.")
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
@stats_wrapper
|
|
|
|
|
def __call__(self,
|
|
|
|
|
input: str,
|
|
|
|
|
server_ip: str="127.0.0.1",
|
|
|
|
|
port: int=8090,
|
|
|
|
|
sample_rate: int=16000,
|
|
|
|
|
lang: str="zh_cn",
|
|
|
|
|
audio_format: str="wav"):
|
|
|
|
|
"""
|
|
|
|
|
Python API to call an executor.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
url = 'http://' + server_ip + ":" + str(port) + '/paddlespeech/asr'
|
|
|
|
|
audio = wav2base64(input)
|
|
|
|
|
data = {
|
|
|
|
|
"audio": audio,
|
|
|
|
|
"audio_format": audio_format,
|
|
|
|
|
"sample_rate": sample_rate,
|
|
|
|
|
"lang": lang,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
res = requests.post(url=url, data=json.dumps(data))
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@cli_client_register(
|
|
|
|
|
name='paddlespeech_client.asr_online',
|
|
|
|
|
description='visit asr online service')
|
|
|
|
|
class ASROnlineClientExecutor(BaseExecutor):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(ASROnlineClientExecutor, self).__init__()
|
|
|
|
|
self.parser = argparse.ArgumentParser(
|
|
|
|
|
prog='paddlespeech_client.asr_online', add_help=True)
|
|
|
|
|
self.parser.add_argument(
|
|
|
|
|
'--server_ip', type=str, default='127.0.0.1', help='server ip')
|
|
|
|
|
self.parser.add_argument(
|
|
|
|
|
'--port', type=int, default=8091, help='server port')
|
|
|
|
|
self.parser.add_argument(
|
|
|
|
|
'--input',
|
|
|
|
|
'--punc.server_ip',
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
help='Audio file to be recognized',
|
|
|
|
|
required=True)
|
|
|
|
|
dest="punc_server_ip",
|
|
|
|
|
help='Punctuation server ip')
|
|
|
|
|
self.parser.add_argument(
|
|
|
|
|
'--sample_rate', type=int, default=16000, help='audio sample rate')
|
|
|
|
|
self.parser.add_argument(
|
|
|
|
|
'--lang', type=str, default="zh_cn", help='language')
|
|
|
|
|
self.parser.add_argument(
|
|
|
|
|
'--audio_format', type=str, default="wav", help='audio format')
|
|
|
|
|
'--punc.port',
|
|
|
|
|
type=int,
|
|
|
|
|
default=8091,
|
|
|
|
|
dest="punc_server_port",
|
|
|
|
|
help='Punctuation server port')
|
|
|
|
|
|
|
|
|
|
def execute(self, argv: List[str]) -> bool:
|
|
|
|
|
args = self.parser.parse_args(argv)
|
|
|
|
@ -379,6 +321,7 @@ class ASROnlineClientExecutor(BaseExecutor):
|
|
|
|
|
sample_rate = args.sample_rate
|
|
|
|
|
lang = args.lang
|
|
|
|
|
audio_format = args.audio_format
|
|
|
|
|
protocol = args.protocol
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
time_start = time.time()
|
|
|
|
@ -388,9 +331,12 @@ class ASROnlineClientExecutor(BaseExecutor):
|
|
|
|
|
port=port,
|
|
|
|
|
sample_rate=sample_rate,
|
|
|
|
|
lang=lang,
|
|
|
|
|
audio_format=audio_format)
|
|
|
|
|
audio_format=audio_format,
|
|
|
|
|
protocol=protocol,
|
|
|
|
|
punc_server_ip=args.punc_server_ip,
|
|
|
|
|
punc_server_port=args.punc_server_port)
|
|
|
|
|
time_end = time.time()
|
|
|
|
|
logger.info(res)
|
|
|
|
|
logger.info(f"ASR result: {res}")
|
|
|
|
|
logger.info("Response time %f s." % (time_end - time_start))
|
|
|
|
|
return True
|
|
|
|
|
except Exception as e:
|
|
|
|
@ -402,21 +348,53 @@ class ASROnlineClientExecutor(BaseExecutor):
|
|
|
|
|
def __call__(self,
|
|
|
|
|
input: str,
|
|
|
|
|
server_ip: str="127.0.0.1",
|
|
|
|
|
port: int=8091,
|
|
|
|
|
port: int=8090,
|
|
|
|
|
sample_rate: int=16000,
|
|
|
|
|
lang: str="zh_cn",
|
|
|
|
|
audio_format: str="wav"):
|
|
|
|
|
"""
|
|
|
|
|
Python API to call an executor.
|
|
|
|
|
audio_format: str="wav",
|
|
|
|
|
protocol: str="http",
|
|
|
|
|
punc_server_ip: str="127.0.0.1",
|
|
|
|
|
punc_server_port: int=8091):
|
|
|
|
|
"""Python API to call an executor.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
input (str): The input audio file path
|
|
|
|
|
server_ip (str, optional): The ASR server ip. Defaults to "127.0.0.1".
|
|
|
|
|
port (int, optional): The ASR server port. Defaults to 8090.
|
|
|
|
|
sample_rate (int, optional): The audio sample rate. Defaults to 16000.
|
|
|
|
|
lang (str, optional): The audio language type. Defaults to "zh_cn".
|
|
|
|
|
audio_format (str, optional): The audio format information. Defaults to "wav".
|
|
|
|
|
protocol (str, optional): The ASR server. Defaults to "http".
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
str: The ASR results
|
|
|
|
|
"""
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
logging.info("asr websocket client start")
|
|
|
|
|
handler = ASRAudioHandler(server_ip, port)
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
|
res = loop.run_until_complete(handler.run(input))
|
|
|
|
|
logging.info("asr websocket client finished")
|
|
|
|
|
|
|
|
|
|
return res['asr_results']
|
|
|
|
|
# we use the asr server to recognize the audio text content
|
|
|
|
|
if protocol.lower() == "http":
|
|
|
|
|
from paddlespeech.server.utils.audio_handler import ASRHttpHandler
|
|
|
|
|
logger.info("asr http client start")
|
|
|
|
|
handler = ASRHttpHandler(server_ip=server_ip, port=port)
|
|
|
|
|
res = handler.run(input, audio_format, sample_rate, lang)
|
|
|
|
|
res = res['result']['transcription']
|
|
|
|
|
logger.info("asr http client finished")
|
|
|
|
|
|
|
|
|
|
elif protocol.lower() == "websocket":
|
|
|
|
|
logger.info("asr websocket client start")
|
|
|
|
|
handler = ASRWsAudioHandler(
|
|
|
|
|
server_ip,
|
|
|
|
|
port,
|
|
|
|
|
punc_server_ip=punc_server_ip,
|
|
|
|
|
punc_server_port=punc_server_port)
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
|
res = loop.run_until_complete(handler.run(input))
|
|
|
|
|
res = res['result']
|
|
|
|
|
logger.info("asr websocket client finished")
|
|
|
|
|
else:
|
|
|
|
|
logger.error(f"Sorry, we have not support protocol: {protocol},"
|
|
|
|
|
"please use http or websocket protocol")
|
|
|
|
|
sys.exit(-1)
|
|
|
|
|
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@cli_client_register(
|
|
|
|
|