diffsinger, test=tts

pull/3005/head
lym0302 3 years ago
parent d7928d712d
commit 1d1e859de9

@ -23,8 +23,8 @@ f0max: 750 # Maximum f0 for pitch extraction.
########################################################### ###########################################################
# DATA SETTING # # DATA SETTING #
########################################################### ###########################################################
batch_size: 24 # batch size batch_size: 48 # batch size
num_workers: 4 # number of gpu num_workers: 1 # number of gpu
########################################################### ###########################################################
@ -98,13 +98,14 @@ model:
use_weight_norm: False # Whether to use weight norm in all convolutions use_weight_norm: False # Whether to use weight norm in all convolutions
init_type: "kaiming_normal" # Type of initialize weights of a neural network module init_type: "kaiming_normal" # Type of initialize weights of a neural network module
# diffusion module
diffusion_params: diffusion_params:
num_train_timesteps: 100 # The number of timesteps between the noise and the real during training num_train_timesteps: 100 # The number of timesteps between the noise and the real during training
beta_start: 0.0001 # beta start parameter for the scheduler beta_start: 0.0001 # beta start parameter for the scheduler
beta_end: 0.06 # beta end parameter for the scheduler beta_end: 0.06 # beta end parameter for the scheduler
beta_schedule: "linear" # beta schedule parameter for the scheduler beta_schedule: "linear" # beta schedule parameter for the scheduler
num_max_timesteps: 60 # The max timestep transition from real to noise num_max_timesteps: 100 # The max timestep transition from real to noise
########################################################### ###########################################################
@ -134,18 +135,18 @@ ds_optimizer_params:
ds_scheduler_params: ds_scheduler_params:
learning_rate: 0.001 learning_rate: 0.001
gamma: 0.5 gamma: 0.5
step_size: 10000 step_size: 50000
ds_grad_norm: 1 ds_grad_norm: 1
########################################################### ###########################################################
# INTERVAL SETTING # # INTERVAL SETTING #
########################################################### ###########################################################
ds_train_start_steps: 100000 # Number of steps to start to train diffusion module. ds_train_start_steps: 160000 # Number of steps to start to train diffusion module.
train_max_steps: 200000 # Number of training steps. train_max_steps: 320000 # Number of training steps.
save_interval_steps: 500 # Interval steps to save checkpoint. save_interval_steps: 2000 # Interval steps to save checkpoint.
eval_interval_steps: 100 # Interval steps to evaluate the network. eval_interval_steps: 2000 # Interval steps to evaluate the network.
num_snapshots: 10 # Number of saved models num_snapshots: 5 # Number of saved models
########################################################### ###########################################################

@ -2,9 +2,9 @@
config_path=$1 config_path=$1
train_output_path=$2 train_output_path=$2
#ckpt_name=$3 ckpt_name=$3
iter=$3 #iter=$3
ckpt_name=snapshot_iter_${iter}.pdz #ckpt_name=snapshot_iter_${iter}.pdz
stage=0 stage=0
stop_stage=0 stop_stage=0
@ -21,7 +21,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--voc_config=pwgan_opencpop/default.yaml \ --voc_config=pwgan_opencpop/default.yaml \
--voc_ckpt=pwgan_opencpop/snapshot_iter_100000.pdz \ --voc_ckpt=pwgan_opencpop/snapshot_iter_100000.pdz \
--voc_stat=pwgan_opencpop/feats_stats.npy \ --voc_stat=pwgan_opencpop/feats_stats.npy \
--test_metadata=test1.jsonl \ --test_metadata=test.jsonl \
--output_dir=${train_output_path}/test_${iter} \ --output_dir=${train_output_path}/test_${iter} \
--phones_dict=dump/phone_id_map.txt --phones_dict=dump/phone_id_map.txt
fi fi

@ -8,5 +8,5 @@ python3 ${BIN_DIR}/train.py \
--dev-metadata=dump/dev/norm/metadata.jsonl \ --dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \ --config=${config_path} \
--output-dir=${train_output_path} \ --output-dir=${train_output_path} \
--ngpu=4 \ --ngpu=1 \
--phones-dict=dump/phone_id_map.txt --phones-dict=dump/phone_id_map.txt

@ -3,9 +3,9 @@
set -e set -e
source path.sh source path.sh
gpus=4,5,6,7 gpus=0
stage=1 stage=0
stop_stage=1 stop_stage=100
conf_path=conf/default.yaml conf_path=conf/default.yaml
train_output_path=exp/default train_output_path=exp/default

