Merge pull request #1723 from yt605155624/refactor_syn_util

[TTS]restructure syn_utils.py, test=tts
pull/1732/head
Hui Zhang 3 years ago committed by GitHub
commit 523d5bd6d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -27,20 +27,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
--phones_dict=dump/phone_id_map.txt
fi
# style melgan
# style melgan's Dygraph to Static Graph is not ready now
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
python3 ${BIN_DIR}/../inference.py \
--inference_dir=${train_output_path}/inference \
--am=tacotron2_csmsc \
--voc=style_melgan_csmsc \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/pd_infer_out \
--phones_dict=dump/phone_id_map.txt
fi
# hifigan
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
python3 ${BIN_DIR}/../inference.py \
--inference_dir=${train_output_path}/inference \
--am=tacotron2_csmsc \

@ -28,7 +28,6 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
--phones_dict=dump/phone_id_map.txt
fi
# hifigan
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
python3 ${BIN_DIR}/../inference.py \

@ -109,6 +109,6 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
--lang=zh \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/test_e2e \
--phones_dict=dump/phone_id_map.txt \
--inference_dir=${train_output_path}/inference
--phones_dict=dump/phone_id_map.txt #\
# --inference_dir=${train_output_path}/inference
fi

@ -26,7 +26,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
fi
# hifigan
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../synthesize.py \

@ -102,20 +102,31 @@ def parse_args():
def main():
args = parse_args()
# frontend
frontend = get_frontend(args)
frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
# am_predictor
am_predictor = get_predictor(args, filed='am')
am_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.am + ".pdmodel",
params_file=args.am + ".pdiparams",
device=args.device)
# model: {model_name}_{dataset}
am_dataset = args.am[args.am.rindex('_') + 1:]
# voc_predictor
voc_predictor = get_predictor(args, filed='voc')
voc_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.voc + ".pdmodel",
params_file=args.voc + ".pdiparams",
device=args.device)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
sentences = get_sentences(args)
sentences = get_sentences(text_file=args.text, lang=args.lang)
merge_sentences = True
fs = 24000 if am_dataset != 'ljspeech' else 22050
@ -123,11 +134,13 @@ def main():
for utt_id, sentence in sentences[:3]:
with timer() as t:
am_output_data = get_am_output(
args,
input=sentence,
am_predictor=am_predictor,
am=args.am,
frontend=frontend,
lang=args.lang,
merge_sentences=merge_sentences,
input=sentence)
speaker_dict=args.speaker_dict, )
wav = get_voc_output(
voc_predictor=voc_predictor, input=am_output_data)
speed = wav.size / t.elapse
@ -143,11 +156,13 @@ def main():
for utt_id, sentence in sentences:
with timer() as t:
am_output_data = get_am_output(
args,
input=sentence,
am_predictor=am_predictor,
am=args.am,
frontend=frontend,
lang=args.lang,
merge_sentences=merge_sentences,
input=sentence)
speaker_dict=args.speaker_dict, )
wav = get_voc_output(
voc_predictor=voc_predictor, input=am_output_data)

@ -25,7 +25,6 @@ from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_predictor
from paddlespeech.t2s.exps.syn_utils import get_sentences
from paddlespeech.t2s.exps.syn_utils import get_streaming_am_output
from paddlespeech.t2s.exps.syn_utils import get_streaming_am_predictor
from paddlespeech.t2s.exps.syn_utils import get_voc_output
from paddlespeech.t2s.utils import str2bool
@ -102,22 +101,43 @@ def parse_args():
def main():
args = parse_args()
# frontend
frontend = get_frontend(args)
frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
# am_predictor
am_encoder_infer_predictor, am_decoder_predictor, am_postnet_predictor = get_streaming_am_predictor(
args)
am_encoder_infer_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.am + "_am_encoder_infer" + ".pdmodel",
params_file=args.am + "_am_encoder_infer" + ".pdiparams",
device=args.device)
am_decoder_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.am + "_am_decoder" + ".pdmodel",
params_file=args.am + "_am_decoder" + ".pdiparams",
device=args.device)
am_postnet_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.am + "_am_postnet" + ".pdmodel",
params_file=args.am + "_am_postnet" + ".pdiparams",
device=args.device)
am_mu, am_std = np.load(args.am_stat)
# model: {model_name}_{dataset}
am_dataset = args.am[args.am.rindex('_') + 1:]
# voc_predictor
voc_predictor = get_predictor(args, filed='voc')
voc_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.voc + ".pdmodel",
params_file=args.voc + ".pdiparams",
device=args.device)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
sentences = get_sentences(args)
sentences = get_sentences(text_file=args.text, lang=args.lang)
merge_sentences = True
@ -126,13 +146,13 @@ def main():
for utt_id, sentence in sentences[:3]:
with timer() as t:
normalized_mel = get_streaming_am_output(
args,
input=sentence,
am_encoder_infer_predictor=am_encoder_infer_predictor,
am_decoder_predictor=am_decoder_predictor,
am_postnet_predictor=am_postnet_predictor,
frontend=frontend,
merge_sentences=merge_sentences,
input=sentence)
lang=args.lang,
merge_sentences=merge_sentences, )
mel = denorm(normalized_mel, am_mu, am_std)
wav = get_voc_output(voc_predictor=voc_predictor, input=mel)
speed = wav.size / t.elapse

