diff --git a/examples/csmsc/vits/local/synthesize_e2e.sh b/examples/csmsc/vits/local/synthesize_e2e.sh index 3f3bf651..7430f52a 100755 --- a/examples/csmsc/vits/local/synthesize_e2e.sh +++ b/examples/csmsc/vits/local/synthesize_e2e.sh @@ -18,5 +18,6 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then --phones_dict=dump/phone_id_map.txt \ --output_dir=${train_output_path}/test_e2e \ --text=${BIN_DIR}/../sentences.txt \ - --add-blank=${add_blank} + --add-blank=${add_blank} #\ + # --inference_dir=${train_output_path}/inference fi diff --git a/paddlespeech/t2s/exps/syn_utils.py b/paddlespeech/t2s/exps/syn_utils.py index dd3b4d55..ebe104c1 100644 --- a/paddlespeech/t2s/exps/syn_utils.py +++ b/paddlespeech/t2s/exps/syn_utils.py @@ -452,9 +452,19 @@ def am_to_static(am_inference, elif am_name == 'tacotron2': am_inference = jit.to_static( am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)]) - - paddle.jit.save(am_inference, os.path.join(inference_dir, am)) - am_inference = paddle.jit.load(os.path.join(inference_dir, am)) + elif am_name == 'vits': + if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None: + am_inference = jit.to_static( + am_inference, + input_spec=[ + InputSpec([-1], dtype=paddle.int64), + InputSpec([1], dtype=paddle.int64), + ]) + else: + am_inference = jit.to_static( + am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)]) + jit.save(am_inference, os.path.join(inference_dir, am)) + am_inference = jit.load(os.path.join(inference_dir, am)) return am_inference @@ -465,8 +475,8 @@ def voc_to_static(voc_inference, voc_inference, input_spec=[ InputSpec([-1, 80], dtype=paddle.float32), ]) - paddle.jit.save(voc_inference, os.path.join(inference_dir, voc)) - voc_inference = paddle.jit.load(os.path.join(inference_dir, voc)) + jit.save(voc_inference, os.path.join(inference_dir, voc)) + voc_inference = jit.load(os.path.join(inference_dir, voc)) return voc_inference diff --git a/paddlespeech/t2s/exps/vits/synthesize_e2e.py b/paddlespeech/t2s/exps/vits/synthesize_e2e.py index f9d10ea6..9768a16e 100644 --- a/paddlespeech/t2s/exps/vits/synthesize_e2e.py +++ b/paddlespeech/t2s/exps/vits/synthesize_e2e.py @@ -20,14 +20,15 @@ import yaml from timer import timer from yacs.config import CfgNode +from paddlespeech.t2s.exps.syn_utils import am_to_static from paddlespeech.t2s.exps.syn_utils import get_frontend from paddlespeech.t2s.exps.syn_utils import get_sentences from paddlespeech.t2s.models.vits import VITS +from paddlespeech.t2s.models.vits import VITSInference from paddlespeech.t2s.utils import str2bool def evaluate(args): - # Init body. with open(args.config) as f: config = CfgNode(yaml.safe_load(f)) @@ -41,6 +42,9 @@ def evaluate(args): # frontend frontend = get_frontend(lang=args.lang, phones_dict=args.phones_dict) + # acoustic model + am_name = args.am[:args.am.rindex('_')] + am_dataset = args.am[args.am.rindex('_') + 1:] spk_num = None if args.speaker_dict is not None: @@ -64,6 +68,15 @@ def evaluate(args): vits.set_state_dict(paddle.load(args.ckpt)["main_params"]) vits.eval() + vits_inference = VITSInference(vits) + # whether dygraph to static + if args.inference_dir: + vits_inference = am_to_static( + am_inference=vits_inference, + am=args.am, + inference_dir=args.inference_dir, + speaker_dict=args.speaker_dict) + output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) merge_sentences = False @@ -90,10 +103,12 @@ def evaluate(args): for i in range(len(phone_ids)): part_phone_ids = phone_ids[i] spk_id = None - if spk_num is not None: + if am_dataset in {"aishell3", + "vctk"} and spk_num is not None: spk_id = paddle.to_tensor(args.spk_id) - out = vits.inference(text=part_phone_ids, sids=spk_id) - wav = out["wav"] + wav = vits_inference(part_phone_ids, spk_id) + else: + wav = vits_inference(part_phone_ids) if flags == 0: wav_all = wav flags = 1 @@ -155,6 +170,11 @@ def parse_args(): type=str2bool, default=True, help="whether to add blank between phones") + parser.add_argument( + '--am', + type=str, + default='vits_csmsc', + help='Choose acoustic model type of tts task.') args = parser.parse_args() return args diff --git a/paddlespeech/t2s/models/vits/transform.py b/paddlespeech/t2s/models/vits/transform.py index fec80377..ea333dcf 100644 --- a/paddlespeech/t2s/models/vits/transform.py +++ b/paddlespeech/t2s/models/vits/transform.py @@ -35,9 +35,10 @@ def piecewise_rational_quadratic_transform( inverse=False, tails=None, tail_bound=1.0, - min_bin_width=DEFAULT_MIN_BIN_WIDTH, - min_bin_height=DEFAULT_MIN_BIN_HEIGHT, - min_derivative=DEFAULT_MIN_DERIVATIVE, ): + # for dygraph-to-static + min_bin_width=1e-3, + min_bin_height=1e-3, + min_derivative=1e-3, ): if tails is None: spline_fn = rational_quadratic_spline spline_kwargs = {} @@ -74,14 +75,17 @@ def unconstrained_rational_quadratic_spline( inverse=False, tails="linear", tail_bound=1.0, - min_bin_width=DEFAULT_MIN_BIN_WIDTH, - min_bin_height=DEFAULT_MIN_BIN_HEIGHT, - min_derivative=DEFAULT_MIN_DERIVATIVE, ): + # for dygraph-to-static + min_bin_width=1e-3, + min_bin_height=1e-3, + min_derivative=1e-3, ): inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) outside_interval_mask = ~inside_interval_mask - - outputs = paddle.zeros(paddle.shape(inputs)) - logabsdet = paddle.zeros(paddle.shape(inputs)) + # for dygraph to static + # 这里用 paddle.shape(x) 然后调用 zeros 会得到一个全 -1 shape 的 var + # 如果用 x.shape 的话可以保留确定的维度 + outputs = paddle.zeros(inputs.shape) + logabsdet = paddle.zeros(inputs.shape) if tails == "linear": unnormalized_derivatives = F.pad( unnormalized_derivatives, @@ -89,8 +93,9 @@ def unconstrained_rational_quadratic_spline( constant = np.log(np.exp(1 - min_derivative) - 1) unnormalized_derivatives[..., 0] = constant unnormalized_derivatives[..., -1] = constant - - outputs[outside_interval_mask] = inputs[outside_interval_mask] + # for dygraph to static + tmp = inputs[outside_interval_mask] + outputs[outside_interval_mask] = tmp logabsdet[outside_interval_mask] = 0 else: raise RuntimeError("{} tails are not implemented.".format(tails)) @@ -130,18 +135,20 @@ def rational_quadratic_spline( right=1.0, bottom=0.0, top=1.0, - min_bin_width=DEFAULT_MIN_BIN_WIDTH, - min_bin_height=DEFAULT_MIN_BIN_HEIGHT, - min_derivative=DEFAULT_MIN_DERIVATIVE, ): - if paddle.min(inputs) < left or paddle.max(inputs) > right: - raise ValueError("Input to a transform is not within its domain") + # for dygraph-to-static + min_bin_width=1e-3, + min_bin_height=1e-3, + min_derivative=1e-3, ): + # for dygraph to static + # if paddle.min(inputs) < left or paddle.max(inputs) > right: + # raise ValueError("Input to a transform is not within its domain") num_bins = unnormalized_widths.shape[-1] - - if min_bin_width * num_bins > 1.0: - raise ValueError("Minimal bin width too large for the number of bins") - if min_bin_height * num_bins > 1.0: - raise ValueError("Minimal bin height too large for the number of bins") + # for dygraph to static + # if min_bin_width * num_bins > 1.0: + # raise ValueError("Minimal bin width too large for the number of bins") + # if min_bin_height * num_bins > 1.0: + # raise ValueError("Minimal bin height too large for the number of bins") widths = F.softmax(unnormalized_widths, axis=-1) widths = min_bin_width + (1 - min_bin_width * num_bins) * widths diff --git a/paddlespeech/t2s/models/vits/vits.py b/paddlespeech/t2s/models/vits/vits.py index e68ed564..7013e06c 100644 --- a/paddlespeech/t2s/models/vits/vits.py +++ b/paddlespeech/t2s/models/vits/vits.py @@ -532,3 +532,14 @@ class VITS(nn.Layer): module.weight[module._padding_idx] = 0 self.apply(_reset_parameters) + +class VITSInference(nn.Layer): + def __init__(self, model): + super().__init__() + self.acoustic_model = model + + def forward(self, text, sids=None): + out = self.acoustic_model.inference( + text, sids=sids) + wav = out['wav'] + return wav