|
|
|
@ -60,9 +60,9 @@ pretrained_models = {
|
|
|
|
|
},
|
|
|
|
|
"conformer2online_aishell-zh-16k": {
|
|
|
|
|
'url':
|
|
|
|
|
'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.1.model.tar.gz',
|
|
|
|
|
'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.3.model.tar.gz',
|
|
|
|
|
'md5':
|
|
|
|
|
'b450d5dfaea0ac227c595ce58d18b637',
|
|
|
|
|
'0ac93d390552336f2a906aec9e33c5fa',
|
|
|
|
|
'cfg_path':
|
|
|
|
|
'model.yaml',
|
|
|
|
|
'ckpt_path':
|
|
|
|
@ -78,12 +78,19 @@ pretrained_models = {
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# ASR server connection process class
|
|
|
|
|
|
|
|
|
|
# ASR server connection process class
|
|
|
|
|
class PaddleASRConnectionHanddler:
|
|
|
|
|
def __init__(self, asr_engine):
|
|
|
|
|
"""Init a Paddle ASR Connection Handler instance
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
asr_engine (ASREngine): the global asr engine
|
|
|
|
|
"""
|
|
|
|
|
super().__init__()
|
|
|
|
|
logger.info("create an paddle asr connection handler to process the websocket connection")
|
|
|
|
|
logger.info(
|
|
|
|
|
"create an paddle asr connection handler to process the websocket connection"
|
|
|
|
|
)
|
|
|
|
|
self.config = asr_engine.config
|
|
|
|
|
self.model_config = asr_engine.executor.config
|
|
|
|
|
self.model = asr_engine.executor.model
|
|
|
|
@ -98,24 +105,26 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
pass
|
|
|
|
|
elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type:
|
|
|
|
|
self.sample_rate = self.asr_engine.executor.sample_rate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# acoustic model
|
|
|
|
|
self.model = self.asr_engine.executor.model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# tokens to text
|
|
|
|
|
self.text_feature = self.asr_engine.executor.text_feature
|
|
|
|
|
|
|
|
|
|
# ctc decoding
|
|
|
|
|
|
|
|
|
|
# ctc decoding config
|
|
|
|
|
self.ctc_decode_config = self.asr_engine.executor.config.decode
|
|
|
|
|
self.searcher = CTCPrefixBeamSearch(self.ctc_decode_config)
|
|
|
|
|
|
|
|
|
|
# extract fbank
|
|
|
|
|
# extract feat, new only fbank in conformer model
|
|
|
|
|
self.preprocess_conf = self.model_config.preprocess_config
|
|
|
|
|
self.preprocess_args = {"train": False}
|
|
|
|
|
self.preprocessing = Transformation(self.preprocess_conf)
|
|
|
|
|
|
|
|
|
|
# frame window samples length and frame shift samples length
|
|
|
|
|
self.win_length = self.preprocess_conf.process[0]['win_length']
|
|
|
|
|
self.n_shift = self.preprocess_conf.process[0]['n_shift']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_feat(self, samples):
|
|
|
|
|
if "deepspeech2online" in self.model_type:
|
|
|
|
|
pass
|
|
|
|
@ -123,10 +132,10 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
logger.info("Online ASR extract the feat")
|
|
|
|
|
samples = np.frombuffer(samples, dtype=np.int16)
|
|
|
|
|
assert samples.ndim == 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"This package receive {samples.shape[0]} pcm data")
|
|
|
|
|
self.num_samples += samples.shape[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# self.reamined_wav stores all the samples,
|
|
|
|
|
# include the original remained_wav and this package samples
|
|
|
|
|
if self.remained_wav is None:
|
|
|
|
@ -141,19 +150,21 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
return 0
|
|
|
|
|
|
|
|
|
|
# fbank
|
|
|
|
|
x_chunk = self.preprocessing(self.remained_wav, **self.preprocess_args)
|
|
|
|
|
x_chunk = self.preprocessing(self.remained_wav,
|
|
|
|
|
**self.preprocess_args)
|
|
|
|
|
x_chunk = paddle.to_tensor(
|
|
|
|
|
x_chunk, dtype="float32").unsqueeze(axis=0)
|
|
|
|
|
if self.cached_feat is None:
|
|
|
|
|
self.cached_feat = x_chunk
|
|
|
|
|
else:
|
|
|
|
|
assert(len(x_chunk.shape) == 3)
|
|
|
|
|
assert(len(self.cached_feat.shape) == 3)
|
|
|
|
|
self.cached_feat = paddle.concat([self.cached_feat, x_chunk], axis=1)
|
|
|
|
|
|
|
|
|
|
assert (len(x_chunk.shape) == 3)
|
|
|
|
|
assert (len(self.cached_feat.shape) == 3)
|
|
|
|
|
self.cached_feat = paddle.concat(
|
|
|
|
|
[self.cached_feat, x_chunk], axis=1)
|
|
|
|
|
|
|
|
|
|
# set the feat device
|
|
|
|
|
if self.device is None:
|
|
|
|
|
self.device = self.cached_feat.place
|
|
|
|
|
self.device = self.cached_feat.place
|
|
|
|
|
|
|
|
|
|
num_frames = x_chunk.shape[1]
|
|
|
|
|
self.num_frames += num_frames
|
|
|
|
@ -161,7 +172,7 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}"
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
logger.info(
|
|
|
|
|
f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}"
|
|
|
|
|
)
|
|
|
|
@ -209,24 +220,30 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
subsampling = self.model.encoder.embed.subsampling_rate
|
|
|
|
|
context = self.model.encoder.embed.right_context + 1
|
|
|
|
|
stride = subsampling * decoding_chunk_size
|
|
|
|
|
cached_feature_num = context - subsampling # processed chunk feature cached for next chunk
|
|
|
|
|
cached_feature_num = context - subsampling # processed chunk feature cached for next chunk
|
|
|
|
|
|
|
|
|
|
# decoding window for model
|
|
|
|
|
decoding_window = (decoding_chunk_size - 1) * subsampling + context
|
|
|
|
|
if self.cached_feat is None:
|
|
|
|
|
logger.info("no audio feat, please input more pcm data")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_frames = self.cached_feat.shape[1]
|
|
|
|
|
logger.info(f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames")
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# the cached feat must be larger decoding_window
|
|
|
|
|
if num_frames < decoding_window and not is_finished:
|
|
|
|
|
logger.info(f"frame feat num is less than {decoding_window}, please input more pcm data")
|
|
|
|
|
logger.info(
|
|
|
|
|
f"frame feat num is less than {decoding_window}, please input more pcm data"
|
|
|
|
|
)
|
|
|
|
|
return None, None
|
|
|
|
|
|
|
|
|
|
if num_frames < context:
|
|
|
|
|
logger.info("flast {num_frames} is less than context {context} frames, and we cannot do model forward")
|
|
|
|
|
logger.info(
|
|
|
|
|
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
|
|
|
|
|
)
|
|
|
|
|
return None, None
|
|
|
|
|
|
|
|
|
|
logger.info("start to do model forward")
|
|
|
|
@ -235,17 +252,17 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
|
|
|
|
|
# num_frames - context + 1 ensure that current frame can get context window
|
|
|
|
|
if is_finished:
|
|
|
|
|
# if get the finished chunk, we need process the last context
|
|
|
|
|
# if get the finished chunk, we need process the last context
|
|
|
|
|
left_frames = context
|
|
|
|
|
else:
|
|
|
|
|
# we only process decoding_window frames for one chunk
|
|
|
|
|
left_frames = decoding_window
|
|
|
|
|
|
|
|
|
|
left_frames = decoding_window
|
|
|
|
|
|
|
|
|
|
# record the end for removing the processed feat
|
|
|
|
|
end = None
|
|
|
|
|
for cur in range(0, num_frames - left_frames + 1, stride):
|
|
|
|
|
end = min(cur + decoding_window, num_frames)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.chunk_num += 1
|
|
|
|
|
chunk_xs = self.cached_feat[:, cur:end, :]
|
|
|
|
|
(y, self.subsampling_cache, self.elayers_output_cache,
|
|
|
|
@ -257,35 +274,31 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
|
|
|
|
|
# update the offset
|
|
|
|
|
self.offset += y.shape[1]
|
|
|
|
|
|
|
|
|
|
logger.info(f"output size: {len(outputs)}")
|
|
|
|
|
|
|
|
|
|
ys = paddle.cat(outputs, 1)
|
|
|
|
|
if self.encoder_out is None:
|
|
|
|
|
self.encoder_out = ys
|
|
|
|
|
self.encoder_out = ys
|
|
|
|
|
else:
|
|
|
|
|
self.encoder_out = paddle.concat([self.encoder_out, ys], axis=1)
|
|
|
|
|
# masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
|
|
|
|
|
# masks = masks.unsqueeze(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# get the ctc probs
|
|
|
|
|
ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size)
|
|
|
|
|
ctc_probs = ctc_probs.squeeze(0)
|
|
|
|
|
|
|
|
|
|
self.searcher.search(None, ctc_probs, self.cached_feat.place)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.hyps = self.searcher.get_one_best_hyps()
|
|
|
|
|
assert self.cached_feat.shape[0] == 1
|
|
|
|
|
assert end >= cached_feature_num
|
|
|
|
|
self.cached_feat = self.cached_feat[0, end -
|
|
|
|
|
cached_feature_num:, :].unsqueeze(0)
|
|
|
|
|
assert len(
|
|
|
|
|
self.cached_feat.shape
|
|
|
|
|
) == 3, f"current cache feat shape is: {self.cached_feat.shape}"
|
|
|
|
|
|
|
|
|
|
# remove the processed feat
|
|
|
|
|
if end == num_frames:
|
|
|
|
|
self.cached_feat = None
|
|
|
|
|
else:
|
|
|
|
|
assert self.cached_feat.shape[0] == 1
|
|
|
|
|
assert end >= cached_feature_num
|
|
|
|
|
self.cached_feat = self.cached_feat[0,end - cached_feature_num:,:].unsqueeze(0)
|
|
|
|
|
assert len(self.cached_feat.shape) == 3, f"current cache feat shape is: {self.cached_feat.shape}"
|
|
|
|
|
|
|
|
|
|
# ys for rescoring
|
|
|
|
|
# return ys, masks
|
|
|
|
|
logger.info(
|
|
|
|
|
f"This connection handler encoder out shape: {self.encoder_out.shape}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def update_result(self):
|
|
|
|
|
logger.info("update the final result")
|
|
|
|
@ -304,8 +317,8 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
def rescoring(self):
|
|
|
|
|
logger.info("rescoring the final result")
|
|
|
|
|
if "attention_rescoring" != self.ctc_decode_config.decoding_method:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
self.searcher.finalize_search()
|
|
|
|
|
self.update_result()
|
|
|
|
|
|
|
|
|
@ -363,8 +376,6 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
logger.info(f"best index: {best_index}")
|
|
|
|
|
self.hyps = [hyps[best_index][0]]
|
|
|
|
|
self.update_result()
|
|
|
|
|
# return hyps[best_index][0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ASRServerExecutor(ASRExecutor):
|
|
|
|
@ -409,9 +420,9 @@ class ASRServerExecutor(ASRExecutor):
|
|
|
|
|
logger.info(f"Load the pretrained model, tag = {tag}")
|
|
|
|
|
res_path = self._get_pretrained_path(tag) # wenetspeech_zh
|
|
|
|
|
self.res_path = res_path
|
|
|
|
|
self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/examples/aishell/asr1/model.yaml"
|
|
|
|
|
# self.cfg_path = os.path.join(res_path,
|
|
|
|
|
# pretrained_models[tag]['cfg_path'])
|
|
|
|
|
|
|
|
|
|
self.cfg_path = os.path.join(res_path,
|
|
|
|
|
pretrained_models[tag]['cfg_path'])
|
|
|
|
|
|
|
|
|
|
self.am_model = os.path.join(res_path,
|
|
|
|
|
pretrained_models[tag]['model'])
|
|
|
|
@ -639,7 +650,7 @@ class ASRServerExecutor(ASRExecutor):
|
|
|
|
|
subsampling = self.model.encoder.embed.subsampling_rate
|
|
|
|
|
context = self.model.encoder.embed.right_context + 1
|
|
|
|
|
stride = subsampling * decoding_chunk_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# decoding window for model
|
|
|
|
|
decoding_window = (decoding_chunk_size - 1) * subsampling + context
|
|
|
|
|
num_frames = xs.shape[1]
|
|
|
|
|