revise the sample rate

pull/1048/head
huangyuxin 3 years ago
parent 90d648a601
commit aee530af27

@ -78,7 +78,7 @@ class ASRExecutor(BaseExecutor):
default='zh',
help='Choose model language. zh or en')
self.parser.add_argument(
"--model_sample_rate",
"--sr",
type=int,
default=16000,
help='Choose the audio sample rate of the model. 8000 or 16000')
@ -117,7 +117,7 @@ class ASRExecutor(BaseExecutor):
def _init_from_path(self,
model_type: str='wenetspeech',
lang: str='zh',
model_sample_rate: int=16000,
sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None,
device: str='cpu'):
@ -125,8 +125,8 @@ class ASRExecutor(BaseExecutor):
Init model and other resources from a specific path.
"""
if cfg_path is None or ckpt_path is None:
model_sample_rate_str = '16k' if model_sample_rate == 16000 else '8k'
tag = model_type + '_' + lang + '_' + model_sample_rate_str
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '_' + lang + '_' + sample_rate_str
res_path = self._get_pretrained_path(tag) # wenetspeech_zh
self.cfg_path = os.path.join(res_path,
pretrained_models[tag]['cfg_path'])
@ -197,8 +197,6 @@ class ASRExecutor(BaseExecutor):
audio_file = input
logger.info("Preprocess audio_file:" + audio_file)
config_target_sample_rate = self.config.collator.target_sample_rate
# Get the object for feature extraction
if model_type == "ds2_online" or model_type == "ds2_offline":
audio, _ = self.collate_fn_test.process_utterance(
@ -222,7 +220,7 @@ class ASRExecutor(BaseExecutor):
preprocess_args = {"train": False}
preprocessing = Transformation(preprocess_conf)
logger.info("read the audio file")
audio, sample_rate = soundfile.read(
audio, audio_sample_rate = soundfile.read(
audio_file, dtype="int16", always_2d=True)
if self.change_format:
@ -231,17 +229,13 @@ class ASRExecutor(BaseExecutor):
else:
audio = audio[:, 0]
audio = audio.astype("float32")
audio = librosa.resample(audio, sample_rate,
self.target_sample_rate)
sample_rate = self.target_sample_rate
audio = librosa.resample(audio, audio_sample_rate,
self.sample_rate)
audio_sample_rate = self.sample_rate
audio = audio.astype("int16")
else:
audio = audio[:, 0]
if sample_rate != config_target_sample_rate:
logger.error(
f"sample rate error: {sample_rate}, need {self.sr} ")
sys.exit(-1)
logger.info(f"audio shape: {audio.shape}")
# fbank
audio = preprocessing(audio, **preprocess_args)
@ -313,11 +307,11 @@ class ASRExecutor(BaseExecutor):
"""
return self._outputs["result"]
def _check(self, audio_file: str, model_sample_rate: int):
self.target_sample_rate = model_sample_rate
if self.target_sample_rate != 16000 and self.target_sample_rate != 8000:
def _check(self, audio_file: str, sample_rate: int):
self.sample_rate = sample_rate
if self.sample_rate != 16000 and self.sample_rate != 8000:
logger.error(
"please input --model_sample_rate 8000 or --model_sample_rate 16000"
"please input --sr 8000 or --sr 16000"
)
raise Exception("invalid sample rate")
sys.exit(-1)
@ -328,7 +322,7 @@ class ASRExecutor(BaseExecutor):
logger.info("checking the audio file format......")
try:
sig, sample_rate = soundfile.read(
audio, audio_sample_rate = soundfile.read(
audio_file, dtype="int16", always_2d=True)
except Exception as e:
logger.error(str(e))
@ -342,15 +336,15 @@ class ASRExecutor(BaseExecutor):
sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \
")
sys.exit(-1)
logger.info("The sample rate is %d" % sample_rate)
if sample_rate != self.target_sample_rate:
logger.info("The sample rate is %d" % audio_sample_rate)
if audio_sample_rate != self.sample_rate:
logger.warning(
"The sample rate of the input file is not {}.\n \
The program will resample the wav file to {}.\n \
If the result does not meet your expectations\n \
Please input the 16k 16bit 1 channel wav file. \
"
.format(self.target_sample_rate, self.target_sample_rate))
.format(self.sample_rate, self.sample_rate))
while (True):
logger.info(
"Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
@ -381,14 +375,14 @@ class ASRExecutor(BaseExecutor):
model = parser_args.model
lang = parser_args.lang
model_sample_rate = parser_args.model_sample_rate
sample_rate = parser_args.sr
config = parser_args.config
ckpt_path = parser_args.ckpt_path
audio_file = parser_args.input
device = parser_args.device
try:
res = self(model, lang, model_sample_rate, config, ckpt_path,
res = self(model, lang, sample_rate, config, ckpt_path,
audio_file, device)
logger.info('ASR Result: {}'.format(res))
return True
@ -396,14 +390,14 @@ class ASRExecutor(BaseExecutor):
print(e)
return False
def __call__(self, model, lang, model_sample_rate, config, ckpt_path,
def __call__(self, model, lang, sample_rate, config, ckpt_path,
audio_file, device):
"""
Python API to call an executor.
"""
audio_file = os.path.abspath(audio_file)
self._check(audio_file, model_sample_rate)
self._init_from_path(model, lang, model_sample_rate, config, ckpt_path,
self._check(audio_file, sample_rate)
self._init_from_path(model, lang, sample_rate, config, ckpt_path,
device)
self.preprocess(model, audio_file)
self.infer(model)

Loading…
Cancel
Save