From af484fc980e9df51e6411a13d9d280a6447f0c26 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Thu, 14 Apr 2022 19:57:52 +0800 Subject: [PATCH 01/16] convert websockert results to str from bytest, test=doc --- .../server/engine/asr/online/asr_engine.py | 23 +++++++---- .../tests/asr/online/websocket_client.py | 40 +++++++++++++++---- 2 files changed, 48 insertions(+), 15 deletions(-) diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index ca82b615..cd5300fc 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -35,9 +35,9 @@ __all__ = ['ASREngine'] pretrained_models = { "deepspeech2online_aishell-zh-16k": { 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.1.1.model.tar.gz', + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz', 'md5': - 'd5e076217cf60486519f72c217d21b9b', + '23e16c69730a1cb5d735c98c83c21e16', 'cfg_path': 'model.yaml', 'ckpt_path': @@ -75,6 +75,7 @@ class ASRServerExecutor(ASRExecutor): if cfg_path is None or am_model is None or am_params is None: sample_rate_str = '16k' if sample_rate == 16000 else '8k' tag = model_type + '-' + lang + '-' + sample_rate_str + logger.info(f"Load the pretrained model, tag = {tag}") res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.res_path = res_path self.cfg_path = os.path.join(res_path, @@ -85,9 +86,6 @@ class ASRServerExecutor(ASRExecutor): self.am_params = os.path.join(res_path, pretrained_models[tag]['params']) logger.info(res_path) - logger.info(self.cfg_path) - logger.info(self.am_model) - logger.info(self.am_params) else: self.cfg_path = os.path.abspath(cfg_path) self.am_model = os.path.abspath(am_model) @@ -95,6 +93,10 @@ class ASRServerExecutor(ASRExecutor): self.res_path = os.path.dirname( os.path.dirname(os.path.abspath(self.cfg_path))) + logger.info(self.cfg_path) + logger.info(self.am_model) + logger.info(self.am_params) + #Init body. self.config = CfgNode(new_allowed=True) self.config.merge_from_file(self.cfg_path) @@ -112,15 +114,20 @@ class ASRServerExecutor(ASRExecutor): lm_url = pretrained_models[tag]['lm_url'] lm_md5 = pretrained_models[tag]['lm_md5'] + logger.info(f"Start to load language model {lm_url}") self.download_lm( lm_url, os.path.dirname(self.config.decode.lang_model_path), lm_md5) elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: - raise Exception("wrong type") + # 开发 conformer 的流式模型 + logger.info("start to create the stream conformer asr engine") + # 复用cli里面的代码 + else: raise Exception("wrong type") # AM predictor + logger.info("ASR engine start to init the am predictor") self.am_predictor_conf = am_predictor_conf self.am_predictor = init_predictor( model_file=self.am_model, @@ -128,6 +135,7 @@ class ASRServerExecutor(ASRExecutor): predictor_conf=self.am_predictor_conf) # decoder + logger.info("ASR engine start to create the ctc decoder instance") self.decoder = CTCDecoder( odim=self.config.output_dim, # is in vocab enc_n_units=self.config.rnn_layer_size * 2, @@ -138,6 +146,7 @@ class ASRServerExecutor(ASRExecutor): grad_norm_type=self.config.get('ctc_grad_norm_type', None)) # init decoder + logger.info("ASR engine start to init the ctc decoder") cfg = self.config.decode decode_batch_size = 1 # for online self.decoder.init_decoder( @@ -215,7 +224,6 @@ class ASRServerExecutor(ASRExecutor): self.decoder.next(output_chunk_probs, output_chunk_lens) trans_best, trans_beam = self.decoder.decode() - return trans_best[0] elif "conformer" in model_type or "transformer" in model_type: @@ -273,6 +281,7 @@ class ASREngine(BaseEngine): def __init__(self): super(ASREngine, self).__init__() + logger.info("create the online asr engine instache") def init(self, config: dict) -> bool: """init engine resource diff --git a/paddlespeech/server/tests/asr/online/websocket_client.py b/paddlespeech/server/tests/asr/online/websocket_client.py index 58b1a452..049d707e 100644 --- a/paddlespeech/server/tests/asr/online/websocket_client.py +++ b/paddlespeech/server/tests/asr/online/websocket_client.py @@ -15,8 +15,10 @@ # -*- coding: UTF-8 -*- import argparse import asyncio +import codecs import json import logging +import os import numpy as np import soundfile @@ -54,12 +56,11 @@ class ASRAudioHandler: async def run(self, wavfile_path: str): logging.info("send a message to the server") - # 读取音频 # self.read_wave() - # 发送 websocket 的 handshake 协议头 + # send websocket handshake protocal async with websockets.connect(self.url) as ws: - # server 端已经接收到 handshake 协议头 - # 发送开始指令 + # server has already received handshake protocal + # client start to send the command audio_info = json.dumps( { "name": "test.wav", @@ -77,8 +78,9 @@ class ASRAudioHandler: for chunk_data in self.read_wave(wavfile_path): await ws.send(chunk_data.tobytes()) msg = await ws.recv() + msg = json.loads(msg) logging.info("receive msg={}".format(msg)) - + result = msg # finished audio_info = json.dumps( { @@ -91,16 +93,36 @@ class ASRAudioHandler: separators=(',', ': ')) await ws.send(audio_info) msg = await ws.recv() + # decode the bytes to str + msg = json.loads(msg) logging.info("receive msg={}".format(msg)) + return result + def main(args): logging.basicConfig(level=logging.INFO) logging.info("asr websocket client start") - handler = ASRAudioHandler("127.0.0.1", 8091) + handler = ASRAudioHandler("127.0.0.1", 8090) loop = asyncio.get_event_loop() - loop.run_until_complete(handler.run(args.wavfile)) - logging.info("asr websocket client finished") + + # support to process single audio file + if args.wavfile and os.path.exists(args.wavfile): + logging.info(f"start to process the wavscp: {args.wavfile}") + result = loop.run_until_complete(handler.run(args.wavfile)) + result = result["asr_results"] + logging.info(f"asr websocket client finished : {result}") + + # support to process batch audios from wav.scp + if args.wavscp and os.path.exists(args.wavscp): + logging.info(f"start to process the wavscp: {args.wavscp}") + with codecs.open(args.wavscp, 'r', encoding='utf-8') as f,\ + codecs.open("result.txt", 'w', encoding='utf-8') as w: + for line in f: + utt_name, utt_path = line.strip().split() + result = loop.run_until_complete(handler.run(utt_path)) + result = result["asr_results"] + w.write(f"{utt_name} {result}\n") if __name__ == "__main__": @@ -110,6 +132,8 @@ if __name__ == "__main__": action="store", help="wav file path ", default="./16_audio.wav") + parser.add_argument( + "--wavscp", type=str, default=None, help="The batch audios dict text") args = parser.parse_args() main(args) From d21ccd02875fea5d8c90483a31cd8b6f4a148d2e Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Fri, 15 Apr 2022 18:42:46 +0800 Subject: [PATCH 02/16] add conformer online server, test=doc --- paddlespeech/cli/asr/infer.py | 56 +++-- paddlespeech/s2t/models/u2/u2.py | 8 +- paddlespeech/s2t/modules/ctc.py | 3 +- paddlespeech/s2t/modules/encoder.py | 2 + paddlespeech/server/conf/ws_application.yaml | 54 +++- .../server/engine/asr/online/asr_engine.py | 231 ++++++++++++------ .../tests/asr/online/websocket_client.py | 2 +- paddlespeech/server/ws/asr_socket.py | 29 ++- 8 files changed, 272 insertions(+), 113 deletions(-) diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index b12b9f6f..53f71a70 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -91,6 +91,20 @@ pretrained_models = { 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', 'lm_md5': '29e02312deb2e59b3c8686c7966d4fe3' + }, + "conformer2online_aishell-zh-16k": { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.1.2.model.tar.gz', + 'md5': + '4814e52e0fc2fd48899373f95c84b0c9', + 'cfg_path': + 'config.yaml', + 'ckpt_path': + 'exp/deepspeech2_online/checkpoints/avg_30', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' }, "deepspeech2offline_librispeech-en-16k": { 'url': @@ -115,6 +129,8 @@ model_alias = { "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline", "conformer": "paddlespeech.s2t.models.u2:U2Model", + "conformer2online": + "paddlespeech.s2t.models.u2:U2Model", "transformer": "paddlespeech.s2t.models.u2:U2Model", "wenetspeech": @@ -219,6 +235,7 @@ class ASRExecutor(BaseExecutor): """ Init model and other resources from a specific path. """ + logger.info("start to init the model") if hasattr(self, 'model'): logger.info('Model had been initialized.') return @@ -233,14 +250,15 @@ class ASRExecutor(BaseExecutor): self.ckpt_path = os.path.join( res_path, pretrained_models[tag]['ckpt_path'] + ".pdparams") logger.info(res_path) - logger.info(self.cfg_path) - logger.info(self.ckpt_path) + else: self.cfg_path = os.path.abspath(cfg_path) self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams") self.res_path = os.path.dirname( os.path.dirname(os.path.abspath(self.cfg_path))) - + logger.info(self.cfg_path) + logger.info(self.ckpt_path) + #Init body. self.config = CfgNode(new_allowed=True) self.config.merge_from_file(self.cfg_path) @@ -269,7 +287,6 @@ class ASRExecutor(BaseExecutor): vocab=self.config.vocab_filepath, spm_model_prefix=self.config.spm_model_prefix) self.config.decode.decoding_method = decode_method - else: raise Exception("wrong type") model_name = model_type[:model_type.rindex( @@ -347,12 +364,14 @@ class ASRExecutor(BaseExecutor): else: raise Exception("wrong type") + logger.info("audio feat process success") + @paddle.no_grad() def infer(self, model_type: str): """ Model inference and result stored in self.output. """ - + logger.info("start to infer the model to get the output") cfg = self.config.decode audio = self._inputs["audio"] audio_len = self._inputs["audio_len"] @@ -369,17 +388,22 @@ class ASRExecutor(BaseExecutor): self._outputs["result"] = result_transcripts[0] elif "conformer" in model_type or "transformer" in model_type: - result_transcripts = self.model.decode( - audio, - audio_len, - text_feature=self.text_feature, - decoding_method=cfg.decoding_method, - beam_size=cfg.beam_size, - ctc_weight=cfg.ctc_weight, - decoding_chunk_size=cfg.decoding_chunk_size, - num_decoding_left_chunks=cfg.num_decoding_left_chunks, - simulate_streaming=cfg.simulate_streaming) - self._outputs["result"] = result_transcripts[0][0] + logger.info(f"we will use the transformer like model : {model_type}") + try: + result_transcripts = self.model.decode( + audio, + audio_len, + text_feature=self.text_feature, + decoding_method=cfg.decoding_method, + beam_size=cfg.beam_size, + ctc_weight=cfg.ctc_weight, + decoding_chunk_size=cfg.decoding_chunk_size, + num_decoding_left_chunks=cfg.num_decoding_left_chunks, + simulate_streaming=cfg.simulate_streaming) + self._outputs["result"] = result_transcripts[0][0] + except Exception as e: + logger.exception(e) + else: raise Exception("invalid model name") diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index 6a98607b..f0d2711d 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -213,12 +213,14 @@ class U2BaseModel(ASRInterface, nn.Layer): num_decoding_left_chunks=num_decoding_left_chunks ) # (B, maxlen, encoder_dim) else: + print("offline decode from the asr") encoder_out, encoder_mask = self.encoder( speech, speech_lengths, decoding_chunk_size=decoding_chunk_size, num_decoding_left_chunks=num_decoding_left_chunks ) # (B, maxlen, encoder_dim) + print("offline decode success") return encoder_out, encoder_mask def recognize( @@ -706,13 +708,15 @@ class U2BaseModel(ASRInterface, nn.Layer): List[List[int]]: transcripts. """ batch_size = feats.shape[0] + print("start to decode the audio feat") if decoding_method in ['ctc_prefix_beam_search', 'attention_rescoring'] and batch_size > 1: - logger.fatal( + logger.error( f'decoding mode {decoding_method} must be running with batch_size == 1' ) + logger.error(f"current batch_size is {batch_size}") sys.exit(1) - + print(f"use the {decoding_method} to decode the audio feat") if decoding_method == 'attention': hyps = self.recognize( feats, diff --git a/paddlespeech/s2t/modules/ctc.py b/paddlespeech/s2t/modules/ctc.py index 33ad472d..bd1219b1 100644 --- a/paddlespeech/s2t/modules/ctc.py +++ b/paddlespeech/s2t/modules/ctc.py @@ -180,7 +180,8 @@ class CTCDecoder(CTCDecoderBase): # init once if self._ext_scorer is not None: return - + + from paddlespeech.s2t.decoders.ctcdecoder import Scorer # noqa: F401 if language_model_path != '': logger.info("begin to initialize the external scorer " "for decoding") diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index c843c0e2..347035cd 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -317,6 +317,8 @@ class BaseEncoder(nn.Layer): outputs = [] offset = 0 # Feed forward overlap input step by step + print(f"context: {context}") + print(f"stride: {stride}") for cur in range(0, num_frames - context + 1, stride): end = min(cur + decoding_window, num_frames) chunk_xs = xs[:, cur:end, :] diff --git a/paddlespeech/server/conf/ws_application.yaml b/paddlespeech/server/conf/ws_application.yaml index ef23593e..6b82edcb 100644 --- a/paddlespeech/server/conf/ws_application.yaml +++ b/paddlespeech/server/conf/ws_application.yaml @@ -4,7 +4,7 @@ # SERVER SETTING # ################################################################################# host: 0.0.0.0 -port: 8091 +port: 8096 # The task format in the engin_list is: _ # task choices = ['asr_online', 'tts_online'] @@ -18,10 +18,44 @@ engine_list: ['asr_online'] # ENGINE CONFIG # ################################################################################# +# ################################### ASR ######################################### +# ################### speech task: asr; engine_type: online ####################### +# asr_online: +# model_type: 'deepspeech2online_aishell' +# am_model: # the pdmodel file of am static model [optional] +# am_params: # the pdiparams file of am static model [optional] +# lang: 'zh' +# sample_rate: 16000 +# cfg_path: +# decode_method: +# force_yes: True + +# am_predictor_conf: +# device: # set 'gpu:id' or 'cpu' +# switch_ir_optim: True +# glog_info: False # True -> print glog +# summary: True # False -> do not show predictor config + +# chunk_buffer_conf: +# frame_duration_ms: 80 +# shift_ms: 40 +# sample_rate: 16000 +# sample_width: 2 + +# vad_conf: +# aggressiveness: 2 +# sample_rate: 16000 +# frame_duration_ms: 20 +# sample_width: 2 +# padding_ms: 200 +# padding_ratio: 0.9 + + + ################################### ASR ######################################### ################### speech task: asr; engine_type: online ####################### asr_online: - model_type: 'deepspeech2online_aishell' + model_type: 'conformer2online_aishell' am_model: # the pdmodel file of am static model [optional] am_params: # the pdiparams file of am static model [optional] lang: 'zh' @@ -37,15 +71,15 @@ asr_online: summary: True # False -> do not show predictor config chunk_buffer_conf: - frame_duration_ms: 80 + frame_duration_ms: 85 shift_ms: 40 sample_rate: 16000 sample_width: 2 - vad_conf: - aggressiveness: 2 - sample_rate: 16000 - frame_duration_ms: 20 - sample_width: 2 - padding_ms: 200 - padding_ratio: 0.9 + # vad_conf: + # aggressiveness: 2 + # sample_rate: 16000 + # frame_duration_ms: 20 + # sample_width: 2 + # padding_ms: 200 + # padding_ratio: 0.9 \ No newline at end of file diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index cd5300fc..a5b9ab48 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -20,11 +20,15 @@ from numpy import float32 from yacs.config import CfgNode from paddlespeech.cli.asr.infer import ASRExecutor +from paddlespeech.cli.asr.infer import model_alias +from paddlespeech.cli.asr.infer import pretrained_models from paddlespeech.cli.log import logger from paddlespeech.cli.utils import MODEL_HOME from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.speech import SpeechSegment from paddlespeech.s2t.modules.ctc import CTCDecoder +from paddlespeech.s2t.transform.transformation import Transformation +from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.audio_process import pcm2float @@ -51,6 +55,24 @@ pretrained_models = { 'lm_md5': '29e02312deb2e59b3c8686c7966d4fe3' }, + "conformer2online_aishell-zh-16k": { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr1_chunk_conformer_aishell_ckpt_0.1.2.model.tar.gz', + 'md5': + '4814e52e0fc2fd48899373f95c84b0c9', + 'cfg_path': + 'exp/chunk_conformer//conf/config.yaml', + 'ckpt_path': + 'exp/chunk_conformer/checkpoints/avg_30/', + 'model': + 'exp/chunk_conformer/checkpoints/avg_30.pdparams', + 'params': + 'exp/chunk_conformer/checkpoints/avg_30.pdparams', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' + }, } @@ -71,15 +93,17 @@ class ASRServerExecutor(ASRExecutor): """ Init model and other resources from a specific path. """ - + self.model_type = model_type + self.sample_rate = sample_rate if cfg_path is None or am_model is None or am_params is None: sample_rate_str = '16k' if sample_rate == 16000 else '8k' tag = model_type + '-' + lang + '-' + sample_rate_str logger.info(f"Load the pretrained model, tag = {tag}") res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.res_path = res_path - self.cfg_path = os.path.join(res_path, - pretrained_models[tag]['cfg_path']) + self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/paddlespeech/server/tests/asr/online/conf/config.yaml" + # self.cfg_path = os.path.join(res_path, + # pretrained_models[tag]['cfg_path']) self.am_model = os.path.join(res_path, pretrained_models[tag]['model']) @@ -119,49 +143,67 @@ class ASRServerExecutor(ASRExecutor): lm_url, os.path.dirname(self.config.decode.lang_model_path), lm_md5) elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: - # 开发 conformer 的流式模型 logger.info("start to create the stream conformer asr engine") - # 复用cli里面的代码 - + if self.config.spm_model_prefix: + self.config.spm_model_prefix = os.path.join( + self.res_path, self.config.spm_model_prefix) + self.config.vocab_filepath = os.path.join( + self.res_path, self.config.vocab_filepath) + self.text_feature = TextFeaturizer( + unit_type=self.config.unit_type, + vocab=self.config.vocab_filepath, + spm_model_prefix=self.config.spm_model_prefix) + # update the decoding method + if decode_method: + self.config.decode.decoding_method = decode_method else: raise Exception("wrong type") - - # AM predictor - logger.info("ASR engine start to init the am predictor") - self.am_predictor_conf = am_predictor_conf - self.am_predictor = init_predictor( - model_file=self.am_model, - params_file=self.am_params, - predictor_conf=self.am_predictor_conf) - - # decoder - logger.info("ASR engine start to create the ctc decoder instance") - self.decoder = CTCDecoder( - odim=self.config.output_dim, # is in vocab - enc_n_units=self.config.rnn_layer_size * 2, - blank_id=self.config.blank_id, - dropout_rate=0.0, - reduction=True, # sum - batch_average=True, # sum / batch_size - grad_norm_type=self.config.get('ctc_grad_norm_type', None)) - - # init decoder - logger.info("ASR engine start to init the ctc decoder") - cfg = self.config.decode - decode_batch_size = 1 # for online - self.decoder.init_decoder( - decode_batch_size, self.text_feature.vocab_list, - cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, - cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, - cfg.num_proc_bsearch) - - # init state box - self.chunk_state_h_box = np.zeros( - (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), - dtype=float32) - self.chunk_state_c_box = np.zeros( - (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), - dtype=float32) + if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: + # AM predictor + logger.info("ASR engine start to init the am predictor") + self.am_predictor_conf = am_predictor_conf + self.am_predictor = init_predictor( + model_file=self.am_model, + params_file=self.am_params, + predictor_conf=self.am_predictor_conf) + + # decoder + logger.info("ASR engine start to create the ctc decoder instance") + self.decoder = CTCDecoder( + odim=self.config.output_dim, # is in vocab + enc_n_units=self.config.rnn_layer_size * 2, + blank_id=self.config.blank_id, + dropout_rate=0.0, + reduction=True, # sum + batch_average=True, # sum / batch_size + grad_norm_type=self.config.get('ctc_grad_norm_type', None)) + + # init decoder + logger.info("ASR engine start to init the ctc decoder") + cfg = self.config.decode + decode_batch_size = 1 # for online + self.decoder.init_decoder( + decode_batch_size, self.text_feature.vocab_list, + cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, + cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, + cfg.num_proc_bsearch) + + # init state box + self.chunk_state_h_box = np.zeros( + (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), + dtype=float32) + self.chunk_state_c_box = np.zeros( + (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), + dtype=float32) + elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: + model_name = model_type[:model_type.rindex( + '_')] # model_type: {model_name}_{dataset} + logger.info(f"model name: {model_name}") + model_class = dynamic_import(model_name, model_alias) + model_conf = self.config + model = model_class.from_config(model_conf) + self.model = model + logger.info("create the transformer like model success") def reset_decoder_and_chunk(self): """reset decoder and chunk state for an new audio @@ -186,6 +228,7 @@ class ASRServerExecutor(ASRExecutor): Returns: [type]: [description] """ + logger.info("start to decoce chunk by chunk") if "deepspeech2online" in model_type: input_names = self.am_predictor.get_input_names() audio_handle = self.am_predictor.get_input_handle(input_names[0]) @@ -224,10 +267,29 @@ class ASRServerExecutor(ASRExecutor): self.decoder.next(output_chunk_probs, output_chunk_lens) trans_best, trans_beam = self.decoder.decode() + logger.info(f"decode one one best result: {trans_best[0]}") return trans_best[0] elif "conformer" in model_type or "transformer" in model_type: - raise Exception("invalid model name") + try: + logger.info( + f"we will use the transformer like model : {self.model_type}" + ) + cfg = self.config.decode + result_transcripts = self.model.decode( + x_chunk, + x_chunk_lens, + text_feature=self.text_feature, + decoding_method=cfg.decoding_method, + beam_size=cfg.beam_size, + ctc_weight=cfg.ctc_weight, + decoding_chunk_size=cfg.decoding_chunk_size, + num_decoding_left_chunks=cfg.num_decoding_left_chunks, + simulate_streaming=cfg.simulate_streaming) + + return result_transcripts[0][0] + except Exception as e: + logger.exception(e) else: raise Exception("invalid model name") @@ -244,32 +306,55 @@ class ASRServerExecutor(ASRExecutor): """ # pcm16 -> pcm 32 samples = pcm2float(samples) - - # read audio - speech_segment = SpeechSegment.from_pcm( - samples, sample_rate, transcript=" ") - # audio augment - self.collate_fn_test.augmentation.transform_audio(speech_segment) - - # extract speech feature - spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize( - speech_segment, self.collate_fn_test.keep_transcription_text) - # CMVN spectrum - if self.collate_fn_test._normalizer: - spectrum = self.collate_fn_test._normalizer.apply(spectrum) - - # spectrum augment - audio = self.collate_fn_test.augmentation.transform_feature(spectrum) - - audio_len = audio.shape[0] - audio = paddle.to_tensor(audio, dtype='float32') - # audio_len = paddle.to_tensor(audio_len) - audio = paddle.unsqueeze(audio, axis=0) - - x_chunk = audio.numpy() - x_chunk_lens = np.array([audio_len]) - - return x_chunk, x_chunk_lens + if "deepspeech2online" in self.model_type: + # read audio + speech_segment = SpeechSegment.from_pcm( + samples, sample_rate, transcript=" ") + # audio augment + self.collate_fn_test.augmentation.transform_audio(speech_segment) + + # extract speech feature + spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize( + speech_segment, self.collate_fn_test.keep_transcription_text) + # CMVN spectrum + if self.collate_fn_test._normalizer: + spectrum = self.collate_fn_test._normalizer.apply(spectrum) + + # spectrum augment + audio = self.collate_fn_test.augmentation.transform_feature( + spectrum) + + audio_len = audio.shape[0] + audio = paddle.to_tensor(audio, dtype='float32') + # audio_len = paddle.to_tensor(audio_len) + audio = paddle.unsqueeze(audio, axis=0) + + x_chunk = audio.numpy() + x_chunk_lens = np.array([audio_len]) + + return x_chunk, x_chunk_lens + elif "conformer2online" in self.model_type: + + if sample_rate != self.sample_rate: + logger.info(f"audio sample rate {sample_rate} is not match," \ + "the model sample_rate is {self.sample_rate}") + logger.info(f"ASR Engine use the {self.model_type} to process") + logger.info("Create the preprocess instance") + preprocess_conf = self.config.preprocess_config + preprocess_args = {"train": False} + preprocessing = Transformation(preprocess_conf) + + logger.info("Read the audio file") + logger.info(f"audio shape: {samples.shape}") + # fbank + x_chunk = preprocessing(samples, **preprocess_args) + x_chunk_lens = paddle.to_tensor(x_chunk.shape[0]) + x_chunk = paddle.to_tensor( + x_chunk, dtype="float32").unsqueeze(axis=0) + logger.info( + f"process the audio feature success, feat shape: {x_chunk.shape}" + ) + return x_chunk, x_chunk_lens class ASREngine(BaseEngine): @@ -310,7 +395,10 @@ class ASREngine(BaseEngine): logger.info("Initialize ASR server engine successfully.") return True - def preprocess(self, samples, sample_rate): + def preprocess(self, + samples, + sample_rate, + model_type="deepspeech2online_aishell-zh-16k"): """preprocess Args: @@ -321,6 +409,7 @@ class ASREngine(BaseEngine): x_chunk (numpy.array): shape[B, T, D] x_chunk_lens (numpy.array): shape[B] """ + # if "deepspeech" in model_type: x_chunk, x_chunk_lens = self.executor.extract_feat(samples, sample_rate) return x_chunk, x_chunk_lens diff --git a/paddlespeech/server/tests/asr/online/websocket_client.py b/paddlespeech/server/tests/asr/online/websocket_client.py index 049d707e..a26838f8 100644 --- a/paddlespeech/server/tests/asr/online/websocket_client.py +++ b/paddlespeech/server/tests/asr/online/websocket_client.py @@ -103,7 +103,7 @@ class ASRAudioHandler: def main(args): logging.basicConfig(level=logging.INFO) logging.info("asr websocket client start") - handler = ASRAudioHandler("127.0.0.1", 8090) + handler = ASRAudioHandler("127.0.0.1", 8096) loop = asyncio.get_event_loop() # support to process single audio file diff --git a/paddlespeech/server/ws/asr_socket.py b/paddlespeech/server/ws/asr_socket.py index ea19816b..442f26cb 100644 --- a/paddlespeech/server/ws/asr_socket.py +++ b/paddlespeech/server/ws/asr_socket.py @@ -14,6 +14,7 @@ import json import numpy as np +import json from fastapi import APIRouter from fastapi import WebSocket from fastapi import WebSocketDisconnect @@ -28,7 +29,7 @@ router = APIRouter() @router.websocket('/ws/asr') async def websocket_endpoint(websocket: WebSocket): - + print("websocket protocal receive the dataset") await websocket.accept() engine_pool = get_engine_pool() @@ -36,14 +37,18 @@ async def websocket_endpoint(websocket: WebSocket): # init buffer chunk_buffer_conf = asr_engine.config.chunk_buffer_conf chunk_buffer = ChunkBuffer( + frame_duration_ms=chunk_buffer_conf['frame_duration_ms'], sample_rate=chunk_buffer_conf['sample_rate'], sample_width=chunk_buffer_conf['sample_width']) # init vad - vad_conf = asr_engine.config.vad_conf - vad = VADAudio( - aggressiveness=vad_conf['aggressiveness'], - rate=vad_conf['sample_rate'], - frame_duration_ms=vad_conf['frame_duration_ms']) + # print(asr_engine.config) + # print(type(asr_engine.config)) + vad_conf = asr_engine.config.get('vad_conf', None) + if vad_conf: + vad = VADAudio( + aggressiveness=vad_conf['aggressiveness'], + rate=vad_conf['sample_rate'], + frame_duration_ms=vad_conf['frame_duration_ms']) try: while True: @@ -65,7 +70,7 @@ async def websocket_endpoint(websocket: WebSocket): engine_pool = get_engine_pool() asr_engine = engine_pool['asr'] # reset single engine for an new connection - asr_engine.reset() + # asr_engine.reset() resp = {"status": "ok", "signal": "finished"} await websocket.send_json(resp) break @@ -75,16 +80,16 @@ async def websocket_endpoint(websocket: WebSocket): elif "bytes" in message: message = message["bytes"] - # vad for input bytes audio - vad.add_audio(message) - message = b''.join(f for f in vad.vad_collector() - if f is not None) - + # # vad for input bytes audio + # vad.add_audio(message) + # message = b''.join(f for f in vad.vad_collector() + # if f is not None) engine_pool = get_engine_pool() asr_engine = engine_pool['asr'] asr_results = "" frames = chunk_buffer.frame_generator(message) for frame in frames: + # get the pcm data from the bytes samples = np.frombuffer(frame.bytes, dtype=np.int16) sample_rate = asr_engine.config.sample_rate x_chunk, x_chunk_lens = asr_engine.preprocess(samples, From 0c5dbbee5bdb784e44d7f6ad1f7a7d911c833e06 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Sat, 16 Apr 2022 21:37:46 +0800 Subject: [PATCH 03/16] add conformer ctc prefix beam search decoding method, test=doc --- .../server/engine/asr/online/asr_engine.py | 213 +++++++++++++++--- paddlespeech/server/ws/asr_socket.py | 26 ++- 2 files changed, 195 insertions(+), 44 deletions(-) diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index a5b9ab48..e1e4a7ad 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from paddlespeech.s2t.utils.utility import log_add from typing import Optional - +from collections import defaultdict import numpy as np import paddle from numpy import float32 @@ -23,10 +24,14 @@ from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.asr.infer import model_alias from paddlespeech.cli.asr.infer import pretrained_models from paddlespeech.cli.log import logger +from paddlespeech.cli.utils import download_and_decompress from paddlespeech.cli.utils import MODEL_HOME from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.speech import SpeechSegment from paddlespeech.s2t.modules.ctc import CTCDecoder +from paddlespeech.s2t.modules.mask import mask_finished_preds +from paddlespeech.s2t.modules.mask import mask_finished_scores +from paddlespeech.s2t.modules.mask import subsequent_mask from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.utility import UpdateConfig @@ -57,17 +62,17 @@ pretrained_models = { }, "conformer2online_aishell-zh-16k": { 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr1_chunk_conformer_aishell_ckpt_0.1.2.model.tar.gz', + 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.0.model.tar.gz', 'md5': - '4814e52e0fc2fd48899373f95c84b0c9', + '7989b3248c898070904cf042fd656003', 'cfg_path': - 'exp/chunk_conformer//conf/config.yaml', + 'model.yaml', 'ckpt_path': - 'exp/chunk_conformer/checkpoints/avg_30/', + 'exp/chunk_conformer/checkpoints/multi_cn', 'model': - 'exp/chunk_conformer/checkpoints/avg_30.pdparams', + 'exp/chunk_conformer/checkpoints/multi_cn.pdparams', 'params': - 'exp/chunk_conformer/checkpoints/avg_30.pdparams', + 'exp/chunk_conformer/checkpoints/multi_cn.pdparams', 'lm_url': 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', 'lm_md5': @@ -81,6 +86,23 @@ class ASRServerExecutor(ASRExecutor): super().__init__() pass + def _get_pretrained_path(self, tag: str) -> os.PathLike: + """ + Download and returns pretrained resources path of current task. + """ + support_models = list(pretrained_models.keys()) + assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format( + tag, '\n\t\t'.join(support_models)) + + res_path = os.path.join(MODEL_HOME, tag) + decompressed_path = download_and_decompress(pretrained_models[tag], + res_path) + decompressed_path = os.path.abspath(decompressed_path) + logger.info( + 'Use pretrained model stored in: {}'.format(decompressed_path)) + + return decompressed_path + def _init_from_path(self, model_type: str='wenetspeech', am_model: Optional[os.PathLike]=None, @@ -101,7 +123,7 @@ class ASRServerExecutor(ASRExecutor): logger.info(f"Load the pretrained model, tag = {tag}") res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.res_path = res_path - self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/paddlespeech/server/tests/asr/online/conf/config.yaml" + self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/examples/aishell/asr1/model.yaml" # self.cfg_path = os.path.join(res_path, # pretrained_models[tag]['cfg_path']) @@ -147,8 +169,7 @@ class ASRServerExecutor(ASRExecutor): if self.config.spm_model_prefix: self.config.spm_model_prefix = os.path.join( self.res_path, self.config.spm_model_prefix) - self.config.vocab_filepath = os.path.join( - self.res_path, self.config.vocab_filepath) + self.vocab = self.config.vocab_filepath self.text_feature = TextFeaturizer( unit_type=self.config.unit_type, vocab=self.config.vocab_filepath, @@ -203,19 +224,31 @@ class ASRServerExecutor(ASRExecutor): model_conf = self.config model = model_class.from_config(model_conf) self.model = model + self.model.eval() + + # load model + model_dict = paddle.load(self.am_model) + self.model.set_state_dict(model_dict) logger.info("create the transformer like model success") + # update the ctc decoding + self.searcher = None + self.transformer_decode_reset() + def reset_decoder_and_chunk(self): """reset decoder and chunk state for an new audio """ - self.decoder.reset_decoder(batch_size=1) - # init state box, for new audio request - self.chunk_state_h_box = np.zeros( - (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), - dtype=float32) - self.chunk_state_c_box = np.zeros( - (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), - dtype=float32) + if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: + self.decoder.reset_decoder(batch_size=1) + # init state box, for new audio request + self.chunk_state_h_box = np.zeros( + (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), + dtype=float32) + self.chunk_state_c_box = np.zeros( + (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), + dtype=float32) + elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type: + self.transformer_decode_reset() def decode_one_chunk(self, x_chunk, x_chunk_lens, model_type: str): """decode one chunk @@ -275,24 +308,137 @@ class ASRServerExecutor(ASRExecutor): logger.info( f"we will use the transformer like model : {self.model_type}" ) - cfg = self.config.decode - result_transcripts = self.model.decode( - x_chunk, - x_chunk_lens, - text_feature=self.text_feature, - decoding_method=cfg.decoding_method, - beam_size=cfg.beam_size, - ctc_weight=cfg.ctc_weight, - decoding_chunk_size=cfg.decoding_chunk_size, - num_decoding_left_chunks=cfg.num_decoding_left_chunks, - simulate_streaming=cfg.simulate_streaming) - - return result_transcripts[0][0] + self.advanced_decoding(x_chunk, x_chunk_lens) + self.update_result() + + return self.result_transcripts[0] except Exception as e: logger.exception(e) else: raise Exception("invalid model name") + def advanced_decoding(self, xs: paddle.Tensor, x_chunk_lens): + logger.info("start to decode with advanced_decoding method") + encoder_out, encoder_mask = self.decode_forward(xs) + self.ctc_prefix_beam_search(xs, encoder_out, encoder_mask) + + def decode_forward(self, xs): + logger.info("get the model out from the feat") + cfg = self.config.decode + decoding_chunk_size = cfg.decoding_chunk_size + num_decoding_left_chunks = cfg.num_decoding_left_chunks + + assert decoding_chunk_size > 0 + subsampling = self.model.encoder.embed.subsampling_rate + context = self.model.encoder.embed.right_context + 1 + stride = subsampling * decoding_chunk_size + + # decoding window for model + decoding_window = (decoding_chunk_size - 1) * subsampling + context + num_frames = xs.shape[1] + required_cache_size = decoding_chunk_size * num_decoding_left_chunks + + + logger.info("start to do model forward") + outputs = [] + + # num_frames - context + 1 ensure that current frame can get context window + for cur in range(0, num_frames - context + 1, stride): + end = min(cur + decoding_window, num_frames) + chunk_xs = xs[:, cur:end, :] + (y, self.subsampling_cache, self.elayers_output_cache, + self.conformer_cnn_cache) = self.model.encoder.forward_chunk( + chunk_xs, self.offset, required_cache_size, + self.subsampling_cache, self.elayers_output_cache, + self.conformer_cnn_cache) + outputs.append(y) + self.offset += y.shape[1] + + ys = paddle.cat(outputs, 1) + masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) + masks = masks.unsqueeze(1) + return ys, masks + + def transformer_decode_reset(self): + self.subsampling_cache = None + self.elayers_output_cache = None + self.conformer_cnn_cache = None + self.hyps = None + self.offset = 0 + self.cur_hyps = None + self.hyps = None + + def ctc_prefix_beam_search(self, xs, encoder_out, encoder_mask, blank_id=0): + # decode + logger.info("start to ctc prefix search") + + device = xs.place + cfg = self.config.decode + batch_size = xs.shape[0] + beam_size = cfg.beam_size + maxlen = encoder_out.shape[1] + + ctc_probs = self.model.ctc.log_softmax(encoder_out) # (1, maxlen, vocab_size) + ctc_probs = ctc_probs.squeeze(0) + + # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score)) + # blank_ending_score and none_blank_ending_score in ln domain + if self.cur_hyps is None: + self.cur_hyps = [(tuple(), (0.0, -float('inf')))] + # 2. CTC beam search step by step + for t in range(0, maxlen): + logp = ctc_probs[t] # (vocab_size,) + # key: prefix, value (pb, pnb), default value(-inf, -inf) + next_hyps = defaultdict(lambda: (-float('inf'), -float('inf'))) + + # 2.1 First beam prune: select topk best + # do token passing process + top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,) + for s in top_k_index: + s = s.item() + ps = logp[s].item() + for prefix, (pb, pnb) in self.cur_hyps: + last = prefix[-1] if len(prefix) > 0 else None + if s == blank_id: # blank + n_pb, n_pnb = next_hyps[prefix] + n_pb = log_add([n_pb, pb + ps, pnb + ps]) + next_hyps[prefix] = (n_pb, n_pnb) + elif s == last: + # Update *ss -> *s; + n_pb, n_pnb = next_hyps[prefix] + n_pnb = log_add([n_pnb, pnb + ps]) + next_hyps[prefix] = (n_pb, n_pnb) + # Update *s-s -> *ss, - is for blank + n_prefix = prefix + (s, ) + n_pb, n_pnb = next_hyps[n_prefix] + n_pnb = log_add([n_pnb, pb + ps]) + next_hyps[n_prefix] = (n_pb, n_pnb) + else: + n_prefix = prefix + (s, ) + n_pb, n_pnb = next_hyps[n_prefix] + n_pnb = log_add([n_pnb, pb + ps, pnb + ps]) + next_hyps[n_prefix] = (n_pb, n_pnb) + + # 2.2 Second beam prune + next_hyps = sorted( + next_hyps.items(), + key=lambda x: log_add(list(x[1])), + reverse=True) + self.cur_hyps = next_hyps[:beam_size] + + hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in self.cur_hyps] + + self.hyps = [hyps[0][0]] + logger.info("ctc prefix search success") + return hyps, encoder_out + + def update_result(self): + logger.info("update the final result") + self.result_transcripts = [ + self.text_feature.defeaturize(hyp) for hyp in self.hyps + ] + self.result_tokenids = [hyp for hyp in self.hyps] + def extract_feat(self, samples, sample_rate): """extract feat @@ -304,9 +450,10 @@ class ASRServerExecutor(ASRExecutor): x_chunk (numpy.array): shape[B, T, D] x_chunk_lens (numpy.array): shape[B] """ - # pcm16 -> pcm 32 - samples = pcm2float(samples) + if "deepspeech2online" in self.model_type: + # pcm16 -> pcm 32 + samples = pcm2float(samples) # read audio speech_segment = SpeechSegment.from_pcm( samples, sample_rate, transcript=" ") diff --git a/paddlespeech/server/ws/asr_socket.py b/paddlespeech/server/ws/asr_socket.py index 14254928..4d1013f4 100644 --- a/paddlespeech/server/ws/asr_socket.py +++ b/paddlespeech/server/ws/asr_socket.py @@ -14,7 +14,6 @@ import json import numpy as np -import json from fastapi import APIRouter from fastapi import WebSocket from fastapi import WebSocketDisconnect @@ -86,16 +85,21 @@ async def websocket_endpoint(websocket: WebSocket): engine_pool = get_engine_pool() asr_engine = engine_pool['asr'] asr_results = "" - frames = chunk_buffer.frame_generator(message) - for frame in frames: - # get the pcm data from the bytes - samples = np.frombuffer(frame.bytes, dtype=np.int16) - sample_rate = asr_engine.config.sample_rate - x_chunk, x_chunk_lens = asr_engine.preprocess(samples, - sample_rate) - asr_engine.run(x_chunk, x_chunk_lens) - asr_results = asr_engine.postprocess() - + # frames = chunk_buffer.frame_generator(message) + # for frame in frames: + # # get the pcm data from the bytes + # samples = np.frombuffer(frame.bytes, dtype=np.int16) + # sample_rate = asr_engine.config.sample_rate + # x_chunk, x_chunk_lens = asr_engine.preprocess(samples, + # sample_rate) + # asr_engine.run(x_chunk, x_chunk_lens) + # asr_results = asr_engine.postprocess() + samples = np.frombuffer(message, dtype=np.int16) + sample_rate = asr_engine.config.sample_rate + x_chunk, x_chunk_lens = asr_engine.preprocess(samples, + sample_rate) + asr_engine.run(x_chunk, x_chunk_lens) + # asr_results = asr_engine.postprocess() asr_results = asr_engine.postprocess() resp = {'asr_results': asr_results} From 97d31f9aacc37e936d70f0a10bccf1622fd69323 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Sun, 17 Apr 2022 16:27:35 +0800 Subject: [PATCH 04/16] update the attention_rescoring method, test=doc --- paddlespeech/server/conf/ws_application.yaml | 16 +- .../server/engine/asr/online/asr_engine.py | 177 +++++++++--------- .../server/engine/asr/online/ctc_search.py | 119 ++++++++++++ paddlespeech/server/tests/__init__.py | 13 ++ paddlespeech/server/tests/asr/__init__.py | 13 ++ .../server/tests/asr/offline/__init__.py | 13 ++ .../server/tests/asr/online/__init__.py | 13 ++ paddlespeech/server/ws/asr_socket.py | 43 ++--- 8 files changed, 287 insertions(+), 120 deletions(-) create mode 100644 paddlespeech/server/engine/asr/online/ctc_search.py create mode 100644 paddlespeech/server/tests/__init__.py create mode 100644 paddlespeech/server/tests/asr/__init__.py create mode 100644 paddlespeech/server/tests/asr/offline/__init__.py create mode 100644 paddlespeech/server/tests/asr/online/__init__.py diff --git a/paddlespeech/server/conf/ws_application.yaml b/paddlespeech/server/conf/ws_application.yaml index c3a488fb..aa3c208b 100644 --- a/paddlespeech/server/conf/ws_application.yaml +++ b/paddlespeech/server/conf/ws_application.yaml @@ -71,15 +71,9 @@ asr_online: summary: True # False -> do not show predictor config chunk_buffer_conf: - frame_duration_ms: 85 - shift_ms: 40 + window_n: 7 # frame + shift_n: 4 # frame + window_ms: 25 # ms + shift_ms: 10 # ms sample_rate: 16000 - sample_width: 2 - - # vad_conf: - # aggressiveness: 2 - # sample_rate: 16000 - # frame_duration_ms: 20 - # sample_width: 2 - # padding_ms: 200 - # padding_ratio: 0.9 \ No newline at end of file + sample_width: 2 \ No newline at end of file diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index e1e4a7ad..e292f9cf 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from paddlespeech.s2t.utils.utility import log_add from typing import Optional -from collections import defaultdict + import numpy as np import paddle from numpy import float32 @@ -22,19 +21,18 @@ from yacs.config import CfgNode from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.asr.infer import model_alias -from paddlespeech.cli.asr.infer import pretrained_models from paddlespeech.cli.log import logger from paddlespeech.cli.utils import download_and_decompress from paddlespeech.cli.utils import MODEL_HOME from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.speech import SpeechSegment from paddlespeech.s2t.modules.ctc import CTCDecoder -from paddlespeech.s2t.modules.mask import mask_finished_preds -from paddlespeech.s2t.modules.mask import mask_finished_scores -from paddlespeech.s2t.modules.mask import subsequent_mask from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.utils.dynamic_import import dynamic_import +from paddlespeech.s2t.utils.tensor_utils import add_sos_eos +from paddlespeech.s2t.utils.tensor_utils import pad_sequence from paddlespeech.s2t.utils.utility import UpdateConfig +from paddlespeech.server.engine.asr.online.ctc_search import CTCPrefixBeamSearch from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.audio_process import pcm2float from paddlespeech.server.utils.paddle_predictor import init_predictor @@ -62,9 +60,9 @@ pretrained_models = { }, "conformer2online_aishell-zh-16k": { 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.0.model.tar.gz', + 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.1.model.tar.gz', 'md5': - '7989b3248c898070904cf042fd656003', + 'b450d5dfaea0ac227c595ce58d18b637', 'cfg_path': 'model.yaml', 'ckpt_path': @@ -123,9 +121,9 @@ class ASRServerExecutor(ASRExecutor): logger.info(f"Load the pretrained model, tag = {tag}") res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.res_path = res_path - self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/examples/aishell/asr1/model.yaml" - # self.cfg_path = os.path.join(res_path, - # pretrained_models[tag]['cfg_path']) + # self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/examples/aishell/asr1/model.yaml" + self.cfg_path = os.path.join(res_path, + pretrained_models[tag]['cfg_path']) self.am_model = os.path.join(res_path, pretrained_models[tag]['model']) @@ -177,6 +175,18 @@ class ASRServerExecutor(ASRExecutor): # update the decoding method if decode_method: self.config.decode.decoding_method = decode_method + + # we only support ctc_prefix_beam_search and attention_rescoring dedoding method + # Generally we set the decoding_method to attention_rescoring + if self.config.decode.decoding_method not in [ + "ctc_prefix_beam_search", "attention_rescoring" + ]: + logger.info( + "we set the decoding_method to attention_rescoring") + self.config.decode.decoding = "attention_rescoring" + assert self.config.decode.decoding_method in [ + "ctc_prefix_beam_search", "attention_rescoring" + ], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}" else: raise Exception("wrong type") if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: @@ -232,7 +242,7 @@ class ASRServerExecutor(ASRExecutor): logger.info("create the transformer like model success") # update the ctc decoding - self.searcher = None + self.searcher = CTCPrefixBeamSearch(self.config.decode) self.transformer_decode_reset() def reset_decoder_and_chunk(self): @@ -320,7 +330,16 @@ class ASRServerExecutor(ASRExecutor): def advanced_decoding(self, xs: paddle.Tensor, x_chunk_lens): logger.info("start to decode with advanced_decoding method") encoder_out, encoder_mask = self.decode_forward(xs) - self.ctc_prefix_beam_search(xs, encoder_out, encoder_mask) + ctc_probs = self.model.ctc.log_softmax( + encoder_out) # (1, maxlen, vocab_size) + ctc_probs = ctc_probs.squeeze(0) + self.searcher.search(xs, ctc_probs, xs.place) + # update the one best result + self.hyps = self.searcher.get_one_best_hyps() + + # now we supprot ctc_prefix_beam_search and attention_rescoring + if "attention_rescoring" in self.config.decode.decoding_method: + self.rescoring(encoder_out, xs.place) def decode_forward(self, xs): logger.info("get the model out from the feat") @@ -338,7 +357,6 @@ class ASRServerExecutor(ASRExecutor): num_frames = xs.shape[1] required_cache_size = decoding_chunk_size * num_decoding_left_chunks - logger.info("start to do model forward") outputs = [] @@ -359,85 +377,74 @@ class ASRServerExecutor(ASRExecutor): masks = masks.unsqueeze(1) return ys, masks + def rescoring(self, encoder_out, device): + logger.info("start to rescoring the hyps") + beam_size = self.config.decode.beam_size + hyps = self.searcher.get_hyps() + assert len(hyps) == beam_size + + hyp_list = [] + for hyp in hyps: + hyp_content = hyp[0] + # Prevent the hyp is empty + if len(hyp_content) == 0: + hyp_content = (self.model.ctc.blank_id, ) + hyp_content = paddle.to_tensor( + hyp_content, place=device, dtype=paddle.long) + hyp_list.append(hyp_content) + hyps_pad = pad_sequence(hyp_list, True, self.model.ignore_id) + hyps_lens = paddle.to_tensor( + [len(hyp[0]) for hyp in hyps], place=device, + dtype=paddle.long) # (beam_size,) + hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos, + self.model.ignore_id) + hyps_lens = hyps_lens + 1 # Add at begining + + encoder_out = encoder_out.repeat(beam_size, 1, 1) + encoder_mask = paddle.ones( + (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool) + decoder_out, _ = self.model.decoder( + encoder_out, encoder_mask, hyps_pad, + hyps_lens) # (beam_size, max_hyps_len, vocab_size) + # ctc score in ln domain + decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) + decoder_out = decoder_out.numpy() + + # Only use decoder score for rescoring + best_score = -float('inf') + best_index = 0 + # hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size + for i, hyp in enumerate(hyps): + score = 0.0 + for j, w in enumerate(hyp[0]): + score += decoder_out[i][j][w] + # last decoder output token is `eos`, for laste decoder input token. + score += decoder_out[i][len(hyp[0])][self.model.eos] + # add ctc score (which in ln domain) + score += hyp[1] * self.config.decode.ctc_weight + if score > best_score: + best_score = score + best_index = i + + # update the one best result + self.hyps = [hyps[best_index][0]] + return hyps[best_index][0] + def transformer_decode_reset(self): self.subsampling_cache = None self.elayers_output_cache = None self.conformer_cnn_cache = None - self.hyps = None self.offset = 0 - self.cur_hyps = None - self.hyps = None - - def ctc_prefix_beam_search(self, xs, encoder_out, encoder_mask, blank_id=0): - # decode - logger.info("start to ctc prefix search") - - device = xs.place - cfg = self.config.decode - batch_size = xs.shape[0] - beam_size = cfg.beam_size - maxlen = encoder_out.shape[1] - - ctc_probs = self.model.ctc.log_softmax(encoder_out) # (1, maxlen, vocab_size) - ctc_probs = ctc_probs.squeeze(0) - - # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score)) - # blank_ending_score and none_blank_ending_score in ln domain - if self.cur_hyps is None: - self.cur_hyps = [(tuple(), (0.0, -float('inf')))] - # 2. CTC beam search step by step - for t in range(0, maxlen): - logp = ctc_probs[t] # (vocab_size,) - # key: prefix, value (pb, pnb), default value(-inf, -inf) - next_hyps = defaultdict(lambda: (-float('inf'), -float('inf'))) - - # 2.1 First beam prune: select topk best - # do token passing process - top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,) - for s in top_k_index: - s = s.item() - ps = logp[s].item() - for prefix, (pb, pnb) in self.cur_hyps: - last = prefix[-1] if len(prefix) > 0 else None - if s == blank_id: # blank - n_pb, n_pnb = next_hyps[prefix] - n_pb = log_add([n_pb, pb + ps, pnb + ps]) - next_hyps[prefix] = (n_pb, n_pnb) - elif s == last: - # Update *ss -> *s; - n_pb, n_pnb = next_hyps[prefix] - n_pnb = log_add([n_pnb, pnb + ps]) - next_hyps[prefix] = (n_pb, n_pnb) - # Update *s-s -> *ss, - is for blank - n_prefix = prefix + (s, ) - n_pb, n_pnb = next_hyps[n_prefix] - n_pnb = log_add([n_pnb, pb + ps]) - next_hyps[n_prefix] = (n_pb, n_pnb) - else: - n_prefix = prefix + (s, ) - n_pb, n_pnb = next_hyps[n_prefix] - n_pnb = log_add([n_pnb, pb + ps, pnb + ps]) - next_hyps[n_prefix] = (n_pb, n_pnb) - - # 2.2 Second beam prune - next_hyps = sorted( - next_hyps.items(), - key=lambda x: log_add(list(x[1])), - reverse=True) - self.cur_hyps = next_hyps[:beam_size] - - hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in self.cur_hyps] - - self.hyps = [hyps[0][0]] - logger.info("ctc prefix search success") - return hyps, encoder_out + # decoding reset + self.searcher.reset() def update_result(self): logger.info("update the final result") + hyps = self.hyps self.result_transcripts = [ - self.text_feature.defeaturize(hyp) for hyp in self.hyps + self.text_feature.defeaturize(hyp) for hyp in hyps ] - self.result_tokenids = [hyp for hyp in self.hyps] + self.result_tokenids = [hyp for hyp in hyps] def extract_feat(self, samples, sample_rate): """extract feat @@ -483,9 +490,9 @@ class ASRServerExecutor(ASRExecutor): elif "conformer2online" in self.model_type: if sample_rate != self.sample_rate: - logger.info(f"audio sample rate {sample_rate} is not match," \ + logger.info(f"audio sample rate {sample_rate} is not match," "the model sample_rate is {self.sample_rate}") - logger.info(f"ASR Engine use the {self.model_type} to process") + logger.info("ASR Engine use the {self.model_type} to process") logger.info("Create the preprocess instance") preprocess_conf = self.config.preprocess_config preprocess_args = {"train": False} diff --git a/paddlespeech/server/engine/asr/online/ctc_search.py b/paddlespeech/server/engine/asr/online/ctc_search.py new file mode 100644 index 00000000..a91b8a21 --- /dev/null +++ b/paddlespeech/server/engine/asr/online/ctc_search.py @@ -0,0 +1,119 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import defaultdict + +from paddlespeech.cli.log import logger +from paddlespeech.s2t.utils.utility import log_add + +__all__ = ['CTCPrefixBeamSearch'] + + +class CTCPrefixBeamSearch: + def __init__(self, config): + """Implement the ctc prefix beam search + + Args: + config (_type_): _description_ + """ + self.config = config + self.reset() + + def search(self, xs, ctc_probs, device, blank_id=0): + """ctc prefix beam search method decode a chunk feature + + Args: + xs (paddle.Tensor): feature data + ctc_probs (paddle.Tensor): the ctc probability of all the tokens + encoder_out (paddle.Tensor): _description_ + encoder_mask (_type_): _description_ + blank_id (int, optional): the blank id in the vocab. Defaults to 0. + + Returns: + list: the search result + """ + # decode + logger.info("start to ctc prefix search") + + # device = xs.place + batch_size = xs.shape[0] + beam_size = self.config.beam_size + maxlen = ctc_probs.shape[0] + + assert len(ctc_probs.shape) == 2 + + # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score)) + # blank_ending_score and none_blank_ending_score in ln domain + if self.cur_hyps is None: + self.cur_hyps = [(tuple(), (0.0, -float('inf')))] + # 2. CTC beam search step by step + for t in range(0, maxlen): + logp = ctc_probs[t] # (vocab_size,) + # key: prefix, value (pb, pnb), default value(-inf, -inf) + next_hyps = defaultdict(lambda: (-float('inf'), -float('inf'))) + + # 2.1 First beam prune: select topk best + # do token passing process + top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,) + for s in top_k_index: + s = s.item() + ps = logp[s].item() + for prefix, (pb, pnb) in self.cur_hyps: + last = prefix[-1] if len(prefix) > 0 else None + if s == blank_id: # blank + n_pb, n_pnb = next_hyps[prefix] + n_pb = log_add([n_pb, pb + ps, pnb + ps]) + next_hyps[prefix] = (n_pb, n_pnb) + elif s == last: + # Update *ss -> *s; + n_pb, n_pnb = next_hyps[prefix] + n_pnb = log_add([n_pnb, pnb + ps]) + next_hyps[prefix] = (n_pb, n_pnb) + # Update *s-s -> *ss, - is for blank + n_prefix = prefix + (s, ) + n_pb, n_pnb = next_hyps[n_prefix] + n_pnb = log_add([n_pnb, pb + ps]) + next_hyps[n_prefix] = (n_pb, n_pnb) + else: + n_prefix = prefix + (s, ) + n_pb, n_pnb = next_hyps[n_prefix] + n_pnb = log_add([n_pnb, pb + ps, pnb + ps]) + next_hyps[n_prefix] = (n_pb, n_pnb) + + # 2.2 Second beam prune + next_hyps = sorted( + next_hyps.items(), + key=lambda x: log_add(list(x[1])), + reverse=True) + self.cur_hyps = next_hyps[:beam_size] + + self.hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in self.cur_hyps] + logger.info("ctc prefix search success") + return self.hyps + + def get_one_best_hyps(self): + """Return the one best result + + Returns: + list: the one best result + """ + return [self.hyps[0][0]] + + def get_hyps(self): + return self.hyps + + def reset(self): + """Rest the search cache value + """ + self.cur_hyps = None + self.hyps = None diff --git a/paddlespeech/server/tests/__init__.py b/paddlespeech/server/tests/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/server/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/tests/asr/__init__.py b/paddlespeech/server/tests/asr/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/server/tests/asr/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/tests/asr/offline/__init__.py b/paddlespeech/server/tests/asr/offline/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/server/tests/asr/offline/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/tests/asr/online/__init__.py b/paddlespeech/server/tests/asr/online/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/server/tests/asr/online/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/ws/asr_socket.py b/paddlespeech/server/ws/asr_socket.py index 4d1013f4..87b43d2c 100644 --- a/paddlespeech/server/ws/asr_socket.py +++ b/paddlespeech/server/ws/asr_socket.py @@ -34,17 +34,17 @@ async def websocket_endpoint(websocket: WebSocket): engine_pool = get_engine_pool() asr_engine = engine_pool['asr'] # init buffer + # each websocekt connection has its own chunk buffer chunk_buffer_conf = asr_engine.config.chunk_buffer_conf chunk_buffer = ChunkBuffer( - window_n=7, - shift_n=4, - window_ms=20, - shift_ms=10, - sample_rate=chunk_buffer_conf['sample_rate'], - sample_width=chunk_buffer_conf['sample_width']) + window_n=chunk_buffer_conf.window_n, + shift_n=chunk_buffer_conf.shift_n, + window_ms=chunk_buffer_conf.window_ms, + shift_ms=chunk_buffer_conf.shift_ms, + sample_rate=chunk_buffer_conf.sample_rate, + sample_width=chunk_buffer_conf.sample_width) + # init vad - # print(asr_engine.config) - # print(type(asr_engine.config)) vad_conf = asr_engine.config.get('vad_conf', None) if vad_conf: vad = VADAudio( @@ -72,7 +72,7 @@ async def websocket_endpoint(websocket: WebSocket): engine_pool = get_engine_pool() asr_engine = engine_pool['asr'] # reset single engine for an new connection - # asr_engine.reset() + asr_engine.reset() resp = {"status": "ok", "signal": "finished"} await websocket.send_json(resp) break @@ -85,21 +85,16 @@ async def websocket_endpoint(websocket: WebSocket): engine_pool = get_engine_pool() asr_engine = engine_pool['asr'] asr_results = "" - # frames = chunk_buffer.frame_generator(message) - # for frame in frames: - # # get the pcm data from the bytes - # samples = np.frombuffer(frame.bytes, dtype=np.int16) - # sample_rate = asr_engine.config.sample_rate - # x_chunk, x_chunk_lens = asr_engine.preprocess(samples, - # sample_rate) - # asr_engine.run(x_chunk, x_chunk_lens) - # asr_results = asr_engine.postprocess() - samples = np.frombuffer(message, dtype=np.int16) - sample_rate = asr_engine.config.sample_rate - x_chunk, x_chunk_lens = asr_engine.preprocess(samples, - sample_rate) - asr_engine.run(x_chunk, x_chunk_lens) - # asr_results = asr_engine.postprocess() + frames = chunk_buffer.frame_generator(message) + for frame in frames: + # get the pcm data from the bytes + samples = np.frombuffer(frame.bytes, dtype=np.int16) + sample_rate = asr_engine.config.sample_rate + x_chunk, x_chunk_lens = asr_engine.preprocess(samples, + sample_rate) + asr_engine.run(x_chunk, x_chunk_lens) + asr_results = asr_engine.postprocess() + asr_results = asr_engine.postprocess() resp = {'asr_results': asr_results} From d2640c14064058c5283830fd2046d1788e800046 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Mon, 18 Apr 2022 12:58:40 +0800 Subject: [PATCH 05/16] add mult sesssion process, test=doc --- .../server/engine/asr/online/asr_engine.py | 190 +++++++++++++++++- 1 file changed, 189 insertions(+), 1 deletion(-) diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index e292f9cf..3546e598 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -78,6 +78,194 @@ pretrained_models = { }, } +# ASR server connection process class + +class PaddleASRConnectionHanddler: + def __init__(self, asr_engine): + super().__init__() + self.config = asr_engine.config + self.model_config = asr_engine.executor.config + self.asr_engine = asr_engine + + self.init() + self.reset() + + def init(self): + self.model_type = self.asr_engine.executor.model_type + if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: + pass + elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type: + self.sample_rate = self.asr_engine.executor.sample_rate + + # acoustic model + self.model = self.asr_engine.executor.model + + # tokens to text + self.text_feature = self.asr_engine.executor.text_feature + + # ctc decoding + self.ctc_decode_config = self.asr_engine.executor.config.decode + self.searcher = CTCPrefixBeamSearch(self.ctc_decode_config) + + # extract fbank + self.preprocess_conf = self.model_config.preprocess_config + self.preprocess_args = {"train": False} + self.preprocessing = Transformation(self.preprocess_conf) + self.win_length = self.preprocess_conf.process[0]['win_length'] + self.n_shift = self.preprocess_conf.process[0]['n_shift'] + + def extract_feat(self, samples): + if "deepspeech2online" in self.model_type: + pass + elif "conformer2online" in self.model_type: + logger.info("Online ASR extract the feat") + samples = np.frombuffer(samples, dtype=np.int16) + assert samples.ndim == 1 + + logger.info(f"This package receive {samples.shape[0]} pcm data") + self.num_samples += samples.shape[0] + + # self.reamined_wav stores all the samples, + # include the original remained_wav and this package samples + if self.remained_wav is None: + self.remained_wav = samples + else: + assert self.remained_wav.ndim == 1 + self.remained_wav = np.concatenate([self.remained_wav, samples]) + logger.info( + f"The connection remain the audio samples: {self.remained_wav.shape}" + ) + if len(self.remained_wav) < self.win_length: + return 0 + + # fbank + x_chunk = self.preprocessing(self.remained_wav, **self.preprocess_args) + x_chunk = paddle.to_tensor( + x_chunk, dtype="float32").unsqueeze(axis=0) + if self.cached_feat is None: + self.cached_feat = x_chunk + else: + self.cached_feat = paddle.concat([self.cached_feat, x_chunk], axis=1) + + num_frames = x_chunk.shape[1] + self.num_frames += num_frames + self.remained_wav = self.remained_wav[self.n_shift * num_frames:] + + logger.info( + f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}" + ) + logger.info( + f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" + ) + # logger.info(f"accumulate samples: {self.num_samples}") + + def reset(self): + self.subsampling_cache = None + self.elayers_output_cache = None + self.conformer_cnn_cache = None + self.encoder_outs_ = None + self.cached_feat = None + self.remained_wav = None + self.offset = 0 + self.num_samples = 0 + + self.num_frames = 0 + self.global_frame_offset = 0 + self.result = [] + + def decode(self, is_finished=False): + if "deepspeech2online" in self.model_type: + pass + elif "conformer" in self.model_type or "transformer" in self.model_type: + try: + logger.info( + f"we will use the transformer like model : {self.model_type}" + ) + self.advance_decoding(is_finished) + # self.update_result() + + # return self.result_transcripts[0] + except Exception as e: + logger.exception(e) + else: + raise Exception("invalid model name") + + def advance_decoding(self, is_finished=False): + logger.info("start to decode with advanced_decoding method") + cfg = self.ctc_decode_config + decoding_chunk_size = cfg.decoding_chunk_size + num_decoding_left_chunks = cfg.num_decoding_left_chunks + + assert decoding_chunk_size > 0 + subsampling = self.model.encoder.embed.subsampling_rate + context = self.model.encoder.embed.right_context + 1 + stride = subsampling * decoding_chunk_size + + # decoding window for model + decoding_window = (decoding_chunk_size - 1) * subsampling + context + num_frames = self.cached_feat.shape[1] + logger.info(f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames") + + # the cached feat must be larger decoding_window + if num_frames < decoding_window and not is_finished: + return None, None + + # logger.info("start to do model forward") + # required_cache_size = decoding_chunk_size * num_decoding_left_chunks + # outputs = [] + + # # num_frames - context + 1 ensure that current frame can get context window + # if is_finished: + # # if get the finished chunk, we need process the last context + # left_frames = context + # else: + # # we only process decoding_window frames for one chunk + # left_frames = decoding_window + + # logger.info(f"") + # end = None + # for cur in range(0, num_frames - left_frames + 1, stride): + # end = min(cur + decoding_window, num_frames) + # print(f"cur: {cur}, end: {end}") + # chunk_xs = self.cached_feat[:, cur:end, :] + # (y, self.subsampling_cache, self.elayers_output_cache, + # self.conformer_cnn_cache) = self.model.encoder.forward_chunk( + # chunk_xs, self.offset, required_cache_size, + # self.subsampling_cache, self.elayers_output_cache, + # self.conformer_cnn_cache) + # outputs.append(y) + # update the offset + # self.offset += y.shape[1] + # self.cached_feat = self.cached_feat[end:] + # ys = paddle.cat(outputs, 1) + # masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) + # masks = masks.unsqueeze(1) + + # # get the ctc probs + # ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size) + # ctc_probs = ctc_probs.squeeze(0) + # # self.searcher.search(xs, ctc_probs, xs.place) + + # self.searcher.search(None, ctc_probs, self.cached_feat.place) + + # self.hyps = self.searcher.get_one_best_hyps() + + # ys for rescoring + # return ys, masks + + def update_result(self): + logger.info("update the final result") + hyps = self.hyps + self.result_transcripts = [ + self.text_feature.defeaturize(hyp) for hyp in hyps + ] + self.result_tokenids = [hyp for hyp in hyps] + + def rescoring(self): + pass + + + class ASRServerExecutor(ASRExecutor): def __init__(self): @@ -492,7 +680,7 @@ class ASRServerExecutor(ASRExecutor): if sample_rate != self.sample_rate: logger.info(f"audio sample rate {sample_rate} is not match," "the model sample_rate is {self.sample_rate}") - logger.info("ASR Engine use the {self.model_type} to process") + logger.info(f"ASR Engine use the {self.model_type} to process") logger.info("Create the preprocess instance") preprocess_conf = self.config.preprocess_config preprocess_args = {"train": False} From 10e825d9b2f619b0c8525c3c24491a657ccc9269 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Mon, 18 Apr 2022 15:05:48 +0800 Subject: [PATCH 06/16] check chunk window process, test=doc --- .../server/engine/asr/online/asr_engine.py | 49 ++++++++++++------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index 3546e598..1f6060e9 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -145,6 +145,8 @@ class PaddleASRConnectionHanddler: if self.cached_feat is None: self.cached_feat = x_chunk else: + assert(len(x_chunk.shape) == 3) + assert(len(self.cached_feat.shape) == 3) self.cached_feat = paddle.concat([self.cached_feat, x_chunk], axis=1) num_frames = x_chunk.shape[1] @@ -170,6 +172,7 @@ class PaddleASRConnectionHanddler: self.num_samples = 0 self.num_frames = 0 + self.chunk_num = 0 self.global_frame_offset = 0 self.result = [] @@ -210,23 +213,24 @@ class PaddleASRConnectionHanddler: if num_frames < decoding_window and not is_finished: return None, None - # logger.info("start to do model forward") - # required_cache_size = decoding_chunk_size * num_decoding_left_chunks - # outputs = [] - - # # num_frames - context + 1 ensure that current frame can get context window - # if is_finished: - # # if get the finished chunk, we need process the last context - # left_frames = context - # else: - # # we only process decoding_window frames for one chunk - # left_frames = decoding_window + logger.info("start to do model forward") + required_cache_size = decoding_chunk_size * num_decoding_left_chunks + outputs = [] + + # num_frames - context + 1 ensure that current frame can get context window + if is_finished: + # if get the finished chunk, we need process the last context + left_frames = context + else: + # we only process decoding_window frames for one chunk + left_frames = decoding_window # logger.info(f"") - # end = None - # for cur in range(0, num_frames - left_frames + 1, stride): - # end = min(cur + decoding_window, num_frames) - # print(f"cur: {cur}, end: {end}") + end = None + for cur in range(0, num_frames - left_frames + 1, stride): + end = min(cur + decoding_window, num_frames) + print(f"cur chunk: {self.chunk_num}, cur: {cur}, end: {end}") + self.chunk_num += 1 # chunk_xs = self.cached_feat[:, cur:end, :] # (y, self.subsampling_cache, self.elayers_output_cache, # self.conformer_cnn_cache) = self.model.encoder.forward_chunk( @@ -236,7 +240,14 @@ class PaddleASRConnectionHanddler: # outputs.append(y) # update the offset # self.offset += y.shape[1] - # self.cached_feat = self.cached_feat[end:] + + # remove the processed feat + if end == num_frames: + self.cached_feat = None + else: + assert self.cached_feat.shape[0] == 1 + self.cached_feat = self.cached_feat[0,end:,:].unsqueeze(0) + assert len(self.cached_feat.shape) == 3, f"current cache feat shape is: {self.cached_feat.shape}" # ys = paddle.cat(outputs, 1) # masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) # masks = masks.unsqueeze(1) @@ -309,9 +320,9 @@ class ASRServerExecutor(ASRExecutor): logger.info(f"Load the pretrained model, tag = {tag}") res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.res_path = res_path - # self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/examples/aishell/asr1/model.yaml" - self.cfg_path = os.path.join(res_path, - pretrained_models[tag]['cfg_path']) + self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/examples/aishell/asr1/model.yaml" + # self.cfg_path = os.path.join(res_path, + # pretrained_models[tag]['cfg_path']) self.am_model = os.path.join(res_path, pretrained_models[tag]['model']) From 68731c61f40d2dc5eb154c5b9cd3faa8f0efd672 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Mon, 18 Apr 2022 15:17:29 +0800 Subject: [PATCH 07/16] add multi session result, test=doc --- .../server/engine/asr/online/asr_engine.py | 50 ++++++++++--------- .../server/engine/asr/online/ctc_search.py | 2 +- paddlespeech/server/utils/buffer.py | 2 +- paddlespeech/server/ws/asr_socket.py | 36 ++++++++----- 4 files changed, 51 insertions(+), 39 deletions(-) diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index 1f6060e9..c13b2f6d 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -185,9 +185,9 @@ class PaddleASRConnectionHanddler: f"we will use the transformer like model : {self.model_type}" ) self.advance_decoding(is_finished) - # self.update_result() + self.update_result() - # return self.result_transcripts[0] + return self.result_transcripts[0] except Exception as e: logger.exception(e) else: @@ -225,22 +225,36 @@ class PaddleASRConnectionHanddler: # we only process decoding_window frames for one chunk left_frames = decoding_window - # logger.info(f"") + # record the end for removing the processed feat end = None for cur in range(0, num_frames - left_frames + 1, stride): end = min(cur + decoding_window, num_frames) - print(f"cur chunk: {self.chunk_num}, cur: {cur}, end: {end}") + self.chunk_num += 1 - # chunk_xs = self.cached_feat[:, cur:end, :] - # (y, self.subsampling_cache, self.elayers_output_cache, - # self.conformer_cnn_cache) = self.model.encoder.forward_chunk( - # chunk_xs, self.offset, required_cache_size, - # self.subsampling_cache, self.elayers_output_cache, - # self.conformer_cnn_cache) - # outputs.append(y) + chunk_xs = self.cached_feat[:, cur:end, :] + (y, self.subsampling_cache, self.elayers_output_cache, + self.conformer_cnn_cache) = self.model.encoder.forward_chunk( + chunk_xs, self.offset, required_cache_size, + self.subsampling_cache, self.elayers_output_cache, + self.conformer_cnn_cache) + outputs.append(y) + # update the offset - # self.offset += y.shape[1] + self.offset += y.shape[1] + ys = paddle.cat(outputs, 1) + masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) + masks = masks.unsqueeze(1) + + # get the ctc probs + ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size) + ctc_probs = ctc_probs.squeeze(0) + # self.searcher.search(xs, ctc_probs, xs.place) + + self.searcher.search(None, ctc_probs, self.cached_feat.place) + + self.hyps = self.searcher.get_one_best_hyps() + # remove the processed feat if end == num_frames: self.cached_feat = None @@ -248,19 +262,7 @@ class PaddleASRConnectionHanddler: assert self.cached_feat.shape[0] == 1 self.cached_feat = self.cached_feat[0,end:,:].unsqueeze(0) assert len(self.cached_feat.shape) == 3, f"current cache feat shape is: {self.cached_feat.shape}" - # ys = paddle.cat(outputs, 1) - # masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) - # masks = masks.unsqueeze(1) - - # # get the ctc probs - # ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size) - # ctc_probs = ctc_probs.squeeze(0) - # # self.searcher.search(xs, ctc_probs, xs.place) - - # self.searcher.search(None, ctc_probs, self.cached_feat.place) - # self.hyps = self.searcher.get_one_best_hyps() - # ys for rescoring # return ys, masks diff --git a/paddlespeech/server/engine/asr/online/ctc_search.py b/paddlespeech/server/engine/asr/online/ctc_search.py index a91b8a21..bf4c4b30 100644 --- a/paddlespeech/server/engine/asr/online/ctc_search.py +++ b/paddlespeech/server/engine/asr/online/ctc_search.py @@ -46,7 +46,7 @@ class CTCPrefixBeamSearch: logger.info("start to ctc prefix search") # device = xs.place - batch_size = xs.shape[0] + batch_size = 1 beam_size = self.config.beam_size maxlen = ctc_probs.shape[0] diff --git a/paddlespeech/server/utils/buffer.py b/paddlespeech/server/utils/buffer.py index 12b1f0e5..d4e6cd49 100644 --- a/paddlespeech/server/utils/buffer.py +++ b/paddlespeech/server/utils/buffer.py @@ -63,12 +63,12 @@ class ChunkBuffer(object): the sample rate. Yields Frames of the requested duration. """ + audio = self.remained_audio + audio self.remained_audio = b'' offset = 0 timestamp = 0.0 - while offset + self.window_bytes <= len(audio): yield Frame(audio[offset:offset + self.window_bytes], timestamp, self.window_sec) diff --git a/paddlespeech/server/ws/asr_socket.py b/paddlespeech/server/ws/asr_socket.py index 87b43d2c..04807e5c 100644 --- a/paddlespeech/server/ws/asr_socket.py +++ b/paddlespeech/server/ws/asr_socket.py @@ -22,6 +22,7 @@ from starlette.websockets import WebSocketState as WebSocketState from paddlespeech.server.engine.engine_pool import get_engine_pool from paddlespeech.server.utils.buffer import ChunkBuffer from paddlespeech.server.utils.vad import VADAudio +from paddlespeech.server.engine.asr.online.asr_engine import PaddleASRConnectionHanddler router = APIRouter() @@ -33,6 +34,7 @@ async def websocket_endpoint(websocket: WebSocket): engine_pool = get_engine_pool() asr_engine = engine_pool['asr'] + connection_handler = None # init buffer # each websocekt connection has its own chunk buffer chunk_buffer_conf = asr_engine.config.chunk_buffer_conf @@ -67,13 +69,17 @@ async def websocket_endpoint(websocket: WebSocket): if message['signal'] == 'start': resp = {"status": "ok", "signal": "server_ready"} # do something at begining here + # create the instance to process the audio + connection_handler = PaddleASRConnectionHanddler(asr_engine) await websocket.send_json(resp) elif message['signal'] == 'end': engine_pool = get_engine_pool() asr_engine = engine_pool['asr'] # reset single engine for an new connection + asr_results = connection_handler.decode(is_finished=True) + connection_handler.reset() asr_engine.reset() - resp = {"status": "ok", "signal": "finished"} + resp = {"status": "ok", "signal": "finished", 'asr_results': asr_results} await websocket.send_json(resp) break else: @@ -81,23 +87,27 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.send_json(resp) elif "bytes" in message: message = message["bytes"] - engine_pool = get_engine_pool() asr_engine = engine_pool['asr'] asr_results = "" - frames = chunk_buffer.frame_generator(message) - for frame in frames: - # get the pcm data from the bytes - samples = np.frombuffer(frame.bytes, dtype=np.int16) - sample_rate = asr_engine.config.sample_rate - x_chunk, x_chunk_lens = asr_engine.preprocess(samples, - sample_rate) - asr_engine.run(x_chunk, x_chunk_lens) - asr_results = asr_engine.postprocess() + connection_handler.extract_feat(message) + asr_results = connection_handler.decode(is_finished=False) + # connection_handler. + # frames = chunk_buffer.frame_generator(message) + # for frame in frames: + # # get the pcm data from the bytes + # samples = np.frombuffer(frame.bytes, dtype=np.int16) + # sample_rate = asr_engine.config.sample_rate + # x_chunk, x_chunk_lens = asr_engine.preprocess(samples, + # sample_rate) + # asr_engine.run(x_chunk, x_chunk_lens) + # asr_results = asr_engine.postprocess() - asr_results = asr_engine.postprocess() + # # connection accept the sample data frame by frame + + # asr_results = asr_engine.postprocess() resp = {'asr_results': asr_results} - + print("\n") await websocket.send_json(resp) except WebSocketDisconnect: pass From 05a8a4b5fccec0fe549132717f24d25c3240b04f Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Mon, 18 Apr 2022 17:11:49 +0800 Subject: [PATCH 08/16] add connection stability, test=doc --- .../server/engine/asr/online/asr_engine.py | 109 ++++++++++++++++-- .../server/engine/asr/online/ctc_search.py | 10 ++ paddlespeech/server/ws/asr_socket.py | 37 +++--- 3 files changed, 121 insertions(+), 35 deletions(-) diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index c13b2f6d..696d223a 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -83,8 +83,10 @@ pretrained_models = { class PaddleASRConnectionHanddler: def __init__(self, asr_engine): super().__init__() + logger.info("create an paddle asr connection handler to process the websocket connection") self.config = asr_engine.config self.model_config = asr_engine.executor.config + self.model = asr_engine.executor.model self.asr_engine = asr_engine self.init() @@ -149,6 +151,10 @@ class PaddleASRConnectionHanddler: assert(len(self.cached_feat.shape) == 3) self.cached_feat = paddle.concat([self.cached_feat, x_chunk], axis=1) + # set the feat device + if self.device is None: + self.device = self.cached_feat.place + num_frames = x_chunk.shape[1] self.num_frames += num_frames self.remained_wav = self.remained_wav[self.n_shift * num_frames:] @@ -165,16 +171,17 @@ class PaddleASRConnectionHanddler: self.subsampling_cache = None self.elayers_output_cache = None self.conformer_cnn_cache = None - self.encoder_outs_ = None + self.encoder_out = None self.cached_feat = None self.remained_wav = None self.offset = 0 self.num_samples = 0 - + self.device = None + self.hyps = [] self.num_frames = 0 self.chunk_num = 0 self.global_frame_offset = 0 - self.result = [] + self.result_transcripts = [''] def decode(self, is_finished=False): if "deepspeech2online" in self.model_type: @@ -187,7 +194,6 @@ class PaddleASRConnectionHanddler: self.advance_decoding(is_finished) self.update_result() - return self.result_transcripts[0] except Exception as e: logger.exception(e) else: @@ -203,16 +209,26 @@ class PaddleASRConnectionHanddler: subsampling = self.model.encoder.embed.subsampling_rate context = self.model.encoder.embed.right_context + 1 stride = subsampling * decoding_chunk_size + cached_feature_num = context - subsampling # processed chunk feature cached for next chunk # decoding window for model decoding_window = (decoding_chunk_size - 1) * subsampling + context + if self.cached_feat is None: + logger.info("no audio feat, please input more pcm data") + return + num_frames = self.cached_feat.shape[1] logger.info(f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames") # the cached feat must be larger decoding_window if num_frames < decoding_window and not is_finished: + logger.info(f"frame feat num is less than {decoding_window}, please input more pcm data") return None, None - + + if num_frames < context: + logger.info("flast {num_frames} is less than context {context} frames, and we cannot do model forward") + return None, None + logger.info("start to do model forward") required_cache_size = decoding_chunk_size * num_decoding_left_chunks outputs = [] @@ -242,14 +258,18 @@ class PaddleASRConnectionHanddler: # update the offset self.offset += y.shape[1] + logger.info(f"output size: {len(outputs)}") ys = paddle.cat(outputs, 1) - masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) - masks = masks.unsqueeze(1) + if self.encoder_out is None: + self.encoder_out = ys + else: + self.encoder_out = paddle.concat([self.encoder_out, ys], axis=1) + # masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) + # masks = masks.unsqueeze(1) # get the ctc probs ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size) ctc_probs = ctc_probs.squeeze(0) - # self.searcher.search(xs, ctc_probs, xs.place) self.searcher.search(None, ctc_probs, self.cached_feat.place) @@ -260,7 +280,8 @@ class PaddleASRConnectionHanddler: self.cached_feat = None else: assert self.cached_feat.shape[0] == 1 - self.cached_feat = self.cached_feat[0,end:,:].unsqueeze(0) + assert end >= cached_feature_num + self.cached_feat = self.cached_feat[0,end - cached_feature_num:,:].unsqueeze(0) assert len(self.cached_feat.shape) == 3, f"current cache feat shape is: {self.cached_feat.shape}" # ys for rescoring @@ -274,9 +295,75 @@ class PaddleASRConnectionHanddler: ] self.result_tokenids = [hyp for hyp in hyps] + def get_result(self): + if len(self.result_transcripts) > 0: + return self.result_transcripts[0] + else: + return '' + def rescoring(self): - pass + logger.info("rescoring the final result") + if "attention_rescoring" != self.ctc_decode_config.decoding_method: + return + + self.searcher.finalize_search() + self.update_result() + + beam_size = self.ctc_decode_config.beam_size + hyps = self.searcher.get_hyps() + if hyps is None or len(hyps) == 0: + return + + # assert len(hyps) == beam_size + paddle.save(self.encoder_out, "encoder.out") + hyp_list = [] + for hyp in hyps: + hyp_content = hyp[0] + # Prevent the hyp is empty + if len(hyp_content) == 0: + hyp_content = (self.model.ctc.blank_id, ) + hyp_content = paddle.to_tensor( + hyp_content, place=self.device, dtype=paddle.long) + hyp_list.append(hyp_content) + hyps_pad = pad_sequence(hyp_list, True, self.model.ignore_id) + hyps_lens = paddle.to_tensor( + [len(hyp[0]) for hyp in hyps], place=self.device, + dtype=paddle.long) # (beam_size,) + hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos, + self.model.ignore_id) + hyps_lens = hyps_lens + 1 # Add at begining + + encoder_out = self.encoder_out.repeat(beam_size, 1, 1) + encoder_mask = paddle.ones( + (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool) + decoder_out, _ = self.model.decoder( + encoder_out, encoder_mask, hyps_pad, + hyps_lens) # (beam_size, max_hyps_len, vocab_size) + # ctc score in ln domain + decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) + decoder_out = decoder_out.numpy() + + # Only use decoder score for rescoring + best_score = -float('inf') + best_index = 0 + # hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size + for i, hyp in enumerate(hyps): + score = 0.0 + for j, w in enumerate(hyp[0]): + score += decoder_out[i][j][w] + # last decoder output token is `eos`, for laste decoder input token. + score += decoder_out[i][len(hyp[0])][self.model.eos] + # add ctc score (which in ln domain) + score += hyp[1] * self.ctc_decode_config.ctc_weight + if score > best_score: + best_score = score + best_index = i + # update the one best result + logger.info(f"best index: {best_index}") + self.hyps = [hyps[best_index][0]] + self.update_result() + # return hyps[best_index][0] @@ -552,7 +639,7 @@ class ASRServerExecutor(ASRExecutor): subsampling = self.model.encoder.embed.subsampling_rate context = self.model.encoder.embed.right_context + 1 stride = subsampling * decoding_chunk_size - + # decoding window for model decoding_window = (decoding_chunk_size - 1) * subsampling + context num_frames = xs.shape[1] diff --git a/paddlespeech/server/engine/asr/online/ctc_search.py b/paddlespeech/server/engine/asr/online/ctc_search.py index bf4c4b30..c3822b5c 100644 --- a/paddlespeech/server/engine/asr/online/ctc_search.py +++ b/paddlespeech/server/engine/asr/online/ctc_search.py @@ -110,6 +110,11 @@ class CTCPrefixBeamSearch: return [self.hyps[0][0]] def get_hyps(self): + """Return the search hyps + + Returns: + list: return the search hyps + """ return self.hyps def reset(self): @@ -117,3 +122,8 @@ class CTCPrefixBeamSearch: """ self.cur_hyps = None self.hyps = None + + def finalize_search(self): + """do nothing in ctc_prefix_beam_search + """ + pass diff --git a/paddlespeech/server/ws/asr_socket.py b/paddlespeech/server/ws/asr_socket.py index 04807e5c..ae7c5eb4 100644 --- a/paddlespeech/server/ws/asr_socket.py +++ b/paddlespeech/server/ws/asr_socket.py @@ -13,16 +13,15 @@ # limitations under the License. import json -import numpy as np from fastapi import APIRouter from fastapi import WebSocket from fastapi import WebSocketDisconnect from starlette.websockets import WebSocketState as WebSocketState +from paddlespeech.server.engine.asr.online.asr_engine import PaddleASRConnectionHanddler from paddlespeech.server.engine.engine_pool import get_engine_pool from paddlespeech.server.utils.buffer import ChunkBuffer from paddlespeech.server.utils.vad import VADAudio -from paddlespeech.server.engine.asr.online.asr_engine import PaddleASRConnectionHanddler router = APIRouter() @@ -73,13 +72,17 @@ async def websocket_endpoint(websocket: WebSocket): connection_handler = PaddleASRConnectionHanddler(asr_engine) await websocket.send_json(resp) elif message['signal'] == 'end': - engine_pool = get_engine_pool() - asr_engine = engine_pool['asr'] # reset single engine for an new connection - asr_results = connection_handler.decode(is_finished=True) + connection_handler.decode(is_finished=True) + connection_handler.rescoring() + asr_results = connection_handler.get_result() connection_handler.reset() - asr_engine.reset() - resp = {"status": "ok", "signal": "finished", 'asr_results': asr_results} + + resp = { + "status": "ok", + "signal": "finished", + 'asr_results': asr_results + } await websocket.send_json(resp) break else: @@ -87,25 +90,11 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.send_json(resp) elif "bytes" in message: message = message["bytes"] - engine_pool = get_engine_pool() - asr_engine = engine_pool['asr'] - asr_results = "" + connection_handler.extract_feat(message) - asr_results = connection_handler.decode(is_finished=False) - # connection_handler. - # frames = chunk_buffer.frame_generator(message) - # for frame in frames: - # # get the pcm data from the bytes - # samples = np.frombuffer(frame.bytes, dtype=np.int16) - # sample_rate = asr_engine.config.sample_rate - # x_chunk, x_chunk_lens = asr_engine.preprocess(samples, - # sample_rate) - # asr_engine.run(x_chunk, x_chunk_lens) - # asr_results = asr_engine.postprocess() + connection_handler.decode(is_finished=False) + asr_results = connection_handler.get_result() - # # connection accept the sample data frame by frame - - # asr_results = asr_engine.postprocess() resp = {'asr_results': asr_results} print("\n") await websocket.send_json(resp) From 5acb0b5252e77018fdca05435c97638ac48f5d6a Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Mon, 18 Apr 2022 21:46:57 +0800 Subject: [PATCH 09/16] fix the websocket chunk edge bug, test=doc --- .../server/engine/asr/online/asr_engine.py | 121 ++++++++++-------- paddlespeech/server/ws/asr_socket.py | 1 - 2 files changed, 66 insertions(+), 56 deletions(-) diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index 696d223a..a8e25f4b 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -60,9 +60,9 @@ pretrained_models = { }, "conformer2online_aishell-zh-16k": { 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.1.model.tar.gz', + 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.3.model.tar.gz', 'md5': - 'b450d5dfaea0ac227c595ce58d18b637', + '0ac93d390552336f2a906aec9e33c5fa', 'cfg_path': 'model.yaml', 'ckpt_path': @@ -78,12 +78,19 @@ pretrained_models = { }, } -# ASR server connection process class +# ASR server connection process class class PaddleASRConnectionHanddler: def __init__(self, asr_engine): + """Init a Paddle ASR Connection Handler instance + + Args: + asr_engine (ASREngine): the global asr engine + """ super().__init__() - logger.info("create an paddle asr connection handler to process the websocket connection") + logger.info( + "create an paddle asr connection handler to process the websocket connection" + ) self.config = asr_engine.config self.model_config = asr_engine.executor.config self.model = asr_engine.executor.model @@ -98,24 +105,26 @@ class PaddleASRConnectionHanddler: pass elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type: self.sample_rate = self.asr_engine.executor.sample_rate - + # acoustic model self.model = self.asr_engine.executor.model - + # tokens to text self.text_feature = self.asr_engine.executor.text_feature - - # ctc decoding + + # ctc decoding config self.ctc_decode_config = self.asr_engine.executor.config.decode self.searcher = CTCPrefixBeamSearch(self.ctc_decode_config) - # extract fbank + # extract feat, new only fbank in conformer model self.preprocess_conf = self.model_config.preprocess_config self.preprocess_args = {"train": False} self.preprocessing = Transformation(self.preprocess_conf) + + # frame window samples length and frame shift samples length self.win_length = self.preprocess_conf.process[0]['win_length'] self.n_shift = self.preprocess_conf.process[0]['n_shift'] - + def extract_feat(self, samples): if "deepspeech2online" in self.model_type: pass @@ -123,10 +132,10 @@ class PaddleASRConnectionHanddler: logger.info("Online ASR extract the feat") samples = np.frombuffer(samples, dtype=np.int16) assert samples.ndim == 1 - + logger.info(f"This package receive {samples.shape[0]} pcm data") self.num_samples += samples.shape[0] - + # self.reamined_wav stores all the samples, # include the original remained_wav and this package samples if self.remained_wav is None: @@ -141,19 +150,21 @@ class PaddleASRConnectionHanddler: return 0 # fbank - x_chunk = self.preprocessing(self.remained_wav, **self.preprocess_args) + x_chunk = self.preprocessing(self.remained_wav, + **self.preprocess_args) x_chunk = paddle.to_tensor( x_chunk, dtype="float32").unsqueeze(axis=0) if self.cached_feat is None: self.cached_feat = x_chunk else: - assert(len(x_chunk.shape) == 3) - assert(len(self.cached_feat.shape) == 3) - self.cached_feat = paddle.concat([self.cached_feat, x_chunk], axis=1) - + assert (len(x_chunk.shape) == 3) + assert (len(self.cached_feat.shape) == 3) + self.cached_feat = paddle.concat( + [self.cached_feat, x_chunk], axis=1) + # set the feat device if self.device is None: - self.device = self.cached_feat.place + self.device = self.cached_feat.place num_frames = x_chunk.shape[1] self.num_frames += num_frames @@ -161,7 +172,7 @@ class PaddleASRConnectionHanddler: logger.info( f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}" - ) + ) logger.info( f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" ) @@ -209,24 +220,30 @@ class PaddleASRConnectionHanddler: subsampling = self.model.encoder.embed.subsampling_rate context = self.model.encoder.embed.right_context + 1 stride = subsampling * decoding_chunk_size - cached_feature_num = context - subsampling # processed chunk feature cached for next chunk + cached_feature_num = context - subsampling # processed chunk feature cached for next chunk # decoding window for model decoding_window = (decoding_chunk_size - 1) * subsampling + context if self.cached_feat is None: logger.info("no audio feat, please input more pcm data") return - + num_frames = self.cached_feat.shape[1] - logger.info(f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames") - + logger.info( + f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" + ) + # the cached feat must be larger decoding_window if num_frames < decoding_window and not is_finished: - logger.info(f"frame feat num is less than {decoding_window}, please input more pcm data") + logger.info( + f"frame feat num is less than {decoding_window}, please input more pcm data" + ) return None, None if num_frames < context: - logger.info("flast {num_frames} is less than context {context} frames, and we cannot do model forward") + logger.info( + "flast {num_frames} is less than context {context} frames, and we cannot do model forward" + ) return None, None logger.info("start to do model forward") @@ -235,17 +252,17 @@ class PaddleASRConnectionHanddler: # num_frames - context + 1 ensure that current frame can get context window if is_finished: - # if get the finished chunk, we need process the last context + # if get the finished chunk, we need process the last context left_frames = context else: # we only process decoding_window frames for one chunk - left_frames = decoding_window - + left_frames = decoding_window + # record the end for removing the processed feat end = None for cur in range(0, num_frames - left_frames + 1, stride): end = min(cur + decoding_window, num_frames) - + self.chunk_num += 1 chunk_xs = self.cached_feat[:, cur:end, :] (y, self.subsampling_cache, self.elayers_output_cache, @@ -257,35 +274,31 @@ class PaddleASRConnectionHanddler: # update the offset self.offset += y.shape[1] - - logger.info(f"output size: {len(outputs)}") + ys = paddle.cat(outputs, 1) if self.encoder_out is None: - self.encoder_out = ys + self.encoder_out = ys else: self.encoder_out = paddle.concat([self.encoder_out, ys], axis=1) - # masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) - # masks = masks.unsqueeze(1) - + # get the ctc probs ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size) ctc_probs = ctc_probs.squeeze(0) self.searcher.search(None, ctc_probs, self.cached_feat.place) - + self.hyps = self.searcher.get_one_best_hyps() + assert self.cached_feat.shape[0] == 1 + assert end >= cached_feature_num + self.cached_feat = self.cached_feat[0, end - + cached_feature_num:, :].unsqueeze(0) + assert len( + self.cached_feat.shape + ) == 3, f"current cache feat shape is: {self.cached_feat.shape}" - # remove the processed feat - if end == num_frames: - self.cached_feat = None - else: - assert self.cached_feat.shape[0] == 1 - assert end >= cached_feature_num - self.cached_feat = self.cached_feat[0,end - cached_feature_num:,:].unsqueeze(0) - assert len(self.cached_feat.shape) == 3, f"current cache feat shape is: {self.cached_feat.shape}" - - # ys for rescoring - # return ys, masks + logger.info( + f"This connection handler encoder out shape: {self.encoder_out.shape}" + ) def update_result(self): logger.info("update the final result") @@ -304,8 +317,8 @@ class PaddleASRConnectionHanddler: def rescoring(self): logger.info("rescoring the final result") if "attention_rescoring" != self.ctc_decode_config.decoding_method: - return - + return + self.searcher.finalize_search() self.update_result() @@ -363,8 +376,6 @@ class PaddleASRConnectionHanddler: logger.info(f"best index: {best_index}") self.hyps = [hyps[best_index][0]] self.update_result() - # return hyps[best_index][0] - class ASRServerExecutor(ASRExecutor): @@ -409,9 +420,9 @@ class ASRServerExecutor(ASRExecutor): logger.info(f"Load the pretrained model, tag = {tag}") res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.res_path = res_path - self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/examples/aishell/asr1/model.yaml" - # self.cfg_path = os.path.join(res_path, - # pretrained_models[tag]['cfg_path']) + + self.cfg_path = os.path.join(res_path, + pretrained_models[tag]['cfg_path']) self.am_model = os.path.join(res_path, pretrained_models[tag]['model']) @@ -639,7 +650,7 @@ class ASRServerExecutor(ASRExecutor): subsampling = self.model.encoder.embed.subsampling_rate context = self.model.encoder.embed.right_context + 1 stride = subsampling * decoding_chunk_size - + # decoding window for model decoding_window = (decoding_chunk_size - 1) * subsampling + context num_frames = xs.shape[1] diff --git a/paddlespeech/server/ws/asr_socket.py b/paddlespeech/server/ws/asr_socket.py index ae7c5eb4..82b05bc5 100644 --- a/paddlespeech/server/ws/asr_socket.py +++ b/paddlespeech/server/ws/asr_socket.py @@ -96,7 +96,6 @@ async def websocket_endpoint(websocket: WebSocket): asr_results = connection_handler.get_result() resp = {'asr_results': asr_results} - print("\n") await websocket.send_json(resp) except WebSocketDisconnect: pass From 380afbbc5d828f81204a5b9ab9088d4491ba0b70 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Tue, 19 Apr 2022 16:18:42 +0800 Subject: [PATCH 10/16] add ds2 model multi session, test=doc --- paddlespeech/server/conf/ws_application.yaml | 50 +--- .../server/conf/ws_conformer_application.yaml | 45 ++++ .../server/engine/asr/online/asr_engine.py | 224 ++++++++++++++++-- 3 files changed, 263 insertions(+), 56 deletions(-) create mode 100644 paddlespeech/server/conf/ws_conformer_application.yaml diff --git a/paddlespeech/server/conf/ws_application.yaml b/paddlespeech/server/conf/ws_application.yaml index aa3c208b..dae4a3ff 100644 --- a/paddlespeech/server/conf/ws_application.yaml +++ b/paddlespeech/server/conf/ws_application.yaml @@ -18,44 +18,10 @@ engine_list: ['asr_online'] # ENGINE CONFIG # ################################################################################# -# ################################### ASR ######################################### -# ################### speech task: asr; engine_type: online ####################### -# asr_online: -# model_type: 'deepspeech2online_aishell' -# am_model: # the pdmodel file of am static model [optional] -# am_params: # the pdiparams file of am static model [optional] -# lang: 'zh' -# sample_rate: 16000 -# cfg_path: -# decode_method: -# force_yes: True - -# am_predictor_conf: -# device: # set 'gpu:id' or 'cpu' -# switch_ir_optim: True -# glog_info: False # True -> print glog -# summary: True # False -> do not show predictor config - -# chunk_buffer_conf: -# frame_duration_ms: 80 -# shift_ms: 40 -# sample_rate: 16000 -# sample_width: 2 - -# vad_conf: -# aggressiveness: 2 -# sample_rate: 16000 -# frame_duration_ms: 20 -# sample_width: 2 -# padding_ms: 200 -# padding_ratio: 0.9 - - - ################################### ASR ######################################### ################### speech task: asr; engine_type: online ####################### asr_online: - model_type: 'conformer2online_aishell' + model_type: 'deepspeech2online_aishell' am_model: # the pdmodel file of am static model [optional] am_params: # the pdiparams file of am static model [optional] lang: 'zh' @@ -71,9 +37,19 @@ asr_online: summary: True # False -> do not show predictor config chunk_buffer_conf: + frame_duration_ms: 80 + shift_ms: 40 + sample_rate: 16000 + sample_width: 2 window_n: 7 # frame shift_n: 4 # frame - window_ms: 25 # ms + window_ms: 20 # ms shift_ms: 10 # ms + + vad_conf: + aggressiveness: 2 sample_rate: 16000 - sample_width: 2 \ No newline at end of file + frame_duration_ms: 20 + sample_width: 2 + padding_ms: 200 + padding_ratio: 0.9 \ No newline at end of file diff --git a/paddlespeech/server/conf/ws_conformer_application.yaml b/paddlespeech/server/conf/ws_conformer_application.yaml new file mode 100644 index 00000000..1a775f85 --- /dev/null +++ b/paddlespeech/server/conf/ws_conformer_application.yaml @@ -0,0 +1,45 @@ +# This is the parameter configuration file for PaddleSpeech Serving. + +################################################################################# +# SERVER SETTING # +################################################################################# +host: 0.0.0.0 +port: 8090 + +# The task format in the engin_list is: _ +# task choices = ['asr_online', 'tts_online'] +# protocol = ['websocket', 'http'] (only one can be selected). +# websocket only support online engine type. +protocol: 'websocket' +engine_list: ['asr_online'] + + +################################################################################# +# ENGINE CONFIG # +################################################################################# + +################################### ASR ######################################### +################### speech task: asr; engine_type: online ####################### +asr_online: + model_type: 'conformer2online_aishell' + am_model: # the pdmodel file of am static model [optional] + am_params: # the pdiparams file of am static model [optional] + lang: 'zh' + sample_rate: 16000 + cfg_path: + decode_method: + force_yes: True + + am_predictor_conf: + device: # set 'gpu:id' or 'cpu' + switch_ir_optim: True + glog_info: False # True -> print glog + summary: True # False -> do not show predictor config + + chunk_buffer_conf: + window_n: 7 # frame + shift_n: 4 # frame + window_ms: 25 # ms + shift_ms: 10 # ms + sample_rate: 16000 + sample_width: 2 \ No newline at end of file diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index a8e25f4b..77eb5a21 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -13,7 +13,7 @@ # limitations under the License. import os from typing import Optional - +import copy import numpy as np import paddle from numpy import float32 @@ -93,7 +93,7 @@ class PaddleASRConnectionHanddler: ) self.config = asr_engine.config self.model_config = asr_engine.executor.config - self.model = asr_engine.executor.model + # self.model = asr_engine.executor.model self.asr_engine = asr_engine self.init() @@ -102,7 +102,32 @@ class PaddleASRConnectionHanddler: def init(self): self.model_type = self.asr_engine.executor.model_type if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: - pass + from paddlespeech.s2t.io.collator import SpeechCollator + self.sample_rate = self.asr_engine.executor.sample_rate + self.am_predictor = self.asr_engine.executor.am_predictor + self.text_feature = self.asr_engine.executor.text_feature + self.collate_fn_test = SpeechCollator.from_config(self.model_config) + self.decoder = CTCDecoder( + odim=self.model_config.output_dim, # is in vocab + enc_n_units=self.model_config.rnn_layer_size * 2, + blank_id=self.model_config.blank_id, + dropout_rate=0.0, + reduction=True, # sum + batch_average=True, # sum / batch_size + grad_norm_type=self.model_config.get('ctc_grad_norm_type', None)) + + cfg = self.model_config.decode + decode_batch_size = 1 # for online + self.decoder.init_decoder( + decode_batch_size, self.text_feature.vocab_list, + cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, + cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, + cfg.num_proc_bsearch) + # frame window samples length and frame shift samples length + + self.win_length = int(self.model_config.window_ms * self.sample_rate) + self.n_shift = int(self.model_config.stride_ms * self.sample_rate) + elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type: self.sample_rate = self.asr_engine.executor.sample_rate @@ -127,7 +152,65 @@ class PaddleASRConnectionHanddler: def extract_feat(self, samples): if "deepspeech2online" in self.model_type: - pass + # self.reamined_wav stores all the samples, + # include the original remained_wav and this package samples + samples = np.frombuffer(samples, dtype=np.int16) + assert samples.ndim == 1 + + if self.remained_wav is None: + self.remained_wav = samples + else: + assert self.remained_wav.ndim == 1 + self.remained_wav = np.concatenate([self.remained_wav, samples]) + logger.info( + f"The connection remain the audio samples: {self.remained_wav.shape}" + ) + + # pcm16 -> pcm 32 + samples = pcm2float(self.remained_wav) + # read audio + speech_segment = SpeechSegment.from_pcm( + samples, self.sample_rate, transcript=" ") + # audio augment + self.collate_fn_test.augmentation.transform_audio(speech_segment) + + # extract speech feature + spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize( + speech_segment, self.collate_fn_test.keep_transcription_text) + # CMVN spectrum + if self.collate_fn_test._normalizer: + spectrum = self.collate_fn_test._normalizer.apply(spectrum) + + # spectrum augment + audio = self.collate_fn_test.augmentation.transform_feature( + spectrum) + + audio_len = audio.shape[0] + audio = paddle.to_tensor(audio, dtype='float32') + # audio_len = paddle.to_tensor(audio_len) + audio = paddle.unsqueeze(audio, axis=0) + + if self.cached_feat is None: + self.cached_feat = audio + else: + assert (len(audio.shape) == 3) + assert (len(self.cached_feat.shape) == 3) + self.cached_feat = paddle.concat( + [self.cached_feat, audio], axis=1) + + # set the feat device + if self.device is None: + self.device = self.cached_feat.place + + self.num_frames += audio_len + self.remained_wav = self.remained_wav[self.n_shift * audio_len:] + + logger.info( + f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}" + ) + logger.info( + f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" + ) elif "conformer2online" in self.model_type: logger.info("Online ASR extract the feat") samples = np.frombuffer(samples, dtype=np.int16) @@ -179,24 +262,81 @@ class PaddleASRConnectionHanddler: # logger.info(f"accumulate samples: {self.num_samples}") def reset(self): - self.subsampling_cache = None - self.elayers_output_cache = None - self.conformer_cnn_cache = None - self.encoder_out = None - self.cached_feat = None - self.remained_wav = None - self.offset = 0 - self.num_samples = 0 - self.device = None - self.hyps = [] - self.num_frames = 0 - self.chunk_num = 0 - self.global_frame_offset = 0 - self.result_transcripts = [''] + if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: + # for deepspeech2 + self.chunk_state_h_box = copy.deepcopy(self.asr_engine.executor.chunk_state_h_box) + self.chunk_state_c_box = copy.deepcopy(self.asr_engine.executor.chunk_state_c_box) + self.decoder.reset_decoder(batch_size=1) + elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type: + # for conformer online + self.subsampling_cache = None + self.elayers_output_cache = None + self.conformer_cnn_cache = None + self.encoder_out = None + self.cached_feat = None + self.remained_wav = None + self.offset = 0 + self.num_samples = 0 + self.device = None + self.hyps = [] + self.num_frames = 0 + self.chunk_num = 0 + self.global_frame_offset = 0 + self.result_transcripts = [''] def decode(self, is_finished=False): if "deepspeech2online" in self.model_type: - pass + # x_chunk 是特征数据 + decoding_chunk_size = 1 # decoding_chunk_size=1 in deepspeech2 model + context = 7 # context=7 in deepspeech2 model + subsampling = 4 # subsampling=4 in deepspeech2 model + stride = subsampling * decoding_chunk_size + cached_feature_num = context - subsampling + # decoding window for model + decoding_window = (decoding_chunk_size - 1) * subsampling + context + + if self.cached_feat is None: + logger.info("no audio feat, please input more pcm data") + return + + num_frames = self.cached_feat.shape[1] + logger.info( + f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" + ) + # the cached feat must be larger decoding_window + if num_frames < decoding_window and not is_finished: + logger.info( + f"frame feat num is less than {decoding_window}, please input more pcm data" + ) + return None, None + + # if is_finished=True, we need at least context frames + if num_frames < context: + logger.info( + "flast {num_frames} is less than context {context} frames, and we cannot do model forward" + ) + return None, None + logger.info("start to do model forward") + # num_frames - context + 1 ensure that current frame can get context window + if is_finished: + # if get the finished chunk, we need process the last context + left_frames = context + else: + # we only process decoding_window frames for one chunk + left_frames = decoding_window + + for cur in range(0, num_frames - left_frames + 1, stride): + end = min(cur + decoding_window, num_frames) + # extract the audio + x_chunk = self.cached_feat[:, cur:end, :].numpy() + x_chunk_lens = np.array([x_chunk.shape[1]]) + trans_best = self.decode_one_chunk(x_chunk, x_chunk_lens) + + self.result_transcripts = [trans_best] + + self.cached_feat = self.cached_feat[:, end - + cached_feature_num:, :] + # return trans_best[0] elif "conformer" in self.model_type or "transformer" in self.model_type: try: logger.info( @@ -210,6 +350,48 @@ class PaddleASRConnectionHanddler: else: raise Exception("invalid model name") + def decode_one_chunk(self, x_chunk, x_chunk_lens): + logger.info("start to decoce one chunk with deepspeech2 model") + input_names = self.am_predictor.get_input_names() + audio_handle = self.am_predictor.get_input_handle(input_names[0]) + audio_len_handle = self.am_predictor.get_input_handle( + input_names[1]) + h_box_handle = self.am_predictor.get_input_handle(input_names[2]) + c_box_handle = self.am_predictor.get_input_handle(input_names[3]) + + audio_handle.reshape(x_chunk.shape) + audio_handle.copy_from_cpu(x_chunk) + + audio_len_handle.reshape(x_chunk_lens.shape) + audio_len_handle.copy_from_cpu(x_chunk_lens) + + h_box_handle.reshape(self.chunk_state_h_box.shape) + h_box_handle.copy_from_cpu(self.chunk_state_h_box) + + c_box_handle.reshape(self.chunk_state_c_box.shape) + c_box_handle.copy_from_cpu(self.chunk_state_c_box) + + output_names = self.am_predictor.get_output_names() + output_handle = self.am_predictor.get_output_handle(output_names[0]) + output_lens_handle = self.am_predictor.get_output_handle( + output_names[1]) + output_state_h_handle = self.am_predictor.get_output_handle( + output_names[2]) + output_state_c_handle = self.am_predictor.get_output_handle( + output_names[3]) + + self.am_predictor.run() + + output_chunk_probs = output_handle.copy_to_cpu() + output_chunk_lens = output_lens_handle.copy_to_cpu() + self.chunk_state_h_box = output_state_h_handle.copy_to_cpu() + self.chunk_state_c_box = output_state_c_handle.copy_to_cpu() + + self.decoder.next(output_chunk_probs, output_chunk_lens) + trans_best, trans_beam = self.decoder.decode() + logger.info(f"decode one one best result: {trans_best[0]}") + return trans_best[0] + def advance_decoding(self, is_finished=False): logger.info("start to decode with advanced_decoding method") cfg = self.ctc_decode_config @@ -240,6 +422,7 @@ class PaddleASRConnectionHanddler: ) return None, None + # if is_finished=True, we need at least context frames if num_frames < context: logger.info( "flast {num_frames} is less than context {context} frames, and we cannot do model forward" @@ -315,6 +498,9 @@ class PaddleASRConnectionHanddler: return '' def rescoring(self): + if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: + return + logger.info("rescoring the final result") if "attention_rescoring" != self.ctc_decode_config.decoding_method: return From f56dba0ca7da29aa5ad11f5ad83e4ee62f1a2fa4 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Tue, 19 Apr 2022 17:57:50 +0800 Subject: [PATCH 11/16] fix the code format, test=doc --- paddlespeech/cli/asr/infer.py | 2 +- .../server/conf/ws_conformer_application.yaml | 2 +- .../server/engine/asr/online/asr_engine.py | 123 +++++++++--------- 3 files changed, 64 insertions(+), 63 deletions(-) diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 53f71a70..f1e46ca1 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -129,7 +129,7 @@ model_alias = { "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline", "conformer": "paddlespeech.s2t.models.u2:U2Model", - "conformer2online": + "conformer_online": "paddlespeech.s2t.models.u2:U2Model", "transformer": "paddlespeech.s2t.models.u2:U2Model", diff --git a/paddlespeech/server/conf/ws_conformer_application.yaml b/paddlespeech/server/conf/ws_conformer_application.yaml index 1a775f85..89a861ef 100644 --- a/paddlespeech/server/conf/ws_conformer_application.yaml +++ b/paddlespeech/server/conf/ws_conformer_application.yaml @@ -21,7 +21,7 @@ engine_list: ['asr_online'] ################################### ASR ######################################### ################### speech task: asr; engine_type: online ####################### asr_online: - model_type: 'conformer2online_aishell' + model_type: 'conformer_online_multi-cn' am_model: # the pdmodel file of am static model [optional] am_params: # the pdiparams file of am static model [optional] lang: 'zh' diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index 77eb5a21..3c2b066c 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import os from typing import Optional -import copy + import numpy as np import paddle from numpy import float32 @@ -58,7 +59,7 @@ pretrained_models = { 'lm_md5': '29e02312deb2e59b3c8686c7966d4fe3' }, - "conformer2online_aishell-zh-16k": { + "conformer_online_multi-cn-zh-16k": { 'url': 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.3.model.tar.gz', 'md5': @@ -93,19 +94,22 @@ class PaddleASRConnectionHanddler: ) self.config = asr_engine.config self.model_config = asr_engine.executor.config - # self.model = asr_engine.executor.model self.asr_engine = asr_engine self.init() self.reset() def init(self): + # model_type, sample_rate and text_feature is shared for deepspeech2 and conformer self.model_type = self.asr_engine.executor.model_type + self.sample_rate = self.asr_engine.executor.sample_rate + # tokens to text + self.text_feature = self.asr_engine.executor.text_feature + if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: from paddlespeech.s2t.io.collator import SpeechCollator - self.sample_rate = self.asr_engine.executor.sample_rate self.am_predictor = self.asr_engine.executor.am_predictor - self.text_feature = self.asr_engine.executor.text_feature + self.collate_fn_test = SpeechCollator.from_config(self.model_config) self.decoder = CTCDecoder( odim=self.model_config.output_dim, # is in vocab @@ -114,7 +118,8 @@ class PaddleASRConnectionHanddler: dropout_rate=0.0, reduction=True, # sum batch_average=True, # sum / batch_size - grad_norm_type=self.model_config.get('ctc_grad_norm_type', None)) + grad_norm_type=self.model_config.get('ctc_grad_norm_type', + None)) cfg = self.model_config.decode decode_batch_size = 1 # for online @@ -123,20 +128,16 @@ class PaddleASRConnectionHanddler: cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, cfg.num_proc_bsearch) - # frame window samples length and frame shift samples length - - self.win_length = int(self.model_config.window_ms * self.sample_rate) - self.n_shift = int(self.model_config.stride_ms * self.sample_rate) + # frame window samples length and frame shift samples length - elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type: - self.sample_rate = self.asr_engine.executor.sample_rate + self.win_length = int(self.model_config.window_ms * + self.sample_rate) + self.n_shift = int(self.model_config.stride_ms * self.sample_rate) + elif "conformer" in self.model_type or "transformer" in self.model_type: # acoustic model self.model = self.asr_engine.executor.model - # tokens to text - self.text_feature = self.asr_engine.executor.text_feature - # ctc decoding config self.ctc_decode_config = self.asr_engine.executor.config.decode self.searcher = CTCPrefixBeamSearch(self.ctc_decode_config) @@ -189,7 +190,7 @@ class PaddleASRConnectionHanddler: audio = paddle.to_tensor(audio, dtype='float32') # audio_len = paddle.to_tensor(audio_len) audio = paddle.unsqueeze(audio, axis=0) - + if self.cached_feat is None: self.cached_feat = audio else: @@ -211,7 +212,7 @@ class PaddleASRConnectionHanddler: logger.info( f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" ) - elif "conformer2online" in self.model_type: + elif "conformer_online" in self.model_type: logger.info("Online ASR extract the feat") samples = np.frombuffer(samples, dtype=np.int16) assert samples.ndim == 1 @@ -264,41 +265,43 @@ class PaddleASRConnectionHanddler: def reset(self): if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: # for deepspeech2 - self.chunk_state_h_box = copy.deepcopy(self.asr_engine.executor.chunk_state_h_box) - self.chunk_state_c_box = copy.deepcopy(self.asr_engine.executor.chunk_state_c_box) + self.chunk_state_h_box = copy.deepcopy( + self.asr_engine.executor.chunk_state_h_box) + self.chunk_state_c_box = copy.deepcopy( + self.asr_engine.executor.chunk_state_c_box) self.decoder.reset_decoder(batch_size=1) - elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type: - # for conformer online - self.subsampling_cache = None - self.elayers_output_cache = None - self.conformer_cnn_cache = None - self.encoder_out = None - self.cached_feat = None - self.remained_wav = None - self.offset = 0 - self.num_samples = 0 - self.device = None - self.hyps = [] - self.num_frames = 0 - self.chunk_num = 0 - self.global_frame_offset = 0 - self.result_transcripts = [''] + + # for conformer online + self.subsampling_cache = None + self.elayers_output_cache = None + self.conformer_cnn_cache = None + self.encoder_out = None + self.cached_feat = None + self.remained_wav = None + self.offset = 0 + self.num_samples = 0 + self.device = None + self.hyps = [] + self.num_frames = 0 + self.chunk_num = 0 + self.global_frame_offset = 0 + self.result_transcripts = [''] def decode(self, is_finished=False): if "deepspeech2online" in self.model_type: # x_chunk 是特征数据 - decoding_chunk_size = 1 # decoding_chunk_size=1 in deepspeech2 model - context = 7 # context=7 in deepspeech2 model - subsampling = 4 # subsampling=4 in deepspeech2 model + decoding_chunk_size = 1 # decoding_chunk_size=1 in deepspeech2 model + context = 7 # context=7 in deepspeech2 model + subsampling = 4 # subsampling=4 in deepspeech2 model stride = subsampling * decoding_chunk_size cached_feature_num = context - subsampling # decoding window for model - decoding_window = (decoding_chunk_size - 1) * subsampling + context - + decoding_window = (decoding_chunk_size - 1) * subsampling + context + if self.cached_feat is None: logger.info("no audio feat, please input more pcm data") - return - + return + num_frames = self.cached_feat.shape[1] logger.info( f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" @@ -306,14 +309,14 @@ class PaddleASRConnectionHanddler: # the cached feat must be larger decoding_window if num_frames < decoding_window and not is_finished: logger.info( - f"frame feat num is less than {decoding_window}, please input more pcm data" + f"frame feat num is less than {decoding_window}, please input more pcm data" ) return None, None # if is_finished=True, we need at least context frames if num_frames < context: logger.info( - "flast {num_frames} is less than context {context} frames, and we cannot do model forward" + "flast {num_frames} is less than context {context} frames, and we cannot do model forward" ) return None, None logger.info("start to do model forward") @@ -334,8 +337,7 @@ class PaddleASRConnectionHanddler: self.result_transcripts = [trans_best] - self.cached_feat = self.cached_feat[:, end - - cached_feature_num:, :] + self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :] # return trans_best[0] elif "conformer" in self.model_type or "transformer" in self.model_type: try: @@ -354,8 +356,7 @@ class PaddleASRConnectionHanddler: logger.info("start to decoce one chunk with deepspeech2 model") input_names = self.am_predictor.get_input_names() audio_handle = self.am_predictor.get_input_handle(input_names[0]) - audio_len_handle = self.am_predictor.get_input_handle( - input_names[1]) + audio_len_handle = self.am_predictor.get_input_handle(input_names[1]) h_box_handle = self.am_predictor.get_input_handle(input_names[2]) c_box_handle = self.am_predictor.get_input_handle(input_names[3]) @@ -374,11 +375,11 @@ class PaddleASRConnectionHanddler: output_names = self.am_predictor.get_output_names() output_handle = self.am_predictor.get_output_handle(output_names[0]) output_lens_handle = self.am_predictor.get_output_handle( - output_names[1]) + output_names[1]) output_state_h_handle = self.am_predictor.get_output_handle( - output_names[2]) + output_names[2]) output_state_c_handle = self.am_predictor.get_output_handle( - output_names[3]) + output_names[3]) self.am_predictor.run() @@ -389,7 +390,7 @@ class PaddleASRConnectionHanddler: self.decoder.next(output_chunk_probs, output_chunk_lens) trans_best, trans_beam = self.decoder.decode() - logger.info(f"decode one one best result: {trans_best[0]}") + logger.info(f"decode one best result: {trans_best[0]}") return trans_best[0] def advance_decoding(self, is_finished=False): @@ -500,7 +501,7 @@ class PaddleASRConnectionHanddler: def rescoring(self): if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: return - + logger.info("rescoring the final result") if "attention_rescoring" != self.ctc_decode_config.decoding_method: return @@ -587,7 +588,7 @@ class ASRServerExecutor(ASRExecutor): return decompressed_path def _init_from_path(self, - model_type: str='wenetspeech', + model_type: str='deepspeech2online_aishell', am_model: Optional[os.PathLike]=None, am_params: Optional[os.PathLike]=None, lang: str='zh', @@ -647,7 +648,7 @@ class ASRServerExecutor(ASRExecutor): self.download_lm( lm_url, os.path.dirname(self.config.decode.lang_model_path), lm_md5) - elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: + elif "conformer" in model_type or "transformer" in model_type: logger.info("start to create the stream conformer asr engine") if self.config.spm_model_prefix: self.config.spm_model_prefix = os.path.join( @@ -711,7 +712,7 @@ class ASRServerExecutor(ASRExecutor): self.chunk_state_c_box = np.zeros( (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), dtype=float32) - elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: + elif "conformer" in model_type or "transformer" in model_type: model_name = model_type[:model_type.rindex( '_')] # model_type: {model_name}_{dataset} logger.info(f"model name: {model_name}") @@ -742,7 +743,7 @@ class ASRServerExecutor(ASRExecutor): self.chunk_state_c_box = np.zeros( (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), dtype=float32) - elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type: + elif "conformer" in self.model_type or "transformer" in self.model_type: self.transformer_decode_reset() def decode_one_chunk(self, x_chunk, x_chunk_lens, model_type: str): @@ -754,7 +755,7 @@ class ASRServerExecutor(ASRExecutor): model_type (str): online model type Returns: - [type]: [description] + str: one best result """ logger.info("start to decoce chunk by chunk") if "deepspeech2online" in model_type: @@ -795,7 +796,7 @@ class ASRServerExecutor(ASRExecutor): self.decoder.next(output_chunk_probs, output_chunk_lens) trans_best, trans_beam = self.decoder.decode() - logger.info(f"decode one one best result: {trans_best[0]}") + logger.info(f"decode one best result: {trans_best[0]}") return trans_best[0] elif "conformer" in model_type or "transformer" in model_type: @@ -972,7 +973,7 @@ class ASRServerExecutor(ASRExecutor): x_chunk_lens = np.array([audio_len]) return x_chunk, x_chunk_lens - elif "conformer2online" in self.model_type: + elif "conformer_online" in self.model_type: if sample_rate != self.sample_rate: logger.info(f"audio sample rate {sample_rate} is not match," @@ -1005,7 +1006,7 @@ class ASREngine(BaseEngine): def __init__(self): super(ASREngine, self).__init__() - logger.info("create the online asr engine instache") + logger.info("create the online asr engine instance") def init(self, config: dict) -> bool: """init engine resource From babac27a7943b5be254afab8af09e909b0d3151c Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Tue, 19 Apr 2022 18:14:30 +0800 Subject: [PATCH 12/16] fix ds2 online edge bug, test=doc --- paddlespeech/cli/asr/pretrained_models.py | 2 ++ .../server/engine/asr/online/asr_engine.py | 20 +++++++++++-------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/paddlespeech/cli/asr/pretrained_models.py b/paddlespeech/cli/asr/pretrained_models.py index a16c4750..cc52c751 100644 --- a/paddlespeech/cli/asr/pretrained_models.py +++ b/paddlespeech/cli/asr/pretrained_models.py @@ -88,6 +88,8 @@ model_alias = { "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline", "conformer": "paddlespeech.s2t.models.u2:U2Model", + "conformer_online": + "paddlespeech.s2t.models.u2:U2Model", "transformer": "paddlespeech.s2t.models.u2:U2Model", "wenetspeech": diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index 3c2b066c..4d15d93b 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -130,9 +130,10 @@ class PaddleASRConnectionHanddler: cfg.num_proc_bsearch) # frame window samples length and frame shift samples length - self.win_length = int(self.model_config.window_ms * + self.win_length = int(self.model_config.window_ms / 1000 * self.sample_rate) - self.n_shift = int(self.model_config.stride_ms * self.sample_rate) + self.n_shift = int(self.model_config.stride_ms / 1000 * + self.sample_rate) elif "conformer" in self.model_type or "transformer" in self.model_type: # acoustic model @@ -158,6 +159,11 @@ class PaddleASRConnectionHanddler: samples = np.frombuffer(samples, dtype=np.int16) assert samples.ndim == 1 + # pcm16 -> pcm 32 + # pcm2float will change the orignal samples, + # so we shoule do pcm2float before concatenate + samples = pcm2float(samples) + if self.remained_wav is None: self.remained_wav = samples else: @@ -167,11 +173,9 @@ class PaddleASRConnectionHanddler: f"The connection remain the audio samples: {self.remained_wav.shape}" ) - # pcm16 -> pcm 32 - samples = pcm2float(self.remained_wav) # read audio speech_segment = SpeechSegment.from_pcm( - samples, self.sample_rate, transcript=" ") + self.remained_wav, self.sample_rate, transcript=" ") # audio augment self.collate_fn_test.augmentation.transform_audio(speech_segment) @@ -474,6 +478,7 @@ class PaddleASRConnectionHanddler: self.hyps = self.searcher.get_one_best_hyps() assert self.cached_feat.shape[0] == 1 assert end >= cached_feature_num + self.cached_feat = self.cached_feat[0, end - cached_feature_num:, :].unsqueeze(0) assert len( @@ -515,7 +520,6 @@ class PaddleASRConnectionHanddler: return # assert len(hyps) == beam_size - paddle.save(self.encoder_out, "encoder.out") hyp_list = [] for hyp in hyps: hyp_content = hyp[0] @@ -815,7 +819,7 @@ class ASRServerExecutor(ASRExecutor): def advanced_decoding(self, xs: paddle.Tensor, x_chunk_lens): logger.info("start to decode with advanced_decoding method") - encoder_out, encoder_mask = self.decode_forward(xs) + encoder_out, encoder_mask = self.encoder_forward(xs) ctc_probs = self.model.ctc.log_softmax( encoder_out) # (1, maxlen, vocab_size) ctc_probs = ctc_probs.squeeze(0) @@ -827,7 +831,7 @@ class ASRServerExecutor(ASRExecutor): if "attention_rescoring" in self.config.decode.decoding_method: self.rescoring(encoder_out, xs.place) - def decode_forward(self, xs): + def encoder_forward(self, xs): logger.info("get the model out from the feat") cfg = self.config.decode decoding_chunk_size = cfg.decoding_chunk_size From 48fa84bee90d8fc8b9f5619f8e22e796b8a10aca Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Tue, 19 Apr 2022 20:15:18 +0800 Subject: [PATCH 13/16] fix the asr online client bug, return None, test=doc --- paddlespeech/s2t/modules/encoder.py | 2 -- paddlespeech/server/README.md | 13 +++++++++++++ paddlespeech/server/README_cn.md | 14 ++++++++++++++ paddlespeech/server/bin/paddlespeech_client.py | 6 ++++-- .../server/engine/asr/online/asr_engine.py | 4 ++-- .../server/engine/asr/online/ctc_search.py | 8 +++----- .../server/tests/asr/online/websocket_client.py | 11 ++++------- 7 files changed, 40 insertions(+), 18 deletions(-) diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index 347035cd..c843c0e2 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -317,8 +317,6 @@ class BaseEncoder(nn.Layer): outputs = [] offset = 0 # Feed forward overlap input step by step - print(f"context: {context}") - print(f"stride: {stride}") for cur in range(0, num_frames - context + 1, stride): end = min(cur + decoding_window, num_frames) chunk_xs = xs[:, cur:end, :] diff --git a/paddlespeech/server/README.md b/paddlespeech/server/README.md index 819fe440..3ac68dae 100644 --- a/paddlespeech/server/README.md +++ b/paddlespeech/server/README.md @@ -35,3 +35,16 @@ ```bash paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav ``` + + ## Online ASR Server + +### Lanuch online asr server +``` +paddlespeech_server start --config_file conf/ws_conformer_application.yaml +``` + +### Access online asr server + +``` +paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input input_16k.wav +``` \ No newline at end of file diff --git a/paddlespeech/server/README_cn.md b/paddlespeech/server/README_cn.md index c0a4a733..5f235313 100644 --- a/paddlespeech/server/README_cn.md +++ b/paddlespeech/server/README_cn.md @@ -35,3 +35,17 @@ ```bash paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav ``` + +## 流式ASR + +### 启动流式语音识别服务 + +``` +paddlespeech_server start --config_file conf/ws_conformer_application.yaml +``` + +### 访问流式语音识别服务 + +``` +paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input zh.wav +``` \ No newline at end of file diff --git a/paddlespeech/server/bin/paddlespeech_client.py b/paddlespeech/server/bin/paddlespeech_client.py index cb802ce5..45469178 100644 --- a/paddlespeech/server/bin/paddlespeech_client.py +++ b/paddlespeech/server/bin/paddlespeech_client.py @@ -277,11 +277,12 @@ class ASRClientExecutor(BaseExecutor): lang=lang, audio_format=audio_format) time_end = time.time() - logger.info(res.json()) + logger.info(res) logger.info("Response time %f s." % (time_end - time_start)) return True except Exception as e: logger.error("Failed to speech recognition.") + logger.error(e) return False @stats_wrapper @@ -299,9 +300,10 @@ class ASRClientExecutor(BaseExecutor): logging.info("asr websocket client start") handler = ASRAudioHandler(server_ip, port) loop = asyncio.get_event_loop() - loop.run_until_complete(handler.run(input)) + res = loop.run_until_complete(handler.run(input)) logging.info("asr websocket client finished") + return res['asr_results'] @cli_client_register( name='paddlespeech_client.cls', description='visit cls service') diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index 4d15d93b..c79abf1b 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -473,7 +473,7 @@ class PaddleASRConnectionHanddler: ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size) ctc_probs = ctc_probs.squeeze(0) - self.searcher.search(None, ctc_probs, self.cached_feat.place) + self.searcher.search(ctc_probs, self.cached_feat.place) self.hyps = self.searcher.get_one_best_hyps() assert self.cached_feat.shape[0] == 1 @@ -823,7 +823,7 @@ class ASRServerExecutor(ASRExecutor): ctc_probs = self.model.ctc.log_softmax( encoder_out) # (1, maxlen, vocab_size) ctc_probs = ctc_probs.squeeze(0) - self.searcher.search(xs, ctc_probs, xs.place) + self.searcher.search(ctc_probs, xs.place) # update the one best result self.hyps = self.searcher.get_one_best_hyps() diff --git a/paddlespeech/server/engine/asr/online/ctc_search.py b/paddlespeech/server/engine/asr/online/ctc_search.py index c3822b5c..b1c80c36 100644 --- a/paddlespeech/server/engine/asr/online/ctc_search.py +++ b/paddlespeech/server/engine/asr/online/ctc_search.py @@ -24,19 +24,18 @@ class CTCPrefixBeamSearch: """Implement the ctc prefix beam search Args: - config (_type_): _description_ + config (yacs.config.CfgNode): _description_ """ self.config = config self.reset() - def search(self, xs, ctc_probs, device, blank_id=0): + def search(self, ctc_probs, device, blank_id=0): """ctc prefix beam search method decode a chunk feature Args: xs (paddle.Tensor): feature data ctc_probs (paddle.Tensor): the ctc probability of all the tokens - encoder_out (paddle.Tensor): _description_ - encoder_mask (_type_): _description_ + device (paddle.fluid.core_avx.Place): the feature host device, such as CUDAPlace(0). blank_id (int, optional): the blank id in the vocab. Defaults to 0. Returns: @@ -45,7 +44,6 @@ class CTCPrefixBeamSearch: # decode logger.info("start to ctc prefix search") - # device = xs.place batch_size = 1 beam_size = self.config.beam_size maxlen = ctc_probs.shape[0] diff --git a/paddlespeech/server/tests/asr/online/websocket_client.py b/paddlespeech/server/tests/asr/online/websocket_client.py index 62e011ce..49cbd703 100644 --- a/paddlespeech/server/tests/asr/online/websocket_client.py +++ b/paddlespeech/server/tests/asr/online/websocket_client.py @@ -34,10 +34,9 @@ class ASRAudioHandler: def read_wave(self, wavfile_path: str): samples, sample_rate = soundfile.read(wavfile_path, dtype='int16') x_len = len(samples) - # chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz - chunk_size = 80 * 16 #80ms, sample_rate = 16kHz - if x_len % chunk_size != 0: + chunk_size = 85 * 16 #80ms, sample_rate = 16kHz + if x_len % chunk_size!= 0: padding_len_x = chunk_size - x_len % chunk_size else: padding_len_x = 0 @@ -48,7 +47,6 @@ class ASRAudioHandler: assert (x_len + padding_len_x) % chunk_size == 0 num_chunk = (x_len + padding_len_x) / chunk_size num_chunk = int(num_chunk) - for i in range(0, num_chunk): start = i * chunk_size end = start + chunk_size @@ -82,7 +80,6 @@ class ASRAudioHandler: msg = json.loads(msg) logging.info("receive msg={}".format(msg)) - result = msg # finished audio_info = json.dumps( { @@ -98,8 +95,8 @@ class ASRAudioHandler: # decode the bytes to str msg = json.loads(msg) - logging.info("receive msg={}".format(msg)) - + logging.info("final receive msg={}".format(msg)) + result = msg return result From 9c03280ca699dbf9837cdedbc0d93d2c11cc9412 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Tue, 19 Apr 2022 21:01:13 +0800 Subject: [PATCH 14/16] remove debug info, test=doc --- paddlespeech/s2t/models/u2/u2.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index f0d2711d..9b66126e 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -213,14 +213,12 @@ class U2BaseModel(ASRInterface, nn.Layer): num_decoding_left_chunks=num_decoding_left_chunks ) # (B, maxlen, encoder_dim) else: - print("offline decode from the asr") encoder_out, encoder_mask = self.encoder( speech, speech_lengths, decoding_chunk_size=decoding_chunk_size, num_decoding_left_chunks=num_decoding_left_chunks ) # (B, maxlen, encoder_dim) - print("offline decode success") return encoder_out, encoder_mask def recognize( @@ -281,14 +279,13 @@ class U2BaseModel(ASRInterface, nn.Layer): # TODO(Hui Zhang): if end_flag.sum() == running_size: if end_flag.cast(paddle.int64).sum() == running_size: break - + # 2.1 Forward decoder step hyps_mask = subsequent_mask(i).unsqueeze(0).repeat( running_size, 1, 1).to(device) # (B*N, i, i) # logp: (B*N, vocab) logp, cache = self.decoder.forward_one_step( encoder_out, encoder_mask, hyps, hyps_mask, cache) - # 2.2 First beam prune: select topk best prob at current time top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N) top_k_logp = mask_finished_scores(top_k_logp, end_flag) @@ -708,7 +705,6 @@ class U2BaseModel(ASRInterface, nn.Layer): List[List[int]]: transcripts. """ batch_size = feats.shape[0] - print("start to decode the audio feat") if decoding_method in ['ctc_prefix_beam_search', 'attention_rescoring'] and batch_size > 1: logger.error( @@ -716,7 +712,6 @@ class U2BaseModel(ASRInterface, nn.Layer): ) logger.error(f"current batch_size is {batch_size}") sys.exit(1) - print(f"use the {decoding_method} to decode the audio feat") if decoding_method == 'attention': hyps = self.recognize( feats, From ff4ddd229e8798f31fce71f7e096319d6171ed3f Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Tue, 19 Apr 2022 23:12:46 +0800 Subject: [PATCH 15/16] fix the unuseful code, test=doc --- paddlespeech/s2t/modules/ctc.py | 1 - paddlespeech/server/conf/ws_application.yaml | 8 -------- paddlespeech/server/conf/ws_conformer_application.yaml | 2 +- paddlespeech/server/engine/asr/online/asr_engine.py | 2 +- paddlespeech/server/ws/asr_socket.py | 1 - 5 files changed, 2 insertions(+), 12 deletions(-) diff --git a/paddlespeech/s2t/modules/ctc.py b/paddlespeech/s2t/modules/ctc.py index bd1219b1..1bb15873 100644 --- a/paddlespeech/s2t/modules/ctc.py +++ b/paddlespeech/s2t/modules/ctc.py @@ -181,7 +181,6 @@ class CTCDecoder(CTCDecoderBase): if self._ext_scorer is not None: return - from paddlespeech.s2t.decoders.ctcdecoder import Scorer # noqa: F401 if language_model_path != '': logger.info("begin to initialize the external scorer " "for decoding") diff --git a/paddlespeech/server/conf/ws_application.yaml b/paddlespeech/server/conf/ws_application.yaml index dae4a3ff..dee8d78b 100644 --- a/paddlespeech/server/conf/ws_application.yaml +++ b/paddlespeech/server/conf/ws_application.yaml @@ -45,11 +45,3 @@ asr_online: shift_n: 4 # frame window_ms: 20 # ms shift_ms: 10 # ms - - vad_conf: - aggressiveness: 2 - sample_rate: 16000 - frame_duration_ms: 20 - sample_width: 2 - padding_ms: 200 - padding_ratio: 0.9 \ No newline at end of file diff --git a/paddlespeech/server/conf/ws_conformer_application.yaml b/paddlespeech/server/conf/ws_conformer_application.yaml index 89a861ef..e14833de 100644 --- a/paddlespeech/server/conf/ws_conformer_application.yaml +++ b/paddlespeech/server/conf/ws_conformer_application.yaml @@ -21,7 +21,7 @@ engine_list: ['asr_online'] ################################### ASR ######################################### ################### speech task: asr; engine_type: online ####################### asr_online: - model_type: 'conformer_online_multi-cn' + model_type: 'conformer_online_multicn' am_model: # the pdmodel file of am static model [optional] am_params: # the pdiparams file of am static model [optional] lang: 'zh' diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index c79abf1b..34a028a3 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -59,7 +59,7 @@ pretrained_models = { 'lm_md5': '29e02312deb2e59b3c8686c7966d4fe3' }, - "conformer_online_multi-cn-zh-16k": { + "conformer_online_multicn-zh-16k": { 'url': 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.3.model.tar.gz', 'md5': diff --git a/paddlespeech/server/ws/asr_socket.py b/paddlespeech/server/ws/asr_socket.py index 82b05bc5..a865703d 100644 --- a/paddlespeech/server/ws/asr_socket.py +++ b/paddlespeech/server/ws/asr_socket.py @@ -28,7 +28,6 @@ router = APIRouter() @router.websocket('/ws/asr') async def websocket_endpoint(websocket: WebSocket): - print("websocket protocal receive the dataset") await websocket.accept() engine_pool = get_engine_pool() From ac9fcf7f4a53026bba8efe235d90a0693a70eae6 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Wed, 20 Apr 2022 00:15:37 +0800 Subject: [PATCH 16/16] fix the asr infernece model, paddle.no_grad, test=doc --- paddlespeech/server/engine/asr/online/asr_engine.py | 3 +++ paddlespeech/server/engine/asr/online/ctc_search.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index 34a028a3..758cbaab 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -356,6 +356,7 @@ class PaddleASRConnectionHanddler: else: raise Exception("invalid model name") + @paddle.no_grad() def decode_one_chunk(self, x_chunk, x_chunk_lens): logger.info("start to decoce one chunk with deepspeech2 model") input_names = self.am_predictor.get_input_names() @@ -397,6 +398,7 @@ class PaddleASRConnectionHanddler: logger.info(f"decode one best result: {trans_best[0]}") return trans_best[0] + @paddle.no_grad() def advance_decoding(self, is_finished=False): logger.info("start to decode with advanced_decoding method") cfg = self.ctc_decode_config @@ -503,6 +505,7 @@ class PaddleASRConnectionHanddler: else: return '' + @paddle.no_grad() def rescoring(self): if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: return diff --git a/paddlespeech/server/engine/asr/online/ctc_search.py b/paddlespeech/server/engine/asr/online/ctc_search.py index b1c80c36..8aee0a50 100644 --- a/paddlespeech/server/engine/asr/online/ctc_search.py +++ b/paddlespeech/server/engine/asr/online/ctc_search.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict - +import paddle from paddlespeech.cli.log import logger from paddlespeech.s2t.utils.utility import log_add @@ -29,6 +29,7 @@ class CTCPrefixBeamSearch: self.config = config self.reset() + @paddle.no_grad() def search(self, ctc_probs, device, blank_id=0): """ctc prefix beam search method decode a chunk feature