Merge pull request #1723 from yt605155624/refactor_syn_util

[TTS]restructure syn_utils.py, test=tts
pull/1732/head
Hui Zhang 3 years ago committed by GitHub
commit 523d5bd6d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -27,20 +27,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
--phones_dict=dump/phone_id_map.txt --phones_dict=dump/phone_id_map.txt
fi fi
# style melgan
# style melgan's Dygraph to Static Graph is not ready now
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
python3 ${BIN_DIR}/../inference.py \
--inference_dir=${train_output_path}/inference \
--am=tacotron2_csmsc \
--voc=style_melgan_csmsc \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/pd_infer_out \
--phones_dict=dump/phone_id_map.txt
fi
# hifigan # hifigan
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
python3 ${BIN_DIR}/../inference.py \ python3 ${BIN_DIR}/../inference.py \
--inference_dir=${train_output_path}/inference \ --inference_dir=${train_output_path}/inference \
--am=tacotron2_csmsc \ --am=tacotron2_csmsc \

@ -28,7 +28,6 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
--phones_dict=dump/phone_id_map.txt --phones_dict=dump/phone_id_map.txt
fi fi
# hifigan # hifigan
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
python3 ${BIN_DIR}/../inference.py \ python3 ${BIN_DIR}/../inference.py \

@ -109,6 +109,6 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
--lang=zh \ --lang=zh \
--text=${BIN_DIR}/../sentences.txt \ --text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/test_e2e \ --output_dir=${train_output_path}/test_e2e \
--phones_dict=dump/phone_id_map.txt \ --phones_dict=dump/phone_id_map.txt #\
--inference_dir=${train_output_path}/inference # --inference_dir=${train_output_path}/inference
fi fi

@ -26,7 +26,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
fi fi
# hifigan # hifigan
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
FLAGS_allocator_strategy=naive_best_fit \ FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \ FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../synthesize.py \ python3 ${BIN_DIR}/../synthesize.py \

@ -102,20 +102,31 @@ def parse_args():
def main(): def main():
args = parse_args() args = parse_args()
# frontend # frontend
frontend = get_frontend(args) frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
# am_predictor # am_predictor
am_predictor = get_predictor(args, filed='am') am_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.am + ".pdmodel",
params_file=args.am + ".pdiparams",
device=args.device)
# model: {model_name}_{dataset} # model: {model_name}_{dataset}
am_dataset = args.am[args.am.rindex('_') + 1:] am_dataset = args.am[args.am.rindex('_') + 1:]
# voc_predictor # voc_predictor
voc_predictor = get_predictor(args, filed='voc') voc_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.voc + ".pdmodel",
params_file=args.voc + ".pdiparams",
device=args.device)
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
sentences = get_sentences(args) sentences = get_sentences(text_file=args.text, lang=args.lang)
merge_sentences = True merge_sentences = True
fs = 24000 if am_dataset != 'ljspeech' else 22050 fs = 24000 if am_dataset != 'ljspeech' else 22050
@ -123,11 +134,13 @@ def main():
for utt_id, sentence in sentences[:3]: for utt_id, sentence in sentences[:3]:
with timer() as t: with timer() as t:
am_output_data = get_am_output( am_output_data = get_am_output(
args, input=sentence,
am_predictor=am_predictor, am_predictor=am_predictor,
am=args.am,
frontend=frontend, frontend=frontend,
lang=args.lang,
merge_sentences=merge_sentences, merge_sentences=merge_sentences,
input=sentence) speaker_dict=args.speaker_dict, )
wav = get_voc_output( wav = get_voc_output(
voc_predictor=voc_predictor, input=am_output_data) voc_predictor=voc_predictor, input=am_output_data)
speed = wav.size / t.elapse speed = wav.size / t.elapse
@ -143,11 +156,13 @@ def main():
for utt_id, sentence in sentences: for utt_id, sentence in sentences:
with timer() as t: with timer() as t:
am_output_data = get_am_output( am_output_data = get_am_output(
args, input=sentence,
am_predictor=am_predictor, am_predictor=am_predictor,
am=args.am,
frontend=frontend, frontend=frontend,
lang=args.lang,
merge_sentences=merge_sentences, merge_sentences=merge_sentences,
input=sentence) speaker_dict=args.speaker_dict, )
wav = get_voc_output( wav = get_voc_output(
voc_predictor=voc_predictor, input=am_output_data) voc_predictor=voc_predictor, input=am_output_data)

