[Cli]optimize the cli, add --yes, and delete transformer_aishell (#1154)

* optimize the cli/asr,test=asr

* test=doc_fix
pull/1157/head
Jackwaterveg 3 years ago committed by GitHub
parent 7c44fb9cd8
commit e9748faa71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -40,7 +40,7 @@ __all__ = ['ASRExecutor']
pretrained_models = {
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
# e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k".
# e.g. "conformer_wenetspeech-zh-16k" and "panns_cnn6-32k".
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
"conformer_wenetspeech-zh-16k": {
@ -53,16 +53,6 @@ pretrained_models = {
'ckpt_path':
'exp/conformer/checkpoints/wenetspeech',
},
"transformer_aishell-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/transformer.model.tar.gz',
'md5':
'4e8b63800c71040b9390b150e2a5d4c4',
'cfg_path':
'conf/transformer.yaml',
'ckpt_path':
'exp/transformer/checkpoints/avg_20',
}
}
model_alias = {
@ -111,6 +101,11 @@ class ASRExecutor(BaseExecutor):
type=str,
default=None,
help='Checkpoint file of model.')
self.parser.add_argument(
'--yes','-y',
action="store_true",
default=False,
help='No additional parameters required. Once set this parameter, it means accepting the request of the program by default, which includes transforming the audio sample rate')
self.parser.add_argument(
'--device',
type=str,
@ -350,7 +345,7 @@ class ASRExecutor(BaseExecutor):
audio = np.round(audio).astype("int16")
return audio
def _check(self, audio_file: str, sample_rate: int):
def _check(self, audio_file: str, sample_rate: int, force_yes: bool):
self.sample_rate = sample_rate
if self.sample_rate != 16000 and self.sample_rate != 8000:
logger.error("please input --sr 8000 or --sr 16000")
@ -384,22 +379,23 @@ class ASRExecutor(BaseExecutor):
If the result does not meet your expectations\n \
Please input the 16k 16 bit 1 channel wav file. \
".format(self.sample_rate, self.sample_rate))
while (True):
logger.info(
"Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
)
content = input("Input(Y/N):")
if content.strip() == "Y" or content.strip(
) == "y" or content.strip() == "yes" or content.strip() == "Yes":
if force_yes == False:
while (True):
logger.info(
"change the sampele rate, channel to 16k and 1 channel")
break
elif content.strip() == "N" or content.strip(
) == "n" or content.strip() == "no" or content.strip() == "No":
logger.info("Exit the program")
exit(1)
else:
logger.warning("Not regular input, please input again")
"Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
)
content = input("Input(Y/N):")
if content.strip() == "Y" or content.strip(
) == "y" or content.strip() == "yes" or content.strip() == "Yes":
logger.info(
"change the sampele rate, channel to 16k and 1 channel")
break
elif content.strip() == "N" or content.strip(
) == "n" or content.strip() == "no" or content.strip() == "No":
logger.info("Exit the program")
exit(1)
else:
logger.warning("Not regular input, please input again")
self.change_format = True
else:
@ -418,10 +414,11 @@ class ASRExecutor(BaseExecutor):
config = parser_args.config
ckpt_path = parser_args.ckpt_path
audio_file = parser_args.input
force_yes = parser_args.yes
device = parser_args.device
try:
res = self(audio_file, model, lang, sample_rate, config, ckpt_path,
res = self(audio_file, model, lang, sample_rate, config, ckpt_path, force_yes,
device)
logger.info('ASR Result: {}'.format(res))
return True
@ -436,12 +433,13 @@ class ASRExecutor(BaseExecutor):
sample_rate: int=16000,
config: os.PathLike=None,
ckpt_path: os.PathLike=None,
force_yes: bool=False,
device=paddle.get_device()):
"""
Python API to call an executor.
"""
audio_file = os.path.abspath(audio_file)
self._check(audio_file, sample_rate)
self._check(audio_file, sample_rate, force_yes)
paddle.set_device(device)
self._init_from_path(model, lang, sample_rate, config, ckpt_path)
self.preprocess(model, audio_file)

Loading…
Cancel
Save