From a6ac497f8e112696ca4d3c3ca787efbcebc2ccca Mon Sep 17 00:00:00 2001 From: TianYuan Date: Mon, 15 Nov 2021 06:45:15 +0000 Subject: [PATCH] add multi-band melgan finetune scripts --- demos/style_fs2/style_syn.py | 120 +------------ examples/csmsc/voc3/conf/finetune.yaml | 139 +++++++++++++++ examples/csmsc/voc3/finetune.sh | 65 +++++++ examples/csmsc/voc3/local/link_wav.py | 85 +++++++++ paddlespeech/t2s/datasets/vocoder_batch_fn.py | 8 +- .../exps/fastspeech2/gen_gt_duration_mel.py | 167 ++++++++++++++++++ .../t2s/models/fastspeech2/fastspeech2.py | 125 +++++++++++++ 7 files changed, 586 insertions(+), 123 deletions(-) create mode 100644 examples/csmsc/voc3/conf/finetune.yaml create mode 100755 examples/csmsc/voc3/finetune.sh create mode 100644 examples/csmsc/voc3/local/link_wav.py create mode 100644 paddlespeech/t2s/exps/fastspeech2/gen_gt_duration_mel.py diff --git a/demos/style_fs2/style_syn.py b/demos/style_fs2/style_syn.py index db15b7ef..5b8ce351 100644 --- a/demos/style_fs2/style_syn.py +++ b/demos/style_fs2/style_syn.py @@ -13,7 +13,6 @@ # limitations under the License. import argparse from pathlib import Path -from typing import Union import numpy as np import paddle @@ -23,129 +22,12 @@ from yacs.config import CfgNode from paddlespeech.t2s.frontend.zh_frontend import Frontend from paddlespeech.t2s.models.fastspeech2 import FastSpeech2 -from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Inference +from paddlespeech.t2s.models.fastspeech2 import StyleFastSpeech2Inference from paddlespeech.t2s.models.parallel_wavegan import PWGGenerator from paddlespeech.t2s.models.parallel_wavegan import PWGInference from paddlespeech.t2s.modules.normalizer import ZScore -class StyleFastSpeech2Inference(FastSpeech2Inference): - def __init__(self, normalizer, model, pitch_stats_path, energy_stats_path): - super().__init__(normalizer, model) - pitch_mean, pitch_std = np.load(pitch_stats_path) - self.pitch_mean = paddle.to_tensor(pitch_mean) - self.pitch_std = paddle.to_tensor(pitch_std) - energy_mean, energy_std = np.load(energy_stats_path) - self.energy_mean = paddle.to_tensor(energy_mean) - self.energy_std = paddle.to_tensor(energy_std) - - def denorm(self, data, mean, std): - return data * std + mean - - def norm(self, data, mean, std): - return (data - mean) / std - - def forward(self, - text: paddle.Tensor, - durations: Union[paddle.Tensor, np.ndarray]=None, - durations_scale: Union[int, float]=None, - durations_bias: Union[int, float]=None, - pitch: Union[paddle.Tensor, np.ndarray]=None, - pitch_scale: Union[int, float]=None, - pitch_bias: Union[int, float]=None, - energy: Union[paddle.Tensor, np.ndarray]=None, - energy_scale: Union[int, float]=None, - energy_bias: Union[int, float]=None, - robot: bool=False): - """ - Parameters - ---------- - text : Tensor(int64) - Input sequence of characters (T,). - speech : Tensor, optional - Feature sequence to extract style (N, idim). - durations : paddle.Tensor/np.ndarray, optional (int64) - Groundtruth of duration (T,), this will overwrite the set of durations_scale and durations_bias - durations_scale: int/float, optional - durations_bias: int/float, optional - pitch : paddle.Tensor/np.ndarray, optional - Groundtruth of token-averaged pitch (T, 1), this will overwrite the set of pitch_scale and pitch_bias - pitch_scale: int/float, optional - In denormed HZ domain. - pitch_bias: int/float, optional - In denormed HZ domain. - energy : paddle.Tensor/np.ndarray, optional - Groundtruth of token-averaged energy (T, 1), this will overwrite the set of energy_scale and energy_bias - energy_scale: int/float, optional - In denormed domain. - energy_bias: int/float, optional - In denormed domain. - robot : bool, optional - Weather output robot style - Returns - ---------- - Tensor - Output sequence of features (L, odim). - """ - normalized_mel, d_outs, p_outs, e_outs = self.acoustic_model.inference( - text, durations=None, pitch=None, energy=None) - # priority: groundtruth > scale/bias > previous output - # set durations - if isinstance(durations, np.ndarray): - durations = paddle.to_tensor(durations) - elif isinstance(durations, paddle.Tensor): - durations = durations - elif durations_scale or durations_bias: - durations_scale = durations_scale if durations_scale is not None else 1 - durations_bias = durations_bias if durations_bias is not None else 0 - durations = durations_scale * d_outs + durations_bias - else: - durations = d_outs - - if robot: - # set normed pitch to zeros have the same effect with set denormd ones to mean - pitch = paddle.zeros(p_outs.shape) - - # set pitch, can overwrite robot set - if isinstance(pitch, np.ndarray): - pitch = paddle.to_tensor(pitch) - elif isinstance(pitch, paddle.Tensor): - pitch = pitch - elif pitch_scale or pitch_bias: - pitch_scale = pitch_scale if pitch_scale is not None else 1 - pitch_bias = pitch_bias if pitch_bias is not None else 0 - p_Hz = paddle.exp( - self.denorm(p_outs, self.pitch_mean, self.pitch_std)) - p_HZ = pitch_scale * p_Hz + pitch_bias - pitch = self.norm(paddle.log(p_HZ), self.pitch_mean, self.pitch_std) - else: - pitch = p_outs - - # set energy - if isinstance(energy, np.ndarray): - energy = paddle.to_tensor(energy) - elif isinstance(energy, paddle.Tensor): - energy = energy - elif energy_scale or energy_bias: - energy_scale = energy_scale if energy_scale is not None else 1 - energy_bias = energy_bias if energy_bias is not None else 0 - e_dnorm = self.denorm(e_outs, self.energy_mean, self.energy_std) - e_dnorm = energy_scale * e_dnorm + energy_bias - energy = self.norm(e_dnorm, self.energy_mean, self.energy_std) - else: - energy = e_outs - - normalized_mel, d_outs, p_outs, e_outs = self.acoustic_model.inference( - text, - durations=durations, - pitch=pitch, - energy=energy, - use_teacher_forcing=True) - - logmel = self.normalizer.inverse(normalized_mel) - return logmel - - def evaluate(args, fastspeech2_config, pwg_config): # construct dataset for evaluation diff --git a/examples/csmsc/voc3/conf/finetune.yaml b/examples/csmsc/voc3/conf/finetune.yaml new file mode 100644 index 00000000..e02f3e22 --- /dev/null +++ b/examples/csmsc/voc3/conf/finetune.yaml @@ -0,0 +1,139 @@ +# This is the hyperparameter configuration file for MelGAN. +# Please make sure this is adjusted for the CSMSC dataset. If you want to +# apply to the other dataset, you might need to carefully change some parameters. +# This configuration requires ~ 8GB memory and will finish within 7 days on Titan V. + +# This configuration is based on full-band MelGAN but the hop size and sampling +# rate is different from the paper (16kHz vs 24kHz). The number of iteraions +# is not shown in the paper so currently we train 1M iterations (not sure enough +# to converge). The optimizer setting is based on @dathudeptrai advice. +# https://github.com/kan-bayashi/ParallelWaveGAN/issues/143#issuecomment-632539906 + +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +fs: 24000 # Sampling rate. +n_fft: 2048 # FFT size. (in samples) +n_shift: 300 # Hop size. (in samples) +win_length: 1200 # Window length. (in samples) + # If set to null, it will be the same as fft_size. +window: "hann" # Window function. +n_mels: 80 # Number of mel basis. +fmin: 80 # Minimum freq in mel basis calculation. (Hz) +fmax: 7600 # Maximum frequency in mel basis calculation. (Hz) + +########################################################### +# GENERATOR NETWORK ARCHITECTURE SETTING # +########################################################### +generator_params: + in_channels: 80 # Number of input channels. + out_channels: 4 # Number of output channels. + kernel_size: 7 # Kernel size of initial and final conv layers. + channels: 384 # Initial number of channels for conv layers. + upsample_scales: [5, 5, 3] # List of Upsampling scales. + stack_kernel_size: 3 # Kernel size of dilated conv layers in residual stack. + stacks: 4 # Number of stacks in a single residual stack module. + use_weight_norm: True # Whether to use weight normalization. + use_causal_conv: False # Whether to use causal convolution. + use_final_nonlinear_activation: True + + +########################################################### +# DISCRIMINATOR NETWORK ARCHITECTURE SETTING # +########################################################### +discriminator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + scales: 3 # Number of multi-scales. + downsample_pooling: "AvgPool1D" # Pooling type for the input downsampling. + downsample_pooling_params: # Parameters of the above pooling function. + kernel_size: 4 + stride: 2 + padding: 1 + exclusive: True + kernel_sizes: [5, 3] # List of kernel size. + channels: 16 # Number of channels of the initial conv layer. + max_downsample_channels: 512 # Maximum number of channels of downsampling layers. + downsample_scales: [4, 4, 4] # List of downsampling scales. + nonlinear_activation: "LeakyReLU" # Nonlinear activation function. + nonlinear_activation_params: # Parameters of nonlinear activation function. + negative_slope: 0.2 + use_weight_norm: True # Whether to use weight norm. + + +########################################################### +# STFT LOSS SETTING # +########################################################### +use_stft_loss: true +stft_loss_params: + fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss. + hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss + win_lengths: [600, 1200, 240] # List of window length for STFT-based loss. + window: "hann" # Window function for STFT-based loss +use_subband_stft_loss: true +subband_stft_loss_params: + fft_sizes: [384, 683, 171] # List of FFT size for STFT-based loss. + hop_sizes: [30, 60, 10] # List of hop size for STFT-based loss + win_lengths: [150, 300, 60] # List of window length for STFT-based loss. + window: "hann" # Window function for STFT-based loss + +########################################################### +# ADVERSARIAL LOSS SETTING # +########################################################### +use_feat_match_loss: false # Whether to use feature matching loss. +lambda_adv: 2.5 # Loss balancing coefficient for adversarial loss. + +########################################################### +# DATA LOADER SETTING # +########################################################### +batch_size: 64 # Batch size. +batch_max_steps: 16200 # Length of each audio in batch. Make sure dividable by hop_size. +num_workers: 2 # Number of workers in DataLoader. + +########################################################### +# OPTIMIZER & SCHEDULER SETTING # +########################################################### +generator_optimizer_params: + epsilon: 1.0e-7 # Generator's epsilon. + weight_decay: 0.0 # Generator's weight decay coefficient. + +generator_grad_norm: -1 # Generator's gradient norm. +generator_scheduler_params: + learning_rate: 1.0e-3 # Generator's learning rate. + gamma: 0.5 # Generator's scheduler gamma. + milestones: # At each milestone, lr will be multiplied by gamma. + - 100000 + - 200000 + - 300000 + - 400000 + - 500000 + - 600000 +discriminator_optimizer_params: + epsilon: 1.0e-7 # Discriminator's epsilon. + weight_decay: 0.0 # Discriminator's weight decay coefficient. + +discriminator_grad_norm: -1 # Discriminator's gradient norm. +discriminator_scheduler_params: + learning_rate: 1.0e-3 # Discriminator's learning rate. + gamma: 0.5 # Discriminator's scheduler gamma. + milestones: # At each milestone, lr will be multiplied by gamma. + - 100000 + - 200000 + - 300000 + - 400000 + - 500000 + - 600000 + +########################################################### +# INTERVAL SETTING # +########################################################### +discriminator_train_start_steps: 200000 # Number of steps to start to train discriminator. +train_max_steps: 1200000 # Number of training steps. +save_interval_steps: 1000 # Interval steps to save checkpoint. +eval_interval_steps: 1000 # Interval steps to evaluate the network. + +########################################################### +# OTHER SETTING # +########################################################### +num_snapshots: 10 # max number of snapshots to keep while training +seed: 42 # random seed for paddle, random, and np.random \ No newline at end of file diff --git a/examples/csmsc/voc3/finetune.sh b/examples/csmsc/voc3/finetune.sh new file mode 100755 index 00000000..1e10e402 --- /dev/null +++ b/examples/csmsc/voc3/finetune.sh @@ -0,0 +1,65 @@ +#!/bin/bash + +source path.sh + +gpus=0 +stage=0 +stop_stage=100 + +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + python3 ${MAIN_ROOT}/paddlespeech/t2s/exps/fastspeech2/gen_gt_duration_mel.py \ + --fastspeech2-config=fastspeech2_nosil_baker_ckpt_0.4/default.yaml \ + --fastspeech2-checkpoint=fastspeech2_nosil_baker_ckpt_0.4/snapshot_iter_76000.pdz \ + --fastspeech2-stat=fastspeech2_nosil_baker_ckpt_0.4/speech_stats.npy \ + --dur-file=durations.txt \ + --output-dir=dump_finetune \ + --phones-dict=fastspeech2_nosil_baker_ckpt_0.4/phone_id_map.txt +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + python3 local/link_wav.py \ + --old-dump-dir=dump \ + --dump-dir=dump_finetune + +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # get features' stats(mean and std) + echo "Get features' stats ..." + python3 ${MAIN_ROOT}/utils/compute_statistics.py \ + --metadata=dump_finetune/train/raw/metadata.jsonl \ + --field-name="feats" +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # normalize, dev and test should use train's stats + echo "Normalize ..." + + python3 ${BIN_DIR}/../normalize.py \ + --metadata=dump_finetune/train/raw/metadata.jsonl \ + --dumpdir=dump_finetune/train/norm \ + --stats=dump_finetune/train/feats_stats.npy + python3 ${BIN_DIR}/../normalize.py \ + --metadata=dump_finetune/dev/raw/metadata.jsonl \ + --dumpdir=dump_finetune/dev/norm \ + --stats=dump_finetune/train/feats_stats.npy + + python3 ${BIN_DIR}/../normalize.py \ + --metadata=dump_finetune/test/raw/metadata.jsonl \ + --dumpdir=dump_finetune/test/norm \ + --stats=dump_finetune/train/feats_stats.npy +fi + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + CUDA_VISIBLE_DEVICES=${gpus} \ + FLAGS_cudnn_exhaustive_search=true \ + FLAGS_conv_workspace_size_limit=4000 \ + python ${BIN_DIR}/train.py \ + --train-metadata=dump_finetune/train/norm/metadata.jsonl \ + --dev-metadata=dump_finetune/dev/norm/metadata.jsonl \ + --config=conf/finetune.yaml \ + --output-dir=exp/finetune \ + --ngpu=1 +fi \ No newline at end of file diff --git a/examples/csmsc/voc3/local/link_wav.py b/examples/csmsc/voc3/local/link_wav.py new file mode 100644 index 00000000..c81e0d4b --- /dev/null +++ b/examples/csmsc/voc3/local/link_wav.py @@ -0,0 +1,85 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +from operator import itemgetter +from pathlib import Path + +import jsonlines +import numpy as np + + +def main(): + # parse config and args + parser = argparse.ArgumentParser( + description="Preprocess audio and then extract features .") + + parser.add_argument( + "--old-dump-dir", + default=None, + type=str, + help="directory to dump feature files.") + parser.add_argument( + "--dump-dir", + type=str, + required=True, + help="directory to finetune dump feature files.") + args = parser.parse_args() + + old_dump_dir = Path(args.old_dump_dir).expanduser() + old_dump_dir = old_dump_dir.resolve() + dump_dir = Path(args.dump_dir).expanduser() + # use absolute path + dump_dir = dump_dir.resolve() + dump_dir.mkdir(parents=True, exist_ok=True) + + assert old_dump_dir.is_dir() + assert dump_dir.is_dir() + + for sub in ["train", "dev", "test"]: + # 把 old_dump_dir 里面的 *-wave.npy 软连接到 dump_dir 的对应位置 + output_dir = dump_dir / sub + output_dir.mkdir(parents=True, exist_ok=True) + results = [] + for name in os.listdir(output_dir / "raw"): + # 003918_feats.npy + utt_id = name.split("_")[0] + mel_path = output_dir / ("raw/" + name) + gen_mel = np.load(mel_path) + wave_name = utt_id + "_wave.npy" + wav = np.load(old_dump_dir / sub / ("raw/" + wave_name)) + os.symlink(old_dump_dir / sub / ("raw/" + wave_name), + output_dir / ("raw/" + wave_name)) + num_sample = wav.shape[0] + num_frames = gen_mel.shape[0] + wav_path = output_dir / ("raw/" + wave_name) + + record = { + "utt_id": utt_id, + "num_samples": num_sample, + "num_frames": num_frames, + "feats": str(mel_path), + "wave": str(wav_path), + } + results.append(record) + + results.sort(key=itemgetter("utt_id")) + + with jsonlines.open(output_dir / "raw/metadata.jsonl", 'w') as writer: + for item in results: + writer.write(item) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/datasets/vocoder_batch_fn.py b/paddlespeech/t2s/datasets/vocoder_batch_fn.py index 2de4fb12..2e4f740f 100644 --- a/paddlespeech/t2s/datasets/vocoder_batch_fn.py +++ b/paddlespeech/t2s/datasets/vocoder_batch_fn.py @@ -110,10 +110,10 @@ class Clip(object): if len(x) < c.shape[0] * self.hop_size: x = np.pad(x, (0, c.shape[0] * self.hop_size - len(x)), mode="edge") elif len(x) > c.shape[0] * self.hop_size: - print( - f"wave length: ({len(x)}), mel length: ({c.shape[0]}), hop size: ({self.hop_size })" - ) - x = x[:c.shape[1] * self.hop_size] + # print( + # f"wave length: ({len(x)}), mel length: ({c.shape[0]}), hop size: ({self.hop_size })" + # ) + x = x[:c.shape[0] * self.hop_size] # check the legnth is valid assert len(x) == c.shape[ diff --git a/paddlespeech/t2s/exps/fastspeech2/gen_gt_duration_mel.py b/paddlespeech/t2s/exps/fastspeech2/gen_gt_duration_mel.py new file mode 100644 index 00000000..8a9ef370 --- /dev/null +++ b/paddlespeech/t2s/exps/fastspeech2/gen_gt_duration_mel.py @@ -0,0 +1,167 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# generate mels using durations.txt +# for mb melgan finetune +# 长度和原本的 mel 不一致怎么办? +import argparse +from pathlib import Path + +import numpy as np +import paddle +import yaml +from yacs.config import CfgNode + +from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur +from paddlespeech.t2s.datasets.preprocess_utils import merge_silence +from paddlespeech.t2s.models.fastspeech2 import FastSpeech2 +from paddlespeech.t2s.models.fastspeech2 import StyleFastSpeech2Inference +from paddlespeech.t2s.modules.normalizer import ZScore + + +def evaluate(args, fastspeech2_config): + + # construct dataset for evaluation + with open(args.phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + vocab_size = len(phn_id) + print("vocab_size:", vocab_size) + + phone_dict = {} + for phn, id in phn_id: + phone_dict[phn] = int(id) + + odim = fastspeech2_config.n_mels + model = FastSpeech2( + idim=vocab_size, odim=odim, **fastspeech2_config["model"]) + + model.set_state_dict( + paddle.load(args.fastspeech2_checkpoint)["main_params"]) + model.eval() + + stat = np.load(args.fastspeech2_stat) + mu, std = stat + mu = paddle.to_tensor(mu) + std = paddle.to_tensor(std) + fastspeech2_normalizer = ZScore(mu, std) + + fastspeech2_inference = StyleFastSpeech2Inference(fastspeech2_normalizer, + model) + fastspeech2_inference.eval() + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + sentences, speaker_set = get_phn_dur(args.dur_file) + merge_silence(sentences) + + for i, utt_id in enumerate(sentences): + phones = sentences[utt_id][0] + durations = sentences[utt_id][1] + speaker = sentences[utt_id][2] + # 裁剪掉开头和结尾的 sil + if args.cut_sil: + if phones[0] == "sil" and len(durations) > 1: + durations = durations[1:] + phones = phones[1:] + if phones[-1] == 'sil' and len(durations) > 1: + durations = durations[:-1] + phones = phones[:-1] + # sentences[utt_id][0] = phones + # sentences[utt_id][1] = durations + + phone_ids = [phone_dict[phn] for phn in phones] + phone_ids = paddle.to_tensor(np.array(phone_ids)) + durations = paddle.to_tensor(np.array(durations)) + # 生成的和真实的可能有 1, 2 帧的差距,但是 batch_fn 会修复 + # split data into 3 sections + if args.dataset == "baker": + num_train = 9800 + num_dev = 100 + if i in range(0, num_train): + sub_output_dir = output_dir / ("train/raw") + elif i in range(num_train, num_train + num_dev): + sub_output_dir = output_dir / ("dev/raw") + else: + sub_output_dir = output_dir / ("test/raw") + sub_output_dir.mkdir(parents=True, exist_ok=True) + with paddle.no_grad(): + mel = fastspeech2_inference(phone_ids, durations=durations) + np.save(sub_output_dir / (utt_id + "_feats.npy"), mel) + + +def main(): + # parse args and config and redirect to train_sp + parser = argparse.ArgumentParser( + description="Synthesize with fastspeech2 & parallel wavegan.") + parser.add_argument( + "--dataset", + default="baker", + type=str, + help="name of dataset, should in {baker, ljspeech, vctk} now") + parser.add_argument( + "--fastspeech2-config", type=str, help="fastspeech2 config file.") + parser.add_argument( + "--fastspeech2-checkpoint", + type=str, + help="fastspeech2 checkpoint to load.") + parser.add_argument( + "--fastspeech2-stat", + type=str, + help="mean and standard deviation used to normalize spectrogram when training fastspeech2." + ) + + parser.add_argument( + "--phones-dict", + type=str, + default="phone_id_map.txt", + help="phone vocabulary file.") + + parser.add_argument( + "--dur-file", default=None, type=str, help="path to durations.txt.") + parser.add_argument("--output-dir", type=str, help="output dir.") + parser.add_argument( + "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") + parser.add_argument("--verbose", type=int, default=1, help="verbose.") + + def str2bool(str): + return True if str.lower() == 'true' else False + + parser.add_argument( + "--cut-sil", + type=str2bool, + default=True, + help="whether cut sil in the edge of audio") + + args = parser.parse_args() + + if args.ngpu == 0: + paddle.set_device("cpu") + elif args.ngpu > 0: + paddle.set_device("gpu") + else: + print("ngpu should >= 0 !") + + with open(args.fastspeech2_config) as f: + fastspeech2_config = CfgNode(yaml.safe_load(f)) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(fastspeech2_config) + + evaluate(args, fastspeech2_config) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py index 2202d156..e650e255 100644 --- a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py +++ b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py @@ -16,7 +16,9 @@ from typing import Dict from typing import Sequence from typing import Tuple +from typing import Union +import numpy as np import paddle import paddle.nn.functional as F from paddle import nn @@ -687,6 +689,129 @@ class FastSpeech2Inference(nn.Layer): return logmel +class StyleFastSpeech2Inference(FastSpeech2Inference): + def __init__(self, + normalizer, + model, + pitch_stats_path=None, + energy_stats_path=None): + super().__init__(normalizer, model) + if pitch_stats_path: + pitch_mean, pitch_std = np.load(pitch_stats_path) + self.pitch_mean = paddle.to_tensor(pitch_mean) + self.pitch_std = paddle.to_tensor(pitch_std) + if energy_stats_path: + energy_mean, energy_std = np.load(energy_stats_path) + self.energy_mean = paddle.to_tensor(energy_mean) + self.energy_std = paddle.to_tensor(energy_std) + + def denorm(self, data, mean, std): + return data * std + mean + + def norm(self, data, mean, std): + return (data - mean) / std + + def forward(self, + text: paddle.Tensor, + durations: Union[paddle.Tensor, np.ndarray]=None, + durations_scale: Union[int, float]=None, + durations_bias: Union[int, float]=None, + pitch: Union[paddle.Tensor, np.ndarray]=None, + pitch_scale: Union[int, float]=None, + pitch_bias: Union[int, float]=None, + energy: Union[paddle.Tensor, np.ndarray]=None, + energy_scale: Union[int, float]=None, + energy_bias: Union[int, float]=None, + robot: bool=False): + """ + Parameters + ---------- + text : Tensor(int64) + Input sequence of characters (T,). + speech : Tensor, optional + Feature sequence to extract style (N, idim). + durations : paddle.Tensor/np.ndarray, optional (int64) + Groundtruth of duration (T,), this will overwrite the set of durations_scale and durations_bias + durations_scale: int/float, optional + durations_bias: int/float, optional + pitch : paddle.Tensor/np.ndarray, optional + Groundtruth of token-averaged pitch (T, 1), this will overwrite the set of pitch_scale and pitch_bias + pitch_scale: int/float, optional + In denormed HZ domain. + pitch_bias: int/float, optional + In denormed HZ domain. + energy : paddle.Tensor/np.ndarray, optional + Groundtruth of token-averaged energy (T, 1), this will overwrite the set of energy_scale and energy_bias + energy_scale: int/float, optional + In denormed domain. + energy_bias: int/float, optional + In denormed domain. + robot : bool, optional + Weather output robot style + Returns + ---------- + Tensor + Output sequence of features (L, odim). + """ + normalized_mel, d_outs, p_outs, e_outs = self.acoustic_model.inference( + text, durations=None, pitch=None, energy=None) + # priority: groundtruth > scale/bias > previous output + # set durations + if isinstance(durations, np.ndarray): + durations = paddle.to_tensor(durations) + elif isinstance(durations, paddle.Tensor): + durations = durations + elif durations_scale or durations_bias: + durations_scale = durations_scale if durations_scale is not None else 1 + durations_bias = durations_bias if durations_bias is not None else 0 + durations = durations_scale * d_outs + durations_bias + else: + durations = d_outs + + if robot: + # set normed pitch to zeros have the same effect with set denormd ones to mean + pitch = paddle.zeros(p_outs.shape) + + # set pitch, can overwrite robot set + if isinstance(pitch, np.ndarray): + pitch = paddle.to_tensor(pitch) + elif isinstance(pitch, paddle.Tensor): + pitch = pitch + elif pitch_scale or pitch_bias: + pitch_scale = pitch_scale if pitch_scale is not None else 1 + pitch_bias = pitch_bias if pitch_bias is not None else 0 + p_Hz = paddle.exp( + self.denorm(p_outs, self.pitch_mean, self.pitch_std)) + p_HZ = pitch_scale * p_Hz + pitch_bias + pitch = self.norm(paddle.log(p_HZ), self.pitch_mean, self.pitch_std) + else: + pitch = p_outs + + # set energy + if isinstance(energy, np.ndarray): + energy = paddle.to_tensor(energy) + elif isinstance(energy, paddle.Tensor): + energy = energy + elif energy_scale or energy_bias: + energy_scale = energy_scale if energy_scale is not None else 1 + energy_bias = energy_bias if energy_bias is not None else 0 + e_dnorm = self.denorm(e_outs, self.energy_mean, self.energy_std) + e_dnorm = energy_scale * e_dnorm + energy_bias + energy = self.norm(e_dnorm, self.energy_mean, self.energy_std) + else: + energy = e_outs + + normalized_mel, d_outs, p_outs, e_outs = self.acoustic_model.inference( + text, + durations=durations, + pitch=pitch, + energy=energy, + use_teacher_forcing=True) + + logmel = self.normalizer.inverse(normalized_mel) + return logmel + + class FastSpeech2Loss(nn.Layer): """Loss function module for FastSpeech2."""