vits dygraph to static

pull/2883/head
TianYuan 3 years ago
parent 6b00ad6064
commit 71eabceedd

@ -18,5 +18,6 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--phones_dict=dump/phone_id_map.txt \
--output_dir=${train_output_path}/test_e2e \
--text=${BIN_DIR}/../sentences.txt \
--add-blank=${add_blank}
--add-blank=${add_blank} #\
# --inference_dir=${train_output_path}/inference
fi

@ -9,7 +9,7 @@ stop_stage=100
conf_path=conf/default.yaml
train_output_path=exp/default
ckpt_name=snapshot_iter_153.pdz
ckpt_name=snapshot_iter_333000.pdz
add_blank=true
# with the following command, you can choose the stage range you want to run

@ -445,9 +445,11 @@ def am_to_static(am_inference,
elif am_name == 'tacotron2':
am_inference = jit.to_static(
am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
paddle.jit.save(am_inference, os.path.join(inference_dir, am))
am_inference = paddle.jit.load(os.path.join(inference_dir, am))
elif am_name == 'vits':
am_inference = jit.to_static(
am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
jit.save(am_inference, os.path.join(inference_dir, am))
am_inference = jit.load(os.path.join(inference_dir, am))
return am_inference
@ -458,8 +460,8 @@ def voc_to_static(voc_inference,
voc_inference, input_spec=[
InputSpec([-1, 80], dtype=paddle.float32),
])
paddle.jit.save(voc_inference, os.path.join(inference_dir, voc))
voc_inference = paddle.jit.load(os.path.join(inference_dir, voc))
jit.save(voc_inference, os.path.join(inference_dir, voc))
voc_inference = jit.load(os.path.join(inference_dir, voc))
return voc_inference

@ -20,14 +20,15 @@ import yaml
from timer import timer
from yacs.config import CfgNode
from paddlespeech.t2s.exps.syn_utils import am_to_static
from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_sentences
from paddlespeech.t2s.models.vits import VITS
from paddlespeech.t2s.models.vits import VITSInference
from paddlespeech.t2s.utils import str2bool
def evaluate(args):
# Init body.
with open(args.config) as f:
config = CfgNode(yaml.safe_load(f))
@ -63,7 +64,22 @@ def evaluate(args):
vits = VITS(idim=vocab_size, odim=odim, **config["model"])
vits.set_state_dict(paddle.load(args.ckpt)["main_params"])
vits.eval()
VITSInference
vits_inference = VITSInference(vits)
# whether dygraph to static
if args.inference_dir:
# acoustic model
# vits = jit.to_static(
# vits, input_spec=[InputSpec([-1], dtype=paddle.int64)])
# jit.save(vits, os.path.join(inference_dir, args.am))
# vits = jit.load(os.path.join(inference_dir, args.am))
vits_inference = am_to_static(
am_inference=vits_inference,
am=args.am,
inference_dir=args.inference_dir,
speaker_dict=args.speaker_dict)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
merge_sentences = False
@ -92,8 +108,7 @@ def evaluate(args):
spk_id = None
if spk_num is not None:
spk_id = paddle.to_tensor(args.spk_id)
out = vits.inference(text=part_phone_ids, sids=spk_id)
wav = out["wav"]
wav = vits_inference(text=part_phone_ids, sids=spk_id)
if flags == 0:
wav_all = wav
flags = 1
@ -155,6 +170,11 @@ def parse_args():
type=str2bool,
default=True,
help="whether to add blank between phones")
parser.add_argument(
'--am',
type=str,
default='vits_csmsc',
help='Choose acoustic model type of tts task.')
args = parser.parse_args()
return args

@ -35,9 +35,10 @@ def piecewise_rational_quadratic_transform(
inverse=False,
tails=None,
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE, ):
# for dygraph-to-static
min_bin_width=1e-3,
min_bin_height=1e-3,
min_derivative=1e-3, ):
if tails is None:
spline_fn = rational_quadratic_spline
spline_kwargs = {}
@ -74,9 +75,10 @@ def unconstrained_rational_quadratic_spline(
inverse=False,
tails="linear",
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE, ):
# for dygraph-to-static
min_bin_width=1e-3,
min_bin_height=1e-3,
min_derivative=1e-3, ):
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
outside_interval_mask = ~inside_interval_mask
@ -89,8 +91,12 @@ def unconstrained_rational_quadratic_spline(
constant = np.log(np.exp(1 - min_derivative) - 1)
unnormalized_derivatives[..., 0] = constant
unnormalized_derivatives[..., -1] = constant
outputs[outside_interval_mask] = inputs[outside_interval_mask]
# import pdb
# pdb.set_trace()
# print("inputs:",inputs)
# print("outside_interval_mask:",outside_interval_mask)
a = inputs[outside_interval_mask]
outputs[outside_interval_mask] = a
logabsdet[outside_interval_mask] = 0
else:
raise RuntimeError("{} tails are not implemented.".format(tails))
@ -130,9 +136,10 @@ def rational_quadratic_spline(
right=1.0,
bottom=0.0,
top=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE, ):
# for dygraph-to-static
min_bin_width=1e-3,
min_bin_height=1e-3,
min_derivative=1e-3, ):
if paddle.min(inputs) < left or paddle.max(inputs) > right:
raise ValueError("Input to a transform is not within its domain")

@ -532,3 +532,14 @@ class VITS(nn.Layer):
module.weight[module._padding_idx] = 0
self.apply(_reset_parameters)
class VITSInference(nn.Layer):
def __init__(self, model):
super().__init__()
self.acoustic_model = model
def forward(self, text, sids=None):
out = self.acoustic_model.inference(
text, sids=sids)
wav = out['wav']
return wav

Loading…
Cancel
Save