fix comment

pull/3005/head
lym0302 3 years ago
parent 3df69e7502
commit 9acc85205a

@ -34,6 +34,7 @@ model:
# music score related
note_num: 300 # number of note
is_slur_num: 2 # number of slur
stretch: True # whether to stretch before diffusion
# fastspeech2 module
fastspeech2_params:
@ -142,15 +143,14 @@ ds_grad_norm: 1
###########################################################
# INTERVAL SETTING #
###########################################################
only_train_diffusion: True # Whether to freeze fastspeech2 parameters when training diffusion
ds_train_start_steps: 160000 # Number of steps to start to train diffusion module.
train_max_steps: 320000 # Number of training steps.
save_interval_steps: 2000 # Interval steps to save checkpoint.
eval_interval_steps: 2000 # Interval steps to evaluate the network.
num_snapshots: 5 # Number of saved models
num_snapshots: 5
###########################################################
# OTHER SETTING #
###########################################################
seed: 10086
find_unused_parameters: True

@ -64,3 +64,11 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# Get feature(mel) extremum for diffusion stretch
echo "Get feature(mel) extremum ..."
python3 ${BIN_DIR}/computer_extremum.py \
--metadata=dump/train/norm/metadata.jsonl \
--speech-stretchs=dump/train/speech_stretchs.npy
fi

@ -3,8 +3,6 @@
config_path=$1
train_output_path=$2
ckpt_name=$3
#iter=$3
#ckpt_name=snapshot_iter_${iter}.pdz
stage=0
stop_stage=0
@ -21,8 +19,8 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--voc_config=pwgan_opencpop/default.yaml \
--voc_ckpt=pwgan_opencpop/snapshot_iter_100000.pdz \
--voc_stat=pwgan_opencpop/feats_stats.npy \
--test_metadata=test.jsonl \
--output_dir=${train_output_path}/test_${iter} \
--test_metadata=dump/test/norm/metadata.jsonl \
--output_dir=${train_output_path}/test \
--phones_dict=dump/phone_id_map.txt
fi

@ -9,4 +9,5 @@ python3 ${BIN_DIR}/train.py \
--config=${config_path} \
--output-dir=${train_output_path} \
--ngpu=1 \
--phones-dict=dump/phone_id_map.txt
--phones-dict=dump/phone_id_map.txt \
--speech-stretchs=dump/train/speech_stretchs.npy

@ -12,18 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
from typing import List
from typing import Optional
from typing import Union
import librosa
import numpy as np
import pyworld
from scipy.interpolate import interp1d
from typing import List
from typing import Optional
from typing import Union
from typing_extensions import Literal
class LogMelFBank():
def __init__(self,
sr: int=24000,
@ -80,7 +79,7 @@ class LogMelFBank():
def _spectrogram(self, wav: np.ndarray):
D = self._stft(wav)
return np.abs(D) ** self.power
return np.abs(D)**self.power
def _mel_spectrogram(self, wav: np.ndarray):
S = self._spectrogram(wav)
@ -139,7 +138,7 @@ class Pitch():
input: np.ndarray,
use_continuous_f0: bool=True,
use_log_f0: bool=True) -> np.ndarray:
input = input.astype(float)
input = input.astype(np.float32)
frame_period = 1000 * self.hop_length / self.sr
f0, timeaxis = pyworld.dio(
input,

@ -1,4 +1,4 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

@ -0,0 +1,83 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import jsonlines
import numpy as np
from tqdm import tqdm
from paddlespeech.t2s.datasets.data_table import DataTable
def find_min_max_spec(spec, min_spec, max_spec):
# spec: [T, 80]
for i in range(spec.shape[1]):
min_value = np.min(spec[:, i])
max_value = np.max(spec[:, i])
min_spec[i] = min(min_value, min_spec[i])
max_spec[i] = max(max_value, max_spec[i])
return min_spec, max_spec
def main():
"""Run preprocessing process."""
parser = argparse.ArgumentParser(
description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)."
)
parser.add_argument(
"--metadata",
type=str,
required=True,
help="directory including feature files to be normalized. "
"you need to specify either *-scp or rootdir.")
parser.add_argument(
"--speech-stretchs",
type=str,
required=True,
help="min max spec file. only computer on train data")
args = parser.parse_args()
# get dataset
with jsonlines.open(args.metadata, 'r') as reader:
metadata = list(reader)
dataset = DataTable(
metadata, converters={
"speech": np.load,
})
logging.info(f"The number of files = {len(dataset)}.")
n_mel = 80
min_spec = 100 * np.ones(shape=(n_mel), dtype=np.float32)
max_spec = -100 * np.ones(shape=(n_mel), dtype=np.float32)
for item in tqdm(dataset):
spec = item['speech']
min_spec, max_spec = find_min_max_spec(spec, min_spec, max_spec)
print(min_spec)
print(max_spec)
min_max_spec = np.stack([min_spec, max_spec], axis=0)
np.save(
str(args.speech_stretchs),
min_max_spec.astype(np.float32),
allow_pickle=False)
if __name__ == "__main__":
main()

