check chunk window process, test=doc

pull/1704/head
xiongxinlei 3 years ago
parent d2640c1406
commit 10e825d9b2

@ -145,6 +145,8 @@ class PaddleASRConnectionHanddler:
if self.cached_feat is None:
self.cached_feat = x_chunk
else:
assert(len(x_chunk.shape) == 3)
assert(len(self.cached_feat.shape) == 3)
self.cached_feat = paddle.concat([self.cached_feat, x_chunk], axis=1)
num_frames = x_chunk.shape[1]
@ -170,6 +172,7 @@ class PaddleASRConnectionHanddler:
self.num_samples = 0
self.num_frames = 0
self.chunk_num = 0
self.global_frame_offset = 0
self.result = []
@ -210,23 +213,24 @@ class PaddleASRConnectionHanddler:
if num_frames < decoding_window and not is_finished:
return None, None
# logger.info("start to do model forward")
# required_cache_size = decoding_chunk_size * num_decoding_left_chunks
# outputs = []
logger.info("start to do model forward")
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
outputs = []
# # num_frames - context + 1 ensure that current frame can get context window
# if is_finished:
# # if get the finished chunk, we need process the last context
# left_frames = context
# else:
# # we only process decoding_window frames for one chunk
# left_frames = decoding_window
# num_frames - context + 1 ensure that current frame can get context window
if is_finished:
# if get the finished chunk, we need process the last context
left_frames = context
else:
# we only process decoding_window frames for one chunk
left_frames = decoding_window
# logger.info(f"")
# end = None
# for cur in range(0, num_frames - left_frames + 1, stride):
# end = min(cur + decoding_window, num_frames)
# print(f"cur: {cur}, end: {end}")
end = None
for cur in range(0, num_frames - left_frames + 1, stride):
end = min(cur + decoding_window, num_frames)
print(f"cur chunk: {self.chunk_num}, cur: {cur}, end: {end}")
self.chunk_num += 1
# chunk_xs = self.cached_feat[:, cur:end, :]
# (y, self.subsampling_cache, self.elayers_output_cache,
# self.conformer_cnn_cache) = self.model.encoder.forward_chunk(
@ -236,7 +240,14 @@ class PaddleASRConnectionHanddler:
# outputs.append(y)
# update the offset
# self.offset += y.shape[1]
# self.cached_feat = self.cached_feat[end:]
# remove the processed feat
if end == num_frames:
self.cached_feat = None
else:
assert self.cached_feat.shape[0] == 1
self.cached_feat = self.cached_feat[0,end:,:].unsqueeze(0)
assert len(self.cached_feat.shape) == 3, f"current cache feat shape is: {self.cached_feat.shape}"
# ys = paddle.cat(outputs, 1)
# masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
# masks = masks.unsqueeze(1)
@ -309,9 +320,9 @@ class ASRServerExecutor(ASRExecutor):
logger.info(f"Load the pretrained model, tag = {tag}")
res_path = self._get_pretrained_path(tag) # wenetspeech_zh
self.res_path = res_path
# self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/examples/aishell/asr1/model.yaml"
self.cfg_path = os.path.join(res_path,
pretrained_models[tag]['cfg_path'])
self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/examples/aishell/asr1/model.yaml"
# self.cfg_path = os.path.join(res_path,
# pretrained_models[tag]['cfg_path'])
self.am_model = os.path.join(res_path,
pretrained_models[tag]['model'])

Loading…
Cancel
Save