@ -25,7 +25,6 @@ from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_predictor from paddlespeech.t2s.exps.syn_utils import get_predictor
from paddlespeech.t2s.exps.syn_utils import get_sentences from paddlespeech.t2s.exps.syn_utils import get_sentences
from paddlespeech.t2s.exps.syn_utils import get_streaming_am_output from paddlespeech.t2s.exps.syn_utils import get_streaming_am_output
from paddlespeech.t2s.exps.syn_utils import get_streaming_am_predictor
from paddlespeech.t2s.exps.syn_utils import get_voc_output from paddlespeech.t2s.exps.syn_utils import get_voc_output
from paddlespeech.t2s.utils import str2bool from paddlespeech.t2s.utils import str2bool
@ -102,22 +101,43 @@ def parse_args():
def main(): def main():
args = parse_args() args = parse_args()
# frontend # frontend
frontend = get_frontend(args) frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
# am_predictor # am_predictor
am_encoder_infer_predictor, am_decoder_predictor, am_postnet_predictor = get_streaming_am_predictor(
args) am_encoder_infer_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.am + "_am_encoder_infer" + ".pdmodel",
params_file=args.am + "_am_encoder_infer" + ".pdiparams",
device=args.device)
am_decoder_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.am + "_am_decoder" + ".pdmodel",
params_file=args.am + "_am_decoder" + ".pdiparams",
device=args.device)
am_postnet_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.am + "_am_postnet" + ".pdmodel",
params_file=args.am + "_am_postnet" + ".pdiparams",
device=args.device)
am_mu, am_std = np.load(args.am_stat) am_mu, am_std = np.load(args.am_stat)
# model: {model_name}_{dataset} # model: {model_name}_{dataset}
am_dataset = args.am[args.am.rindex('_') + 1:] am_dataset = args.am[args.am.rindex('_') + 1:]
# voc_predictor # voc_predictor
voc_predictor = get_predictor(args, filed='voc') voc_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.voc + ".pdmodel",
params_file=args.voc + ".pdiparams",
device=args.device)
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
sentences = get_sentences(args) sentences = get_sentences(text_file=args.text, lang=args.lang)
merge_sentences = True merge_sentences = True
@ -126,13 +146,13 @@ def main():
for utt_id, sentence in sentences[:3]: for utt_id, sentence in sentences[:3]:
with timer() as t: with timer() as t:
normalized_mel = get_streaming_am_output( normalized_mel = get_streaming_am_output(
args, input=sentence,
am_encoder_infer_predictor=am_encoder_infer_predictor, am_encoder_infer_predictor=am_encoder_infer_predictor,
am_decoder_predictor=am_decoder_predictor, am_decoder_predictor=am_decoder_predictor,
am_postnet_predictor=am_postnet_predictor, am_postnet_predictor=am_postnet_predictor,
frontend=frontend, frontend=frontend,
merge_sentences=merge_sentences, lang=args.lang,
input=sentence) merge_sentences=merge_sentences, )
mel = denorm(normalized_mel, am_mu, am_std) mel = denorm(normalized_mel, am_mu, am_std)
wav = get_voc_output(voc_predictor=voc_predictor, input=mel) wav = get_voc_output(voc_predictor=voc_predictor, input=mel)
speed = wav.size / t.elapse speed = wav.size / t.elapse

@ -30,7 +30,7 @@ def ort_predict(args):
test_metadata = list(reader) test_metadata = list(reader)
am_name = args.am[:args.am.rindex('_')] am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:] am_dataset = args.am[args.am.rindex('_') + 1:]
test_dataset = get_test_dataset(args, test_metadata, am_name, am_dataset) test_dataset = get_test_dataset(test_metadata=test_metadata, am=args.am)
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
@ -38,10 +38,18 @@ def ort_predict(args):
fs = 24000 if am_dataset != 'ljspeech' else 22050 fs = 24000 if am_dataset != 'ljspeech' else 22050
# am # am
am_sess = get_sess(args, filed='am') am_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.am + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
# vocoder # vocoder
voc_sess = get_sess(args, filed='voc') voc_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.voc + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
# am warmup # am warmup
for T in [27, 38, 54]: for T in [27, 38, 54]:

