[TTS]vits dygraph to static (#2883)

Co-authored-by: 0x45f <wangzhen45@baidu.com>
pull/2964/head
TianYuan 2 years ago committed by GitHub
parent 11bc392617
commit 84f751f529
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -18,5 +18,6 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--phones_dict=dump/phone_id_map.txt \ --phones_dict=dump/phone_id_map.txt \
--output_dir=${train_output_path}/test_e2e \ --output_dir=${train_output_path}/test_e2e \
--text=${BIN_DIR}/../sentences.txt \ --text=${BIN_DIR}/../sentences.txt \
--add-blank=${add_blank} --add-blank=${add_blank} #\
# --inference_dir=${train_output_path}/inference
fi fi

@ -452,9 +452,19 @@ def am_to_static(am_inference,
elif am_name == 'tacotron2': elif am_name == 'tacotron2':
am_inference = jit.to_static( am_inference = jit.to_static(
am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)]) am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
elif am_name == 'vits':
paddle.jit.save(am_inference, os.path.join(inference_dir, am)) if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None:
am_inference = paddle.jit.load(os.path.join(inference_dir, am)) am_inference = jit.to_static(
am_inference,
input_spec=[
InputSpec([-1], dtype=paddle.int64),
InputSpec([1], dtype=paddle.int64),
])
else:
am_inference = jit.to_static(
am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
jit.save(am_inference, os.path.join(inference_dir, am))
am_inference = jit.load(os.path.join(inference_dir, am))
return am_inference return am_inference
@ -465,8 +475,8 @@ def voc_to_static(voc_inference,
voc_inference, input_spec=[ voc_inference, input_spec=[
InputSpec([-1, 80], dtype=paddle.float32), InputSpec([-1, 80], dtype=paddle.float32),
]) ])
paddle.jit.save(voc_inference, os.path.join(inference_dir, voc)) jit.save(voc_inference, os.path.join(inference_dir, voc))
voc_inference = paddle.jit.load(os.path.join(inference_dir, voc)) voc_inference = jit.load(os.path.join(inference_dir, voc))
return voc_inference return voc_inference

@ -20,14 +20,15 @@ import yaml
from timer import timer from timer import timer
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.t2s.exps.syn_utils import am_to_static
from paddlespeech.t2s.exps.syn_utils import get_frontend from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_sentences from paddlespeech.t2s.exps.syn_utils import get_sentences
from paddlespeech.t2s.models.vits import VITS from paddlespeech.t2s.models.vits import VITS
from paddlespeech.t2s.models.vits import VITSInference
from paddlespeech.t2s.utils import str2bool from paddlespeech.t2s.utils import str2bool
def evaluate(args): def evaluate(args):
# Init body. # Init body.
with open(args.config) as f: with open(args.config) as f:
config = CfgNode(yaml.safe_load(f)) config = CfgNode(yaml.safe_load(f))
@ -41,6 +42,9 @@ def evaluate(args):
# frontend # frontend
frontend = get_frontend(lang=args.lang, phones_dict=args.phones_dict) frontend = get_frontend(lang=args.lang, phones_dict=args.phones_dict)
# acoustic model
am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
spk_num = None spk_num = None
if args.speaker_dict is not None: if args.speaker_dict is not None:
@ -64,6 +68,15 @@ def evaluate(args):
vits.set_state_dict(paddle.load(args.ckpt)["main_params"]) vits.set_state_dict(paddle.load(args.ckpt)["main_params"])
vits.eval() vits.eval()
vits_inference = VITSInference(vits)
# whether dygraph to static
if args.inference_dir:
vits_inference = am_to_static(
am_inference=vits_inference,
am=args.am,
inference_dir=args.inference_dir,
speaker_dict=args.speaker_dict)
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
merge_sentences = False merge_sentences = False
@ -90,10 +103,12 @@ def evaluate(args):
for i in range(len(phone_ids)): for i in range(len(phone_ids)):
part_phone_ids = phone_ids[i] part_phone_ids = phone_ids[i]
spk_id = None spk_id = None
if spk_num is not None: if am_dataset in {"aishell3",
"vctk"} and spk_num is not None:
spk_id = paddle.to_tensor(args.spk_id) spk_id = paddle.to_tensor(args.spk_id)
out = vits.inference(text=part_phone_ids, sids=spk_id) wav = vits_inference(part_phone_ids, spk_id)
wav = out["wav"] else:
wav = vits_inference(part_phone_ids)
if flags == 0: if flags == 0:
wav_all = wav wav_all = wav
flags = 1 flags = 1
@ -155,6 +170,11 @@ def parse_args():
type=str2bool, type=str2bool,
default=True, default=True,
help="whether to add blank between phones") help="whether to add blank between phones")
parser.add_argument(
'--am',
type=str,
default='vits_csmsc',
help='Choose acoustic model type of tts task.')
args = parser.parse_args() args = parser.parse_args()
return args return args

