|
|
|
@ -78,7 +78,7 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
|
default='zh',
|
|
|
|
|
help='Choose model language. zh or en')
|
|
|
|
|
self.parser.add_argument(
|
|
|
|
|
"--model_sample_rate",
|
|
|
|
|
"--sr",
|
|
|
|
|
type=int,
|
|
|
|
|
default=16000,
|
|
|
|
|
help='Choose the audio sample rate of the model. 8000 or 16000')
|
|
|
|
@ -117,7 +117,7 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
|
def _init_from_path(self,
|
|
|
|
|
model_type: str='wenetspeech',
|
|
|
|
|
lang: str='zh',
|
|
|
|
|
model_sample_rate: int=16000,
|
|
|
|
|
sample_rate: int=16000,
|
|
|
|
|
cfg_path: Optional[os.PathLike]=None,
|
|
|
|
|
ckpt_path: Optional[os.PathLike]=None,
|
|
|
|
|
device: str='cpu'):
|
|
|
|
@ -125,8 +125,8 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
|
Init model and other resources from a specific path.
|
|
|
|
|
"""
|
|
|
|
|
if cfg_path is None or ckpt_path is None:
|
|
|
|
|
model_sample_rate_str = '16k' if model_sample_rate == 16000 else '8k'
|
|
|
|
|
tag = model_type + '_' + lang + '_' + model_sample_rate_str
|
|
|
|
|
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
|
|
|
|
|
tag = model_type + '_' + lang + '_' + sample_rate_str
|
|
|
|
|
res_path = self._get_pretrained_path(tag) # wenetspeech_zh
|
|
|
|
|
self.cfg_path = os.path.join(res_path,
|
|
|
|
|
pretrained_models[tag]['cfg_path'])
|
|
|
|
@ -197,8 +197,6 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
|
audio_file = input
|
|
|
|
|
logger.info("Preprocess audio_file:" + audio_file)
|
|
|
|
|
|
|
|
|
|
config_target_sample_rate = self.config.collator.target_sample_rate
|
|
|
|
|
|
|
|
|
|
# Get the object for feature extraction
|
|
|
|
|
if model_type == "ds2_online" or model_type == "ds2_offline":
|
|
|
|
|
audio, _ = self.collate_fn_test.process_utterance(
|
|
|
|
@ -222,7 +220,7 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
|
preprocess_args = {"train": False}
|
|
|
|
|
preprocessing = Transformation(preprocess_conf)
|
|
|
|
|
logger.info("read the audio file")
|
|
|
|
|
audio, sample_rate = soundfile.read(
|
|
|
|
|
audio, audio_sample_rate = soundfile.read(
|
|
|
|
|
audio_file, dtype="int16", always_2d=True)
|
|
|
|
|
|
|
|
|
|
if self.change_format:
|
|
|
|
@ -231,17 +229,13 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
|
else:
|
|
|
|
|
audio = audio[:, 0]
|
|
|
|
|
audio = audio.astype("float32")
|
|
|
|
|
audio = librosa.resample(audio, sample_rate,
|
|
|
|
|
self.target_sample_rate)
|
|
|
|
|
sample_rate = self.target_sample_rate
|
|
|
|
|
audio = librosa.resample(audio, audio_sample_rate,
|
|
|
|
|
self.sample_rate)
|
|
|
|
|
audio_sample_rate = self.sample_rate
|
|
|
|
|
audio = audio.astype("int16")
|
|
|
|
|
else:
|
|
|
|
|
audio = audio[:, 0]
|
|
|
|
|
|
|
|
|
|
if sample_rate != config_target_sample_rate:
|
|
|
|
|
logger.error(
|
|
|
|
|
f"sample rate error: {sample_rate}, need {self.sr} ")
|
|
|
|
|
sys.exit(-1)
|
|
|
|
|
logger.info(f"audio shape: {audio.shape}")
|
|
|
|
|
# fbank
|
|
|
|
|
audio = preprocessing(audio, **preprocess_args)
|
|
|
|
@ -313,11 +307,11 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
|
"""
|
|
|
|
|
return self._outputs["result"]
|
|
|
|
|
|
|
|
|
|
def _check(self, audio_file: str, model_sample_rate: int):
|
|
|
|
|
self.target_sample_rate = model_sample_rate
|
|
|
|
|
if self.target_sample_rate != 16000 and self.target_sample_rate != 8000:
|
|
|
|
|
def _check(self, audio_file: str, sample_rate: int):
|
|
|
|
|
self.sample_rate = sample_rate
|
|
|
|
|
if self.sample_rate != 16000 and self.sample_rate != 8000:
|
|
|
|
|
logger.error(
|
|
|
|
|
"please input --model_sample_rate 8000 or --model_sample_rate 16000"
|
|
|
|
|
"please input --sr 8000 or --sr 16000"
|
|
|
|
|
)
|
|
|
|
|
raise Exception("invalid sample rate")
|
|
|
|
|
sys.exit(-1)
|
|
|
|
@ -328,7 +322,7 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
|
|
|
|
|
|
logger.info("checking the audio file format......")
|
|
|
|
|
try:
|
|
|
|
|
sig, sample_rate = soundfile.read(
|
|
|
|
|
audio, audio_sample_rate = soundfile.read(
|
|
|
|
|
audio_file, dtype="int16", always_2d=True)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(str(e))
|
|
|
|
@ -342,15 +336,15 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
|
sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \
|
|
|
|
|
")
|
|
|
|
|
sys.exit(-1)
|
|
|
|
|
logger.info("The sample rate is %d" % sample_rate)
|
|
|
|
|
if sample_rate != self.target_sample_rate:
|
|
|
|
|
logger.info("The sample rate is %d" % audio_sample_rate)
|
|
|
|
|
if audio_sample_rate != self.sample_rate:
|
|
|
|
|
logger.warning(
|
|
|
|
|
"The sample rate of the input file is not {}.\n \
|
|
|
|
|
The program will resample the wav file to {}.\n \
|
|
|
|
|
If the result does not meet your expectations,\n \
|
|
|
|
|
Please input the 16k 16bit 1 channel wav file. \
|
|
|
|
|
"
|
|
|
|
|
.format(self.target_sample_rate, self.target_sample_rate))
|
|
|
|
|
.format(self.sample_rate, self.sample_rate))
|
|
|
|
|
while (True):
|
|
|
|
|
logger.info(
|
|
|
|
|
"Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
|
|
|
|
@ -381,14 +375,14 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
|
|
|
|
|
|
model = parser_args.model
|
|
|
|
|
lang = parser_args.lang
|
|
|
|
|
model_sample_rate = parser_args.model_sample_rate
|
|
|
|
|
sample_rate = parser_args.sr
|
|
|
|
|
config = parser_args.config
|
|
|
|
|
ckpt_path = parser_args.ckpt_path
|
|
|
|
|
audio_file = parser_args.input
|
|
|
|
|
device = parser_args.device
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
res = self(model, lang, model_sample_rate, config, ckpt_path,
|
|
|
|
|
res = self(model, lang, sample_rate, config, ckpt_path,
|
|
|
|
|
audio_file, device)
|
|
|
|
|
logger.info('ASR Result: {}'.format(res))
|
|
|
|
|
return True
|
|
|
|
@ -396,14 +390,14 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
|
print(e)
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def __call__(self, model, lang, model_sample_rate, config, ckpt_path,
|
|
|
|
|
def __call__(self, model, lang, sample_rate, config, ckpt_path,
|
|
|
|
|
audio_file, device):
|
|
|
|
|
"""
|
|
|
|
|
Python API to call an executor.
|
|
|
|
|
"""
|
|
|
|
|
audio_file = os.path.abspath(audio_file)
|
|
|
|
|
self._check(audio_file, model_sample_rate)
|
|
|
|
|
self._init_from_path(model, lang, model_sample_rate, config, ckpt_path,
|
|
|
|
|
self._check(audio_file, sample_rate)
|
|
|
|
|
self._init_from_path(model, lang, sample_rate, config, ckpt_path,
|
|
|
|
|
device)
|
|
|
|
|
self.preprocess(model, audio_file)
|
|
|
|
|
self.infer(model)
|
|
|
|
|