|
|
|
@ -153,8 +153,7 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
spectrum = self.collate_fn_test._normalizer.apply(spectrum)
|
|
|
|
|
|
|
|
|
|
# spectrum augment
|
|
|
|
|
feat = self.collate_fn_test.augmentation.transform_feature(
|
|
|
|
|
spectrum)
|
|
|
|
|
feat = self.collate_fn_test.augmentation.transform_feature(spectrum)
|
|
|
|
|
|
|
|
|
|
# audio_len is frame num
|
|
|
|
|
frame_num = feat.shape[0]
|
|
|
|
@ -189,14 +188,16 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
assert samples.ndim == 1
|
|
|
|
|
|
|
|
|
|
self.num_samples += samples.shape[0]
|
|
|
|
|
logger.info(f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}")
|
|
|
|
|
logger.info(
|
|
|
|
|
f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# self.reamined_wav stores all the samples,
|
|
|
|
|
# include the original remained_wav and this package samples
|
|
|
|
|
if self.remained_wav is None:
|
|
|
|
|
self.remained_wav = samples
|
|
|
|
|
else:
|
|
|
|
|
assert self.remained_wav.ndim == 1 # (T,)
|
|
|
|
|
assert self.remained_wav.ndim == 1 # (T,)
|
|
|
|
|
self.remained_wav = np.concatenate([self.remained_wav, samples])
|
|
|
|
|
logger.info(
|
|
|
|
|
f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}"
|
|
|
|
@ -216,8 +217,8 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
if self.cached_feat is None:
|
|
|
|
|
self.cached_feat = x_chunk
|
|
|
|
|
else:
|
|
|
|
|
assert (len(x_chunk.shape) == 3) # (B,T,D)
|
|
|
|
|
assert (len(self.cached_feat.shape) == 3) # (B,T,D)
|
|
|
|
|
assert (len(x_chunk.shape) == 3) # (B,T,D)
|
|
|
|
|
assert (len(self.cached_feat.shape) == 3) # (B,T,D)
|
|
|
|
|
self.cached_feat = paddle.concat(
|
|
|
|
|
[self.cached_feat, x_chunk], axis=1)
|
|
|
|
|
|
|
|
|
@ -234,7 +235,6 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
# update remained wav
|
|
|
|
|
self.remained_wav = self.remained_wav[self.n_shift * num_frames:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}"
|
|
|
|
|
)
|
|
|
|
@ -246,7 +246,6 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"not supported: {self.model_type}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
|
if "deepspeech2" in self.model_type:
|
|
|
|
|
# for deepspeech2
|
|
|
|
@ -268,7 +267,6 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
self.remained_wav = None
|
|
|
|
|
self.cached_feat = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# partial/ending decoding results
|
|
|
|
|
self.result_transcripts = ['']
|
|
|
|
|
|
|
|
|
@ -280,8 +278,8 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
self.conformer_cnn_cache = None
|
|
|
|
|
self.encoder_out = None
|
|
|
|
|
# conformer decoding state
|
|
|
|
|
self.chunk_num = 0 # globa decoding chunk num
|
|
|
|
|
self.offset = 0 # global offset in decoding frame unit
|
|
|
|
|
self.chunk_num = 0 # globa decoding chunk num
|
|
|
|
|
self.offset = 0 # global offset in decoding frame unit
|
|
|
|
|
self.hyps = []
|
|
|
|
|
|
|
|
|
|
# token timestamp result
|
|
|
|
@ -290,7 +288,6 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
# one best timestamp viterbi prob is large.
|
|
|
|
|
self.time_stamp = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def decode(self, is_finished=False):
|
|
|
|
|
"""advance decoding
|
|
|
|
|
|
|
|
|
@ -373,7 +370,6 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
else:
|
|
|
|
|
raise Exception("invalid model name")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@paddle.no_grad()
|
|
|
|
|
def decode_one_chunk(self, x_chunk, x_chunk_lens):
|
|
|
|
|
"""forward one chunk frames
|
|
|
|
@ -425,10 +421,11 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
logger.info(f"decode one best result for deepspeech2: {trans_best[0]}")
|
|
|
|
|
return trans_best[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@paddle.no_grad()
|
|
|
|
|
def advance_decoding(self, is_finished=False):
|
|
|
|
|
logger.info("Conformer/Transformer: start to decode with advanced_decoding method")
|
|
|
|
|
logger.info(
|
|
|
|
|
"Conformer/Transformer: start to decode with advanced_decoding method"
|
|
|
|
|
)
|
|
|
|
|
cfg = self.ctc_decode_config
|
|
|
|
|
|
|
|
|
|
# cur chunk size, in decoding frame unit
|
|
|
|
@ -563,7 +560,6 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
"""
|
|
|
|
|
return self.word_time_stamp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@paddle.no_grad()
|
|
|
|
|
def rescoring(self):
|
|
|
|
|
"""Second-Pass Decoding,
|
|
|
|
@ -574,7 +570,9 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if "attention_rescoring" != self.ctc_decode_config.decoding_method:
|
|
|
|
|
logger.info(f"decoding method not match: {self.ctc_decode_config.decoding_method}, need attention_rescoring")
|
|
|
|
|
logger.info(
|
|
|
|
|
f"decoding method not match: {self.ctc_decode_config.decoding_method}, need attention_rescoring"
|
|
|
|
|
)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
logger.info("rescoring the final result")
|
|
|
|
@ -605,7 +603,8 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
hyp_content, place=self.device, dtype=paddle.long)
|
|
|
|
|
hyp_list.append(hyp_content)
|
|
|
|
|
|
|
|
|
|
hyps_pad = pad_sequence(hyp_list, batch_first=True, padding_value=self.model.ignore_id)
|
|
|
|
|
hyps_pad = pad_sequence(
|
|
|
|
|
hyp_list, batch_first=True, padding_value=self.model.ignore_id)
|
|
|
|
|
hyps_lens = paddle.to_tensor(
|
|
|
|
|
[len(hyp[0]) for hyp in hyps], place=self.device,
|
|
|
|
|
dtype=paddle.long) # (beam_size,)
|
|
|
|
@ -694,7 +693,6 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
logger.info(f"word time stamp: {self.word_time_stamp}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ASRServerExecutor(ASRExecutor):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__()
|
|
|
|
|