|
|
@ -83,8 +83,10 @@ pretrained_models = {
|
|
|
|
class PaddleASRConnectionHanddler:
|
|
|
|
class PaddleASRConnectionHanddler:
|
|
|
|
def __init__(self, asr_engine):
|
|
|
|
def __init__(self, asr_engine):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
logger.info("create an paddle asr connection handler to process the websocket connection")
|
|
|
|
self.config = asr_engine.config
|
|
|
|
self.config = asr_engine.config
|
|
|
|
self.model_config = asr_engine.executor.config
|
|
|
|
self.model_config = asr_engine.executor.config
|
|
|
|
|
|
|
|
self.model = asr_engine.executor.model
|
|
|
|
self.asr_engine = asr_engine
|
|
|
|
self.asr_engine = asr_engine
|
|
|
|
|
|
|
|
|
|
|
|
self.init()
|
|
|
|
self.init()
|
|
|
@ -149,6 +151,10 @@ class PaddleASRConnectionHanddler:
|
|
|
|
assert(len(self.cached_feat.shape) == 3)
|
|
|
|
assert(len(self.cached_feat.shape) == 3)
|
|
|
|
self.cached_feat = paddle.concat([self.cached_feat, x_chunk], axis=1)
|
|
|
|
self.cached_feat = paddle.concat([self.cached_feat, x_chunk], axis=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# set the feat device
|
|
|
|
|
|
|
|
if self.device is None:
|
|
|
|
|
|
|
|
self.device = self.cached_feat.place
|
|
|
|
|
|
|
|
|
|
|
|
num_frames = x_chunk.shape[1]
|
|
|
|
num_frames = x_chunk.shape[1]
|
|
|
|
self.num_frames += num_frames
|
|
|
|
self.num_frames += num_frames
|
|
|
|
self.remained_wav = self.remained_wav[self.n_shift * num_frames:]
|
|
|
|
self.remained_wav = self.remained_wav[self.n_shift * num_frames:]
|
|
|
@ -165,16 +171,17 @@ class PaddleASRConnectionHanddler:
|
|
|
|
self.subsampling_cache = None
|
|
|
|
self.subsampling_cache = None
|
|
|
|
self.elayers_output_cache = None
|
|
|
|
self.elayers_output_cache = None
|
|
|
|
self.conformer_cnn_cache = None
|
|
|
|
self.conformer_cnn_cache = None
|
|
|
|
self.encoder_outs_ = None
|
|
|
|
self.encoder_out = None
|
|
|
|
self.cached_feat = None
|
|
|
|
self.cached_feat = None
|
|
|
|
self.remained_wav = None
|
|
|
|
self.remained_wav = None
|
|
|
|
self.offset = 0
|
|
|
|
self.offset = 0
|
|
|
|
self.num_samples = 0
|
|
|
|
self.num_samples = 0
|
|
|
|
|
|
|
|
self.device = None
|
|
|
|
|
|
|
|
self.hyps = []
|
|
|
|
self.num_frames = 0
|
|
|
|
self.num_frames = 0
|
|
|
|
self.chunk_num = 0
|
|
|
|
self.chunk_num = 0
|
|
|
|
self.global_frame_offset = 0
|
|
|
|
self.global_frame_offset = 0
|
|
|
|
self.result = []
|
|
|
|
self.result_transcripts = ['']
|
|
|
|
|
|
|
|
|
|
|
|
def decode(self, is_finished=False):
|
|
|
|
def decode(self, is_finished=False):
|
|
|
|
if "deepspeech2online" in self.model_type:
|
|
|
|
if "deepspeech2online" in self.model_type:
|
|
|
@ -187,7 +194,6 @@ class PaddleASRConnectionHanddler:
|
|
|
|
self.advance_decoding(is_finished)
|
|
|
|
self.advance_decoding(is_finished)
|
|
|
|
self.update_result()
|
|
|
|
self.update_result()
|
|
|
|
|
|
|
|
|
|
|
|
return self.result_transcripts[0]
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
logger.exception(e)
|
|
|
|
logger.exception(e)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
@ -203,14 +209,24 @@ class PaddleASRConnectionHanddler:
|
|
|
|
subsampling = self.model.encoder.embed.subsampling_rate
|
|
|
|
subsampling = self.model.encoder.embed.subsampling_rate
|
|
|
|
context = self.model.encoder.embed.right_context + 1
|
|
|
|
context = self.model.encoder.embed.right_context + 1
|
|
|
|
stride = subsampling * decoding_chunk_size
|
|
|
|
stride = subsampling * decoding_chunk_size
|
|
|
|
|
|
|
|
cached_feature_num = context - subsampling # processed chunk feature cached for next chunk
|
|
|
|
|
|
|
|
|
|
|
|
# decoding window for model
|
|
|
|
# decoding window for model
|
|
|
|
decoding_window = (decoding_chunk_size - 1) * subsampling + context
|
|
|
|
decoding_window = (decoding_chunk_size - 1) * subsampling + context
|
|
|
|
|
|
|
|
if self.cached_feat is None:
|
|
|
|
|
|
|
|
logger.info("no audio feat, please input more pcm data")
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
num_frames = self.cached_feat.shape[1]
|
|
|
|
num_frames = self.cached_feat.shape[1]
|
|
|
|
logger.info(f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames")
|
|
|
|
logger.info(f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames")
|
|
|
|
|
|
|
|
|
|
|
|
# the cached feat must be larger decoding_window
|
|
|
|
# the cached feat must be larger decoding_window
|
|
|
|
if num_frames < decoding_window and not is_finished:
|
|
|
|
if num_frames < decoding_window and not is_finished:
|
|
|
|
|
|
|
|
logger.info(f"frame feat num is less than {decoding_window}, please input more pcm data")
|
|
|
|
|
|
|
|
return None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if num_frames < context:
|
|
|
|
|
|
|
|
logger.info("flast {num_frames} is less than context {context} frames, and we cannot do model forward")
|
|
|
|
return None, None
|
|
|
|
return None, None
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("start to do model forward")
|
|
|
|
logger.info("start to do model forward")
|
|
|
@ -242,14 +258,18 @@ class PaddleASRConnectionHanddler:
|
|
|
|
# update the offset
|
|
|
|
# update the offset
|
|
|
|
self.offset += y.shape[1]
|
|
|
|
self.offset += y.shape[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"output size: {len(outputs)}")
|
|
|
|
ys = paddle.cat(outputs, 1)
|
|
|
|
ys = paddle.cat(outputs, 1)
|
|
|
|
masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
|
|
|
|
if self.encoder_out is None:
|
|
|
|
masks = masks.unsqueeze(1)
|
|
|
|
self.encoder_out = ys
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
self.encoder_out = paddle.concat([self.encoder_out, ys], axis=1)
|
|
|
|
|
|
|
|
# masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
|
|
|
|
|
|
|
|
# masks = masks.unsqueeze(1)
|
|
|
|
|
|
|
|
|
|
|
|
# get the ctc probs
|
|
|
|
# get the ctc probs
|
|
|
|
ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size)
|
|
|
|
ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size)
|
|
|
|
ctc_probs = ctc_probs.squeeze(0)
|
|
|
|
ctc_probs = ctc_probs.squeeze(0)
|
|
|
|
# self.searcher.search(xs, ctc_probs, xs.place)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.searcher.search(None, ctc_probs, self.cached_feat.place)
|
|
|
|
self.searcher.search(None, ctc_probs, self.cached_feat.place)
|
|
|
|
|
|
|
|
|
|
|
@ -260,7 +280,8 @@ class PaddleASRConnectionHanddler:
|
|
|
|
self.cached_feat = None
|
|
|
|
self.cached_feat = None
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
assert self.cached_feat.shape[0] == 1
|
|
|
|
assert self.cached_feat.shape[0] == 1
|
|
|
|
self.cached_feat = self.cached_feat[0,end:,:].unsqueeze(0)
|
|
|
|
assert end >= cached_feature_num
|
|
|
|
|
|
|
|
self.cached_feat = self.cached_feat[0,end - cached_feature_num:,:].unsqueeze(0)
|
|
|
|
assert len(self.cached_feat.shape) == 3, f"current cache feat shape is: {self.cached_feat.shape}"
|
|
|
|
assert len(self.cached_feat.shape) == 3, f"current cache feat shape is: {self.cached_feat.shape}"
|
|
|
|
|
|
|
|
|
|
|
|
# ys for rescoring
|
|
|
|
# ys for rescoring
|
|
|
@ -274,9 +295,75 @@ class PaddleASRConnectionHanddler:
|
|
|
|
]
|
|
|
|
]
|
|
|
|
self.result_tokenids = [hyp for hyp in hyps]
|
|
|
|
self.result_tokenids = [hyp for hyp in hyps]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_result(self):
|
|
|
|
|
|
|
|
if len(self.result_transcripts) > 0:
|
|
|
|
|
|
|
|
return self.result_transcripts[0]
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
return ''
|
|
|
|
|
|
|
|
|
|
|
|
def rescoring(self):
|
|
|
|
def rescoring(self):
|
|
|
|
pass
|
|
|
|
logger.info("rescoring the final result")
|
|
|
|
|
|
|
|
if "attention_rescoring" != self.ctc_decode_config.decoding_method:
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.searcher.finalize_search()
|
|
|
|
|
|
|
|
self.update_result()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
beam_size = self.ctc_decode_config.beam_size
|
|
|
|
|
|
|
|
hyps = self.searcher.get_hyps()
|
|
|
|
|
|
|
|
if hyps is None or len(hyps) == 0:
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# assert len(hyps) == beam_size
|
|
|
|
|
|
|
|
paddle.save(self.encoder_out, "encoder.out")
|
|
|
|
|
|
|
|
hyp_list = []
|
|
|
|
|
|
|
|
for hyp in hyps:
|
|
|
|
|
|
|
|
hyp_content = hyp[0]
|
|
|
|
|
|
|
|
# Prevent the hyp is empty
|
|
|
|
|
|
|
|
if len(hyp_content) == 0:
|
|
|
|
|
|
|
|
hyp_content = (self.model.ctc.blank_id, )
|
|
|
|
|
|
|
|
hyp_content = paddle.to_tensor(
|
|
|
|
|
|
|
|
hyp_content, place=self.device, dtype=paddle.long)
|
|
|
|
|
|
|
|
hyp_list.append(hyp_content)
|
|
|
|
|
|
|
|
hyps_pad = pad_sequence(hyp_list, True, self.model.ignore_id)
|
|
|
|
|
|
|
|
hyps_lens = paddle.to_tensor(
|
|
|
|
|
|
|
|
[len(hyp[0]) for hyp in hyps], place=self.device,
|
|
|
|
|
|
|
|
dtype=paddle.long) # (beam_size,)
|
|
|
|
|
|
|
|
hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos,
|
|
|
|
|
|
|
|
self.model.ignore_id)
|
|
|
|
|
|
|
|
hyps_lens = hyps_lens + 1 # Add <sos> at begining
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
encoder_out = self.encoder_out.repeat(beam_size, 1, 1)
|
|
|
|
|
|
|
|
encoder_mask = paddle.ones(
|
|
|
|
|
|
|
|
(beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool)
|
|
|
|
|
|
|
|
decoder_out, _ = self.model.decoder(
|
|
|
|
|
|
|
|
encoder_out, encoder_mask, hyps_pad,
|
|
|
|
|
|
|
|
hyps_lens) # (beam_size, max_hyps_len, vocab_size)
|
|
|
|
|
|
|
|
# ctc score in ln domain
|
|
|
|
|
|
|
|
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
|
|
|
|
|
|
|
|
decoder_out = decoder_out.numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Only use decoder score for rescoring
|
|
|
|
|
|
|
|
best_score = -float('inf')
|
|
|
|
|
|
|
|
best_index = 0
|
|
|
|
|
|
|
|
# hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size
|
|
|
|
|
|
|
|
for i, hyp in enumerate(hyps):
|
|
|
|
|
|
|
|
score = 0.0
|
|
|
|
|
|
|
|
for j, w in enumerate(hyp[0]):
|
|
|
|
|
|
|
|
score += decoder_out[i][j][w]
|
|
|
|
|
|
|
|
# last decoder output token is `eos`, for laste decoder input token.
|
|
|
|
|
|
|
|
score += decoder_out[i][len(hyp[0])][self.model.eos]
|
|
|
|
|
|
|
|
# add ctc score (which in ln domain)
|
|
|
|
|
|
|
|
score += hyp[1] * self.ctc_decode_config.ctc_weight
|
|
|
|
|
|
|
|
if score > best_score:
|
|
|
|
|
|
|
|
best_score = score
|
|
|
|
|
|
|
|
best_index = i
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# update the one best result
|
|
|
|
|
|
|
|
logger.info(f"best index: {best_index}")
|
|
|
|
|
|
|
|
self.hyps = [hyps[best_index][0]]
|
|
|
|
|
|
|
|
self.update_result()
|
|
|
|
|
|
|
|
# return hyps[best_index][0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|