convert websockert results to str from bytest, test=doc

pull/1704/head
xiongxinlei 3 years ago
parent 23a6534119
commit af484fc980

@ -35,9 +35,9 @@ __all__ = ['ASREngine']
pretrained_models = { pretrained_models = {
"deepspeech2online_aishell-zh-16k": { "deepspeech2online_aishell-zh-16k": {
'url': 'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.1.1.model.tar.gz', 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz',
'md5': 'md5':
'd5e076217cf60486519f72c217d21b9b', '23e16c69730a1cb5d735c98c83c21e16',
'cfg_path': 'cfg_path':
'model.yaml', 'model.yaml',
'ckpt_path': 'ckpt_path':
@ -75,6 +75,7 @@ class ASRServerExecutor(ASRExecutor):
if cfg_path is None or am_model is None or am_params is None: if cfg_path is None or am_model is None or am_params is None:
sample_rate_str = '16k' if sample_rate == 16000 else '8k' sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str tag = model_type + '-' + lang + '-' + sample_rate_str
logger.info(f"Load the pretrained model, tag = {tag}")
res_path = self._get_pretrained_path(tag) # wenetspeech_zh res_path = self._get_pretrained_path(tag) # wenetspeech_zh
self.res_path = res_path self.res_path = res_path
self.cfg_path = os.path.join(res_path, self.cfg_path = os.path.join(res_path,
@ -85,9 +86,6 @@ class ASRServerExecutor(ASRExecutor):
self.am_params = os.path.join(res_path, self.am_params = os.path.join(res_path,
pretrained_models[tag]['params']) pretrained_models[tag]['params'])
logger.info(res_path) logger.info(res_path)
logger.info(self.cfg_path)
logger.info(self.am_model)
logger.info(self.am_params)
else: else:
self.cfg_path = os.path.abspath(cfg_path) self.cfg_path = os.path.abspath(cfg_path)
self.am_model = os.path.abspath(am_model) self.am_model = os.path.abspath(am_model)
@ -95,6 +93,10 @@ class ASRServerExecutor(ASRExecutor):
self.res_path = os.path.dirname( self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path))) os.path.dirname(os.path.abspath(self.cfg_path)))
logger.info(self.cfg_path)
logger.info(self.am_model)
logger.info(self.am_params)
#Init body. #Init body.
self.config = CfgNode(new_allowed=True) self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path) self.config.merge_from_file(self.cfg_path)
@ -112,15 +114,20 @@ class ASRServerExecutor(ASRExecutor):
lm_url = pretrained_models[tag]['lm_url'] lm_url = pretrained_models[tag]['lm_url']
lm_md5 = pretrained_models[tag]['lm_md5'] lm_md5 = pretrained_models[tag]['lm_md5']
logger.info(f"Start to load language model {lm_url}")
self.download_lm( self.download_lm(
lm_url, lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5) os.path.dirname(self.config.decode.lang_model_path), lm_md5)
elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
raise Exception("wrong type") # 开发 conformer 的流式模型
logger.info("start to create the stream conformer asr engine")
# 复用cli里面的代码
else: else:
raise Exception("wrong type") raise Exception("wrong type")
# AM predictor # AM predictor
logger.info("ASR engine start to init the am predictor")
self.am_predictor_conf = am_predictor_conf self.am_predictor_conf = am_predictor_conf
self.am_predictor = init_predictor( self.am_predictor = init_predictor(
model_file=self.am_model, model_file=self.am_model,
@ -128,6 +135,7 @@ class ASRServerExecutor(ASRExecutor):
predictor_conf=self.am_predictor_conf) predictor_conf=self.am_predictor_conf)
# decoder # decoder
logger.info("ASR engine start to create the ctc decoder instance")
self.decoder = CTCDecoder( self.decoder = CTCDecoder(
odim=self.config.output_dim, # <blank> is in vocab odim=self.config.output_dim, # <blank> is in vocab
enc_n_units=self.config.rnn_layer_size * 2, enc_n_units=self.config.rnn_layer_size * 2,
@ -138,6 +146,7 @@ class ASRServerExecutor(ASRExecutor):
grad_norm_type=self.config.get('ctc_grad_norm_type', None)) grad_norm_type=self.config.get('ctc_grad_norm_type', None))
# init decoder # init decoder
logger.info("ASR engine start to init the ctc decoder")
cfg = self.config.decode cfg = self.config.decode
decode_batch_size = 1 # for online decode_batch_size = 1 # for online
self.decoder.init_decoder( self.decoder.init_decoder(
@ -215,7 +224,6 @@ class ASRServerExecutor(ASRExecutor):
self.decoder.next(output_chunk_probs, output_chunk_lens) self.decoder.next(output_chunk_probs, output_chunk_lens)
trans_best, trans_beam = self.decoder.decode() trans_best, trans_beam = self.decoder.decode()
return trans_best[0] return trans_best[0]
elif "conformer" in model_type or "transformer" in model_type: elif "conformer" in model_type or "transformer" in model_type:
@ -273,6 +281,7 @@ class ASREngine(BaseEngine):
def __init__(self): def __init__(self):
super(ASREngine, self).__init__() super(ASREngine, self).__init__()
logger.info("create the online asr engine instache")
def init(self, config: dict) -> bool: def init(self, config: dict) -> bool:
"""init engine resource """init engine resource

@ -15,8 +15,10 @@
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
import argparse import argparse
import asyncio import asyncio
import codecs
import json import json
import logging import logging
import os
import numpy as np import numpy as np
import soundfile import soundfile
@ -54,12 +56,11 @@ class ASRAudioHandler:
async def run(self, wavfile_path: str): async def run(self, wavfile_path: str):
logging.info("send a message to the server") logging.info("send a message to the server")
# 读取音频
# self.read_wave() # self.read_wave()
# 发送 websocket 的 handshake 协议头 # send websocket handshake protocal
async with websockets.connect(self.url) as ws: async with websockets.connect(self.url) as ws:
# server 端已经接收到 handshake 协议头 # server has already received handshake protocal
# 发送开始指令 # client start to send the command
audio_info = json.dumps( audio_info = json.dumps(
{ {
"name": "test.wav", "name": "test.wav",
@ -77,8 +78,9 @@ class ASRAudioHandler:
for chunk_data in self.read_wave(wavfile_path): for chunk_data in self.read_wave(wavfile_path):
await ws.send(chunk_data.tobytes()) await ws.send(chunk_data.tobytes())
msg = await ws.recv() msg = await ws.recv()
msg = json.loads(msg)
logging.info("receive msg={}".format(msg)) logging.info("receive msg={}".format(msg))
result = msg
# finished # finished
audio_info = json.dumps( audio_info = json.dumps(
{ {
@ -91,16 +93,36 @@ class ASRAudioHandler:
separators=(',', ': ')) separators=(',', ': '))
await ws.send(audio_info) await ws.send(audio_info)
msg = await ws.recv() msg = await ws.recv()
# decode the bytes to str
msg = json.loads(msg)
logging.info("receive msg={}".format(msg)) logging.info("receive msg={}".format(msg))
return result
def main(args): def main(args):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logging.info("asr websocket client start") logging.info("asr websocket client start")
handler = ASRAudioHandler("127.0.0.1", 8091) handler = ASRAudioHandler("127.0.0.1", 8090)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.run_until_complete(handler.run(args.wavfile))
logging.info("asr websocket client finished") # support to process single audio file
if args.wavfile and os.path.exists(args.wavfile):
logging.info(f"start to process the wavscp: {args.wavfile}")
result = loop.run_until_complete(handler.run(args.wavfile))
result = result["asr_results"]
logging.info(f"asr websocket client finished : {result}")
# support to process batch audios from wav.scp
if args.wavscp and os.path.exists(args.wavscp):
logging.info(f"start to process the wavscp: {args.wavscp}")
with codecs.open(args.wavscp, 'r', encoding='utf-8') as f,\
codecs.open("result.txt", 'w', encoding='utf-8') as w:
for line in f:
utt_name, utt_path = line.strip().split()
result = loop.run_until_complete(handler.run(utt_path))
result = result["asr_results"]
w.write(f"{utt_name} {result}\n")
if __name__ == "__main__": if __name__ == "__main__":
@ -110,6 +132,8 @@ if __name__ == "__main__":
action="store", action="store",
help="wav file path ", help="wav file path ",
default="./16_audio.wav") default="./16_audio.wav")
parser.add_argument(
"--wavscp", type=str, default=None, help="The batch audios dict text")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

Loading…
Cancel
Save