[Cli]optimize the cli, add --yes, and delete transformer_aishell (#1154)

* optimize the cli/asr,test=asr

* test=doc_fix
pull/1157/head
Jackwaterveg 3 years ago committed by GitHub
parent 7c44fb9cd8
commit e9748faa71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -40,7 +40,7 @@ __all__ = ['ASRExecutor']
pretrained_models = { pretrained_models = {
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
# e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k". # e.g. "conformer_wenetspeech-zh-16k" and "panns_cnn6-32k".
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage: # Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
"conformer_wenetspeech-zh-16k": { "conformer_wenetspeech-zh-16k": {
@ -53,16 +53,6 @@ pretrained_models = {
'ckpt_path': 'ckpt_path':
'exp/conformer/checkpoints/wenetspeech', 'exp/conformer/checkpoints/wenetspeech',
}, },
"transformer_aishell-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/transformer.model.tar.gz',
'md5':
'4e8b63800c71040b9390b150e2a5d4c4',
'cfg_path':
'conf/transformer.yaml',
'ckpt_path':
'exp/transformer/checkpoints/avg_20',
}
} }
model_alias = { model_alias = {
@ -111,6 +101,11 @@ class ASRExecutor(BaseExecutor):
type=str, type=str,
default=None, default=None,
help='Checkpoint file of model.') help='Checkpoint file of model.')
self.parser.add_argument(
'--yes','-y',
action="store_true",
default=False,
help='No additional parameters required. Once set this parameter, it means accepting the request of the program by default, which includes transforming the audio sample rate')
self.parser.add_argument( self.parser.add_argument(
'--device', '--device',
type=str, type=str,
@ -350,7 +345,7 @@ class ASRExecutor(BaseExecutor):
audio = np.round(audio).astype("int16") audio = np.round(audio).astype("int16")
return audio return audio
def _check(self, audio_file: str, sample_rate: int): def _check(self, audio_file: str, sample_rate: int, force_yes: bool):
self.sample_rate = sample_rate self.sample_rate = sample_rate
if self.sample_rate != 16000 and self.sample_rate != 8000: if self.sample_rate != 16000 and self.sample_rate != 8000:
logger.error("please input --sr 8000 or --sr 16000") logger.error("please input --sr 8000 or --sr 16000")
@ -384,22 +379,23 @@ class ASRExecutor(BaseExecutor):
If the result does not meet your expectations\n \ If the result does not meet your expectations\n \
Please input the 16k 16 bit 1 channel wav file. \ Please input the 16k 16 bit 1 channel wav file. \
".format(self.sample_rate, self.sample_rate)) ".format(self.sample_rate, self.sample_rate))
while (True): if force_yes == False:
logger.info( while (True):
"Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
)
content = input("Input(Y/N):")
if content.strip() == "Y" or content.strip(
) == "y" or content.strip() == "yes" or content.strip() == "Yes":
logger.info( logger.info(
"change the sampele rate, channel to 16k and 1 channel") "Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
break )
elif content.strip() == "N" or content.strip( content = input("Input(Y/N):")
) == "n" or content.strip() == "no" or content.strip() == "No": if content.strip() == "Y" or content.strip(
logger.info("Exit the program") ) == "y" or content.strip() == "yes" or content.strip() == "Yes":
exit(1) logger.info(
else: "change the sampele rate, channel to 16k and 1 channel")
logger.warning("Not regular input, please input again") break
elif content.strip() == "N" or content.strip(
) == "n" or content.strip() == "no" or content.strip() == "No":
logger.info("Exit the program")
exit(1)
else:
logger.warning("Not regular input, please input again")
self.change_format = True self.change_format = True
else: else:
@ -418,10 +414,11 @@ class ASRExecutor(BaseExecutor):
config = parser_args.config config = parser_args.config
ckpt_path = parser_args.ckpt_path ckpt_path = parser_args.ckpt_path
audio_file = parser_args.input audio_file = parser_args.input
force_yes = parser_args.yes
device = parser_args.device device = parser_args.device
try: try:
res = self(audio_file, model, lang, sample_rate, config, ckpt_path, res = self(audio_file, model, lang, sample_rate, config, ckpt_path, force_yes,
device) device)
logger.info('ASR Result: {}'.format(res)) logger.info('ASR Result: {}'.format(res))
return True return True
@ -436,12 +433,13 @@ class ASRExecutor(BaseExecutor):
sample_rate: int=16000, sample_rate: int=16000,
config: os.PathLike=None, config: os.PathLike=None,
ckpt_path: os.PathLike=None, ckpt_path: os.PathLike=None,
force_yes: bool=False,
device=paddle.get_device()): device=paddle.get_device()):
""" """
Python API to call an executor. Python API to call an executor.
""" """
audio_file = os.path.abspath(audio_file) audio_file = os.path.abspath(audio_file)
self._check(audio_file, sample_rate) self._check(audio_file, sample_rate, force_yes)
paddle.set_device(device) paddle.set_device(device)
self._init_from_path(model, lang, sample_rate, config, ckpt_path) self._init_from_path(model, lang, sample_rate, config, ckpt_path)
self.preprocess(model, audio_file) self.preprocess(model, audio_file)

Loading…
Cancel
Save