diffsinger, test=tts

pull/3005/head
lym0302 3 years ago
parent d7928d712d
commit 1d1e859de9

@ -23,8 +23,8 @@ f0max: 750 # Maximum f0 for pitch extraction.
###########################################################
# DATA SETTING #
###########################################################
batch_size: 24 # batch size
num_workers: 4 # number of gpu
batch_size: 48 # batch size
num_workers: 1 # number of gpu
###########################################################
@ -98,13 +98,14 @@ model:
use_weight_norm: False # Whether to use weight norm in all convolutions
init_type: "kaiming_normal" # Type of initialize weights of a neural network module
# diffusion module
diffusion_params:
num_train_timesteps: 100 # The number of timesteps between the noise and the real during training
beta_start: 0.0001 # beta start parameter for the scheduler
beta_end: 0.06 # beta end parameter for the scheduler
beta_schedule: "linear" # beta schedule parameter for the scheduler
num_max_timesteps: 60 # The max timestep transition from real to noise
num_max_timesteps: 100 # The max timestep transition from real to noise
###########################################################
@ -134,18 +135,18 @@ ds_optimizer_params:
ds_scheduler_params:
learning_rate: 0.001
gamma: 0.5
step_size: 10000
step_size: 50000
ds_grad_norm: 1
###########################################################
# INTERVAL SETTING #
###########################################################
ds_train_start_steps: 100000 # Number of steps to start to train diffusion module.
train_max_steps: 200000 # Number of training steps.
save_interval_steps: 500 # Interval steps to save checkpoint.
eval_interval_steps: 100 # Interval steps to evaluate the network.
num_snapshots: 10 # Number of saved models
ds_train_start_steps: 160000 # Number of steps to start to train diffusion module.
train_max_steps: 320000 # Number of training steps.
save_interval_steps: 2000 # Interval steps to save checkpoint.
eval_interval_steps: 2000 # Interval steps to evaluate the network.
num_snapshots: 5 # Number of saved models
###########################################################

@ -2,9 +2,9 @@
config_path=$1
train_output_path=$2
#ckpt_name=$3
iter=$3
ckpt_name=snapshot_iter_${iter}.pdz
ckpt_name=$3
#iter=$3
#ckpt_name=snapshot_iter_${iter}.pdz
stage=0
stop_stage=0
@ -21,7 +21,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--voc_config=pwgan_opencpop/default.yaml \
--voc_ckpt=pwgan_opencpop/snapshot_iter_100000.pdz \
--voc_stat=pwgan_opencpop/feats_stats.npy \
--test_metadata=test1.jsonl \
--test_metadata=test.jsonl \
--output_dir=${train_output_path}/test_${iter} \
--phones_dict=dump/phone_id_map.txt
fi

@ -8,5 +8,5 @@ python3 ${BIN_DIR}/train.py \
--dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \
--output-dir=${train_output_path} \
--ngpu=4 \
--ngpu=1 \
--phones-dict=dump/phone_id_map.txt

@ -3,9 +3,9 @@
set -e
source path.sh
gpus=4,5,6,7
stage=1
stop_stage=1
gpus=0
stage=0
stop_stage=100
conf_path=conf/default.yaml
train_output_path=exp/default