@ -30,7 +30,7 @@ def ort_predict(args):
test_metadata = list(reader)
am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
test_dataset = get_test_dataset(args, test_metadata, am_name, am_dataset)
test_dataset = get_test_dataset(test_metadata=test_metadata, am=args.am)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
@ -38,10 +38,18 @@ def ort_predict(args):
fs = 24000 if am_dataset != 'ljspeech' else 22050
# am
am_sess = get_sess(args, filed='am')
am_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.am + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
# vocoder
voc_sess = get_sess(args, filed='voc')
voc_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.voc + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
# am warmup
for T in [27, 38, 54]:

@ -27,21 +27,31 @@ from paddlespeech.t2s.utils import str2bool
def ort_predict(args):
# frontend
frontend = get_frontend(args)
frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
sentences = get_sentences(args)
sentences = get_sentences(text_file=args.text, lang=args.lang)
am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
fs = 24000 if am_dataset != 'ljspeech' else 22050
# am
am_sess = get_sess(args, filed='am')
am_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.am + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
# vocoder
voc_sess = get_sess(args, filed='voc')
voc_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.voc + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
# frontend warmup
# Loading model cost 0.5+ seconds

@ -23,30 +23,50 @@ from paddlespeech.t2s.exps.syn_utils import get_chunks
from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_sentences
from paddlespeech.t2s.exps.syn_utils import get_sess
from paddlespeech.t2s.exps.syn_utils import get_streaming_am_sess
from paddlespeech.t2s.utils import str2bool
def ort_predict(args):
# frontend
frontend = get_frontend(args)
frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
sentences = get_sentences(args)
sentences = get_sentences(text_file=args.text, lang=args.lang)
am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
fs = 24000 if am_dataset != 'ljspeech' else 22050
# am
am_encoder_infer_sess, am_decoder_sess, am_postnet_sess = get_streaming_am_sess(
args)
# streaming acoustic model
am_encoder_infer_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.am + "_am_encoder_infer" + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
am_decoder_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.am + "_am_decoder" + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
am_postnet_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.am + "_am_postnet" + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
am_mu, am_std = np.load(args.am_stat)
# vocoder
voc_sess = get_sess(args, filed='voc')
voc_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.voc + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
# frontend warmup
# Loading model cost 0.5+ seconds

