|
|
|
@ -13,7 +13,7 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
import os
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
import copy
|
|
|
|
|
import numpy as np
|
|
|
|
|
import paddle
|
|
|
|
|
from numpy import float32
|
|
|
|
@ -93,7 +93,7 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
)
|
|
|
|
|
self.config = asr_engine.config
|
|
|
|
|
self.model_config = asr_engine.executor.config
|
|
|
|
|
self.model = asr_engine.executor.model
|
|
|
|
|
# self.model = asr_engine.executor.model
|
|
|
|
|
self.asr_engine = asr_engine
|
|
|
|
|
|
|
|
|
|
self.init()
|
|
|
|
@ -102,7 +102,32 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
def init(self):
|
|
|
|
|
self.model_type = self.asr_engine.executor.model_type
|
|
|
|
|
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
|
|
|
|
|
pass
|
|
|
|
|
from paddlespeech.s2t.io.collator import SpeechCollator
|
|
|
|
|
self.sample_rate = self.asr_engine.executor.sample_rate
|
|
|
|
|
self.am_predictor = self.asr_engine.executor.am_predictor
|
|
|
|
|
self.text_feature = self.asr_engine.executor.text_feature
|
|
|
|
|
self.collate_fn_test = SpeechCollator.from_config(self.model_config)
|
|
|
|
|
self.decoder = CTCDecoder(
|
|
|
|
|
odim=self.model_config.output_dim, # <blank> is in vocab
|
|
|
|
|
enc_n_units=self.model_config.rnn_layer_size * 2,
|
|
|
|
|
blank_id=self.model_config.blank_id,
|
|
|
|
|
dropout_rate=0.0,
|
|
|
|
|
reduction=True, # sum
|
|
|
|
|
batch_average=True, # sum / batch_size
|
|
|
|
|
grad_norm_type=self.model_config.get('ctc_grad_norm_type', None))
|
|
|
|
|
|
|
|
|
|
cfg = self.model_config.decode
|
|
|
|
|
decode_batch_size = 1 # for online
|
|
|
|
|
self.decoder.init_decoder(
|
|
|
|
|
decode_batch_size, self.text_feature.vocab_list,
|
|
|
|
|
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
|
|
|
|
|
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
|
|
|
|
|
cfg.num_proc_bsearch)
|
|
|
|
|
# frame window samples length and frame shift samples length
|
|
|
|
|
|
|
|
|
|
self.win_length = int(self.model_config.window_ms * self.sample_rate)
|
|
|
|
|
self.n_shift = int(self.model_config.stride_ms * self.sample_rate)
|
|
|
|
|
|
|
|
|
|
elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type:
|
|
|
|
|
self.sample_rate = self.asr_engine.executor.sample_rate
|
|
|
|
|
|
|
|
|
@ -127,7 +152,65 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
|
|
|
|
|
def extract_feat(self, samples):
|
|
|
|
|
if "deepspeech2online" in self.model_type:
|
|
|
|
|
pass
|
|
|
|
|
# self.reamined_wav stores all the samples,
|
|
|
|
|
# include the original remained_wav and this package samples
|
|
|
|
|
samples = np.frombuffer(samples, dtype=np.int16)
|
|
|
|
|
assert samples.ndim == 1
|
|
|
|
|
|
|
|
|
|
if self.remained_wav is None:
|
|
|
|
|
self.remained_wav = samples
|
|
|
|
|
else:
|
|
|
|
|
assert self.remained_wav.ndim == 1
|
|
|
|
|
self.remained_wav = np.concatenate([self.remained_wav, samples])
|
|
|
|
|
logger.info(
|
|
|
|
|
f"The connection remain the audio samples: {self.remained_wav.shape}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# pcm16 -> pcm 32
|
|
|
|
|
samples = pcm2float(self.remained_wav)
|
|
|
|
|
# read audio
|
|
|
|
|
speech_segment = SpeechSegment.from_pcm(
|
|
|
|
|
samples, self.sample_rate, transcript=" ")
|
|
|
|
|
# audio augment
|
|
|
|
|
self.collate_fn_test.augmentation.transform_audio(speech_segment)
|
|
|
|
|
|
|
|
|
|
# extract speech feature
|
|
|
|
|
spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize(
|
|
|
|
|
speech_segment, self.collate_fn_test.keep_transcription_text)
|
|
|
|
|
# CMVN spectrum
|
|
|
|
|
if self.collate_fn_test._normalizer:
|
|
|
|
|
spectrum = self.collate_fn_test._normalizer.apply(spectrum)
|
|
|
|
|
|
|
|
|
|
# spectrum augment
|
|
|
|
|
audio = self.collate_fn_test.augmentation.transform_feature(
|
|
|
|
|
spectrum)
|
|
|
|
|
|
|
|
|
|
audio_len = audio.shape[0]
|
|
|
|
|
audio = paddle.to_tensor(audio, dtype='float32')
|
|
|
|
|
# audio_len = paddle.to_tensor(audio_len)
|
|
|
|
|
audio = paddle.unsqueeze(audio, axis=0)
|
|
|
|
|
|
|
|
|
|
if self.cached_feat is None:
|
|
|
|
|
self.cached_feat = audio
|
|
|
|
|
else:
|
|
|
|
|
assert (len(audio.shape) == 3)
|
|
|
|
|
assert (len(self.cached_feat.shape) == 3)
|
|
|
|
|
self.cached_feat = paddle.concat(
|
|
|
|
|
[self.cached_feat, audio], axis=1)
|
|
|
|
|
|
|
|
|
|
# set the feat device
|
|
|
|
|
if self.device is None:
|
|
|
|
|
self.device = self.cached_feat.place
|
|
|
|
|
|
|
|
|
|
self.num_frames += audio_len
|
|
|
|
|
self.remained_wav = self.remained_wav[self.n_shift * audio_len:]
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}"
|
|
|
|
|
)
|
|
|
|
|
logger.info(
|
|
|
|
|
f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}"
|
|
|
|
|
)
|
|
|
|
|
elif "conformer2online" in self.model_type:
|
|
|
|
|
logger.info("Online ASR extract the feat")
|
|
|
|
|
samples = np.frombuffer(samples, dtype=np.int16)
|
|
|
|
@ -179,24 +262,81 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
# logger.info(f"accumulate samples: {self.num_samples}")
|
|
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
|
self.subsampling_cache = None
|
|
|
|
|
self.elayers_output_cache = None
|
|
|
|
|
self.conformer_cnn_cache = None
|
|
|
|
|
self.encoder_out = None
|
|
|
|
|
self.cached_feat = None
|
|
|
|
|
self.remained_wav = None
|
|
|
|
|
self.offset = 0
|
|
|
|
|
self.num_samples = 0
|
|
|
|
|
self.device = None
|
|
|
|
|
self.hyps = []
|
|
|
|
|
self.num_frames = 0
|
|
|
|
|
self.chunk_num = 0
|
|
|
|
|
self.global_frame_offset = 0
|
|
|
|
|
self.result_transcripts = ['']
|
|
|
|
|
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
|
|
|
|
|
# for deepspeech2
|
|
|
|
|
self.chunk_state_h_box = copy.deepcopy(self.asr_engine.executor.chunk_state_h_box)
|
|
|
|
|
self.chunk_state_c_box = copy.deepcopy(self.asr_engine.executor.chunk_state_c_box)
|
|
|
|
|
self.decoder.reset_decoder(batch_size=1)
|
|
|
|
|
elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type:
|
|
|
|
|
# for conformer online
|
|
|
|
|
self.subsampling_cache = None
|
|
|
|
|
self.elayers_output_cache = None
|
|
|
|
|
self.conformer_cnn_cache = None
|
|
|
|
|
self.encoder_out = None
|
|
|
|
|
self.cached_feat = None
|
|
|
|
|
self.remained_wav = None
|
|
|
|
|
self.offset = 0
|
|
|
|
|
self.num_samples = 0
|
|
|
|
|
self.device = None
|
|
|
|
|
self.hyps = []
|
|
|
|
|
self.num_frames = 0
|
|
|
|
|
self.chunk_num = 0
|
|
|
|
|
self.global_frame_offset = 0
|
|
|
|
|
self.result_transcripts = ['']
|
|
|
|
|
|
|
|
|
|
def decode(self, is_finished=False):
|
|
|
|
|
if "deepspeech2online" in self.model_type:
|
|
|
|
|
pass
|
|
|
|
|
# x_chunk 是特征数据
|
|
|
|
|
decoding_chunk_size = 1 # decoding_chunk_size=1 in deepspeech2 model
|
|
|
|
|
context = 7 # context=7 in deepspeech2 model
|
|
|
|
|
subsampling = 4 # subsampling=4 in deepspeech2 model
|
|
|
|
|
stride = subsampling * decoding_chunk_size
|
|
|
|
|
cached_feature_num = context - subsampling
|
|
|
|
|
# decoding window for model
|
|
|
|
|
decoding_window = (decoding_chunk_size - 1) * subsampling + context
|
|
|
|
|
|
|
|
|
|
if self.cached_feat is None:
|
|
|
|
|
logger.info("no audio feat, please input more pcm data")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
num_frames = self.cached_feat.shape[1]
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames"
|
|
|
|
|
)
|
|
|
|
|
# the cached feat must be larger decoding_window
|
|
|
|
|
if num_frames < decoding_window and not is_finished:
|
|
|
|
|
logger.info(
|
|
|
|
|
f"frame feat num is less than {decoding_window}, please input more pcm data"
|
|
|
|
|
)
|
|
|
|
|
return None, None
|
|
|
|
|
|
|
|
|
|
# if is_finished=True, we need at least context frames
|
|
|
|
|
if num_frames < context:
|
|
|
|
|
logger.info(
|
|
|
|
|
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
|
|
|
|
|
)
|
|
|
|
|
return None, None
|
|
|
|
|
logger.info("start to do model forward")
|
|
|
|
|
# num_frames - context + 1 ensure that current frame can get context window
|
|
|
|
|
if is_finished:
|
|
|
|
|
# if get the finished chunk, we need process the last context
|
|
|
|
|
left_frames = context
|
|
|
|
|
else:
|
|
|
|
|
# we only process decoding_window frames for one chunk
|
|
|
|
|
left_frames = decoding_window
|
|
|
|
|
|
|
|
|
|
for cur in range(0, num_frames - left_frames + 1, stride):
|
|
|
|
|
end = min(cur + decoding_window, num_frames)
|
|
|
|
|
# extract the audio
|
|
|
|
|
x_chunk = self.cached_feat[:, cur:end, :].numpy()
|
|
|
|
|
x_chunk_lens = np.array([x_chunk.shape[1]])
|
|
|
|
|
trans_best = self.decode_one_chunk(x_chunk, x_chunk_lens)
|
|
|
|
|
|
|
|
|
|
self.result_transcripts = [trans_best]
|
|
|
|
|
|
|
|
|
|
self.cached_feat = self.cached_feat[:, end -
|
|
|
|
|
cached_feature_num:, :]
|
|
|
|
|
# return trans_best[0]
|
|
|
|
|
elif "conformer" in self.model_type or "transformer" in self.model_type:
|
|
|
|
|
try:
|
|
|
|
|
logger.info(
|
|
|
|
@ -210,6 +350,48 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
else:
|
|
|
|
|
raise Exception("invalid model name")
|
|
|
|
|
|
|
|
|
|
def decode_one_chunk(self, x_chunk, x_chunk_lens):
|
|
|
|
|
logger.info("start to decoce one chunk with deepspeech2 model")
|
|
|
|
|
input_names = self.am_predictor.get_input_names()
|
|
|
|
|
audio_handle = self.am_predictor.get_input_handle(input_names[0])
|
|
|
|
|
audio_len_handle = self.am_predictor.get_input_handle(
|
|
|
|
|
input_names[1])
|
|
|
|
|
h_box_handle = self.am_predictor.get_input_handle(input_names[2])
|
|
|
|
|
c_box_handle = self.am_predictor.get_input_handle(input_names[3])
|
|
|
|
|
|
|
|
|
|
audio_handle.reshape(x_chunk.shape)
|
|
|
|
|
audio_handle.copy_from_cpu(x_chunk)
|
|
|
|
|
|
|
|
|
|
audio_len_handle.reshape(x_chunk_lens.shape)
|
|
|
|
|
audio_len_handle.copy_from_cpu(x_chunk_lens)
|
|
|
|
|
|
|
|
|
|
h_box_handle.reshape(self.chunk_state_h_box.shape)
|
|
|
|
|
h_box_handle.copy_from_cpu(self.chunk_state_h_box)
|
|
|
|
|
|
|
|
|
|
c_box_handle.reshape(self.chunk_state_c_box.shape)
|
|
|
|
|
c_box_handle.copy_from_cpu(self.chunk_state_c_box)
|
|
|
|
|
|
|
|
|
|
output_names = self.am_predictor.get_output_names()
|
|
|
|
|
output_handle = self.am_predictor.get_output_handle(output_names[0])
|
|
|
|
|
output_lens_handle = self.am_predictor.get_output_handle(
|
|
|
|
|
output_names[1])
|
|
|
|
|
output_state_h_handle = self.am_predictor.get_output_handle(
|
|
|
|
|
output_names[2])
|
|
|
|
|
output_state_c_handle = self.am_predictor.get_output_handle(
|
|
|
|
|
output_names[3])
|
|
|
|
|
|
|
|
|
|
self.am_predictor.run()
|
|
|
|
|
|
|
|
|
|
output_chunk_probs = output_handle.copy_to_cpu()
|
|
|
|
|
output_chunk_lens = output_lens_handle.copy_to_cpu()
|
|
|
|
|
self.chunk_state_h_box = output_state_h_handle.copy_to_cpu()
|
|
|
|
|
self.chunk_state_c_box = output_state_c_handle.copy_to_cpu()
|
|
|
|
|
|
|
|
|
|
self.decoder.next(output_chunk_probs, output_chunk_lens)
|
|
|
|
|
trans_best, trans_beam = self.decoder.decode()
|
|
|
|
|
logger.info(f"decode one one best result: {trans_best[0]}")
|
|
|
|
|
return trans_best[0]
|
|
|
|
|
|
|
|
|
|
def advance_decoding(self, is_finished=False):
|
|
|
|
|
logger.info("start to decode with advanced_decoding method")
|
|
|
|
|
cfg = self.ctc_decode_config
|
|
|
|
@ -240,6 +422,7 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
)
|
|
|
|
|
return None, None
|
|
|
|
|
|
|
|
|
|
# if is_finished=True, we need at least context frames
|
|
|
|
|
if num_frames < context:
|
|
|
|
|
logger.info(
|
|
|
|
|
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
|
|
|
|
@ -315,6 +498,9 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
return ''
|
|
|
|
|
|
|
|
|
|
def rescoring(self):
|
|
|
|
|
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
logger.info("rescoring the final result")
|
|
|
|
|
if "attention_rescoring" != self.ctc_decode_config.decoding_method:
|
|
|
|
|
return
|
|
|
|
|