diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index 3546e598..1f6060e9 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -145,6 +145,8 @@ class PaddleASRConnectionHanddler: if self.cached_feat is None: self.cached_feat = x_chunk else: + assert(len(x_chunk.shape) == 3) + assert(len(self.cached_feat.shape) == 3) self.cached_feat = paddle.concat([self.cached_feat, x_chunk], axis=1) num_frames = x_chunk.shape[1] @@ -170,6 +172,7 @@ class PaddleASRConnectionHanddler: self.num_samples = 0 self.num_frames = 0 + self.chunk_num = 0 self.global_frame_offset = 0 self.result = [] @@ -210,23 +213,24 @@ class PaddleASRConnectionHanddler: if num_frames < decoding_window and not is_finished: return None, None - # logger.info("start to do model forward") - # required_cache_size = decoding_chunk_size * num_decoding_left_chunks - # outputs = [] - - # # num_frames - context + 1 ensure that current frame can get context window - # if is_finished: - # # if get the finished chunk, we need process the last context - # left_frames = context - # else: - # # we only process decoding_window frames for one chunk - # left_frames = decoding_window + logger.info("start to do model forward") + required_cache_size = decoding_chunk_size * num_decoding_left_chunks + outputs = [] + + # num_frames - context + 1 ensure that current frame can get context window + if is_finished: + # if get the finished chunk, we need process the last context + left_frames = context + else: + # we only process decoding_window frames for one chunk + left_frames = decoding_window # logger.info(f"") - # end = None - # for cur in range(0, num_frames - left_frames + 1, stride): - # end = min(cur + decoding_window, num_frames) - # print(f"cur: {cur}, end: {end}") + end = None + for cur in range(0, num_frames - left_frames + 1, stride): + end = min(cur + decoding_window, num_frames) + print(f"cur chunk: {self.chunk_num}, cur: {cur}, end: {end}") + self.chunk_num += 1 # chunk_xs = self.cached_feat[:, cur:end, :] # (y, self.subsampling_cache, self.elayers_output_cache, # self.conformer_cnn_cache) = self.model.encoder.forward_chunk( @@ -236,7 +240,14 @@ class PaddleASRConnectionHanddler: # outputs.append(y) # update the offset # self.offset += y.shape[1] - # self.cached_feat = self.cached_feat[end:] + + # remove the processed feat + if end == num_frames: + self.cached_feat = None + else: + assert self.cached_feat.shape[0] == 1 + self.cached_feat = self.cached_feat[0,end:,:].unsqueeze(0) + assert len(self.cached_feat.shape) == 3, f"current cache feat shape is: {self.cached_feat.shape}" # ys = paddle.cat(outputs, 1) # masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) # masks = masks.unsqueeze(1) @@ -309,9 +320,9 @@ class ASRServerExecutor(ASRExecutor): logger.info(f"Load the pretrained model, tag = {tag}") res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.res_path = res_path - # self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/examples/aishell/asr1/model.yaml" - self.cfg_path = os.path.join(res_path, - pretrained_models[tag]['cfg_path']) + self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/examples/aishell/asr1/model.yaml" + # self.cfg_path = os.path.join(res_path, + # pretrained_models[tag]['cfg_path']) self.am_model = os.path.join(res_path, pretrained_models[tag]['model'])