@ -23,6 +23,7 @@ from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.utils import str2bool
def main():
@ -58,6 +59,11 @@ def main():
"--phones-dict", type=str, default=None, help="phone vocabulary file.")
parser.add_argument(
"--speaker-dict", type=str, default=None, help="speaker id map file.")
parser.add_argument(
"--norm-feats",
type=str2bool,
default=False,
help="whether to norm features")
args = parser.parse_args()
@ -80,18 +86,36 @@ def main():
# restore scaler
speech_scaler = StandardScaler()
speech_scaler.mean_ = np.load(args.speech_stats)[0]
speech_scaler.scale_ = np.load(args.speech_stats)[1]
if args.norm_feats:
speech_scaler.mean_ = np.load(args.speech_stats)[0]
speech_scaler.scale_ = np.load(args.speech_stats)[1]
else:
speech_scaler.mean_ = np.zeros(
np.load(args.speech_stats)[0].shape, dtype="float32")
speech_scaler.scale_ = np.ones(
np.load(args.speech_stats)[1].shape, dtype="float32")
speech_scaler.n_features_in_ = speech_scaler.mean_.shape[0]
pitch_scaler = StandardScaler()
pitch_scaler.mean_ = np.load(args.pitch_stats)[0]
pitch_scaler.scale_ = np.load(args.pitch_stats)[1]
if args.norm_feats:
pitch_scaler.mean_ = np.load(args.pitch_stats)[0]
pitch_scaler.scale_ = np.load(args.pitch_stats)[1]
else:
pitch_scaler.mean_ = np.zeros(
np.load(args.pitch_stats)[0].shape, dtype="float32")
pitch_scaler.scale_ = np.ones(
np.load(args.pitch_stats)[1].shape, dtype="float32")
pitch_scaler.n_features_in_ = pitch_scaler.mean_.shape[0]
energy_scaler = StandardScaler()
energy_scaler.mean_ = np.load(args.energy_stats)[0]
energy_scaler.scale_ = np.load(args.energy_stats)[1]
if args.norm_feats:
energy_scaler.mean_ = np.load(args.energy_stats)[0]
energy_scaler.scale_ = np.load(args.energy_stats)[1]
else:
energy_scaler.mean_ = np.zeros(
np.load(args.energy_stats)[0].shape, dtype="float32")
energy_scaler.scale_ = np.ones(
np.load(args.energy_stats)[1].shape, dtype="float32")
energy_scaler.n_features_in_ = energy_scaler.mean_.shape[0]
vocab_phones = {}
@ -111,7 +135,6 @@ def main():
for item in tqdm(dataset):
utt_id = item['utt_id']
print(utt_id)
speech = item['speech']
pitch = item['pitch']
energy = item['energy']

@ -143,9 +143,17 @@ def train_sp(args, config):
print("criterions done!")
optimizer_fs2 = build_optimizers(model_fs2, **config["fs2_optimizer"])
# gradient_clip_ds = nn.ClipGradByGlobalNorm(config["ds_grad_norm"])
# optimizer_ds = AdamW(
# learning_rate=config["ds_scheduler_params"]["learning_rate"],
# grad_clip=gradient_clip_ds,
# parameters=model_ds.parameters(),
# **config["ds_optimizer_params"])
lr_schedule_ds = StepDecay(**config["ds_scheduler_params"])
gradient_clip_ds = nn.ClipGradByGlobalNorm(config["ds_grad_norm"])
optimizer_ds = AdamW(
learning_rate=config["ds_scheduler_params"]["learning_rate"],
learning_rate=lr_schedule_ds,
grad_clip=gradient_clip_ds,
parameters=model_ds.parameters(),
**config["ds_optimizer_params"])

@ -114,29 +114,44 @@ def evaluate(args):
is_slur = paddle.to_tensor(datum["is_slur"])
get_mel_fs2 = False
# mel: [T, mel_bin]
mel = am_inference(
mel1 = am_inference(
phone_ids,
note=note,
note_dur=note_dur,
is_slur=is_slur,
get_mel_fs2=get_mel_fs2)
get_mel_fs2=True)
mel2 = am_inference(
phone_ids,
note=note,
note_dur=note_dur,
is_slur=is_slur,
get_mel_fs2=False)
# import numpy as np
# mel = np.load("/home/liangyunming/others_code/DiffSinger_lym/diffsinger_mel.npy")
# mel = paddle.to_tensor(mel)
wav = voc_inference(mel)
wav1 = voc_inference(mel1)
wav2 = voc_inference(mel2)
wav = wav.numpy()
N += wav.size
wav1 = wav1.numpy()
wav2 = wav2.numpy()
N += wav1.size
N += wav2.size
T += t.elapse
speed = wav.size / t.elapse
speed = 2 * wav1.size / t.elapse
rtf = am_config.fs / speed
print(
f"{utt_id}, mel: {mel.shape}, wave: {wav.size}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
f"{utt_id}, mel: {mel1.shape}, wave: {wav1.size}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
)
sf.write(
# str(output_dir / ("xiaojiuwo_diffsinger" + ".wav")), wav, samplerate=am_config.fs)
str(output_dir / (utt_id + ".wav")), wav, samplerate=am_config.fs)
str(output_dir / (utt_id + "_fs2.wav")),
wav1,
samplerate=am_config.fs)
sf.write(
str(output_dir / (utt_id + "_diffusion.wav")),
wav2,
samplerate=am_config.fs)
print(f"{utt_id} done!")
# break
print(f"generation speed: {N / T}Hz, RTF: {am_config.fs / (N / T) }")

