# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import os from typing import Optional import numpy as np import paddle from numpy import float32 from yacs.config import CfgNode from .pretrained_models import pretrained_models from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.log import logger from paddlespeech.cli.utils import MODEL_HOME from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.speech import SpeechSegment from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.tensor_utils import add_sos_eos from paddlespeech.s2t.utils.tensor_utils import pad_sequence from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.server.engine.asr.online.ctc_search import CTCPrefixBeamSearch from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.audio_process import pcm2float from paddlespeech.server.utils.paddle_predictor import init_predictor __all__ = ['ASREngine'] # ASR server connection process class class PaddleASRConnectionHanddler: def __init__(self, asr_engine): """Init a Paddle ASR Connection Handler instance Args: asr_engine (ASREngine): the global asr engine """ super().__init__() logger.info( "create an paddle asr connection handler to process the websocket connection" ) self.config = asr_engine.config self.model_config = asr_engine.executor.config self.asr_engine = asr_engine self.init() self.reset() def init(self): # model_type, sample_rate and text_feature is shared for deepspeech2 and conformer self.model_type = self.asr_engine.executor.model_type self.sample_rate = self.asr_engine.executor.sample_rate # tokens to text self.text_feature = self.asr_engine.executor.text_feature if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: from paddlespeech.s2t.io.collator import SpeechCollator self.am_predictor = self.asr_engine.executor.am_predictor self.collate_fn_test = SpeechCollator.from_config(self.model_config) self.decoder = CTCDecoder( odim=self.model_config.output_dim, # is in vocab enc_n_units=self.model_config.rnn_layer_size * 2, blank_id=self.model_config.blank_id, dropout_rate=0.0, reduction=True, # sum batch_average=True, # sum / batch_size grad_norm_type=self.model_config.get('ctc_grad_norm_type', None)) cfg = self.model_config.decode decode_batch_size = 1 # for online self.decoder.init_decoder( decode_batch_size, self.text_feature.vocab_list, cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, cfg.num_proc_bsearch) # frame window samples length and frame shift samples length self.win_length = int(self.model_config.window_ms / 1000 * self.sample_rate) self.n_shift = int(self.model_config.stride_ms / 1000 * self.sample_rate) elif "conformer" in self.model_type or "transformer" in self.model_type: # acoustic model self.model = self.asr_engine.executor.model # ctc decoding config self.ctc_decode_config = self.asr_engine.executor.config.decode self.searcher = CTCPrefixBeamSearch(self.ctc_decode_config) # extract feat, new only fbank in conformer model self.preprocess_conf = self.model_config.preprocess_config self.preprocess_args = {"train": False} self.preprocessing = Transformation(self.preprocess_conf) # frame window samples length and frame shift samples length self.win_length = self.preprocess_conf.process[0]['win_length'] self.n_shift = self.preprocess_conf.process[0]['n_shift'] def extract_feat(self, samples): # we compute the elapsed time of first char occuring # and we record the start time at the first pcm sample arraving # if self.first_char_occur_elapsed is not None: # self.first_char_occur_elapsed = time.time() if "deepspeech2online" in self.model_type: # self.reamined_wav stores all the samples, # include the original remained_wav and this package samples samples = np.frombuffer(samples, dtype=np.int16) assert samples.ndim == 1 # pcm16 -> pcm 32 # pcm2float will change the orignal samples, # so we shoule do pcm2float before concatenate samples = pcm2float(samples) if self.remained_wav is None: self.remained_wav = samples else: assert self.remained_wav.ndim == 1 self.remained_wav = np.concatenate([self.remained_wav, samples]) logger.info( f"The connection remain the audio samples: {self.remained_wav.shape}" ) # read audio speech_segment = SpeechSegment.from_pcm( self.remained_wav, self.sample_rate, transcript=" ") # audio augment self.collate_fn_test.augmentation.transform_audio(speech_segment) # extract speech feature spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize( speech_segment, self.collate_fn_test.keep_transcription_text) # CMVN spectrum if self.collate_fn_test._normalizer: spectrum = self.collate_fn_test._normalizer.apply(spectrum) # spectrum augment audio = self.collate_fn_test.augmentation.transform_feature( spectrum) audio_len = audio.shape[0] audio = paddle.to_tensor(audio, dtype='float32') # audio_len = paddle.to_tensor(audio_len) audio = paddle.unsqueeze(audio, axis=0) if self.cached_feat is None: self.cached_feat = audio else: assert (len(audio.shape) == 3) assert (len(self.cached_feat.shape) == 3) self.cached_feat = paddle.concat( [self.cached_feat, audio], axis=1) # set the feat device if self.device is None: self.device = self.cached_feat.place self.num_frames += audio_len self.remained_wav = self.remained_wav[self.n_shift * audio_len:] logger.info( f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}" ) logger.info( f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" ) elif "conformer_online" in self.model_type: logger.info("Online ASR extract the feat") samples = np.frombuffer(samples, dtype=np.int16) assert samples.ndim == 1 logger.info(f"This package receive {samples.shape[0]} pcm data") self.num_samples += samples.shape[0] # self.reamined_wav stores all the samples, # include the original remained_wav and this package samples if self.remained_wav is None: self.remained_wav = samples else: assert self.remained_wav.ndim == 1 self.remained_wav = np.concatenate([self.remained_wav, samples]) logger.info( f"The connection remain the audio samples: {self.remained_wav.shape}" ) if len(self.remained_wav) < self.win_length: return 0 # fbank x_chunk = self.preprocessing(self.remained_wav, **self.preprocess_args) x_chunk = paddle.to_tensor( x_chunk, dtype="float32").unsqueeze(axis=0) if self.cached_feat is None: self.cached_feat = x_chunk else: assert (len(x_chunk.shape) == 3) assert (len(self.cached_feat.shape) == 3) self.cached_feat = paddle.concat( [self.cached_feat, x_chunk], axis=1) # set the feat device if self.device is None: self.device = self.cached_feat.place num_frames = x_chunk.shape[1] self.num_frames += num_frames self.remained_wav = self.remained_wav[self.n_shift * num_frames:] logger.info( f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}" ) logger.info( f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" ) # logger.info(f"accumulate samples: {self.num_samples}") def reset(self): if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: # for deepspeech2 self.chunk_state_h_box = copy.deepcopy( self.asr_engine.executor.chunk_state_h_box) self.chunk_state_c_box = copy.deepcopy( self.asr_engine.executor.chunk_state_c_box) self.decoder.reset_decoder(batch_size=1) # for conformer online self.subsampling_cache = None self.elayers_output_cache = None self.conformer_cnn_cache = None self.encoder_out = None self.cached_feat = None self.remained_wav = None self.offset = 0 self.num_samples = 0 self.device = None self.hyps = [] self.num_frames = 0 self.chunk_num = 0 self.global_frame_offset = 0 self.result_transcripts = [''] self.word_time_stamp = [] self.time_stamp = [] self.first_char_occur_elapsed = None def decode(self, is_finished=False): if "deepspeech2online" in self.model_type: # x_chunk 是特征数据 decoding_chunk_size = 1 # decoding_chunk_size=1 in deepspeech2 model context = 7 # context=7 in deepspeech2 model subsampling = 4 # subsampling=4 in deepspeech2 model stride = subsampling * decoding_chunk_size cached_feature_num = context - subsampling # decoding window for model decoding_window = (decoding_chunk_size - 1) * subsampling + context if self.cached_feat is None: logger.info("no audio feat, please input more pcm data") return num_frames = self.cached_feat.shape[1] logger.info( f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" ) # the cached feat must be larger decoding_window if num_frames < decoding_window and not is_finished: logger.info( f"frame feat num is less than {decoding_window}, please input more pcm data" ) return None, None # if is_finished=True, we need at least context frames if num_frames < context: logger.info( "flast {num_frames} is less than context {context} frames, and we cannot do model forward" ) return None, None logger.info("start to do model forward") # num_frames - context + 1 ensure that current frame can get context window if is_finished: # if get the finished chunk, we need process the last context left_frames = context else: # we only process decoding_window frames for one chunk left_frames = decoding_window for cur in range(0, num_frames - left_frames + 1, stride): end = min(cur + decoding_window, num_frames) # extract the audio x_chunk = self.cached_feat[:, cur:end, :].numpy() x_chunk_lens = np.array([x_chunk.shape[1]]) trans_best = self.decode_one_chunk(x_chunk, x_chunk_lens) self.result_transcripts = [trans_best] self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :] # return trans_best[0] elif "conformer" in self.model_type or "transformer" in self.model_type: try: logger.info( f"we will use the transformer like model : {self.model_type}" ) self.advance_decoding(is_finished) self.update_result() except Exception as e: logger.exception(e) else: raise Exception("invalid model name") @paddle.no_grad() def decode_one_chunk(self, x_chunk, x_chunk_lens): logger.info("start to decoce one chunk with deepspeech2 model") input_names = self.am_predictor.get_input_names() audio_handle = self.am_predictor.get_input_handle(input_names[0]) audio_len_handle = self.am_predictor.get_input_handle(input_names[1]) h_box_handle = self.am_predictor.get_input_handle(input_names[2]) c_box_handle = self.am_predictor.get_input_handle(input_names[3]) audio_handle.reshape(x_chunk.shape) audio_handle.copy_from_cpu(x_chunk) audio_len_handle.reshape(x_chunk_lens.shape) audio_len_handle.copy_from_cpu(x_chunk_lens) h_box_handle.reshape(self.chunk_state_h_box.shape) h_box_handle.copy_from_cpu(self.chunk_state_h_box) c_box_handle.reshape(self.chunk_state_c_box.shape) c_box_handle.copy_from_cpu(self.chunk_state_c_box) output_names = self.am_predictor.get_output_names() output_handle = self.am_predictor.get_output_handle(output_names[0]) output_lens_handle = self.am_predictor.get_output_handle( output_names[1]) output_state_h_handle = self.am_predictor.get_output_handle( output_names[2]) output_state_c_handle = self.am_predictor.get_output_handle( output_names[3]) self.am_predictor.run() output_chunk_probs = output_handle.copy_to_cpu() output_chunk_lens = output_lens_handle.copy_to_cpu() self.chunk_state_h_box = output_state_h_handle.copy_to_cpu() self.chunk_state_c_box = output_state_c_handle.copy_to_cpu() self.decoder.next(output_chunk_probs, output_chunk_lens) trans_best, trans_beam = self.decoder.decode() logger.info(f"decode one best result: {trans_best[0]}") return trans_best[0] @paddle.no_grad() def advance_decoding(self, is_finished=False): logger.info("start to decode with advanced_decoding method") cfg = self.ctc_decode_config decoding_chunk_size = cfg.decoding_chunk_size num_decoding_left_chunks = cfg.num_decoding_left_chunks assert decoding_chunk_size > 0 subsampling = self.model.encoder.embed.subsampling_rate context = self.model.encoder.embed.right_context + 1 stride = subsampling * decoding_chunk_size cached_feature_num = context - subsampling # processed chunk feature cached for next chunk # decoding window for model decoding_window = (decoding_chunk_size - 1) * subsampling + context if self.cached_feat is None: logger.info("no audio feat, please input more pcm data") return num_frames = self.cached_feat.shape[1] logger.info( f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" ) # the cached feat must be larger decoding_window if num_frames < decoding_window and not is_finished: logger.info( f"frame feat num is less than {decoding_window}, please input more pcm data" ) return None, None # if is_finished=True, we need at least context frames if num_frames < context: logger.info( "flast {num_frames} is less than context {context} frames, and we cannot do model forward" ) return None, None logger.info("start to do model forward") required_cache_size = decoding_chunk_size * num_decoding_left_chunks outputs = [] # num_frames - context + 1 ensure that current frame can get context window if is_finished: # if get the finished chunk, we need process the last context left_frames = context else: # we only process decoding_window frames for one chunk left_frames = decoding_window # record the end for removing the processed feat end = None for cur in range(0, num_frames - left_frames + 1, stride): end = min(cur + decoding_window, num_frames) self.chunk_num += 1 chunk_xs = self.cached_feat[:, cur:end, :] (y, self.subsampling_cache, self.elayers_output_cache, self.conformer_cnn_cache) = self.model.encoder.forward_chunk( chunk_xs, self.offset, required_cache_size, self.subsampling_cache, self.elayers_output_cache, self.conformer_cnn_cache) outputs.append(y) # update the offset self.offset += y.shape[1] ys = paddle.cat(outputs, 1) if self.encoder_out is None: self.encoder_out = ys else: self.encoder_out = paddle.concat([self.encoder_out, ys], axis=1) # get the ctc probs ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size) ctc_probs = ctc_probs.squeeze(0) self.searcher.search(ctc_probs, self.cached_feat.place) self.hyps = self.searcher.get_one_best_hyps() assert self.cached_feat.shape[0] == 1 assert end >= cached_feature_num self.cached_feat = self.cached_feat[0, end - cached_feature_num:, :].unsqueeze(0) assert len( self.cached_feat.shape ) == 3, f"current cache feat shape is: {self.cached_feat.shape}" logger.info( f"This connection handler encoder out shape: {self.encoder_out.shape}" ) def update_result(self): logger.info("update the final result") hyps = self.hyps self.result_transcripts = [ self.text_feature.defeaturize(hyp) for hyp in hyps ] self.result_tokenids = [hyp for hyp in hyps] def get_result(self): if len(self.result_transcripts) > 0: return self.result_transcripts[0] else: return '' def get_word_time_stamp(self): return self.word_time_stamp @paddle.no_grad() def rescoring(self): if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: return logger.info("rescoring the final result") if "attention_rescoring" != self.ctc_decode_config.decoding_method: return self.searcher.finalize_search() self.update_result() beam_size = self.ctc_decode_config.beam_size hyps = self.searcher.get_hyps() if hyps is None or len(hyps) == 0: return # assert len(hyps) == beam_size hyp_list = [] for hyp in hyps: hyp_content = hyp[0] # Prevent the hyp is empty if len(hyp_content) == 0: hyp_content = (self.model.ctc.blank_id, ) hyp_content = paddle.to_tensor( hyp_content, place=self.device, dtype=paddle.long) hyp_list.append(hyp_content) hyps_pad = pad_sequence(hyp_list, True, self.model.ignore_id) hyps_lens = paddle.to_tensor( [len(hyp[0]) for hyp in hyps], place=self.device, dtype=paddle.long) # (beam_size,) hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos, self.model.ignore_id) hyps_lens = hyps_lens + 1 # Add at begining encoder_out = self.encoder_out.repeat(beam_size, 1, 1) encoder_mask = paddle.ones( (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool) decoder_out, _ = self.model.decoder( encoder_out, encoder_mask, hyps_pad, hyps_lens) # (beam_size, max_hyps_len, vocab_size) # ctc score in ln domain decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) decoder_out = decoder_out.numpy() # Only use decoder score for rescoring best_score = -float('inf') best_index = 0 # hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size for i, hyp in enumerate(hyps): score = 0.0 for j, w in enumerate(hyp[0]): score += decoder_out[i][j][w] # last decoder output token is `eos`, for laste decoder input token. score += decoder_out[i][len(hyp[0])][self.model.eos] # add ctc score (which in ln domain) score += hyp[1] * self.ctc_decode_config.ctc_weight if score > best_score: best_score = score best_index = i # update the one best result # hyps stored the beam results and each fields is: logger.info(f"best index: {best_index}") # logger.info(f'best result: {hyps[best_index]}') # the field of the hyps is: # hyps[0][0]: the sentence word-id in the vocab with a tuple # hyps[0][1]: the sentence decoding probability with all paths # hyps[0][2]: viterbi_blank ending probability # hyps[0][3]: viterbi_non_blank probability # hyps[0][4]: current_token_prob, # hyps[0][5]: times_viterbi_blank, # hyps[0][6]: times_titerbi_non_blank self.hyps = [hyps[best_index][0]] # update the hyps time stamp self.time_stamp = hyps[best_index][5] if hyps[best_index][2] > hyps[ best_index][3] else hyps[best_index][6] logger.info(f"time stamp: {self.time_stamp}") self.update_result() # update each word start and end time stamp frame_shift_in_ms = self.model.encoder.embed.subsampling_rate * self.n_shift / self.sample_rate logger.info(f"frame shift ms: {frame_shift_in_ms}") word_time_stamp = [] for idx, _ in enumerate(self.time_stamp): start = (self.time_stamp[idx - 1] + self.time_stamp[idx] ) / 2.0 if idx > 0 else 0 start = start * frame_shift_in_ms end = (self.time_stamp[idx] + self.time_stamp[idx + 1] ) / 2.0 if idx < len(self.time_stamp) - 1 else self.offset end = end * frame_shift_in_ms word_time_stamp.append({ "w": self.result_transcripts[0][idx], "bg": start, "ed": end }) # logger.info(f"{self.result_transcripts[0][idx]}, start: {start}, end: {end}") self.word_time_stamp = word_time_stamp logger.info(f"word time stamp: {self.word_time_stamp}") class ASRServerExecutor(ASRExecutor): def __init__(self): super().__init__() self.pretrained_models = pretrained_models def _init_from_path(self, model_type: str='deepspeech2online_aishell', am_model: Optional[os.PathLike]=None, am_params: Optional[os.PathLike]=None, lang: str='zh', sample_rate: int=16000, cfg_path: Optional[os.PathLike]=None, decode_method: str='attention_rescoring', am_predictor_conf: dict=None): """ Init model and other resources from a specific path. """ self.model_type = model_type self.sample_rate = sample_rate sample_rate_str = '16k' if sample_rate == 16000 else '8k' tag = model_type + '-' + lang + '-' + sample_rate_str if cfg_path is None or am_model is None or am_params is None: logger.info(f"Load the pretrained model, tag = {tag}") res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.res_path = res_path self.cfg_path = os.path.join( res_path, self.pretrained_models[tag]['cfg_path']) self.am_model = os.path.join(res_path, self.pretrained_models[tag]['model']) self.am_params = os.path.join(res_path, self.pretrained_models[tag]['params']) logger.info(res_path) else: self.cfg_path = os.path.abspath(cfg_path) self.am_model = os.path.abspath(am_model) self.am_params = os.path.abspath(am_params) self.res_path = os.path.dirname( os.path.dirname(os.path.abspath(self.cfg_path))) logger.info(self.cfg_path) logger.info(self.am_model) logger.info(self.am_params) #Init body. self.config = CfgNode(new_allowed=True) self.config.merge_from_file(self.cfg_path) with UpdateConfig(self.config): if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: from paddlespeech.s2t.io.collator import SpeechCollator self.vocab = self.config.vocab_filepath self.config.decode.lang_model_path = os.path.join( MODEL_HOME, 'language_model', self.config.decode.lang_model_path) self.collate_fn_test = SpeechCollator.from_config(self.config) self.text_feature = TextFeaturizer( unit_type=self.config.unit_type, vocab=self.vocab) lm_url = self.pretrained_models[tag]['lm_url'] lm_md5 = self.pretrained_models[tag]['lm_md5'] logger.info(f"Start to load language model {lm_url}") self.download_lm( lm_url, os.path.dirname(self.config.decode.lang_model_path), lm_md5) elif "conformer" in model_type or "transformer" in model_type: logger.info("start to create the stream conformer asr engine") if self.config.spm_model_prefix: self.config.spm_model_prefix = os.path.join( self.res_path, self.config.spm_model_prefix) self.vocab = self.config.vocab_filepath self.text_feature = TextFeaturizer( unit_type=self.config.unit_type, vocab=self.config.vocab_filepath, spm_model_prefix=self.config.spm_model_prefix) # update the decoding method if decode_method: self.config.decode.decoding_method = decode_method # we only support ctc_prefix_beam_search and attention_rescoring dedoding method # Generally we set the decoding_method to attention_rescoring if self.config.decode.decoding_method not in [ "ctc_prefix_beam_search", "attention_rescoring" ]: logger.info( "we set the decoding_method to attention_rescoring") self.config.decode.decoding_method = "attention_rescoring" assert self.config.decode.decoding_method in [ "ctc_prefix_beam_search", "attention_rescoring" ], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}" else: raise Exception("wrong type") if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: # AM predictor logger.info("ASR engine start to init the am predictor") self.am_predictor_conf = am_predictor_conf self.am_predictor = init_predictor( model_file=self.am_model, params_file=self.am_params, predictor_conf=self.am_predictor_conf) # decoder logger.info("ASR engine start to create the ctc decoder instance") self.decoder = CTCDecoder( odim=self.config.output_dim, # is in vocab enc_n_units=self.config.rnn_layer_size * 2, blank_id=self.config.blank_id, dropout_rate=0.0, reduction=True, # sum batch_average=True, # sum / batch_size grad_norm_type=self.config.get('ctc_grad_norm_type', None)) # init decoder logger.info("ASR engine start to init the ctc decoder") cfg = self.config.decode decode_batch_size = 1 # for online self.decoder.init_decoder( decode_batch_size, self.text_feature.vocab_list, cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, cfg.num_proc_bsearch) # init state box self.chunk_state_h_box = np.zeros( (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), dtype=float32) self.chunk_state_c_box = np.zeros( (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), dtype=float32) elif "conformer" in model_type or "transformer" in model_type: model_name = model_type[:model_type.rindex( '_')] # model_type: {model_name}_{dataset} logger.info(f"model name: {model_name}") model_class = dynamic_import(model_name, self.model_alias) model_conf = self.config model = model_class.from_config(model_conf) self.model = model self.model.eval() # load model model_dict = paddle.load(self.am_model) self.model.set_state_dict(model_dict) logger.info("create the transformer like model success") # update the ctc decoding self.searcher = CTCPrefixBeamSearch(self.config.decode) self.transformer_decode_reset() def reset_decoder_and_chunk(self): """reset decoder and chunk state for an new audio """ if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: self.decoder.reset_decoder(batch_size=1) # init state box, for new audio request self.chunk_state_h_box = np.zeros( (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), dtype=float32) self.chunk_state_c_box = np.zeros( (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), dtype=float32) elif "conformer" in self.model_type or "transformer" in self.model_type: self.transformer_decode_reset() def decode_one_chunk(self, x_chunk, x_chunk_lens, model_type: str): """decode one chunk Args: x_chunk (numpy.array): shape[B, T, D] x_chunk_lens (numpy.array): shape[B] model_type (str): online model type Returns: str: one best result """ logger.info("start to decoce chunk by chunk") if "deepspeech2online" in model_type: input_names = self.am_predictor.get_input_names() audio_handle = self.am_predictor.get_input_handle(input_names[0]) audio_len_handle = self.am_predictor.get_input_handle( input_names[1]) h_box_handle = self.am_predictor.get_input_handle(input_names[2]) c_box_handle = self.am_predictor.get_input_handle(input_names[3]) audio_handle.reshape(x_chunk.shape) audio_handle.copy_from_cpu(x_chunk) audio_len_handle.reshape(x_chunk_lens.shape) audio_len_handle.copy_from_cpu(x_chunk_lens) h_box_handle.reshape(self.chunk_state_h_box.shape) h_box_handle.copy_from_cpu(self.chunk_state_h_box) c_box_handle.reshape(self.chunk_state_c_box.shape) c_box_handle.copy_from_cpu(self.chunk_state_c_box) output_names = self.am_predictor.get_output_names() output_handle = self.am_predictor.get_output_handle(output_names[0]) output_lens_handle = self.am_predictor.get_output_handle( output_names[1]) output_state_h_handle = self.am_predictor.get_output_handle( output_names[2]) output_state_c_handle = self.am_predictor.get_output_handle( output_names[3]) self.am_predictor.run() output_chunk_probs = output_handle.copy_to_cpu() output_chunk_lens = output_lens_handle.copy_to_cpu() self.chunk_state_h_box = output_state_h_handle.copy_to_cpu() self.chunk_state_c_box = output_state_c_handle.copy_to_cpu() self.decoder.next(output_chunk_probs, output_chunk_lens) trans_best, trans_beam = self.decoder.decode() logger.info(f"decode one best result: {trans_best[0]}") return trans_best[0] elif "conformer" in model_type or "transformer" in model_type: try: logger.info( f"we will use the transformer like model : {self.model_type}" ) self.advanced_decoding(x_chunk, x_chunk_lens) self.update_result() return self.result_transcripts[0] except Exception as e: logger.exception(e) else: raise Exception("invalid model name") def advanced_decoding(self, xs: paddle.Tensor, x_chunk_lens): logger.info("start to decode with advanced_decoding method") encoder_out, encoder_mask = self.encoder_forward(xs) ctc_probs = self.model.ctc.log_softmax( encoder_out) # (1, maxlen, vocab_size) ctc_probs = ctc_probs.squeeze(0) self.searcher.search(ctc_probs, xs.place) # update the one best result self.hyps = self.searcher.get_one_best_hyps() # now we supprot ctc_prefix_beam_search and attention_rescoring if "attention_rescoring" in self.config.decode.decoding_method: self.rescoring(encoder_out, xs.place) def encoder_forward(self, xs): logger.info("get the model out from the feat") cfg = self.config.decode decoding_chunk_size = cfg.decoding_chunk_size num_decoding_left_chunks = cfg.num_decoding_left_chunks assert decoding_chunk_size > 0 subsampling = self.model.encoder.embed.subsampling_rate context = self.model.encoder.embed.right_context + 1 stride = subsampling * decoding_chunk_size # decoding window for model decoding_window = (decoding_chunk_size - 1) * subsampling + context num_frames = xs.shape[1] required_cache_size = decoding_chunk_size * num_decoding_left_chunks logger.info("start to do model forward") outputs = [] # num_frames - context + 1 ensure that current frame can get context window for cur in range(0, num_frames - context + 1, stride): end = min(cur + decoding_window, num_frames) chunk_xs = xs[:, cur:end, :] (y, self.subsampling_cache, self.elayers_output_cache, self.conformer_cnn_cache) = self.model.encoder.forward_chunk( chunk_xs, self.offset, required_cache_size, self.subsampling_cache, self.elayers_output_cache, self.conformer_cnn_cache) outputs.append(y) self.offset += y.shape[1] ys = paddle.cat(outputs, 1) masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) masks = masks.unsqueeze(1) return ys, masks def rescoring(self, encoder_out, device): logger.info("start to rescoring the hyps") beam_size = self.config.decode.beam_size hyps = self.searcher.get_hyps() assert len(hyps) == beam_size hyp_list = [] for hyp in hyps: hyp_content = hyp[0] # Prevent the hyp is empty if len(hyp_content) == 0: hyp_content = (self.model.ctc.blank_id, ) hyp_content = paddle.to_tensor( hyp_content, place=device, dtype=paddle.long) hyp_list.append(hyp_content) hyps_pad = pad_sequence(hyp_list, True, self.model.ignore_id) hyps_lens = paddle.to_tensor( [len(hyp[0]) for hyp in hyps], place=device, dtype=paddle.long) # (beam_size,) hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos, self.model.ignore_id) hyps_lens = hyps_lens + 1 # Add at begining encoder_out = encoder_out.repeat(beam_size, 1, 1) encoder_mask = paddle.ones( (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool) decoder_out, _ = self.model.decoder( encoder_out, encoder_mask, hyps_pad, hyps_lens) # (beam_size, max_hyps_len, vocab_size) # ctc score in ln domain decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) decoder_out = decoder_out.numpy() # Only use decoder score for rescoring best_score = -float('inf') best_index = 0 # hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size for i, hyp in enumerate(hyps): score = 0.0 for j, w in enumerate(hyp[0]): score += decoder_out[i][j][w] # last decoder output token is `eos`, for laste decoder input token. score += decoder_out[i][len(hyp[0])][self.model.eos] # add ctc score (which in ln domain) score += hyp[1] * self.config.decode.ctc_weight if score > best_score: best_score = score best_index = i # update the one best result self.hyps = [hyps[best_index][0]] return hyps[best_index][0] def transformer_decode_reset(self): self.subsampling_cache = None self.elayers_output_cache = None self.conformer_cnn_cache = None self.offset = 0 # decoding reset self.searcher.reset() def update_result(self): logger.info("update the final result") hyps = self.hyps self.result_transcripts = [ self.text_feature.defeaturize(hyp) for hyp in hyps ] self.result_tokenids = [hyp for hyp in hyps] def extract_feat(self, samples, sample_rate): """extract feat Args: samples (numpy.array): numpy.float32 sample_rate (int): sample rate Returns: x_chunk (numpy.array): shape[B, T, D] x_chunk_lens (numpy.array): shape[B] """ if "deepspeech2online" in self.model_type: # pcm16 -> pcm 32 samples = pcm2float(samples) # read audio speech_segment = SpeechSegment.from_pcm( samples, sample_rate, transcript=" ") # audio augment self.collate_fn_test.augmentation.transform_audio(speech_segment) # extract speech feature spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize( speech_segment, self.collate_fn_test.keep_transcription_text) # CMVN spectrum if self.collate_fn_test._normalizer: spectrum = self.collate_fn_test._normalizer.apply(spectrum) # spectrum augment audio = self.collate_fn_test.augmentation.transform_feature( spectrum) audio_len = audio.shape[0] audio = paddle.to_tensor(audio, dtype='float32') # audio_len = paddle.to_tensor(audio_len) audio = paddle.unsqueeze(audio, axis=0) x_chunk = audio.numpy() x_chunk_lens = np.array([audio_len]) return x_chunk, x_chunk_lens elif "conformer_online" in self.model_type: if sample_rate != self.sample_rate: logger.info(f"audio sample rate {sample_rate} is not match," "the model sample_rate is {self.sample_rate}") logger.info(f"ASR Engine use the {self.model_type} to process") logger.info("Create the preprocess instance") preprocess_conf = self.config.preprocess_config preprocess_args = {"train": False} preprocessing = Transformation(preprocess_conf) logger.info("Read the audio file") logger.info(f"audio shape: {samples.shape}") # fbank x_chunk = preprocessing(samples, **preprocess_args) x_chunk_lens = paddle.to_tensor(x_chunk.shape[0]) x_chunk = paddle.to_tensor( x_chunk, dtype="float32").unsqueeze(axis=0) logger.info( f"process the audio feature success, feat shape: {x_chunk.shape}" ) return x_chunk, x_chunk_lens class ASREngine(BaseEngine): """ASR server engine Args: metaclass: Defaults to Singleton. """ def __init__(self): super(ASREngine, self).__init__() logger.info("create the online asr engine instance") def init(self, config: dict) -> bool: """init engine resource Args: config_file (str): config file Returns: bool: init failed or success """ self.input = None self.output = "" self.executor = ASRServerExecutor() self.config = config try: if self.config.get("device", None): self.device = self.config.device else: self.device = paddle.get_device() logger.info(f"paddlespeech_server set the device: {self.device}") paddle.set_device(self.device) except BaseException: logger.error( "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" ) self.executor._init_from_path( model_type=self.config.model_type, am_model=self.config.am_model, am_params=self.config.am_params, lang=self.config.lang, sample_rate=self.config.sample_rate, cfg_path=self.config.cfg_path, decode_method=self.config.decode_method, am_predictor_conf=self.config.am_predictor_conf) logger.info("Initialize ASR server engine successfully.") return True def preprocess(self, samples, sample_rate, model_type="deepspeech2online_aishell-zh-16k"): """preprocess Args: samples (numpy.array): numpy.float32 sample_rate (int): sample rate Returns: x_chunk (numpy.array): shape[B, T, D] x_chunk_lens (numpy.array): shape[B] """ # if "deepspeech" in model_type: x_chunk, x_chunk_lens = self.executor.extract_feat(samples, sample_rate) return x_chunk, x_chunk_lens def run(self, x_chunk, x_chunk_lens, decoder_chunk_size=1): """run online engine Args: x_chunk (numpy.array): shape[B, T, D] x_chunk_lens (numpy.array): shape[B] decoder_chunk_size(int) """ self.output = self.executor.decode_one_chunk(x_chunk, x_chunk_lens, self.config.model_type) def postprocess(self): """postprocess """ return self.output def reset(self): """reset engine decoder and inference state """ self.executor.reset_decoder_and_chunk() self.output = ""