@ -23,6 +23,7 @@ from sklearn.preprocessing import StandardScaler
from tqdm import tqdm from tqdm import tqdm
from paddlespeech.t2s.datasets.data_table import DataTable from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.utils import str2bool
def main(): def main():
@ -58,6 +59,11 @@ def main():
"--phones-dict", type=str, default=None, help="phone vocabulary file.") "--phones-dict", type=str, default=None, help="phone vocabulary file.")
parser.add_argument( parser.add_argument(
"--speaker-dict", type=str, default=None, help="speaker id map file.") "--speaker-dict", type=str, default=None, help="speaker id map file.")
parser.add_argument(
"--norm-feats",
type=str2bool,
default=False,
help="whether to norm features")
args = parser.parse_args() args = parser.parse_args()
@ -80,18 +86,36 @@ def main():
# restore scaler # restore scaler
speech_scaler = StandardScaler() speech_scaler = StandardScaler()
speech_scaler.mean_ = np.load(args.speech_stats)[0] if args.norm_feats:
speech_scaler.scale_ = np.load(args.speech_stats)[1] speech_scaler.mean_ = np.load(args.speech_stats)[0]
speech_scaler.scale_ = np.load(args.speech_stats)[1]
else:
speech_scaler.mean_ = np.zeros(
np.load(args.speech_stats)[0].shape, dtype="float32")
speech_scaler.scale_ = np.ones(
np.load(args.speech_stats)[1].shape, dtype="float32")
speech_scaler.n_features_in_ = speech_scaler.mean_.shape[0] speech_scaler.n_features_in_ = speech_scaler.mean_.shape[0]
pitch_scaler = StandardScaler() pitch_scaler = StandardScaler()
pitch_scaler.mean_ = np.load(args.pitch_stats)[0] if args.norm_feats:
pitch_scaler.scale_ = np.load(args.pitch_stats)[1] pitch_scaler.mean_ = np.load(args.pitch_stats)[0]
pitch_scaler.scale_ = np.load(args.pitch_stats)[1]
else:
pitch_scaler.mean_ = np.zeros(
np.load(args.pitch_stats)[0].shape, dtype="float32")
pitch_scaler.scale_ = np.ones(
np.load(args.pitch_stats)[1].shape, dtype="float32")
pitch_scaler.n_features_in_ = pitch_scaler.mean_.shape[0] pitch_scaler.n_features_in_ = pitch_scaler.mean_.shape[0]
energy_scaler = StandardScaler() energy_scaler = StandardScaler()
energy_scaler.mean_ = np.load(args.energy_stats)[0] if args.norm_feats:
energy_scaler.scale_ = np.load(args.energy_stats)[1] energy_scaler.mean_ = np.load(args.energy_stats)[0]
energy_scaler.scale_ = np.load(args.energy_stats)[1]
else:
energy_scaler.mean_ = np.zeros(
np.load(args.energy_stats)[0].shape, dtype="float32")
energy_scaler.scale_ = np.ones(
np.load(args.energy_stats)[1].shape, dtype="float32")
energy_scaler.n_features_in_ = energy_scaler.mean_.shape[0] energy_scaler.n_features_in_ = energy_scaler.mean_.shape[0]
vocab_phones = {} vocab_phones = {}
@ -111,7 +135,6 @@ def main():
for item in tqdm(dataset): for item in tqdm(dataset):
utt_id = item['utt_id'] utt_id = item['utt_id']
print(utt_id)
speech = item['speech'] speech = item['speech']
pitch = item['pitch'] pitch = item['pitch']
energy = item['energy'] energy = item['energy']

@ -143,9 +143,17 @@ def train_sp(args, config):
print("criterions done!") print("criterions done!")
optimizer_fs2 = build_optimizers(model_fs2, **config["fs2_optimizer"]) optimizer_fs2 = build_optimizers(model_fs2, **config["fs2_optimizer"])
# gradient_clip_ds = nn.ClipGradByGlobalNorm(config["ds_grad_norm"])
# optimizer_ds = AdamW(
# learning_rate=config["ds_scheduler_params"]["learning_rate"],
# grad_clip=gradient_clip_ds,
# parameters=model_ds.parameters(),
# **config["ds_optimizer_params"])
lr_schedule_ds = StepDecay(**config["ds_scheduler_params"])
gradient_clip_ds = nn.ClipGradByGlobalNorm(config["ds_grad_norm"]) gradient_clip_ds = nn.ClipGradByGlobalNorm(config["ds_grad_norm"])
optimizer_ds = AdamW( optimizer_ds = AdamW(
learning_rate=config["ds_scheduler_params"]["learning_rate"], learning_rate=lr_schedule_ds,
grad_clip=gradient_clip_ds, grad_clip=gradient_clip_ds,
parameters=model_ds.parameters(), parameters=model_ds.parameters(),
**config["ds_optimizer_params"]) **config["ds_optimizer_params"])

