|
|
@ -22,6 +22,7 @@ import librosa
|
|
|
|
import paddle
|
|
|
|
import paddle
|
|
|
|
import soundfile
|
|
|
|
import soundfile
|
|
|
|
from yacs.config import CfgNode
|
|
|
|
from yacs.config import CfgNode
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
from ..executor import BaseExecutor
|
|
|
|
from ..executor import BaseExecutor
|
|
|
|
from ..utils import cli_register
|
|
|
|
from ..utils import cli_register
|
|
|
@ -78,9 +79,10 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
default='zh',
|
|
|
|
default='zh',
|
|
|
|
help='Choose model language. zh or en')
|
|
|
|
help='Choose model language. zh or en')
|
|
|
|
self.parser.add_argument(
|
|
|
|
self.parser.add_argument(
|
|
|
|
"--model_sample_rate",
|
|
|
|
"--sr",
|
|
|
|
type=int,
|
|
|
|
type=int,
|
|
|
|
default=16000,
|
|
|
|
default=16000,
|
|
|
|
|
|
|
|
choices=[8000, 16000],
|
|
|
|
help='Choose the audio sample rate of the model. 8000 or 16000')
|
|
|
|
help='Choose the audio sample rate of the model. 8000 or 16000')
|
|
|
|
self.parser.add_argument(
|
|
|
|
self.parser.add_argument(
|
|
|
|
'--config',
|
|
|
|
'--config',
|
|
|
@ -117,26 +119,27 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
def _init_from_path(self,
|
|
|
|
def _init_from_path(self,
|
|
|
|
model_type: str='wenetspeech',
|
|
|
|
model_type: str='wenetspeech',
|
|
|
|
lang: str='zh',
|
|
|
|
lang: str='zh',
|
|
|
|
model_sample_rate: int=16000,
|
|
|
|
sample_rate: int=16000,
|
|
|
|
cfg_path: Optional[os.PathLike]=None,
|
|
|
|
cfg_path: Optional[os.PathLike]=None,
|
|
|
|
ckpt_path: Optional[os.PathLike]=None):
|
|
|
|
ckpt_path: Optional[os.PathLike]=None
|
|
|
|
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Init model and other resources from a specific path.
|
|
|
|
Init model and other resources from a specific path.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
if cfg_path is None or ckpt_path is None:
|
|
|
|
if cfg_path is None or ckpt_path is None:
|
|
|
|
model_sample_rate_str = '16k' if model_sample_rate == 16000 else '8k'
|
|
|
|
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
|
|
|
|
tag = model_type + '_' + lang + '_' + model_sample_rate_str
|
|
|
|
tag = model_type + '_' + lang + '_' + sample_rate_str
|
|
|
|
res_path = self._get_pretrained_path(tag) # wenetspeech_zh
|
|
|
|
res_path = self._get_pretrained_path(tag) # wenetspeech_zh
|
|
|
|
self.cfg_path = os.path.join(res_path,
|
|
|
|
self.cfg_path = os.path.join(res_path,
|
|
|
|
pretrained_models[tag]['cfg_path'])
|
|
|
|
pretrained_models[tag]['cfg_path'])
|
|
|
|
self.ckpt_path = os.path.join(res_path,
|
|
|
|
self.ckpt_path = os.path.join(res_path,
|
|
|
|
pretrained_models[tag]['ckpt_path'])
|
|
|
|
pretrained_models[tag]['ckpt_path'] + ".pdparams")
|
|
|
|
logger.info(res_path)
|
|
|
|
logger.info(res_path)
|
|
|
|
logger.info(self.cfg_path)
|
|
|
|
logger.info(self.cfg_path)
|
|
|
|
logger.info(self.ckpt_path)
|
|
|
|
logger.info(self.ckpt_path)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
self.cfg_path = os.path.abspath(cfg_path)
|
|
|
|
self.cfg_path = os.path.abspath(cfg_path)
|
|
|
|
self.ckpt_path = os.path.abspath(ckpt_path)
|
|
|
|
self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams")
|
|
|
|
res_path = os.path.dirname(
|
|
|
|
res_path = os.path.dirname(
|
|
|
|
os.path.dirname(os.path.abspath(self.cfg_path)))
|
|
|
|
os.path.dirname(os.path.abspath(self.cfg_path)))
|
|
|
|
|
|
|
|
|
|
|
@ -182,8 +185,7 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
self.model.eval()
|
|
|
|
self.model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
# load model
|
|
|
|
# load model
|
|
|
|
params_path = self.ckpt_path + ".pdparams"
|
|
|
|
model_dict = paddle.load(self.ckpt_path)
|
|
|
|
model_dict = paddle.load(params_path)
|
|
|
|
|
|
|
|
self.model.set_state_dict(model_dict)
|
|
|
|
self.model.set_state_dict(model_dict)
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
|
|
|
|
def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
|
|
|
@ -195,8 +197,6 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
audio_file = input
|
|
|
|
audio_file = input
|
|
|
|
logger.info("Preprocess audio_file:" + audio_file)
|
|
|
|
logger.info("Preprocess audio_file:" + audio_file)
|
|
|
|
|
|
|
|
|
|
|
|
config_target_sample_rate = self.config.collator.target_sample_rate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Get the object for feature extraction
|
|
|
|
# Get the object for feature extraction
|
|
|
|
if model_type == "ds2_online" or model_type == "ds2_offline":
|
|
|
|
if model_type == "ds2_online" or model_type == "ds2_offline":
|
|
|
|
audio, _ = self.collate_fn_test.process_utterance(
|
|
|
|
audio, _ = self.collate_fn_test.process_utterance(
|
|
|
@ -220,7 +220,7 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
preprocess_args = {"train": False}
|
|
|
|
preprocess_args = {"train": False}
|
|
|
|
preprocessing = Transformation(preprocess_conf)
|
|
|
|
preprocessing = Transformation(preprocess_conf)
|
|
|
|
logger.info("read the audio file")
|
|
|
|
logger.info("read the audio file")
|
|
|
|
audio, sample_rate = soundfile.read(
|
|
|
|
audio, audio_sample_rate = soundfile.read(
|
|
|
|
audio_file, dtype="int16", always_2d=True)
|
|
|
|
audio_file, dtype="int16", always_2d=True)
|
|
|
|
|
|
|
|
|
|
|
|
if self.change_format:
|
|
|
|
if self.change_format:
|
|
|
@ -229,17 +229,13 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
audio = audio[:, 0]
|
|
|
|
audio = audio[:, 0]
|
|
|
|
audio = audio.astype("float32")
|
|
|
|
audio = audio.astype("float32")
|
|
|
|
audio = librosa.resample(audio, sample_rate,
|
|
|
|
audio = librosa.resample(audio, audio_sample_rate,
|
|
|
|
self.target_sample_rate)
|
|
|
|
self.sample_rate)
|
|
|
|
sample_rate = self.target_sample_rate
|
|
|
|
audio_sample_rate = self.sample_rate
|
|
|
|
audio = audio.astype("int16")
|
|
|
|
audio = np.round(audio).astype("int16")
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
audio = audio[:, 0]
|
|
|
|
audio = audio[:, 0]
|
|
|
|
|
|
|
|
|
|
|
|
if sample_rate != config_target_sample_rate:
|
|
|
|
|
|
|
|
logger.error(
|
|
|
|
|
|
|
|
f"sample rate error: {sample_rate}, need {self.sr} ")
|
|
|
|
|
|
|
|
sys.exit(-1)
|
|
|
|
|
|
|
|
logger.info(f"audio shape: {audio.shape}")
|
|
|
|
logger.info(f"audio shape: {audio.shape}")
|
|
|
|
# fbank
|
|
|
|
# fbank
|
|
|
|
audio = preprocessing(audio, **preprocess_args)
|
|
|
|
audio = preprocessing(audio, **preprocess_args)
|
|
|
@ -311,11 +307,11 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
return self._outputs["result"]
|
|
|
|
return self._outputs["result"]
|
|
|
|
|
|
|
|
|
|
|
|
def _check(self, audio_file: str, model_sample_rate: int):
|
|
|
|
def _check(self, audio_file: str, sample_rate: int):
|
|
|
|
self.target_sample_rate = model_sample_rate
|
|
|
|
self.sample_rate = sample_rate
|
|
|
|
if self.target_sample_rate != 16000 and self.target_sample_rate != 8000:
|
|
|
|
if self.sample_rate != 16000 and self.sample_rate != 8000:
|
|
|
|
logger.error(
|
|
|
|
logger.error(
|
|
|
|
"please input --model_sample_rate 8000 or --model_sample_rate 16000"
|
|
|
|
"please input --sr 8000 or --sr 16000"
|
|
|
|
)
|
|
|
|
)
|
|
|
|
raise Exception("invalid sample rate")
|
|
|
|
raise Exception("invalid sample rate")
|
|
|
|
sys.exit(-1)
|
|
|
|
sys.exit(-1)
|
|
|
@ -326,7 +322,7 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("checking the audio file format......")
|
|
|
|
logger.info("checking the audio file format......")
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
sig, sample_rate = soundfile.read(
|
|
|
|
audio, audio_sample_rate = soundfile.read(
|
|
|
|
audio_file, dtype="int16", always_2d=True)
|
|
|
|
audio_file, dtype="int16", always_2d=True)
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(str(e))
|
|
|
|
logger.error(str(e))
|
|
|
@ -340,15 +336,15 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \
|
|
|
|
sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \
|
|
|
|
")
|
|
|
|
")
|
|
|
|
sys.exit(-1)
|
|
|
|
sys.exit(-1)
|
|
|
|
logger.info("The sample rate is %d" % sample_rate)
|
|
|
|
logger.info("The sample rate is %d" % audio_sample_rate)
|
|
|
|
if sample_rate != self.target_sample_rate:
|
|
|
|
if audio_sample_rate != self.sample_rate:
|
|
|
|
logger.warning(
|
|
|
|
logger.warning(
|
|
|
|
"The sample rate of the input file is not {}.\n \
|
|
|
|
"The sample rate of the input file is not {}.\n \
|
|
|
|
The program will resample the wav file to {}.\n \
|
|
|
|
The program will resample the wav file to {}.\n \
|
|
|
|
If the result does not meet your expectations,\n \
|
|
|
|
If the result does not meet your expectations,\n \
|
|
|
|
Please input the 16k 16bit 1 channel wav file. \
|
|
|
|
Please input the 16k 16bit 1 channel wav file. \
|
|
|
|
"
|
|
|
|
"
|
|
|
|
.format(self.target_sample_rate, self.target_sample_rate))
|
|
|
|
.format(self.sample_rate, self.sample_rate))
|
|
|
|
while (True):
|
|
|
|
while (True):
|
|
|
|
logger.info(
|
|
|
|
logger.info(
|
|
|
|
"Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
|
|
|
|
"Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
|
|
|
@ -379,14 +375,14 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
|
|
|
|
|
|
|
|
model = parser_args.model
|
|
|
|
model = parser_args.model
|
|
|
|
lang = parser_args.lang
|
|
|
|
lang = parser_args.lang
|
|
|
|
model_sample_rate = parser_args.model_sample_rate
|
|
|
|
sample_rate = parser_args.sr
|
|
|
|
config = parser_args.config
|
|
|
|
config = parser_args.config
|
|
|
|
ckpt_path = parser_args.ckpt_path
|
|
|
|
ckpt_path = parser_args.ckpt_path
|
|
|
|
audio_file = parser_args.input
|
|
|
|
audio_file = parser_args.input
|
|
|
|
device = parser_args.device
|
|
|
|
device = parser_args.device
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
res = self(model, lang, model_sample_rate, config, ckpt_path,
|
|
|
|
res = self(model, lang, sample_rate, config, ckpt_path,
|
|
|
|
audio_file, device)
|
|
|
|
audio_file, device)
|
|
|
|
logger.info('ASR Result: {}'.format(res))
|
|
|
|
logger.info('ASR Result: {}'.format(res))
|
|
|
|
return True
|
|
|
|
return True
|
|
|
@ -394,16 +390,15 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
print(e)
|
|
|
|
print(e)
|
|
|
|
return False
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, model, lang, model_sample_rate, config, ckpt_path,
|
|
|
|
def __call__(self, model, lang, sample_rate, config, ckpt_path,
|
|
|
|
audio_file, device):
|
|
|
|
audio_file, device):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Python API to call an executor.
|
|
|
|
Python API to call an executor.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
audio_file = os.path.abspath(audio_file)
|
|
|
|
audio_file = os.path.abspath(audio_file)
|
|
|
|
self._check(audio_file, model_sample_rate)
|
|
|
|
self._check(audio_file, sample_rate)
|
|
|
|
|
|
|
|
|
|
|
|
paddle.set_device(device)
|
|
|
|
paddle.set_device(device)
|
|
|
|
self._init_from_path(model, lang, model_sample_rate, config, ckpt_path)
|
|
|
|
self._init_from_path(model, lang, sample_rate, config, ckpt_path)
|
|
|
|
self.preprocess(model, audio_file)
|
|
|
|
self.preprocess(model, audio_file)
|
|
|
|
self.infer(model)
|
|
|
|
self.infer(model)
|
|
|
|
res = self.postprocess() # Retrieve result of asr.
|
|
|
|
res = self.postprocess() # Retrieve result of asr.
|
|
|
|