fix fastspeech2 multi speaker to static, test=tts

pull/1362/head
TianYuan 3 years ago
parent 4a133619a1
commit 1a9e59612a

@ -257,6 +257,7 @@ python3 ${BIN_DIR}/../synthesize_e2e.py \
--output_dir=exp/default/test_e2e \ --output_dir=exp/default/test_e2e \
--phones_dict=fastspeech2_nosil_aishell3_ckpt_0.4/phone_id_map.txt \ --phones_dict=fastspeech2_nosil_aishell3_ckpt_0.4/phone_id_map.txt \
--speaker_dict=fastspeech2_nosil_aishell3_ckpt_0.4/speaker_id_map.txt \ --speaker_dict=fastspeech2_nosil_aishell3_ckpt_0.4/speaker_id_map.txt \
--spk_id=0 --spk_id=0 \
--inference_dir=exp/default/inference
``` ```

@ -20,4 +20,5 @@ python3 ${BIN_DIR}/../synthesize_e2e.py \
--output_dir=${train_output_path}/test_e2e \ --output_dir=${train_output_path}/test_e2e \
--phones_dict=dump/phone_id_map.txt \ --phones_dict=dump/phone_id_map.txt \
--speaker_dict=dump/speaker_id_map.txt \ --speaker_dict=dump/speaker_id_map.txt \
--spk_id=0 --spk_id=0 \
--inference_dir=${train_output_path}/inference

@ -240,13 +240,14 @@ python3 ${BIN_DIR}/../synthesize_e2e.py \
--am_ckpt=fastspeech2_nosil_vctk_ckpt_0.5/snapshot_iter_66200.pdz \ --am_ckpt=fastspeech2_nosil_vctk_ckpt_0.5/snapshot_iter_66200.pdz \
--am_stat=fastspeech2_nosil_vctk_ckpt_0.5/speech_stats.npy \ --am_stat=fastspeech2_nosil_vctk_ckpt_0.5/speech_stats.npy \
--voc=pwgan_vctk \ --voc=pwgan_vctk \
--voc_config=pwg_vctk_ckpt_0.5/pwg_default.yaml \ --voc_config=pwg_vctk_ckpt_0.1.1/default.yaml \
--voc_ckpt=pwg_vctk_ckpt_0.5/pwg_snapshot_iter_1000000.pdz \ --voc_ckpt=pwg_vctk_ckpt_0.1.1/snapshot_iter_1500000.pdz \
--voc_stat=pwg_vctk_ckpt_0.5/pwg_stats.npy \ --voc_stat=pwg_vctk_ckpt_0.1.1/feats_stats.npy \
--lang=en \ --lang=en \
--text=${BIN_DIR}/../sentences_en.txt \ --text=${BIN_DIR}/../sentences_en.txt \
--output_dir=exp/default/test_e2e \ --output_dir=exp/default/test_e2e \
--phones_dict=fastspeech2_nosil_vctk_ckpt_0.5/phone_id_map.txt \ --phones_dict=fastspeech2_nosil_vctk_ckpt_0.5/phone_id_map.txt \
--speaker_dict=fastspeech2_nosil_vctk_ckpt_0.5/speaker_id_map.txt \ --speaker_dict=fastspeech2_nosil_vctk_ckpt_0.5/speaker_id_map.txt \
--spk_id=0 --spk_id=0 \
--inference_dir=exp/default/inference
``` ```

@ -20,4 +20,5 @@ python3 ${BIN_DIR}/../synthesize_e2e.py \
--output_dir=${train_output_path}/test_e2e \ --output_dir=${train_output_path}/test_e2e \
--phones_dict=dump/phone_id_map.txt \ --phones_dict=dump/phone_id_map.txt \
--speaker_dict=dump/speaker_id_map.txt \ --speaker_dict=dump/speaker_id_map.txt \
--spk_id=0 --spk_id=0 \
--inference_dir=${train_output_path}/inference

