# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import os import sys from typing import Optional import numpy as np import paddle from numpy import float32 from yacs.config import CfgNode from .pretrained_models import pretrained_models from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.log import logger from paddlespeech.cli.utils import MODEL_HOME from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.speech import SpeechSegment from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.tensor_utils import add_sos_eos from paddlespeech.s2t.utils.tensor_utils import pad_sequence from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.server.engine.asr.online.ctc_search import CTCPrefixBeamSearch from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.audio_process import pcm2float from paddlespeech.server.utils.paddle_predictor import init_predictor __all__ = ['PaddleASRConnectionHanddler', 'ASRServerExecutor', 'ASREngine'] # ASR server connection process class class PaddleASRConnectionHanddler: def __init__(self, asr_engine): """Init a Paddle ASR Connection Handler instance Args: asr_engine (ASREngine): the global asr engine """ super().__init__() logger.info( "create an paddle asr connection handler to process the websocket connection" ) self.config = asr_engine.config self.model_config = asr_engine.executor.config self.asr_engine = asr_engine self.init() self.reset() def init(self): # model_type, sample_rate and text_feature is shared for deepspeech2 and conformer self.model_type = self.asr_engine.executor.model_type self.sample_rate = self.asr_engine.executor.sample_rate # tokens to text self.text_feature = self.asr_engine.executor.text_feature if "deepspeech2" in self.model_type: from paddlespeech.s2t.io.collator import SpeechCollator self.am_predictor = self.asr_engine.executor.am_predictor self.collate_fn_test = SpeechCollator.from_config(self.model_config) self.decoder = CTCDecoder( odim=self.model_config.output_dim, # is in vocab enc_n_units=self.model_config.rnn_layer_size * 2, blank_id=self.model_config.blank_id, dropout_rate=0.0, reduction=True, # sum batch_average=True, # sum / batch_size grad_norm_type=self.model_config.get('ctc_grad_norm_type', None)) cfg = self.model_config.decode decode_batch_size = 1 # for online self.decoder.init_decoder( decode_batch_size, self.text_feature.vocab_list, cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, cfg.num_proc_bsearch) # frame window and frame shift, in samples unit self.win_length = int(self.model_config.window_ms / 1000 * self.sample_rate) self.n_shift = int(self.model_config.stride_ms / 1000 * self.sample_rate) elif "conformer" in self.model_type or "transformer" in self.model_type: # acoustic model self.model = self.asr_engine.executor.model # ctc decoding config self.ctc_decode_config = self.asr_engine.executor.config.decode self.searcher = CTCPrefixBeamSearch(self.ctc_decode_config) # extract feat, new only fbank in conformer model self.preprocess_conf = self.model_config.preprocess_config self.preprocess_args = {"train": False} self.preprocessing = Transformation(self.preprocess_conf) # frame window and frame shift, in samples unit self.win_length = self.preprocess_conf.process[0]['win_length'] self.n_shift = self.preprocess_conf.process[0]['n_shift'] else: raise ValueError(f"Not supported: {self.model_type}") def extract_feat(self, samples): # we compute the elapsed time of first char occuring # and we record the start time at the first pcm sample arraving if "deepspeech2online" in self.model_type: # self.reamined_wav stores all the samples, # include the original remained_wav and this package samples samples = np.frombuffer(samples, dtype=np.int16) assert samples.ndim == 1 # pcm16 -> pcm 32 # pcm2float will change the orignal samples, # so we shoule do pcm2float before concatenate samples = pcm2float(samples) if self.remained_wav is None: self.remained_wav = samples else: assert self.remained_wav.ndim == 1 self.remained_wav = np.concatenate([self.remained_wav, samples]) logger.info( f"The connection remain the audio samples: {self.remained_wav.shape}" ) # read audio speech_segment = SpeechSegment.from_pcm( self.remained_wav, self.sample_rate, transcript=" ") # audio augment self.collate_fn_test.augmentation.transform_audio(speech_segment) # extract speech feature spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize( speech_segment, self.collate_fn_test.keep_transcription_text) # CMVN spectrum if self.collate_fn_test._normalizer: spectrum = self.collate_fn_test._normalizer.apply(spectrum) # spectrum augment feat = self.collate_fn_test.augmentation.transform_feature( spectrum) # audio_len is frame num frame_num = feat.shape[0] feat = paddle.to_tensor(feat, dtype='float32') feat = paddle.unsqueeze(feat, axis=0) if self.cached_feat is None: self.cached_feat = feat else: assert (len(feat.shape) == 3) assert (len(self.cached_feat.shape) == 3) self.cached_feat = paddle.concat( [self.cached_feat, feat], axis=1) # set the feat device if self.device is None: self.device = self.cached_feat.place self.num_frames += frame_num self.remained_wav = self.remained_wav[self.n_shift * frame_num:] logger.info( f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}" ) logger.info( f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" ) elif "conformer_online" in self.model_type: logger.info("Online ASR extract the feat") samples = np.frombuffer(samples, dtype=np.int16) assert samples.ndim == 1 self.num_samples += samples.shape[0] logger.info(f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}") # self.reamined_wav stores all the samples, # include the original remained_wav and this package samples if self.remained_wav is None: self.remained_wav = samples else: assert self.remained_wav.ndim == 1 # (T,) self.remained_wav = np.concatenate([self.remained_wav, samples]) logger.info( f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}" ) if len(self.remained_wav) < self.win_length: # samples not enough for feature window return 0 # fbank x_chunk = self.preprocessing(self.remained_wav, **self.preprocess_args) x_chunk = paddle.to_tensor( x_chunk, dtype="float32").unsqueeze(axis=0) # feature cache if self.cached_feat is None: self.cached_feat = x_chunk else: assert (len(x_chunk.shape) == 3) # (B,T,D) assert (len(self.cached_feat.shape) == 3) # (B,T,D) self.cached_feat = paddle.concat( [self.cached_feat, x_chunk], axis=1) # set the feat device if self.device is None: self.device = self.cached_feat.place # cur frame step num_frames = x_chunk.shape[1] # global frame step self.num_frames += num_frames # update remained wav self.remained_wav = self.remained_wav[self.n_shift * num_frames:] logger.info( f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}" ) logger.info( f"After extract feat, the cached remain the audio samples: {self.remained_wav.shape}" ) logger.info(f"global samples: {self.num_samples}") logger.info(f"global frames: {self.num_frames}") else: raise ValueError(f"not supported: {self.model_type}") def reset(self): if "deepspeech2" in self.model_type: # for deepspeech2 self.chunk_state_h_box = copy.deepcopy( self.asr_engine.executor.chunk_state_h_box) self.chunk_state_c_box = copy.deepcopy( self.asr_engine.executor.chunk_state_c_box) self.decoder.reset_decoder(batch_size=1) self.device = None ## common # global sample and frame step self.num_samples = 0 self.num_frames = 0 # cache for audio and feat self.remained_wav = None self.cached_feat = None # partial/ending decoding results self.result_transcripts = [''] ## conformer # cache for conformer online self.subsampling_cache = None self.elayers_output_cache = None self.conformer_cnn_cache = None self.encoder_out = None # conformer decoding state self.chunk_num = 0 # globa decoding chunk num self.offset = 0 # global offset in decoding frame unit self.hyps = [] # token timestamp result self.word_time_stamp = [] # one best timestamp viterbi prob is large. self.time_stamp = [] def decode(self, is_finished=False): """advance decoding Args: is_finished (bool, optional): Is last frame or not. Defaults to False. Raises: Exception: when not support model. Returns: None: nothing """ if "deepspeech2online" in self.model_type: decoding_chunk_size = 1 # decoding chunk size = 1. int decoding frame unit context = 7 # context=7, in audio frame unit subsampling = 4 # subsampling=4, in audio frame unit cached_feature_num = context - subsampling # decoding window for model, in audio frame unit decoding_window = (decoding_chunk_size - 1) * subsampling + context # decoding stride for model, in audio frame unit stride = subsampling * decoding_chunk_size if self.cached_feat is None: logger.info("no audio feat, please input more pcm data") return num_frames = self.cached_feat.shape[1] logger.info( f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" ) # the cached feat must be larger decoding_window if num_frames < decoding_window and not is_finished: logger.info( f"frame feat num is less than {decoding_window}, please input more pcm data" ) return None, None # if is_finished=True, we need at least context frames if num_frames < context: logger.info( "flast {num_frames} is less than context {context} frames, and we cannot do model forward" ) return None, None logger.info("start to do model forward") # num_frames - context + 1 ensure that current frame can get context window if is_finished: # if get the finished chunk, we need process the last context left_frames = context else: # we only process decoding_window frames for one chunk left_frames = decoding_window end = None for cur in range(0, num_frames - left_frames + 1, stride): end = min(cur + decoding_window, num_frames) # extract the audio x_chunk = self.cached_feat[:, cur:end, :].numpy() x_chunk_lens = np.array([x_chunk.shape[1]]) trans_best = self.decode_one_chunk(x_chunk, x_chunk_lens) self.result_transcripts = [trans_best] # update feat cache self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :] # return trans_best[0] elif "conformer" in self.model_type or "transformer" in self.model_type: try: logger.info( f"we will use the transformer like model : {self.model_type}" ) self.advance_decoding(is_finished) self.update_result() except Exception as e: logger.exception(e) else: raise Exception("invalid model name") @paddle.no_grad() def decode_one_chunk(self, x_chunk, x_chunk_lens): """forward one chunk frames Args: x_chunk (np.ndarray): (B,T,D), audio frames. x_chunk_lens ([type]): (B,), audio frame lens Returns: logprob: poster probability. """ logger.info("start to decoce one chunk for deepspeech2") input_names = self.am_predictor.get_input_names() audio_handle = self.am_predictor.get_input_handle(input_names[0]) audio_len_handle = self.am_predictor.get_input_handle(input_names[1]) h_box_handle = self.am_predictor.get_input_handle(input_names[2]) c_box_handle = self.am_predictor.get_input_handle(input_names[3]) audio_handle.reshape(x_chunk.shape) audio_handle.copy_from_cpu(x_chunk) audio_len_handle.reshape(x_chunk_lens.shape) audio_len_handle.copy_from_cpu(x_chunk_lens) h_box_handle.reshape(self.chunk_state_h_box.shape) h_box_handle.copy_from_cpu(self.chunk_state_h_box) c_box_handle.reshape(self.chunk_state_c_box.shape) c_box_handle.copy_from_cpu(self.chunk_state_c_box) output_names = self.am_predictor.get_output_names() output_handle = self.am_predictor.get_output_handle(output_names[0]) output_lens_handle = self.am_predictor.get_output_handle( output_names[1]) output_state_h_handle = self.am_predictor.get_output_handle( output_names[2]) output_state_c_handle = self.am_predictor.get_output_handle( output_names[3]) self.am_predictor.run() output_chunk_probs = output_handle.copy_to_cpu() output_chunk_lens = output_lens_handle.copy_to_cpu() self.chunk_state_h_box = output_state_h_handle.copy_to_cpu() self.chunk_state_c_box = output_state_c_handle.copy_to_cpu() self.decoder.next(output_chunk_probs, output_chunk_lens) trans_best, trans_beam = self.decoder.decode() logger.info(f"decode one best result for deepspeech2: {trans_best[0]}") return trans_best[0] @paddle.no_grad() def advance_decoding(self, is_finished=False): logger.info("Conformer/Transformer: start to decode with advanced_decoding method") cfg = self.ctc_decode_config # cur chunk size, in decoding frame unit decoding_chunk_size = cfg.decoding_chunk_size # using num of history chunks num_decoding_left_chunks = cfg.num_decoding_left_chunks assert decoding_chunk_size > 0 subsampling = self.model.encoder.embed.subsampling_rate context = self.model.encoder.embed.right_context + 1 # processed chunk feature cached for next chunk cached_feature_num = context - subsampling # decoding stride, in audio frame unit stride = subsampling * decoding_chunk_size # decoding window, in audio frame unit decoding_window = (decoding_chunk_size - 1) * subsampling + context if self.cached_feat is None: logger.info("no audio feat, please input more pcm data") return num_frames = self.cached_feat.shape[1] logger.info( f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" ) # the cached feat must be larger decoding_window if num_frames < decoding_window and not is_finished: logger.info( f"frame feat num is less than {decoding_window}, please input more pcm data" ) return None, None # if is_finished=True, we need at least context frames if num_frames < context: logger.info( "flast {num_frames} is less than context {context} frames, and we cannot do model forward" ) return None, None logger.info("start to do model forward") # hist of chunks, in deocding frame unit required_cache_size = decoding_chunk_size * num_decoding_left_chunks outputs = [] # num_frames - context + 1 ensure that current frame can get context window if is_finished: # if get the finished chunk, we need process the last context left_frames = context else: # we only process decoding_window frames for one chunk left_frames = decoding_window # record the end for removing the processed feat end = None for cur in range(0, num_frames - left_frames + 1, stride): end = min(cur + decoding_window, num_frames) # global chunk_num self.chunk_num += 1 # cur chunk chunk_xs = self.cached_feat[:, cur:end, :] # forward chunk (y, self.subsampling_cache, self.elayers_output_cache, self.conformer_cnn_cache) = self.model.encoder.forward_chunk( chunk_xs, self.offset, required_cache_size, self.subsampling_cache, self.elayers_output_cache, self.conformer_cnn_cache) outputs.append(y) # update the global offset, in decoding frame unit self.offset += y.shape[1] ys = paddle.cat(outputs, 1) if self.encoder_out is None: self.encoder_out = ys else: self.encoder_out = paddle.concat([self.encoder_out, ys], axis=1) # get the ctc probs ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size) ctc_probs = ctc_probs.squeeze(0) # advance decoding self.searcher.search(ctc_probs, self.cached_feat.place) # get one best hyps self.hyps = self.searcher.get_one_best_hyps() assert self.cached_feat.shape[0] == 1 assert end >= cached_feature_num # advance cache of feat self.cached_feat = self.cached_feat[0, end - cached_feature_num:, :].unsqueeze(0) assert len( self.cached_feat.shape ) == 3, f"current cache feat shape is: {self.cached_feat.shape}" logger.info( f"This connection handler encoder out shape: {self.encoder_out.shape}" ) def update_result(self): """Conformer/Transformer hyps to result. """ logger.info("update the final result") hyps = self.hyps # output results and tokenids self.result_transcripts = [ self.text_feature.defeaturize(hyp) for hyp in hyps ] self.result_tokenids = [hyp for hyp in hyps] def get_result(self): """return partial/ending asr result. Returns: str: one best result of partial/ending. """ if len(self.result_transcripts) > 0: return self.result_transcripts[0] else: return '' def get_word_time_stamp(self): """return token timestamp result. Returns: list: List of ('w':token, 'bg':time, 'ed':time) """ return self.word_time_stamp @paddle.no_grad() def rescoring(self): """Second-Pass Decoding, only for conformer and transformer model. """ if "deepspeech2" in self.model_type: logger.info("deepspeech2 not support rescoring decoding.") return if "attention_rescoring" != self.ctc_decode_config.decoding_method: logger.info(f"decoding method not match: {self.ctc_decode_config.decoding_method}, need attention_rescoring") return logger.info("rescoring the final result") # last decoding for last audio self.searcher.finalize_search() # update beam search results self.update_result() beam_size = self.ctc_decode_config.beam_size hyps = self.searcher.get_hyps() if hyps is None or len(hyps) == 0: logger.info("No Hyps!") return # rescore by decoder post probability # assert len(hyps) == beam_size # list of Tensor hyp_list = [] for hyp in hyps: hyp_content = hyp[0] # Prevent the hyp is empty if len(hyp_content) == 0: hyp_content = (self.model.ctc.blank_id, ) hyp_content = paddle.to_tensor( hyp_content, place=self.device, dtype=paddle.long) hyp_list.append(hyp_content) hyps_pad = pad_sequence(hyp_list, batch_first=True, padding_value=self.model.ignore_id) hyps_lens = paddle.to_tensor( [len(hyp[0]) for hyp in hyps], place=self.device, dtype=paddle.long) # (beam_size,) hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos, self.model.ignore_id) hyps_lens = hyps_lens + 1 # Add at begining encoder_out = self.encoder_out.repeat(beam_size, 1, 1) encoder_mask = paddle.ones( (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool) decoder_out, _ = self.model.decoder( encoder_out, encoder_mask, hyps_pad, hyps_lens) # (beam_size, max_hyps_len, vocab_size) # ctc score in ln domain decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) decoder_out = decoder_out.numpy() # Only use decoder score for rescoring best_score = -float('inf') best_index = 0 # hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size for i, hyp in enumerate(hyps): score = 0.0 for j, w in enumerate(hyp[0]): score += decoder_out[i][j][w] # last decoder output token is `eos`, for laste decoder input token. score += decoder_out[i][len(hyp[0])][self.model.eos] # add ctc score (which in ln domain) score += hyp[1] * self.ctc_decode_config.ctc_weight if score > best_score: best_score = score best_index = i # update the one best result # hyps stored the beam results and each fields is: logger.info(f"best hyp index: {best_index}") # logger.info(f'best result: {hyps[best_index]}') # the field of the hyps is: ## asr results # hyps[0][0]: the sentence word-id in the vocab with a tuple # hyps[0][1]: the sentence decoding probability with all paths ## timestamp # hyps[0][2]: viterbi_blank ending probability # hyps[0][3]: viterbi_non_blank dending probability # hyps[0][4]: current_token_prob, # hyps[0][5]: times_viterbi_blank ending timestamp, # hyps[0][6]: times_titerbi_non_blank encding timestamp. self.hyps = [hyps[best_index][0]] logger.info(f"best hyp ids: {self.hyps}") # update the hyps time stamp self.time_stamp = hyps[best_index][5] if hyps[best_index][2] > hyps[ best_index][3] else hyps[best_index][6] logger.info(f"time stamp: {self.time_stamp}") # update one best result self.update_result() # update each word start and end time stamp # decoding frame to audio frame frame_shift = self.model.encoder.embed.subsampling_rate frame_shift_in_sec = frame_shift * (self.n_shift / self.sample_rate) logger.info(f"frame shift sec: {frame_shift_in_sec}") word_time_stamp = [] for idx, _ in enumerate(self.time_stamp): start = (self.time_stamp[idx - 1] + self.time_stamp[idx] ) / 2.0 if idx > 0 else 0 start = start * frame_shift_in_sec end = (self.time_stamp[idx] + self.time_stamp[idx + 1] ) / 2.0 if idx < len(self.time_stamp) - 1 else self.offset end = end * frame_shift_in_sec word_time_stamp.append({ "w": self.result_transcripts[0][idx], "bg": start, "ed": end }) # logger.info(f"{word_time_stamp[-1]}") self.word_time_stamp = word_time_stamp logger.info(f"word time stamp: {self.word_time_stamp}") class ASRServerExecutor(ASRExecutor): def __init__(self): super().__init__() self.pretrained_models = pretrained_models def _init_from_path(self, model_type: str=None, am_model: Optional[os.PathLike]=None, am_params: Optional[os.PathLike]=None, lang: str='zh', sample_rate: int=16000, cfg_path: Optional[os.PathLike]=None, decode_method: str='attention_rescoring', am_predictor_conf: dict=None): """ Init model and other resources from a specific path. """ if not model_type or not lang or not sample_rate: logger.error( "The model type or lang or sample rate is None, please input an valid server parameter yaml" ) return False self.model_type = model_type self.sample_rate = sample_rate sample_rate_str = '16k' if sample_rate == 16000 else '8k' tag = model_type + '-' + lang + '-' + sample_rate_str if cfg_path is None or am_model is None or am_params is None: logger.info(f"Load the pretrained model, tag = {tag}") res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.res_path = res_path self.cfg_path = os.path.join( res_path, self.pretrained_models[tag]['cfg_path']) self.am_model = os.path.join(res_path, self.pretrained_models[tag]['model']) self.am_params = os.path.join(res_path, self.pretrained_models[tag]['params']) logger.info(res_path) else: self.cfg_path = os.path.abspath(cfg_path) self.am_model = os.path.abspath(am_model) self.am_params = os.path.abspath(am_params) self.res_path = os.path.dirname( os.path.dirname(os.path.abspath(self.cfg_path))) logger.info(self.cfg_path) logger.info(self.am_model) logger.info(self.am_params) #Init body. self.config = CfgNode(new_allowed=True) self.config.merge_from_file(self.cfg_path) with UpdateConfig(self.config): if "deepspeech2" in model_type: from paddlespeech.s2t.io.collator import SpeechCollator self.vocab = self.config.vocab_filepath self.config.decode.lang_model_path = os.path.join( MODEL_HOME, 'language_model', self.config.decode.lang_model_path) self.collate_fn_test = SpeechCollator.from_config(self.config) self.text_feature = TextFeaturizer( unit_type=self.config.unit_type, vocab=self.vocab) lm_url = self.pretrained_models[tag]['lm_url'] lm_md5 = self.pretrained_models[tag]['lm_md5'] logger.info(f"Start to load language model {lm_url}") self.download_lm( lm_url, os.path.dirname(self.config.decode.lang_model_path), lm_md5) elif "conformer" in model_type or "transformer" in model_type: logger.info("start to create the stream conformer asr engine") if self.config.spm_model_prefix: self.config.spm_model_prefix = os.path.join( self.res_path, self.config.spm_model_prefix) self.vocab = self.config.vocab_filepath self.text_feature = TextFeaturizer( unit_type=self.config.unit_type, vocab=self.config.vocab_filepath, spm_model_prefix=self.config.spm_model_prefix) # update the decoding method if decode_method: self.config.decode.decoding_method = decode_method # we only support ctc_prefix_beam_search and attention_rescoring dedoding method # Generally we set the decoding_method to attention_rescoring if self.config.decode.decoding_method not in [ "ctc_prefix_beam_search", "attention_rescoring" ]: logger.info( "we set the decoding_method to attention_rescoring") self.config.decode.decoding_method = "attention_rescoring" assert self.config.decode.decoding_method in [ "ctc_prefix_beam_search", "attention_rescoring" ], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}" else: raise Exception("wrong type") if "deepspeech2" in model_type: # AM predictor logger.info("ASR engine start to init the am predictor") self.am_predictor_conf = am_predictor_conf self.am_predictor = init_predictor( model_file=self.am_model, params_file=self.am_params, predictor_conf=self.am_predictor_conf) # decoder logger.info("ASR engine start to create the ctc decoder instance") self.decoder = CTCDecoder( odim=self.config.output_dim, # is in vocab enc_n_units=self.config.rnn_layer_size * 2, blank_id=self.config.blank_id, dropout_rate=0.0, reduction=True, # sum batch_average=True, # sum / batch_size grad_norm_type=self.config.get('ctc_grad_norm_type', None)) # init decoder logger.info("ASR engine start to init the ctc decoder") cfg = self.config.decode decode_batch_size = 1 # for online self.decoder.init_decoder( decode_batch_size, self.text_feature.vocab_list, cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, cfg.num_proc_bsearch) # init state box self.chunk_state_h_box = np.zeros( (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), dtype=float32) self.chunk_state_c_box = np.zeros( (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), dtype=float32) elif "conformer" in model_type or "transformer" in model_type: model_name = model_type[:model_type.rindex( '_')] # model_type: {model_name}_{dataset} logger.info(f"model name: {model_name}") model_class = dynamic_import(model_name, self.model_alias) model_conf = self.config model = model_class.from_config(model_conf) self.model = model self.model.eval() # load model model_dict = paddle.load(self.am_model) self.model.set_state_dict(model_dict) logger.info("create the transformer like model success") # update the ctc decoding self.searcher = CTCPrefixBeamSearch(self.config.decode) self.transformer_decode_reset() else: raise ValueError(f"Not support: {model_type}") return True class ASREngine(BaseEngine): """ASR server resource Args: metaclass: Defaults to Singleton. """ def __init__(self): super(ASREngine, self).__init__() logger.info("create the online asr engine resource instance") def init(self, config: dict) -> bool: """init engine resource Args: config_file (str): config file Returns: bool: init failed or success """ self.config = config self.executor = ASRServerExecutor() try: default_dev = paddle.get_device() paddle.set_device(self.config.get("device", default_dev)) except BaseException as e: logger.error( f"Set device failed, please check if device '{self.device}' is already used and the parameter 'device' in the yaml file" ) logger.error( "If all GPU or XPU is used, you can set the server to 'cpu'") sys.exit(-1) logger.info(f"paddlespeech_server set the device: {self.device}") if not self.executor._init_from_path( model_type=self.config.model_type, am_model=self.config.am_model, am_params=self.config.am_params, lang=self.config.lang, sample_rate=self.config.sample_rate, cfg_path=self.config.cfg_path, decode_method=self.config.decode_method, am_predictor_conf=self.config.am_predictor_conf): logger.error( "Init the ASR server occurs error, please check the server configuration yaml" ) return False logger.info("Initialize ASR server engine successfully.") return True def preprocess(self, *args, **kwargs): raise NotImplementedError("Online not using this.") def run(self, *args, **kwargs): raise NotImplementedError("Online not using this.") def postprocess(self): raise NotImplementedError("Online not using this.")