@ -1,4 +1,4 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

@ -1,4 +1,4 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

@ -1,4 +1,4 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -127,9 +127,21 @@ def train_sp(args, config):
vocab_size = len(phn_id)
print("vocab_size:", vocab_size)
with open(args.speech_stretchs, "r") as f:
spec_min = np.load(args.speech_stretchs)[0]
spec_max = np.load(args.speech_stretchs)[1]
spec_min = paddle.to_tensor(spec_min)
spec_max = paddle.to_tensor(spec_max)
print("min and max spec done!")
odim = config.n_mels
config["model"]["fastspeech2_params"]["spk_num"] = spk_num
model = DiffSinger(idim=vocab_size, odim=odim, **config["model"])
model = DiffSinger(
idim=vocab_size,
odim=odim,
**config["model"],
spec_min=spec_min,
spec_max=spec_max)
model_fs2 = model.fs2
model_ds = model.diffusion
if world_size > 1:
@ -143,13 +155,6 @@ def train_sp(args, config):
print("criterions done!")
optimizer_fs2 = build_optimizers(model_fs2, **config["fs2_optimizer"])
# gradient_clip_ds = nn.ClipGradByGlobalNorm(config["ds_grad_norm"])
# optimizer_ds = AdamW(
# learning_rate=config["ds_scheduler_params"]["learning_rate"],
# grad_clip=gradient_clip_ds,
# parameters=model_ds.parameters(),
# **config["ds_optimizer_params"])
lr_schedule_ds = StepDecay(**config["ds_scheduler_params"])
gradient_clip_ds = nn.ClipGradByGlobalNorm(config["ds_grad_norm"])
optimizer_ds = AdamW(
@ -178,7 +183,8 @@ def train_sp(args, config):
},
dataloader=train_dataloader,
ds_train_start_steps=config.ds_train_start_steps,
output_dir=output_dir)
output_dir=output_dir,
only_train_diffusion=config["only_train_diffusion"])
evaluator = DiffSingerEvaluator(
model=model,
@ -222,6 +228,10 @@ def main():
type=str,
default=None,
help="speaker id map file for multiple speaker model.")
parser.add_argument(
"--speech-stretchs",
type=str,
help="The min and max values of the mel spectrum.")
args = parser.parse_args()

