From e9748faa7114b6ce3709ee4a31ff7cc613db1774 Mon Sep 17 00:00:00 2001 From: Jackwaterveg <87408988+Jackwaterveg@users.noreply.github.com> Date: Fri, 17 Dec 2021 14:09:43 +0800 Subject: [PATCH] [Cli]optimize the cli, add --yes, and delete transformer_aishell (#1154) * optimize the cli/asr,test=asr * test=doc_fix --- paddlespeech/cli/asr/infer.py | 56 +++++++++++++++++------------------ 1 file changed, 27 insertions(+), 29 deletions(-) diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 89a9fcfa..db659a88 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -40,7 +40,7 @@ __all__ = ['ASRExecutor'] pretrained_models = { # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". - # e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k". + # e.g. "conformer_wenetspeech-zh-16k" and "panns_cnn6-32k". # Command line and python api use "{model_name}[_{dataset}]" as --model, usage: # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" "conformer_wenetspeech-zh-16k": { @@ -53,16 +53,6 @@ pretrained_models = { 'ckpt_path': 'exp/conformer/checkpoints/wenetspeech', }, - "transformer_aishell-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/transformer.model.tar.gz', - 'md5': - '4e8b63800c71040b9390b150e2a5d4c4', - 'cfg_path': - 'conf/transformer.yaml', - 'ckpt_path': - 'exp/transformer/checkpoints/avg_20', - } } model_alias = { @@ -111,6 +101,11 @@ class ASRExecutor(BaseExecutor): type=str, default=None, help='Checkpoint file of model.') + self.parser.add_argument( + '--yes','-y', + action="store_true", + default=False, + help='No additional parameters required. Once set this parameter, it means accepting the request of the program by default, which includes transforming the audio sample rate') self.parser.add_argument( '--device', type=str, @@ -350,7 +345,7 @@ class ASRExecutor(BaseExecutor): audio = np.round(audio).astype("int16") return audio - def _check(self, audio_file: str, sample_rate: int): + def _check(self, audio_file: str, sample_rate: int, force_yes: bool): self.sample_rate = sample_rate if self.sample_rate != 16000 and self.sample_rate != 8000: logger.error("please input --sr 8000 or --sr 16000") @@ -384,22 +379,23 @@ class ASRExecutor(BaseExecutor): If the result does not meet your expectations,\n \ Please input the 16k 16 bit 1 channel wav file. \ ".format(self.sample_rate, self.sample_rate)) - while (True): - logger.info( - "Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream." - ) - content = input("Input(Y/N):") - if content.strip() == "Y" or content.strip( - ) == "y" or content.strip() == "yes" or content.strip() == "Yes": + if force_yes == False: + while (True): logger.info( - "change the sampele rate, channel to 16k and 1 channel") - break - elif content.strip() == "N" or content.strip( - ) == "n" or content.strip() == "no" or content.strip() == "No": - logger.info("Exit the program") - exit(1) - else: - logger.warning("Not regular input, please input again") + "Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream." + ) + content = input("Input(Y/N):") + if content.strip() == "Y" or content.strip( + ) == "y" or content.strip() == "yes" or content.strip() == "Yes": + logger.info( + "change the sampele rate, channel to 16k and 1 channel") + break + elif content.strip() == "N" or content.strip( + ) == "n" or content.strip() == "no" or content.strip() == "No": + logger.info("Exit the program") + exit(1) + else: + logger.warning("Not regular input, please input again") self.change_format = True else: @@ -418,10 +414,11 @@ class ASRExecutor(BaseExecutor): config = parser_args.config ckpt_path = parser_args.ckpt_path audio_file = parser_args.input + force_yes = parser_args.yes device = parser_args.device try: - res = self(audio_file, model, lang, sample_rate, config, ckpt_path, + res = self(audio_file, model, lang, sample_rate, config, ckpt_path, force_yes, device) logger.info('ASR Result: {}'.format(res)) return True @@ -436,12 +433,13 @@ class ASRExecutor(BaseExecutor): sample_rate: int=16000, config: os.PathLike=None, ckpt_path: os.PathLike=None, + force_yes: bool=False, device=paddle.get_device()): """ Python API to call an executor. """ audio_file = os.path.abspath(audio_file) - self._check(audio_file, sample_rate) + self._check(audio_file, sample_rate, force_yes) paddle.set_device(device) self._init_from_path(model, lang, sample_rate, config, ckpt_path) self.preprocess(model, audio_file)