convert websockert results to str from bytest, test=doc

pull/1704/head
xiongxinlei 3 years ago
parent 23a6534119
commit af484fc980

@ -35,9 +35,9 @@ __all__ = ['ASREngine']
pretrained_models = {
"deepspeech2online_aishell-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.1.1.model.tar.gz',
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz',
'md5':
'd5e076217cf60486519f72c217d21b9b',
'23e16c69730a1cb5d735c98c83c21e16',
'cfg_path':
'model.yaml',
'ckpt_path':
@ -75,6 +75,7 @@ class ASRServerExecutor(ASRExecutor):
if cfg_path is None or am_model is None or am_params is None:
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str
logger.info(f"Load the pretrained model, tag = {tag}")
res_path = self._get_pretrained_path(tag) # wenetspeech_zh
self.res_path = res_path
self.cfg_path = os.path.join(res_path,
@ -85,9 +86,6 @@ class ASRServerExecutor(ASRExecutor):
self.am_params = os.path.join(res_path,
pretrained_models[tag]['params'])
logger.info(res_path)
logger.info(self.cfg_path)
logger.info(self.am_model)
logger.info(self.am_params)
else:
self.cfg_path = os.path.abspath(cfg_path)
self.am_model = os.path.abspath(am_model)
@ -95,6 +93,10 @@ class ASRServerExecutor(ASRExecutor):
self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path)))
logger.info(self.cfg_path)
logger.info(self.am_model)
logger.info(self.am_params)
#Init body.
self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path)
@ -112,15 +114,20 @@ class ASRServerExecutor(ASRExecutor):
lm_url = pretrained_models[tag]['lm_url']
lm_md5 = pretrained_models[tag]['lm_md5']
logger.info(f"Start to load language model {lm_url}")
self.download_lm(
lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
raise Exception("wrong type")
# 开发 conformer 的流式模型
logger.info("start to create the stream conformer asr engine")
# 复用cli里面的代码
else:
raise Exception("wrong type")
# AM predictor
logger.info("ASR engine start to init the am predictor")
self.am_predictor_conf = am_predictor_conf
self.am_predictor = init_predictor(
model_file=self.am_model,
@ -128,6 +135,7 @@ class ASRServerExecutor(ASRExecutor):
predictor_conf=self.am_predictor_conf)
# decoder
logger.info("ASR engine start to create the ctc decoder instance")
self.decoder = CTCDecoder(
odim=self.config.output_dim, # <blank> is in vocab
enc_n_units=self.config.rnn_layer_size * 2,
@ -138,6 +146,7 @@ class ASRServerExecutor(ASRExecutor):
grad_norm_type=self.config.get('ctc_grad_norm_type', None))
# init decoder
logger.info("ASR engine start to init the ctc decoder")
cfg = self.config.decode
decode_batch_size = 1 # for online
self.decoder.init_decoder(
@ -215,7 +224,6 @@ class ASRServerExecutor(ASRExecutor):
self.decoder.next(output_chunk_probs, output_chunk_lens)
trans_best, trans_beam = self.decoder.decode()
return trans_best[0]
elif "conformer" in model_type or "transformer" in model_type:
@ -273,6 +281,7 @@ class ASREngine(BaseEngine):
def __init__(self):
super(ASREngine, self).__init__()
logger.info("create the online asr engine instache")
def init(self, config: dict) -> bool:
"""init engine resource

@ -15,8 +15,10 @@
# -*- coding: UTF-8 -*-
import argparse
import asyncio
import codecs
import json
import logging
import os
import numpy as np
import soundfile
@ -54,12 +56,11 @@ class ASRAudioHandler:
async def run(self, wavfile_path: str):
logging.info("send a message to the server")
# 读取音频
# self.read_wave()
# 发送 websocket 的 handshake 协议头
# send websocket handshake protocal
async with websockets.connect(self.url) as ws:
# server 端已经接收到 handshake 协议头
# 发送开始指令
# server has already received handshake protocal
# client start to send the command
audio_info = json.dumps(
{
"name": "test.wav",
@ -77,8 +78,9 @@ class ASRAudioHandler:
for chunk_data in self.read_wave(wavfile_path):
await ws.send(chunk_data.tobytes())
msg = await ws.recv()
msg = json.loads(msg)
logging.info("receive msg={}".format(msg))
result = msg
# finished
audio_info = json.dumps(
{
@ -91,16 +93,36 @@ class ASRAudioHandler:
separators=(',', ': '))
await ws.send(audio_info)
msg = await ws.recv()
# decode the bytes to str
msg = json.loads(msg)
logging.info("receive msg={}".format(msg))
return result
def main(args):
logging.basicConfig(level=logging.INFO)
logging.info("asr websocket client start")
handler = ASRAudioHandler("127.0.0.1", 8091)
handler = ASRAudioHandler("127.0.0.1", 8090)
loop = asyncio.get_event_loop()
loop.run_until_complete(handler.run(args.wavfile))
logging.info("asr websocket client finished")
# support to process single audio file
if args.wavfile and os.path.exists(args.wavfile):
logging.info(f"start to process the wavscp: {args.wavfile}")
result = loop.run_until_complete(handler.run(args.wavfile))
result = result["asr_results"]
logging.info(f"asr websocket client finished : {result}")
# support to process batch audios from wav.scp
if args.wavscp and os.path.exists(args.wavscp):
logging.info(f"start to process the wavscp: {args.wavscp}")
with codecs.open(args.wavscp, 'r', encoding='utf-8') as f,\
codecs.open("result.txt", 'w', encoding='utf-8') as w:
for line in f:
utt_name, utt_path = line.strip().split()
result = loop.run_until_complete(handler.run(utt_path))
result = result["asr_results"]
w.write(f"{utt_name} {result}\n")
if __name__ == "__main__":
@ -110,6 +132,8 @@ if __name__ == "__main__":
action="store",
help="wav file path ",
default="./16_audio.wav")
parser.add_argument(
"--wavscp", type=str, default=None, help="The batch audios dict text")
args = parser.parse_args()
main(args)

Loading…
Cancel
Save