Fix the code format, test=tts

pull/1302/head
Jerryuhoo 4 years ago
parent 1e710ef570
commit 111a452378

@ -15,21 +15,22 @@
# for mb melgan finetune # for mb melgan finetune
# 长度和原本的 mel 不一致怎么办? # 长度和原本的 mel 不一致怎么办?
import argparse import argparse
import os
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import paddle import paddle
import yaml import yaml
from yacs.config import CfgNode
from tqdm import tqdm from tqdm import tqdm
import os from yacs.config import CfgNode
from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur
from paddlespeech.t2s.datasets.preprocess_utils import merge_silence from paddlespeech.t2s.datasets.preprocess_utils import merge_silence
from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.models.speedyspeech import SpeedySpeech from paddlespeech.t2s.models.speedyspeech import SpeedySpeech
from paddlespeech.t2s.models.speedyspeech import SpeedySpeechInference from paddlespeech.t2s.models.speedyspeech import SpeedySpeechInference
from paddlespeech.t2s.modules.normalizer import ZScore from paddlespeech.t2s.modules.normalizer import ZScore
from paddlespeech.t2s.frontend.zh_frontend import Frontend
def evaluate(args, speedyspeech_config): def evaluate(args, speedyspeech_config):
rootdir = Path(args.rootdir).expanduser() rootdir = Path(args.rootdir).expanduser()
@ -50,17 +51,21 @@ def evaluate(args, speedyspeech_config):
tone_size = len(tone_id) tone_size = len(tone_id)
print("tone_size:", tone_size) print("tone_size:", tone_size)
frontend = Frontend(phone_vocab_path=args.phones_dict, tone_vocab_path=args.tones_dict) frontend = Frontend(
phone_vocab_path=args.phones_dict, tone_vocab_path=args.tones_dict)
if args.speaker_dict: if args.speaker_dict:
with open(args.speaker_dict, 'rt') as f: with open(args.speaker_dict, 'rt') as f:
spk_id_list = [line.strip().split() for line in f.readlines()] spk_id_list = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id_list) spk_num = len(spk_id_list)
else: else:
spk_num=None spk_num = None
model = SpeedySpeech( model = SpeedySpeech(
vocab_size=vocab_size, tone_size=tone_size, **speedyspeech_config["model"], spk_num=spk_num) vocab_size=vocab_size,
tone_size=tone_size,
**speedyspeech_config["model"],
spk_num=spk_num)
model.set_state_dict( model.set_state_dict(
paddle.load(args.speedyspeech_checkpoint)["main_params"]) paddle.load(args.speedyspeech_checkpoint)["main_params"])
@ -105,9 +110,15 @@ def evaluate(args, speedyspeech_config):
else: else:
train_wav_files += wav_files train_wav_files += wav_files
train_wav_files = [os.path.basename(str(str_path)) for str_path in train_wav_files] train_wav_files = [
dev_wav_files = [os.path.basename(str(str_path)) for str_path in dev_wav_files] os.path.basename(str(str_path)) for str_path in train_wav_files
test_wav_files = [os.path.basename(str(str_path)) for str_path in test_wav_files] ]
dev_wav_files = [
os.path.basename(str(str_path)) for str_path in dev_wav_files
]
test_wav_files = [
os.path.basename(str(str_path)) for str_path in test_wav_files
]
for i, utt_id in enumerate(tqdm(sentences)): for i, utt_id in enumerate(tqdm(sentences)):
phones = sentences[utt_id][0] phones = sentences[utt_id][0]
@ -122,8 +133,7 @@ def evaluate(args, speedyspeech_config):
durations = durations[:-1] durations = durations[:-1]
phones = phones[:-1] phones = phones[:-1]
phones, tones = frontend._get_phone_tone( phones, tones = frontend._get_phone_tone(phones, get_tone_ids=True)
phones, get_tone_ids=True)
if tones: if tones:
tone_ids = frontend._t2id(tones) tone_ids = frontend._t2id(tones)
tone_ids = paddle.to_tensor(tone_ids) tone_ids = paddle.to_tensor(tone_ids)
@ -132,7 +142,8 @@ def evaluate(args, speedyspeech_config):
phone_ids = paddle.to_tensor(phone_ids) phone_ids = paddle.to_tensor(phone_ids)
if args.speaker_dict: if args.speaker_dict:
speaker_id = int([item[1] for item in spk_id_list if speaker == item[0]][0]) speaker_id = int(
[item[1] for item in spk_id_list if speaker == item[0]][0])
speaker_id = paddle.to_tensor(speaker_id) speaker_id = paddle.to_tensor(speaker_id)
else: else:
speaker_id = None speaker_id = None
@ -155,7 +166,8 @@ def evaluate(args, speedyspeech_config):
sub_output_dir.mkdir(parents=True, exist_ok=True) sub_output_dir.mkdir(parents=True, exist_ok=True)
with paddle.no_grad(): with paddle.no_grad():
mel = speedyspeech_inference(phone_ids, tone_ids, durations=durations, spk_id=speaker_id) mel = speedyspeech_inference(
phone_ids, tone_ids, durations=durations, spk_id=speaker_id)
np.save(sub_output_dir / (utt_id + "_feats.npy"), mel) np.save(sub_output_dir / (utt_id + "_feats.npy"), mel)
@ -193,10 +205,7 @@ def main():
default="tone_id_map.txt", default="tone_id_map.txt",
help="tone vocabulary file.") help="tone vocabulary file.")
parser.add_argument( parser.add_argument(
"--speaker-dict", "--speaker-dict", type=str, default=None, help="speaker id map file.")
type=str,
default=None,
help="speaker id map file.")
parser.add_argument( parser.add_argument(
"--dur-file", default=None, type=str, help="path to durations.txt.") "--dur-file", default=None, type=str, help="path to durations.txt.")

@ -272,9 +272,6 @@ class SpeedySpeechInference(nn.Layer):
def forward(self, phones, tones, durations=None, spk_id=None): def forward(self, phones, tones, durations=None, spk_id=None):
normalized_mel = self.acoustic_model.inference( normalized_mel = self.acoustic_model.inference(
phones, phones, tones, durations=durations, spk_id=spk_id)
tones,
durations=durations,
spk_id=spk_id)
logmel = self.normalizer.inverse(normalized_mel) logmel = self.normalizer.inverse(normalized_mel)
return logmel return logmel

@ -20,6 +20,7 @@ import jsonlines
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
def main(): def main():
# parse config and args # parse config and args
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -61,9 +62,10 @@ def main():
try: try:
wav = np.load(old_dump_dir / sub / ("raw/" + wave_name)) wav = np.load(old_dump_dir / sub / ("raw/" + wave_name))
os.symlink(old_dump_dir / sub / ("raw/" + wave_name), os.symlink(old_dump_dir / sub / ("raw/" + wave_name),
output_dir / ("raw/" + wave_name)) output_dir / ("raw/" + wave_name))
except FileNotFoundError: except FileNotFoundError:
print("delete " + name + " because it cannot be found in the dump folder") print("delete " + name +
" because it cannot be found in the dump folder")
os.remove(output_dir / "raw" / name) os.remove(output_dir / "raw" / name)
continue continue
except FileExistsError: except FileExistsError:

Loading…
Cancel
Save