@ -149,8 +149,6 @@ def get_test_dataset(test_metadata: List[Dict[str, Any]],
print("single speaker fastspeech2!")
elif am_name == 'diffsinger':
fields = ["utt_id", "text", "note", "note_dur", "is_slur"]
elif am_name == 'fastspeech2midi':
fields = ["utt_id", "text", "note", "note_dur", "is_slur"]
elif am_name == 'speedyspeech':
fields = ["utt_id", "phones", "tones"]
elif am_name == 'tacotron2':

@ -112,44 +112,29 @@ def evaluate(args):
note = paddle.to_tensor(datum["note"])
note_dur = paddle.to_tensor(datum["note_dur"])
is_slur = paddle.to_tensor(datum["is_slur"])
# get_mel_fs2 = False, means mel from diffusion, get_mel_fs2 = True, means mel from fastspeech2.
get_mel_fs2 = False
# mel: [T, mel_bin]
mel1 = am_inference(
phone_ids,
note=note,
note_dur=note_dur,
is_slur=is_slur,
get_mel_fs2=True)
mel2 = am_inference(
mel = am_inference(
phone_ids,
note=note,
note_dur=note_dur,
is_slur=is_slur,
get_mel_fs2=False)
wav1 = voc_inference(mel1)
wav2 = voc_inference(mel2)
get_mel_fs2=get_mel_fs2)
# vocoder
wav = voc_inference(mel)
wav1 = wav1.numpy()
wav2 = wav2.numpy()
N += wav1.size
N += wav2.size
wav = wav.numpy()
N += wav.size
T += t.elapse
speed = 2 * wav1.size / t.elapse
speed = wav.size / t.elapse
rtf = am_config.fs / speed
print(
f"{utt_id}, mel: {mel1.shape}, wave: {wav1.size}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
f"{utt_id}, mel: {mel.shape}, wave: {wav.size}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
)
sf.write(
str(output_dir / (utt_id + "_fs2.wav")),
wav1,
samplerate=am_config.fs)
sf.write(
str(output_dir / (utt_id + "_diffusion.wav")),
wav2,
samplerate=am_config.fs)
str(output_dir / (utt_id + ".wav")), wav, samplerate=am_config.fs)
print(f"{utt_id} done!")
# break
print(f"generation speed: {N / T}Hz, RTF: {am_config.fs / (N / T) }")

@ -1,4 +1,4 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

