added cli changed code, test=doc

pull/1475/head
WilliamZhang06 3 years ago
parent 7ebe904e20
commit 147018a8b4

@ -287,7 +287,8 @@ class ASRExecutor(BaseExecutor):
""" """
audio_file = input audio_file = input
logger.info("Preprocess audio_file:" + audio_file) if isinstance(audio_file, (str, os.PathLike)):
logger.info("Preprocess audio_file:" + audio_file)
# Get the object for feature extraction # Get the object for feature extraction
if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: if "deepspeech2online" in model_type or "deepspeech2offline" in model_type:
@ -408,13 +409,8 @@ class ASRExecutor(BaseExecutor):
def _check(self, audio_file: str, sample_rate: int, force_yes: bool): def _check(self, audio_file: str, sample_rate: int, force_yes: bool):
self.sample_rate = sample_rate self.sample_rate = sample_rate
if self.sample_rate != 16000 and self.sample_rate != 8000: if self.sample_rate != 16000 and self.sample_rate != 8000:
logger.error("please input --sr 8000 or --sr 16000") logger.error("invalid sample rate, please input --sr 8000 or --sr 16000")
raise Exception("invalid sample rate") return False
sys.exit(-1)
if not os.path.isfile(audio_file):
logger.error("Please input the right audio file path")
sys.exit(-1)
logger.info("checking the audio file format......") logger.info("checking the audio file format......")
try: try:
@ -431,7 +427,7 @@ class ASRExecutor(BaseExecutor):
sample rate: 8k \n \ sample rate: 8k \n \
sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \ sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \
") ")
sys.exit(-1) return False
logger.info("The sample rate is %d" % audio_sample_rate) logger.info("The sample rate is %d" % audio_sample_rate)
if audio_sample_rate != self.sample_rate: if audio_sample_rate != self.sample_rate:
logger.warning("The sample rate of the input file is not {}.\n \ logger.warning("The sample rate of the input file is not {}.\n \
@ -465,6 +461,8 @@ class ASRExecutor(BaseExecutor):
logger.info("The audio file format is right") logger.info("The audio file format is right")
self.change_format = False self.change_format = False
return True
def execute(self, argv: List[str]) -> bool: def execute(self, argv: List[str]) -> bool:
""" """
Command line entry. Command line entry.
@ -517,7 +515,8 @@ class ASRExecutor(BaseExecutor):
Python API to call an executor. Python API to call an executor.
""" """
audio_file = os.path.abspath(audio_file) audio_file = os.path.abspath(audio_file)
self._check(audio_file, sample_rate, force_yes) if not self._check(audio_file, sample_rate, force_yes):
sys.exit(-1)
paddle.set_device(device) paddle.set_device(device)
self._init_from_path(model, lang, sample_rate, config, decode_method, self._init_from_path(model, lang, sample_rate, config, decode_method,
ckpt_path) ckpt_path)

@ -74,4 +74,4 @@ class ServerExecutor(BaseExecutor):
config = get_config(args.config_file) config = get_config(args.config_file)
if self.init(config): if self.init(config):
uvicorn.run(app, host=config.host, port=config.port, debug=True)v uvicorn.run(app, host=config.host, port=config.port, debug=True)
Loading…
Cancel
Save