pull/1048/head
huangyuxin 3 years ago
parent 957f2e3a1c
commit b0356ae489

@ -22,6 +22,7 @@ import librosa
import paddle import paddle
import soundfile import soundfile
from yacs.config import CfgNode from yacs.config import CfgNode
import numpy as np
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..utils import cli_register from ..utils import cli_register
@ -81,6 +82,7 @@ class ASRExecutor(BaseExecutor):
"--sr", "--sr",
type=int, type=int,
default=16000, default=16000,
choices=[8000, 16000],
help='Choose the audio sample rate of the model. 8000 or 16000') help='Choose the audio sample rate of the model. 8000 or 16000')
self.parser.add_argument( self.parser.add_argument(
'--config', '--config',
@ -131,13 +133,13 @@ class ASRExecutor(BaseExecutor):
self.cfg_path = os.path.join(res_path, self.cfg_path = os.path.join(res_path,
pretrained_models[tag]['cfg_path']) pretrained_models[tag]['cfg_path'])
self.ckpt_path = os.path.join(res_path, self.ckpt_path = os.path.join(res_path,
pretrained_models[tag]['ckpt_path']) pretrained_models[tag]['ckpt_path'] + ".pdparams")
logger.info(res_path) logger.info(res_path)
logger.info(self.cfg_path) logger.info(self.cfg_path)
logger.info(self.ckpt_path) logger.info(self.ckpt_path)
else: else:
self.cfg_path = os.path.abspath(cfg_path) self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path) self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams")
res_path = os.path.dirname( res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path))) os.path.dirname(os.path.abspath(self.cfg_path)))
@ -183,8 +185,7 @@ class ASRExecutor(BaseExecutor):
self.model.eval() self.model.eval()
# load model # load model
params_path = self.ckpt_path + ".pdparams" model_dict = paddle.load(self.ckpt_path)
model_dict = paddle.load(params_path)
self.model.set_state_dict(model_dict) self.model.set_state_dict(model_dict)
def preprocess(self, model_type: str, input: Union[str, os.PathLike]): def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
@ -227,11 +228,16 @@ class ASRExecutor(BaseExecutor):
audio = audio.mean(axis=1) audio = audio.mean(axis=1)
else: else:
audio = audio[:, 0] audio = audio[:, 0]
# pcm16 -> pcm 32
audio = audio.astype("float32") audio = audio.astype("float32")
bits = np.iinfo(np.int16).bits
audio = audio / (2**(bits - 1))
audio = librosa.resample(audio, audio_sample_rate, audio = librosa.resample(audio, audio_sample_rate,
self.sample_rate) self.sample_rate)
audio_sample_rate = self.sample_rate audio_sample_rate = self.sample_rate
audio = audio.astype("int16") # pcm16 -> pcm 32
audio = audio * (2**(bits - 1))
audio = np.round(audio).astype("int16")
else: else:
audio = audio[:, 0] audio = audio[:, 0]
@ -341,7 +347,7 @@ class ASRExecutor(BaseExecutor):
"The sample rate of the input file is not {}.\n \ "The sample rate of the input file is not {}.\n \
The program will resample the wav file to {}.\n \ The program will resample the wav file to {}.\n \
If the result does not meet your expectations\n \ If the result does not meet your expectations\n \
Please input the 16k 16bit 1 channel wav file. \ Please input the 16k 16 bit 1 channel wav file. \
" "
.format(self.sample_rate, self.sample_rate)) .format(self.sample_rate, self.sample_rate))
while (True): while (True):

Loading…
Cancel
Save