@ -1,4 +1,4 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -136,7 +136,9 @@ class DiffSinger(nn.Layer):
"beta_schedule": "squaredcos_cap_v2",
"num_max_timesteps": 60
},
stretch: bool=True, ):
stretch: bool=True,
spec_min: paddle.Tensor=None,
spec_max: paddle.Tensor=None, ):
"""Initialize DiffSinger module.
Args:
@ -149,6 +151,7 @@ class DiffSinger(nn.Layer):
fastspeech2_params (Dict[str, Any]): Parameter dict for fastspeech2 module.
denoiser_params (Dict[str, Any]): Parameter dict for dinoiser module.
diffusion_params (Dict[str, Any]): Parameter dict for diffusion module.
stretch (bool): Whether to stretch before diffusion. Defaults True.
"""
assert check_argument_types()
super().__init__()
@ -159,33 +162,6 @@ class DiffSinger(nn.Layer):
note_num=note_num,
is_slur_num=is_slur_num)
denoiser = DiffNet(**denoiser_params)
spec_min = paddle.to_tensor(
np.array([
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0
]))
spec_max = paddle.to_tensor(
np.array([
-0.79453, -0.81116, -0.61631, -0.30679, -0.13863, -0.050652,
-0.11563, -0.10679, -0.091068, -0.062174, -0.075302, -0.072217,
-0.063815, -0.073299, 0.007361, -0.072508, -0.050234, -0.16534,
-0.26928, -0.20782, -0.20823, -0.11702, -0.070128, -0.065868,
-0.012675, 0.0015121, -0.089902, -0.21392, -0.23789, -0.28922,
-0.30405, -0.23029, -0.22088, -0.21542, -0.29367, -0.30137,
-0.38281, -0.4359, -0.28681, -0.46855, -0.57485, -0.47022,
-0.54266, -0.44848, -0.6412, -0.687, -0.6486, -0.76436,
-0.49971, -0.71068, -0.69724, -0.61487, -0.55843, -0.69773,
-0.57502, -0.70919, -0.82431, -0.84213, -0.90431, -0.8284,
-0.77945, -0.82758, -0.87699, -1.0532, -1.0766, -1.1198,
-1.0185, -0.98983, -1.0001, -1.0756, -1.0024, -1.0304, -1.0579,
-1.0188, -1.05, -1.0842, -1.0923, -1.1223, -1.2381, -1.6467
]))
self.diffusion = GaussianDiffusion(
denoiser,
**diffusion_params,

@ -1,4 +1,4 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -139,9 +139,8 @@ class DiffSingerUpdater(StandardUpdater):
# Then only train diffusion module, freeze fastspeech2 parameters.
if self.state.iteration > self.ds_train_start_steps:
if self.only_train_diffusion:
for param in self.model.fs2.parameters():
param.trainable = False
for param in self.model.fs2.parameters():
param.trainable = False if self.only_train_diffusion else True
noise_pred, noise_target, mel_masks = self.model(
text=batch["text"],
@ -213,7 +212,59 @@ class DiffSingerEvaluator(StandardEvaluator):
if spk_emb is not None:
spk_id = None
# Here show diffsinger eval
# Here show fastspeech2 eval
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model(
text=batch["text"],
note=batch["note"],
note_dur=batch["note_dur"],
is_slur=batch["is_slur"],
text_lengths=batch["text_lengths"],
speech=batch["speech"],
speech_lengths=batch["speech_lengths"],
durations=batch["durations"],
pitch=batch["pitch"],
energy=batch["energy"],
spk_id=spk_id,
spk_emb=spk_emb,
only_train_fs2=True, )
l1_loss_fs2, ssim_loss_fs2, duration_loss, pitch_loss, energy_loss, speaker_loss = self.criterion_fs2(
after_outs=after_outs,
before_outs=before_outs,
d_outs=d_outs,
p_outs=p_outs,
e_outs=e_outs,
ys=ys,
ds=batch["durations"],
ps=batch["pitch"],
es=batch["energy"],
ilens=batch["text_lengths"],
olens=olens,
spk_logits=spk_logits,
spk_ids=spk_id, )
loss_fs2 = l1_loss_fs2 + ssim_loss_fs2 + duration_loss + pitch_loss + energy_loss + speaker_loss
report("eval/loss_fs2", float(loss_fs2))
report("eval/l1_loss_fs2", float(l1_loss_fs2))
report("eval/ssim_loss_fs2", float(ssim_loss_fs2))
report("eval/duration_loss", float(duration_loss))
report("eval/pitch_loss", float(pitch_loss))
report("eval/energy_loss", float(energy_loss))
losses_dict["l1_loss_fs2"] = float(l1_loss_fs2)
losses_dict["ssim_loss_fs2"] = float(ssim_loss_fs2)
losses_dict["duration_loss"] = float(duration_loss)
losses_dict["pitch_loss"] = float(pitch_loss)
losses_dict["energy_loss"] = float(energy_loss)
if speaker_loss != 0.:
report("eval/speaker_loss", float(speaker_loss))
losses_dict["speaker_loss"] = float(speaker_loss)
losses_dict["loss_fs2"] = float(loss_fs2)
# Here show diffusion eval
noise_pred, noise_target, mel_masks = self.model(
text=batch["text"],
note=batch["note"],
@ -236,6 +287,7 @@ class DiffSingerEvaluator(StandardEvaluator):
noise_pred=noise_pred,
noise_target=noise_target,
mel_masks=mel_masks, )
loss_ds = l1_loss_ds
report("eval/loss_ds", float(loss_ds))

@ -1,4 +1,4 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -175,18 +175,18 @@ class FastSpeech2MIDI(FastSpeech2):
before_outs = after_outs = d_outs = p_outs = e_outs = spk_logits = None
# forward encoder
x_masks = self._source_mask(ilens)
masks = self._source_mask(ilens)
note_emb = self.note_embedding_table(note)
note_dur_emb = self.note_dur_layer(paddle.unsqueeze(note_dur, axis=-1))
is_slur_emb = self.is_slur_embedding_table(is_slur)
# (B, Tmax, adim)
hs, _ = self.encoder(
xs,
x_masks,
note_emb,
note_dur_emb,
is_slur_emb, )
xs=xs,
masks=masks,
note_emb=note_emb,
note_dur_emb=note_dur_emb,
is_slur_emb=is_slur_emb, )
if self.spk_num and self.enable_speaker_classifier and not is_inference:
hs_for_spk_cls = self.grad_reverse(hs)

@ -52,10 +52,26 @@ def Linear(*args, **kwargs):
class ResidualBlock(nn.Layer):
"""ResidualBlock
Args:
encoder_hidden (int, optional):
Input feature size of the 1D convolution, by default 256
residual_channels (int, optional):
Feature size of the residual output(and also the input), by default 256
gate_channels (int, optional):
Output feature size of the 1D convolution, by default 512
kernel_size (int, optional):
Kernel size of the 1D convolution, by default 3
dilation (int, optional):
Dilation of the 1D convolution, by default 4
"""
def __init__(self, encoder_hidden, residual_channels, gate_channels,
kernel_size, dilation):
def __init__(self,
encoder_hidden: int=256,
residual_channels: int=256,
gate_channels: int=512,
kernel_size: int=3,
dilation: int=4):
super().__init__()
self.dilated_conv = Conv1D(
residual_channels,
@ -67,17 +83,26 @@ class ResidualBlock(nn.Layer):
self.conditioner_projection = Conv1D(encoder_hidden, gate_channels, 1)
self.output_projection = Conv1D(residual_channels, gate_channels, 1)
def forward(self, x, conditioner, diffusion_step):
"""_summary_
def forward(
self,
x: paddle.Tensor,
diffusion_step: paddle.Tensor,
cond: paddle.Tensor, ):
"""Calculate forward propagation.
Args:
nn (_type_): _description_
spec (Tensor(float32)): input feature. (B, residual_channels, T)
diffusion_step (Tensor(int64)): The timestep input (adding noise step). (B,)
cond (Tensor(float32)): The auxiliary input (e.g. fastspeech2 encoder output). (B, residual_channels, T)
Returns:
x (Tensor(float32)): output (B, residual_channels, T)
"""
diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
conditioner = self.conditioner_projection(conditioner)
cond = self.conditioner_projection(cond)
y = x + diffusion_step
y = self.dilated_conv(y) + conditioner
y = self.dilated_conv(y) + cond
gate, filter = paddle.chunk(y, 2, axis=1)
y = F.sigmoid(gate) * paddle.tanh(filter)
@ -88,22 +113,14 @@ class ResidualBlock(nn.Layer):
class SinusoidalPosEmb(nn.Layer):
"""_summary_
Args:
nn (_type_): _description_
"""Positional embedding
"""
def __init__(self, dim):
def __init__(self, dim: int=256):
super().__init__()
self.dim = dim
def forward(self, x):
"""_summary_
Args:
nn (_type_): _description_
"""
def forward(self, x: paddle.Tensor):
x = paddle.cast(x, 'float32')
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
@ -114,6 +131,36 @@ class SinusoidalPosEmb(nn.Layer):
class DiffNet(nn.Layer):
"""A Mel-Spectrogram Denoiser
Args:
in_channels (int, optional):
Number of channels of the input mel-spectrogram, by default 80
out_channels (int, optional):
Number of channels of the output mel-spectrogram, by default 80
kernel_size (int, optional):
Kernel size of the residual blocks inside, by default 3
layers (int, optional):
Number of residual blocks inside, by default 20
stacks (int, optional):
The number of groups to split the residual blocks into, by default 5
Within each group, the dilation of the residual block grows exponentially.
residual_channels (int, optional):
Residual channel of the residual blocks, by default 256
gate_channels (int, optional):
Gate channel of the residual blocks, by default 512
skip_channels (int, optional):
Skip channel of the residual blocks, by default 256
aux_channels (int, optional):
Auxiliary channel of the residual blocks, by default 256
dropout (float, optional):
Dropout of the residual blocks, by default 0.
bias (bool, optional):
Whether to use bias in residual blocks, by default True
use_weight_norm (bool, optional):
Whether to use weight norm in all convolutions, by default False
"""
def __init__(
self,
in_channels: int=80,
@ -162,13 +209,20 @@ class DiffNet(nn.Layer):
self.out_channels, 1)
zeros_(self.output_projection.weight)
def forward(self, spec, diffusion_step, cond):
"""
def forward(
self,
spec: paddle.Tensor,
diffusion_step: paddle.Tensor,
cond: paddle.Tensor, ):
"""Calculate forward propagation.
Args:
spec (Tensor(float32)): The input mel-spectrogram. (B, n_mel, T)
diffusion_step (Tensor(int64)): The timestep input (adding noise step). (B,)
cond (Tensor(float32)): The auxiliary input (e.g. fastspeech2 encoder output). (B, D_enc_out, T)
Returns:
x (Tensor(float32)): pred noise (B, n_mel, T)
:param spec: [B, M, T]
:param diffusion_step: [B, 1]
:param cond: [B, M, T]
:return:
"""
x = spec
x = self.input_projection(x) # x [B, residual_channel, T]
@ -178,7 +232,10 @@ class DiffNet(nn.Layer):
diffusion_step = self.mlp(diffusion_step)
skip = []
for layer_id, layer in enumerate(self.residual_layers):
x, skip_connection = layer(x, cond, diffusion_step)
x, skip_connection = layer(
x=x,
diffusion_step=diffusion_step,
cond=cond, )
skip.append(skip_connection)
x = paddle.sum(
paddle.stack(skip), axis=0) / math.sqrt(len(self.residual_layers))

@ -44,6 +44,13 @@ class GaussianDiffusion(nn.Layer):
beta schedule parameter for the scheduler, by default 'squaredcos_cap_v2' (cosine schedule).
num_max_timesteps (int, optional):
The max timestep transition from real to noise, by default None.
stretch (bool, optional):
Whether to stretch before diffusion, by defalut True.
min_values: (paddle.Tensor):
The minimum value of the feature to stretch.
max_values: (paddle.Tensor):
The maximum value of the feature to stretch.
Examples:
>>> import paddle
@ -191,7 +198,6 @@ class GaussianDiffusion(nn.Layer):
"""
if self.stretch:
assert self.min_values is not None and self.max_values is not None, "self.min_values and self.max_values should not be None."
x = x.transpose((0, 2, 1))
x = self.norm_spec(x)
x = x.transpose((0, 2, 1))
@ -291,7 +297,6 @@ class GaussianDiffusion(nn.Layer):
noisy_input = noise
if self.stretch and ref_x is not None:
assert self.min_values is not None and self.max_values is not None, "self.min_values and self.max_values should not be None."
ref_x = ref_x.transpose((0, 2, 1))
ref_x = self.norm_spec(ref_x)
ref_x = ref_x.transpose((0, 2, 1))
@ -315,7 +320,6 @@ class GaussianDiffusion(nn.Layer):
denoised_output = paddle.clip(denoised_output, n_min, n_max)
if self.stretch:
assert self.min_values is not None and self.max_values is not None, "self.min_values and self.max_values should not be None."
denoised_output = denoised_output.transpose((0, 2, 1))
denoised_output = self.denorm_spec(denoised_output)
denoised_output = denoised_output.transpose((0, 2, 1))

@ -131,19 +131,19 @@ class WaveNetDenoiser(nn.Layer):
if use_weight_norm:
self.apply_weight_norm()
def forward(self, x, t, c):
def forward(self, x: paddle.Tensor, t: paddle.Tensor, c: paddle.Tensor):
"""Denoise mel-spectrogram.
Args:
x(Tensor):
Shape (N, C_in, T), The input mel-spectrogram.
Shape (B, C_in, T), The input mel-spectrogram.
t(Tensor):
Shape (N), The timestep input.
Shape (B), The timestep input.
c(Tensor):
Shape (N, C_aux, T'). The auxiliary input (e.g. fastspeech2 encoder output).
Shape (B, C_aux, T'). The auxiliary input (e.g. fastspeech2 encoder output).
Returns:
Tensor: Shape (N, C_out, T), the denoised mel-spectrogram.
Tensor: Shape (B, C_out, T), the pred noise.
"""
assert c.shape[-1] == x.shape[-1]
@ -189,4 +189,3 @@ class WaveNetDenoiser(nn.Layer):
pass
self.apply(_remove_weight_norm)

Loading…
Cancel
Save