|
|
|
@ -78,6 +78,194 @@ pretrained_models = {
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# ASR server connection process class
|
|
|
|
|
|
|
|
|
|
class PaddleASRConnectionHanddler:
|
|
|
|
|
def __init__(self, asr_engine):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.config = asr_engine.config
|
|
|
|
|
self.model_config = asr_engine.executor.config
|
|
|
|
|
self.asr_engine = asr_engine
|
|
|
|
|
|
|
|
|
|
self.init()
|
|
|
|
|
self.reset()
|
|
|
|
|
|
|
|
|
|
def init(self):
|
|
|
|
|
self.model_type = self.asr_engine.executor.model_type
|
|
|
|
|
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
|
|
|
|
|
pass
|
|
|
|
|
elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type:
|
|
|
|
|
self.sample_rate = self.asr_engine.executor.sample_rate
|
|
|
|
|
|
|
|
|
|
# acoustic model
|
|
|
|
|
self.model = self.asr_engine.executor.model
|
|
|
|
|
|
|
|
|
|
# tokens to text
|
|
|
|
|
self.text_feature = self.asr_engine.executor.text_feature
|
|
|
|
|
|
|
|
|
|
# ctc decoding
|
|
|
|
|
self.ctc_decode_config = self.asr_engine.executor.config.decode
|
|
|
|
|
self.searcher = CTCPrefixBeamSearch(self.ctc_decode_config)
|
|
|
|
|
|
|
|
|
|
# extract fbank
|
|
|
|
|
self.preprocess_conf = self.model_config.preprocess_config
|
|
|
|
|
self.preprocess_args = {"train": False}
|
|
|
|
|
self.preprocessing = Transformation(self.preprocess_conf)
|
|
|
|
|
self.win_length = self.preprocess_conf.process[0]['win_length']
|
|
|
|
|
self.n_shift = self.preprocess_conf.process[0]['n_shift']
|
|
|
|
|
|
|
|
|
|
def extract_feat(self, samples):
|
|
|
|
|
if "deepspeech2online" in self.model_type:
|
|
|
|
|
pass
|
|
|
|
|
elif "conformer2online" in self.model_type:
|
|
|
|
|
logger.info("Online ASR extract the feat")
|
|
|
|
|
samples = np.frombuffer(samples, dtype=np.int16)
|
|
|
|
|
assert samples.ndim == 1
|
|
|
|
|
|
|
|
|
|
logger.info(f"This package receive {samples.shape[0]} pcm data")
|
|
|
|
|
self.num_samples += samples.shape[0]
|
|
|
|
|
|
|
|
|
|
# self.reamined_wav stores all the samples,
|
|
|
|
|
# include the original remained_wav and this package samples
|
|
|
|
|
if self.remained_wav is None:
|
|
|
|
|
self.remained_wav = samples
|
|
|
|
|
else:
|
|
|
|
|
assert self.remained_wav.ndim == 1
|
|
|
|
|
self.remained_wav = np.concatenate([self.remained_wav, samples])
|
|
|
|
|
logger.info(
|
|
|
|
|
f"The connection remain the audio samples: {self.remained_wav.shape}"
|
|
|
|
|
)
|
|
|
|
|
if len(self.remained_wav) < self.win_length:
|
|
|
|
|
return 0
|
|
|
|
|
|
|
|
|
|
# fbank
|
|
|
|
|
x_chunk = self.preprocessing(self.remained_wav, **self.preprocess_args)
|
|
|
|
|
x_chunk = paddle.to_tensor(
|
|
|
|
|
x_chunk, dtype="float32").unsqueeze(axis=0)
|
|
|
|
|
if self.cached_feat is None:
|
|
|
|
|
self.cached_feat = x_chunk
|
|
|
|
|
else:
|
|
|
|
|
self.cached_feat = paddle.concat([self.cached_feat, x_chunk], axis=1)
|
|
|
|
|
|
|
|
|
|
num_frames = x_chunk.shape[1]
|
|
|
|
|
self.num_frames += num_frames
|
|
|
|
|
self.remained_wav = self.remained_wav[self.n_shift * num_frames:]
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}"
|
|
|
|
|
)
|
|
|
|
|
logger.info(
|
|
|
|
|
f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}"
|
|
|
|
|
)
|
|
|
|
|
# logger.info(f"accumulate samples: {self.num_samples}")
|
|
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
|
self.subsampling_cache = None
|
|
|
|
|
self.elayers_output_cache = None
|
|
|
|
|
self.conformer_cnn_cache = None
|
|
|
|
|
self.encoder_outs_ = None
|
|
|
|
|
self.cached_feat = None
|
|
|
|
|
self.remained_wav = None
|
|
|
|
|
self.offset = 0
|
|
|
|
|
self.num_samples = 0
|
|
|
|
|
|
|
|
|
|
self.num_frames = 0
|
|
|
|
|
self.global_frame_offset = 0
|
|
|
|
|
self.result = []
|
|
|
|
|
|
|
|
|
|
def decode(self, is_finished=False):
|
|
|
|
|
if "deepspeech2online" in self.model_type:
|
|
|
|
|
pass
|
|
|
|
|
elif "conformer" in self.model_type or "transformer" in self.model_type:
|
|
|
|
|
try:
|
|
|
|
|
logger.info(
|
|
|
|
|
f"we will use the transformer like model : {self.model_type}"
|
|
|
|
|
)
|
|
|
|
|
self.advance_decoding(is_finished)
|
|
|
|
|
# self.update_result()
|
|
|
|
|
|
|
|
|
|
# return self.result_transcripts[0]
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.exception(e)
|
|
|
|
|
else:
|
|
|
|
|
raise Exception("invalid model name")
|
|
|
|
|
|
|
|
|
|
def advance_decoding(self, is_finished=False):
|
|
|
|
|
logger.info("start to decode with advanced_decoding method")
|
|
|
|
|
cfg = self.ctc_decode_config
|
|
|
|
|
decoding_chunk_size = cfg.decoding_chunk_size
|
|
|
|
|
num_decoding_left_chunks = cfg.num_decoding_left_chunks
|
|
|
|
|
|
|
|
|
|
assert decoding_chunk_size > 0
|
|
|
|
|
subsampling = self.model.encoder.embed.subsampling_rate
|
|
|
|
|
context = self.model.encoder.embed.right_context + 1
|
|
|
|
|
stride = subsampling * decoding_chunk_size
|
|
|
|
|
|
|
|
|
|
# decoding window for model
|
|
|
|
|
decoding_window = (decoding_chunk_size - 1) * subsampling + context
|
|
|
|
|
num_frames = self.cached_feat.shape[1]
|
|
|
|
|
logger.info(f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames")
|
|
|
|
|
|
|
|
|
|
# the cached feat must be larger decoding_window
|
|
|
|
|
if num_frames < decoding_window and not is_finished:
|
|
|
|
|
return None, None
|
|
|
|
|
|
|
|
|
|
# logger.info("start to do model forward")
|
|
|
|
|
# required_cache_size = decoding_chunk_size * num_decoding_left_chunks
|
|
|
|
|
# outputs = []
|
|
|
|
|
|
|
|
|
|
# # num_frames - context + 1 ensure that current frame can get context window
|
|
|
|
|
# if is_finished:
|
|
|
|
|
# # if get the finished chunk, we need process the last context
|
|
|
|
|
# left_frames = context
|
|
|
|
|
# else:
|
|
|
|
|
# # we only process decoding_window frames for one chunk
|
|
|
|
|
# left_frames = decoding_window
|
|
|
|
|
|
|
|
|
|
# logger.info(f"")
|
|
|
|
|
# end = None
|
|
|
|
|
# for cur in range(0, num_frames - left_frames + 1, stride):
|
|
|
|
|
# end = min(cur + decoding_window, num_frames)
|
|
|
|
|
# print(f"cur: {cur}, end: {end}")
|
|
|
|
|
# chunk_xs = self.cached_feat[:, cur:end, :]
|
|
|
|
|
# (y, self.subsampling_cache, self.elayers_output_cache,
|
|
|
|
|
# self.conformer_cnn_cache) = self.model.encoder.forward_chunk(
|
|
|
|
|
# chunk_xs, self.offset, required_cache_size,
|
|
|
|
|
# self.subsampling_cache, self.elayers_output_cache,
|
|
|
|
|
# self.conformer_cnn_cache)
|
|
|
|
|
# outputs.append(y)
|
|
|
|
|
# update the offset
|
|
|
|
|
# self.offset += y.shape[1]
|
|
|
|
|
# self.cached_feat = self.cached_feat[end:]
|
|
|
|
|
# ys = paddle.cat(outputs, 1)
|
|
|
|
|
# masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
|
|
|
|
|
# masks = masks.unsqueeze(1)
|
|
|
|
|
|
|
|
|
|
# # get the ctc probs
|
|
|
|
|
# ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size)
|
|
|
|
|
# ctc_probs = ctc_probs.squeeze(0)
|
|
|
|
|
# # self.searcher.search(xs, ctc_probs, xs.place)
|
|
|
|
|
|
|
|
|
|
# self.searcher.search(None, ctc_probs, self.cached_feat.place)
|
|
|
|
|
|
|
|
|
|
# self.hyps = self.searcher.get_one_best_hyps()
|
|
|
|
|
|
|
|
|
|
# ys for rescoring
|
|
|
|
|
# return ys, masks
|
|
|
|
|
|
|
|
|
|
def update_result(self):
|
|
|
|
|
logger.info("update the final result")
|
|
|
|
|
hyps = self.hyps
|
|
|
|
|
self.result_transcripts = [
|
|
|
|
|
self.text_feature.defeaturize(hyp) for hyp in hyps
|
|
|
|
|
]
|
|
|
|
|
self.result_tokenids = [hyp for hyp in hyps]
|
|
|
|
|
|
|
|
|
|
def rescoring(self):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ASRServerExecutor(ASRExecutor):
|
|
|
|
|
def __init__(self):
|
|
|
|
@ -492,7 +680,7 @@ class ASRServerExecutor(ASRExecutor):
|
|
|
|
|
if sample_rate != self.sample_rate:
|
|
|
|
|
logger.info(f"audio sample rate {sample_rate} is not match,"
|
|
|
|
|
"the model sample_rate is {self.sample_rate}")
|
|
|
|
|
logger.info("ASR Engine use the {self.model_type} to process")
|
|
|
|
|
logger.info(f"ASR Engine use the {self.model_type} to process")
|
|
|
|
|
logger.info("Create the preprocess instance")
|
|
|
|
|
preprocess_conf = self.config.preprocess_config
|
|
|
|
|
preprocess_args = {"train": False}
|
|
|
|
|