@ -114,29 +114,44 @@ def evaluate(args):
is_slur = paddle.to_tensor(datum["is_slur"]) is_slur = paddle.to_tensor(datum["is_slur"])
get_mel_fs2 = False get_mel_fs2 = False
# mel: [T, mel_bin] # mel: [T, mel_bin]
mel = am_inference( mel1 = am_inference(
phone_ids, phone_ids,
note=note, note=note,
note_dur=note_dur, note_dur=note_dur,
is_slur=is_slur, is_slur=is_slur,
get_mel_fs2=get_mel_fs2) get_mel_fs2=True)
mel2 = am_inference(
phone_ids,
note=note,
note_dur=note_dur,
is_slur=is_slur,
get_mel_fs2=False)
# import numpy as np # import numpy as np
# mel = np.load("/home/liangyunming/others_code/DiffSinger_lym/diffsinger_mel.npy") # mel = np.load("/home/liangyunming/others_code/DiffSinger_lym/diffsinger_mel.npy")
# mel = paddle.to_tensor(mel) # mel = paddle.to_tensor(mel)
wav = voc_inference(mel) wav1 = voc_inference(mel1)
wav2 = voc_inference(mel2)
wav = wav.numpy() wav1 = wav1.numpy()
N += wav.size wav2 = wav2.numpy()
N += wav1.size
N += wav2.size
T += t.elapse T += t.elapse
speed = wav.size / t.elapse speed = 2 * wav1.size / t.elapse
rtf = am_config.fs / speed rtf = am_config.fs / speed
print( print(
f"{utt_id}, mel: {mel.shape}, wave: {wav.size}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." f"{utt_id}, mel: {mel1.shape}, wave: {wav1.size}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
) )
sf.write( sf.write(
# str(output_dir / ("xiaojiuwo_diffsinger" + ".wav")), wav, samplerate=am_config.fs) # str(output_dir / ("xiaojiuwo_diffsinger" + ".wav")), wav, samplerate=am_config.fs)
str(output_dir / (utt_id + ".wav")), wav, samplerate=am_config.fs) str(output_dir / (utt_id + "_fs2.wav")),
wav1,
samplerate=am_config.fs)
sf.write(
str(output_dir / (utt_id + "_diffusion.wav")),
wav2,
samplerate=am_config.fs)
print(f"{utt_id} done!") print(f"{utt_id} done!")
# break # break
print(f"generation speed: {N / T}Hz, RTF: {am_config.fs / (N / T) }") print(f"generation speed: {N / T}Hz, RTF: {am_config.fs / (N / T) }")

@ -17,13 +17,14 @@ from typing import Any
from typing import Dict from typing import Dict
from typing import Tuple from typing import Tuple
import numpy as np
import paddle import paddle
from paddle import nn from paddle import nn
from typeguard import check_argument_types from typeguard import check_argument_types
from paddlespeech.t2s.models.diffsinger.fastspeech2midi import FastSpeech2MIDI from paddlespeech.t2s.models.diffsinger.fastspeech2midi import FastSpeech2MIDI
from paddlespeech.t2s.modules.diffnet import DiffNet
from paddlespeech.t2s.modules.diffusion import GaussianDiffusion from paddlespeech.t2s.modules.diffusion import GaussianDiffusion
from paddlespeech.t2s.modules.diffusion import WaveNetDenoiser
class DiffSinger(nn.Layer): class DiffSinger(nn.Layer):
@ -134,7 +135,8 @@ class DiffSinger(nn.Layer):
"beta_end": 0.06, "beta_end": 0.06,
"beta_schedule": "squaredcos_cap_v2", "beta_schedule": "squaredcos_cap_v2",
"num_max_timesteps": 60 "num_max_timesteps": 60
}, ): },
stretch: bool=True, ):
"""Initialize DiffSinger module. """Initialize DiffSinger module.
Args: Args:
@ -156,8 +158,40 @@ class DiffSinger(nn.Layer):
fastspeech2_params=fastspeech2_params, fastspeech2_params=fastspeech2_params,
note_num=note_num, note_num=note_num,
is_slur_num=is_slur_num) is_slur_num=is_slur_num)
denoiser = WaveNetDenoiser(**denoiser_params) denoiser = DiffNet(**denoiser_params)
self.diffusion = GaussianDiffusion(denoiser, **diffusion_params) spec_min = paddle.to_tensor(
np.array([
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0
]))
spec_max = paddle.to_tensor(
np.array([
-0.79453, -0.81116, -0.61631, -0.30679, -0.13863, -0.050652,
-0.11563, -0.10679, -0.091068, -0.062174, -0.075302, -0.072217,
-0.063815, -0.073299, 0.007361, -0.072508, -0.050234, -0.16534,
-0.26928, -0.20782, -0.20823, -0.11702, -0.070128, -0.065868,
-0.012675, 0.0015121, -0.089902, -0.21392, -0.23789, -0.28922,
-0.30405, -0.23029, -0.22088, -0.21542, -0.29367, -0.30137,
-0.38281, -0.4359, -0.28681, -0.46855, -0.57485, -0.47022,
-0.54266, -0.44848, -0.6412, -0.687, -0.6486, -0.76436,
-0.49971, -0.71068, -0.69724, -0.61487, -0.55843, -0.69773,
-0.57502, -0.70919, -0.82431, -0.84213, -0.90431, -0.8284,
-0.77945, -0.82758, -0.87699, -1.0532, -1.0766, -1.1198,
-1.0185, -0.98983, -1.0001, -1.0756, -1.0024, -1.0304, -1.0579,
-1.0188, -1.05, -1.0842, -1.0923, -1.1223, -1.2381, -1.6467
]))
self.diffusion = GaussianDiffusion(
denoiser,
**diffusion_params,
stretch=stretch,
min_values=spec_min,
max_values=spec_max, )
def forward( def forward(
self, self,
@ -279,7 +313,7 @@ class DiffSinger(nn.Layer):
cond=cond_fs2, cond=cond_fs2,
ref_x=mel_fs2, ref_x=mel_fs2,
scheduler_type="ddpm", scheduler_type="ddpm",
num_inference_steps=25) num_inference_steps=60)
mel = mel.transpose((0, 2, 1)) mel = mel.transpose((0, 2, 1))
return mel[0] return mel[0]

