|
|
@ -287,7 +287,8 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
audio_file = input
|
|
|
|
audio_file = input
|
|
|
|
logger.info("Preprocess audio_file:" + audio_file)
|
|
|
|
if isinstance(audio_file, (str, os.PathLike)):
|
|
|
|
|
|
|
|
logger.info("Preprocess audio_file:" + audio_file)
|
|
|
|
|
|
|
|
|
|
|
|
# Get the object for feature extraction
|
|
|
|
# Get the object for feature extraction
|
|
|
|
if "deepspeech2online" in model_type or "deepspeech2offline" in model_type:
|
|
|
|
if "deepspeech2online" in model_type or "deepspeech2offline" in model_type:
|
|
|
@ -408,13 +409,8 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
def _check(self, audio_file: str, sample_rate: int, force_yes: bool):
|
|
|
|
def _check(self, audio_file: str, sample_rate: int, force_yes: bool):
|
|
|
|
self.sample_rate = sample_rate
|
|
|
|
self.sample_rate = sample_rate
|
|
|
|
if self.sample_rate != 16000 and self.sample_rate != 8000:
|
|
|
|
if self.sample_rate != 16000 and self.sample_rate != 8000:
|
|
|
|
logger.error("please input --sr 8000 or --sr 16000")
|
|
|
|
logger.error("invalid sample rate, please input --sr 8000 or --sr 16000")
|
|
|
|
raise Exception("invalid sample rate")
|
|
|
|
return False
|
|
|
|
sys.exit(-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not os.path.isfile(audio_file):
|
|
|
|
|
|
|
|
logger.error("Please input the right audio file path")
|
|
|
|
|
|
|
|
sys.exit(-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("checking the audio file format......")
|
|
|
|
logger.info("checking the audio file format......")
|
|
|
|
try:
|
|
|
|
try:
|
|
|
@ -431,7 +427,7 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
sample rate: 8k \n \
|
|
|
|
sample rate: 8k \n \
|
|
|
|
sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \
|
|
|
|
sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \
|
|
|
|
")
|
|
|
|
")
|
|
|
|
sys.exit(-1)
|
|
|
|
return False
|
|
|
|
logger.info("The sample rate is %d" % audio_sample_rate)
|
|
|
|
logger.info("The sample rate is %d" % audio_sample_rate)
|
|
|
|
if audio_sample_rate != self.sample_rate:
|
|
|
|
if audio_sample_rate != self.sample_rate:
|
|
|
|
logger.warning("The sample rate of the input file is not {}.\n \
|
|
|
|
logger.warning("The sample rate of the input file is not {}.\n \
|
|
|
@ -465,6 +461,8 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
logger.info("The audio file format is right")
|
|
|
|
logger.info("The audio file format is right")
|
|
|
|
self.change_format = False
|
|
|
|
self.change_format = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
def execute(self, argv: List[str]) -> bool:
|
|
|
|
def execute(self, argv: List[str]) -> bool:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Command line entry.
|
|
|
|
Command line entry.
|
|
|
@ -517,7 +515,8 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
Python API to call an executor.
|
|
|
|
Python API to call an executor.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
audio_file = os.path.abspath(audio_file)
|
|
|
|
audio_file = os.path.abspath(audio_file)
|
|
|
|
self._check(audio_file, sample_rate, force_yes)
|
|
|
|
if not self._check(audio_file, sample_rate, force_yes):
|
|
|
|
|
|
|
|
sys.exit(-1)
|
|
|
|
paddle.set_device(device)
|
|
|
|
paddle.set_device(device)
|
|
|
|
self._init_from_path(model, lang, sample_rate, config, decode_method,
|
|
|
|
self._init_from_path(model, lang, sample_rate, config, decode_method,
|
|
|
|
ckpt_path)
|
|
|
|
ckpt_path)
|
|
|
|