revise the sample rate

pull/1048/head
huangyuxin 4 years ago
parent 90d648a601
commit aee530af27

@ -78,7 +78,7 @@ class ASRExecutor(BaseExecutor):
default='zh', default='zh',
help='Choose model language. zh or en') help='Choose model language. zh or en')
self.parser.add_argument( self.parser.add_argument(
"--model_sample_rate", "--sr",
type=int, type=int,
default=16000, default=16000,
help='Choose the audio sample rate of the model. 8000 or 16000') help='Choose the audio sample rate of the model. 8000 or 16000')
@ -117,7 +117,7 @@ class ASRExecutor(BaseExecutor):
def _init_from_path(self, def _init_from_path(self,
model_type: str='wenetspeech', model_type: str='wenetspeech',
lang: str='zh', lang: str='zh',
model_sample_rate: int=16000, sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None, cfg_path: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None, ckpt_path: Optional[os.PathLike]=None,
device: str='cpu'): device: str='cpu'):
@ -125,8 +125,8 @@ class ASRExecutor(BaseExecutor):
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
if cfg_path is None or ckpt_path is None: if cfg_path is None or ckpt_path is None:
model_sample_rate_str = '16k' if model_sample_rate == 16000 else '8k' sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '_' + lang + '_' + model_sample_rate_str tag = model_type + '_' + lang + '_' + sample_rate_str
res_path = self._get_pretrained_path(tag) # wenetspeech_zh res_path = self._get_pretrained_path(tag) # wenetspeech_zh
self.cfg_path = os.path.join(res_path, self.cfg_path = os.path.join(res_path,
pretrained_models[tag]['cfg_path']) pretrained_models[tag]['cfg_path'])
@ -197,8 +197,6 @@ class ASRExecutor(BaseExecutor):
audio_file = input audio_file = input
logger.info("Preprocess audio_file:" + audio_file) logger.info("Preprocess audio_file:" + audio_file)
config_target_sample_rate = self.config.collator.target_sample_rate
# Get the object for feature extraction # Get the object for feature extraction
if model_type == "ds2_online" or model_type == "ds2_offline": if model_type == "ds2_online" or model_type == "ds2_offline":
audio, _ = self.collate_fn_test.process_utterance( audio, _ = self.collate_fn_test.process_utterance(
@ -222,7 +220,7 @@ class ASRExecutor(BaseExecutor):
preprocess_args = {"train": False} preprocess_args = {"train": False}
preprocessing = Transformation(preprocess_conf) preprocessing = Transformation(preprocess_conf)
logger.info("read the audio file") logger.info("read the audio file")
audio, sample_rate = soundfile.read( audio, audio_sample_rate = soundfile.read(
audio_file, dtype="int16", always_2d=True) audio_file, dtype="int16", always_2d=True)
if self.change_format: if self.change_format:
@ -231,17 +229,13 @@ class ASRExecutor(BaseExecutor):
else: else:
audio = audio[:, 0] audio = audio[:, 0]
audio = audio.astype("float32") audio = audio.astype("float32")
audio = librosa.resample(audio, sample_rate, audio = librosa.resample(audio, audio_sample_rate,
self.target_sample_rate) self.sample_rate)
sample_rate = self.target_sample_rate audio_sample_rate = self.sample_rate
audio = audio.astype("int16") audio = audio.astype("int16")
else: else:
audio = audio[:, 0] audio = audio[:, 0]
if sample_rate != config_target_sample_rate:
logger.error(
f"sample rate error: {sample_rate}, need {self.sr} ")
sys.exit(-1)
logger.info(f"audio shape: {audio.shape}") logger.info(f"audio shape: {audio.shape}")
# fbank # fbank
audio = preprocessing(audio, **preprocess_args) audio = preprocessing(audio, **preprocess_args)
@ -313,11 +307,11 @@ class ASRExecutor(BaseExecutor):
""" """
return self._outputs["result"] return self._outputs["result"]
def _check(self, audio_file: str, model_sample_rate: int): def _check(self, audio_file: str, sample_rate: int):
self.target_sample_rate = model_sample_rate self.sample_rate = sample_rate
if self.target_sample_rate != 16000 and self.target_sample_rate != 8000: if self.sample_rate != 16000 and self.sample_rate != 8000:
logger.error( logger.error(
"please input --model_sample_rate 8000 or --model_sample_rate 16000" "please input --sr 8000 or --sr 16000"
) )
raise Exception("invalid sample rate") raise Exception("invalid sample rate")
sys.exit(-1) sys.exit(-1)
@ -328,7 +322,7 @@ class ASRExecutor(BaseExecutor):
logger.info("checking the audio file format......") logger.info("checking the audio file format......")
try: try:
sig, sample_rate = soundfile.read( audio, audio_sample_rate = soundfile.read(
audio_file, dtype="int16", always_2d=True) audio_file, dtype="int16", always_2d=True)
except Exception as e: except Exception as e:
logger.error(str(e)) logger.error(str(e))
@ -342,15 +336,15 @@ class ASRExecutor(BaseExecutor):
sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \ sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \
") ")
sys.exit(-1) sys.exit(-1)
logger.info("The sample rate is %d" % sample_rate) logger.info("The sample rate is %d" % audio_sample_rate)
if sample_rate != self.target_sample_rate: if audio_sample_rate != self.sample_rate:
logger.warning( logger.warning(
"The sample rate of the input file is not {}.\n \ "The sample rate of the input file is not {}.\n \
The program will resample the wav file to {}.\n \ The program will resample the wav file to {}.\n \
If the result does not meet your expectations\n \ If the result does not meet your expectations\n \
Please input the 16k 16bit 1 channel wav file. \ Please input the 16k 16bit 1 channel wav file. \
" "
.format(self.target_sample_rate, self.target_sample_rate)) .format(self.sample_rate, self.sample_rate))
while (True): while (True):
logger.info( logger.info(
"Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream." "Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
@ -381,14 +375,14 @@ class ASRExecutor(BaseExecutor):
model = parser_args.model model = parser_args.model
lang = parser_args.lang lang = parser_args.lang
model_sample_rate = parser_args.model_sample_rate sample_rate = parser_args.sr
config = parser_args.config config = parser_args.config
ckpt_path = parser_args.ckpt_path ckpt_path = parser_args.ckpt_path
audio_file = parser_args.input audio_file = parser_args.input
device = parser_args.device device = parser_args.device
try: try:
res = self(model, lang, model_sample_rate, config, ckpt_path, res = self(model, lang, sample_rate, config, ckpt_path,
audio_file, device) audio_file, device)
logger.info('ASR Result: {}'.format(res)) logger.info('ASR Result: {}'.format(res))
return True return True
@ -396,14 +390,14 @@ class ASRExecutor(BaseExecutor):
print(e) print(e)
return False return False
def __call__(self, model, lang, model_sample_rate, config, ckpt_path, def __call__(self, model, lang, sample_rate, config, ckpt_path,
audio_file, device): audio_file, device):
""" """
Python API to call an executor. Python API to call an executor.
""" """
audio_file = os.path.abspath(audio_file) audio_file = os.path.abspath(audio_file)
self._check(audio_file, model_sample_rate) self._check(audio_file, sample_rate)
self._init_from_path(model, lang, model_sample_rate, config, ckpt_path, self._init_from_path(model, lang, sample_rate, config, ckpt_path,
device) device)
self.preprocess(model, audio_file) self.preprocess(model, audio_file)
self.infer(model) self.infer(model)

Loading…
Cancel
Save