|
|
|
@ -185,9 +185,9 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
f"we will use the transformer like model : {self.model_type}"
|
|
|
|
|
)
|
|
|
|
|
self.advance_decoding(is_finished)
|
|
|
|
|
# self.update_result()
|
|
|
|
|
self.update_result()
|
|
|
|
|
|
|
|
|
|
# return self.result_transcripts[0]
|
|
|
|
|
return self.result_transcripts[0]
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.exception(e)
|
|
|
|
|
else:
|
|
|
|
@ -225,22 +225,36 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
# we only process decoding_window frames for one chunk
|
|
|
|
|
left_frames = decoding_window
|
|
|
|
|
|
|
|
|
|
# logger.info(f"")
|
|
|
|
|
# record the end for removing the processed feat
|
|
|
|
|
end = None
|
|
|
|
|
for cur in range(0, num_frames - left_frames + 1, stride):
|
|
|
|
|
end = min(cur + decoding_window, num_frames)
|
|
|
|
|
print(f"cur chunk: {self.chunk_num}, cur: {cur}, end: {end}")
|
|
|
|
|
|
|
|
|
|
self.chunk_num += 1
|
|
|
|
|
# chunk_xs = self.cached_feat[:, cur:end, :]
|
|
|
|
|
# (y, self.subsampling_cache, self.elayers_output_cache,
|
|
|
|
|
# self.conformer_cnn_cache) = self.model.encoder.forward_chunk(
|
|
|
|
|
# chunk_xs, self.offset, required_cache_size,
|
|
|
|
|
# self.subsampling_cache, self.elayers_output_cache,
|
|
|
|
|
# self.conformer_cnn_cache)
|
|
|
|
|
# outputs.append(y)
|
|
|
|
|
chunk_xs = self.cached_feat[:, cur:end, :]
|
|
|
|
|
(y, self.subsampling_cache, self.elayers_output_cache,
|
|
|
|
|
self.conformer_cnn_cache) = self.model.encoder.forward_chunk(
|
|
|
|
|
chunk_xs, self.offset, required_cache_size,
|
|
|
|
|
self.subsampling_cache, self.elayers_output_cache,
|
|
|
|
|
self.conformer_cnn_cache)
|
|
|
|
|
outputs.append(y)
|
|
|
|
|
|
|
|
|
|
# update the offset
|
|
|
|
|
# self.offset += y.shape[1]
|
|
|
|
|
self.offset += y.shape[1]
|
|
|
|
|
|
|
|
|
|
ys = paddle.cat(outputs, 1)
|
|
|
|
|
masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
|
|
|
|
|
masks = masks.unsqueeze(1)
|
|
|
|
|
|
|
|
|
|
# get the ctc probs
|
|
|
|
|
ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size)
|
|
|
|
|
ctc_probs = ctc_probs.squeeze(0)
|
|
|
|
|
# self.searcher.search(xs, ctc_probs, xs.place)
|
|
|
|
|
|
|
|
|
|
self.searcher.search(None, ctc_probs, self.cached_feat.place)
|
|
|
|
|
|
|
|
|
|
self.hyps = self.searcher.get_one_best_hyps()
|
|
|
|
|
|
|
|
|
|
# remove the processed feat
|
|
|
|
|
if end == num_frames:
|
|
|
|
|
self.cached_feat = None
|
|
|
|
@ -248,19 +262,7 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
assert self.cached_feat.shape[0] == 1
|
|
|
|
|
self.cached_feat = self.cached_feat[0,end:,:].unsqueeze(0)
|
|
|
|
|
assert len(self.cached_feat.shape) == 3, f"current cache feat shape is: {self.cached_feat.shape}"
|
|
|
|
|
# ys = paddle.cat(outputs, 1)
|
|
|
|
|
# masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
|
|
|
|
|
# masks = masks.unsqueeze(1)
|
|
|
|
|
|
|
|
|
|
# # get the ctc probs
|
|
|
|
|
# ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size)
|
|
|
|
|
# ctc_probs = ctc_probs.squeeze(0)
|
|
|
|
|
# # self.searcher.search(xs, ctc_probs, xs.place)
|
|
|
|
|
|
|
|
|
|
# self.searcher.search(None, ctc_probs, self.cached_feat.place)
|
|
|
|
|
|
|
|
|
|
# self.hyps = self.searcher.get_one_best_hyps()
|
|
|
|
|
|
|
|
|
|
# ys for rescoring
|
|
|
|
|
# return ys, masks
|
|
|
|
|
|
|
|
|
|