@ -14,9 +14,11 @@
import argparse import argparse
from pathlib import Path from pathlib import Path
import numpy
import soundfile as sf import soundfile as sf
from paddle import inference from paddle import inference
from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend from paddlespeech.t2s.frontend.zh_frontend import Frontend
@ -29,20 +31,38 @@ def main():
'--am', '--am',
type=str, type=str,
default='fastspeech2_csmsc', default='fastspeech2_csmsc',
choices=['speedyspeech_csmsc', 'fastspeech2_csmsc'], choices=[
'speedyspeech_csmsc', 'fastspeech2_csmsc', 'fastspeech2_aishell3',
'fastspeech2_vctk'
],
help='Choose acoustic model type of tts task.') help='Choose acoustic model type of tts task.')
parser.add_argument( parser.add_argument(
"--phones_dict", type=str, default=None, help="phone vocabulary file.") "--phones_dict", type=str, default=None, help="phone vocabulary file.")
parser.add_argument( parser.add_argument(
"--tones_dict", type=str, default=None, help="tone vocabulary file.") "--tones_dict", type=str, default=None, help="tone vocabulary file.")
parser.add_argument(
"--speaker_dict", type=str, default=None, help="speaker id map file.")
parser.add_argument(
'--spk_id',
type=int,
default=0,
help='spk id for multi speaker acoustic model')
# voc # voc
parser.add_argument( parser.add_argument(
'--voc', '--voc',
type=str, type=str,
default='pwgan_csmsc', default='pwgan_csmsc',
choices=['pwgan_csmsc', 'mb_melgan_csmsc', 'hifigan_csmsc'], choices=[
'pwgan_csmsc', 'mb_melgan_csmsc', 'hifigan_csmsc', 'pwgan_aishell3',
'pwgan_vctk'
],
help='Choose vocoder type of tts task.') help='Choose vocoder type of tts task.')
# other # other
parser.add_argument(
'--lang',
type=str,
default='zh',
help='Choose model language. zh or en')
parser.add_argument( parser.add_argument(
"--text", "--text",
type=str, type=str,
@ -53,8 +73,12 @@ def main():
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
frontend = Frontend( # frontend
phone_vocab_path=args.phones_dict, tone_vocab_path=args.tones_dict) if args.lang == 'zh':
frontend = Frontend(
phone_vocab_path=args.phones_dict, tone_vocab_path=args.tones_dict)
elif args.lang == 'en':
frontend = English(phone_vocab_path=args.phones_dict)
print("frontend done!") print("frontend done!")
# model: {model_name}_{dataset} # model: {model_name}_{dataset}
@ -83,30 +107,52 @@ def main():
print("in new inference") print("in new inference")
# construct dataset for evaluation
sentences = []
with open(args.text, 'rt') as f: with open(args.text, 'rt') as f:
for line in f: for line in f:
items = line.strip().split() items = line.strip().split()
utt_id = items[0] utt_id = items[0]
sentence = "".join(items[1:]) if args.lang == 'zh':
sentence = "".join(items[1:])
elif args.lang == 'en':
sentence = " ".join(items[1:])
sentences.append((utt_id, sentence)) sentences.append((utt_id, sentence))
get_tone_ids = False get_tone_ids = False
if am_name == 'speedyspeech': if am_name == 'speedyspeech':
get_tone_ids = True get_tone_ids = True
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
get_spk_id = True
spk_id = numpy.array([args.spk_id])
am_input_names = am_predictor.get_input_names() am_input_names = am_predictor.get_input_names()
print("am_input_names:", am_input_names)
merge_sentences = True
for utt_id, sentence in sentences: for utt_id, sentence in sentences:
input_ids = frontend.get_input_ids( if args.lang == 'zh':
sentence, merge_sentences=True, get_tone_ids=get_tone_ids) input_ids = frontend.get_input_ids(
phone_ids = input_ids["phone_ids"] sentence,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
elif args.lang == 'en':
input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en'}!")
if get_tone_ids: if get_tone_ids:
tone_ids = input_ids["tone_ids"] tone_ids = input_ids["tone_ids"]
tones = tone_ids[0].numpy() tones = tone_ids[0].numpy()
tones_handle = am_predictor.get_input_handle(am_input_names[1]) tones_handle = am_predictor.get_input_handle(am_input_names[1])
tones_handle.reshape(tones.shape) tones_handle.reshape(tones.shape)
tones_handle.copy_from_cpu(tones) tones_handle.copy_from_cpu(tones)
if get_spk_id:
spk_id_handle = am_predictor.get_input_handle(am_input_names[1])
spk_id_handle.reshape(spk_id.shape)
spk_id_handle.copy_from_cpu(spk_id)
phones = phone_ids[0].numpy() phones = phone_ids[0].numpy()
phones_handle = am_predictor.get_input_handle(am_input_names[0]) phones_handle = am_predictor.get_input_handle(am_input_names[0])
phones_handle.reshape(phones.shape) phones_handle.reshape(phones.shape)

@ -159,9 +159,16 @@ def evaluate(args):
# acoustic model # acoustic model
if am_name == 'fastspeech2': if am_name == 'fastspeech2':
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
print( am_inference = jit.to_static(
"Haven't test dygraph to static for multi speaker fastspeech2 now!" am_inference,
) input_spec=[
InputSpec([-1], dtype=paddle.int64),
InputSpec([1], dtype=paddle.int64)
])
paddle.jit.save(am_inference,
os.path.join(args.inference_dir, args.am))
am_inference = paddle.jit.load(
os.path.join(args.inference_dir, args.am))
else: else:
am_inference = jit.to_static( am_inference = jit.to_static(
am_inference, am_inference,

@ -781,7 +781,7 @@ class FastSpeech2(nn.Layer):
elif self.spk_embed_integration_type == "concat": elif self.spk_embed_integration_type == "concat":
# concat hidden states with spk embeds and then apply projection # concat hidden states with spk embeds and then apply projection
spk_emb = F.normalize(spk_emb).unsqueeze(1).expand( spk_emb = F.normalize(spk_emb).unsqueeze(1).expand(
shape=[-1, hs.shape[1], -1]) shape=[-1, paddle.shape(hs)[1], -1])
hs = self.spk_projection(paddle.concat([hs, spk_emb], axis=-1)) hs = self.spk_projection(paddle.concat([hs, spk_emb], axis=-1))
else: else:
raise NotImplementedError("support only add or concat.") raise NotImplementedError("support only add or concat.")

Loading…
Cancel
Save