|
|
|
@ -20,14 +20,15 @@ import yaml
|
|
|
|
|
from timer import timer
|
|
|
|
|
from yacs.config import CfgNode
|
|
|
|
|
|
|
|
|
|
from paddlespeech.t2s.exps.syn_utils import am_to_static
|
|
|
|
|
from paddlespeech.t2s.exps.syn_utils import get_frontend
|
|
|
|
|
from paddlespeech.t2s.exps.syn_utils import get_sentences
|
|
|
|
|
from paddlespeech.t2s.models.vits import VITS
|
|
|
|
|
from paddlespeech.t2s.models.vits import VITSInference
|
|
|
|
|
from paddlespeech.t2s.utils import str2bool
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate(args):
|
|
|
|
|
|
|
|
|
|
# Init body.
|
|
|
|
|
with open(args.config) as f:
|
|
|
|
|
config = CfgNode(yaml.safe_load(f))
|
|
|
|
@ -63,7 +64,22 @@ def evaluate(args):
|
|
|
|
|
vits = VITS(idim=vocab_size, odim=odim, **config["model"])
|
|
|
|
|
vits.set_state_dict(paddle.load(args.ckpt)["main_params"])
|
|
|
|
|
vits.eval()
|
|
|
|
|
|
|
|
|
|
VITSInference
|
|
|
|
|
|
|
|
|
|
vits_inference = VITSInference(vits)
|
|
|
|
|
# whether dygraph to static
|
|
|
|
|
if args.inference_dir:
|
|
|
|
|
# acoustic model
|
|
|
|
|
# vits = jit.to_static(
|
|
|
|
|
# vits, input_spec=[InputSpec([-1], dtype=paddle.int64)])
|
|
|
|
|
# jit.save(vits, os.path.join(inference_dir, args.am))
|
|
|
|
|
# vits = jit.load(os.path.join(inference_dir, args.am))
|
|
|
|
|
vits_inference = am_to_static(
|
|
|
|
|
am_inference=vits_inference,
|
|
|
|
|
am=args.am,
|
|
|
|
|
inference_dir=args.inference_dir,
|
|
|
|
|
speaker_dict=args.speaker_dict)
|
|
|
|
|
|
|
|
|
|
output_dir = Path(args.output_dir)
|
|
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
merge_sentences = False
|
|
|
|
@ -92,8 +108,7 @@ def evaluate(args):
|
|
|
|
|
spk_id = None
|
|
|
|
|
if spk_num is not None:
|
|
|
|
|
spk_id = paddle.to_tensor(args.spk_id)
|
|
|
|
|
out = vits.inference(text=part_phone_ids, sids=spk_id)
|
|
|
|
|
wav = out["wav"]
|
|
|
|
|
wav = vits_inference(text=part_phone_ids, sids=spk_id)
|
|
|
|
|
if flags == 0:
|
|
|
|
|
wav_all = wav
|
|
|
|
|
flags = 1
|
|
|
|
@ -155,6 +170,11 @@ def parse_args():
|
|
|
|
|
type=str2bool,
|
|
|
|
|
default=True,
|
|
|
|
|
help="whether to add blank between phones")
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
'--am',
|
|
|
|
|
type=str,
|
|
|
|
|
default='vits_csmsc',
|
|
|
|
|
help='Choose acoustic model type of tts task.')
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
return args
|
|
|
|
|