@ -14,6 +14,10 @@
import math
import os
from pathlib import Path
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
import numpy as np
import onnxruntime as ort
@ -21,6 +25,7 @@ import paddle
from paddle import inference
from paddle import jit
from paddle.static import InputSpec
from yacs.config import CfgNode
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.t2s.datasets.data_table import DataTable
@ -70,7 +75,7 @@ def denorm(data, mean, std):
return data * std + mean
def get_chunks(data, chunk_size, pad_size):
def get_chunks(data, chunk_size: int, pad_size: int):
data_len = data.shape[1]
chunks = []
n = math.ceil(data_len / chunk_size)
@ -82,28 +87,34 @@ def get_chunks(data, chunk_size, pad_size):
# input
def get_sentences(args):
def get_sentences(text_file: Optional[os.PathLike], lang: str='zh'):
# construct dataset for evaluation
sentences = []
with open(args.text, 'rt') as f:
with open(text_file, 'rt') as f:
for line in f:
items = line.strip().split()
utt_id = items[0]
if 'lang' in args and args.lang == 'zh':
if lang == 'zh':
sentence = "".join(items[1:])
elif 'lang' in args and args.lang == 'en':
elif lang == 'en':
sentence = " ".join(items[1:])
sentences.append((utt_id, sentence))
return sentences
def get_test_dataset(args, test_metadata, am_name, am_dataset):
def get_test_dataset(test_metadata: List[Dict[str, Any]],
am: str,
speaker_dict: Optional[os.PathLike]=None,
voice_cloning: bool=False):
# model: {model_name}_{dataset}
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
if am_name == 'fastspeech2':
fields = ["utt_id", "text"]
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None:
print("multiple speaker fastspeech2!")
fields += ["spk_id"]
elif 'voice_cloning' in args and args.voice_cloning:
elif voice_cloning:
print("voice cloning!")
fields += ["spk_emb"]
else:
@ -112,7 +123,7 @@ def get_test_dataset(args, test_metadata, am_name, am_dataset):
fields = ["utt_id", "phones", "tones"]
elif am_name == 'tacotron2':
fields = ["utt_id", "text"]
if 'voice_cloning' in args and args.voice_cloning:
if voice_cloning:
print("voice cloning!")
fields += ["spk_emb"]
@ -121,12 +132,14 @@ def get_test_dataset(args, test_metadata, am_name, am_dataset):
# frontend
def get_frontend(args):
if 'lang' in args and args.lang == 'zh':
def get_frontend(lang: str='zh',
phones_dict: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None):
if lang == 'zh':
frontend = Frontend(
phone_vocab_path=args.phones_dict, tone_vocab_path=args.tones_dict)
elif 'lang' in args and args.lang == 'en':
frontend = English(phone_vocab_path=args.phones_dict)
phone_vocab_path=phones_dict, tone_vocab_path=tones_dict)
elif lang == 'en':
frontend = English(phone_vocab_path=phones_dict)
else:
print("wrong lang!")
print("frontend done!")
@ -134,30 +147,37 @@ def get_frontend(args):
# dygraph
def get_am_inference(args, am_config):
with open(args.phones_dict, "r") as f:
def get_am_inference(
am: str='fastspeech2_csmsc',
am_config: CfgNode=None,
am_ckpt: Optional[os.PathLike]=None,
am_stat: Optional[os.PathLike]=None,
phones_dict: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None,
speaker_dict: Optional[os.PathLike]=None, ):
with open(phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id)
print("vocab_size:", vocab_size)
tone_size = None
if 'tones_dict' in args and args.tones_dict:
with open(args.tones_dict, "r") as f:
if tones_dict is not None:
with open(tones_dict, "r") as f:
tone_id = [line.strip().split() for line in f.readlines()]
tone_size = len(tone_id)
print("tone_size:", tone_size)
spk_num = None
if 'speaker_dict' in args and args.speaker_dict:
with open(args.speaker_dict, 'rt') as f:
if speaker_dict is not None:
with open(speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id)
print("spk_num:", spk_num)
odim = am_config.n_mels
# model: {model_name}_{dataset}
am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
am_class = dynamic_import(am_name, model_alias)
am_inference_class = dynamic_import(am_name + '_inference', model_alias)
@ -174,34 +194,38 @@ def get_am_inference(args, am_config):
elif am_name == 'tacotron2':
am = am_class(idim=vocab_size, odim=odim, **am_config["model"])
am.set_state_dict(paddle.load(args.am_ckpt)["main_params"])
am.set_state_dict(paddle.load(am_ckpt)["main_params"])
am.eval()
am_mu, am_std = np.load(args.am_stat)
am_mu, am_std = np.load(am_stat)
am_mu = paddle.to_tensor(am_mu)
am_std = paddle.to_tensor(am_std)
am_normalizer = ZScore(am_mu, am_std)
am_inference = am_inference_class(am_normalizer, am)
am_inference.eval()
print("acoustic model done!")
return am_inference, am_name, am_dataset
return am_inference
def get_voc_inference(args, voc_config):
def get_voc_inference(
voc: str='pwgan_csmsc',
voc_config: Optional[os.PathLike]=None,
voc_ckpt: Optional[os.PathLike]=None,
voc_stat: Optional[os.PathLike]=None, ):
# model: {model_name}_{dataset}
voc_name = args.voc[:args.voc.rindex('_')]
voc_name = voc[:voc.rindex('_')]
voc_class = dynamic_import(voc_name, model_alias)
voc_inference_class = dynamic_import(voc_name + '_inference', model_alias)
if voc_name != 'wavernn':
voc = voc_class(**voc_config["generator_params"])
voc.set_state_dict(paddle.load(args.voc_ckpt)["generator_params"])
voc.set_state_dict(paddle.load(voc_ckpt)["generator_params"])
voc.remove_weight_norm()
voc.eval()
else:
voc = voc_class(**voc_config["model"])
voc.set_state_dict(paddle.load(args.voc_ckpt)["main_params"])
voc.set_state_dict(paddle.load(voc_ckpt)["main_params"])
voc.eval()
voc_mu, voc_std = np.load(args.voc_stat)
voc_mu, voc_std = np.load(voc_stat)
voc_mu = paddle.to_tensor(voc_mu)
voc_std = paddle.to_tensor(voc_std)
voc_normalizer = ZScore(voc_mu, voc_std)
@ -211,10 +235,16 @@ def get_voc_inference(args, voc_config):
return voc_inference
# to static
def am_to_static(args, am_inference, am_name, am_dataset):
# dygraph to static graph
def am_to_static(am_inference,
am: str='fastspeech2_csmsc',
inference_dir=Optional[os.PathLike],
speaker_dict: Optional[os.PathLike]=None):
# model: {model_name}_{dataset}
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
if am_name == 'fastspeech2':
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None:
am_inference = jit.to_static(
am_inference,
input_spec=[
@ -226,7 +256,7 @@ def am_to_static(args, am_inference, am_name, am_dataset):
am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
elif am_name == 'speedyspeech':
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None:
am_inference = jit.to_static(
am_inference,
input_spec=[
@ -247,56 +277,64 @@ def am_to_static(args, am_inference, am_name, am_dataset):
am_inference = jit.to_static(
am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
paddle.jit.save(am_inference, os.path.join(args.inference_dir, args.am))
am_inference = paddle.jit.load(os.path.join(args.inference_dir, args.am))
paddle.jit.save(am_inference, os.path.join(inference_dir, am))
am_inference = paddle.jit.load(os.path.join(inference_dir, am))
return am_inference
def voc_to_static(args, voc_inference):
def voc_to_static(voc_inference,
voc: str='pwgan_csmsc',
inference_dir=Optional[os.PathLike]):
voc_inference = jit.to_static(
voc_inference, input_spec=[
InputSpec([-1, 80], dtype=paddle.float32),
])
paddle.jit.save(voc_inference, os.path.join(args.inference_dir, args.voc))
voc_inference = paddle.jit.load(os.path.join(args.inference_dir, args.voc))
paddle.jit.save(voc_inference, os.path.join(inference_dir, voc))
voc_inference = paddle.jit.load(os.path.join(inference_dir, voc))
return voc_inference
# inference
def get_predictor(args, filed='am'):
full_name = ''
if filed == 'am':
full_name = args.am
elif filed == 'voc':
full_name = args.voc
def get_predictor(model_dir: Optional[os.PathLike]=None,
model_file: Optional[os.PathLike]=None,
params_file: Optional[os.PathLike]=None,
device: str='cpu'):
config = inference.Config(
str(Path(args.inference_dir) / (full_name + ".pdmodel")),
str(Path(args.inference_dir) / (full_name + ".pdiparams")))
if args.device == "gpu":
str(Path(model_dir) / model_file), str(Path(model_dir) / params_file))
if device == "gpu":
config.enable_use_gpu(100, 0)
elif args.device == "cpu":
elif device == "cpu":
config.disable_gpu()
config.enable_memory_optim()
predictor = inference.create_predictor(config)
return predictor
def get_am_output(args, am_predictor, frontend, merge_sentences, input):
am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
def get_am_output(
input: str,
am_predictor,
am,
frontend,
lang: str='zh',
merge_sentences: bool=True,
speaker_dict: Optional[os.PathLike]=None,
spk_id: int=0, ):
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
am_input_names = am_predictor.get_input_names()
get_tone_ids = False
get_spk_id = False
if am_name == 'speedyspeech':
get_tone_ids = True
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
if am_dataset in {"aishell3", "vctk"} and speaker_dict:
get_spk_id = True
spk_id = np.array([args.spk_id])
if args.lang == 'zh':
spk_id = np.array([spk_id])
if lang == 'zh':
input_ids = frontend.get_input_ids(
input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
elif args.lang == 'en':
elif lang == 'en':
input_ids = frontend.get_input_ids(
input, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
@ -338,50 +376,6 @@ def get_voc_output(voc_predictor, input):
return wav
# streaming am
def get_streaming_am_predictor(args):
full_name = args.am
am_encoder_infer_config = inference.Config(
str(
Path(args.inference_dir) /
(full_name + "_am_encoder_infer" + ".pdmodel")),
str(
Path(args.inference_dir) /
(full_name + "_am_encoder_infer" + ".pdiparams")))
am_decoder_config = inference.Config(
str(
Path(args.inference_dir) /
(full_name + "_am_decoder" + ".pdmodel")),
str(
Path(args.inference_dir) /
(full_name + "_am_decoder" + ".pdiparams")))
am_postnet_config = inference.Config(
str(
Path(args.inference_dir) /
(full_name + "_am_postnet" + ".pdmodel")),
str(
Path(args.inference_dir) /
(full_name + "_am_postnet" + ".pdiparams")))
if args.device == "gpu":
am_encoder_infer_config.enable_use_gpu(100, 0)
am_decoder_config.enable_use_gpu(100, 0)
am_postnet_config.enable_use_gpu(100, 0)
elif args.device == "cpu":
am_encoder_infer_config.disable_gpu()
am_decoder_config.disable_gpu()
am_postnet_config.disable_gpu()
am_encoder_infer_config.enable_memory_optim()
am_decoder_config.enable_memory_optim()
am_postnet_config.enable_memory_optim()
am_encoder_infer_predictor = inference.create_predictor(
am_encoder_infer_config)
am_decoder_predictor = inference.create_predictor(am_decoder_config)
am_postnet_predictor = inference.create_predictor(am_postnet_config)
return am_encoder_infer_predictor, am_decoder_predictor, am_postnet_predictor
def get_am_sublayer_output(am_sublayer_predictor, input):
am_sublayer_input_names = am_sublayer_predictor.get_input_names()
input_handle = am_sublayer_predictor.get_input_handle(
@ -397,11 +391,15 @@ def get_am_sublayer_output(am_sublayer_predictor, input):
return am_sublayer_output
def get_streaming_am_output(args, am_encoder_infer_predictor,
am_decoder_predictor, am_postnet_predictor,
frontend, merge_sentences, input):
def get_streaming_am_output(input: str,
am_encoder_infer_predictor,
am_decoder_predictor,
am_postnet_predictor,
frontend,
lang: str='zh',
merge_sentences: bool=True):
get_tone_ids = False
if args.lang == 'zh':
if lang == 'zh':
input_ids = frontend.get_input_ids(
input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
@ -423,58 +421,27 @@ def get_streaming_am_output(args, am_encoder_infer_predictor,
return normalized_mel
def get_sess(args, filed='am'):
full_name = ''
if filed == 'am':
full_name = args.am
elif filed == 'voc':
full_name = args.voc
model_dir = str(Path(args.inference_dir) / (full_name + ".onnx"))
# onnx
def get_sess(model_dir: Optional[os.PathLike]=None,
model_file: Optional[os.PathLike]=None,
device: str='cpu',
cpu_threads: int=1,
use_trt: bool=False):
model_dir = str(Path(model_dir) / model_file)
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
if args.device == "gpu":
if device == "gpu":
# fastspeech2/mb_melgan can't use trt now!
if args.use_trt:
if use_trt:
providers = ['TensorrtExecutionProvider']
else:
providers = ['CUDAExecutionProvider']
elif args.device == "cpu":
elif device == "cpu":
providers = ['CPUExecutionProvider']
sess_options.intra_op_num_threads = args.cpu_threads
sess_options.intra_op_num_threads = cpu_threads
sess = ort.InferenceSession(
model_dir, providers=providers, sess_options=sess_options)
return sess
# streaming am
def get_streaming_am_sess(args):
full_name = args.am
am_encoder_infer_model_dir = str(
Path(args.inference_dir) / (full_name + "_am_encoder_infer" + ".onnx"))
am_decoder_model_dir = str(
Path(args.inference_dir) / (full_name + "_am_decoder" + ".onnx"))
am_postnet_model_dir = str(
Path(args.inference_dir) / (full_name + "_am_postnet" + ".onnx"))
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
if args.device == "gpu":
# fastspeech2/mb_melgan can't use trt now!
if args.use_trt:
providers = ['TensorrtExecutionProvider']
else:
providers = ['CUDAExecutionProvider']
elif args.device == "cpu":
providers = ['CPUExecutionProvider']
sess_options.intra_op_num_threads = args.cpu_threads
am_encoder_infer_sess = ort.InferenceSession(
am_encoder_infer_model_dir,
providers=providers,
sess_options=sess_options)
am_decoder_sess = ort.InferenceSession(
am_decoder_model_dir, providers=providers, sess_options=sess_options)
am_postnet_sess = ort.InferenceSession(
am_postnet_model_dir, providers=providers, sess_options=sess_options)
return am_encoder_infer_sess, am_decoder_sess, am_postnet_sess

@ -50,11 +50,29 @@ def evaluate(args):
print(voc_config)
# acoustic model
am_inference, am_name, am_dataset = get_am_inference(args, am_config)
test_dataset = get_test_dataset(args, test_metadata, am_name, am_dataset)
am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
am_inference = get_am_inference(
am=args.am,
am_config=am_config,
am_ckpt=args.am_ckpt,
am_stat=args.am_stat,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict,
speaker_dict=args.speaker_dict)
test_dataset = get_test_dataset(
test_metadata=test_metadata,
am=args.am,
speaker_dict=args.speaker_dict,
voice_cloning=args.voice_cloning)
# vocoder
voc_inference = get_voc_inference(args, voc_config)
voc_inference = get_voc_inference(
voc=args.voc,
voc_config=voc_config,
voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

@ -42,24 +42,48 @@ def evaluate(args):
print(am_config)
print(voc_config)
sentences = get_sentences(args)
sentences = get_sentences(text_file=args.text, lang=args.lang)
# frontend
frontend = get_frontend(args)
frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
# acoustic model
am_inference, am_name, am_dataset = get_am_inference(args, am_config)
am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
am_inference = get_am_inference(
am=args.am,
am_config=am_config,
am_ckpt=args.am_ckpt,
am_stat=args.am_stat,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict,
speaker_dict=args.speaker_dict)
# vocoder
voc_inference = get_voc_inference(args, voc_config)
voc_inference = get_voc_inference(
voc=args.voc,
voc_config=voc_config,
voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat)
# whether dygraph to static
if args.inference_dir:
# acoustic model
am_inference = am_to_static(args, am_inference, am_name, am_dataset)
am_inference = am_to_static(
am_inference=am_inference,
am=args.am,
inference_dir=args.inference_dir,
speaker_dict=args.speaker_dict)
# vocoder
voc_inference = voc_to_static(args, voc_inference)
voc_inference = voc_to_static(
voc_inference=voc_inference,
voc=args.voc,
inference_dir=args.inference_dir)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

@ -49,10 +49,13 @@ def evaluate(args):
print(am_config)
print(voc_config)
sentences = get_sentences(args)
sentences = get_sentences(text_file=args.text, lang=args.lang)
# frontend
frontend = get_frontend(args)
frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
@ -60,7 +63,6 @@ def evaluate(args):
print("vocab_size:", vocab_size)
# acoustic model, only support fastspeech2 here now!
# am_inference, am_name, am_dataset = get_am_inference(args, am_config)
# model: {model_name}_{dataset}
am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
@ -80,7 +82,11 @@ def evaluate(args):
am_postnet = am.postnet
# vocoder
voc_inference = get_voc_inference(args, voc_config)
voc_inference = get_voc_inference(
voc=args.voc,
voc_config=voc_config,
voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat)
# whether dygraph to static
if args.inference_dir:
@ -115,7 +121,10 @@ def evaluate(args):
os.path.join(args.inference_dir, args.am + "_am_postnet"))
# vocoder
voc_inference = voc_to_static(args, voc_inference)
voc_inference = voc_to_static(
voc_inference=voc_inference,
voc=args.voc,
inference_dir=args.inference_dir)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

@ -66,10 +66,19 @@ def voice_cloning(args):
print("frontend done!")
# acoustic model
am_inference, *_ = get_am_inference(args, am_config)
am_inference = get_am_inference(
am=args.am,
am_config=am_config,
am_ckpt=args.am_ckpt,
am_stat=args.am_stat,
phones_dict=args.phones_dict)
# vocoder
voc_inference = get_voc_inference(args, voc_config)
voc_inference = get_voc_inference(
voc=args.voc,
voc_config=voc_config,
voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

@ -58,8 +58,7 @@ def main():
else:
print("ngpu should >= 0 !")
model = WaveRNN(
hop_length=config.n_shift, sample_rate=config.fs, **config["model"])
model = WaveRNN(**config["model"])
state_dict = paddle.load(args.checkpoint)
model.set_state_dict(state_dict["main_params"])

Loading…
Cancel
Save