pull/1048/head
huangyuxin 3 years ago
parent 957f2e3a1c
commit a9d206c1bf

@ -22,6 +22,7 @@ import librosa
import paddle import paddle
import soundfile import soundfile
from yacs.config import CfgNode from yacs.config import CfgNode
import numpy as np
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..utils import cli_register from ..utils import cli_register
@ -81,6 +82,7 @@ class ASRExecutor(BaseExecutor):
"--sr", "--sr",
type=int, type=int,
default=16000, default=16000,
choices=[8000, 16000],
help='Choose the audio sample rate of the model. 8000 or 16000') help='Choose the audio sample rate of the model. 8000 or 16000')
self.parser.add_argument( self.parser.add_argument(
'--config', '--config',
@ -131,13 +133,13 @@ class ASRExecutor(BaseExecutor):
self.cfg_path = os.path.join(res_path, self.cfg_path = os.path.join(res_path,
pretrained_models[tag]['cfg_path']) pretrained_models[tag]['cfg_path'])
self.ckpt_path = os.path.join(res_path, self.ckpt_path = os.path.join(res_path,
pretrained_models[tag]['ckpt_path']) pretrained_models[tag]['ckpt_path'] + ".pdparams")
logger.info(res_path) logger.info(res_path)
logger.info(self.cfg_path) logger.info(self.cfg_path)
logger.info(self.ckpt_path) logger.info(self.ckpt_path)
else: else:
self.cfg_path = os.path.abspath(cfg_path) self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path) self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams")
res_path = os.path.dirname( res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path))) os.path.dirname(os.path.abspath(self.cfg_path)))
@ -183,8 +185,7 @@ class ASRExecutor(BaseExecutor):
self.model.eval() self.model.eval()
# load model # load model
params_path = self.ckpt_path + ".pdparams" model_dict = paddle.load(self.ckpt_path)
model_dict = paddle.load(params_path)
self.model.set_state_dict(model_dict) self.model.set_state_dict(model_dict)
def preprocess(self, model_type: str, input: Union[str, os.PathLike]): def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
@ -231,7 +232,7 @@ class ASRExecutor(BaseExecutor):
audio = librosa.resample(audio, audio_sample_rate, audio = librosa.resample(audio, audio_sample_rate,
self.sample_rate) self.sample_rate)
audio_sample_rate = self.sample_rate audio_sample_rate = self.sample_rate
audio = audio.astype("int16") audio = np.round(audio).astype("int16")
else: else:
audio = audio[:, 0] audio = audio[:, 0]

Loading…
Cancel
Save