diff --git a/paddlespeech/server/conf/ws_application.yaml b/paddlespeech/server/conf/ws_application.yaml index c3a488fb..aa3c208b 100644 --- a/paddlespeech/server/conf/ws_application.yaml +++ b/paddlespeech/server/conf/ws_application.yaml @@ -71,15 +71,9 @@ asr_online: summary: True # False -> do not show predictor config chunk_buffer_conf: - frame_duration_ms: 85 - shift_ms: 40 + window_n: 7 # frame + shift_n: 4 # frame + window_ms: 25 # ms + shift_ms: 10 # ms sample_rate: 16000 - sample_width: 2 - - # vad_conf: - # aggressiveness: 2 - # sample_rate: 16000 - # frame_duration_ms: 20 - # sample_width: 2 - # padding_ms: 200 - # padding_ratio: 0.9 \ No newline at end of file + sample_width: 2 \ No newline at end of file diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index e1e4a7ad..e292f9cf 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from paddlespeech.s2t.utils.utility import log_add from typing import Optional -from collections import defaultdict + import numpy as np import paddle from numpy import float32 @@ -22,19 +21,18 @@ from yacs.config import CfgNode from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.asr.infer import model_alias -from paddlespeech.cli.asr.infer import pretrained_models from paddlespeech.cli.log import logger from paddlespeech.cli.utils import download_and_decompress from paddlespeech.cli.utils import MODEL_HOME from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.speech import SpeechSegment from paddlespeech.s2t.modules.ctc import CTCDecoder -from paddlespeech.s2t.modules.mask import mask_finished_preds -from paddlespeech.s2t.modules.mask import mask_finished_scores -from paddlespeech.s2t.modules.mask import subsequent_mask from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.utils.dynamic_import import dynamic_import +from paddlespeech.s2t.utils.tensor_utils import add_sos_eos +from paddlespeech.s2t.utils.tensor_utils import pad_sequence from paddlespeech.s2t.utils.utility import UpdateConfig +from paddlespeech.server.engine.asr.online.ctc_search import CTCPrefixBeamSearch from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.audio_process import pcm2float from paddlespeech.server.utils.paddle_predictor import init_predictor @@ -62,9 +60,9 @@ pretrained_models = { }, "conformer2online_aishell-zh-16k": { 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.0.model.tar.gz', + 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.1.model.tar.gz', 'md5': - '7989b3248c898070904cf042fd656003', + 'b450d5dfaea0ac227c595ce58d18b637', 'cfg_path': 'model.yaml', 'ckpt_path': @@ -123,9 +121,9 @@ class ASRServerExecutor(ASRExecutor): logger.info(f"Load the pretrained model, tag = {tag}") res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.res_path = res_path - self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/examples/aishell/asr1/model.yaml" - # self.cfg_path = os.path.join(res_path, - # pretrained_models[tag]['cfg_path']) + # self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/examples/aishell/asr1/model.yaml" + self.cfg_path = os.path.join(res_path, + pretrained_models[tag]['cfg_path']) self.am_model = os.path.join(res_path, pretrained_models[tag]['model']) @@ -177,6 +175,18 @@ class ASRServerExecutor(ASRExecutor): # update the decoding method if decode_method: self.config.decode.decoding_method = decode_method + + # we only support ctc_prefix_beam_search and attention_rescoring dedoding method + # Generally we set the decoding_method to attention_rescoring + if self.config.decode.decoding_method not in [ + "ctc_prefix_beam_search", "attention_rescoring" + ]: + logger.info( + "we set the decoding_method to attention_rescoring") + self.config.decode.decoding = "attention_rescoring" + assert self.config.decode.decoding_method in [ + "ctc_prefix_beam_search", "attention_rescoring" + ], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}" else: raise Exception("wrong type") if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: @@ -232,7 +242,7 @@ class ASRServerExecutor(ASRExecutor): logger.info("create the transformer like model success") # update the ctc decoding - self.searcher = None + self.searcher = CTCPrefixBeamSearch(self.config.decode) self.transformer_decode_reset() def reset_decoder_and_chunk(self): @@ -320,7 +330,16 @@ class ASRServerExecutor(ASRExecutor): def advanced_decoding(self, xs: paddle.Tensor, x_chunk_lens): logger.info("start to decode with advanced_decoding method") encoder_out, encoder_mask = self.decode_forward(xs) - self.ctc_prefix_beam_search(xs, encoder_out, encoder_mask) + ctc_probs = self.model.ctc.log_softmax( + encoder_out) # (1, maxlen, vocab_size) + ctc_probs = ctc_probs.squeeze(0) + self.searcher.search(xs, ctc_probs, xs.place) + # update the one best result + self.hyps = self.searcher.get_one_best_hyps() + + # now we supprot ctc_prefix_beam_search and attention_rescoring + if "attention_rescoring" in self.config.decode.decoding_method: + self.rescoring(encoder_out, xs.place) def decode_forward(self, xs): logger.info("get the model out from the feat") @@ -338,7 +357,6 @@ class ASRServerExecutor(ASRExecutor): num_frames = xs.shape[1] required_cache_size = decoding_chunk_size * num_decoding_left_chunks - logger.info("start to do model forward") outputs = [] @@ -359,85 +377,74 @@ class ASRServerExecutor(ASRExecutor): masks = masks.unsqueeze(1) return ys, masks + def rescoring(self, encoder_out, device): + logger.info("start to rescoring the hyps") + beam_size = self.config.decode.beam_size + hyps = self.searcher.get_hyps() + assert len(hyps) == beam_size + + hyp_list = [] + for hyp in hyps: + hyp_content = hyp[0] + # Prevent the hyp is empty + if len(hyp_content) == 0: + hyp_content = (self.model.ctc.blank_id, ) + hyp_content = paddle.to_tensor( + hyp_content, place=device, dtype=paddle.long) + hyp_list.append(hyp_content) + hyps_pad = pad_sequence(hyp_list, True, self.model.ignore_id) + hyps_lens = paddle.to_tensor( + [len(hyp[0]) for hyp in hyps], place=device, + dtype=paddle.long) # (beam_size,) + hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos, + self.model.ignore_id) + hyps_lens = hyps_lens + 1 # Add at begining + + encoder_out = encoder_out.repeat(beam_size, 1, 1) + encoder_mask = paddle.ones( + (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool) + decoder_out, _ = self.model.decoder( + encoder_out, encoder_mask, hyps_pad, + hyps_lens) # (beam_size, max_hyps_len, vocab_size) + # ctc score in ln domain + decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) + decoder_out = decoder_out.numpy() + + # Only use decoder score for rescoring + best_score = -float('inf') + best_index = 0 + # hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size + for i, hyp in enumerate(hyps): + score = 0.0 + for j, w in enumerate(hyp[0]): + score += decoder_out[i][j][w] + # last decoder output token is `eos`, for laste decoder input token. + score += decoder_out[i][len(hyp[0])][self.model.eos] + # add ctc score (which in ln domain) + score += hyp[1] * self.config.decode.ctc_weight + if score > best_score: + best_score = score + best_index = i + + # update the one best result + self.hyps = [hyps[best_index][0]] + return hyps[best_index][0] + def transformer_decode_reset(self): self.subsampling_cache = None self.elayers_output_cache = None self.conformer_cnn_cache = None - self.hyps = None self.offset = 0 - self.cur_hyps = None - self.hyps = None - - def ctc_prefix_beam_search(self, xs, encoder_out, encoder_mask, blank_id=0): - # decode - logger.info("start to ctc prefix search") - - device = xs.place - cfg = self.config.decode - batch_size = xs.shape[0] - beam_size = cfg.beam_size - maxlen = encoder_out.shape[1] - - ctc_probs = self.model.ctc.log_softmax(encoder_out) # (1, maxlen, vocab_size) - ctc_probs = ctc_probs.squeeze(0) - - # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score)) - # blank_ending_score and none_blank_ending_score in ln domain - if self.cur_hyps is None: - self.cur_hyps = [(tuple(), (0.0, -float('inf')))] - # 2. CTC beam search step by step - for t in range(0, maxlen): - logp = ctc_probs[t] # (vocab_size,) - # key: prefix, value (pb, pnb), default value(-inf, -inf) - next_hyps = defaultdict(lambda: (-float('inf'), -float('inf'))) - - # 2.1 First beam prune: select topk best - # do token passing process - top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,) - for s in top_k_index: - s = s.item() - ps = logp[s].item() - for prefix, (pb, pnb) in self.cur_hyps: - last = prefix[-1] if len(prefix) > 0 else None - if s == blank_id: # blank - n_pb, n_pnb = next_hyps[prefix] - n_pb = log_add([n_pb, pb + ps, pnb + ps]) - next_hyps[prefix] = (n_pb, n_pnb) - elif s == last: - # Update *ss -> *s; - n_pb, n_pnb = next_hyps[prefix] - n_pnb = log_add([n_pnb, pnb + ps]) - next_hyps[prefix] = (n_pb, n_pnb) - # Update *s-s -> *ss, - is for blank - n_prefix = prefix + (s, ) - n_pb, n_pnb = next_hyps[n_prefix] - n_pnb = log_add([n_pnb, pb + ps]) - next_hyps[n_prefix] = (n_pb, n_pnb) - else: - n_prefix = prefix + (s, ) - n_pb, n_pnb = next_hyps[n_prefix] - n_pnb = log_add([n_pnb, pb + ps, pnb + ps]) - next_hyps[n_prefix] = (n_pb, n_pnb) - - # 2.2 Second beam prune - next_hyps = sorted( - next_hyps.items(), - key=lambda x: log_add(list(x[1])), - reverse=True) - self.cur_hyps = next_hyps[:beam_size] - - hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in self.cur_hyps] - - self.hyps = [hyps[0][0]] - logger.info("ctc prefix search success") - return hyps, encoder_out + # decoding reset + self.searcher.reset() def update_result(self): logger.info("update the final result") + hyps = self.hyps self.result_transcripts = [ - self.text_feature.defeaturize(hyp) for hyp in self.hyps + self.text_feature.defeaturize(hyp) for hyp in hyps ] - self.result_tokenids = [hyp for hyp in self.hyps] + self.result_tokenids = [hyp for hyp in hyps] def extract_feat(self, samples, sample_rate): """extract feat @@ -483,9 +490,9 @@ class ASRServerExecutor(ASRExecutor): elif "conformer2online" in self.model_type: if sample_rate != self.sample_rate: - logger.info(f"audio sample rate {sample_rate} is not match," \ + logger.info(f"audio sample rate {sample_rate} is not match," "the model sample_rate is {self.sample_rate}") - logger.info(f"ASR Engine use the {self.model_type} to process") + logger.info("ASR Engine use the {self.model_type} to process") logger.info("Create the preprocess instance") preprocess_conf = self.config.preprocess_config preprocess_args = {"train": False} diff --git a/paddlespeech/server/engine/asr/online/ctc_search.py b/paddlespeech/server/engine/asr/online/ctc_search.py new file mode 100644 index 00000000..a91b8a21 --- /dev/null +++ b/paddlespeech/server/engine/asr/online/ctc_search.py @@ -0,0 +1,119 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import defaultdict + +from paddlespeech.cli.log import logger +from paddlespeech.s2t.utils.utility import log_add + +__all__ = ['CTCPrefixBeamSearch'] + + +class CTCPrefixBeamSearch: + def __init__(self, config): + """Implement the ctc prefix beam search + + Args: + config (_type_): _description_ + """ + self.config = config + self.reset() + + def search(self, xs, ctc_probs, device, blank_id=0): + """ctc prefix beam search method decode a chunk feature + + Args: + xs (paddle.Tensor): feature data + ctc_probs (paddle.Tensor): the ctc probability of all the tokens + encoder_out (paddle.Tensor): _description_ + encoder_mask (_type_): _description_ + blank_id (int, optional): the blank id in the vocab. Defaults to 0. + + Returns: + list: the search result + """ + # decode + logger.info("start to ctc prefix search") + + # device = xs.place + batch_size = xs.shape[0] + beam_size = self.config.beam_size + maxlen = ctc_probs.shape[0] + + assert len(ctc_probs.shape) == 2 + + # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score)) + # blank_ending_score and none_blank_ending_score in ln domain + if self.cur_hyps is None: + self.cur_hyps = [(tuple(), (0.0, -float('inf')))] + # 2. CTC beam search step by step + for t in range(0, maxlen): + logp = ctc_probs[t] # (vocab_size,) + # key: prefix, value (pb, pnb), default value(-inf, -inf) + next_hyps = defaultdict(lambda: (-float('inf'), -float('inf'))) + + # 2.1 First beam prune: select topk best + # do token passing process + top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,) + for s in top_k_index: + s = s.item() + ps = logp[s].item() + for prefix, (pb, pnb) in self.cur_hyps: + last = prefix[-1] if len(prefix) > 0 else None + if s == blank_id: # blank + n_pb, n_pnb = next_hyps[prefix] + n_pb = log_add([n_pb, pb + ps, pnb + ps]) + next_hyps[prefix] = (n_pb, n_pnb) + elif s == last: + # Update *ss -> *s; + n_pb, n_pnb = next_hyps[prefix] + n_pnb = log_add([n_pnb, pnb + ps]) + next_hyps[prefix] = (n_pb, n_pnb) + # Update *s-s -> *ss, - is for blank + n_prefix = prefix + (s, ) + n_pb, n_pnb = next_hyps[n_prefix] + n_pnb = log_add([n_pnb, pb + ps]) + next_hyps[n_prefix] = (n_pb, n_pnb) + else: + n_prefix = prefix + (s, ) + n_pb, n_pnb = next_hyps[n_prefix] + n_pnb = log_add([n_pnb, pb + ps, pnb + ps]) + next_hyps[n_prefix] = (n_pb, n_pnb) + + # 2.2 Second beam prune + next_hyps = sorted( + next_hyps.items(), + key=lambda x: log_add(list(x[1])), + reverse=True) + self.cur_hyps = next_hyps[:beam_size] + + self.hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in self.cur_hyps] + logger.info("ctc prefix search success") + return self.hyps + + def get_one_best_hyps(self): + """Return the one best result + + Returns: + list: the one best result + """ + return [self.hyps[0][0]] + + def get_hyps(self): + return self.hyps + + def reset(self): + """Rest the search cache value + """ + self.cur_hyps = None + self.hyps = None diff --git a/paddlespeech/server/tests/__init__.py b/paddlespeech/server/tests/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/server/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/tests/asr/__init__.py b/paddlespeech/server/tests/asr/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/server/tests/asr/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/tests/asr/offline/__init__.py b/paddlespeech/server/tests/asr/offline/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/server/tests/asr/offline/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/tests/asr/online/__init__.py b/paddlespeech/server/tests/asr/online/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/server/tests/asr/online/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/ws/asr_socket.py b/paddlespeech/server/ws/asr_socket.py index 4d1013f4..87b43d2c 100644 --- a/paddlespeech/server/ws/asr_socket.py +++ b/paddlespeech/server/ws/asr_socket.py @@ -34,17 +34,17 @@ async def websocket_endpoint(websocket: WebSocket): engine_pool = get_engine_pool() asr_engine = engine_pool['asr'] # init buffer + # each websocekt connection has its own chunk buffer chunk_buffer_conf = asr_engine.config.chunk_buffer_conf chunk_buffer = ChunkBuffer( - window_n=7, - shift_n=4, - window_ms=20, - shift_ms=10, - sample_rate=chunk_buffer_conf['sample_rate'], - sample_width=chunk_buffer_conf['sample_width']) + window_n=chunk_buffer_conf.window_n, + shift_n=chunk_buffer_conf.shift_n, + window_ms=chunk_buffer_conf.window_ms, + shift_ms=chunk_buffer_conf.shift_ms, + sample_rate=chunk_buffer_conf.sample_rate, + sample_width=chunk_buffer_conf.sample_width) + # init vad - # print(asr_engine.config) - # print(type(asr_engine.config)) vad_conf = asr_engine.config.get('vad_conf', None) if vad_conf: vad = VADAudio( @@ -72,7 +72,7 @@ async def websocket_endpoint(websocket: WebSocket): engine_pool = get_engine_pool() asr_engine = engine_pool['asr'] # reset single engine for an new connection - # asr_engine.reset() + asr_engine.reset() resp = {"status": "ok", "signal": "finished"} await websocket.send_json(resp) break @@ -85,21 +85,16 @@ async def websocket_endpoint(websocket: WebSocket): engine_pool = get_engine_pool() asr_engine = engine_pool['asr'] asr_results = "" - # frames = chunk_buffer.frame_generator(message) - # for frame in frames: - # # get the pcm data from the bytes - # samples = np.frombuffer(frame.bytes, dtype=np.int16) - # sample_rate = asr_engine.config.sample_rate - # x_chunk, x_chunk_lens = asr_engine.preprocess(samples, - # sample_rate) - # asr_engine.run(x_chunk, x_chunk_lens) - # asr_results = asr_engine.postprocess() - samples = np.frombuffer(message, dtype=np.int16) - sample_rate = asr_engine.config.sample_rate - x_chunk, x_chunk_lens = asr_engine.preprocess(samples, - sample_rate) - asr_engine.run(x_chunk, x_chunk_lens) - # asr_results = asr_engine.postprocess() + frames = chunk_buffer.frame_generator(message) + for frame in frames: + # get the pcm data from the bytes + samples = np.frombuffer(frame.bytes, dtype=np.int16) + sample_rate = asr_engine.config.sample_rate + x_chunk, x_chunk_lens = asr_engine.preprocess(samples, + sample_rate) + asr_engine.run(x_chunk, x_chunk_lens) + asr_results = asr_engine.postprocess() + asr_results = asr_engine.postprocess() resp = {'asr_results': asr_results}