@ -35,9 +35,10 @@ def piecewise_rational_quadratic_transform(
inverse=False, inverse=False,
tails=None, tails=None,
tail_bound=1.0, tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH, # for dygraph-to-static
min_bin_height=DEFAULT_MIN_BIN_HEIGHT, min_bin_width=1e-3,
min_derivative=DEFAULT_MIN_DERIVATIVE, ): min_bin_height=1e-3,
min_derivative=1e-3, ):
if tails is None: if tails is None:
spline_fn = rational_quadratic_spline spline_fn = rational_quadratic_spline
spline_kwargs = {} spline_kwargs = {}
@ -74,14 +75,17 @@ def unconstrained_rational_quadratic_spline(
inverse=False, inverse=False,
tails="linear", tails="linear",
tail_bound=1.0, tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH, # for dygraph-to-static
min_bin_height=DEFAULT_MIN_BIN_HEIGHT, min_bin_width=1e-3,
min_derivative=DEFAULT_MIN_DERIVATIVE, ): min_bin_height=1e-3,
min_derivative=1e-3, ):
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
outside_interval_mask = ~inside_interval_mask outside_interval_mask = ~inside_interval_mask
# for dygraph to static
outputs = paddle.zeros(paddle.shape(inputs)) # 这里用 paddle.shape(x) 然后调用 zeros 会得到一个全 -1 shape 的 var
logabsdet = paddle.zeros(paddle.shape(inputs)) # 如果用 x.shape 的话可以保留确定的维度
outputs = paddle.zeros(inputs.shape)
logabsdet = paddle.zeros(inputs.shape)
if tails == "linear": if tails == "linear":
unnormalized_derivatives = F.pad( unnormalized_derivatives = F.pad(
unnormalized_derivatives, unnormalized_derivatives,
@ -89,8 +93,9 @@ def unconstrained_rational_quadratic_spline(
constant = np.log(np.exp(1 - min_derivative) - 1) constant = np.log(np.exp(1 - min_derivative) - 1)
unnormalized_derivatives[..., 0] = constant unnormalized_derivatives[..., 0] = constant
unnormalized_derivatives[..., -1] = constant unnormalized_derivatives[..., -1] = constant
# for dygraph to static
outputs[outside_interval_mask] = inputs[outside_interval_mask] tmp = inputs[outside_interval_mask]
outputs[outside_interval_mask] = tmp
logabsdet[outside_interval_mask] = 0 logabsdet[outside_interval_mask] = 0
else: else:
raise RuntimeError("{} tails are not implemented.".format(tails)) raise RuntimeError("{} tails are not implemented.".format(tails))
@ -130,18 +135,20 @@ def rational_quadratic_spline(
right=1.0, right=1.0,
bottom=0.0, bottom=0.0,
top=1.0, top=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH, # for dygraph-to-static
min_bin_height=DEFAULT_MIN_BIN_HEIGHT, min_bin_width=1e-3,
min_derivative=DEFAULT_MIN_DERIVATIVE, ): min_bin_height=1e-3,
if paddle.min(inputs) < left or paddle.max(inputs) > right: min_derivative=1e-3, ):
raise ValueError("Input to a transform is not within its domain") # for dygraph to static
# if paddle.min(inputs) < left or paddle.max(inputs) > right:
# raise ValueError("Input to a transform is not within its domain")
num_bins = unnormalized_widths.shape[-1] num_bins = unnormalized_widths.shape[-1]
# for dygraph to static
if min_bin_width * num_bins > 1.0: # if min_bin_width * num_bins > 1.0:
raise ValueError("Minimal bin width too large for the number of bins") # raise ValueError("Minimal bin width too large for the number of bins")
if min_bin_height * num_bins > 1.0: # if min_bin_height * num_bins > 1.0:
raise ValueError("Minimal bin height too large for the number of bins") # raise ValueError("Minimal bin height too large for the number of bins")
widths = F.softmax(unnormalized_widths, axis=-1) widths = F.softmax(unnormalized_widths, axis=-1)
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths widths = min_bin_width + (1 - min_bin_width * num_bins) * widths

@ -532,3 +532,14 @@ class VITS(nn.Layer):
module.weight[module._padding_idx] = 0 module.weight[module._padding_idx] = 0
self.apply(_reset_parameters) self.apply(_reset_parameters)
class VITSInference(nn.Layer):
def __init__(self, model):
super().__init__()
self.acoustic_model = model
def forward(self, text, sids=None):
out = self.acoustic_model.inference(
text, sids=sids)
wav = out['wav']
return wav

Loading…
Cancel
Save