@ -0,0 +1,188 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import paddle
import paddle.nn.functional as F
from paddle import nn
from paddlespeech.t2s.modules.nets_utils import initialize
from paddlespeech.utils.initialize import _calculate_fan_in_and_fan_out
from paddlespeech.utils.initialize import kaiming_normal_
from paddlespeech.utils.initialize import kaiming_uniform_
from paddlespeech.utils.initialize import uniform_
from paddlespeech.utils.initialize import zeros_
def Conv1D(*args, **kwargs):
layer = nn.Conv1D(*args, **kwargs)
# Initialize the weight to be consistent with the official
kaiming_normal_(layer.weight)
# Initialization is consistent with torch
if layer.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(layer.weight)
if fan_in != 0:
bound = 1 / math.sqrt(fan_in)
uniform_(layer.bias, -bound, bound)
return layer
# Initialization is consistent with torch
def Linear(*args, **kwargs):
layer = nn.Linear(*args, **kwargs)
kaiming_uniform_(layer.weight, a=math.sqrt(5))
if layer.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(layer.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
uniform_(layer.bias, -bound, bound)
return layer
class ResidualBlock(nn.Layer):
"""ResidualBlock
"""
def __init__(self, encoder_hidden, residual_channels, gate_channels,
kernel_size, dilation):
super().__init__()
self.dilated_conv = Conv1D(
residual_channels,
gate_channels,
kernel_size,
padding=dilation,
dilation=dilation)
self.diffusion_projection = Linear(residual_channels, residual_channels)
self.conditioner_projection = Conv1D(encoder_hidden, gate_channels, 1)
self.output_projection = Conv1D(residual_channels, gate_channels, 1)
def forward(self, x, conditioner, diffusion_step):
"""_summary_
Args:
nn (_type_): _description_
"""
diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
conditioner = self.conditioner_projection(conditioner)
y = x + diffusion_step
y = self.dilated_conv(y) + conditioner
gate, filter = paddle.chunk(y, 2, axis=1)
y = F.sigmoid(gate) * paddle.tanh(filter)
y = self.output_projection(y)
residual, skip = paddle.chunk(y, 2, axis=1)
return (x + residual) / math.sqrt(2.0), skip
class SinusoidalPosEmb(nn.Layer):
"""_summary_
Args:
nn (_type_): _description_
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
"""_summary_
Args:
nn (_type_): _description_
"""
x = paddle.cast(x, 'float32')
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = paddle.exp(paddle.arange(half_dim) * -emb)
emb = x[:, None] * emb[None, :]
emb = paddle.concat([emb.sin(), emb.cos()], axis=-1)
return emb
class DiffNet(nn.Layer):
def __init__(
self,
in_channels: int=80,
out_channels: int=80,
kernel_size: int=3,
layers: int=20,
stacks: int=5,
residual_channels: int=256,
gate_channels: int=512,
skip_channels: int=256,
aux_channels: int=256,
dropout: float=0.,
bias: bool=True,
use_weight_norm: bool=False,
init_type: str="kaiming_normal", ):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.layers = layers
self.aux_channels = aux_channels
self.residual_channels = residual_channels
self.gate_channels = gate_channels
self.kernel_size = kernel_size
self.dilation_cycle_length = layers // stacks
self.skip_channels = skip_channels
self.input_projection = Conv1D(self.in_channels, self.residual_channels,
1)
self.diffusion_embedding = SinusoidalPosEmb(self.residual_channels)
dim = self.residual_channels
self.mlp = nn.Sequential(
Linear(dim, dim * 4), nn.Mish(), Linear(dim * 4, dim))
self.residual_layers = nn.LayerList([
ResidualBlock(
encoder_hidden=self.aux_channels,
residual_channels=self.residual_channels,
gate_channels=self.gate_channels,
kernel_size=self.kernel_size,
dilation=2**(i % self.dilation_cycle_length))
for i in range(self.layers)
])
self.skip_projection = Conv1D(self.residual_channels,
self.skip_channels, 1)
self.output_projection = Conv1D(self.residual_channels,
self.out_channels, 1)
zeros_(self.output_projection.weight)
def forward(self, spec, diffusion_step, cond):
"""
:param spec: [B, M, T]
:param diffusion_step: [B, 1]
:param cond: [B, M, T]
:return:
"""
x = spec
x = self.input_projection(x) # x [B, residual_channel, T]
x = F.relu(x)
diffusion_step = self.diffusion_embedding(diffusion_step)
diffusion_step = self.mlp(diffusion_step)
skip = []
for layer_id, layer in enumerate(self.residual_layers):
x, skip_connection = layer(x, cond, diffusion_step)
skip.append(skip_connection)
x = paddle.sum(
paddle.stack(skip), axis=0) / math.sqrt(len(self.residual_layers))
x = self.skip_projection(x)
x = F.relu(x)
x = self.output_projection(x) # [B, 80, T]
return x

@ -17,6 +17,7 @@ from typing import Callable
from typing import Optional from typing import Optional
from typing import Tuple from typing import Tuple
import numpy as np
import paddle import paddle
import ppdiffusers import ppdiffusers
from paddle import nn from paddle import nn
@ -27,170 +28,6 @@ from paddlespeech.t2s.modules.nets_utils import initialize
from paddlespeech.t2s.modules.residual_block import WaveNetResidualBlock from paddlespeech.t2s.modules.residual_block import WaveNetResidualBlock
class WaveNetDenoiser(nn.Layer):
"""A Mel-Spectrogram Denoiser modified from WaveNet
Args:
in_channels (int, optional):
Number of channels of the input mel-spectrogram, by default 80
out_channels (int, optional):
Number of channels of the output mel-spectrogram, by default 80
kernel_size (int, optional):
Kernel size of the residual blocks inside, by default 3
layers (int, optional):
Number of residual blocks inside, by default 20
stacks (int, optional):
The number of groups to split the residual blocks into, by default 5
Within each group, the dilation of the residual block grows exponentially.
residual_channels (int, optional):
Residual channel of the residual blocks, by default 256
gate_channels (int, optional):
Gate channel of the residual blocks, by default 512
skip_channels (int, optional):
Skip channel of the residual blocks, by default 256
aux_channels (int, optional):
Auxiliary channel of the residual blocks, by default 256
dropout (float, optional):
Dropout of the residual blocks, by default 0.
bias (bool, optional):
Whether to use bias in residual blocks, by default True
use_weight_norm (bool, optional):
Whether to use weight norm in all convolutions, by default False
"""
def __init__(
self,
in_channels: int=80,
out_channels: int=80,
kernel_size: int=3,
layers: int=20,
stacks: int=5,
residual_channels: int=256,
gate_channels: int=512,
skip_channels: int=256,
aux_channels: int=256,
dropout: float=0.,
bias: bool=True,
use_weight_norm: bool=False,
init_type: str="kaiming_normal", ):
super().__init__()
# initialize parameters
initialize(self, init_type)
self.in_channels = in_channels
self.out_channels = out_channels
self.aux_channels = aux_channels
self.layers = layers
self.stacks = stacks
self.kernel_size = kernel_size
assert layers % stacks == 0
layers_per_stack = layers // stacks
self.first_t_emb = nn.Sequential(
Timesteps(
residual_channels,
flip_sin_to_cos=False,
downscale_freq_shift=1),
nn.Linear(residual_channels, residual_channels * 4),
nn.Mish(), nn.Linear(residual_channels * 4, residual_channels))
self.t_emb_layers = nn.LayerList([
nn.Linear(residual_channels, residual_channels)
for _ in range(layers)
])
self.first_conv = nn.Conv1D(
in_channels, residual_channels, 1, bias_attr=True)
self.first_act = nn.ReLU()
self.conv_layers = nn.LayerList()
for layer in range(layers):
dilation = 2**(layer % layers_per_stack)
conv = WaveNetResidualBlock(
kernel_size=kernel_size,
residual_channels=residual_channels,
gate_channels=gate_channels,
skip_channels=skip_channels,
aux_channels=aux_channels,
dilation=dilation,
dropout=dropout,
bias=bias)
self.conv_layers.append(conv)
final_conv = nn.Conv1D(skip_channels, out_channels, 1, bias_attr=True)
nn.initializer.Constant(0.0)(final_conv.weight)
self.last_conv_layers = nn.Sequential(nn.ReLU(),
nn.Conv1D(
skip_channels,
skip_channels,
1,
bias_attr=True),
nn.ReLU(), final_conv)
if use_weight_norm:
self.apply_weight_norm()
def forward(self, x, t, c):
"""Denoise mel-spectrogram.
Args:
x(Tensor):
Shape (N, C_in, T), The input mel-spectrogram.
t(Tensor):
Shape (N), The timestep input.
c(Tensor):
Shape (N, C_aux, T'). The auxiliary input (e.g. fastspeech2 encoder output).
Returns:
Tensor: Shape (N, C_out, T), the denoised mel-spectrogram.
"""
assert c.shape[-1] == x.shape[-1]
if t.shape[0] != x.shape[0]:
t = t.tile([x.shape[0]])
t_emb = self.first_t_emb(t)
t_embs = [
t_emb_layer(t_emb)[..., None] for t_emb_layer in self.t_emb_layers
]
x = self.first_conv(x)
x = self.first_act(x)
skips = 0
for f, t in zip(self.conv_layers, t_embs):
x = x + t
x, s = f(x, c)
skips += s
skips *= math.sqrt(1.0 / len(self.conv_layers))
x = self.last_conv_layers(skips)
return x
def apply_weight_norm(self):
"""Recursively apply weight normalization to all the Convolution layers
in the sublayers.
"""
def _apply_weight_norm(layer):
if isinstance(layer, (nn.Conv1D, nn.Conv2D)):
nn.utils.weight_norm(layer)
self.apply(_apply_weight_norm)
def remove_weight_norm(self):
"""Recursively remove weight normalization from all the Convolution
layers in the sublayers.
"""
def _remove_weight_norm(layer):
try:
nn.utils.remove_weight_norm(layer)
except ValueError:
pass
self.apply(_remove_weight_norm)
class GaussianDiffusion(nn.Layer): class GaussianDiffusion(nn.Layer):
"""Common Gaussian Diffusion Denoising Model Module """Common Gaussian Diffusion Denoising Model Module
@ -294,13 +131,17 @@ class GaussianDiffusion(nn.Layer):
""" """
def __init__(self, def __init__(
denoiser: nn.Layer, self,
num_train_timesteps: Optional[int]=1000, denoiser: nn.Layer,
beta_start: Optional[float]=0.0001, num_train_timesteps: Optional[int]=1000,
beta_end: Optional[float]=0.02, beta_start: Optional[float]=0.0001,
beta_schedule: Optional[str]="squaredcos_cap_v2", beta_end: Optional[float]=0.02,
num_max_timesteps: Optional[int]=None): beta_schedule: Optional[str]="squaredcos_cap_v2",
num_max_timesteps: Optional[int]=None,
stretch: bool=True,
min_values: paddle.Tensor=None,
max_values: paddle.Tensor=None, ):
super().__init__() super().__init__()
self.num_train_timesteps = num_train_timesteps self.num_train_timesteps = num_train_timesteps
@ -315,6 +156,22 @@ class GaussianDiffusion(nn.Layer):
beta_end=beta_end, beta_end=beta_end,
beta_schedule=beta_schedule) beta_schedule=beta_schedule)
self.num_max_timesteps = num_max_timesteps self.num_max_timesteps = num_max_timesteps
self.stretch = stretch
self.min_values = min_values
self.max_values = max_values
def norm_spec(self, x):
"""
Linearly map x to [-1, 1]
Args:
x: [B, T, N]
"""
return (x - self.min_values) / (self.max_values - self.min_values
) * 2 - 1
def denorm_spec(self, x):
return (x + 1) / 2 * (self.max_values - self.min_values
) + self.min_values
def forward(self, x: paddle.Tensor, cond: Optional[paddle.Tensor]=None def forward(self, x: paddle.Tensor, cond: Optional[paddle.Tensor]=None
) -> Tuple[paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor]:
@ -333,6 +190,12 @@ class GaussianDiffusion(nn.Layer):
The noises which is added to the input. The noises which is added to the input.
""" """
if self.stretch:
assert self.min_values is not None and self.max_values is not None, "self.min_values and self.max_values should not be None."
x = x.transpose((0, 2, 1))
x = self.norm_spec(x)
x = x.transpose((0, 2, 1))
noise_scheduler = self.noise_scheduler noise_scheduler = self.noise_scheduler
# Sample noise that we'll add to the mel-spectrograms # Sample noise that we'll add to the mel-spectrograms
@ -369,9 +232,9 @@ class GaussianDiffusion(nn.Layer):
Args: Args:
noise (Tensor): noise (Tensor):
The input tensor as a starting point for denoising. The input tensor as a starting point for denoising.
cond (Tensor, optional): cond (Tensor, optional):
Conditional input for compute noises. Conditional input for compute noises. (N, C_aux, T)
ref_x (Tensor, optional): ref_x (Tensor, optional):
The real output for the denoising process to refer. The real output for the denoising process to refer.
num_inference_steps (int, optional): num_inference_steps (int, optional):
@ -382,6 +245,7 @@ class GaussianDiffusion(nn.Layer):
scheduler_type (str, optional): scheduler_type (str, optional):
Noise scheduler for generate noises. Noise scheduler for generate noises.
Choose a great scheduler can skip many denoising step, by default 'ddpm'. Choose a great scheduler can skip many denoising step, by default 'ddpm'.
only support 'ddpm' now !
clip_noise (bool, optional): clip_noise (bool, optional):
Whether to clip each denoised output, by default True. Whether to clip each denoised output, by default True.
clip_noise_range (tuple, optional): clip_noise_range (tuple, optional):
@ -425,48 +289,30 @@ class GaussianDiffusion(nn.Layer):
# set timesteps # set timesteps
scheduler.set_timesteps(num_inference_steps) scheduler.set_timesteps(num_inference_steps)
# prepare first noise variables
noisy_input = noise noisy_input = noise
timesteps = scheduler.timesteps if self.stretch and ref_x is not None:
if ref_x is not None: assert self.min_values is not None and self.max_values is not None, "self.min_values and self.max_values should not be None."
init_timestep = None ref_x = ref_x.transpose((0, 2, 1))
if strength is None or strength < 0. or strength > 1.: ref_x = self.norm_spec(ref_x)
strength = None ref_x = ref_x.transpose((0, 2, 1))
if self.num_max_timesteps is not None:
strength = self.num_max_timesteps / self.num_train_timesteps # for ddpm
if strength is not None: timesteps = paddle.to_tensor(
# get the original timestep using init_timestep np.flipud(np.arange(num_inference_steps)))
init_timestep = min( noisy_input = scheduler.add_noise(ref_x, noise, timesteps[0])
int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = scheduler.timesteps[t_start:]
num_inference_steps = num_inference_steps - t_start
noisy_input = scheduler.add_noise(
ref_x, noise, timesteps[:1].tile([noise.shape[0]]))
# denoising loop
denoised_output = noisy_input denoised_output = noisy_input
if clip_noise:
n_min, n_max = clip_noise_range
denoised_output = paddle.clip(denoised_output, n_min, n_max)
num_warmup_steps = len(
timesteps) - num_inference_steps * scheduler.order
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
denoised_output = scheduler.scale_model_input(denoised_output, t) denoised_output = scheduler.scale_model_input(denoised_output, t)
# predict the noise residual
noise_pred = self.denoiser(denoised_output, t, cond) noise_pred = self.denoiser(denoised_output, t, cond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
denoised_output = scheduler.step(noise_pred, t, denoised_output = scheduler.step(noise_pred, t,
denoised_output).prev_sample denoised_output).prev_sample
if clip_noise:
denoised_output = paddle.clip(denoised_output, n_min, n_max) if self.stretch:
assert self.min_values is not None and self.max_values is not None, "self.min_values and self.max_values should not be None."
# call the callback, if provided denoised_output = denoised_output.transpose((0, 2, 1))
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and denoised_output = self.denorm_spec(denoised_output)
(i + 1) % scheduler.order == 0): denoised_output = denoised_output.transpose((0, 2, 1))
if callback is not None and i % callback_steps == 0:
callback(i, t, len(timesteps), denoised_output)
return denoised_output return denoised_output

@ -0,0 +1,184 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from paddle import nn
from ppdiffusers.models.embeddings import Timesteps
from paddlespeech.t2s.modules.nets_utils import initialize
from paddlespeech.t2s.modules.residual_block import WaveNetResidualBlock
class WaveNetDenoiser(nn.Layer):
"""A Mel-Spectrogram Denoiser modified from WaveNet
Args:
in_channels (int, optional):
Number of channels of the input mel-spectrogram, by default 80
out_channels (int, optional):
Number of channels of the output mel-spectrogram, by default 80
kernel_size (int, optional):
Kernel size of the residual blocks inside, by default 3
layers (int, optional):
Number of residual blocks inside, by default 20
stacks (int, optional):
The number of groups to split the residual blocks into, by default 5
Within each group, the dilation of the residual block grows exponentially.
residual_channels (int, optional):
Residual channel of the residual blocks, by default 256
gate_channels (int, optional):
Gate channel of the residual blocks, by default 512
skip_channels (int, optional):
Skip channel of the residual blocks, by default 256
aux_channels (int, optional):
Auxiliary channel of the residual blocks, by default 256
dropout (float, optional):
Dropout of the residual blocks, by default 0.
bias (bool, optional):
Whether to use bias in residual blocks, by default True
use_weight_norm (bool, optional):
Whether to use weight norm in all convolutions, by default False
"""
def __init__(
self,
in_channels: int=80,
out_channels: int=80,
kernel_size: int=3,
layers: int=20,
stacks: int=5,
residual_channels: int=256,
gate_channels: int=512,
skip_channels: int=256,
aux_channels: int=256,
dropout: float=0.,
bias: bool=True,
use_weight_norm: bool=False,
init_type: str="kaiming_normal", ):
super().__init__()
# initialize parameters
initialize(self, init_type)
self.in_channels = in_channels
self.out_channels = out_channels
self.aux_channels = aux_channels
self.layers = layers
self.stacks = stacks
self.kernel_size = kernel_size
assert layers % stacks == 0
layers_per_stack = layers // stacks
self.first_t_emb = nn.Sequential(
Timesteps(
residual_channels,
flip_sin_to_cos=False,
downscale_freq_shift=1),
nn.Linear(residual_channels, residual_channels * 4),
nn.Mish(), nn.Linear(residual_channels * 4, residual_channels))
self.t_emb_layers = nn.LayerList([
nn.Linear(residual_channels, residual_channels)
for _ in range(layers)
])
self.first_conv = nn.Conv1D(
in_channels, residual_channels, 1, bias_attr=True)
self.first_act = nn.ReLU()
self.conv_layers = nn.LayerList()
for layer in range(layers):
dilation = 2**(layer % layers_per_stack)
conv = WaveNetResidualBlock(
kernel_size=kernel_size,
residual_channels=residual_channels,
gate_channels=gate_channels,
skip_channels=skip_channels,
aux_channels=aux_channels,
dilation=dilation,
dropout=dropout,
bias=bias)
self.conv_layers.append(conv)
final_conv = nn.Conv1D(skip_channels, out_channels, 1, bias_attr=True)
nn.initializer.Constant(0.0)(final_conv.weight)
self.last_conv_layers = nn.Sequential(nn.ReLU(),
nn.Conv1D(
skip_channels,
skip_channels,
1,
bias_attr=True),
nn.ReLU(), final_conv)
if use_weight_norm:
self.apply_weight_norm()
def forward(self, x, t, c):
"""Denoise mel-spectrogram.
Args:
x(Tensor):
Shape (N, C_in, T), The input mel-spectrogram.
t(Tensor):
Shape (N), The timestep input.
c(Tensor):
Shape (N, C_aux, T'). The auxiliary input (e.g. fastspeech2 encoder output).
Returns:
Tensor: Shape (N, C_out, T), the denoised mel-spectrogram.
"""
assert c.shape[-1] == x.shape[-1]
if t.shape[0] != x.shape[0]:
t = t.tile([x.shape[0]])
t_emb = self.first_t_emb(t)
t_embs = [
t_emb_layer(t_emb)[..., None] for t_emb_layer in self.t_emb_layers
]
x = self.first_conv(x)
x = self.first_act(x)
skips = 0
for f, t in zip(self.conv_layers, t_embs):
x = x + t
x, s = f(x, c)
skips += s
skips *= math.sqrt(1.0 / len(self.conv_layers))
x = self.last_conv_layers(skips)
return x
def apply_weight_norm(self):
"""Recursively apply weight normalization to all the Convolution layers
in the sublayers.
"""
def _apply_weight_norm(layer):
if isinstance(layer, (nn.Conv1D, nn.Conv2D)):
nn.utils.weight_norm(layer)
self.apply(_apply_weight_norm)
def remove_weight_norm(self):
"""Recursively remove weight normalization from all the Convolution
layers in the sublayers.
"""
def _remove_weight_norm(layer):
try:
nn.utils.remove_weight_norm(layer)
except ValueError:
pass
self.apply(_remove_weight_norm)
Loading…
Cancel
Save