fix inference

pull/3005/head
lym0302 3 years ago
parent 9acc85205a
commit c9c6960f7e

@ -21,6 +21,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--voc_stat=pwgan_opencpop/feats_stats.npy \
--test_metadata=dump/test/norm/metadata.jsonl \
--output_dir=${train_output_path}/test \
--phones_dict=dump/phone_id_map.txt
--phones_dict=dump/phone_id_map.txt \
--speech_stretchs=dump/train/speech_stretchs.npy
fi

@ -333,14 +333,16 @@ def run_frontend(frontend: object,
# dygraph
def get_am_inference(am: str='fastspeech2_csmsc',
am_config: CfgNode=None,
am_ckpt: Optional[os.PathLike]=None,
am_stat: Optional[os.PathLike]=None,
phones_dict: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None,
speaker_dict: Optional[os.PathLike]=None,
return_am: bool=False):
def get_am_inference(
am: str='fastspeech2_csmsc',
am_config: CfgNode=None,
am_ckpt: Optional[os.PathLike]=None,
am_stat: Optional[os.PathLike]=None,
phones_dict: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None,
speaker_dict: Optional[os.PathLike]=None,
return_am: bool=False,
speech_stretchs: Optional[os.PathLike]=None, ):
with open(phones_dict, 'rt', encoding='utf-8') as f:
phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id)
@ -364,8 +366,18 @@ def get_am_inference(am: str='fastspeech2_csmsc',
am = am_class(
idim=vocab_size, odim=odim, spk_num=spk_num, **am_config["model"])
elif am_name == 'diffsinger':
with open(speech_stretchs, "r") as f:
spec_min = np.load(speech_stretchs)[0]
spec_max = np.load(speech_stretchs)[1]
spec_min = paddle.to_tensor(spec_min)
spec_max = paddle.to_tensor(spec_max)
am_config["model"]["fastspeech2_params"]["spk_num"] = spk_num
am = am_class(idim=vocab_size, odim=odim, **am_config["model"])
am = am_class(
idim=vocab_size,
odim=odim,
**am_config["model"],
spec_min=spec_min,
spec_max=spec_max, )
elif am_name == 'speedyspeech':
am = am_class(
vocab_size=vocab_size,

@ -60,7 +60,8 @@ def evaluate(args):
am_stat=args.am_stat,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict,
speaker_dict=args.speaker_dict)
speaker_dict=args.speaker_dict,
speech_stretchs=args.speech_stretchs, )
test_dataset = get_test_dataset(
test_metadata=test_metadata,
am=args.am,
@ -221,6 +222,11 @@ def parse_args():
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
parser.add_argument("--test_metadata", type=str, help="test metadata.")
parser.add_argument("--output_dir", type=str, help="output dir.")
parser.add_argument(
"--speech_stretchs",
type=str,
default=None,
help="The min and max values of the mel spectrum.")
args = parser.parse_args()
return args

Loading…
Cancel
Save