From a9d206c1bfc433f1aec6cebbb783f8539e0bb6a9 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Thu, 2 Dec 2021 05:58:20 +0000 Subject: [PATCH] revise --- paddlespeech/cli/asr/infer.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 6ae038539..e9d8c0b11 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -22,6 +22,7 @@ import librosa import paddle import soundfile from yacs.config import CfgNode +import numpy as np from ..executor import BaseExecutor from ..utils import cli_register @@ -81,6 +82,7 @@ class ASRExecutor(BaseExecutor): "--sr", type=int, default=16000, + choices=[8000, 16000], help='Choose the audio sample rate of the model. 8000 or 16000') self.parser.add_argument( '--config', @@ -131,13 +133,13 @@ class ASRExecutor(BaseExecutor): self.cfg_path = os.path.join(res_path, pretrained_models[tag]['cfg_path']) self.ckpt_path = os.path.join(res_path, - pretrained_models[tag]['ckpt_path']) + pretrained_models[tag]['ckpt_path'] + ".pdparams") logger.info(res_path) logger.info(self.cfg_path) logger.info(self.ckpt_path) else: self.cfg_path = os.path.abspath(cfg_path) - self.ckpt_path = os.path.abspath(ckpt_path) + self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams") res_path = os.path.dirname( os.path.dirname(os.path.abspath(self.cfg_path))) @@ -183,8 +185,7 @@ class ASRExecutor(BaseExecutor): self.model.eval() # load model - params_path = self.ckpt_path + ".pdparams" - model_dict = paddle.load(params_path) + model_dict = paddle.load(self.ckpt_path) self.model.set_state_dict(model_dict) def preprocess(self, model_type: str, input: Union[str, os.PathLike]): @@ -231,7 +232,7 @@ class ASRExecutor(BaseExecutor): audio = librosa.resample(audio, audio_sample_rate, self.sample_rate) audio_sample_rate = self.sample_rate - audio = audio.astype("int16") + audio = np.round(audio).astype("int16") else: audio = audio[:, 0]