@ -17,13 +17,14 @@ from typing import Any
from typing import Dict
from typing import Tuple
import numpy as np
import paddle
from paddle import nn
from typeguard import check_argument_types
from paddlespeech.t2s.models.diffsinger.fastspeech2midi import FastSpeech2MIDI
from paddlespeech.t2s.modules.diffnet import DiffNet
from paddlespeech.t2s.modules.diffusion import GaussianDiffusion
from paddlespeech.t2s.modules.diffusion import WaveNetDenoiser
class DiffSinger(nn.Layer):
@ -134,7 +135,8 @@ class DiffSinger(nn.Layer):
"beta_end": 0.06,
"beta_schedule": "squaredcos_cap_v2",
"num_max_timesteps": 60
}, ):
},
stretch: bool=True, ):
"""Initialize DiffSinger module.
Args:
@ -156,8 +158,40 @@ class DiffSinger(nn.Layer):
fastspeech2_params=fastspeech2_params,
note_num=note_num,
is_slur_num=is_slur_num)
denoiser = WaveNetDenoiser(**denoiser_params)
self.diffusion = GaussianDiffusion(denoiser, **diffusion_params)
denoiser = DiffNet(**denoiser_params)
spec_min = paddle.to_tensor(
np.array([
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0
]))
spec_max = paddle.to_tensor(
np.array([
-0.79453, -0.81116, -0.61631, -0.30679, -0.13863, -0.050652,
-0.11563, -0.10679, -0.091068, -0.062174, -0.075302, -0.072217,
-0.063815, -0.073299, 0.007361, -0.072508, -0.050234, -0.16534,
-0.26928, -0.20782, -0.20823, -0.11702, -0.070128, -0.065868,
-0.012675, 0.0015121, -0.089902, -0.21392, -0.23789, -0.28922,
-0.30405, -0.23029, -0.22088, -0.21542, -0.29367, -0.30137,
-0.38281, -0.4359, -0.28681, -0.46855, -0.57485, -0.47022,
-0.54266, -0.44848, -0.6412, -0.687, -0.6486, -0.76436,
-0.49971, -0.71068, -0.69724, -0.61487, -0.55843, -0.69773,
-0.57502, -0.70919, -0.82431, -0.84213, -0.90431, -0.8284,
-0.77945, -0.82758, -0.87699, -1.0532, -1.0766, -1.1198,
-1.0185, -0.98983, -1.0001, -1.0756, -1.0024, -1.0304, -1.0579,
-1.0188, -1.05, -1.0842, -1.0923, -1.1223, -1.2381, -1.6467
]))
self.diffusion = GaussianDiffusion(
denoiser,
**diffusion_params,
stretch=stretch,
min_values=spec_min,
max_values=spec_max, )
def forward(
self,
@ -279,7 +313,7 @@ class DiffSinger(nn.Layer):
cond=cond_fs2,
ref_x=mel_fs2,
scheduler_type="ddpm",
num_inference_steps=25)
num_inference_steps=60)
mel = mel.transpose((0, 2, 1))
return mel[0]

@ -0,0 +1,188 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import paddle
import paddle.nn.functional as F
from paddle import nn
from paddlespeech.t2s.modules.nets_utils import initialize
from paddlespeech.utils.initialize import _calculate_fan_in_and_fan_out
from paddlespeech.utils.initialize import kaiming_normal_
from paddlespeech.utils.initialize import kaiming_uniform_
from paddlespeech.utils.initialize import uniform_
from paddlespeech.utils.initialize import zeros_
def Conv1D(*args, **kwargs):
layer = nn.Conv1D(*args, **kwargs)
# Initialize the weight to be consistent with the official
kaiming_normal_(layer.weight)
# Initialization is consistent with torch
if layer.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(layer.weight)
if fan_in != 0:
bound = 1 / math.sqrt(fan_in)
uniform_(layer.bias, -bound, bound)
return layer
# Initialization is consistent with torch
def Linear(*args, **kwargs):
layer = nn.Linear(*args, **kwargs)
kaiming_uniform_(layer.weight, a=math.sqrt(5))
if layer.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(layer.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
uniform_(layer.bias, -bound, bound)
return layer
class ResidualBlock(nn.Layer):
"""ResidualBlock
"""
def __init__(self, encoder_hidden, residual_channels, gate_channels,
kernel_size, dilation):
super().__init__()
self.dilated_conv = Conv1D(
residual_channels,
gate_channels,
kernel_size,
padding=dilation,
dilation=dilation)
self.diffusion_projection = Linear(residual_channels, residual_channels)
self.conditioner_projection = Conv1D(encoder_hidden, gate_channels, 1)
self.output_projection = Conv1D(residual_channels, gate_channels, 1)
def forward(self, x, conditioner, diffusion_step):
"""_summary_
Args:
nn (_type_): _description_
"""
diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
conditioner = self.conditioner_projection(conditioner)
y = x + diffusion_step
y = self.dilated_conv(y) + conditioner
gate, filter = paddle.chunk(y, 2, axis=1)
y = F.sigmoid(gate) * paddle.tanh(filter)
y = self.output_projection(y)
residual, skip = paddle.chunk(y, 2, axis=1)
return (x + residual) / math.sqrt(2.0), skip
class SinusoidalPosEmb(nn.Layer):
"""_summary_
Args:
nn (_type_): _description_
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
"""_summary_
Args:
nn (_type_): _description_
"""
x = paddle.cast(x, 'float32')
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = paddle.exp(paddle.arange(half_dim) * -emb)
emb = x[:, None] * emb[None, :]
emb = paddle.concat([emb.sin(), emb.cos()], axis=-1)
return emb
class DiffNet(nn.Layer):
def __init__(
self,
in_channels: int=80,
out_channels: int=80,
kernel_size: int=3,
layers: int=20,
stacks: int=5,
residual_channels: int=256,
gate_channels: int=512,
skip_channels: int=256,
aux_channels: int=256,
dropout: float=0.,
bias: bool=True,
use_weight_norm: bool=False,
init_type: str="kaiming_normal", ):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.layers = layers
self.aux_channels = aux_channels
self.residual_channels = residual_channels
self.gate_channels = gate_channels
self.kernel_size = kernel_size
self.dilation_cycle_length = layers // stacks
self.skip_channels = skip_channels
self.input_projection = Conv1D(self.in_channels, self.residual_channels,
1)
self.diffusion_embedding = SinusoidalPosEmb(self.residual_channels)
dim = self.residual_channels
self.mlp = nn.Sequential(
Linear(dim, dim * 4), nn.Mish(), Linear(dim * 4, dim))
self.residual_layers = nn.LayerList([
ResidualBlock(
encoder_hidden=self.aux_channels,
residual_channels=self.residual_channels,
gate_channels=self.gate_channels,
kernel_size=self.kernel_size,
dilation=2**(i % self.dilation_cycle_length))
for i in range(self.layers)
])
self.skip_projection = Conv1D(self.residual_channels,
self.skip_channels, 1)
self.output_projection = Conv1D(self.residual_channels,
self.out_channels, 1)
zeros_(self.output_projection.weight)
def forward(self, spec, diffusion_step, cond):
"""
:param spec: [B, M, T]
:param diffusion_step: [B, 1]
:param cond: [B, M, T]
:return:
"""
x = spec
x = self.input_projection(x) # x [B, residual_channel, T]
x = F.relu(x)
diffusion_step = self.diffusion_embedding(diffusion_step)
diffusion_step = self.mlp(diffusion_step)
skip = []
for layer_id, layer in enumerate(self.residual_layers):
x, skip_connection = layer(x, cond, diffusion_step)
skip.append(skip_connection)
x = paddle.sum(
paddle.stack(skip), axis=0) / math.sqrt(len(self.residual_layers))
x = self.skip_projection(x)
x = F.relu(x)
x = self.output_projection(x) # [B, 80, T]
return x

