# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import os from typing import Optional import numpy as np import paddle from numpy import float32 from yacs.config import CfgNode from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.asr.infer import model_alias from paddlespeech.cli.log import logger from paddlespeech.cli.utils import download_and_decompress from paddlespeech.cli.utils import MODEL_HOME from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.speech import SpeechSegment from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.tensor_utils import add_sos_eos from paddlespeech.s2t.utils.tensor_utils import pad_sequence from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.server.engine.asr.online.ctc_search import CTCPrefixBeamSearch from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.audio_process import pcm2float from paddlespeech.server.utils.paddle_predictor import init_predictor __all__ = ['ASREngine'] pretrained_models = { "deepspeech2online_aishell-zh-16k": { 'url': 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz', 'md5': '23e16c69730a1cb5d735c98c83c21e16', 'cfg_path': 'model.yaml', 'ckpt_path': 'exp/deepspeech2_online/checkpoints/avg_1', 'model': 'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel', 'params': 'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams', 'lm_url': 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', 'lm_md5': '29e02312deb2e59b3c8686c7966d4fe3' }, "conformer_online_multicn-zh-16k": { 'url': 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.3.model.tar.gz', 'md5': '0ac93d390552336f2a906aec9e33c5fa', 'cfg_path': 'model.yaml', 'ckpt_path': 'exp/chunk_conformer/checkpoints/multi_cn', 'model': 'exp/chunk_conformer/checkpoints/multi_cn.pdparams', 'params': 'exp/chunk_conformer/checkpoints/multi_cn.pdparams', 'lm_url': 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', 'lm_md5': '29e02312deb2e59b3c8686c7966d4fe3' }, } # ASR server connection process class class PaddleASRConnectionHanddler: def __init__(self, asr_engine): """Init a Paddle ASR Connection Handler instance Args: asr_engine (ASREngine): the global asr engine """ super().__init__() logger.info( "create an paddle asr connection handler to process the websocket connection" ) self.config = asr_engine.config self.model_config = asr_engine.executor.config self.asr_engine = asr_engine self.init() self.reset() def init(self): # model_type, sample_rate and text_feature is shared for deepspeech2 and conformer self.model_type = self.asr_engine.executor.model_type self.sample_rate = self.asr_engine.executor.sample_rate # tokens to text self.text_feature = self.asr_engine.executor.text_feature if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: from paddlespeech.s2t.io.collator import SpeechCollator self.am_predictor = self.asr_engine.executor.am_predictor self.collate_fn_test = SpeechCollator.from_config(self.model_config) self.decoder = CTCDecoder( odim=self.model_config.output_dim, # is in vocab enc_n_units=self.model_config.rnn_layer_size * 2, blank_id=self.model_config.blank_id, dropout_rate=0.0, reduction=True, # sum batch_average=True, # sum / batch_size grad_norm_type=self.model_config.get('ctc_grad_norm_type', None)) cfg = self.model_config.decode decode_batch_size = 1 # for online self.decoder.init_decoder( decode_batch_size, self.text_feature.vocab_list, cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, cfg.num_proc_bsearch) # frame window samples length and frame shift samples length self.win_length = int(self.model_config.window_ms / 1000 * self.sample_rate) self.n_shift = int(self.model_config.stride_ms / 1000 * self.sample_rate) elif "conformer" in self.model_type or "transformer" in self.model_type: # acoustic model self.model = self.asr_engine.executor.model # ctc decoding config self.ctc_decode_config = self.asr_engine.executor.config.decode self.searcher = CTCPrefixBeamSearch(self.ctc_decode_config) # extract feat, new only fbank in conformer model self.preprocess_conf = self.model_config.preprocess_config self.preprocess_args = {"train": False} self.preprocessing = Transformation(self.preprocess_conf) # frame window samples length and frame shift samples length self.win_length = self.preprocess_conf.process[0]['win_length'] self.n_shift = self.preprocess_conf.process[0]['n_shift'] def extract_feat(self, samples): if "deepspeech2online" in self.model_type: # self.reamined_wav stores all the samples, # include the original remained_wav and this package samples samples = np.frombuffer(samples, dtype=np.int16) assert samples.ndim == 1 # pcm16 -> pcm 32 # pcm2float will change the orignal samples, # so we shoule do pcm2float before concatenate samples = pcm2float(samples) if self.remained_wav is None: self.remained_wav = samples else: assert self.remained_wav.ndim == 1 self.remained_wav = np.concatenate([self.remained_wav, samples]) logger.info( f"The connection remain the audio samples: {self.remained_wav.shape}" ) # read audio speech_segment = SpeechSegment.from_pcm( self.remained_wav, self.sample_rate, transcript=" ") # audio augment self.collate_fn_test.augmentation.transform_audio(speech_segment) # extract speech feature spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize( speech_segment, self.collate_fn_test.keep_transcription_text) # CMVN spectrum if self.collate_fn_test._normalizer: spectrum = self.collate_fn_test._normalizer.apply(spectrum) # spectrum augment audio = self.collate_fn_test.augmentation.transform_feature( spectrum) audio_len = audio.shape[0] audio = paddle.to_tensor(audio, dtype='float32') # audio_len = paddle.to_tensor(audio_len) audio = paddle.unsqueeze(audio, axis=0) if self.cached_feat is None: self.cached_feat = audio else: assert (len(audio.shape) == 3) assert (len(self.cached_feat.shape) == 3) self.cached_feat = paddle.concat( [self.cached_feat, audio], axis=1) # set the feat device if self.device is None: self.device = self.cached_feat.place self.num_frames += audio_len self.remained_wav = self.remained_wav[self.n_shift * audio_len:] logger.info( f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}" ) logger.info( f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" ) elif "conformer_online" in self.model_type: logger.info("Online ASR extract the feat") samples = np.frombuffer(samples, dtype=np.int16) assert samples.ndim == 1 logger.info(f"This package receive {samples.shape[0]} pcm data") self.num_samples += samples.shape[0] # self.reamined_wav stores all the samples, # include the original remained_wav and this package samples if self.remained_wav is None: self.remained_wav = samples else: assert self.remained_wav.ndim == 1 self.remained_wav = np.concatenate([self.remained_wav, samples]) logger.info( f"The connection remain the audio samples: {self.remained_wav.shape}" ) if len(self.remained_wav) < self.win_length: return 0 # fbank x_chunk = self.preprocessing(self.remained_wav, **self.preprocess_args) x_chunk = paddle.to_tensor( x_chunk, dtype="float32").unsqueeze(axis=0) if self.cached_feat is None: self.cached_feat = x_chunk else: assert (len(x_chunk.shape) == 3) assert (len(self.cached_feat.shape) == 3) self.cached_feat = paddle.concat( [self.cached_feat, x_chunk], axis=1) # set the feat device if self.device is None: self.device = self.cached_feat.place num_frames = x_chunk.shape[1] self.num_frames += num_frames self.remained_wav = self.remained_wav[self.n_shift * num_frames:] logger.info( f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}" ) logger.info( f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" ) # logger.info(f"accumulate samples: {self.num_samples}") def reset(self): if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: # for deepspeech2 self.chunk_state_h_box = copy.deepcopy( self.asr_engine.executor.chunk_state_h_box) self.chunk_state_c_box = copy.deepcopy( self.asr_engine.executor.chunk_state_c_box) self.decoder.reset_decoder(batch_size=1) # for conformer online self.subsampling_cache = None self.elayers_output_cache = None self.conformer_cnn_cache = None self.encoder_out = None self.cached_feat = None self.remained_wav = None self.offset = 0 self.num_samples = 0 self.device = None self.hyps = [] self.num_frames = 0 self.chunk_num = 0 self.global_frame_offset = 0 self.result_transcripts = [''] def decode(self, is_finished=False): if "deepspeech2online" in self.model_type: # x_chunk 是特征数据 decoding_chunk_size = 1 # decoding_chunk_size=1 in deepspeech2 model context = 7 # context=7 in deepspeech2 model subsampling = 4 # subsampling=4 in deepspeech2 model stride = subsampling * decoding_chunk_size cached_feature_num = context - subsampling # decoding window for model decoding_window = (decoding_chunk_size - 1) * subsampling + context if self.cached_feat is None: logger.info("no audio feat, please input more pcm data") return num_frames = self.cached_feat.shape[1] logger.info( f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" ) # the cached feat must be larger decoding_window if num_frames < decoding_window and not is_finished: logger.info( f"frame feat num is less than {decoding_window}, please input more pcm data" ) return None, None # if is_finished=True, we need at least context frames if num_frames < context: logger.info( "flast {num_frames} is less than context {context} frames, and we cannot do model forward" ) return None, None logger.info("start to do model forward") # num_frames - context + 1 ensure that current frame can get context window if is_finished: # if get the finished chunk, we need process the last context left_frames = context else: # we only process decoding_window frames for one chunk left_frames = decoding_window for cur in range(0, num_frames - left_frames + 1, stride): end = min(cur + decoding_window, num_frames) # extract the audio x_chunk = self.cached_feat[:, cur:end, :].numpy() x_chunk_lens = np.array([x_chunk.shape[1]]) trans_best = self.decode_one_chunk(x_chunk, x_chunk_lens) self.result_transcripts = [trans_best] self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :] # return trans_best[0] elif "conformer" in self.model_type or "transformer" in self.model_type: try: logger.info( f"we will use the transformer like model : {self.model_type}" ) self.advance_decoding(is_finished) self.update_result() except Exception as e: logger.exception(e) else: raise Exception("invalid model name") @paddle.no_grad() def decode_one_chunk(self, x_chunk, x_chunk_lens): logger.info("start to decoce one chunk with deepspeech2 model") input_names = self.am_predictor.get_input_names() audio_handle = self.am_predictor.get_input_handle(input_names[0]) audio_len_handle = self.am_predictor.get_input_handle(input_names[1]) h_box_handle = self.am_predictor.get_input_handle(input_names[2]) c_box_handle = self.am_predictor.get_input_handle(input_names[3]) audio_handle.reshape(x_chunk.shape) audio_handle.copy_from_cpu(x_chunk) audio_len_handle.reshape(x_chunk_lens.shape) audio_len_handle.copy_from_cpu(x_chunk_lens) h_box_handle.reshape(self.chunk_state_h_box.shape) h_box_handle.copy_from_cpu(self.chunk_state_h_box) c_box_handle.reshape(self.chunk_state_c_box.shape) c_box_handle.copy_from_cpu(self.chunk_state_c_box) output_names = self.am_predictor.get_output_names() output_handle = self.am_predictor.get_output_handle(output_names[0]) output_lens_handle = self.am_predictor.get_output_handle( output_names[1]) output_state_h_handle = self.am_predictor.get_output_handle( output_names[2]) output_state_c_handle = self.am_predictor.get_output_handle( output_names[3]) self.am_predictor.run() output_chunk_probs = output_handle.copy_to_cpu() output_chunk_lens = output_lens_handle.copy_to_cpu() self.chunk_state_h_box = output_state_h_handle.copy_to_cpu() self.chunk_state_c_box = output_state_c_handle.copy_to_cpu() self.decoder.next(output_chunk_probs, output_chunk_lens) trans_best, trans_beam = self.decoder.decode() logger.info(f"decode one best result: {trans_best[0]}") return trans_best[0] @paddle.no_grad() def advance_decoding(self, is_finished=False): logger.info("start to decode with advanced_decoding method") cfg = self.ctc_decode_config decoding_chunk_size = cfg.decoding_chunk_size num_decoding_left_chunks = cfg.num_decoding_left_chunks assert decoding_chunk_size > 0 subsampling = self.model.encoder.embed.subsampling_rate context = self.model.encoder.embed.right_context + 1 stride = subsampling * decoding_chunk_size cached_feature_num = context - subsampling # processed chunk feature cached for next chunk # decoding window for model decoding_window = (decoding_chunk_size - 1) * subsampling + context if self.cached_feat is None: logger.info("no audio feat, please input more pcm data") return num_frames = self.cached_feat.shape[1] logger.info( f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" ) # the cached feat must be larger decoding_window if num_frames < decoding_window and not is_finished: logger.info( f"frame feat num is less than {decoding_window}, please input more pcm data" ) return None, None # if is_finished=True, we need at least context frames if num_frames < context: logger.info( "flast {num_frames} is less than context {context} frames, and we cannot do model forward" ) return None, None logger.info("start to do model forward") required_cache_size = decoding_chunk_size * num_decoding_left_chunks outputs = [] # num_frames - context + 1 ensure that current frame can get context window if is_finished: # if get the finished chunk, we need process the last context left_frames = context else: # we only process decoding_window frames for one chunk left_frames = decoding_window # record the end for removing the processed feat end = None for cur in range(0, num_frames - left_frames + 1, stride): end = min(cur + decoding_window, num_frames) self.chunk_num += 1 chunk_xs = self.cached_feat[:, cur:end, :] (y, self.subsampling_cache, self.elayers_output_cache, self.conformer_cnn_cache) = self.model.encoder.forward_chunk( chunk_xs, self.offset, required_cache_size, self.subsampling_cache, self.elayers_output_cache, self.conformer_cnn_cache) outputs.append(y) # update the offset self.offset += y.shape[1] ys = paddle.cat(outputs, 1) if self.encoder_out is None: self.encoder_out = ys else: self.encoder_out = paddle.concat([self.encoder_out, ys], axis=1) # get the ctc probs ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size) ctc_probs = ctc_probs.squeeze(0) self.searcher.search(ctc_probs, self.cached_feat.place) self.hyps = self.searcher.get_one_best_hyps() assert self.cached_feat.shape[0] == 1 assert end >= cached_feature_num self.cached_feat = self.cached_feat[0, end - cached_feature_num:, :].unsqueeze(0) assert len( self.cached_feat.shape ) == 3, f"current cache feat shape is: {self.cached_feat.shape}" logger.info( f"This connection handler encoder out shape: {self.encoder_out.shape}" ) def update_result(self): logger.info("update the final result") hyps = self.hyps self.result_transcripts = [ self.text_feature.defeaturize(hyp) for hyp in hyps ] self.result_tokenids = [hyp for hyp in hyps] def get_result(self): if len(self.result_transcripts) > 0: return self.result_transcripts[0] else: return '' @paddle.no_grad() def rescoring(self): if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: return logger.info("rescoring the final result") if "attention_rescoring" != self.ctc_decode_config.decoding_method: return self.searcher.finalize_search() self.update_result() beam_size = self.ctc_decode_config.beam_size hyps = self.searcher.get_hyps() if hyps is None or len(hyps) == 0: return # assert len(hyps) == beam_size hyp_list = [] for hyp in hyps: hyp_content = hyp[0] # Prevent the hyp is empty if len(hyp_content) == 0: hyp_content = (self.model.ctc.blank_id, ) hyp_content = paddle.to_tensor( hyp_content, place=self.device, dtype=paddle.long) hyp_list.append(hyp_content) hyps_pad = pad_sequence(hyp_list, True, self.model.ignore_id) hyps_lens = paddle.to_tensor( [len(hyp[0]) for hyp in hyps], place=self.device, dtype=paddle.long) # (beam_size,) hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos, self.model.ignore_id) hyps_lens = hyps_lens + 1 # Add at begining encoder_out = self.encoder_out.repeat(beam_size, 1, 1) encoder_mask = paddle.ones( (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool) decoder_out, _ = self.model.decoder( encoder_out, encoder_mask, hyps_pad, hyps_lens) # (beam_size, max_hyps_len, vocab_size) # ctc score in ln domain decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) decoder_out = decoder_out.numpy() # Only use decoder score for rescoring best_score = -float('inf') best_index = 0 # hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size for i, hyp in enumerate(hyps): score = 0.0 for j, w in enumerate(hyp[0]): score += decoder_out[i][j][w] # last decoder output token is `eos`, for laste decoder input token. score += decoder_out[i][len(hyp[0])][self.model.eos] # add ctc score (which in ln domain) score += hyp[1] * self.ctc_decode_config.ctc_weight if score > best_score: best_score = score best_index = i # update the one best result logger.info(f"best index: {best_index}") self.hyps = [hyps[best_index][0]] self.update_result() class ASRServerExecutor(ASRExecutor): def __init__(self): super().__init__() pass def _get_pretrained_path(self, tag: str) -> os.PathLike: """ Download and returns pretrained resources path of current task. """ support_models = list(pretrained_models.keys()) assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format( tag, '\n\t\t'.join(support_models)) res_path = os.path.join(MODEL_HOME, tag) decompressed_path = download_and_decompress(pretrained_models[tag], res_path) decompressed_path = os.path.abspath(decompressed_path) logger.info( 'Use pretrained model stored in: {}'.format(decompressed_path)) return decompressed_path def _init_from_path(self, model_type: str='deepspeech2online_aishell', am_model: Optional[os.PathLike]=None, am_params: Optional[os.PathLike]=None, lang: str='zh', sample_rate: int=16000, cfg_path: Optional[os.PathLike]=None, decode_method: str='attention_rescoring', am_predictor_conf: dict=None): """ Init model and other resources from a specific path. """ self.model_type = model_type self.sample_rate = sample_rate if cfg_path is None or am_model is None or am_params is None: sample_rate_str = '16k' if sample_rate == 16000 else '8k' tag = model_type + '-' + lang + '-' + sample_rate_str logger.info(f"Load the pretrained model, tag = {tag}") res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.res_path = res_path self.cfg_path = os.path.join(res_path, pretrained_models[tag]['cfg_path']) self.am_model = os.path.join(res_path, pretrained_models[tag]['model']) self.am_params = os.path.join(res_path, pretrained_models[tag]['params']) logger.info(res_path) else: self.cfg_path = os.path.abspath(cfg_path) self.am_model = os.path.abspath(am_model) self.am_params = os.path.abspath(am_params) self.res_path = os.path.dirname( os.path.dirname(os.path.abspath(self.cfg_path))) logger.info(self.cfg_path) logger.info(self.am_model) logger.info(self.am_params) #Init body. self.config = CfgNode(new_allowed=True) self.config.merge_from_file(self.cfg_path) with UpdateConfig(self.config): if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: from paddlespeech.s2t.io.collator import SpeechCollator self.vocab = self.config.vocab_filepath self.config.decode.lang_model_path = os.path.join( MODEL_HOME, 'language_model', self.config.decode.lang_model_path) self.collate_fn_test = SpeechCollator.from_config(self.config) self.text_feature = TextFeaturizer( unit_type=self.config.unit_type, vocab=self.vocab) lm_url = pretrained_models[tag]['lm_url'] lm_md5 = pretrained_models[tag]['lm_md5'] logger.info(f"Start to load language model {lm_url}") self.download_lm( lm_url, os.path.dirname(self.config.decode.lang_model_path), lm_md5) elif "conformer" in model_type or "transformer" in model_type: logger.info("start to create the stream conformer asr engine") if self.config.spm_model_prefix: self.config.spm_model_prefix = os.path.join( self.res_path, self.config.spm_model_prefix) self.vocab = self.config.vocab_filepath self.text_feature = TextFeaturizer( unit_type=self.config.unit_type, vocab=self.config.vocab_filepath, spm_model_prefix=self.config.spm_model_prefix) # update the decoding method if decode_method: self.config.decode.decoding_method = decode_method # we only support ctc_prefix_beam_search and attention_rescoring dedoding method # Generally we set the decoding_method to attention_rescoring if self.config.decode.decoding_method not in [ "ctc_prefix_beam_search", "attention_rescoring" ]: logger.info( "we set the decoding_method to attention_rescoring") self.config.decode.decoding = "attention_rescoring" assert self.config.decode.decoding_method in [ "ctc_prefix_beam_search", "attention_rescoring" ], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}" else: raise Exception("wrong type") if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: # AM predictor logger.info("ASR engine start to init the am predictor") self.am_predictor_conf = am_predictor_conf self.am_predictor = init_predictor( model_file=self.am_model, params_file=self.am_params, predictor_conf=self.am_predictor_conf) # decoder logger.info("ASR engine start to create the ctc decoder instance") self.decoder = CTCDecoder( odim=self.config.output_dim, # is in vocab enc_n_units=self.config.rnn_layer_size * 2, blank_id=self.config.blank_id, dropout_rate=0.0, reduction=True, # sum batch_average=True, # sum / batch_size grad_norm_type=self.config.get('ctc_grad_norm_type', None)) # init decoder logger.info("ASR engine start to init the ctc decoder") cfg = self.config.decode decode_batch_size = 1 # for online self.decoder.init_decoder( decode_batch_size, self.text_feature.vocab_list, cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, cfg.num_proc_bsearch) # init state box self.chunk_state_h_box = np.zeros( (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), dtype=float32) self.chunk_state_c_box = np.zeros( (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), dtype=float32) elif "conformer" in model_type or "transformer" in model_type: model_name = model_type[:model_type.rindex( '_')] # model_type: {model_name}_{dataset} logger.info(f"model name: {model_name}") model_class = dynamic_import(model_name, model_alias) model_conf = self.config model = model_class.from_config(model_conf) self.model = model self.model.eval() # load model model_dict = paddle.load(self.am_model) self.model.set_state_dict(model_dict) logger.info("create the transformer like model success") # update the ctc decoding self.searcher = CTCPrefixBeamSearch(self.config.decode) self.transformer_decode_reset() def reset_decoder_and_chunk(self): """reset decoder and chunk state for an new audio """ if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: self.decoder.reset_decoder(batch_size=1) # init state box, for new audio request self.chunk_state_h_box = np.zeros( (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), dtype=float32) self.chunk_state_c_box = np.zeros( (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), dtype=float32) elif "conformer" in self.model_type or "transformer" in self.model_type: self.transformer_decode_reset() def decode_one_chunk(self, x_chunk, x_chunk_lens, model_type: str): """decode one chunk Args: x_chunk (numpy.array): shape[B, T, D] x_chunk_lens (numpy.array): shape[B] model_type (str): online model type Returns: str: one best result """ logger.info("start to decoce chunk by chunk") if "deepspeech2online" in model_type: input_names = self.am_predictor.get_input_names() audio_handle = self.am_predictor.get_input_handle(input_names[0]) audio_len_handle = self.am_predictor.get_input_handle( input_names[1]) h_box_handle = self.am_predictor.get_input_handle(input_names[2]) c_box_handle = self.am_predictor.get_input_handle(input_names[3]) audio_handle.reshape(x_chunk.shape) audio_handle.copy_from_cpu(x_chunk) audio_len_handle.reshape(x_chunk_lens.shape) audio_len_handle.copy_from_cpu(x_chunk_lens) h_box_handle.reshape(self.chunk_state_h_box.shape) h_box_handle.copy_from_cpu(self.chunk_state_h_box) c_box_handle.reshape(self.chunk_state_c_box.shape) c_box_handle.copy_from_cpu(self.chunk_state_c_box) output_names = self.am_predictor.get_output_names() output_handle = self.am_predictor.get_output_handle(output_names[0]) output_lens_handle = self.am_predictor.get_output_handle( output_names[1]) output_state_h_handle = self.am_predictor.get_output_handle( output_names[2]) output_state_c_handle = self.am_predictor.get_output_handle( output_names[3]) self.am_predictor.run() output_chunk_probs = output_handle.copy_to_cpu() output_chunk_lens = output_lens_handle.copy_to_cpu() self.chunk_state_h_box = output_state_h_handle.copy_to_cpu() self.chunk_state_c_box = output_state_c_handle.copy_to_cpu() self.decoder.next(output_chunk_probs, output_chunk_lens) trans_best, trans_beam = self.decoder.decode() logger.info(f"decode one best result: {trans_best[0]}") return trans_best[0] elif "conformer" in model_type or "transformer" in model_type: try: logger.info( f"we will use the transformer like model : {self.model_type}" ) self.advanced_decoding(x_chunk, x_chunk_lens) self.update_result() return self.result_transcripts[0] except Exception as e: logger.exception(e) else: raise Exception("invalid model name") def advanced_decoding(self, xs: paddle.Tensor, x_chunk_lens): logger.info("start to decode with advanced_decoding method") encoder_out, encoder_mask = self.encoder_forward(xs) ctc_probs = self.model.ctc.log_softmax( encoder_out) # (1, maxlen, vocab_size) ctc_probs = ctc_probs.squeeze(0) self.searcher.search(ctc_probs, xs.place) # update the one best result self.hyps = self.searcher.get_one_best_hyps() # now we supprot ctc_prefix_beam_search and attention_rescoring if "attention_rescoring" in self.config.decode.decoding_method: self.rescoring(encoder_out, xs.place) def encoder_forward(self, xs): logger.info("get the model out from the feat") cfg = self.config.decode decoding_chunk_size = cfg.decoding_chunk_size num_decoding_left_chunks = cfg.num_decoding_left_chunks assert decoding_chunk_size > 0 subsampling = self.model.encoder.embed.subsampling_rate context = self.model.encoder.embed.right_context + 1 stride = subsampling * decoding_chunk_size # decoding window for model decoding_window = (decoding_chunk_size - 1) * subsampling + context num_frames = xs.shape[1] required_cache_size = decoding_chunk_size * num_decoding_left_chunks logger.info("start to do model forward") outputs = [] # num_frames - context + 1 ensure that current frame can get context window for cur in range(0, num_frames - context + 1, stride): end = min(cur + decoding_window, num_frames) chunk_xs = xs[:, cur:end, :] (y, self.subsampling_cache, self.elayers_output_cache, self.conformer_cnn_cache) = self.model.encoder.forward_chunk( chunk_xs, self.offset, required_cache_size, self.subsampling_cache, self.elayers_output_cache, self.conformer_cnn_cache) outputs.append(y) self.offset += y.shape[1] ys = paddle.cat(outputs, 1) masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) masks = masks.unsqueeze(1) return ys, masks def rescoring(self, encoder_out, device): logger.info("start to rescoring the hyps") beam_size = self.config.decode.beam_size hyps = self.searcher.get_hyps() assert len(hyps) == beam_size hyp_list = [] for hyp in hyps: hyp_content = hyp[0] # Prevent the hyp is empty if len(hyp_content) == 0: hyp_content = (self.model.ctc.blank_id, ) hyp_content = paddle.to_tensor( hyp_content, place=device, dtype=paddle.long) hyp_list.append(hyp_content) hyps_pad = pad_sequence(hyp_list, True, self.model.ignore_id) hyps_lens = paddle.to_tensor( [len(hyp[0]) for hyp in hyps], place=device, dtype=paddle.long) # (beam_size,) hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos, self.model.ignore_id) hyps_lens = hyps_lens + 1 # Add at begining encoder_out = encoder_out.repeat(beam_size, 1, 1) encoder_mask = paddle.ones( (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool) decoder_out, _ = self.model.decoder( encoder_out, encoder_mask, hyps_pad, hyps_lens) # (beam_size, max_hyps_len, vocab_size) # ctc score in ln domain decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) decoder_out = decoder_out.numpy() # Only use decoder score for rescoring best_score = -float('inf') best_index = 0 # hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size for i, hyp in enumerate(hyps): score = 0.0 for j, w in enumerate(hyp[0]): score += decoder_out[i][j][w] # last decoder output token is `eos`, for laste decoder input token. score += decoder_out[i][len(hyp[0])][self.model.eos] # add ctc score (which in ln domain) score += hyp[1] * self.config.decode.ctc_weight if score > best_score: best_score = score best_index = i # update the one best result self.hyps = [hyps[best_index][0]] return hyps[best_index][0] def transformer_decode_reset(self): self.subsampling_cache = None self.elayers_output_cache = None self.conformer_cnn_cache = None self.offset = 0 # decoding reset self.searcher.reset() def update_result(self): logger.info("update the final result") hyps = self.hyps self.result_transcripts = [ self.text_feature.defeaturize(hyp) for hyp in hyps ] self.result_tokenids = [hyp for hyp in hyps] def extract_feat(self, samples, sample_rate): """extract feat Args: samples (numpy.array): numpy.float32 sample_rate (int): sample rate Returns: x_chunk (numpy.array): shape[B, T, D] x_chunk_lens (numpy.array): shape[B] """ if "deepspeech2online" in self.model_type: # pcm16 -> pcm 32 samples = pcm2float(samples) # read audio speech_segment = SpeechSegment.from_pcm( samples, sample_rate, transcript=" ") # audio augment self.collate_fn_test.augmentation.transform_audio(speech_segment) # extract speech feature spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize( speech_segment, self.collate_fn_test.keep_transcription_text) # CMVN spectrum if self.collate_fn_test._normalizer: spectrum = self.collate_fn_test._normalizer.apply(spectrum) # spectrum augment audio = self.collate_fn_test.augmentation.transform_feature( spectrum) audio_len = audio.shape[0] audio = paddle.to_tensor(audio, dtype='float32') # audio_len = paddle.to_tensor(audio_len) audio = paddle.unsqueeze(audio, axis=0) x_chunk = audio.numpy() x_chunk_lens = np.array([audio_len]) return x_chunk, x_chunk_lens elif "conformer_online" in self.model_type: if sample_rate != self.sample_rate: logger.info(f"audio sample rate {sample_rate} is not match," "the model sample_rate is {self.sample_rate}") logger.info(f"ASR Engine use the {self.model_type} to process") logger.info("Create the preprocess instance") preprocess_conf = self.config.preprocess_config preprocess_args = {"train": False} preprocessing = Transformation(preprocess_conf) logger.info("Read the audio file") logger.info(f"audio shape: {samples.shape}") # fbank x_chunk = preprocessing(samples, **preprocess_args) x_chunk_lens = paddle.to_tensor(x_chunk.shape[0]) x_chunk = paddle.to_tensor( x_chunk, dtype="float32").unsqueeze(axis=0) logger.info( f"process the audio feature success, feat shape: {x_chunk.shape}" ) return x_chunk, x_chunk_lens class ASREngine(BaseEngine): """ASR server engine Args: metaclass: Defaults to Singleton. """ def __init__(self): super(ASREngine, self).__init__() logger.info("create the online asr engine instance") def init(self, config: dict) -> bool: """init engine resource Args: config_file (str): config file Returns: bool: init failed or success """ self.input = None self.output = "" self.executor = ASRServerExecutor() self.config = config try: if self.config.get("device", None): self.device = self.config.device else: self.device = paddle.get_device() logger.info(f"paddlespeech_server set the device: {self.device}") paddle.set_device(self.device) except BaseException: logger.error( "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" ) self.executor._init_from_path( model_type=self.config.model_type, am_model=self.config.am_model, am_params=self.config.am_params, lang=self.config.lang, sample_rate=self.config.sample_rate, cfg_path=self.config.cfg_path, decode_method=self.config.decode_method, am_predictor_conf=self.config.am_predictor_conf) logger.info("Initialize ASR server engine successfully.") return True def preprocess(self, samples, sample_rate, model_type="deepspeech2online_aishell-zh-16k"): """preprocess Args: samples (numpy.array): numpy.float32 sample_rate (int): sample rate Returns: x_chunk (numpy.array): shape[B, T, D] x_chunk_lens (numpy.array): shape[B] """ # if "deepspeech" in model_type: x_chunk, x_chunk_lens = self.executor.extract_feat(samples, sample_rate) return x_chunk, x_chunk_lens def run(self, x_chunk, x_chunk_lens, decoder_chunk_size=1): """run online engine Args: x_chunk (numpy.array): shape[B, T, D] x_chunk_lens (numpy.array): shape[B] decoder_chunk_size(int) """ self.output = self.executor.decode_one_chunk(x_chunk, x_chunk_lens, self.config.model_type) def postprocess(self): """postprocess """ return self.output def reset(self): """reset engine decoder and inference state """ self.executor.reset_decoder_and_chunk() self.output = ""