From e9798498d686e568d4d3488952f8cd2abec9a05f Mon Sep 17 00:00:00 2001 From: KP <109694228@qq.com> Date: Mon, 29 Nov 2021 18:01:39 +0800 Subject: [PATCH 01/15] Update asr inference in paddlespeech.cli. --- paddlespeech/cli/executor.py | 9 +-- paddlespeech/cli/s2t/conf/default_conf.yaml | 0 paddlespeech/cli/s2t/infer.py | 67 +++++++++++++--- paddlespeech/cli/utils.py | 86 ++++++++++++++++++--- 4 files changed, 136 insertions(+), 26 deletions(-) delete mode 100644 paddlespeech/cli/s2t/conf/default_conf.yaml diff --git a/paddlespeech/cli/executor.py b/paddlespeech/cli/executor.py index 45472fa4b..2261e011b 100644 --- a/paddlespeech/cli/executor.py +++ b/paddlespeech/cli/executor.py @@ -14,7 +14,6 @@ import os from abc import ABC from abc import abstractmethod -from typing import Optional from typing import Union import paddle @@ -30,16 +29,16 @@ class BaseExecutor(ABC): self.output = None @abstractmethod - def _get_default_cfg_path(self): + def _get_pretrained_path(self, tag: str) -> os.PathLike: """ - Returns a default config file path of current task. + Download and returns pretrained resources path of current task. """ pass @abstractmethod - def _init_from_cfg(self, cfg_path: Optional[os.PathLike]=None): + def _init_from_path(self, *args, **kwargs): """ - Init model from a specific config file. + Init model and other resources from a specific path. """ pass diff --git a/paddlespeech/cli/s2t/conf/default_conf.yaml b/paddlespeech/cli/s2t/conf/default_conf.yaml deleted file mode 100644 index e69de29bb..000000000 diff --git a/paddlespeech/cli/s2t/infer.py b/paddlespeech/cli/s2t/infer.py index 682279852..6aa29addf 100644 --- a/paddlespeech/cli/s2t/infer.py +++ b/paddlespeech/cli/s2t/infer.py @@ -21,9 +21,21 @@ import paddle from ..executor import BaseExecutor from ..utils import cli_register +from ..utils import download_and_decompress +from ..utils import logger +from ..utils import MODEL_HOME __all__ = ['S2TExecutor'] +pretrained_models = { + "wenetspeech_zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/conformer.model.tar.gz', + 'md5': + '54e7a558a6e020c2f5fb224874943f97', + } +} + @cli_register( name='paddlespeech.s2t', description='Speech to text infer command.') @@ -33,11 +45,23 @@ class S2TExecutor(BaseExecutor): self.parser = argparse.ArgumentParser( prog='paddlespeech.s2t', add_help=True) + self.parser.add_argument( + '--model', + type=str, + default='wenetspeech', + help='Choose model type of asr task.') + self.parser.add_argument( + '--lang', type=str, default='zh', help='Choose model language.') self.parser.add_argument( '--config', type=str, default=None, help='Config of s2t task. Use deault config when it is None.') + self.parser.add_argument( + '--ckpt_path', + type=str, + default=None, + help='Checkpoint file of model.') self.parser.add_argument( '--input', type=str, help='Audio file to recognize.') self.parser.add_argument( @@ -46,16 +70,39 @@ class S2TExecutor(BaseExecutor): default='cpu', help='Choose device to execute model inference.') - def _get_default_cfg_path(self): + def _get_pretrained_path(self, tag: str) -> os.PathLike: """ - Returns a default config file path of current task. + Download and returns pretrained resources path of current task. """ - pass + assert tag in pretrained_models, 'Can not find pretrained resources of {}.'.format( + tag) + + res_path = os.path.join(MODEL_HOME, tag) + decompressed_path = download_and_decompress(pretrained_models[tag], + res_path) + logger.info( + 'Use pretrained model stored in: {}'.format(decompressed_path)) + return decompressed_path - def _init_from_cfg(self, cfg_path: Optional[os.PathLike]=None): + def _init_from_path(self, + model_type: str='wenetspeech', + lang: str='zh', + cfg_path: Optional[os.PathLike]=None, + ckpt_path: Optional[os.PathLike]=None): """ - Init model from a specific config file. + Init model and other resources from a specific path. """ + if cfg_path is None or ckpt_path is None: + res_path = self._get_pretrained_path( + model_type + '_' + lang) # wenetspeech_zh + cfg_path = os.path.join(res_path, 'conf/conformer.yaml') + ckpt_path = os.path.join( + res_path, 'exp/conformer/checkpoints/wenetspeech.pdparams') + logger.info(res_path) + logger.info(cfg_path) + logger.info(ckpt_path) + + # Init body. pass def preprocess(self, input: Union[str, os.PathLike]): @@ -82,17 +129,15 @@ class S2TExecutor(BaseExecutor): parser_args = self.parser.parse_args(argv) print(parser_args) + model = parser_args.model + lang = parser_args.lang config = parser_args.config + ckpt_path = parser_args.ckpt_path audio_file = parser_args.input device = parser_args.device - if config is not None: - assert os.path.isfile(config), 'Config file is not valid.' - else: - config = self._get_default_cfg_path() - try: - self._init_from_cfg(config) + self._init_from_path(model, lang, config, ckpt_path) self.preprocess(audio_file) self.infer() res = self.postprocess() # Retrieve result of s2t. diff --git a/paddlespeech/cli/utils.py b/paddlespeech/cli/utils.py index c83deee89..edf579f71 100644 --- a/paddlespeech/cli/utils.py +++ b/paddlespeech/cli/utils.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools +import logging import os from typing import Any from typing import Dict -from typing import List from paddle.framework import load from paddle.utils import download @@ -26,6 +27,7 @@ __all__ = [ 'get_command', 'download_and_decompress', 'load_state_dict_from_url', + 'logger', ] @@ -53,29 +55,27 @@ def get_command(name: str) -> Any: return com['_entry'] -def decompress(file: str): +def decompress(file: str) -> os.PathLike: """ Extracts all files from a compressed file. """ assert os.path.isfile(file), "File: {} not exists.".format(file) - download._decompress(file) + return download._decompress(file) -def download_and_decompress(archives: List[Dict[str, str]], path: str): +def download_and_decompress(archive: Dict[str, str], path: str) -> os.PathLike: """ Download archieves and decompress to specific path. """ if not os.path.isdir(path): os.makedirs(path) - for archive in archives: - assert 'url' in archive and 'md5' in archive, \ - 'Dictionary keys of "url" and "md5" are required in the archive, but got: {list(archieve.keys())}' + assert 'url' in archive and 'md5' in archive, \ + 'Dictionary keys of "url" and "md5" are required in the archive, but got: {}'.format(list(archive.keys())) + return download.get_path_from_url(archive['url'], path, archive['md5']) - download.get_path_from_url(archive['url'], path, archive['md5']) - -def load_state_dict_from_url(url: str, path: str, md5: str=None): +def load_state_dict_from_url(url: str, path: str, md5: str=None) -> os.PathLike: """ Download and load a state dict from url """ @@ -84,3 +84,69 @@ def load_state_dict_from_url(url: str, path: str, md5: str=None): download.get_path_from_url(url, path, md5) return load(os.path.join(path, os.path.basename(url))) + + +def _get_user_home(): + return os.path.expanduser('~') + + +def _get_paddlespcceh_home(): + if 'PPSPEECH_HOME' in os.environ: + home_path = os.environ['PPSPEECH_HOME'] + if os.path.exists(home_path): + if os.path.isdir(home_path): + return home_path + else: + raise RuntimeError( + 'The environment variable PPSPEECH_HOME {} is not a directory.'. + format(home_path)) + else: + return home_path + return os.path.join(_get_user_home(), '.paddlespeech') + + +def _get_sub_home(directory): + home = os.path.join(_get_paddlespcceh_home(), directory) + if not os.path.exists(home): + os.makedirs(home) + return home + + +PPSPEECH_HOME = _get_paddlespcceh_home() +MODEL_HOME = _get_sub_home('models') + + +class Logger(object): + def __init__(self, name: str=None): + name = 'PaddleSpeech' if not name else name + self.logger = logging.getLogger(name) + + log_config = { + 'DEBUG': 10, + 'INFO': 20, + 'TRAIN': 21, + 'EVAL': 22, + 'WARNING': 30, + 'ERROR': 40, + 'CRITICAL': 50 + } + for key, level in log_config.items(): + logging.addLevelName(level, key) + self.__dict__[key.lower()] = functools.partial(self.__call__, level) + + self.format = logging.Formatter( + fmt='[%(asctime)-15s] [%(levelname)8s] [%(filename)s] [L%(lineno)d] - %(message)s' + ) + + self.handler = logging.StreamHandler() + self.handler.setFormatter(self.format) + + self.logger.addHandler(self.handler) + self.logger.setLevel(logging.DEBUG) + self.logger.propagate = False + + def __call__(self, log_level: str, msg: str): + self.logger.log(log_level, msg) + + +logger = Logger() From c94ebdc52cdcf52b9e400fe2090efc953f895b4e Mon Sep 17 00:00:00 2001 From: KP <109694228@qq.com> Date: Tue, 30 Nov 2021 14:22:32 +0800 Subject: [PATCH 02/15] Add python api for executor. --- paddlespeech/cli/executor.py | 15 +++++++++++++++ paddlespeech/cli/s2t/infer.py | 19 +++++++++++++++---- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/paddlespeech/cli/executor.py b/paddlespeech/cli/executor.py index 2261e011b..2314bd6d3 100644 --- a/paddlespeech/cli/executor.py +++ b/paddlespeech/cli/executor.py @@ -14,6 +14,7 @@ import os from abc import ABC from abc import abstractmethod +from typing import List from typing import Union import paddle @@ -64,3 +65,17 @@ class BaseExecutor(ABC): Output postprocess and return human-readable results such as texts and audio files. """ pass + + @abstractmethod + def execute(self, argv: List[str]) -> bool: + """ + Command line entry. + """ + pass + + @abstractmethod + def __call__(self, *arg, **kwargs): + """ + Python API to call an executor. + """ + pass diff --git a/paddlespeech/cli/s2t/infer.py b/paddlespeech/cli/s2t/infer.py index 6aa29addf..9509e311c 100644 --- a/paddlespeech/cli/s2t/infer.py +++ b/paddlespeech/cli/s2t/infer.py @@ -126,6 +126,9 @@ class S2TExecutor(BaseExecutor): pass def execute(self, argv: List[str]) -> bool: + """ + Command line entry. + """ parser_args = self.parser.parse_args(argv) print(parser_args) @@ -137,12 +140,20 @@ class S2TExecutor(BaseExecutor): device = parser_args.device try: - self._init_from_path(model, lang, config, ckpt_path) - self.preprocess(audio_file) - self.infer() - res = self.postprocess() # Retrieve result of s2t. + res = self(model, lang, config, ckpt_path, audio_file, device) print(res) return True except Exception as e: print(e) return False + + def __call__(self, model, lang, config, ckpt_path, audio_file, device): + """ + Python API to call an executor. + """ + self._init_from_path(model, lang, config, ckpt_path) + self.preprocess(audio_file) + self.infer() + res = self.postprocess() # Retrieve result of s2t. + + return res From cdc8520969bda11eb348f6784a93b607223db9d6 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Tue, 30 Nov 2021 07:32:23 +0000 Subject: [PATCH 03/15] add the infer --- paddlespeech/cli/s2t/infer.py | 239 ++++++++++++++++++++++++++++++---- 1 file changed, 212 insertions(+), 27 deletions(-) diff --git a/paddlespeech/cli/s2t/infer.py b/paddlespeech/cli/s2t/infer.py index 6aa29addf..912d1df08 100644 --- a/paddlespeech/cli/s2t/infer.py +++ b/paddlespeech/cli/s2t/infer.py @@ -13,17 +13,24 @@ # limitations under the License. import argparse import os +import sys from typing import List from typing import Optional from typing import Union +import soundfile import paddle - -from ..executor import BaseExecutor -from ..utils import cli_register -from ..utils import download_and_decompress -from ..utils import logger -from ..utils import MODEL_HOME +from paddlespeech.cli.executor import BaseExecutor +from paddlespeech.cli.utils import cli_register +from paddlespeech.cli.utils import download_and_decompress +from paddlespeech.cli.utils import logger +from paddlespeech.cli.utils import MODEL_HOME +from paddlespeech.s2t.exps.u2.config import get_cfg_defaults +from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer +from paddlespeech.s2t.io.collator import SpeechCollator +from paddlespeech.s2t.transform.transformation import Transformation +from paddlespeech.s2t.utils.dynamic_import import dynamic_import +from paddlespeech.s2t.utils.utility import UpdateConfig __all__ = ['S2TExecutor'] @@ -33,9 +40,44 @@ pretrained_models = { 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/conformer.model.tar.gz', 'md5': '54e7a558a6e020c2f5fb224874943f97', + 'cfg_path': + 'conf/conformer.yaml', + 'ckpt_path': + 'exp/conformer/checkpoints/wenetspeech', } } +model_alias = { + "ds2_offline": "paddlespeech.s2t.models.ds2:DeepSpeech2Model", + "ds2_online": "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline", + "conformer": "paddlespeech.s2t.models.u2:U2Model", + "transformer": "paddlespeech.s2t.models.u2:U2Model", + "wenetspeech": "paddlespeech.s2t.models.u2:U2Model", +} + +pretrain_model_alias = { + "ds2_online_zn": [ + "https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/aishell_ds2_online_cer8.00_release.tar.gz", + "", "" + ], + "ds2_offline_zn": [ + "https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/ds2.model.tar.gz", + "", "" + ], + "transformer_zn": [ + "https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/transformer.model.tar.gz", + "", "" + ], + "conformer_zn": [ + "https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/conformer.model.tar.gz", + "", "" + ], + "wenetspeech_zn": [ + "https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/conformer.model.tar.gz", + "conf/conformer.yaml", "exp/conformer/checkpoints/wenetspeech" + ], +} + @cli_register( name='paddlespeech.s2t', description='Speech to text infer command.') @@ -63,7 +105,10 @@ class S2TExecutor(BaseExecutor): default=None, help='Checkpoint file of model.') self.parser.add_argument( - '--input', type=str, help='Audio file to recognize.') + '--input', + type=str, + default="../Downloads/asr-demo-1.wav", + help='Audio file to recognize.') self.parser.add_argument( '--device', type=str, @@ -80,8 +125,10 @@ class S2TExecutor(BaseExecutor): res_path = os.path.join(MODEL_HOME, tag) decompressed_path = download_and_decompress(pretrained_models[tag], res_path) + decompressed_path = os.path.abspath(decompressed_path) logger.info( 'Use pretrained model stored in: {}'.format(decompressed_path)) + return decompressed_path def _init_from_path(self, @@ -93,56 +140,194 @@ class S2TExecutor(BaseExecutor): Init model and other resources from a specific path. """ if cfg_path is None or ckpt_path is None: - res_path = self._get_pretrained_path( - model_type + '_' + lang) # wenetspeech_zh - cfg_path = os.path.join(res_path, 'conf/conformer.yaml') - ckpt_path = os.path.join( - res_path, 'exp/conformer/checkpoints/wenetspeech.pdparams') + tag = model_type + '_' + lang + res_path = self._get_pretrained_path(tag) # wenetspeech_zh + self.cfg_path = os.path.join(res_path, + pretrained_models[tag]['cfg_path']) + self.ckpt_path = os.path.join(res_path, + pretrained_models[tag]['ckpt_path']) logger.info(res_path) - logger.info(cfg_path) - logger.info(ckpt_path) + logger.info(self.cfg_path) + logger.info(self.ckpt_path) + else: + self.cfg_path = os.path.abspath(cfg_path) + self.ckpt_path = os.path.abspath(ckpt_path) + res_path = os.path.dirname( + os.path.dirname(os.path.abspath(self.cfg_path))) - # Init body. - pass + os.chdir(res_path) + #Init body. + parser_args = self.parser_args + paddle.set_device(parser_args.device) + self.config = get_cfg_defaults() + self.config.merge_from_file(self.cfg_path) + self.config.decoding.decoding_method = "attention_rescoring" + #self.config.freeze() + model_conf = self.config.model + logger.info(model_conf) + + with UpdateConfig(model_conf): + if parser_args.model == "ds2_online" or parser_args.model == "ds2_offline": + self.config.collator.vocab_filepath = os.path.join( + res_path, self.config.collator.vocab_filepath) + self.config.collator.vocab_filepath = os.path.join( + res_path, self.config.collator.cmvn_path) + self.collate_fn_test = SpeechCollator.from_config(self.config) + model_conf.feat_size = self.collate_fn_test.feature_size + model_conf.dict_size = self.text_feature.vocab_size + elif parser_args.model == "conformer" or parser_args.model == "transformer" or parser_args.model == "wenetspeech": + self.config.collator.vocab_filepath = os.path.join( + res_path, self.config.collator.vocab_filepath) + self.text_feature = TextFeaturizer( + unit_type=self.config.collator.unit_type, + vocab_filepath=self.config.collator.vocab_filepath, + spm_model_prefix=self.config.collator.spm_model_prefix) + model_conf.input_dim = self.config.collator.feat_dim + model_conf.output_dim = self.text_feature.vocab_size + else: + raise Exception("wrong type") + model_class = dynamic_import(parser_args.model, model_alias) + model = model_class.from_config(model_conf) + self.model = model + self.model.eval() + + # load model + params_path = self.ckpt_path + ".pdparams" + model_dict = paddle.load(params_path) + self.model.set_state_dict(model_dict) def preprocess(self, input: Union[str, os.PathLike]): """ Input preprocess and return paddle.Tensor stored in self.input. Input content can be a text(t2s), a file(s2t, cls) or a streaming(not supported yet). """ - pass + + parser_args = self.parser_args + config = self.config + audio_file = input + #print("audio_file", audio_file) + logger.info("audio_file"+ audio_file) + + self.sr = config.collator.target_sample_rate + + # Get the object for feature extraction + if parser_args.model == "ds2_online" or parser_args.model == "ds2_offline": + audio, _ = collate_fn_test.process_utterance( + audio_file=audio_file, transcript=" ") + audio_len = audio.shape[0] + audio = paddle.to_tensor(audio, dtype='float32') + self.audio_len = paddle.to_tensor(audio_len) + self.audio = paddle.unsqueeze(audio, axis=0) + self.vocab_list = collate_fn_test.vocab_list + logger.info(f"audio feat shape: {self.audio.shape}") + + elif parser_args.model == "conformer" or parser_args.model == "transformer" or parser_args.model == "wenetspeech": + logger.info("get the preprocess conf") + preprocess_conf = os.path.join( + os.path.dirname(os.path.abspath(self.cfg_path)), + "preprocess.yaml") + + cmvn_path: data / mean_std.json + + logger.info(preprocess_conf) + preprocess_args = {"train": False} + preprocessing = Transformation(preprocess_conf) + audio, sample_rate = soundfile.read( + audio_file, dtype="int16", always_2d=True) + if sample_rate != self.sr: + logger.error( + f"sample rate error: {sample_rate}, need {self.sr} ") + sys.exit(-1) + audio = audio[:, 0] + logger.info(f"audio shape: {audio.shape}") + # fbank + audio = preprocessing(audio, **preprocess_args) + + self.audio_len = paddle.to_tensor(audio.shape[0]) + self.audio = paddle.to_tensor( + audio, dtype='float32').unsqueeze(axis=0) + logger.info(f"audio feat shape: {self.audio.shape}") + + else: + raise Exception("wrong type") @paddle.no_grad() def infer(self): """ Model inference and result stored in self.output. """ + cfg = self.config.decoding + parser_args = self.parser_args + audio = self.audio + audio_len = self.audio_len + if parser_args.model == "ds2_online" or parser_args.model == "ds2_offline": + vocab_list = self.vocab_list + result_transcripts = self.model.decode( + audio, + audio_len, + vocab_list, + decoding_method=cfg.decoding_method, + lang_model_path=cfg.lang_model_path, + beam_alpha=cfg.alpha, + beam_beta=cfg.beta, + beam_size=cfg.beam_size, + cutoff_prob=cfg.cutoff_prob, + cutoff_top_n=cfg.cutoff_top_n, + num_processes=cfg.num_proc_bsearch) + self.result_transcripts = result_transcripts[0] + + elif parser_args.model == "conformer" or parser_args.model == "transformer" or parser_args.model == "wenetspeech": + text_feature = self.text_feature + result_transcripts = self.model.decode( + audio, + audio_len, + text_feature=self.text_feature, + decoding_method=cfg.decoding_method, + lang_model_path=cfg.lang_model_path, + beam_alpha=cfg.alpha, + beam_beta=cfg.beta, + beam_size=cfg.beam_size, + cutoff_prob=cfg.cutoff_prob, + cutoff_top_n=cfg.cutoff_top_n, + num_processes=cfg.num_proc_bsearch, + ctc_weight=cfg.ctc_weight, + decoding_chunk_size=cfg.decoding_chunk_size, + num_decoding_left_chunks=cfg.num_decoding_left_chunks, + simulate_streaming=cfg.simulate_streaming) + self.result_transcripts = result_transcripts[0][0] + else: + raise Exception("invalid model name") + pass def postprocess(self) -> Union[str, os.PathLike]: """ Output postprocess and return human-readable results such as texts and audio files. """ - pass + return self.result_transcripts def execute(self, argv: List[str]) -> bool: - parser_args = self.parser.parse_args(argv) - print(parser_args) + self.parser_args = self.parser.parse_args(argv) - model = parser_args.model - lang = parser_args.lang - config = parser_args.config - ckpt_path = parser_args.ckpt_path - audio_file = parser_args.input - device = parser_args.device + model = self.parser_args.model + lang = self.parser_args.lang + config = self.parser_args.config + ckpt_path = self.parser_args.ckpt_path + audio_file = os.path.abspath(self.parser_args.input) + device = self.parser_args.device try: self._init_from_path(model, lang, config, ckpt_path) self.preprocess(audio_file) self.infer() res = self.postprocess() # Retrieve result of s2t. - print(res) + logger.info(res) return True except Exception as e: print(e) return False + + +if __name__ == "__main__": + exe = S2TExecutor() + exe.execute('') From 000294132cd2e37d04cb09d68450bdad9494ac5f Mon Sep 17 00:00:00 2001 From: KP <109694228@qq.com> Date: Tue, 30 Nov 2021 17:55:44 +0800 Subject: [PATCH 04/15] Rename s2t to asr. --- paddlespeech/cli/README.md | 4 ++-- paddlespeech/cli/__init__.py | 2 +- paddlespeech/cli/{s2t => asr}/__init__.py | 2 +- paddlespeech/cli/{s2t => asr}/infer.py | 26 ++++++++++------------- paddlespeech/cli/entry.py | 5 ++++- paddlespeech/cli/executor.py | 2 +- paddlespeech/cli/{t2s => tts}/__init.__py | 0 7 files changed, 20 insertions(+), 21 deletions(-) rename paddlespeech/cli/{s2t => asr}/__init__.py (95%) rename paddlespeech/cli/{s2t => asr}/infer.py (95%) rename paddlespeech/cli/{t2s => tts}/__init.__py (100%) diff --git a/paddlespeech/cli/README.md b/paddlespeech/cli/README.md index 4cea85b14..bd6572f19 100644 --- a/paddlespeech/cli/README.md +++ b/paddlespeech/cli/README.md @@ -5,5 +5,5 @@ ## Help `paddlespeech help` - ## S2T - `paddlespeech s2t --config ./s2t.yaml --input ./zh.wav --device gpu` + ## ASR + `paddlespeech asr --input ./test_audio.wav --device gpu` diff --git a/paddlespeech/cli/__init__.py b/paddlespeech/cli/__init__.py index 1cc7e27f5..7e0329041 100644 --- a/paddlespeech/cli/__init__.py +++ b/paddlespeech/cli/__init__.py @@ -11,6 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from .asr import ASRExecutor from .base_commands import BaseCommand from .base_commands import HelpCommand -from .s2t import S2TExecutor diff --git a/paddlespeech/cli/s2t/__init__.py b/paddlespeech/cli/asr/__init__.py similarity index 95% rename from paddlespeech/cli/s2t/__init__.py rename to paddlespeech/cli/asr/__init__.py index 57e814b9e..8ab0991fc 100644 --- a/paddlespeech/cli/s2t/__init__.py +++ b/paddlespeech/cli/asr/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .infer import S2TExecutor +from .infer import ASRExecutor diff --git a/paddlespeech/cli/s2t/infer.py b/paddlespeech/cli/asr/infer.py similarity index 95% rename from paddlespeech/cli/s2t/infer.py rename to paddlespeech/cli/asr/infer.py index b3507cb60..605163803 100644 --- a/paddlespeech/cli/s2t/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -33,7 +33,7 @@ from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.utility import UpdateConfig -__all__ = ['S2TExecutor'] +__all__ = ['ASRExecutor'] pretrained_models = { "wenetspeech_zh": { @@ -58,13 +58,15 @@ model_alias = { @cli_register( - name='paddlespeech.s2t', description='Speech to text infer command.') -class S2TExecutor(BaseExecutor): + name='paddlespeech.asr', description='Speech to text infer command.') +class ASRExecutor(BaseExecutor): def __init__(self): - super(S2TExecutor, self).__init__() + super(ASRExecutor, self).__init__() self.parser = argparse.ArgumentParser( - prog='paddlespeech.s2t', add_help=True) + prog='paddlespeech.asr', add_help=True) + self.parser.add_argument( + '--input', type=str, required=True, help='Audio file to recognize.') self.parser.add_argument( '--model', type=str, @@ -76,16 +78,12 @@ class S2TExecutor(BaseExecutor): '--config', type=str, default=None, - help='Config of s2t task. Use deault config when it is None.') + help='Config of asr task. Use deault config when it is None.') self.parser.add_argument( '--ckpt_path', type=str, default=None, help='Checkpoint file of model.') - self.parser.add_argument( - '--input', - type=str, - help='Audio file to recognize.') self.parser.add_argument( '--device', type=str, @@ -178,13 +176,12 @@ class S2TExecutor(BaseExecutor): def preprocess(self, input: Union[str, os.PathLike]): """ Input preprocess and return paddle.Tensor stored in self.input. - Input content can be a text(t2s), a file(s2t, cls) or a streaming(not supported yet). + Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet). """ parser_args = self.parser_args config = self.config audio_file = input - #print("audio_file", audio_file) logger.info("audio_file" + audio_file) self.sr = config.collator.target_sample_rate @@ -290,7 +287,6 @@ class S2TExecutor(BaseExecutor): Command line entry. """ self.parser_args = self.parser.parse_args(argv) - print(self.parser_args) model = self.parser_args.model lang = self.parser_args.lang @@ -301,7 +297,7 @@ class S2TExecutor(BaseExecutor): try: res = self(model, lang, config, ckpt_path, audio_file, device) - print(res) + logger.info('ASR Result: {}'.format(res)) return True except Exception as e: print(e) @@ -314,6 +310,6 @@ class S2TExecutor(BaseExecutor): self._init_from_path(model, lang, config, ckpt_path) self.preprocess(audio_file) self.infer() - res = self.postprocess() # Retrieve result of s2t. + res = self.postprocess() # Retrieve result of asr. return res diff --git a/paddlespeech/cli/entry.py b/paddlespeech/cli/entry.py index 726cff1af..32123ece7 100644 --- a/paddlespeech/cli/entry.py +++ b/paddlespeech/cli/entry.py @@ -23,9 +23,12 @@ def _CommandDict(): def _execute(): com = commands - for idx, _argv in enumerate(['paddlespeech'] + sys.argv[1:]): + + idx = 0 + for _argv in (['paddlespeech'] + sys.argv[1:]): if _argv not in com: break + idx += 1 com = com[_argv] # The method 'execute' of a command instance returns 'True' for a success diff --git a/paddlespeech/cli/executor.py b/paddlespeech/cli/executor.py index 2314bd6d3..e307a287b 100644 --- a/paddlespeech/cli/executor.py +++ b/paddlespeech/cli/executor.py @@ -47,7 +47,7 @@ class BaseExecutor(ABC): def preprocess(self, input: Union[str, os.PathLike]): """ Input preprocess and return paddle.Tensor stored in self.input. - Input content can be a text(t2s), a file(s2t, cls) or a streaming(not supported yet). + Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet). """ pass diff --git a/paddlespeech/cli/t2s/__init.__py b/paddlespeech/cli/tts/__init.__py similarity index 100% rename from paddlespeech/cli/t2s/__init.__py rename to paddlespeech/cli/tts/__init.__py From 17072444726509b4653e3c5b7eaf40490d85e4f6 Mon Sep 17 00:00:00 2001 From: KP <109694228@qq.com> Date: Tue, 30 Nov 2021 18:14:07 +0800 Subject: [PATCH 05/15] Update device usage. --- paddlespeech/cli/README.md | 2 +- paddlespeech/cli/asr/infer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddlespeech/cli/README.md b/paddlespeech/cli/README.md index bd6572f19..56afb939c 100644 --- a/paddlespeech/cli/README.md +++ b/paddlespeech/cli/README.md @@ -6,4 +6,4 @@ `paddlespeech help` ## ASR - `paddlespeech asr --input ./test_audio.wav --device gpu` + `paddlespeech asr --input ./test_audio.wav` diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 605163803..e5c64e9ab 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -87,7 +87,7 @@ class ASRExecutor(BaseExecutor): self.parser.add_argument( '--device', type=str, - default='cpu', + default=paddle.get_device(), help='Choose device to execute model inference.') def _get_pretrained_path(self, tag: str) -> os.PathLike: From 3fadcde5e259aa32a5f9af59843eddf4b0c22b63 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Wed, 1 Dec 2021 10:26:18 +0000 Subject: [PATCH 06/15] revise the asr infer.py --- paddlespeech/cli/asr/infer.py | 119 +++++++++++++++++++++++++++++----- 1 file changed, 102 insertions(+), 17 deletions(-) diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index e5c64e9ab..a0ae53507 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -18,17 +18,17 @@ from typing import List from typing import Optional from typing import Union +import librosa import paddle import soundfile +from yacs.config import CfgNode from ..executor import BaseExecutor from ..utils import cli_register from ..utils import download_and_decompress from ..utils import logger from ..utils import MODEL_HOME -from paddlespeech.s2t.exps.u2.config import get_cfg_defaults from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer -from paddlespeech.s2t.io.collator import SpeechCollator from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.utility import UpdateConfig @@ -36,7 +36,7 @@ from paddlespeech.s2t.utils.utility import UpdateConfig __all__ = ['ASRExecutor'] pretrained_models = { - "wenetspeech_zh": { + "wenetspeech_zh_16k": { 'url': 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/conformer.model.tar.gz', 'md5': @@ -73,7 +73,15 @@ class ASRExecutor(BaseExecutor): default='wenetspeech', help='Choose model type of asr task.') self.parser.add_argument( - '--lang', type=str, default='zh', help='Choose model language.') + '--lang', + type=str, + default='zh', + help='Choose model language. zh or en') + self.parser.add_argument( + "--model_sample_rate", + type=int, + default=16000, + help='Choose the audio sample rate of the model. 8000 or 16000') self.parser.add_argument( '--config', type=str, @@ -109,13 +117,15 @@ class ASRExecutor(BaseExecutor): def _init_from_path(self, model_type: str='wenetspeech', lang: str='zh', + model_sample_rate: int=16000, cfg_path: Optional[os.PathLike]=None, ckpt_path: Optional[os.PathLike]=None): """ Init model and other resources from a specific path. """ if cfg_path is None or ckpt_path is None: - tag = model_type + '_' + lang + model_sample_rate_str = '16k' if model_sample_rate == 16000 else '8k' + tag = model_type + '_' + lang + '_' + model_sample_rate_str res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.cfg_path = os.path.join(res_path, pretrained_models[tag]['cfg_path']) @@ -136,23 +146,24 @@ class ASRExecutor(BaseExecutor): #Init body. parser_args = self.parser_args paddle.set_device(parser_args.device) - self.config = get_cfg_defaults() + self.config = CfgNode(new_allowed=True) self.config.merge_from_file(self.cfg_path) self.config.decoding.decoding_method = "attention_rescoring" - #self.config.freeze() model_conf = self.config.model logger.info(model_conf) with UpdateConfig(model_conf): if parser_args.model == "ds2_online" or parser_args.model == "ds2_offline": + from paddlespeech.s2t.io.collator import SpeechCollator self.config.collator.vocab_filepath = os.path.join( res_path, self.config.collator.vocab_filepath) self.config.collator.vocab_filepath = os.path.join( res_path, self.config.collator.cmvn_path) self.collate_fn_test = SpeechCollator.from_config(self.config) - model_conf.feat_size = self.collate_fn_test.feature_size - model_conf.dict_size = self.text_feature.vocab_size + model_conf.input_dim = self.collate_fn_test.feature_size + model_conf.output_dim = self.text_feature.vocab_size elif parser_args.model == "conformer" or parser_args.model == "transformer" or parser_args.model == "wenetspeech": + self.config.collator.vocab_filepath = os.path.join( res_path, self.config.collator.vocab_filepath) self.text_feature = TextFeaturizer( @@ -163,6 +174,7 @@ class ASRExecutor(BaseExecutor): model_conf.output_dim = self.text_feature.vocab_size else: raise Exception("wrong type") + self.config.freeze() model_class = dynamic_import(parser_args.model, model_alias) model = model_class.from_config(model_conf) self.model = model @@ -182,13 +194,13 @@ class ASRExecutor(BaseExecutor): parser_args = self.parser_args config = self.config audio_file = input - logger.info("audio_file" + audio_file) + logger.info("Preprocess audio_file:" + audio_file) self.sr = config.collator.target_sample_rate # Get the object for feature extraction if parser_args.model == "ds2_online" or parser_args.model == "ds2_offline": - audio, _ = collate_fn_test.process_utterance( + audio, _ = self.collate_fn_test.process_utterance( audio_file=audio_file, transcript=" ") audio_len = audio.shape[0] audio = paddle.to_tensor(audio, dtype='float32') @@ -203,18 +215,30 @@ class ASRExecutor(BaseExecutor): os.path.dirname(os.path.abspath(self.cfg_path)), "preprocess.yaml") - cmvn_path: data / mean_std.json - logger.info(preprocess_conf) preprocess_args = {"train": False} preprocessing = Transformation(preprocess_conf) + logger.info("read the audio file") audio, sample_rate = soundfile.read( audio_file, dtype="int16", always_2d=True) + + if self.change_format: + if audio.shape[1] >= 2: + audio = audio.mean(axis=1) + else: + audio = audio[:, 0] + audio = audio.astype("float32") + audio = librosa.resample(audio, sample_rate, + self.target_sample_rate) + sample_rate = self.target_sample_rate + audio = audio.astype("int16") + else: + audio = audio[:, 0] + if sample_rate != self.sr: logger.error( f"sample rate error: {sample_rate}, need {self.sr} ") sys.exit(-1) - audio = audio[:, 0] logger.info(f"audio shape: {audio.shape}") # fbank audio = preprocessing(audio, **preprocess_args) @@ -282,6 +306,63 @@ class ASRExecutor(BaseExecutor): """ return self.result_transcripts + def _check(self, audio_file: str, model_sample_rate: int): + self.target_sample_rate = model_sample_rate + if self.target_sample_rate != 16000 and self.target_sample_rate != 8000: + logger.error( + "please input --model_sample_rate 8000 or --model_sample_rate 16000") + raise Exception("invalid sample rate") + sys.exit(-1) + + if not os.path.isfile(audio_file): + logger.error("Please input the right audio file path") + sys.exit(-1) + + logger.info("checking the audio file format......") + try: + sig, sample_rate = soundfile.read( + audio_file, dtype="int16", always_2d=True) + except Exception as e: + logger.error(str(e)) + logger.error( + "can not open the audio file, please check the audio file format is 'wav'. \n \ + you can try to use sox to change the file format.\n \ + For example: \n \ + sample rate: 16k \n \ + sox input_audio.xx --rate 16k --bits 16 --channels 1 output_audio.wav \n \ + sample rate: 8k \n \ + sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \ + ") + sys.exit(-1) + logger.info("The sample rate is %d" % sample_rate) + if sample_rate != self.target_sample_rate: + logger.warning("The sample rate of the input file is not {}.\n \ + The program will resample the wav file to {}.\n \ + If the result does not meet your expectations,\n \ + Please input the 16k 16bit 1 channel wav file. \ + ".format(self.target_sample_rate, self.target_sample_rate)) + while (True): + logger.info( + "Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream." + ) + content = input("Input(Y/N):") + if content.strip() == "Y" or content.strip( + ) == "y" or content.strip() == "yes" or content.strip() == "Yes": + logger.info( + "change the sampele rate, channel to 16k and 1 channel") + break + elif content.strip() == "N" or content.strip( + ) == "n" or content.strip() == "no" or content.strip() == "No": + logger.info("Exit the program") + exit(1) + else: + logger.warning("Not regular input, please input again") + + self.change_format = True + else: + logger.info("The audio file format is right") + self.change_format = False + def execute(self, argv: List[str]) -> bool: """ Command line entry. @@ -290,24 +371,28 @@ class ASRExecutor(BaseExecutor): model = self.parser_args.model lang = self.parser_args.lang + model_sample_rate = self.parser_args.model_sample_rate config = self.parser_args.config ckpt_path = self.parser_args.ckpt_path audio_file = os.path.abspath(self.parser_args.input) device = self.parser_args.device try: - res = self(model, lang, config, ckpt_path, audio_file, device) + res = self(model, lang, model_sample_rate, config, ckpt_path, audio_file, + device) logger.info('ASR Result: {}'.format(res)) return True except Exception as e: print(e) return False - def __call__(self, model, lang, config, ckpt_path, audio_file, device): + def __call__(self, model, lang, model_sample_rate, config, ckpt_path, audio_file, + device): """ Python API to call an executor. """ - self._init_from_path(model, lang, config, ckpt_path) + self._check(audio_file, model_sample_rate) + self._init_from_path(model, lang, model_sample_rate, config, ckpt_path) self.preprocess(audio_file) self.infer() res = self.postprocess() # Retrieve result of asr. From 44e9b032d5a247e09524d7b0776db6c5fb4a11aa Mon Sep 17 00:00:00 2001 From: KP <109694228@qq.com> Date: Wed, 1 Dec 2021 19:09:20 +0800 Subject: [PATCH 07/15] Update inputs and outputs of executor. --- paddlespeech/cli/executor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddlespeech/cli/executor.py b/paddlespeech/cli/executor.py index e307a287b..c132b3b87 100644 --- a/paddlespeech/cli/executor.py +++ b/paddlespeech/cli/executor.py @@ -26,8 +26,8 @@ class BaseExecutor(ABC): """ def __init__(self): - self.input = None - self.output = None + self._inputs = dict() + self._outputs = dict() @abstractmethod def _get_pretrained_path(self, tag: str) -> os.PathLike: From 90d648a601d64aefea0dd9d4a63a87eede1a8b09 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Wed, 1 Dec 2021 12:12:55 +0000 Subject: [PATCH 08/15] support using by __call__ --- paddlespeech/cli/asr/infer.py | 138 ++++++++++++++++++---------------- 1 file changed, 75 insertions(+), 63 deletions(-) diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index a0ae53507..ea1828b6b 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -119,7 +119,8 @@ class ASRExecutor(BaseExecutor): lang: str='zh', model_sample_rate: int=16000, cfg_path: Optional[os.PathLike]=None, - ckpt_path: Optional[os.PathLike]=None): + ckpt_path: Optional[os.PathLike]=None, + device: str='cpu'): """ Init model and other resources from a specific path. """ @@ -140,12 +141,8 @@ class ASRExecutor(BaseExecutor): res_path = os.path.dirname( os.path.dirname(os.path.abspath(self.cfg_path))) - # Enter the path of model root - os.chdir(res_path) - #Init body. - parser_args = self.parser_args - paddle.set_device(parser_args.device) + paddle.set_device(device) self.config = CfgNode(new_allowed=True) self.config.merge_from_file(self.cfg_path) self.config.decoding.decoding_method = "attention_rescoring" @@ -153,29 +150,35 @@ class ASRExecutor(BaseExecutor): logger.info(model_conf) with UpdateConfig(model_conf): - if parser_args.model == "ds2_online" or parser_args.model == "ds2_offline": + if model_type == "ds2_online" or model_type == "ds2_offline": from paddlespeech.s2t.io.collator import SpeechCollator self.config.collator.vocab_filepath = os.path.join( res_path, self.config.collator.vocab_filepath) - self.config.collator.vocab_filepath = os.path.join( + self.config.collator.mean_std_filepath = os.path.join( res_path, self.config.collator.cmvn_path) self.collate_fn_test = SpeechCollator.from_config(self.config) + text_feature = TextFeaturizer( + unit_type=self.config.collator.unit_type, + vocab_filepath=self.config.collator.vocab_filepath, + spm_model_prefix=self.config.collator.spm_model_prefix) model_conf.input_dim = self.collate_fn_test.feature_size - model_conf.output_dim = self.text_feature.vocab_size - elif parser_args.model == "conformer" or parser_args.model == "transformer" or parser_args.model == "wenetspeech": - + model_conf.output_dim = text_feature.vocab_size + elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech": self.config.collator.vocab_filepath = os.path.join( res_path, self.config.collator.vocab_filepath) - self.text_feature = TextFeaturizer( + text_feature = TextFeaturizer( unit_type=self.config.collator.unit_type, vocab_filepath=self.config.collator.vocab_filepath, spm_model_prefix=self.config.collator.spm_model_prefix) model_conf.input_dim = self.config.collator.feat_dim - model_conf.output_dim = self.text_feature.vocab_size + model_conf.output_dim = text_feature.vocab_size else: raise Exception("wrong type") self.config.freeze() - model_class = dynamic_import(parser_args.model, model_alias) + # Enter the path of model root + os.chdir(res_path) + + model_class = dynamic_import(model_type, model_alias) model = model_class.from_config(model_conf) self.model = model self.model.eval() @@ -185,31 +188,31 @@ class ASRExecutor(BaseExecutor): model_dict = paddle.load(params_path) self.model.set_state_dict(model_dict) - def preprocess(self, input: Union[str, os.PathLike]): + def preprocess(self, model_type: str, input: Union[str, os.PathLike]): """ Input preprocess and return paddle.Tensor stored in self.input. Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet). """ - parser_args = self.parser_args - config = self.config audio_file = input logger.info("Preprocess audio_file:" + audio_file) - self.sr = config.collator.target_sample_rate + config_target_sample_rate = self.config.collator.target_sample_rate # Get the object for feature extraction - if parser_args.model == "ds2_online" or parser_args.model == "ds2_offline": + if model_type == "ds2_online" or model_type == "ds2_offline": audio, _ = self.collate_fn_test.process_utterance( audio_file=audio_file, transcript=" ") audio_len = audio.shape[0] audio = paddle.to_tensor(audio, dtype='float32') - self.audio_len = paddle.to_tensor(audio_len) - self.audio = paddle.unsqueeze(audio, axis=0) - self.vocab_list = collate_fn_test.vocab_list - logger.info(f"audio feat shape: {self.audio.shape}") - - elif parser_args.model == "conformer" or parser_args.model == "transformer" or parser_args.model == "wenetspeech": + audio_len = paddle.to_tensor(audio_len) + audio = paddle.unsqueeze(audio, axis=0) + vocab_list = collate_fn_test.vocab_list + self._inputs["audio"] = audio + self._inputs["audio_len"] = audio_len + logger.info(f"audio feat shape: {audio.shape}") + + elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech": logger.info("get the preprocess conf") preprocess_conf = os.path.join( os.path.dirname(os.path.abspath(self.cfg_path)), @@ -235,7 +238,7 @@ class ASRExecutor(BaseExecutor): else: audio = audio[:, 0] - if sample_rate != self.sr: + if sample_rate != config_target_sample_rate: logger.error( f"sample rate error: {sample_rate}, need {self.sr} ") sys.exit(-1) @@ -243,29 +246,36 @@ class ASRExecutor(BaseExecutor): # fbank audio = preprocessing(audio, **preprocess_args) - self.audio_len = paddle.to_tensor(audio.shape[0]) - self.audio = paddle.to_tensor( - audio, dtype='float32').unsqueeze(axis=0) - logger.info(f"audio feat shape: {self.audio.shape}") + audio_len = paddle.to_tensor(audio.shape[0]) + audio = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0) + text_feature = TextFeaturizer( + unit_type=self.config.collator.unit_type, + vocab_filepath=self.config.collator.vocab_filepath, + spm_model_prefix=self.config.collator.spm_model_prefix) + self._inputs["audio"] = audio + self._inputs["audio_len"] = audio_len + logger.info(f"audio feat shape: {audio.shape}") else: raise Exception("wrong type") @paddle.no_grad() - def infer(self): + def infer(self, model_type: str): """ Model inference and result stored in self.output. """ + text_feature = TextFeaturizer( + unit_type=self.config.collator.unit_type, + vocab_filepath=self.config.collator.vocab_filepath, + spm_model_prefix=self.config.collator.spm_model_prefix) cfg = self.config.decoding - parser_args = self.parser_args - audio = self.audio - audio_len = self.audio_len - if parser_args.model == "ds2_online" or parser_args.model == "ds2_offline": - vocab_list = self.vocab_list + audio = self._inputs["audio"] + audio_len = self._inputs["audio_len"] + if model_type == "ds2_online" or model_type == "ds2_offline": result_transcripts = self.model.decode( audio, audio_len, - vocab_list, + text_feature.vocab_list, decoding_method=cfg.decoding_method, lang_model_path=cfg.lang_model_path, beam_alpha=cfg.alpha, @@ -274,14 +284,13 @@ class ASRExecutor(BaseExecutor): cutoff_prob=cfg.cutoff_prob, cutoff_top_n=cfg.cutoff_top_n, num_processes=cfg.num_proc_bsearch) - self.result_transcripts = result_transcripts[0] + self._outputs["result"] = result_transcripts[0] - elif parser_args.model == "conformer" or parser_args.model == "transformer" or parser_args.model == "wenetspeech": - text_feature = self.text_feature + elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech": result_transcripts = self.model.decode( audio, audio_len, - text_feature=self.text_feature, + text_feature=text_feature, decoding_method=cfg.decoding_method, lang_model_path=cfg.lang_model_path, beam_alpha=cfg.alpha, @@ -294,23 +303,22 @@ class ASRExecutor(BaseExecutor): decoding_chunk_size=cfg.decoding_chunk_size, num_decoding_left_chunks=cfg.num_decoding_left_chunks, simulate_streaming=cfg.simulate_streaming) - self.result_transcripts = result_transcripts[0][0] + self._outputs["result"] = result_transcripts[0][0] else: raise Exception("invalid model name") - pass - def postprocess(self) -> Union[str, os.PathLike]: """ Output postprocess and return human-readable results such as texts and audio files. """ - return self.result_transcripts + return self._outputs["result"] def _check(self, audio_file: str, model_sample_rate: int): self.target_sample_rate = model_sample_rate if self.target_sample_rate != 16000 and self.target_sample_rate != 8000: logger.error( - "please input --model_sample_rate 8000 or --model_sample_rate 16000") + "please input --model_sample_rate 8000 or --model_sample_rate 16000" + ) raise Exception("invalid sample rate") sys.exit(-1) @@ -336,11 +344,13 @@ class ASRExecutor(BaseExecutor): sys.exit(-1) logger.info("The sample rate is %d" % sample_rate) if sample_rate != self.target_sample_rate: - logger.warning("The sample rate of the input file is not {}.\n \ + logger.warning( + "The sample rate of the input file is not {}.\n \ The program will resample the wav file to {}.\n \ If the result does not meet your expectations,\n \ Please input the 16k 16bit 1 channel wav file. \ - ".format(self.target_sample_rate, self.target_sample_rate)) + " + .format(self.target_sample_rate, self.target_sample_rate)) while (True): logger.info( "Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream." @@ -367,34 +377,36 @@ class ASRExecutor(BaseExecutor): """ Command line entry. """ - self.parser_args = self.parser.parse_args(argv) + parser_args = self.parser.parse_args(argv) - model = self.parser_args.model - lang = self.parser_args.lang - model_sample_rate = self.parser_args.model_sample_rate - config = self.parser_args.config - ckpt_path = self.parser_args.ckpt_path - audio_file = os.path.abspath(self.parser_args.input) - device = self.parser_args.device + model = parser_args.model + lang = parser_args.lang + model_sample_rate = parser_args.model_sample_rate + config = parser_args.config + ckpt_path = parser_args.ckpt_path + audio_file = parser_args.input + device = parser_args.device try: - res = self(model, lang, model_sample_rate, config, ckpt_path, audio_file, - device) + res = self(model, lang, model_sample_rate, config, ckpt_path, + audio_file, device) logger.info('ASR Result: {}'.format(res)) return True except Exception as e: print(e) return False - def __call__(self, model, lang, model_sample_rate, config, ckpt_path, audio_file, - device): + def __call__(self, model, lang, model_sample_rate, config, ckpt_path, + audio_file, device): """ Python API to call an executor. """ + audio_file = os.path.abspath(audio_file) self._check(audio_file, model_sample_rate) - self._init_from_path(model, lang, model_sample_rate, config, ckpt_path) - self.preprocess(audio_file) - self.infer() + self._init_from_path(model, lang, model_sample_rate, config, ckpt_path, + device) + self.preprocess(model, audio_file) + self.infer(model) res = self.postprocess() # Retrieve result of asr. return res From e0642ffc772d80a0d8fdc34c2f4d93c704c74571 Mon Sep 17 00:00:00 2001 From: KP <109694228@qq.com> Date: Wed, 1 Dec 2021 20:36:55 +0800 Subject: [PATCH 09/15] Update doc strings. --- paddlespeech/cli/executor.py | 43 +++++++++++++++++++++++++++--------- 1 file changed, 32 insertions(+), 11 deletions(-) diff --git a/paddlespeech/cli/executor.py b/paddlespeech/cli/executor.py index c132b3b87..00371371d 100644 --- a/paddlespeech/cli/executor.py +++ b/paddlespeech/cli/executor.py @@ -14,6 +14,7 @@ import os from abc import ABC from abc import abstractmethod +from typing import Any from typing import List from typing import Union @@ -32,50 +33,70 @@ class BaseExecutor(ABC): @abstractmethod def _get_pretrained_path(self, tag: str) -> os.PathLike: """ - Download and returns pretrained resources path of current task. + Download and returns pretrained resources path of current task. + + Args: + tag (str): A tag of pretrained model. + + Returns: + os.PathLike: The path on which resources of pretrained model locate. """ pass @abstractmethod def _init_from_path(self, *args, **kwargs): """ - Init model and other resources from a specific path. + Init model and other resources from arguments. This method should be called by `__call__()`. """ pass @abstractmethod - def preprocess(self, input: Union[str, os.PathLike]): + def preprocess(self, input: Any, *args, **kwargs): """ - Input preprocess and return paddle.Tensor stored in self.input. - Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet). + Input preprocess and return paddle.Tensor stored in self._inputs. + Input content can be a text(tts), a file(asr, cls), a stream(not supported yet) or anything needed. + + Args: + input (Any): Input text/file/stream or other content. """ pass @paddle.no_grad() @abstractmethod - def infer(self, device: str): + def infer(self, *args, **kwargs): """ - Model inference and result stored in self.output. + Model inference and put results into self._outputs. + This method get input tensors from self._inputs, and write output tensors into self._outputs. """ pass @abstractmethod - def postprocess(self) -> Union[str, os.PathLike]: + def postprocess(self, *args, **kwargs) -> Union[str, os.PathLike]: """ - Output postprocess and return human-readable results such as texts and audio files. + Output postprocess and return results. + This method get model output from self._outputs and convert it into human-readable results. + + Returns: + Union[str, os.PathLike]: Human-readable results such as texts and audio files. """ pass @abstractmethod def execute(self, argv: List[str]) -> bool: """ - Command line entry. + Command line entry. This method can only be accessed by a command line such as `paddlespeech asr`. + + Args: + argv (List[str]): Arguments from command line. + + Returns: + int: Result of the command execution. `True` for a success and `False` for a failure. """ pass @abstractmethod def __call__(self, *arg, **kwargs): """ - Python API to call an executor. + Python API to call an executor. """ pass From a19e51d7da83a6794652ac2d965f6cd880a10b86 Mon Sep 17 00:00:00 2001 From: KP <109694228@qq.com> Date: Wed, 1 Dec 2021 20:45:34 +0800 Subject: [PATCH 10/15] Update python api. --- paddlespeech/cli/asr/infer.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index ea1828b6b..00216356c 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -119,8 +119,7 @@ class ASRExecutor(BaseExecutor): lang: str='zh', model_sample_rate: int=16000, cfg_path: Optional[os.PathLike]=None, - ckpt_path: Optional[os.PathLike]=None, - device: str='cpu'): + ckpt_path: Optional[os.PathLike]=None): """ Init model and other resources from a specific path. """ @@ -142,7 +141,6 @@ class ASRExecutor(BaseExecutor): os.path.dirname(os.path.abspath(self.cfg_path))) #Init body. - paddle.set_device(device) self.config = CfgNode(new_allowed=True) self.config.merge_from_file(self.cfg_path) self.config.decoding.decoding_method = "attention_rescoring" @@ -403,8 +401,9 @@ class ASRExecutor(BaseExecutor): """ audio_file = os.path.abspath(audio_file) self._check(audio_file, model_sample_rate) - self._init_from_path(model, lang, model_sample_rate, config, ckpt_path, - device) + + paddle.set_device(device) + self._init_from_path(model, lang, model_sample_rate, config, ckpt_path) self.preprocess(model, audio_file) self.infer(model) res = self.postprocess() # Retrieve result of asr. From aee530af2773e5127d28727468d4c30128ba3527 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Thu, 2 Dec 2021 03:24:07 +0000 Subject: [PATCH 11/15] revise the sample rate --- paddlespeech/cli/asr/infer.py | 48 +++++++++++++++-------------------- 1 file changed, 21 insertions(+), 27 deletions(-) diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index ea1828b6b..c9ec058cd 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -78,7 +78,7 @@ class ASRExecutor(BaseExecutor): default='zh', help='Choose model language. zh or en') self.parser.add_argument( - "--model_sample_rate", + "--sr", type=int, default=16000, help='Choose the audio sample rate of the model. 8000 or 16000') @@ -117,7 +117,7 @@ class ASRExecutor(BaseExecutor): def _init_from_path(self, model_type: str='wenetspeech', lang: str='zh', - model_sample_rate: int=16000, + sample_rate: int=16000, cfg_path: Optional[os.PathLike]=None, ckpt_path: Optional[os.PathLike]=None, device: str='cpu'): @@ -125,8 +125,8 @@ class ASRExecutor(BaseExecutor): Init model and other resources from a specific path. """ if cfg_path is None or ckpt_path is None: - model_sample_rate_str = '16k' if model_sample_rate == 16000 else '8k' - tag = model_type + '_' + lang + '_' + model_sample_rate_str + sample_rate_str = '16k' if sample_rate == 16000 else '8k' + tag = model_type + '_' + lang + '_' + sample_rate_str res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.cfg_path = os.path.join(res_path, pretrained_models[tag]['cfg_path']) @@ -197,8 +197,6 @@ class ASRExecutor(BaseExecutor): audio_file = input logger.info("Preprocess audio_file:" + audio_file) - config_target_sample_rate = self.config.collator.target_sample_rate - # Get the object for feature extraction if model_type == "ds2_online" or model_type == "ds2_offline": audio, _ = self.collate_fn_test.process_utterance( @@ -222,7 +220,7 @@ class ASRExecutor(BaseExecutor): preprocess_args = {"train": False} preprocessing = Transformation(preprocess_conf) logger.info("read the audio file") - audio, sample_rate = soundfile.read( + audio, audio_sample_rate = soundfile.read( audio_file, dtype="int16", always_2d=True) if self.change_format: @@ -231,17 +229,13 @@ class ASRExecutor(BaseExecutor): else: audio = audio[:, 0] audio = audio.astype("float32") - audio = librosa.resample(audio, sample_rate, - self.target_sample_rate) - sample_rate = self.target_sample_rate + audio = librosa.resample(audio, audio_sample_rate, + self.sample_rate) + audio_sample_rate = self.sample_rate audio = audio.astype("int16") else: audio = audio[:, 0] - if sample_rate != config_target_sample_rate: - logger.error( - f"sample rate error: {sample_rate}, need {self.sr} ") - sys.exit(-1) logger.info(f"audio shape: {audio.shape}") # fbank audio = preprocessing(audio, **preprocess_args) @@ -313,11 +307,11 @@ class ASRExecutor(BaseExecutor): """ return self._outputs["result"] - def _check(self, audio_file: str, model_sample_rate: int): - self.target_sample_rate = model_sample_rate - if self.target_sample_rate != 16000 and self.target_sample_rate != 8000: + def _check(self, audio_file: str, sample_rate: int): + self.sample_rate = sample_rate + if self.sample_rate != 16000 and self.sample_rate != 8000: logger.error( - "please input --model_sample_rate 8000 or --model_sample_rate 16000" + "please input --sr 8000 or --sr 16000" ) raise Exception("invalid sample rate") sys.exit(-1) @@ -328,7 +322,7 @@ class ASRExecutor(BaseExecutor): logger.info("checking the audio file format......") try: - sig, sample_rate = soundfile.read( + audio, audio_sample_rate = soundfile.read( audio_file, dtype="int16", always_2d=True) except Exception as e: logger.error(str(e)) @@ -342,15 +336,15 @@ class ASRExecutor(BaseExecutor): sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \ ") sys.exit(-1) - logger.info("The sample rate is %d" % sample_rate) - if sample_rate != self.target_sample_rate: + logger.info("The sample rate is %d" % audio_sample_rate) + if audio_sample_rate != self.sample_rate: logger.warning( "The sample rate of the input file is not {}.\n \ The program will resample the wav file to {}.\n \ If the result does not meet your expectations,\n \ Please input the 16k 16bit 1 channel wav file. \ " - .format(self.target_sample_rate, self.target_sample_rate)) + .format(self.sample_rate, self.sample_rate)) while (True): logger.info( "Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream." @@ -381,14 +375,14 @@ class ASRExecutor(BaseExecutor): model = parser_args.model lang = parser_args.lang - model_sample_rate = parser_args.model_sample_rate + sample_rate = parser_args.sr config = parser_args.config ckpt_path = parser_args.ckpt_path audio_file = parser_args.input device = parser_args.device try: - res = self(model, lang, model_sample_rate, config, ckpt_path, + res = self(model, lang, sample_rate, config, ckpt_path, audio_file, device) logger.info('ASR Result: {}'.format(res)) return True @@ -396,14 +390,14 @@ class ASRExecutor(BaseExecutor): print(e) return False - def __call__(self, model, lang, model_sample_rate, config, ckpt_path, + def __call__(self, model, lang, sample_rate, config, ckpt_path, audio_file, device): """ Python API to call an executor. """ audio_file = os.path.abspath(audio_file) - self._check(audio_file, model_sample_rate) - self._init_from_path(model, lang, model_sample_rate, config, ckpt_path, + self._check(audio_file, sample_rate) + self._init_from_path(model, lang, sample_rate, config, ckpt_path, device) self.preprocess(model, audio_file) self.infer(model) From a9d206c1bfc433f1aec6cebbb783f8539e0bb6a9 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Thu, 2 Dec 2021 05:58:20 +0000 Subject: [PATCH 12/15] revise --- paddlespeech/cli/asr/infer.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 6ae038539..e9d8c0b11 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -22,6 +22,7 @@ import librosa import paddle import soundfile from yacs.config import CfgNode +import numpy as np from ..executor import BaseExecutor from ..utils import cli_register @@ -81,6 +82,7 @@ class ASRExecutor(BaseExecutor): "--sr", type=int, default=16000, + choices=[8000, 16000], help='Choose the audio sample rate of the model. 8000 or 16000') self.parser.add_argument( '--config', @@ -131,13 +133,13 @@ class ASRExecutor(BaseExecutor): self.cfg_path = os.path.join(res_path, pretrained_models[tag]['cfg_path']) self.ckpt_path = os.path.join(res_path, - pretrained_models[tag]['ckpt_path']) + pretrained_models[tag]['ckpt_path'] + ".pdparams") logger.info(res_path) logger.info(self.cfg_path) logger.info(self.ckpt_path) else: self.cfg_path = os.path.abspath(cfg_path) - self.ckpt_path = os.path.abspath(ckpt_path) + self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams") res_path = os.path.dirname( os.path.dirname(os.path.abspath(self.cfg_path))) @@ -183,8 +185,7 @@ class ASRExecutor(BaseExecutor): self.model.eval() # load model - params_path = self.ckpt_path + ".pdparams" - model_dict = paddle.load(params_path) + model_dict = paddle.load(self.ckpt_path) self.model.set_state_dict(model_dict) def preprocess(self, model_type: str, input: Union[str, os.PathLike]): @@ -231,7 +232,7 @@ class ASRExecutor(BaseExecutor): audio = librosa.resample(audio, audio_sample_rate, self.sample_rate) audio_sample_rate = self.sample_rate - audio = audio.astype("int16") + audio = np.round(audio).astype("int16") else: audio = audio[:, 0] From b0356ae4892c85984804ecc1fda1f9cf4d5018ac Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Thu, 2 Dec 2021 05:58:20 +0000 Subject: [PATCH 13/15] revise --- paddlespeech/cli/asr/infer.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 6ae038539..640cf729f 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -22,6 +22,7 @@ import librosa import paddle import soundfile from yacs.config import CfgNode +import numpy as np from ..executor import BaseExecutor from ..utils import cli_register @@ -81,6 +82,7 @@ class ASRExecutor(BaseExecutor): "--sr", type=int, default=16000, + choices=[8000, 16000], help='Choose the audio sample rate of the model. 8000 or 16000') self.parser.add_argument( '--config', @@ -131,13 +133,13 @@ class ASRExecutor(BaseExecutor): self.cfg_path = os.path.join(res_path, pretrained_models[tag]['cfg_path']) self.ckpt_path = os.path.join(res_path, - pretrained_models[tag]['ckpt_path']) + pretrained_models[tag]['ckpt_path'] + ".pdparams") logger.info(res_path) logger.info(self.cfg_path) logger.info(self.ckpt_path) else: self.cfg_path = os.path.abspath(cfg_path) - self.ckpt_path = os.path.abspath(ckpt_path) + self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams") res_path = os.path.dirname( os.path.dirname(os.path.abspath(self.cfg_path))) @@ -183,8 +185,7 @@ class ASRExecutor(BaseExecutor): self.model.eval() # load model - params_path = self.ckpt_path + ".pdparams" - model_dict = paddle.load(params_path) + model_dict = paddle.load(self.ckpt_path) self.model.set_state_dict(model_dict) def preprocess(self, model_type: str, input: Union[str, os.PathLike]): @@ -227,11 +228,16 @@ class ASRExecutor(BaseExecutor): audio = audio.mean(axis=1) else: audio = audio[:, 0] + # pcm16 -> pcm 32 audio = audio.astype("float32") + bits = np.iinfo(np.int16).bits + audio = audio / (2**(bits - 1)) audio = librosa.resample(audio, audio_sample_rate, self.sample_rate) audio_sample_rate = self.sample_rate - audio = audio.astype("int16") + # pcm16 -> pcm 32 + audio = audio * (2**(bits - 1)) + audio = np.round(audio).astype("int16") else: audio = audio[:, 0] @@ -341,7 +347,7 @@ class ASRExecutor(BaseExecutor): "The sample rate of the input file is not {}.\n \ The program will resample the wav file to {}.\n \ If the result does not meet your expectations,\n \ - Please input the 16k 16bit 1 channel wav file. \ + Please input the 16k 16 bit 1 channel wav file. \ " .format(self.sample_rate, self.sample_rate)) while (True): From 8ec576f477603269a677d11ddd87a56163a656aa Mon Sep 17 00:00:00 2001 From: Jackwaterveg <87408988+Jackwaterveg@users.noreply.github.com> Date: Thu, 2 Dec 2021 15:03:04 +0800 Subject: [PATCH 14/15] Update infer.py --- paddlespeech/cli/asr/infer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 640cf729f..66a2f169f 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -235,7 +235,7 @@ class ASRExecutor(BaseExecutor): audio = librosa.resample(audio, audio_sample_rate, self.sample_rate) audio_sample_rate = self.sample_rate - # pcm16 -> pcm 32 + # pcm32 -> pcm 16 audio = audio * (2**(bits - 1)) audio = np.round(audio).astype("int16") else: From a258a34ec037066d309e6dec01da48b35a3317eb Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Thu, 2 Dec 2021 07:22:12 +0000 Subject: [PATCH 15/15] revise the convert pcm --- paddlespeech/cli/asr/infer.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 66a2f169f..48772997a 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -225,19 +225,16 @@ class ASRExecutor(BaseExecutor): if self.change_format: if audio.shape[1] >= 2: - audio = audio.mean(axis=1) + audio = audio.mean(axis=1, dtype=np.int16) else: audio = audio[:, 0] # pcm16 -> pcm 32 - audio = audio.astype("float32") - bits = np.iinfo(np.int16).bits - audio = audio / (2**(bits - 1)) + audio = self._pcm16to32(audio) audio = librosa.resample(audio, audio_sample_rate, self.sample_rate) audio_sample_rate = self.sample_rate # pcm32 -> pcm 16 - audio = audio * (2**(bits - 1)) - audio = np.round(audio).astype("int16") + audio = self._pcm32to16(audio) else: audio = audio[:, 0] @@ -312,6 +309,20 @@ class ASRExecutor(BaseExecutor): """ return self._outputs["result"] + def _pcm16to32(self, audio): + assert(audio.dtype == np.int16) + audio = audio.astype("float32") + bits = np.iinfo(np.int16).bits + audio = audio / (2**(bits - 1)) + return audio + + def _pcm32to16(self, audio): + assert(audio.dtype == np.float32) + bits = np.iinfo(np.int16).bits + audio = audio * (2**(bits - 1)) + audio = np.round(audio).astype("int16") + return audio + def _check(self, audio_file: str, sample_rate: int): self.sample_rate = sample_rate if self.sample_rate != 16000 and self.sample_rate != 8000: