From 3fadcde5e259aa32a5f9af59843eddf4b0c22b63 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Wed, 1 Dec 2021 10:26:18 +0000 Subject: [PATCH 1/2] revise the asr infer.py --- paddlespeech/cli/asr/infer.py | 119 +++++++++++++++++++++++++++++----- 1 file changed, 102 insertions(+), 17 deletions(-) diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index e5c64e9a..a0ae5350 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -18,17 +18,17 @@ from typing import List from typing import Optional from typing import Union +import librosa import paddle import soundfile +from yacs.config import CfgNode from ..executor import BaseExecutor from ..utils import cli_register from ..utils import download_and_decompress from ..utils import logger from ..utils import MODEL_HOME -from paddlespeech.s2t.exps.u2.config import get_cfg_defaults from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer -from paddlespeech.s2t.io.collator import SpeechCollator from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.utility import UpdateConfig @@ -36,7 +36,7 @@ from paddlespeech.s2t.utils.utility import UpdateConfig __all__ = ['ASRExecutor'] pretrained_models = { - "wenetspeech_zh": { + "wenetspeech_zh_16k": { 'url': 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/conformer.model.tar.gz', 'md5': @@ -73,7 +73,15 @@ class ASRExecutor(BaseExecutor): default='wenetspeech', help='Choose model type of asr task.') self.parser.add_argument( - '--lang', type=str, default='zh', help='Choose model language.') + '--lang', + type=str, + default='zh', + help='Choose model language. zh or en') + self.parser.add_argument( + "--model_sample_rate", + type=int, + default=16000, + help='Choose the audio sample rate of the model. 8000 or 16000') self.parser.add_argument( '--config', type=str, @@ -109,13 +117,15 @@ class ASRExecutor(BaseExecutor): def _init_from_path(self, model_type: str='wenetspeech', lang: str='zh', + model_sample_rate: int=16000, cfg_path: Optional[os.PathLike]=None, ckpt_path: Optional[os.PathLike]=None): """ Init model and other resources from a specific path. """ if cfg_path is None or ckpt_path is None: - tag = model_type + '_' + lang + model_sample_rate_str = '16k' if model_sample_rate == 16000 else '8k' + tag = model_type + '_' + lang + '_' + model_sample_rate_str res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.cfg_path = os.path.join(res_path, pretrained_models[tag]['cfg_path']) @@ -136,23 +146,24 @@ class ASRExecutor(BaseExecutor): #Init body. parser_args = self.parser_args paddle.set_device(parser_args.device) - self.config = get_cfg_defaults() + self.config = CfgNode(new_allowed=True) self.config.merge_from_file(self.cfg_path) self.config.decoding.decoding_method = "attention_rescoring" - #self.config.freeze() model_conf = self.config.model logger.info(model_conf) with UpdateConfig(model_conf): if parser_args.model == "ds2_online" or parser_args.model == "ds2_offline": + from paddlespeech.s2t.io.collator import SpeechCollator self.config.collator.vocab_filepath = os.path.join( res_path, self.config.collator.vocab_filepath) self.config.collator.vocab_filepath = os.path.join( res_path, self.config.collator.cmvn_path) self.collate_fn_test = SpeechCollator.from_config(self.config) - model_conf.feat_size = self.collate_fn_test.feature_size - model_conf.dict_size = self.text_feature.vocab_size + model_conf.input_dim = self.collate_fn_test.feature_size + model_conf.output_dim = self.text_feature.vocab_size elif parser_args.model == "conformer" or parser_args.model == "transformer" or parser_args.model == "wenetspeech": + self.config.collator.vocab_filepath = os.path.join( res_path, self.config.collator.vocab_filepath) self.text_feature = TextFeaturizer( @@ -163,6 +174,7 @@ class ASRExecutor(BaseExecutor): model_conf.output_dim = self.text_feature.vocab_size else: raise Exception("wrong type") + self.config.freeze() model_class = dynamic_import(parser_args.model, model_alias) model = model_class.from_config(model_conf) self.model = model @@ -182,13 +194,13 @@ class ASRExecutor(BaseExecutor): parser_args = self.parser_args config = self.config audio_file = input - logger.info("audio_file" + audio_file) + logger.info("Preprocess audio_file:" + audio_file) self.sr = config.collator.target_sample_rate # Get the object for feature extraction if parser_args.model == "ds2_online" or parser_args.model == "ds2_offline": - audio, _ = collate_fn_test.process_utterance( + audio, _ = self.collate_fn_test.process_utterance( audio_file=audio_file, transcript=" ") audio_len = audio.shape[0] audio = paddle.to_tensor(audio, dtype='float32') @@ -203,18 +215,30 @@ class ASRExecutor(BaseExecutor): os.path.dirname(os.path.abspath(self.cfg_path)), "preprocess.yaml") - cmvn_path: data / mean_std.json - logger.info(preprocess_conf) preprocess_args = {"train": False} preprocessing = Transformation(preprocess_conf) + logger.info("read the audio file") audio, sample_rate = soundfile.read( audio_file, dtype="int16", always_2d=True) + + if self.change_format: + if audio.shape[1] >= 2: + audio = audio.mean(axis=1) + else: + audio = audio[:, 0] + audio = audio.astype("float32") + audio = librosa.resample(audio, sample_rate, + self.target_sample_rate) + sample_rate = self.target_sample_rate + audio = audio.astype("int16") + else: + audio = audio[:, 0] + if sample_rate != self.sr: logger.error( f"sample rate error: {sample_rate}, need {self.sr} ") sys.exit(-1) - audio = audio[:, 0] logger.info(f"audio shape: {audio.shape}") # fbank audio = preprocessing(audio, **preprocess_args) @@ -282,6 +306,63 @@ class ASRExecutor(BaseExecutor): """ return self.result_transcripts + def _check(self, audio_file: str, model_sample_rate: int): + self.target_sample_rate = model_sample_rate + if self.target_sample_rate != 16000 and self.target_sample_rate != 8000: + logger.error( + "please input --model_sample_rate 8000 or --model_sample_rate 16000") + raise Exception("invalid sample rate") + sys.exit(-1) + + if not os.path.isfile(audio_file): + logger.error("Please input the right audio file path") + sys.exit(-1) + + logger.info("checking the audio file format......") + try: + sig, sample_rate = soundfile.read( + audio_file, dtype="int16", always_2d=True) + except Exception as e: + logger.error(str(e)) + logger.error( + "can not open the audio file, please check the audio file format is 'wav'. \n \ + you can try to use sox to change the file format.\n \ + For example: \n \ + sample rate: 16k \n \ + sox input_audio.xx --rate 16k --bits 16 --channels 1 output_audio.wav \n \ + sample rate: 8k \n \ + sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \ + ") + sys.exit(-1) + logger.info("The sample rate is %d" % sample_rate) + if sample_rate != self.target_sample_rate: + logger.warning("The sample rate of the input file is not {}.\n \ + The program will resample the wav file to {}.\n \ + If the result does not meet your expectations,\n \ + Please input the 16k 16bit 1 channel wav file. \ + ".format(self.target_sample_rate, self.target_sample_rate)) + while (True): + logger.info( + "Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream." + ) + content = input("Input(Y/N):") + if content.strip() == "Y" or content.strip( + ) == "y" or content.strip() == "yes" or content.strip() == "Yes": + logger.info( + "change the sampele rate, channel to 16k and 1 channel") + break + elif content.strip() == "N" or content.strip( + ) == "n" or content.strip() == "no" or content.strip() == "No": + logger.info("Exit the program") + exit(1) + else: + logger.warning("Not regular input, please input again") + + self.change_format = True + else: + logger.info("The audio file format is right") + self.change_format = False + def execute(self, argv: List[str]) -> bool: """ Command line entry. @@ -290,24 +371,28 @@ class ASRExecutor(BaseExecutor): model = self.parser_args.model lang = self.parser_args.lang + model_sample_rate = self.parser_args.model_sample_rate config = self.parser_args.config ckpt_path = self.parser_args.ckpt_path audio_file = os.path.abspath(self.parser_args.input) device = self.parser_args.device try: - res = self(model, lang, config, ckpt_path, audio_file, device) + res = self(model, lang, model_sample_rate, config, ckpt_path, audio_file, + device) logger.info('ASR Result: {}'.format(res)) return True except Exception as e: print(e) return False - def __call__(self, model, lang, config, ckpt_path, audio_file, device): + def __call__(self, model, lang, model_sample_rate, config, ckpt_path, audio_file, + device): """ Python API to call an executor. """ - self._init_from_path(model, lang, config, ckpt_path) + self._check(audio_file, model_sample_rate) + self._init_from_path(model, lang, model_sample_rate, config, ckpt_path) self.preprocess(audio_file) self.infer() res = self.postprocess() # Retrieve result of asr. From 90d648a601d64aefea0dd9d4a63a87eede1a8b09 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Wed, 1 Dec 2021 12:12:55 +0000 Subject: [PATCH 2/2] support using by __call__ --- paddlespeech/cli/asr/infer.py | 138 ++++++++++++++++++---------------- 1 file changed, 75 insertions(+), 63 deletions(-) diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index a0ae5350..ea1828b6 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -119,7 +119,8 @@ class ASRExecutor(BaseExecutor): lang: str='zh', model_sample_rate: int=16000, cfg_path: Optional[os.PathLike]=None, - ckpt_path: Optional[os.PathLike]=None): + ckpt_path: Optional[os.PathLike]=None, + device: str='cpu'): """ Init model and other resources from a specific path. """ @@ -140,12 +141,8 @@ class ASRExecutor(BaseExecutor): res_path = os.path.dirname( os.path.dirname(os.path.abspath(self.cfg_path))) - # Enter the path of model root - os.chdir(res_path) - #Init body. - parser_args = self.parser_args - paddle.set_device(parser_args.device) + paddle.set_device(device) self.config = CfgNode(new_allowed=True) self.config.merge_from_file(self.cfg_path) self.config.decoding.decoding_method = "attention_rescoring" @@ -153,29 +150,35 @@ class ASRExecutor(BaseExecutor): logger.info(model_conf) with UpdateConfig(model_conf): - if parser_args.model == "ds2_online" or parser_args.model == "ds2_offline": + if model_type == "ds2_online" or model_type == "ds2_offline": from paddlespeech.s2t.io.collator import SpeechCollator self.config.collator.vocab_filepath = os.path.join( res_path, self.config.collator.vocab_filepath) - self.config.collator.vocab_filepath = os.path.join( + self.config.collator.mean_std_filepath = os.path.join( res_path, self.config.collator.cmvn_path) self.collate_fn_test = SpeechCollator.from_config(self.config) + text_feature = TextFeaturizer( + unit_type=self.config.collator.unit_type, + vocab_filepath=self.config.collator.vocab_filepath, + spm_model_prefix=self.config.collator.spm_model_prefix) model_conf.input_dim = self.collate_fn_test.feature_size - model_conf.output_dim = self.text_feature.vocab_size - elif parser_args.model == "conformer" or parser_args.model == "transformer" or parser_args.model == "wenetspeech": - + model_conf.output_dim = text_feature.vocab_size + elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech": self.config.collator.vocab_filepath = os.path.join( res_path, self.config.collator.vocab_filepath) - self.text_feature = TextFeaturizer( + text_feature = TextFeaturizer( unit_type=self.config.collator.unit_type, vocab_filepath=self.config.collator.vocab_filepath, spm_model_prefix=self.config.collator.spm_model_prefix) model_conf.input_dim = self.config.collator.feat_dim - model_conf.output_dim = self.text_feature.vocab_size + model_conf.output_dim = text_feature.vocab_size else: raise Exception("wrong type") self.config.freeze() - model_class = dynamic_import(parser_args.model, model_alias) + # Enter the path of model root + os.chdir(res_path) + + model_class = dynamic_import(model_type, model_alias) model = model_class.from_config(model_conf) self.model = model self.model.eval() @@ -185,31 +188,31 @@ class ASRExecutor(BaseExecutor): model_dict = paddle.load(params_path) self.model.set_state_dict(model_dict) - def preprocess(self, input: Union[str, os.PathLike]): + def preprocess(self, model_type: str, input: Union[str, os.PathLike]): """ Input preprocess and return paddle.Tensor stored in self.input. Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet). """ - parser_args = self.parser_args - config = self.config audio_file = input logger.info("Preprocess audio_file:" + audio_file) - self.sr = config.collator.target_sample_rate + config_target_sample_rate = self.config.collator.target_sample_rate # Get the object for feature extraction - if parser_args.model == "ds2_online" or parser_args.model == "ds2_offline": + if model_type == "ds2_online" or model_type == "ds2_offline": audio, _ = self.collate_fn_test.process_utterance( audio_file=audio_file, transcript=" ") audio_len = audio.shape[0] audio = paddle.to_tensor(audio, dtype='float32') - self.audio_len = paddle.to_tensor(audio_len) - self.audio = paddle.unsqueeze(audio, axis=0) - self.vocab_list = collate_fn_test.vocab_list - logger.info(f"audio feat shape: {self.audio.shape}") - - elif parser_args.model == "conformer" or parser_args.model == "transformer" or parser_args.model == "wenetspeech": + audio_len = paddle.to_tensor(audio_len) + audio = paddle.unsqueeze(audio, axis=0) + vocab_list = collate_fn_test.vocab_list + self._inputs["audio"] = audio + self._inputs["audio_len"] = audio_len + logger.info(f"audio feat shape: {audio.shape}") + + elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech": logger.info("get the preprocess conf") preprocess_conf = os.path.join( os.path.dirname(os.path.abspath(self.cfg_path)), @@ -235,7 +238,7 @@ class ASRExecutor(BaseExecutor): else: audio = audio[:, 0] - if sample_rate != self.sr: + if sample_rate != config_target_sample_rate: logger.error( f"sample rate error: {sample_rate}, need {self.sr} ") sys.exit(-1) @@ -243,29 +246,36 @@ class ASRExecutor(BaseExecutor): # fbank audio = preprocessing(audio, **preprocess_args) - self.audio_len = paddle.to_tensor(audio.shape[0]) - self.audio = paddle.to_tensor( - audio, dtype='float32').unsqueeze(axis=0) - logger.info(f"audio feat shape: {self.audio.shape}") + audio_len = paddle.to_tensor(audio.shape[0]) + audio = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0) + text_feature = TextFeaturizer( + unit_type=self.config.collator.unit_type, + vocab_filepath=self.config.collator.vocab_filepath, + spm_model_prefix=self.config.collator.spm_model_prefix) + self._inputs["audio"] = audio + self._inputs["audio_len"] = audio_len + logger.info(f"audio feat shape: {audio.shape}") else: raise Exception("wrong type") @paddle.no_grad() - def infer(self): + def infer(self, model_type: str): """ Model inference and result stored in self.output. """ + text_feature = TextFeaturizer( + unit_type=self.config.collator.unit_type, + vocab_filepath=self.config.collator.vocab_filepath, + spm_model_prefix=self.config.collator.spm_model_prefix) cfg = self.config.decoding - parser_args = self.parser_args - audio = self.audio - audio_len = self.audio_len - if parser_args.model == "ds2_online" or parser_args.model == "ds2_offline": - vocab_list = self.vocab_list + audio = self._inputs["audio"] + audio_len = self._inputs["audio_len"] + if model_type == "ds2_online" or model_type == "ds2_offline": result_transcripts = self.model.decode( audio, audio_len, - vocab_list, + text_feature.vocab_list, decoding_method=cfg.decoding_method, lang_model_path=cfg.lang_model_path, beam_alpha=cfg.alpha, @@ -274,14 +284,13 @@ class ASRExecutor(BaseExecutor): cutoff_prob=cfg.cutoff_prob, cutoff_top_n=cfg.cutoff_top_n, num_processes=cfg.num_proc_bsearch) - self.result_transcripts = result_transcripts[0] + self._outputs["result"] = result_transcripts[0] - elif parser_args.model == "conformer" or parser_args.model == "transformer" or parser_args.model == "wenetspeech": - text_feature = self.text_feature + elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech": result_transcripts = self.model.decode( audio, audio_len, - text_feature=self.text_feature, + text_feature=text_feature, decoding_method=cfg.decoding_method, lang_model_path=cfg.lang_model_path, beam_alpha=cfg.alpha, @@ -294,23 +303,22 @@ class ASRExecutor(BaseExecutor): decoding_chunk_size=cfg.decoding_chunk_size, num_decoding_left_chunks=cfg.num_decoding_left_chunks, simulate_streaming=cfg.simulate_streaming) - self.result_transcripts = result_transcripts[0][0] + self._outputs["result"] = result_transcripts[0][0] else: raise Exception("invalid model name") - pass - def postprocess(self) -> Union[str, os.PathLike]: """ Output postprocess and return human-readable results such as texts and audio files. """ - return self.result_transcripts + return self._outputs["result"] def _check(self, audio_file: str, model_sample_rate: int): self.target_sample_rate = model_sample_rate if self.target_sample_rate != 16000 and self.target_sample_rate != 8000: logger.error( - "please input --model_sample_rate 8000 or --model_sample_rate 16000") + "please input --model_sample_rate 8000 or --model_sample_rate 16000" + ) raise Exception("invalid sample rate") sys.exit(-1) @@ -336,11 +344,13 @@ class ASRExecutor(BaseExecutor): sys.exit(-1) logger.info("The sample rate is %d" % sample_rate) if sample_rate != self.target_sample_rate: - logger.warning("The sample rate of the input file is not {}.\n \ + logger.warning( + "The sample rate of the input file is not {}.\n \ The program will resample the wav file to {}.\n \ If the result does not meet your expectations,\n \ Please input the 16k 16bit 1 channel wav file. \ - ".format(self.target_sample_rate, self.target_sample_rate)) + " + .format(self.target_sample_rate, self.target_sample_rate)) while (True): logger.info( "Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream." @@ -367,34 +377,36 @@ class ASRExecutor(BaseExecutor): """ Command line entry. """ - self.parser_args = self.parser.parse_args(argv) + parser_args = self.parser.parse_args(argv) - model = self.parser_args.model - lang = self.parser_args.lang - model_sample_rate = self.parser_args.model_sample_rate - config = self.parser_args.config - ckpt_path = self.parser_args.ckpt_path - audio_file = os.path.abspath(self.parser_args.input) - device = self.parser_args.device + model = parser_args.model + lang = parser_args.lang + model_sample_rate = parser_args.model_sample_rate + config = parser_args.config + ckpt_path = parser_args.ckpt_path + audio_file = parser_args.input + device = parser_args.device try: - res = self(model, lang, model_sample_rate, config, ckpt_path, audio_file, - device) + res = self(model, lang, model_sample_rate, config, ckpt_path, + audio_file, device) logger.info('ASR Result: {}'.format(res)) return True except Exception as e: print(e) return False - def __call__(self, model, lang, model_sample_rate, config, ckpt_path, audio_file, - device): + def __call__(self, model, lang, model_sample_rate, config, ckpt_path, + audio_file, device): """ Python API to call an executor. """ + audio_file = os.path.abspath(audio_file) self._check(audio_file, model_sample_rate) - self._init_from_path(model, lang, model_sample_rate, config, ckpt_path) - self.preprocess(audio_file) - self.infer() + self._init_from_path(model, lang, model_sample_rate, config, ckpt_path, + device) + self.preprocess(model, audio_file) + self.infer(model) res = self.postprocess() # Retrieve result of asr. return res