diff --git a/examples/opencpop/svs1/conf/default.yaml b/examples/opencpop/svs1/conf/default.yaml index 13b803b5e..57ffde9a2 100644 --- a/examples/opencpop/svs1/conf/default.yaml +++ b/examples/opencpop/svs1/conf/default.yaml @@ -23,8 +23,8 @@ f0max: 750 # Maximum f0 for pitch extraction. ########################################################### # DATA SETTING # ########################################################### -batch_size: 24 # batch size -num_workers: 4 # number of gpu +batch_size: 48 # batch size +num_workers: 1 # number of gpu ########################################################### @@ -98,13 +98,14 @@ model: use_weight_norm: False # Whether to use weight norm in all convolutions init_type: "kaiming_normal" # Type of initialize weights of a neural network module - # diffusion module + diffusion_params: num_train_timesteps: 100 # The number of timesteps between the noise and the real during training beta_start: 0.0001 # beta start parameter for the scheduler beta_end: 0.06 # beta end parameter for the scheduler beta_schedule: "linear" # beta schedule parameter for the scheduler - num_max_timesteps: 60 # The max timestep transition from real to noise + num_max_timesteps: 100 # The max timestep transition from real to noise + ########################################################### @@ -134,18 +135,18 @@ ds_optimizer_params: ds_scheduler_params: learning_rate: 0.001 gamma: 0.5 - step_size: 10000 + step_size: 50000 ds_grad_norm: 1 ########################################################### # INTERVAL SETTING # ########################################################### -ds_train_start_steps: 100000 # Number of steps to start to train diffusion module. -train_max_steps: 200000 # Number of training steps. -save_interval_steps: 500 # Interval steps to save checkpoint. -eval_interval_steps: 100 # Interval steps to evaluate the network. -num_snapshots: 10 # Number of saved models +ds_train_start_steps: 160000 # Number of steps to start to train diffusion module. +train_max_steps: 320000 # Number of training steps. +save_interval_steps: 2000 # Interval steps to save checkpoint. +eval_interval_steps: 2000 # Interval steps to evaluate the network. +num_snapshots: 5 # Number of saved models ########################################################### diff --git a/examples/opencpop/svs1/local/synthesize.sh b/examples/opencpop/svs1/local/synthesize.sh index 37f8893a9..dae0c6323 100755 --- a/examples/opencpop/svs1/local/synthesize.sh +++ b/examples/opencpop/svs1/local/synthesize.sh @@ -2,9 +2,9 @@ config_path=$1 train_output_path=$2 -#ckpt_name=$3 -iter=$3 -ckpt_name=snapshot_iter_${iter}.pdz +ckpt_name=$3 +#iter=$3 +#ckpt_name=snapshot_iter_${iter}.pdz stage=0 stop_stage=0 @@ -21,7 +21,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then --voc_config=pwgan_opencpop/default.yaml \ --voc_ckpt=pwgan_opencpop/snapshot_iter_100000.pdz \ --voc_stat=pwgan_opencpop/feats_stats.npy \ - --test_metadata=test1.jsonl \ + --test_metadata=test.jsonl \ --output_dir=${train_output_path}/test_${iter} \ --phones_dict=dump/phone_id_map.txt fi diff --git a/examples/opencpop/svs1/local/train.sh b/examples/opencpop/svs1/local/train.sh index 42fff26ca..d1302f99f 100755 --- a/examples/opencpop/svs1/local/train.sh +++ b/examples/opencpop/svs1/local/train.sh @@ -8,5 +8,5 @@ python3 ${BIN_DIR}/train.py \ --dev-metadata=dump/dev/norm/metadata.jsonl \ --config=${config_path} \ --output-dir=${train_output_path} \ - --ngpu=4 \ + --ngpu=1 \ --phones-dict=dump/phone_id_map.txt diff --git a/examples/opencpop/svs1/run.sh b/examples/opencpop/svs1/run.sh index 7f25a15bd..7bde38518 100755 --- a/examples/opencpop/svs1/run.sh +++ b/examples/opencpop/svs1/run.sh @@ -3,9 +3,9 @@ set -e source path.sh -gpus=4,5,6,7 -stage=1 -stop_stage=1 +gpus=0 +stage=0 +stop_stage=100 conf_path=conf/default.yaml train_output_path=exp/default diff --git a/paddlespeech/t2s/exps/diffsinger/normalize.py b/paddlespeech/t2s/exps/diffsinger/normalize.py index 0a54cfbb6..dec6127e1 100644 --- a/paddlespeech/t2s/exps/diffsinger/normalize.py +++ b/paddlespeech/t2s/exps/diffsinger/normalize.py @@ -23,6 +23,7 @@ from sklearn.preprocessing import StandardScaler from tqdm import tqdm from paddlespeech.t2s.datasets.data_table import DataTable +from paddlespeech.t2s.utils import str2bool def main(): @@ -58,6 +59,11 @@ def main(): "--phones-dict", type=str, default=None, help="phone vocabulary file.") parser.add_argument( "--speaker-dict", type=str, default=None, help="speaker id map file.") + parser.add_argument( + "--norm-feats", + type=str2bool, + default=False, + help="whether to norm features") args = parser.parse_args() @@ -80,18 +86,36 @@ def main(): # restore scaler speech_scaler = StandardScaler() - speech_scaler.mean_ = np.load(args.speech_stats)[0] - speech_scaler.scale_ = np.load(args.speech_stats)[1] + if args.norm_feats: + speech_scaler.mean_ = np.load(args.speech_stats)[0] + speech_scaler.scale_ = np.load(args.speech_stats)[1] + else: + speech_scaler.mean_ = np.zeros( + np.load(args.speech_stats)[0].shape, dtype="float32") + speech_scaler.scale_ = np.ones( + np.load(args.speech_stats)[1].shape, dtype="float32") speech_scaler.n_features_in_ = speech_scaler.mean_.shape[0] pitch_scaler = StandardScaler() - pitch_scaler.mean_ = np.load(args.pitch_stats)[0] - pitch_scaler.scale_ = np.load(args.pitch_stats)[1] + if args.norm_feats: + pitch_scaler.mean_ = np.load(args.pitch_stats)[0] + pitch_scaler.scale_ = np.load(args.pitch_stats)[1] + else: + pitch_scaler.mean_ = np.zeros( + np.load(args.pitch_stats)[0].shape, dtype="float32") + pitch_scaler.scale_ = np.ones( + np.load(args.pitch_stats)[1].shape, dtype="float32") pitch_scaler.n_features_in_ = pitch_scaler.mean_.shape[0] energy_scaler = StandardScaler() - energy_scaler.mean_ = np.load(args.energy_stats)[0] - energy_scaler.scale_ = np.load(args.energy_stats)[1] + if args.norm_feats: + energy_scaler.mean_ = np.load(args.energy_stats)[0] + energy_scaler.scale_ = np.load(args.energy_stats)[1] + else: + energy_scaler.mean_ = np.zeros( + np.load(args.energy_stats)[0].shape, dtype="float32") + energy_scaler.scale_ = np.ones( + np.load(args.energy_stats)[1].shape, dtype="float32") energy_scaler.n_features_in_ = energy_scaler.mean_.shape[0] vocab_phones = {} @@ -111,7 +135,6 @@ def main(): for item in tqdm(dataset): utt_id = item['utt_id'] - print(utt_id) speech = item['speech'] pitch = item['pitch'] energy = item['energy'] diff --git a/paddlespeech/t2s/exps/diffsinger/train.py b/paddlespeech/t2s/exps/diffsinger/train.py index 2444b0610..5e834b3d3 100644 --- a/paddlespeech/t2s/exps/diffsinger/train.py +++ b/paddlespeech/t2s/exps/diffsinger/train.py @@ -143,9 +143,17 @@ def train_sp(args, config): print("criterions done!") optimizer_fs2 = build_optimizers(model_fs2, **config["fs2_optimizer"]) + # gradient_clip_ds = nn.ClipGradByGlobalNorm(config["ds_grad_norm"]) + # optimizer_ds = AdamW( + # learning_rate=config["ds_scheduler_params"]["learning_rate"], + # grad_clip=gradient_clip_ds, + # parameters=model_ds.parameters(), + # **config["ds_optimizer_params"]) + + lr_schedule_ds = StepDecay(**config["ds_scheduler_params"]) gradient_clip_ds = nn.ClipGradByGlobalNorm(config["ds_grad_norm"]) optimizer_ds = AdamW( - learning_rate=config["ds_scheduler_params"]["learning_rate"], + learning_rate=lr_schedule_ds, grad_clip=gradient_clip_ds, parameters=model_ds.parameters(), **config["ds_optimizer_params"]) diff --git a/paddlespeech/t2s/exps/synthesize.py b/paddlespeech/t2s/exps/synthesize.py index 23d7e24dc..adf59a96a 100644 --- a/paddlespeech/t2s/exps/synthesize.py +++ b/paddlespeech/t2s/exps/synthesize.py @@ -114,29 +114,44 @@ def evaluate(args): is_slur = paddle.to_tensor(datum["is_slur"]) get_mel_fs2 = False # mel: [T, mel_bin] - mel = am_inference( + mel1 = am_inference( phone_ids, note=note, note_dur=note_dur, is_slur=is_slur, - get_mel_fs2=get_mel_fs2) + get_mel_fs2=True) + mel2 = am_inference( + phone_ids, + note=note, + note_dur=note_dur, + is_slur=is_slur, + get_mel_fs2=False) # import numpy as np # mel = np.load("/home/liangyunming/others_code/DiffSinger_lym/diffsinger_mel.npy") # mel = paddle.to_tensor(mel) - wav = voc_inference(mel) - + wav1 = voc_inference(mel1) + wav2 = voc_inference(mel2) - wav = wav.numpy() - N += wav.size + wav1 = wav1.numpy() + wav2 = wav2.numpy() + N += wav1.size + N += wav2.size T += t.elapse - speed = wav.size / t.elapse + speed = 2 * wav1.size / t.elapse rtf = am_config.fs / speed print( - f"{utt_id}, mel: {mel.shape}, wave: {wav.size}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." + f"{utt_id}, mel: {mel1.shape}, wave: {wav1.size}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." ) sf.write( # str(output_dir / ("xiaojiuwo_diffsinger" + ".wav")), wav, samplerate=am_config.fs) - str(output_dir / (utt_id + ".wav")), wav, samplerate=am_config.fs) + str(output_dir / (utt_id + "_fs2.wav")), + wav1, + samplerate=am_config.fs) + sf.write( + str(output_dir / (utt_id + "_diffusion.wav")), + wav2, + samplerate=am_config.fs) + print(f"{utt_id} done!") # break print(f"generation speed: {N / T}Hz, RTF: {am_config.fs / (N / T) }") diff --git a/paddlespeech/t2s/models/diffsinger/diffsinger.py b/paddlespeech/t2s/models/diffsinger/diffsinger.py index 496237c7b..1fa4dfd39 100644 --- a/paddlespeech/t2s/models/diffsinger/diffsinger.py +++ b/paddlespeech/t2s/models/diffsinger/diffsinger.py @@ -17,13 +17,14 @@ from typing import Any from typing import Dict from typing import Tuple +import numpy as np import paddle from paddle import nn from typeguard import check_argument_types from paddlespeech.t2s.models.diffsinger.fastspeech2midi import FastSpeech2MIDI +from paddlespeech.t2s.modules.diffnet import DiffNet from paddlespeech.t2s.modules.diffusion import GaussianDiffusion -from paddlespeech.t2s.modules.diffusion import WaveNetDenoiser class DiffSinger(nn.Layer): @@ -134,7 +135,8 @@ class DiffSinger(nn.Layer): "beta_end": 0.06, "beta_schedule": "squaredcos_cap_v2", "num_max_timesteps": 60 - }, ): + }, + stretch: bool=True, ): """Initialize DiffSinger module. Args: @@ -156,8 +158,40 @@ class DiffSinger(nn.Layer): fastspeech2_params=fastspeech2_params, note_num=note_num, is_slur_num=is_slur_num) - denoiser = WaveNetDenoiser(**denoiser_params) - self.diffusion = GaussianDiffusion(denoiser, **diffusion_params) + denoiser = DiffNet(**denoiser_params) + spec_min = paddle.to_tensor( + np.array([ + -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, + -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, + -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, + -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, + -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, + -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, + -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, + -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0 + ])) + spec_max = paddle.to_tensor( + np.array([ + -0.79453, -0.81116, -0.61631, -0.30679, -0.13863, -0.050652, + -0.11563, -0.10679, -0.091068, -0.062174, -0.075302, -0.072217, + -0.063815, -0.073299, 0.007361, -0.072508, -0.050234, -0.16534, + -0.26928, -0.20782, -0.20823, -0.11702, -0.070128, -0.065868, + -0.012675, 0.0015121, -0.089902, -0.21392, -0.23789, -0.28922, + -0.30405, -0.23029, -0.22088, -0.21542, -0.29367, -0.30137, + -0.38281, -0.4359, -0.28681, -0.46855, -0.57485, -0.47022, + -0.54266, -0.44848, -0.6412, -0.687, -0.6486, -0.76436, + -0.49971, -0.71068, -0.69724, -0.61487, -0.55843, -0.69773, + -0.57502, -0.70919, -0.82431, -0.84213, -0.90431, -0.8284, + -0.77945, -0.82758, -0.87699, -1.0532, -1.0766, -1.1198, + -1.0185, -0.98983, -1.0001, -1.0756, -1.0024, -1.0304, -1.0579, + -1.0188, -1.05, -1.0842, -1.0923, -1.1223, -1.2381, -1.6467 + ])) + self.diffusion = GaussianDiffusion( + denoiser, + **diffusion_params, + stretch=stretch, + min_values=spec_min, + max_values=spec_max, ) def forward( self, @@ -279,7 +313,7 @@ class DiffSinger(nn.Layer): cond=cond_fs2, ref_x=mel_fs2, scheduler_type="ddpm", - num_inference_steps=25) + num_inference_steps=60) mel = mel.transpose((0, 2, 1)) return mel[0] diff --git a/paddlespeech/t2s/modules/diffnet.py b/paddlespeech/t2s/modules/diffnet.py new file mode 100644 index 000000000..6a87c5537 --- /dev/null +++ b/paddlespeech/t2s/modules/diffnet.py @@ -0,0 +1,188 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import paddle +import paddle.nn.functional as F +from paddle import nn + +from paddlespeech.t2s.modules.nets_utils import initialize +from paddlespeech.utils.initialize import _calculate_fan_in_and_fan_out +from paddlespeech.utils.initialize import kaiming_normal_ +from paddlespeech.utils.initialize import kaiming_uniform_ +from paddlespeech.utils.initialize import uniform_ +from paddlespeech.utils.initialize import zeros_ + + +def Conv1D(*args, **kwargs): + layer = nn.Conv1D(*args, **kwargs) + # Initialize the weight to be consistent with the official + kaiming_normal_(layer.weight) + + # Initialization is consistent with torch + if layer.bias is not None: + fan_in, _ = _calculate_fan_in_and_fan_out(layer.weight) + if fan_in != 0: + bound = 1 / math.sqrt(fan_in) + uniform_(layer.bias, -bound, bound) + return layer + + +# Initialization is consistent with torch +def Linear(*args, **kwargs): + layer = nn.Linear(*args, **kwargs) + kaiming_uniform_(layer.weight, a=math.sqrt(5)) + if layer.bias is not None: + fan_in, _ = _calculate_fan_in_and_fan_out(layer.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + uniform_(layer.bias, -bound, bound) + return layer + + +class ResidualBlock(nn.Layer): + """ResidualBlock + """ + + def __init__(self, encoder_hidden, residual_channels, gate_channels, + kernel_size, dilation): + super().__init__() + self.dilated_conv = Conv1D( + residual_channels, + gate_channels, + kernel_size, + padding=dilation, + dilation=dilation) + self.diffusion_projection = Linear(residual_channels, residual_channels) + self.conditioner_projection = Conv1D(encoder_hidden, gate_channels, 1) + self.output_projection = Conv1D(residual_channels, gate_channels, 1) + + def forward(self, x, conditioner, diffusion_step): + """_summary_ + + Args: + nn (_type_): _description_ + """ + diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) + conditioner = self.conditioner_projection(conditioner) + y = x + diffusion_step + + y = self.dilated_conv(y) + conditioner + + gate, filter = paddle.chunk(y, 2, axis=1) + y = F.sigmoid(gate) * paddle.tanh(filter) + + y = self.output_projection(y) + residual, skip = paddle.chunk(y, 2, axis=1) + return (x + residual) / math.sqrt(2.0), skip + + +class SinusoidalPosEmb(nn.Layer): + """_summary_ + + Args: + nn (_type_): _description_ + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + """_summary_ + + Args: + nn (_type_): _description_ + """ + x = paddle.cast(x, 'float32') + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = paddle.exp(paddle.arange(half_dim) * -emb) + emb = x[:, None] * emb[None, :] + emb = paddle.concat([emb.sin(), emb.cos()], axis=-1) + return emb + + +class DiffNet(nn.Layer): + def __init__( + self, + in_channels: int=80, + out_channels: int=80, + kernel_size: int=3, + layers: int=20, + stacks: int=5, + residual_channels: int=256, + gate_channels: int=512, + skip_channels: int=256, + aux_channels: int=256, + dropout: float=0., + bias: bool=True, + use_weight_norm: bool=False, + init_type: str="kaiming_normal", ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.layers = layers + self.aux_channels = aux_channels + self.residual_channels = residual_channels + self.gate_channels = gate_channels + self.kernel_size = kernel_size + self.dilation_cycle_length = layers // stacks + self.skip_channels = skip_channels + + self.input_projection = Conv1D(self.in_channels, self.residual_channels, + 1) + self.diffusion_embedding = SinusoidalPosEmb(self.residual_channels) + dim = self.residual_channels + self.mlp = nn.Sequential( + Linear(dim, dim * 4), nn.Mish(), Linear(dim * 4, dim)) + self.residual_layers = nn.LayerList([ + ResidualBlock( + encoder_hidden=self.aux_channels, + residual_channels=self.residual_channels, + gate_channels=self.gate_channels, + kernel_size=self.kernel_size, + dilation=2**(i % self.dilation_cycle_length)) + for i in range(self.layers) + ]) + self.skip_projection = Conv1D(self.residual_channels, + self.skip_channels, 1) + self.output_projection = Conv1D(self.residual_channels, + self.out_channels, 1) + zeros_(self.output_projection.weight) + + def forward(self, spec, diffusion_step, cond): + """ + + :param spec: [B, M, T] + :param diffusion_step: [B, 1] + :param cond: [B, M, T] + :return: + """ + x = spec + x = self.input_projection(x) # x [B, residual_channel, T] + + x = F.relu(x) + diffusion_step = self.diffusion_embedding(diffusion_step) + diffusion_step = self.mlp(diffusion_step) + skip = [] + for layer_id, layer in enumerate(self.residual_layers): + x, skip_connection = layer(x, cond, diffusion_step) + skip.append(skip_connection) + x = paddle.sum( + paddle.stack(skip), axis=0) / math.sqrt(len(self.residual_layers)) + x = self.skip_projection(x) + x = F.relu(x) + x = self.output_projection(x) # [B, 80, T] + return x diff --git a/paddlespeech/t2s/modules/diffusion.py b/paddlespeech/t2s/modules/diffusion.py index be684ce38..d70fba4be 100644 --- a/paddlespeech/t2s/modules/diffusion.py +++ b/paddlespeech/t2s/modules/diffusion.py @@ -17,6 +17,7 @@ from typing import Callable from typing import Optional from typing import Tuple +import numpy as np import paddle import ppdiffusers from paddle import nn @@ -27,170 +28,6 @@ from paddlespeech.t2s.modules.nets_utils import initialize from paddlespeech.t2s.modules.residual_block import WaveNetResidualBlock -class WaveNetDenoiser(nn.Layer): - """A Mel-Spectrogram Denoiser modified from WaveNet - - Args: - in_channels (int, optional): - Number of channels of the input mel-spectrogram, by default 80 - out_channels (int, optional): - Number of channels of the output mel-spectrogram, by default 80 - kernel_size (int, optional): - Kernel size of the residual blocks inside, by default 3 - layers (int, optional): - Number of residual blocks inside, by default 20 - stacks (int, optional): - The number of groups to split the residual blocks into, by default 5 - Within each group, the dilation of the residual block grows exponentially. - residual_channels (int, optional): - Residual channel of the residual blocks, by default 256 - gate_channels (int, optional): - Gate channel of the residual blocks, by default 512 - skip_channels (int, optional): - Skip channel of the residual blocks, by default 256 - aux_channels (int, optional): - Auxiliary channel of the residual blocks, by default 256 - dropout (float, optional): - Dropout of the residual blocks, by default 0. - bias (bool, optional): - Whether to use bias in residual blocks, by default True - use_weight_norm (bool, optional): - Whether to use weight norm in all convolutions, by default False - """ - - def __init__( - self, - in_channels: int=80, - out_channels: int=80, - kernel_size: int=3, - layers: int=20, - stacks: int=5, - residual_channels: int=256, - gate_channels: int=512, - skip_channels: int=256, - aux_channels: int=256, - dropout: float=0., - bias: bool=True, - use_weight_norm: bool=False, - init_type: str="kaiming_normal", ): - super().__init__() - - # initialize parameters - initialize(self, init_type) - - self.in_channels = in_channels - self.out_channels = out_channels - self.aux_channels = aux_channels - self.layers = layers - self.stacks = stacks - self.kernel_size = kernel_size - - assert layers % stacks == 0 - layers_per_stack = layers // stacks - - self.first_t_emb = nn.Sequential( - Timesteps( - residual_channels, - flip_sin_to_cos=False, - downscale_freq_shift=1), - nn.Linear(residual_channels, residual_channels * 4), - nn.Mish(), nn.Linear(residual_channels * 4, residual_channels)) - self.t_emb_layers = nn.LayerList([ - nn.Linear(residual_channels, residual_channels) - for _ in range(layers) - ]) - - self.first_conv = nn.Conv1D( - in_channels, residual_channels, 1, bias_attr=True) - self.first_act = nn.ReLU() - - self.conv_layers = nn.LayerList() - for layer in range(layers): - dilation = 2**(layer % layers_per_stack) - conv = WaveNetResidualBlock( - kernel_size=kernel_size, - residual_channels=residual_channels, - gate_channels=gate_channels, - skip_channels=skip_channels, - aux_channels=aux_channels, - dilation=dilation, - dropout=dropout, - bias=bias) - self.conv_layers.append(conv) - - final_conv = nn.Conv1D(skip_channels, out_channels, 1, bias_attr=True) - nn.initializer.Constant(0.0)(final_conv.weight) - self.last_conv_layers = nn.Sequential(nn.ReLU(), - nn.Conv1D( - skip_channels, - skip_channels, - 1, - bias_attr=True), - nn.ReLU(), final_conv) - - if use_weight_norm: - self.apply_weight_norm() - - def forward(self, x, t, c): - """Denoise mel-spectrogram. - - Args: - x(Tensor): - Shape (N, C_in, T), The input mel-spectrogram. - t(Tensor): - Shape (N), The timestep input. - c(Tensor): - Shape (N, C_aux, T'). The auxiliary input (e.g. fastspeech2 encoder output). - - Returns: - Tensor: Shape (N, C_out, T), the denoised mel-spectrogram. - """ - assert c.shape[-1] == x.shape[-1] - - if t.shape[0] != x.shape[0]: - t = t.tile([x.shape[0]]) - t_emb = self.first_t_emb(t) - t_embs = [ - t_emb_layer(t_emb)[..., None] for t_emb_layer in self.t_emb_layers - ] - - x = self.first_conv(x) - x = self.first_act(x) - skips = 0 - for f, t in zip(self.conv_layers, t_embs): - x = x + t - x, s = f(x, c) - skips += s - skips *= math.sqrt(1.0 / len(self.conv_layers)) - - x = self.last_conv_layers(skips) - return x - - def apply_weight_norm(self): - """Recursively apply weight normalization to all the Convolution layers - in the sublayers. - """ - - def _apply_weight_norm(layer): - if isinstance(layer, (nn.Conv1D, nn.Conv2D)): - nn.utils.weight_norm(layer) - - self.apply(_apply_weight_norm) - - def remove_weight_norm(self): - """Recursively remove weight normalization from all the Convolution - layers in the sublayers. - """ - - def _remove_weight_norm(layer): - try: - nn.utils.remove_weight_norm(layer) - except ValueError: - pass - - self.apply(_remove_weight_norm) - - class GaussianDiffusion(nn.Layer): """Common Gaussian Diffusion Denoising Model Module @@ -294,13 +131,17 @@ class GaussianDiffusion(nn.Layer): """ - def __init__(self, - denoiser: nn.Layer, - num_train_timesteps: Optional[int]=1000, - beta_start: Optional[float]=0.0001, - beta_end: Optional[float]=0.02, - beta_schedule: Optional[str]="squaredcos_cap_v2", - num_max_timesteps: Optional[int]=None): + def __init__( + self, + denoiser: nn.Layer, + num_train_timesteps: Optional[int]=1000, + beta_start: Optional[float]=0.0001, + beta_end: Optional[float]=0.02, + beta_schedule: Optional[str]="squaredcos_cap_v2", + num_max_timesteps: Optional[int]=None, + stretch: bool=True, + min_values: paddle.Tensor=None, + max_values: paddle.Tensor=None, ): super().__init__() self.num_train_timesteps = num_train_timesteps @@ -315,6 +156,22 @@ class GaussianDiffusion(nn.Layer): beta_end=beta_end, beta_schedule=beta_schedule) self.num_max_timesteps = num_max_timesteps + self.stretch = stretch + self.min_values = min_values + self.max_values = max_values + + def norm_spec(self, x): + """ + Linearly map x to [-1, 1] + Args: + x: [B, T, N] + """ + return (x - self.min_values) / (self.max_values - self.min_values + ) * 2 - 1 + + def denorm_spec(self, x): + return (x + 1) / 2 * (self.max_values - self.min_values + ) + self.min_values def forward(self, x: paddle.Tensor, cond: Optional[paddle.Tensor]=None ) -> Tuple[paddle.Tensor, paddle.Tensor]: @@ -333,6 +190,12 @@ class GaussianDiffusion(nn.Layer): The noises which is added to the input. """ + if self.stretch: + assert self.min_values is not None and self.max_values is not None, "self.min_values and self.max_values should not be None." + x = x.transpose((0, 2, 1)) + x = self.norm_spec(x) + x = x.transpose((0, 2, 1)) + noise_scheduler = self.noise_scheduler # Sample noise that we'll add to the mel-spectrograms @@ -369,9 +232,9 @@ class GaussianDiffusion(nn.Layer): Args: noise (Tensor): - The input tensor as a starting point for denoising. + The input tensor as a starting point for denoising. cond (Tensor, optional): - Conditional input for compute noises. + Conditional input for compute noises. (N, C_aux, T) ref_x (Tensor, optional): The real output for the denoising process to refer. num_inference_steps (int, optional): @@ -382,6 +245,7 @@ class GaussianDiffusion(nn.Layer): scheduler_type (str, optional): Noise scheduler for generate noises. Choose a great scheduler can skip many denoising step, by default 'ddpm'. + only support 'ddpm' now ! clip_noise (bool, optional): Whether to clip each denoised output, by default True. clip_noise_range (tuple, optional): @@ -425,48 +289,30 @@ class GaussianDiffusion(nn.Layer): # set timesteps scheduler.set_timesteps(num_inference_steps) - # prepare first noise variables noisy_input = noise - timesteps = scheduler.timesteps - if ref_x is not None: - init_timestep = None - if strength is None or strength < 0. or strength > 1.: - strength = None - if self.num_max_timesteps is not None: - strength = self.num_max_timesteps / self.num_train_timesteps - if strength is not None: - # get the original timestep using init_timestep - init_timestep = min( - int(num_inference_steps * strength), num_inference_steps) - t_start = max(num_inference_steps - init_timestep, 0) - timesteps = scheduler.timesteps[t_start:] - num_inference_steps = num_inference_steps - t_start - noisy_input = scheduler.add_noise( - ref_x, noise, timesteps[:1].tile([noise.shape[0]])) - - # denoising loop + if self.stretch and ref_x is not None: + assert self.min_values is not None and self.max_values is not None, "self.min_values and self.max_values should not be None." + ref_x = ref_x.transpose((0, 2, 1)) + ref_x = self.norm_spec(ref_x) + ref_x = ref_x.transpose((0, 2, 1)) + + # for ddpm + timesteps = paddle.to_tensor( + np.flipud(np.arange(num_inference_steps))) + noisy_input = scheduler.add_noise(ref_x, noise, timesteps[0]) + denoised_output = noisy_input - if clip_noise: - n_min, n_max = clip_noise_range - denoised_output = paddle.clip(denoised_output, n_min, n_max) - num_warmup_steps = len( - timesteps) - num_inference_steps * scheduler.order for i, t in enumerate(timesteps): denoised_output = scheduler.scale_model_input(denoised_output, t) - - # predict the noise residual noise_pred = self.denoiser(denoised_output, t, cond) - # compute the previous noisy sample x_t -> x_t-1 denoised_output = scheduler.step(noise_pred, t, denoised_output).prev_sample - if clip_noise: - denoised_output = paddle.clip(denoised_output, n_min, n_max) - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and - (i + 1) % scheduler.order == 0): - if callback is not None and i % callback_steps == 0: - callback(i, t, len(timesteps), denoised_output) + + if self.stretch: + assert self.min_values is not None and self.max_values is not None, "self.min_values and self.max_values should not be None." + denoised_output = denoised_output.transpose((0, 2, 1)) + denoised_output = self.denorm_spec(denoised_output) + denoised_output = denoised_output.transpose((0, 2, 1)) return denoised_output diff --git a/paddlespeech/t2s/modules/wavenet_denoiser.py b/paddlespeech/t2s/modules/wavenet_denoiser.py new file mode 100644 index 000000000..1cdd6ad9d --- /dev/null +++ b/paddlespeech/t2s/modules/wavenet_denoiser.py @@ -0,0 +1,184 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +from paddle import nn +from ppdiffusers.models.embeddings import Timesteps + +from paddlespeech.t2s.modules.nets_utils import initialize +from paddlespeech.t2s.modules.residual_block import WaveNetResidualBlock + + +class WaveNetDenoiser(nn.Layer): + """A Mel-Spectrogram Denoiser modified from WaveNet + + Args: + in_channels (int, optional): + Number of channels of the input mel-spectrogram, by default 80 + out_channels (int, optional): + Number of channels of the output mel-spectrogram, by default 80 + kernel_size (int, optional): + Kernel size of the residual blocks inside, by default 3 + layers (int, optional): + Number of residual blocks inside, by default 20 + stacks (int, optional): + The number of groups to split the residual blocks into, by default 5 + Within each group, the dilation of the residual block grows exponentially. + residual_channels (int, optional): + Residual channel of the residual blocks, by default 256 + gate_channels (int, optional): + Gate channel of the residual blocks, by default 512 + skip_channels (int, optional): + Skip channel of the residual blocks, by default 256 + aux_channels (int, optional): + Auxiliary channel of the residual blocks, by default 256 + dropout (float, optional): + Dropout of the residual blocks, by default 0. + bias (bool, optional): + Whether to use bias in residual blocks, by default True + use_weight_norm (bool, optional): + Whether to use weight norm in all convolutions, by default False + """ + + def __init__( + self, + in_channels: int=80, + out_channels: int=80, + kernel_size: int=3, + layers: int=20, + stacks: int=5, + residual_channels: int=256, + gate_channels: int=512, + skip_channels: int=256, + aux_channels: int=256, + dropout: float=0., + bias: bool=True, + use_weight_norm: bool=False, + init_type: str="kaiming_normal", ): + super().__init__() + + # initialize parameters + initialize(self, init_type) + + self.in_channels = in_channels + self.out_channels = out_channels + self.aux_channels = aux_channels + self.layers = layers + self.stacks = stacks + self.kernel_size = kernel_size + + assert layers % stacks == 0 + layers_per_stack = layers // stacks + + self.first_t_emb = nn.Sequential( + Timesteps( + residual_channels, + flip_sin_to_cos=False, + downscale_freq_shift=1), + nn.Linear(residual_channels, residual_channels * 4), + nn.Mish(), nn.Linear(residual_channels * 4, residual_channels)) + self.t_emb_layers = nn.LayerList([ + nn.Linear(residual_channels, residual_channels) + for _ in range(layers) + ]) + + self.first_conv = nn.Conv1D( + in_channels, residual_channels, 1, bias_attr=True) + self.first_act = nn.ReLU() + + self.conv_layers = nn.LayerList() + for layer in range(layers): + dilation = 2**(layer % layers_per_stack) + conv = WaveNetResidualBlock( + kernel_size=kernel_size, + residual_channels=residual_channels, + gate_channels=gate_channels, + skip_channels=skip_channels, + aux_channels=aux_channels, + dilation=dilation, + dropout=dropout, + bias=bias) + self.conv_layers.append(conv) + + final_conv = nn.Conv1D(skip_channels, out_channels, 1, bias_attr=True) + nn.initializer.Constant(0.0)(final_conv.weight) + self.last_conv_layers = nn.Sequential(nn.ReLU(), + nn.Conv1D( + skip_channels, + skip_channels, + 1, + bias_attr=True), + nn.ReLU(), final_conv) + + if use_weight_norm: + self.apply_weight_norm() + + def forward(self, x, t, c): + """Denoise mel-spectrogram. + + Args: + x(Tensor): + Shape (N, C_in, T), The input mel-spectrogram. + t(Tensor): + Shape (N), The timestep input. + c(Tensor): + Shape (N, C_aux, T'). The auxiliary input (e.g. fastspeech2 encoder output). + + Returns: + Tensor: Shape (N, C_out, T), the denoised mel-spectrogram. + """ + assert c.shape[-1] == x.shape[-1] + + if t.shape[0] != x.shape[0]: + t = t.tile([x.shape[0]]) + t_emb = self.first_t_emb(t) + t_embs = [ + t_emb_layer(t_emb)[..., None] for t_emb_layer in self.t_emb_layers + ] + + x = self.first_conv(x) + x = self.first_act(x) + skips = 0 + for f, t in zip(self.conv_layers, t_embs): + x = x + t + x, s = f(x, c) + skips += s + skips *= math.sqrt(1.0 / len(self.conv_layers)) + + x = self.last_conv_layers(skips) + return x + + def apply_weight_norm(self): + """Recursively apply weight normalization to all the Convolution layers + in the sublayers. + """ + + def _apply_weight_norm(layer): + if isinstance(layer, (nn.Conv1D, nn.Conv2D)): + nn.utils.weight_norm(layer) + + self.apply(_apply_weight_norm) + + def remove_weight_norm(self): + """Recursively remove weight normalization from all the Convolution + layers in the sublayers. + """ + + def _remove_weight_norm(layer): + try: + nn.utils.remove_weight_norm(layer) + except ValueError: + pass + + self.apply(_remove_weight_norm)