diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index c13b2f6d3..696d223a6 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -83,8 +83,10 @@ pretrained_models = { class PaddleASRConnectionHanddler: def __init__(self, asr_engine): super().__init__() + logger.info("create an paddle asr connection handler to process the websocket connection") self.config = asr_engine.config self.model_config = asr_engine.executor.config + self.model = asr_engine.executor.model self.asr_engine = asr_engine self.init() @@ -149,6 +151,10 @@ class PaddleASRConnectionHanddler: assert(len(self.cached_feat.shape) == 3) self.cached_feat = paddle.concat([self.cached_feat, x_chunk], axis=1) + # set the feat device + if self.device is None: + self.device = self.cached_feat.place + num_frames = x_chunk.shape[1] self.num_frames += num_frames self.remained_wav = self.remained_wav[self.n_shift * num_frames:] @@ -165,16 +171,17 @@ class PaddleASRConnectionHanddler: self.subsampling_cache = None self.elayers_output_cache = None self.conformer_cnn_cache = None - self.encoder_outs_ = None + self.encoder_out = None self.cached_feat = None self.remained_wav = None self.offset = 0 self.num_samples = 0 - + self.device = None + self.hyps = [] self.num_frames = 0 self.chunk_num = 0 self.global_frame_offset = 0 - self.result = [] + self.result_transcripts = [''] def decode(self, is_finished=False): if "deepspeech2online" in self.model_type: @@ -187,7 +194,6 @@ class PaddleASRConnectionHanddler: self.advance_decoding(is_finished) self.update_result() - return self.result_transcripts[0] except Exception as e: logger.exception(e) else: @@ -203,16 +209,26 @@ class PaddleASRConnectionHanddler: subsampling = self.model.encoder.embed.subsampling_rate context = self.model.encoder.embed.right_context + 1 stride = subsampling * decoding_chunk_size + cached_feature_num = context - subsampling # processed chunk feature cached for next chunk # decoding window for model decoding_window = (decoding_chunk_size - 1) * subsampling + context + if self.cached_feat is None: + logger.info("no audio feat, please input more pcm data") + return + num_frames = self.cached_feat.shape[1] logger.info(f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames") # the cached feat must be larger decoding_window if num_frames < decoding_window and not is_finished: + logger.info(f"frame feat num is less than {decoding_window}, please input more pcm data") return None, None - + + if num_frames < context: + logger.info("flast {num_frames} is less than context {context} frames, and we cannot do model forward") + return None, None + logger.info("start to do model forward") required_cache_size = decoding_chunk_size * num_decoding_left_chunks outputs = [] @@ -242,14 +258,18 @@ class PaddleASRConnectionHanddler: # update the offset self.offset += y.shape[1] + logger.info(f"output size: {len(outputs)}") ys = paddle.cat(outputs, 1) - masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) - masks = masks.unsqueeze(1) + if self.encoder_out is None: + self.encoder_out = ys + else: + self.encoder_out = paddle.concat([self.encoder_out, ys], axis=1) + # masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) + # masks = masks.unsqueeze(1) # get the ctc probs ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size) ctc_probs = ctc_probs.squeeze(0) - # self.searcher.search(xs, ctc_probs, xs.place) self.searcher.search(None, ctc_probs, self.cached_feat.place) @@ -260,7 +280,8 @@ class PaddleASRConnectionHanddler: self.cached_feat = None else: assert self.cached_feat.shape[0] == 1 - self.cached_feat = self.cached_feat[0,end:,:].unsqueeze(0) + assert end >= cached_feature_num + self.cached_feat = self.cached_feat[0,end - cached_feature_num:,:].unsqueeze(0) assert len(self.cached_feat.shape) == 3, f"current cache feat shape is: {self.cached_feat.shape}" # ys for rescoring @@ -274,9 +295,75 @@ class PaddleASRConnectionHanddler: ] self.result_tokenids = [hyp for hyp in hyps] + def get_result(self): + if len(self.result_transcripts) > 0: + return self.result_transcripts[0] + else: + return '' + def rescoring(self): - pass + logger.info("rescoring the final result") + if "attention_rescoring" != self.ctc_decode_config.decoding_method: + return + + self.searcher.finalize_search() + self.update_result() + + beam_size = self.ctc_decode_config.beam_size + hyps = self.searcher.get_hyps() + if hyps is None or len(hyps) == 0: + return + + # assert len(hyps) == beam_size + paddle.save(self.encoder_out, "encoder.out") + hyp_list = [] + for hyp in hyps: + hyp_content = hyp[0] + # Prevent the hyp is empty + if len(hyp_content) == 0: + hyp_content = (self.model.ctc.blank_id, ) + hyp_content = paddle.to_tensor( + hyp_content, place=self.device, dtype=paddle.long) + hyp_list.append(hyp_content) + hyps_pad = pad_sequence(hyp_list, True, self.model.ignore_id) + hyps_lens = paddle.to_tensor( + [len(hyp[0]) for hyp in hyps], place=self.device, + dtype=paddle.long) # (beam_size,) + hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos, + self.model.ignore_id) + hyps_lens = hyps_lens + 1 # Add at begining + + encoder_out = self.encoder_out.repeat(beam_size, 1, 1) + encoder_mask = paddle.ones( + (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool) + decoder_out, _ = self.model.decoder( + encoder_out, encoder_mask, hyps_pad, + hyps_lens) # (beam_size, max_hyps_len, vocab_size) + # ctc score in ln domain + decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) + decoder_out = decoder_out.numpy() + + # Only use decoder score for rescoring + best_score = -float('inf') + best_index = 0 + # hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size + for i, hyp in enumerate(hyps): + score = 0.0 + for j, w in enumerate(hyp[0]): + score += decoder_out[i][j][w] + # last decoder output token is `eos`, for laste decoder input token. + score += decoder_out[i][len(hyp[0])][self.model.eos] + # add ctc score (which in ln domain) + score += hyp[1] * self.ctc_decode_config.ctc_weight + if score > best_score: + best_score = score + best_index = i + # update the one best result + logger.info(f"best index: {best_index}") + self.hyps = [hyps[best_index][0]] + self.update_result() + # return hyps[best_index][0] @@ -552,7 +639,7 @@ class ASRServerExecutor(ASRExecutor): subsampling = self.model.encoder.embed.subsampling_rate context = self.model.encoder.embed.right_context + 1 stride = subsampling * decoding_chunk_size - + # decoding window for model decoding_window = (decoding_chunk_size - 1) * subsampling + context num_frames = xs.shape[1] diff --git a/paddlespeech/server/engine/asr/online/ctc_search.py b/paddlespeech/server/engine/asr/online/ctc_search.py index bf4c4b306..c3822b5c9 100644 --- a/paddlespeech/server/engine/asr/online/ctc_search.py +++ b/paddlespeech/server/engine/asr/online/ctc_search.py @@ -110,6 +110,11 @@ class CTCPrefixBeamSearch: return [self.hyps[0][0]] def get_hyps(self): + """Return the search hyps + + Returns: + list: return the search hyps + """ return self.hyps def reset(self): @@ -117,3 +122,8 @@ class CTCPrefixBeamSearch: """ self.cur_hyps = None self.hyps = None + + def finalize_search(self): + """do nothing in ctc_prefix_beam_search + """ + pass diff --git a/paddlespeech/server/ws/asr_socket.py b/paddlespeech/server/ws/asr_socket.py index 04807e5c9..ae7c5eb4d 100644 --- a/paddlespeech/server/ws/asr_socket.py +++ b/paddlespeech/server/ws/asr_socket.py @@ -13,16 +13,15 @@ # limitations under the License. import json -import numpy as np from fastapi import APIRouter from fastapi import WebSocket from fastapi import WebSocketDisconnect from starlette.websockets import WebSocketState as WebSocketState +from paddlespeech.server.engine.asr.online.asr_engine import PaddleASRConnectionHanddler from paddlespeech.server.engine.engine_pool import get_engine_pool from paddlespeech.server.utils.buffer import ChunkBuffer from paddlespeech.server.utils.vad import VADAudio -from paddlespeech.server.engine.asr.online.asr_engine import PaddleASRConnectionHanddler router = APIRouter() @@ -73,13 +72,17 @@ async def websocket_endpoint(websocket: WebSocket): connection_handler = PaddleASRConnectionHanddler(asr_engine) await websocket.send_json(resp) elif message['signal'] == 'end': - engine_pool = get_engine_pool() - asr_engine = engine_pool['asr'] # reset single engine for an new connection - asr_results = connection_handler.decode(is_finished=True) + connection_handler.decode(is_finished=True) + connection_handler.rescoring() + asr_results = connection_handler.get_result() connection_handler.reset() - asr_engine.reset() - resp = {"status": "ok", "signal": "finished", 'asr_results': asr_results} + + resp = { + "status": "ok", + "signal": "finished", + 'asr_results': asr_results + } await websocket.send_json(resp) break else: @@ -87,25 +90,11 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.send_json(resp) elif "bytes" in message: message = message["bytes"] - engine_pool = get_engine_pool() - asr_engine = engine_pool['asr'] - asr_results = "" + connection_handler.extract_feat(message) - asr_results = connection_handler.decode(is_finished=False) - # connection_handler. - # frames = chunk_buffer.frame_generator(message) - # for frame in frames: - # # get the pcm data from the bytes - # samples = np.frombuffer(frame.bytes, dtype=np.int16) - # sample_rate = asr_engine.config.sample_rate - # x_chunk, x_chunk_lens = asr_engine.preprocess(samples, - # sample_rate) - # asr_engine.run(x_chunk, x_chunk_lens) - # asr_results = asr_engine.postprocess() + connection_handler.decode(is_finished=False) + asr_results = connection_handler.get_result() - # # connection accept the sample data frame by frame - - # asr_results = asr_engine.postprocess() resp = {'asr_results': asr_results} print("\n") await websocket.send_json(resp)