@ -27,21 +27,31 @@ from paddlespeech.t2s.utils import str2bool
def ort_predict(args): def ort_predict(args):
# frontend # frontend
frontend = get_frontend(args) frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
sentences = get_sentences(args) sentences = get_sentences(text_file=args.text, lang=args.lang)
am_name = args.am[:args.am.rindex('_')] am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:] am_dataset = args.am[args.am.rindex('_') + 1:]
fs = 24000 if am_dataset != 'ljspeech' else 22050 fs = 24000 if am_dataset != 'ljspeech' else 22050
# am am_sess = get_sess(
am_sess = get_sess(args, filed='am') model_dir=args.inference_dir,
model_file=args.am + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
# vocoder # vocoder
voc_sess = get_sess(args, filed='voc') voc_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.voc + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
# frontend warmup # frontend warmup
# Loading model cost 0.5+ seconds # Loading model cost 0.5+ seconds

@ -23,30 +23,50 @@ from paddlespeech.t2s.exps.syn_utils import get_chunks
from paddlespeech.t2s.exps.syn_utils import get_frontend from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_sentences from paddlespeech.t2s.exps.syn_utils import get_sentences
from paddlespeech.t2s.exps.syn_utils import get_sess from paddlespeech.t2s.exps.syn_utils import get_sess
from paddlespeech.t2s.exps.syn_utils import get_streaming_am_sess
from paddlespeech.t2s.utils import str2bool from paddlespeech.t2s.utils import str2bool
def ort_predict(args): def ort_predict(args):
# frontend # frontend
frontend = get_frontend(args) frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
sentences = get_sentences(args) sentences = get_sentences(text_file=args.text, lang=args.lang)
am_name = args.am[:args.am.rindex('_')] am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:] am_dataset = args.am[args.am.rindex('_') + 1:]
fs = 24000 if am_dataset != 'ljspeech' else 22050 fs = 24000 if am_dataset != 'ljspeech' else 22050
# am # streaming acoustic model
am_encoder_infer_sess, am_decoder_sess, am_postnet_sess = get_streaming_am_sess( am_encoder_infer_sess = get_sess(
args) model_dir=args.inference_dir,
model_file=args.am + "_am_encoder_infer" + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
am_decoder_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.am + "_am_decoder" + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
am_postnet_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.am + "_am_postnet" + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
am_mu, am_std = np.load(args.am_stat) am_mu, am_std = np.load(args.am_stat)
# vocoder # vocoder
voc_sess = get_sess(args, filed='voc') voc_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.voc + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
# frontend warmup # frontend warmup
# Loading model cost 0.5+ seconds # Loading model cost 0.5+ seconds

@ -14,6 +14,10 @@
import math import math
import os import os
from pathlib import Path from pathlib import Path
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
import numpy as np import numpy as np
import onnxruntime as ort import onnxruntime as ort
@ -21,6 +25,7 @@ import paddle
from paddle import inference from paddle import inference
from paddle import jit from paddle import jit
from paddle.static import InputSpec from paddle.static import InputSpec
from yacs.config import CfgNode
from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.t2s.datasets.data_table import DataTable from paddlespeech.t2s.datasets.data_table import DataTable
@ -70,7 +75,7 @@ def denorm(data, mean, std):
return data * std + mean return data * std + mean
def get_chunks(data, chunk_size, pad_size): def get_chunks(data, chunk_size: int, pad_size: int):
data_len = data.shape[1] data_len = data.shape[1]
chunks = [] chunks = []
n = math.ceil(data_len / chunk_size) n = math.ceil(data_len / chunk_size)
@ -82,28 +87,34 @@ def get_chunks(data, chunk_size, pad_size):
# input # input
def get_sentences(args): def get_sentences(text_file: Optional[os.PathLike], lang: str='zh'):
# construct dataset for evaluation # construct dataset for evaluation
sentences = [] sentences = []
with open(args.text, 'rt') as f: with open(text_file, 'rt') as f:
for line in f: for line in f:
items = line.strip().split() items = line.strip().split()
utt_id = items[0] utt_id = items[0]
if 'lang' in args and args.lang == 'zh': if lang == 'zh':
sentence = "".join(items[1:]) sentence = "".join(items[1:])
elif 'lang' in args and args.lang == 'en': elif lang == 'en':
sentence = " ".join(items[1:]) sentence = " ".join(items[1:])
sentences.append((utt_id, sentence)) sentences.append((utt_id, sentence))
return sentences return sentences
def get_test_dataset(args, test_metadata, am_name, am_dataset): def get_test_dataset(test_metadata: List[Dict[str, Any]],
am: str,
speaker_dict: Optional[os.PathLike]=None,
voice_cloning: bool=False):
# model: {model_name}_{dataset}
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
if am_name == 'fastspeech2': if am_name == 'fastspeech2':
fields = ["utt_id", "text"] fields = ["utt_id", "text"]
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None:
print("multiple speaker fastspeech2!") print("multiple speaker fastspeech2!")
fields += ["spk_id"] fields += ["spk_id"]
elif 'voice_cloning' in args and args.voice_cloning: elif voice_cloning:
print("voice cloning!") print("voice cloning!")
fields += ["spk_emb"] fields += ["spk_emb"]
else: else:
@ -112,7 +123,7 @@ def get_test_dataset(args, test_metadata, am_name, am_dataset):
fields = ["utt_id", "phones", "tones"] fields = ["utt_id", "phones", "tones"]
elif am_name == 'tacotron2': elif am_name == 'tacotron2':
fields = ["utt_id", "text"] fields = ["utt_id", "text"]
if 'voice_cloning' in args and args.voice_cloning: if voice_cloning:
print("voice cloning!") print("voice cloning!")
fields += ["spk_emb"] fields += ["spk_emb"]
@ -121,12 +132,14 @@ def get_test_dataset(args, test_metadata, am_name, am_dataset):
# frontend # frontend
def get_frontend(args): def get_frontend(lang: str='zh',
if 'lang' in args and args.lang == 'zh': phones_dict: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None):
if lang == 'zh':
frontend = Frontend( frontend = Frontend(
phone_vocab_path=args.phones_dict, tone_vocab_path=args.tones_dict) phone_vocab_path=phones_dict, tone_vocab_path=tones_dict)
elif 'lang' in args and args.lang == 'en': elif lang == 'en':
frontend = English(phone_vocab_path=args.phones_dict) frontend = English(phone_vocab_path=phones_dict)
else: else:
print("wrong lang!") print("wrong lang!")
print("frontend done!") print("frontend done!")
@ -134,30 +147,37 @@ def get_frontend(args):
# dygraph # dygraph
def get_am_inference(args, am_config): def get_am_inference(
with open(args.phones_dict, "r") as f: am: str='fastspeech2_csmsc',
am_config: CfgNode=None,
am_ckpt: Optional[os.PathLike]=None,
am_stat: Optional[os.PathLike]=None,
phones_dict: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None,
speaker_dict: Optional[os.PathLike]=None, ):
with open(phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()] phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id) vocab_size = len(phn_id)
print("vocab_size:", vocab_size) print("vocab_size:", vocab_size)
tone_size = None tone_size = None
if 'tones_dict' in args and args.tones_dict: if tones_dict is not None:
with open(args.tones_dict, "r") as f: with open(tones_dict, "r") as f:
tone_id = [line.strip().split() for line in f.readlines()] tone_id = [line.strip().split() for line in f.readlines()]
tone_size = len(tone_id) tone_size = len(tone_id)
print("tone_size:", tone_size) print("tone_size:", tone_size)
spk_num = None spk_num = None
if 'speaker_dict' in args and args.speaker_dict: if speaker_dict is not None:
with open(args.speaker_dict, 'rt') as f: with open(speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()] spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id) spk_num = len(spk_id)
print("spk_num:", spk_num) print("spk_num:", spk_num)
odim = am_config.n_mels odim = am_config.n_mels
# model: {model_name}_{dataset} # model: {model_name}_{dataset}
am_name = args.am[:args.am.rindex('_')] am_name = am[:am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:] am_dataset = am[am.rindex('_') + 1:]
am_class = dynamic_import(am_name, model_alias) am_class = dynamic_import(am_name, model_alias)
am_inference_class = dynamic_import(am_name + '_inference', model_alias) am_inference_class = dynamic_import(am_name + '_inference', model_alias)
@ -174,34 +194,38 @@ def get_am_inference(args, am_config):
elif am_name == 'tacotron2': elif am_name == 'tacotron2':
am = am_class(idim=vocab_size, odim=odim, **am_config["model"]) am = am_class(idim=vocab_size, odim=odim, **am_config["model"])
am.set_state_dict(paddle.load(args.am_ckpt)["main_params"]) am.set_state_dict(paddle.load(am_ckpt)["main_params"])
am.eval() am.eval()
am_mu, am_std = np.load(args.am_stat) am_mu, am_std = np.load(am_stat)
am_mu = paddle.to_tensor(am_mu) am_mu = paddle.to_tensor(am_mu)
am_std = paddle.to_tensor(am_std) am_std = paddle.to_tensor(am_std)
am_normalizer = ZScore(am_mu, am_std) am_normalizer = ZScore(am_mu, am_std)
am_inference = am_inference_class(am_normalizer, am) am_inference = am_inference_class(am_normalizer, am)
am_inference.eval() am_inference.eval()
print("acoustic model done!") print("acoustic model done!")
return am_inference, am_name, am_dataset return am_inference
def get_voc_inference(args, voc_config): def get_voc_inference(
voc: str='pwgan_csmsc',
voc_config: Optional[os.PathLike]=None,
voc_ckpt: Optional[os.PathLike]=None,
voc_stat: Optional[os.PathLike]=None, ):
# model: {model_name}_{dataset} # model: {model_name}_{dataset}
voc_name = args.voc[:args.voc.rindex('_')] voc_name = voc[:voc.rindex('_')]
voc_class = dynamic_import(voc_name, model_alias) voc_class = dynamic_import(voc_name, model_alias)
voc_inference_class = dynamic_import(voc_name + '_inference', model_alias) voc_inference_class = dynamic_import(voc_name + '_inference', model_alias)
if voc_name != 'wavernn': if voc_name != 'wavernn':
voc = voc_class(**voc_config["generator_params"]) voc = voc_class(**voc_config["generator_params"])
voc.set_state_dict(paddle.load(args.voc_ckpt)["generator_params"]) voc.set_state_dict(paddle.load(voc_ckpt)["generator_params"])
voc.remove_weight_norm() voc.remove_weight_norm()
voc.eval() voc.eval()
else: else:
voc = voc_class(**voc_config["model"]) voc = voc_class(**voc_config["model"])
voc.set_state_dict(paddle.load(args.voc_ckpt)["main_params"]) voc.set_state_dict(paddle.load(voc_ckpt)["main_params"])
voc.eval() voc.eval()
voc_mu, voc_std = np.load(args.voc_stat) voc_mu, voc_std = np.load(voc_stat)
voc_mu = paddle.to_tensor(voc_mu) voc_mu = paddle.to_tensor(voc_mu)
voc_std = paddle.to_tensor(voc_std) voc_std = paddle.to_tensor(voc_std)
voc_normalizer = ZScore(voc_mu, voc_std) voc_normalizer = ZScore(voc_mu, voc_std)
@ -211,10 +235,16 @@ def get_voc_inference(args, voc_config):
return voc_inference return voc_inference
# to static # dygraph to static graph
def am_to_static(args, am_inference, am_name, am_dataset): def am_to_static(am_inference,
am: str='fastspeech2_csmsc',
inference_dir=Optional[os.PathLike],
speaker_dict: Optional[os.PathLike]=None):
# model: {model_name}_{dataset}
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
if am_name == 'fastspeech2': if am_name == 'fastspeech2':
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None:
am_inference = jit.to_static( am_inference = jit.to_static(
am_inference, am_inference,
input_spec=[ input_spec=[
@ -226,7 +256,7 @@ def am_to_static(args, am_inference, am_name, am_dataset):
am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)]) am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
elif am_name == 'speedyspeech': elif am_name == 'speedyspeech':
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None:
am_inference = jit.to_static( am_inference = jit.to_static(
am_inference, am_inference,
input_spec=[ input_spec=[
@ -247,56 +277,64 @@ def am_to_static(args, am_inference, am_name, am_dataset):
am_inference = jit.to_static( am_inference = jit.to_static(
am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)]) am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
paddle.jit.save(am_inference, os.path.join(args.inference_dir, args.am)) paddle.jit.save(am_inference, os.path.join(inference_dir, am))
am_inference = paddle.jit.load(os.path.join(args.inference_dir, args.am)) am_inference = paddle.jit.load(os.path.join(inference_dir, am))
return am_inference return am_inference
def voc_to_static(args, voc_inference): def voc_to_static(voc_inference,
voc: str='pwgan_csmsc',
inference_dir=Optional[os.PathLike]):
voc_inference = jit.to_static( voc_inference = jit.to_static(
voc_inference, input_spec=[ voc_inference, input_spec=[
InputSpec([-1, 80], dtype=paddle.float32), InputSpec([-1, 80], dtype=paddle.float32),
]) ])
paddle.jit.save(voc_inference, os.path.join(args.inference_dir, args.voc)) paddle.jit.save(voc_inference, os.path.join(inference_dir, voc))
voc_inference = paddle.jit.load(os.path.join(args.inference_dir, args.voc)) voc_inference = paddle.jit.load(os.path.join(inference_dir, voc))
return voc_inference return voc_inference
# inference # inference
def get_predictor(args, filed='am'): def get_predictor(model_dir: Optional[os.PathLike]=None,
full_name = '' model_file: Optional[os.PathLike]=None,
if filed == 'am': params_file: Optional[os.PathLike]=None,
full_name = args.am device: str='cpu'):
elif filed == 'voc':
full_name = args.voc
config = inference.Config( config = inference.Config(
str(Path(args.inference_dir) / (full_name + ".pdmodel")), str(Path(model_dir) / model_file), str(Path(model_dir) / params_file))
str(Path(args.inference_dir) / (full_name + ".pdiparams"))) if device == "gpu":
if args.device == "gpu":
config.enable_use_gpu(100, 0) config.enable_use_gpu(100, 0)
elif args.device == "cpu": elif device == "cpu":
config.disable_gpu() config.disable_gpu()
config.enable_memory_optim() config.enable_memory_optim()
predictor = inference.create_predictor(config) predictor = inference.create_predictor(config)
return predictor return predictor
def get_am_output(args, am_predictor, frontend, merge_sentences, input): def get_am_output(
am_name = args.am[:args.am.rindex('_')] input: str,
am_dataset = args.am[args.am.rindex('_') + 1:] am_predictor,
am,
frontend,
lang: str='zh',
merge_sentences: bool=True,
speaker_dict: Optional[os.PathLike]=None,
spk_id: int=0, ):
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
am_input_names = am_predictor.get_input_names() am_input_names = am_predictor.get_input_names()
get_tone_ids = False get_tone_ids = False
get_spk_id = False get_spk_id = False
if am_name == 'speedyspeech': if am_name == 'speedyspeech':
get_tone_ids = True get_tone_ids = True
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: if am_dataset in {"aishell3", "vctk"} and speaker_dict:
get_spk_id = True get_spk_id = True
spk_id = np.array([args.spk_id]) spk_id = np.array([spk_id])
if args.lang == 'zh': if lang == 'zh':
input_ids = frontend.get_input_ids( input_ids = frontend.get_input_ids(
input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"] phone_ids = input_ids["phone_ids"]
elif args.lang == 'en': elif lang == 'en':
input_ids = frontend.get_input_ids( input_ids = frontend.get_input_ids(
input, merge_sentences=merge_sentences) input, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"] phone_ids = input_ids["phone_ids"]
@ -338,50 +376,6 @@ def get_voc_output(voc_predictor, input):
return wav return wav
# streaming am
def get_streaming_am_predictor(args):
full_name = args.am
am_encoder_infer_config = inference.Config(
str(
Path(args.inference_dir) /
(full_name + "_am_encoder_infer" + ".pdmodel")),
str(
Path(args.inference_dir) /
(full_name + "_am_encoder_infer" + ".pdiparams")))
am_decoder_config = inference.Config(
str(
Path(args.inference_dir) /
(full_name + "_am_decoder" + ".pdmodel")),
str(
Path(args.inference_dir) /
(full_name + "_am_decoder" + ".pdiparams")))
am_postnet_config = inference.Config(
str(
Path(args.inference_dir) /
(full_name + "_am_postnet" + ".pdmodel")),
str(
Path(args.inference_dir) /
(full_name + "_am_postnet" + ".pdiparams")))
if args.device == "gpu":
am_encoder_infer_config.enable_use_gpu(100, 0)
am_decoder_config.enable_use_gpu(100, 0)
am_postnet_config.enable_use_gpu(100, 0)
elif args.device == "cpu":
am_encoder_infer_config.disable_gpu()
am_decoder_config.disable_gpu()
am_postnet_config.disable_gpu()
am_encoder_infer_config.enable_memory_optim()
am_decoder_config.enable_memory_optim()
am_postnet_config.enable_memory_optim()
am_encoder_infer_predictor = inference.create_predictor(
am_encoder_infer_config)
am_decoder_predictor = inference.create_predictor(am_decoder_config)
am_postnet_predictor = inference.create_predictor(am_postnet_config)
return am_encoder_infer_predictor, am_decoder_predictor, am_postnet_predictor
def get_am_sublayer_output(am_sublayer_predictor, input): def get_am_sublayer_output(am_sublayer_predictor, input):
am_sublayer_input_names = am_sublayer_predictor.get_input_names() am_sublayer_input_names = am_sublayer_predictor.get_input_names()
input_handle = am_sublayer_predictor.get_input_handle( input_handle = am_sublayer_predictor.get_input_handle(
@ -397,11 +391,15 @@ def get_am_sublayer_output(am_sublayer_predictor, input):
return am_sublayer_output return am_sublayer_output
def get_streaming_am_output(args, am_encoder_infer_predictor, def get_streaming_am_output(input: str,
am_decoder_predictor, am_postnet_predictor, am_encoder_infer_predictor,
frontend, merge_sentences, input): am_decoder_predictor,
am_postnet_predictor,
frontend,
lang: str='zh',
merge_sentences: bool=True):
get_tone_ids = False get_tone_ids = False
if args.lang == 'zh': if lang == 'zh':
input_ids = frontend.get_input_ids( input_ids = frontend.get_input_ids(
input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"] phone_ids = input_ids["phone_ids"]
@ -423,58 +421,27 @@ def get_streaming_am_output(args, am_encoder_infer_predictor,
return normalized_mel return normalized_mel
def get_sess(args, filed='am'): # onnx
full_name = '' def get_sess(model_dir: Optional[os.PathLike]=None,
if filed == 'am': model_file: Optional[os.PathLike]=None,
full_name = args.am device: str='cpu',
elif filed == 'voc': cpu_threads: int=1,
full_name = args.voc use_trt: bool=False):
model_dir = str(Path(args.inference_dir) / (full_name + ".onnx"))
model_dir = str(Path(model_dir) / model_file)
sess_options = ort.SessionOptions() sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
if args.device == "gpu": if device == "gpu":
# fastspeech2/mb_melgan can't use trt now! # fastspeech2/mb_melgan can't use trt now!
if args.use_trt: if use_trt:
providers = ['TensorrtExecutionProvider'] providers = ['TensorrtExecutionProvider']
else: else:
providers = ['CUDAExecutionProvider'] providers = ['CUDAExecutionProvider']
elif args.device == "cpu": elif device == "cpu":
providers = ['CPUExecutionProvider'] providers = ['CPUExecutionProvider']
sess_options.intra_op_num_threads = args.cpu_threads sess_options.intra_op_num_threads = cpu_threads
sess = ort.InferenceSession( sess = ort.InferenceSession(
model_dir, providers=providers, sess_options=sess_options) model_dir, providers=providers, sess_options=sess_options)
return sess return sess
# streaming am
def get_streaming_am_sess(args):
full_name = args.am
am_encoder_infer_model_dir = str(
Path(args.inference_dir) / (full_name + "_am_encoder_infer" + ".onnx"))
am_decoder_model_dir = str(
Path(args.inference_dir) / (full_name + "_am_decoder" + ".onnx"))
am_postnet_model_dir = str(
Path(args.inference_dir) / (full_name + "_am_postnet" + ".onnx"))
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
if args.device == "gpu":
# fastspeech2/mb_melgan can't use trt now!
if args.use_trt:
providers = ['TensorrtExecutionProvider']
else:
providers = ['CUDAExecutionProvider']
elif args.device == "cpu":
providers = ['CPUExecutionProvider']
sess_options.intra_op_num_threads = args.cpu_threads
am_encoder_infer_sess = ort.InferenceSession(
am_encoder_infer_model_dir,
providers=providers,
sess_options=sess_options)
am_decoder_sess = ort.InferenceSession(
am_decoder_model_dir, providers=providers, sess_options=sess_options)
am_postnet_sess = ort.InferenceSession(
am_postnet_model_dir, providers=providers, sess_options=sess_options)
return am_encoder_infer_sess, am_decoder_sess, am_postnet_sess

@ -50,11 +50,29 @@ def evaluate(args):
print(voc_config) print(voc_config)
# acoustic model # acoustic model
am_inference, am_name, am_dataset = get_am_inference(args, am_config) am_name = args.am[:args.am.rindex('_')]
test_dataset = get_test_dataset(args, test_metadata, am_name, am_dataset) am_dataset = args.am[args.am.rindex('_') + 1:]
am_inference = get_am_inference(
am=args.am,
am_config=am_config,
am_ckpt=args.am_ckpt,
am_stat=args.am_stat,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict,
speaker_dict=args.speaker_dict)
test_dataset = get_test_dataset(
test_metadata=test_metadata,
am=args.am,
speaker_dict=args.speaker_dict,
voice_cloning=args.voice_cloning)
# vocoder # vocoder
voc_inference = get_voc_inference(args, voc_config) voc_inference = get_voc_inference(
voc=args.voc,
voc_config=voc_config,
voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat)
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)

@ -42,24 +42,48 @@ def evaluate(args):
print(am_config) print(am_config)
print(voc_config) print(voc_config)
sentences = get_sentences(args) sentences = get_sentences(text_file=args.text, lang=args.lang)
# frontend # frontend
frontend = get_frontend(args) frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
# acoustic model # acoustic model
am_inference, am_name, am_dataset = get_am_inference(args, am_config) am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
am_inference = get_am_inference(
am=args.am,
am_config=am_config,
am_ckpt=args.am_ckpt,
am_stat=args.am_stat,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict,
speaker_dict=args.speaker_dict)
# vocoder # vocoder
voc_inference = get_voc_inference(args, voc_config) voc_inference = get_voc_inference(
voc=args.voc,
voc_config=voc_config,
voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat)
# whether dygraph to static # whether dygraph to static
if args.inference_dir: if args.inference_dir:
# acoustic model # acoustic model
am_inference = am_to_static(args, am_inference, am_name, am_dataset) am_inference = am_to_static(
am_inference=am_inference,
am=args.am,
inference_dir=args.inference_dir,
speaker_dict=args.speaker_dict)
# vocoder # vocoder
voc_inference = voc_to_static(args, voc_inference) voc_inference = voc_to_static(
voc_inference=voc_inference,
voc=args.voc,
inference_dir=args.inference_dir)
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)

@ -49,10 +49,13 @@ def evaluate(args):
print(am_config) print(am_config)
print(voc_config) print(voc_config)
sentences = get_sentences(args) sentences = get_sentences(text_file=args.text, lang=args.lang)
# frontend # frontend
frontend = get_frontend(args) frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
with open(args.phones_dict, "r") as f: with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()] phn_id = [line.strip().split() for line in f.readlines()]
@ -60,7 +63,6 @@ def evaluate(args):
print("vocab_size:", vocab_size) print("vocab_size:", vocab_size)
# acoustic model, only support fastspeech2 here now! # acoustic model, only support fastspeech2 here now!
# am_inference, am_name, am_dataset = get_am_inference(args, am_config)
# model: {model_name}_{dataset} # model: {model_name}_{dataset}
am_name = args.am[:args.am.rindex('_')] am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:] am_dataset = args.am[args.am.rindex('_') + 1:]
@ -80,7 +82,11 @@ def evaluate(args):
am_postnet = am.postnet am_postnet = am.postnet
# vocoder # vocoder
voc_inference = get_voc_inference(args, voc_config) voc_inference = get_voc_inference(
voc=args.voc,
voc_config=voc_config,
voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat)
# whether dygraph to static # whether dygraph to static
if args.inference_dir: if args.inference_dir:
@ -115,7 +121,10 @@ def evaluate(args):
os.path.join(args.inference_dir, args.am + "_am_postnet")) os.path.join(args.inference_dir, args.am + "_am_postnet"))
# vocoder # vocoder
voc_inference = voc_to_static(args, voc_inference) voc_inference = voc_to_static(
voc_inference=voc_inference,
voc=args.voc,
inference_dir=args.inference_dir)
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)

@ -66,10 +66,19 @@ def voice_cloning(args):
print("frontend done!") print("frontend done!")
# acoustic model # acoustic model
am_inference, *_ = get_am_inference(args, am_config) am_inference = get_am_inference(
am=args.am,
am_config=am_config,
am_ckpt=args.am_ckpt,
am_stat=args.am_stat,
phones_dict=args.phones_dict)
# vocoder # vocoder
voc_inference = get_voc_inference(args, voc_config) voc_inference = get_voc_inference(
voc=args.voc,
voc_config=voc_config,
voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat)
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)

@ -58,8 +58,7 @@ def main():
else: else:
print("ngpu should >= 0 !") print("ngpu should >= 0 !")
model = WaveRNN( model = WaveRNN(**config["model"])
hop_length=config.n_shift, sample_rate=config.fs, **config["model"])
state_dict = paddle.load(args.checkpoint) state_dict = paddle.load(args.checkpoint)
model.set_state_dict(state_dict["main_params"]) model.set_state_dict(state_dict["main_params"])

Loading…
Cancel
Save