From af484fc980e9df51e6411a13d9d280a6447f0c26 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Thu, 14 Apr 2022 19:57:52 +0800 Subject: [PATCH] convert websockert results to str from bytest, test=doc --- .../server/engine/asr/online/asr_engine.py | 23 +++++++---- .../tests/asr/online/websocket_client.py | 40 +++++++++++++++---- 2 files changed, 48 insertions(+), 15 deletions(-) diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index ca82b615..cd5300fc 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -35,9 +35,9 @@ __all__ = ['ASREngine'] pretrained_models = { "deepspeech2online_aishell-zh-16k": { 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.1.1.model.tar.gz', + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz', 'md5': - 'd5e076217cf60486519f72c217d21b9b', + '23e16c69730a1cb5d735c98c83c21e16', 'cfg_path': 'model.yaml', 'ckpt_path': @@ -75,6 +75,7 @@ class ASRServerExecutor(ASRExecutor): if cfg_path is None or am_model is None or am_params is None: sample_rate_str = '16k' if sample_rate == 16000 else '8k' tag = model_type + '-' + lang + '-' + sample_rate_str + logger.info(f"Load the pretrained model, tag = {tag}") res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.res_path = res_path self.cfg_path = os.path.join(res_path, @@ -85,9 +86,6 @@ class ASRServerExecutor(ASRExecutor): self.am_params = os.path.join(res_path, pretrained_models[tag]['params']) logger.info(res_path) - logger.info(self.cfg_path) - logger.info(self.am_model) - logger.info(self.am_params) else: self.cfg_path = os.path.abspath(cfg_path) self.am_model = os.path.abspath(am_model) @@ -95,6 +93,10 @@ class ASRServerExecutor(ASRExecutor): self.res_path = os.path.dirname( os.path.dirname(os.path.abspath(self.cfg_path))) + logger.info(self.cfg_path) + logger.info(self.am_model) + logger.info(self.am_params) + #Init body. self.config = CfgNode(new_allowed=True) self.config.merge_from_file(self.cfg_path) @@ -112,15 +114,20 @@ class ASRServerExecutor(ASRExecutor): lm_url = pretrained_models[tag]['lm_url'] lm_md5 = pretrained_models[tag]['lm_md5'] + logger.info(f"Start to load language model {lm_url}") self.download_lm( lm_url, os.path.dirname(self.config.decode.lang_model_path), lm_md5) elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: - raise Exception("wrong type") + # 开发 conformer 的流式模型 + logger.info("start to create the stream conformer asr engine") + # 复用cli里面的代码 + else: raise Exception("wrong type") # AM predictor + logger.info("ASR engine start to init the am predictor") self.am_predictor_conf = am_predictor_conf self.am_predictor = init_predictor( model_file=self.am_model, @@ -128,6 +135,7 @@ class ASRServerExecutor(ASRExecutor): predictor_conf=self.am_predictor_conf) # decoder + logger.info("ASR engine start to create the ctc decoder instance") self.decoder = CTCDecoder( odim=self.config.output_dim, # is in vocab enc_n_units=self.config.rnn_layer_size * 2, @@ -138,6 +146,7 @@ class ASRServerExecutor(ASRExecutor): grad_norm_type=self.config.get('ctc_grad_norm_type', None)) # init decoder + logger.info("ASR engine start to init the ctc decoder") cfg = self.config.decode decode_batch_size = 1 # for online self.decoder.init_decoder( @@ -215,7 +224,6 @@ class ASRServerExecutor(ASRExecutor): self.decoder.next(output_chunk_probs, output_chunk_lens) trans_best, trans_beam = self.decoder.decode() - return trans_best[0] elif "conformer" in model_type or "transformer" in model_type: @@ -273,6 +281,7 @@ class ASREngine(BaseEngine): def __init__(self): super(ASREngine, self).__init__() + logger.info("create the online asr engine instache") def init(self, config: dict) -> bool: """init engine resource diff --git a/paddlespeech/server/tests/asr/online/websocket_client.py b/paddlespeech/server/tests/asr/online/websocket_client.py index 58b1a452..049d707e 100644 --- a/paddlespeech/server/tests/asr/online/websocket_client.py +++ b/paddlespeech/server/tests/asr/online/websocket_client.py @@ -15,8 +15,10 @@ # -*- coding: UTF-8 -*- import argparse import asyncio +import codecs import json import logging +import os import numpy as np import soundfile @@ -54,12 +56,11 @@ class ASRAudioHandler: async def run(self, wavfile_path: str): logging.info("send a message to the server") - # 读取音频 # self.read_wave() - # 发送 websocket 的 handshake 协议头 + # send websocket handshake protocal async with websockets.connect(self.url) as ws: - # server 端已经接收到 handshake 协议头 - # 发送开始指令 + # server has already received handshake protocal + # client start to send the command audio_info = json.dumps( { "name": "test.wav", @@ -77,8 +78,9 @@ class ASRAudioHandler: for chunk_data in self.read_wave(wavfile_path): await ws.send(chunk_data.tobytes()) msg = await ws.recv() + msg = json.loads(msg) logging.info("receive msg={}".format(msg)) - + result = msg # finished audio_info = json.dumps( { @@ -91,16 +93,36 @@ class ASRAudioHandler: separators=(',', ': ')) await ws.send(audio_info) msg = await ws.recv() + # decode the bytes to str + msg = json.loads(msg) logging.info("receive msg={}".format(msg)) + return result + def main(args): logging.basicConfig(level=logging.INFO) logging.info("asr websocket client start") - handler = ASRAudioHandler("127.0.0.1", 8091) + handler = ASRAudioHandler("127.0.0.1", 8090) loop = asyncio.get_event_loop() - loop.run_until_complete(handler.run(args.wavfile)) - logging.info("asr websocket client finished") + + # support to process single audio file + if args.wavfile and os.path.exists(args.wavfile): + logging.info(f"start to process the wavscp: {args.wavfile}") + result = loop.run_until_complete(handler.run(args.wavfile)) + result = result["asr_results"] + logging.info(f"asr websocket client finished : {result}") + + # support to process batch audios from wav.scp + if args.wavscp and os.path.exists(args.wavscp): + logging.info(f"start to process the wavscp: {args.wavscp}") + with codecs.open(args.wavscp, 'r', encoding='utf-8') as f,\ + codecs.open("result.txt", 'w', encoding='utf-8') as w: + for line in f: + utt_name, utt_path = line.strip().split() + result = loop.run_until_complete(handler.run(utt_path)) + result = result["asr_results"] + w.write(f"{utt_name} {result}\n") if __name__ == "__main__": @@ -110,6 +132,8 @@ if __name__ == "__main__": action="store", help="wav file path ", default="./16_audio.wav") + parser.add_argument( + "--wavscp", type=str, default=None, help="The batch audios dict text") args = parser.parse_args() main(args)