pull/1048/head
huangyuxin 3 years ago
parent 957f2e3a1c
commit a9d206c1bf

@ -22,6 +22,7 @@ import librosa
import paddle
import soundfile
from yacs.config import CfgNode
import numpy as np
from ..executor import BaseExecutor
from ..utils import cli_register
@ -81,6 +82,7 @@ class ASRExecutor(BaseExecutor):
"--sr",
type=int,
default=16000,
choices=[8000, 16000],
help='Choose the audio sample rate of the model. 8000 or 16000')
self.parser.add_argument(
'--config',
@ -131,13 +133,13 @@ class ASRExecutor(BaseExecutor):
self.cfg_path = os.path.join(res_path,
pretrained_models[tag]['cfg_path'])
self.ckpt_path = os.path.join(res_path,
pretrained_models[tag]['ckpt_path'])
pretrained_models[tag]['ckpt_path'] + ".pdparams")
logger.info(res_path)
logger.info(self.cfg_path)
logger.info(self.ckpt_path)
else:
self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path)
self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams")
res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path)))
@ -183,8 +185,7 @@ class ASRExecutor(BaseExecutor):
self.model.eval()
# load model
params_path = self.ckpt_path + ".pdparams"
model_dict = paddle.load(params_path)
model_dict = paddle.load(self.ckpt_path)
self.model.set_state_dict(model_dict)
def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
@ -231,7 +232,7 @@ class ASRExecutor(BaseExecutor):
audio = librosa.resample(audio, audio_sample_rate,
self.sample_rate)
audio_sample_rate = self.sample_rate
audio = audio.astype("int16")
audio = np.round(audio).astype("int16")
else:
audio = audio[:, 0]

Loading…
Cancel
Save