From 90d648a601d64aefea0dd9d4a63a87eede1a8b09 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Wed, 1 Dec 2021 12:12:55 +0000 Subject: [PATCH] support using by __call__ --- paddlespeech/cli/asr/infer.py | 138 ++++++++++++++++++---------------- 1 file changed, 75 insertions(+), 63 deletions(-) diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index a0ae53507..ea1828b6b 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -119,7 +119,8 @@ class ASRExecutor(BaseExecutor): lang: str='zh', model_sample_rate: int=16000, cfg_path: Optional[os.PathLike]=None, - ckpt_path: Optional[os.PathLike]=None): + ckpt_path: Optional[os.PathLike]=None, + device: str='cpu'): """ Init model and other resources from a specific path. """ @@ -140,12 +141,8 @@ class ASRExecutor(BaseExecutor): res_path = os.path.dirname( os.path.dirname(os.path.abspath(self.cfg_path))) - # Enter the path of model root - os.chdir(res_path) - #Init body. - parser_args = self.parser_args - paddle.set_device(parser_args.device) + paddle.set_device(device) self.config = CfgNode(new_allowed=True) self.config.merge_from_file(self.cfg_path) self.config.decoding.decoding_method = "attention_rescoring" @@ -153,29 +150,35 @@ class ASRExecutor(BaseExecutor): logger.info(model_conf) with UpdateConfig(model_conf): - if parser_args.model == "ds2_online" or parser_args.model == "ds2_offline": + if model_type == "ds2_online" or model_type == "ds2_offline": from paddlespeech.s2t.io.collator import SpeechCollator self.config.collator.vocab_filepath = os.path.join( res_path, self.config.collator.vocab_filepath) - self.config.collator.vocab_filepath = os.path.join( + self.config.collator.mean_std_filepath = os.path.join( res_path, self.config.collator.cmvn_path) self.collate_fn_test = SpeechCollator.from_config(self.config) + text_feature = TextFeaturizer( + unit_type=self.config.collator.unit_type, + vocab_filepath=self.config.collator.vocab_filepath, + spm_model_prefix=self.config.collator.spm_model_prefix) model_conf.input_dim = self.collate_fn_test.feature_size - model_conf.output_dim = self.text_feature.vocab_size - elif parser_args.model == "conformer" or parser_args.model == "transformer" or parser_args.model == "wenetspeech": - + model_conf.output_dim = text_feature.vocab_size + elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech": self.config.collator.vocab_filepath = os.path.join( res_path, self.config.collator.vocab_filepath) - self.text_feature = TextFeaturizer( + text_feature = TextFeaturizer( unit_type=self.config.collator.unit_type, vocab_filepath=self.config.collator.vocab_filepath, spm_model_prefix=self.config.collator.spm_model_prefix) model_conf.input_dim = self.config.collator.feat_dim - model_conf.output_dim = self.text_feature.vocab_size + model_conf.output_dim = text_feature.vocab_size else: raise Exception("wrong type") self.config.freeze() - model_class = dynamic_import(parser_args.model, model_alias) + # Enter the path of model root + os.chdir(res_path) + + model_class = dynamic_import(model_type, model_alias) model = model_class.from_config(model_conf) self.model = model self.model.eval() @@ -185,31 +188,31 @@ class ASRExecutor(BaseExecutor): model_dict = paddle.load(params_path) self.model.set_state_dict(model_dict) - def preprocess(self, input: Union[str, os.PathLike]): + def preprocess(self, model_type: str, input: Union[str, os.PathLike]): """ Input preprocess and return paddle.Tensor stored in self.input. Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet). """ - parser_args = self.parser_args - config = self.config audio_file = input logger.info("Preprocess audio_file:" + audio_file) - self.sr = config.collator.target_sample_rate + config_target_sample_rate = self.config.collator.target_sample_rate # Get the object for feature extraction - if parser_args.model == "ds2_online" or parser_args.model == "ds2_offline": + if model_type == "ds2_online" or model_type == "ds2_offline": audio, _ = self.collate_fn_test.process_utterance( audio_file=audio_file, transcript=" ") audio_len = audio.shape[0] audio = paddle.to_tensor(audio, dtype='float32') - self.audio_len = paddle.to_tensor(audio_len) - self.audio = paddle.unsqueeze(audio, axis=0) - self.vocab_list = collate_fn_test.vocab_list - logger.info(f"audio feat shape: {self.audio.shape}") - - elif parser_args.model == "conformer" or parser_args.model == "transformer" or parser_args.model == "wenetspeech": + audio_len = paddle.to_tensor(audio_len) + audio = paddle.unsqueeze(audio, axis=0) + vocab_list = collate_fn_test.vocab_list + self._inputs["audio"] = audio + self._inputs["audio_len"] = audio_len + logger.info(f"audio feat shape: {audio.shape}") + + elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech": logger.info("get the preprocess conf") preprocess_conf = os.path.join( os.path.dirname(os.path.abspath(self.cfg_path)), @@ -235,7 +238,7 @@ class ASRExecutor(BaseExecutor): else: audio = audio[:, 0] - if sample_rate != self.sr: + if sample_rate != config_target_sample_rate: logger.error( f"sample rate error: {sample_rate}, need {self.sr} ") sys.exit(-1) @@ -243,29 +246,36 @@ class ASRExecutor(BaseExecutor): # fbank audio = preprocessing(audio, **preprocess_args) - self.audio_len = paddle.to_tensor(audio.shape[0]) - self.audio = paddle.to_tensor( - audio, dtype='float32').unsqueeze(axis=0) - logger.info(f"audio feat shape: {self.audio.shape}") + audio_len = paddle.to_tensor(audio.shape[0]) + audio = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0) + text_feature = TextFeaturizer( + unit_type=self.config.collator.unit_type, + vocab_filepath=self.config.collator.vocab_filepath, + spm_model_prefix=self.config.collator.spm_model_prefix) + self._inputs["audio"] = audio + self._inputs["audio_len"] = audio_len + logger.info(f"audio feat shape: {audio.shape}") else: raise Exception("wrong type") @paddle.no_grad() - def infer(self): + def infer(self, model_type: str): """ Model inference and result stored in self.output. """ + text_feature = TextFeaturizer( + unit_type=self.config.collator.unit_type, + vocab_filepath=self.config.collator.vocab_filepath, + spm_model_prefix=self.config.collator.spm_model_prefix) cfg = self.config.decoding - parser_args = self.parser_args - audio = self.audio - audio_len = self.audio_len - if parser_args.model == "ds2_online" or parser_args.model == "ds2_offline": - vocab_list = self.vocab_list + audio = self._inputs["audio"] + audio_len = self._inputs["audio_len"] + if model_type == "ds2_online" or model_type == "ds2_offline": result_transcripts = self.model.decode( audio, audio_len, - vocab_list, + text_feature.vocab_list, decoding_method=cfg.decoding_method, lang_model_path=cfg.lang_model_path, beam_alpha=cfg.alpha, @@ -274,14 +284,13 @@ class ASRExecutor(BaseExecutor): cutoff_prob=cfg.cutoff_prob, cutoff_top_n=cfg.cutoff_top_n, num_processes=cfg.num_proc_bsearch) - self.result_transcripts = result_transcripts[0] + self._outputs["result"] = result_transcripts[0] - elif parser_args.model == "conformer" or parser_args.model == "transformer" or parser_args.model == "wenetspeech": - text_feature = self.text_feature + elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech": result_transcripts = self.model.decode( audio, audio_len, - text_feature=self.text_feature, + text_feature=text_feature, decoding_method=cfg.decoding_method, lang_model_path=cfg.lang_model_path, beam_alpha=cfg.alpha, @@ -294,23 +303,22 @@ class ASRExecutor(BaseExecutor): decoding_chunk_size=cfg.decoding_chunk_size, num_decoding_left_chunks=cfg.num_decoding_left_chunks, simulate_streaming=cfg.simulate_streaming) - self.result_transcripts = result_transcripts[0][0] + self._outputs["result"] = result_transcripts[0][0] else: raise Exception("invalid model name") - pass - def postprocess(self) -> Union[str, os.PathLike]: """ Output postprocess and return human-readable results such as texts and audio files. """ - return self.result_transcripts + return self._outputs["result"] def _check(self, audio_file: str, model_sample_rate: int): self.target_sample_rate = model_sample_rate if self.target_sample_rate != 16000 and self.target_sample_rate != 8000: logger.error( - "please input --model_sample_rate 8000 or --model_sample_rate 16000") + "please input --model_sample_rate 8000 or --model_sample_rate 16000" + ) raise Exception("invalid sample rate") sys.exit(-1) @@ -336,11 +344,13 @@ class ASRExecutor(BaseExecutor): sys.exit(-1) logger.info("The sample rate is %d" % sample_rate) if sample_rate != self.target_sample_rate: - logger.warning("The sample rate of the input file is not {}.\n \ + logger.warning( + "The sample rate of the input file is not {}.\n \ The program will resample the wav file to {}.\n \ If the result does not meet your expectations,\n \ Please input the 16k 16bit 1 channel wav file. \ - ".format(self.target_sample_rate, self.target_sample_rate)) + " + .format(self.target_sample_rate, self.target_sample_rate)) while (True): logger.info( "Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream." @@ -367,34 +377,36 @@ class ASRExecutor(BaseExecutor): """ Command line entry. """ - self.parser_args = self.parser.parse_args(argv) + parser_args = self.parser.parse_args(argv) - model = self.parser_args.model - lang = self.parser_args.lang - model_sample_rate = self.parser_args.model_sample_rate - config = self.parser_args.config - ckpt_path = self.parser_args.ckpt_path - audio_file = os.path.abspath(self.parser_args.input) - device = self.parser_args.device + model = parser_args.model + lang = parser_args.lang + model_sample_rate = parser_args.model_sample_rate + config = parser_args.config + ckpt_path = parser_args.ckpt_path + audio_file = parser_args.input + device = parser_args.device try: - res = self(model, lang, model_sample_rate, config, ckpt_path, audio_file, - device) + res = self(model, lang, model_sample_rate, config, ckpt_path, + audio_file, device) logger.info('ASR Result: {}'.format(res)) return True except Exception as e: print(e) return False - def __call__(self, model, lang, model_sample_rate, config, ckpt_path, audio_file, - device): + def __call__(self, model, lang, model_sample_rate, config, ckpt_path, + audio_file, device): """ Python API to call an executor. """ + audio_file = os.path.abspath(audio_file) self._check(audio_file, model_sample_rate) - self._init_from_path(model, lang, model_sample_rate, config, ckpt_path) - self.preprocess(audio_file) - self.infer() + self._init_from_path(model, lang, model_sample_rate, config, ckpt_path, + device) + self.preprocess(model, audio_file) + self.infer(model) res = self.postprocess() # Retrieve result of asr. return res