From 1b57d05d1b3664dc39cd9048a8f7550b21389a23 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Fri, 3 Dec 2021 11:56:52 +0000 Subject: [PATCH] rm the os.chdir in cli asr --- paddlespeech/cli/asr/infer.py | 70 +++++++++++++++++++---------------- 1 file changed, 39 insertions(+), 31 deletions(-) diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 5ea3e59a..b40516e9 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import io import os import sys from typing import List @@ -19,10 +20,11 @@ from typing import Optional from typing import Union import librosa +import numpy as np import paddle import soundfile +import yaml from yacs.config import CfgNode -import numpy as np from ..executor import BaseExecutor from ..utils import cli_register @@ -131,8 +133,7 @@ class ASRExecutor(BaseExecutor): lang: str='zh', sample_rate: int=16000, cfg_path: Optional[os.PathLike]=None, - ckpt_path: Optional[os.PathLike]=None - ): + ckpt_path: Optional[os.PathLike]=None): """ Init model and other resources from a specific path. """ @@ -140,10 +141,11 @@ class ASRExecutor(BaseExecutor): sample_rate_str = '16k' if sample_rate == 16000 else '8k' tag = model_type + '_' + lang + '_' + sample_rate_str res_path = self._get_pretrained_path(tag) # wenetspeech_zh + self.res_path = res_path self.cfg_path = os.path.join(res_path, pretrained_models[tag]['cfg_path']) - self.ckpt_path = os.path.join(res_path, - pretrained_models[tag]['ckpt_path'] + ".pdparams") + self.ckpt_path = os.path.join( + res_path, pretrained_models[tag]['ckpt_path'] + ".pdparams") logger.info(res_path) logger.info(self.cfg_path) logger.info(self.ckpt_path) @@ -157,10 +159,8 @@ class ASRExecutor(BaseExecutor): self.config = CfgNode(new_allowed=True) self.config.merge_from_file(self.cfg_path) self.config.decoding.decoding_method = "attention_rescoring" - model_conf = self.config.model - logger.info(model_conf) - with UpdateConfig(model_conf): + with UpdateConfig(self.config): if model_type == "ds2_online" or model_type == "ds2_offline": from paddlespeech.s2t.io.collator import SpeechCollator self.config.collator.vocab_filepath = os.path.join( @@ -172,24 +172,29 @@ class ASRExecutor(BaseExecutor): unit_type=self.config.collator.unit_type, vocab_filepath=self.config.collator.vocab_filepath, spm_model_prefix=self.config.collator.spm_model_prefix) - model_conf.input_dim = self.collate_fn_test.feature_size - model_conf.output_dim = text_feature.vocab_size + self.config.model.input_dim = self.collate_fn_test.feature_size + self.config.model.output_dim = text_feature.vocab_size elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech": self.config.collator.vocab_filepath = os.path.join( res_path, self.config.collator.vocab_filepath) + self.config.collator.augmentation_config = os.path.join( + res_path, self.config.collator.augmentation_config) + self.config.collator.spm_model_prefix = os.path.join( + res_path, self.config.collator.spm_model_prefix) text_feature = TextFeaturizer( unit_type=self.config.collator.unit_type, vocab_filepath=self.config.collator.vocab_filepath, spm_model_prefix=self.config.collator.spm_model_prefix) - model_conf.input_dim = self.config.collator.feat_dim - model_conf.output_dim = text_feature.vocab_size + self.config.model.input_dim = self.config.collator.feat_dim + self.config.model.output_dim = text_feature.vocab_size + else: raise Exception("wrong type") - self.config.freeze() # Enter the path of model root - os.chdir(res_path) model_class = dynamic_import(model_type, model_alias) + model_conf = self.config.model + logger.info(model_conf) model = model_class.from_config(model_conf) self.model = model self.model.eval() @@ -222,10 +227,17 @@ class ASRExecutor(BaseExecutor): elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech": logger.info("get the preprocess conf") - preprocess_conf = os.path.join( - os.path.dirname(os.path.abspath(self.cfg_path)), - "preprocess.yaml") - + preprocess_conf_file = self.config.collator.augmentation_config + # redirect the cmvn path + with io.open(preprocess_conf_file, encoding="utf-8") as f: + preprocess_conf = yaml.safe_load(f) + for idx, process in enumerate(preprocess_conf["process"]): + if process['type'] == "cmvn_json": + preprocess_conf["process"][idx][ + "cmvn_path"] = os.path.join( + self.res_path, + preprocess_conf["process"][idx]["cmvn_path"]) + break logger.info(preprocess_conf) preprocess_args = {"train": False} preprocessing = Transformation(preprocess_conf) @@ -320,14 +332,14 @@ class ASRExecutor(BaseExecutor): return self._outputs["result"] def _pcm16to32(self, audio): - assert(audio.dtype == np.int16) + assert (audio.dtype == np.int16) audio = audio.astype("float32") bits = np.iinfo(np.int16).bits audio = audio / (2**(bits - 1)) return audio def _pcm32to16(self, audio): - assert(audio.dtype == np.float32) + assert (audio.dtype == np.float32) bits = np.iinfo(np.int16).bits audio = audio * (2**(bits - 1)) audio = np.round(audio).astype("int16") @@ -336,9 +348,7 @@ class ASRExecutor(BaseExecutor): def _check(self, audio_file: str, sample_rate: int): self.sample_rate = sample_rate if self.sample_rate != 16000 and self.sample_rate != 8000: - logger.error( - "please input --sr 8000 or --sr 16000" - ) + logger.error("please input --sr 8000 or --sr 16000") raise Exception("invalid sample rate") sys.exit(-1) @@ -364,13 +374,11 @@ class ASRExecutor(BaseExecutor): sys.exit(-1) logger.info("The sample rate is %d" % audio_sample_rate) if audio_sample_rate != self.sample_rate: - logger.warning( - "The sample rate of the input file is not {}.\n \ + logger.warning("The sample rate of the input file is not {}.\n \ The program will resample the wav file to {}.\n \ If the result does not meet your expectations,\n \ Please input the 16k 16 bit 1 channel wav file. \ - " - .format(self.sample_rate, self.sample_rate)) + ".format(self.sample_rate, self.sample_rate)) while (True): logger.info( "Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream." @@ -408,16 +416,16 @@ class ASRExecutor(BaseExecutor): device = parser_args.device try: - res = self(model, lang, sample_rate, config, ckpt_path, - audio_file, device) + res = self(model, lang, sample_rate, config, ckpt_path, audio_file, + device) logger.info('ASR Result: {}'.format(res)) return True except Exception as e: print(e) return False - def __call__(self, model, lang, sample_rate, config, ckpt_path, - audio_file, device): + def __call__(self, model, lang, sample_rate, config, ckpt_path, audio_file, + device): """ Python API to call an executor. """