pull/1048/head
huangyuxin 3 years ago
parent 957f2e3a1c
commit b0356ae489

@ -22,6 +22,7 @@ import librosa
import paddle
import soundfile
from yacs.config import CfgNode
import numpy as np
from ..executor import BaseExecutor
from ..utils import cli_register
@ -81,6 +82,7 @@ class ASRExecutor(BaseExecutor):
"--sr",
type=int,
default=16000,
choices=[8000, 16000],
help='Choose the audio sample rate of the model. 8000 or 16000')
self.parser.add_argument(
'--config',
@ -131,13 +133,13 @@ class ASRExecutor(BaseExecutor):
self.cfg_path = os.path.join(res_path,
pretrained_models[tag]['cfg_path'])
self.ckpt_path = os.path.join(res_path,
pretrained_models[tag]['ckpt_path'])
pretrained_models[tag]['ckpt_path'] + ".pdparams")
logger.info(res_path)
logger.info(self.cfg_path)
logger.info(self.ckpt_path)
else:
self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path)
self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams")
res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path)))
@ -183,8 +185,7 @@ class ASRExecutor(BaseExecutor):
self.model.eval()
# load model
params_path = self.ckpt_path + ".pdparams"
model_dict = paddle.load(params_path)
model_dict = paddle.load(self.ckpt_path)
self.model.set_state_dict(model_dict)
def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
@ -227,11 +228,16 @@ class ASRExecutor(BaseExecutor):
audio = audio.mean(axis=1)
else:
audio = audio[:, 0]
# pcm16 -> pcm 32
audio = audio.astype("float32")
bits = np.iinfo(np.int16).bits
audio = audio / (2**(bits - 1))
audio = librosa.resample(audio, audio_sample_rate,
self.sample_rate)
audio_sample_rate = self.sample_rate
audio = audio.astype("int16")
# pcm16 -> pcm 32
audio = audio * (2**(bits - 1))
audio = np.round(audio).astype("int16")
else:
audio = audio[:, 0]
@ -341,7 +347,7 @@ class ASRExecutor(BaseExecutor):
"The sample rate of the input file is not {}.\n \
The program will resample the wav file to {}.\n \
If the result does not meet your expectations\n \
Please input the 16k 16bit 1 channel wav file. \
Please input the 16k 16 bit 1 channel wav file. \
"
.format(self.sample_rate, self.sample_rate))
while (True):

Loading…
Cancel
Save