@ -17,6 +17,7 @@ from typing import Callable
from typing import Optional
from typing import Tuple
import numpy as np
import paddle
import ppdiffusers
from paddle import nn
@ -27,170 +28,6 @@ from paddlespeech.t2s.modules.nets_utils import initialize
from paddlespeech.t2s.modules.residual_block import WaveNetResidualBlock
class WaveNetDenoiser(nn.Layer):
"""A Mel-Spectrogram Denoiser modified from WaveNet
Args:
in_channels (int, optional):
Number of channels of the input mel-spectrogram, by default 80
out_channels (int, optional):
Number of channels of the output mel-spectrogram, by default 80
kernel_size (int, optional):
Kernel size of the residual blocks inside, by default 3
layers (int, optional):
Number of residual blocks inside, by default 20
stacks (int, optional):
The number of groups to split the residual blocks into, by default 5
Within each group, the dilation of the residual block grows exponentially.
residual_channels (int, optional):
Residual channel of the residual blocks, by default 256
gate_channels (int, optional):
Gate channel of the residual blocks, by default 512
skip_channels (int, optional):
Skip channel of the residual blocks, by default 256
aux_channels (int, optional):
Auxiliary channel of the residual blocks, by default 256
dropout (float, optional):
Dropout of the residual blocks, by default 0.
bias (bool, optional):
Whether to use bias in residual blocks, by default True
use_weight_norm (bool, optional):
Whether to use weight norm in all convolutions, by default False
"""
def __init__(
self,
in_channels: int=80,
out_channels: int=80,
kernel_size: int=3,
layers: int=20,
stacks: int=5,
residual_channels: int=256,
gate_channels: int=512,
skip_channels: int=256,
aux_channels: int=256,
dropout: float=0.,
bias: bool=True,
use_weight_norm: bool=False,
init_type: str="kaiming_normal", ):
super().__init__()
# initialize parameters
initialize(self, init_type)
self.in_channels = in_channels
self.out_channels = out_channels
self.aux_channels = aux_channels
self.layers = layers
self.stacks = stacks
self.kernel_size = kernel_size
assert layers % stacks == 0
layers_per_stack = layers // stacks
self.first_t_emb = nn.Sequential(
Timesteps(
residual_channels,
flip_sin_to_cos=False,
downscale_freq_shift=1),
nn.Linear(residual_channels, residual_channels * 4),
nn.Mish(), nn.Linear(residual_channels * 4, residual_channels))
self.t_emb_layers = nn.LayerList([
nn.Linear(residual_channels, residual_channels)
for _ in range(layers)
])
self.first_conv = nn.Conv1D(
in_channels, residual_channels, 1, bias_attr=True)
self.first_act = nn.ReLU()
self.conv_layers = nn.LayerList()
for layer in range(layers):
dilation = 2**(layer % layers_per_stack)
conv = WaveNetResidualBlock(
kernel_size=kernel_size,
residual_channels=residual_channels,
gate_channels=gate_channels,
skip_channels=skip_channels,
aux_channels=aux_channels,
dilation=dilation,
dropout=dropout,
bias=bias)
self.conv_layers.append(conv)
final_conv = nn.Conv1D(skip_channels, out_channels, 1, bias_attr=True)
nn.initializer.Constant(0.0)(final_conv.weight)
self.last_conv_layers = nn.Sequential(nn.ReLU(),
nn.Conv1D(
skip_channels,
skip_channels,
1,
bias_attr=True),
nn.ReLU(), final_conv)
if use_weight_norm:
self.apply_weight_norm()
def forward(self, x, t, c):
"""Denoise mel-spectrogram.
Args:
x(Tensor):
Shape (N, C_in, T), The input mel-spectrogram.
t(Tensor):
Shape (N), The timestep input.
c(Tensor):
Shape (N, C_aux, T'). The auxiliary input (e.g. fastspeech2 encoder output).
Returns:
Tensor: Shape (N, C_out, T), the denoised mel-spectrogram.
"""
assert c.shape[-1] == x.shape[-1]
if t.shape[0] != x.shape[0]:
t = t.tile([x.shape[0]])
t_emb = self.first_t_emb(t)
t_embs = [
t_emb_layer(t_emb)[..., None] for t_emb_layer in self.t_emb_layers
]
x = self.first_conv(x)
x = self.first_act(x)
skips = 0
for f, t in zip(self.conv_layers, t_embs):
x = x + t
x, s = f(x, c)
skips += s
skips *= math.sqrt(1.0 / len(self.conv_layers))
x = self.last_conv_layers(skips)
return x
def apply_weight_norm(self):
"""Recursively apply weight normalization to all the Convolution layers
in the sublayers.
"""
def _apply_weight_norm(layer):
if isinstance(layer, (nn.Conv1D, nn.Conv2D)):
nn.utils.weight_norm(layer)
self.apply(_apply_weight_norm)
def remove_weight_norm(self):
"""Recursively remove weight normalization from all the Convolution
layers in the sublayers.
"""
def _remove_weight_norm(layer):
try:
nn.utils.remove_weight_norm(layer)
except ValueError:
pass
self.apply(_remove_weight_norm)
class GaussianDiffusion(nn.Layer):
"""Common Gaussian Diffusion Denoising Model Module
@ -294,13 +131,17 @@ class GaussianDiffusion(nn.Layer):
"""
def __init__(self,
denoiser: nn.Layer,
num_train_timesteps: Optional[int]=1000,
beta_start: Optional[float]=0.0001,
beta_end: Optional[float]=0.02,
beta_schedule: Optional[str]="squaredcos_cap_v2",
num_max_timesteps: Optional[int]=None):
def __init__(
self,
denoiser: nn.Layer,
num_train_timesteps: Optional[int]=1000,
beta_start: Optional[float]=0.0001,
beta_end: Optional[float]=0.02,
beta_schedule: Optional[str]="squaredcos_cap_v2",
num_max_timesteps: Optional[int]=None,
stretch: bool=True,
min_values: paddle.Tensor=None,
max_values: paddle.Tensor=None, ):
super().__init__()
self.num_train_timesteps = num_train_timesteps
@ -315,6 +156,22 @@ class GaussianDiffusion(nn.Layer):
beta_end=beta_end,
beta_schedule=beta_schedule)
self.num_max_timesteps = num_max_timesteps
self.stretch = stretch
self.min_values = min_values
self.max_values = max_values
def norm_spec(self, x):
"""
Linearly map x to [-1, 1]
Args:
x: [B, T, N]
"""
return (x - self.min_values) / (self.max_values - self.min_values
) * 2 - 1
def denorm_spec(self, x):
return (x + 1) / 2 * (self.max_values - self.min_values
) + self.min_values
def forward(self, x: paddle.Tensor, cond: Optional[paddle.Tensor]=None
) -> Tuple[paddle.Tensor, paddle.Tensor]:
@ -333,6 +190,12 @@ class GaussianDiffusion(nn.Layer):
The noises which is added to the input.
"""
if self.stretch:
assert self.min_values is not None and self.max_values is not None, "self.min_values and self.max_values should not be None."
x = x.transpose((0, 2, 1))
x = self.norm_spec(x)
x = x.transpose((0, 2, 1))
noise_scheduler = self.noise_scheduler
# Sample noise that we'll add to the mel-spectrograms
@ -369,9 +232,9 @@ class GaussianDiffusion(nn.Layer):
Args:
noise (Tensor):
The input tensor as a starting point for denoising.
The input tensor as a starting point for denoising.
cond (Tensor, optional):
Conditional input for compute noises.
Conditional input for compute noises. (N, C_aux, T)
ref_x (Tensor, optional):
The real output for the denoising process to refer.
num_inference_steps (int, optional):
@ -382,6 +245,7 @@ class GaussianDiffusion(nn.Layer):
scheduler_type (str, optional):
Noise scheduler for generate noises.
Choose a great scheduler can skip many denoising step, by default 'ddpm'.
only support 'ddpm' now !
clip_noise (bool, optional):
Whether to clip each denoised output, by default True.
clip_noise_range (tuple, optional):
@ -425,48 +289,30 @@ class GaussianDiffusion(nn.Layer):
# set timesteps
scheduler.set_timesteps(num_inference_steps)
# prepare first noise variables
noisy_input = noise
timesteps = scheduler.timesteps
if ref_x is not None:
init_timestep = None
if strength is None or strength < 0. or strength > 1.:
strength = None
if self.num_max_timesteps is not None:
strength = self.num_max_timesteps / self.num_train_timesteps
if strength is not None:
# get the original timestep using init_timestep
init_timestep = min(
int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = scheduler.timesteps[t_start:]
num_inference_steps = num_inference_steps - t_start
noisy_input = scheduler.add_noise(
ref_x, noise, timesteps[:1].tile([noise.shape[0]]))
# denoising loop
if self.stretch and ref_x is not None:
assert self.min_values is not None and self.max_values is not None, "self.min_values and self.max_values should not be None."
ref_x = ref_x.transpose((0, 2, 1))
ref_x = self.norm_spec(ref_x)
ref_x = ref_x.transpose((0, 2, 1))
# for ddpm
timesteps = paddle.to_tensor(
np.flipud(np.arange(num_inference_steps)))
noisy_input = scheduler.add_noise(ref_x, noise, timesteps[0])
denoised_output = noisy_input
if clip_noise:
n_min, n_max = clip_noise_range
denoised_output = paddle.clip(denoised_output, n_min, n_max)
num_warmup_steps = len(
timesteps) - num_inference_steps * scheduler.order
for i, t in enumerate(timesteps):
denoised_output = scheduler.scale_model_input(denoised_output, t)
# predict the noise residual
noise_pred = self.denoiser(denoised_output, t, cond)
# compute the previous noisy sample x_t -> x_t-1
denoised_output = scheduler.step(noise_pred, t,
denoised_output).prev_sample
if clip_noise:
denoised_output = paddle.clip(denoised_output, n_min, n_max)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and
(i + 1) % scheduler.order == 0):
if callback is not None and i % callback_steps == 0:
callback(i, t, len(timesteps), denoised_output)
if self.stretch:
assert self.min_values is not None and self.max_values is not None, "self.min_values and self.max_values should not be None."
denoised_output = denoised_output.transpose((0, 2, 1))
denoised_output = self.denorm_spec(denoised_output)
denoised_output = denoised_output.transpose((0, 2, 1))
return denoised_output

@ -0,0 +1,184 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from paddle import nn
from ppdiffusers.models.embeddings import Timesteps
from paddlespeech.t2s.modules.nets_utils import initialize
from paddlespeech.t2s.modules.residual_block import WaveNetResidualBlock
class WaveNetDenoiser(nn.Layer):
"""A Mel-Spectrogram Denoiser modified from WaveNet
Args:
in_channels (int, optional):
Number of channels of the input mel-spectrogram, by default 80
out_channels (int, optional):
Number of channels of the output mel-spectrogram, by default 80
kernel_size (int, optional):
Kernel size of the residual blocks inside, by default 3
layers (int, optional):
Number of residual blocks inside, by default 20
stacks (int, optional):
The number of groups to split the residual blocks into, by default 5
Within each group, the dilation of the residual block grows exponentially.
residual_channels (int, optional):
Residual channel of the residual blocks, by default 256
gate_channels (int, optional):
Gate channel of the residual blocks, by default 512
skip_channels (int, optional):
Skip channel of the residual blocks, by default 256
aux_channels (int, optional):
Auxiliary channel of the residual blocks, by default 256
dropout (float, optional):
Dropout of the residual blocks, by default 0.
bias (bool, optional):
Whether to use bias in residual blocks, by default True
use_weight_norm (bool, optional):
Whether to use weight norm in all convolutions, by default False
"""
def __init__(
self,
in_channels: int=80,
out_channels: int=80,
kernel_size: int=3,
layers: int=20,
stacks: int=5,
residual_channels: int=256,
gate_channels: int=512,
skip_channels: int=256,
aux_channels: int=256,
dropout: float=0.,
bias: bool=True,
use_weight_norm: bool=False,
init_type: str="kaiming_normal", ):
super().__init__()
# initialize parameters
initialize(self, init_type)
self.in_channels = in_channels
self.out_channels = out_channels
self.aux_channels = aux_channels
self.layers = layers
self.stacks = stacks
self.kernel_size = kernel_size
assert layers % stacks == 0
layers_per_stack = layers // stacks
self.first_t_emb = nn.Sequential(
Timesteps(
residual_channels,
flip_sin_to_cos=False,
downscale_freq_shift=1),
nn.Linear(residual_channels, residual_channels * 4),
nn.Mish(), nn.Linear(residual_channels * 4, residual_channels))
self.t_emb_layers = nn.LayerList([
nn.Linear(residual_channels, residual_channels)
for _ in range(layers)
])
self.first_conv = nn.Conv1D(
in_channels, residual_channels, 1, bias_attr=True)
self.first_act = nn.ReLU()
self.conv_layers = nn.LayerList()
for layer in range(layers):
dilation = 2**(layer % layers_per_stack)
conv = WaveNetResidualBlock(
kernel_size=kernel_size,
residual_channels=residual_channels,
gate_channels=gate_channels,
skip_channels=skip_channels,
aux_channels=aux_channels,
dilation=dilation,
dropout=dropout,
bias=bias)
self.conv_layers.append(conv)
final_conv = nn.Conv1D(skip_channels, out_channels, 1, bias_attr=True)
nn.initializer.Constant(0.0)(final_conv.weight)
self.last_conv_layers = nn.Sequential(nn.ReLU(),
nn.Conv1D(
skip_channels,
skip_channels,
1,
bias_attr=True),
nn.ReLU(), final_conv)
if use_weight_norm:
self.apply_weight_norm()
def forward(self, x, t, c):
"""Denoise mel-spectrogram.
Args:
x(Tensor):
Shape (N, C_in, T), The input mel-spectrogram.
t(Tensor):
Shape (N), The timestep input.
c(Tensor):
Shape (N, C_aux, T'). The auxiliary input (e.g. fastspeech2 encoder output).
Returns:
Tensor: Shape (N, C_out, T), the denoised mel-spectrogram.
"""
assert c.shape[-1] == x.shape[-1]
if t.shape[0] != x.shape[0]:
t = t.tile([x.shape[0]])
t_emb = self.first_t_emb(t)
t_embs = [
t_emb_layer(t_emb)[..., None] for t_emb_layer in self.t_emb_layers
]
x = self.first_conv(x)
x = self.first_act(x)
skips = 0
for f, t in zip(self.conv_layers, t_embs):
x = x + t
x, s = f(x, c)
skips += s
skips *= math.sqrt(1.0 / len(self.conv_layers))
x = self.last_conv_layers(skips)
return x
def apply_weight_norm(self):
"""Recursively apply weight normalization to all the Convolution layers
in the sublayers.
"""
def _apply_weight_norm(layer):
if isinstance(layer, (nn.Conv1D, nn.Conv2D)):
nn.utils.weight_norm(layer)
self.apply(_apply_weight_norm)
def remove_weight_norm(self):
"""Recursively remove weight normalization from all the Convolution
layers in the sublayers.
"""
def _remove_weight_norm(layer):
try:
nn.utils.remove_weight_norm(layer)
except ValueError:
pass
self.apply(_remove_weight_norm)
Loading…
Cancel
Save