diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index 1f6060e9..c13b2f6d 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -185,9 +185,9 @@ class PaddleASRConnectionHanddler: f"we will use the transformer like model : {self.model_type}" ) self.advance_decoding(is_finished) - # self.update_result() + self.update_result() - # return self.result_transcripts[0] + return self.result_transcripts[0] except Exception as e: logger.exception(e) else: @@ -225,22 +225,36 @@ class PaddleASRConnectionHanddler: # we only process decoding_window frames for one chunk left_frames = decoding_window - # logger.info(f"") + # record the end for removing the processed feat end = None for cur in range(0, num_frames - left_frames + 1, stride): end = min(cur + decoding_window, num_frames) - print(f"cur chunk: {self.chunk_num}, cur: {cur}, end: {end}") + self.chunk_num += 1 - # chunk_xs = self.cached_feat[:, cur:end, :] - # (y, self.subsampling_cache, self.elayers_output_cache, - # self.conformer_cnn_cache) = self.model.encoder.forward_chunk( - # chunk_xs, self.offset, required_cache_size, - # self.subsampling_cache, self.elayers_output_cache, - # self.conformer_cnn_cache) - # outputs.append(y) + chunk_xs = self.cached_feat[:, cur:end, :] + (y, self.subsampling_cache, self.elayers_output_cache, + self.conformer_cnn_cache) = self.model.encoder.forward_chunk( + chunk_xs, self.offset, required_cache_size, + self.subsampling_cache, self.elayers_output_cache, + self.conformer_cnn_cache) + outputs.append(y) + # update the offset - # self.offset += y.shape[1] + self.offset += y.shape[1] + ys = paddle.cat(outputs, 1) + masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) + masks = masks.unsqueeze(1) + + # get the ctc probs + ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size) + ctc_probs = ctc_probs.squeeze(0) + # self.searcher.search(xs, ctc_probs, xs.place) + + self.searcher.search(None, ctc_probs, self.cached_feat.place) + + self.hyps = self.searcher.get_one_best_hyps() + # remove the processed feat if end == num_frames: self.cached_feat = None @@ -248,19 +262,7 @@ class PaddleASRConnectionHanddler: assert self.cached_feat.shape[0] == 1 self.cached_feat = self.cached_feat[0,end:,:].unsqueeze(0) assert len(self.cached_feat.shape) == 3, f"current cache feat shape is: {self.cached_feat.shape}" - # ys = paddle.cat(outputs, 1) - # masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) - # masks = masks.unsqueeze(1) - - # # get the ctc probs - # ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size) - # ctc_probs = ctc_probs.squeeze(0) - # # self.searcher.search(xs, ctc_probs, xs.place) - - # self.searcher.search(None, ctc_probs, self.cached_feat.place) - # self.hyps = self.searcher.get_one_best_hyps() - # ys for rescoring # return ys, masks diff --git a/paddlespeech/server/engine/asr/online/ctc_search.py b/paddlespeech/server/engine/asr/online/ctc_search.py index a91b8a21..bf4c4b30 100644 --- a/paddlespeech/server/engine/asr/online/ctc_search.py +++ b/paddlespeech/server/engine/asr/online/ctc_search.py @@ -46,7 +46,7 @@ class CTCPrefixBeamSearch: logger.info("start to ctc prefix search") # device = xs.place - batch_size = xs.shape[0] + batch_size = 1 beam_size = self.config.beam_size maxlen = ctc_probs.shape[0] diff --git a/paddlespeech/server/utils/buffer.py b/paddlespeech/server/utils/buffer.py index 12b1f0e5..d4e6cd49 100644 --- a/paddlespeech/server/utils/buffer.py +++ b/paddlespeech/server/utils/buffer.py @@ -63,12 +63,12 @@ class ChunkBuffer(object): the sample rate. Yields Frames of the requested duration. """ + audio = self.remained_audio + audio self.remained_audio = b'' offset = 0 timestamp = 0.0 - while offset + self.window_bytes <= len(audio): yield Frame(audio[offset:offset + self.window_bytes], timestamp, self.window_sec) diff --git a/paddlespeech/server/ws/asr_socket.py b/paddlespeech/server/ws/asr_socket.py index 87b43d2c..04807e5c 100644 --- a/paddlespeech/server/ws/asr_socket.py +++ b/paddlespeech/server/ws/asr_socket.py @@ -22,6 +22,7 @@ from starlette.websockets import WebSocketState as WebSocketState from paddlespeech.server.engine.engine_pool import get_engine_pool from paddlespeech.server.utils.buffer import ChunkBuffer from paddlespeech.server.utils.vad import VADAudio +from paddlespeech.server.engine.asr.online.asr_engine import PaddleASRConnectionHanddler router = APIRouter() @@ -33,6 +34,7 @@ async def websocket_endpoint(websocket: WebSocket): engine_pool = get_engine_pool() asr_engine = engine_pool['asr'] + connection_handler = None # init buffer # each websocekt connection has its own chunk buffer chunk_buffer_conf = asr_engine.config.chunk_buffer_conf @@ -67,13 +69,17 @@ async def websocket_endpoint(websocket: WebSocket): if message['signal'] == 'start': resp = {"status": "ok", "signal": "server_ready"} # do something at begining here + # create the instance to process the audio + connection_handler = PaddleASRConnectionHanddler(asr_engine) await websocket.send_json(resp) elif message['signal'] == 'end': engine_pool = get_engine_pool() asr_engine = engine_pool['asr'] # reset single engine for an new connection + asr_results = connection_handler.decode(is_finished=True) + connection_handler.reset() asr_engine.reset() - resp = {"status": "ok", "signal": "finished"} + resp = {"status": "ok", "signal": "finished", 'asr_results': asr_results} await websocket.send_json(resp) break else: @@ -81,23 +87,27 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.send_json(resp) elif "bytes" in message: message = message["bytes"] - engine_pool = get_engine_pool() asr_engine = engine_pool['asr'] asr_results = "" - frames = chunk_buffer.frame_generator(message) - for frame in frames: - # get the pcm data from the bytes - samples = np.frombuffer(frame.bytes, dtype=np.int16) - sample_rate = asr_engine.config.sample_rate - x_chunk, x_chunk_lens = asr_engine.preprocess(samples, - sample_rate) - asr_engine.run(x_chunk, x_chunk_lens) - asr_results = asr_engine.postprocess() + connection_handler.extract_feat(message) + asr_results = connection_handler.decode(is_finished=False) + # connection_handler. + # frames = chunk_buffer.frame_generator(message) + # for frame in frames: + # # get the pcm data from the bytes + # samples = np.frombuffer(frame.bytes, dtype=np.int16) + # sample_rate = asr_engine.config.sample_rate + # x_chunk, x_chunk_lens = asr_engine.preprocess(samples, + # sample_rate) + # asr_engine.run(x_chunk, x_chunk_lens) + # asr_results = asr_engine.postprocess() - asr_results = asr_engine.postprocess() + # # connection accept the sample data frame by frame + + # asr_results = asr_engine.postprocess() resp = {'asr_results': asr_results} - + print("\n") await websocket.send_json(resp) except WebSocketDisconnect: pass