update diffsinger, test=tts

pull/3005/head
lym0302 3 years ago
parent d1173b9cdb
commit d7928d712d

@ -23,8 +23,8 @@ f0max: 750 # Maximum f0 for pitch extraction.
########################################################### ###########################################################
# DATA SETTING # # DATA SETTING #
########################################################### ###########################################################
batch_size: 48 batch_size: 24 # batch size
num_workers: 1 num_workers: 4 # number of gpu
########################################################### ###########################################################
@ -32,23 +32,23 @@ num_workers: 1
########################################################### ###########################################################
model: model:
# music score related # music score related
note_num: 300 note_num: 300 # number of note
is_slur_num: 2 is_slur_num: 2 # number of slur
# fastspeech2 module # fastspeech2 module
fastspeech2_params: fastspeech2_params:
adim: 256 # attention dimension # lym check adim: 256 # attention dimension
aheads: 2 # number of attention heads # lym check aheads: 2 # number of attention heads
elayers: 4 # number of encoder layers # lym check elayers: 4 # number of encoder layers
eunits: 1024 # number of encoder ff units # lym check adim * 4 eunits: 1024 # number of encoder ff units
dlayers: 4 # number of decoder layers # lym check dlayers: 4 # number of decoder layers
dunits: 1024 # number of decoder ff units # lym check dunits: 1024 # number of decoder ff units
positionwise_layer_type: conv1d-linear # type of position-wise layer # lym check positionwise_layer_type: conv1d-linear # type of position-wise layer
positionwise_conv_kernel_size: 9 # kernel size of position wise conv layer # lym check positionwise_conv_kernel_size: 9 # kernel size of position wise conv layer
transformer_enc_dropout_rate: 0.1 # dropout rate for transformer encoder layer # lym check transformer_enc_dropout_rate: 0.1 # dropout rate for transformer encoder layer
transformer_enc_positional_dropout_rate: 0.1 # dropout rate for transformer encoder positional encoding # lym check transformer_enc_positional_dropout_rate: 0.1 # dropout rate for transformer encoder positional encoding
transformer_enc_attn_dropout_rate: 0.0 # dropout rate for transformer encoder attention layer # lym check transformer_enc_attn_dropout_rate: 0.0 # dropout rate for transformer encoder attention layer
transformer_activation_type: "gelu" transformer_activation_type: "gelu" # Activation function type in transformer.
encoder_normalize_before: True # whether to perform layer normalization before the input encoder_normalize_before: True # whether to perform layer normalization before the input
decoder_normalize_before: True # whether to perform layer normalization before the input decoder_normalize_before: True # whether to perform layer normalization before the input
reduction_factor: 1 # reduction factor reduction_factor: 1 # reduction factor
@ -70,11 +70,6 @@ model:
pitch_embed_kernel_size: 1 # kernel size of conv embedding layer for pitch pitch_embed_kernel_size: 1 # kernel size of conv embedding layer for pitch
pitch_embed_dropout: 0.0 # dropout rate after conv embedding layer for pitch pitch_embed_dropout: 0.0 # dropout rate after conv embedding layer for pitch
stop_gradient_from_pitch_predictor: True # whether to stop the gradient from pitch predictor to encoder stop_gradient_from_pitch_predictor: True # whether to stop the gradient from pitch predictor to encoder
postnet_layers: 5 # number of layers of postnset
postnet_filts: 5 # filter size of conv layers in postnet
postnet_chans: 256 # number of channels of conv layers in postnet
energy_predictor_layers: 2 # number of conv layers in energy predictor energy_predictor_layers: 2 # number of conv layers in energy predictor
energy_predictor_chans: 256 # number of channels of conv layers in energy predictor energy_predictor_chans: 256 # number of channels of conv layers in energy predictor
energy_predictor_kernel_size: 3 # kernel size of conv leyers in energy predictor energy_predictor_kernel_size: 3 # kernel size of conv leyers in energy predictor
@ -82,30 +77,34 @@ model:
energy_embed_kernel_size: 1 # kernel size of conv embedding layer for energy energy_embed_kernel_size: 1 # kernel size of conv embedding layer for energy
energy_embed_dropout: 0.0 # dropout rate after conv embedding layer for energy energy_embed_dropout: 0.0 # dropout rate after conv embedding layer for energy
stop_gradient_from_energy_predictor: False # whether to stop the gradient from energy predictor to encoder stop_gradient_from_energy_predictor: False # whether to stop the gradient from energy predictor to encoder
postnet_layers: 5 # number of layers of postnet
postnet_filts: 5 # filter size of conv layers in postnet
postnet_chans: 256 # number of channels of conv layers in postnet
postnet_dropout_rate: 0.5 # dropout rate for postnet
# denoiser module # denoiser module
denoiser_params: denoiser_params:
in_channels: 80 in_channels: 80 # Number of channels of the input mel-spectrogram
out_channels: 80 out_channels: 80 # Number of channels of the output mel-spectrogram
kernel_size: 3 kernel_size: 3 # Kernel size of the residual blocks inside
layers: 20 layers: 20 # Number of residual blocks inside
stacks: 5 stacks: 5 # The number of groups to split the residual blocks into
residual_channels: 256 residual_channels: 256 # Residual channel of the residual blocks
gate_channels: 512 gate_channels: 512 # Gate channel of the residual blocks
skip_channels: 256 skip_channels: 256 # Skip channel of the residual blocks
aux_channels: 256 aux_channels: 256 # Auxiliary channel of the residual blocks
dropout: 0.1 dropout: 0.1 # Dropout of the residual blocks
bias: True bias: True # Whether to use bias in residual blocks
use_weight_norm: False use_weight_norm: False # Whether to use weight norm in all convolutions
init_type: "kaiming_normal" init_type: "kaiming_normal" # Type of initialize weights of a neural network module
# diffusion module # diffusion module
diffusion_params: diffusion_params:
num_train_timesteps: 100 num_train_timesteps: 100 # The number of timesteps between the noise and the real during training
beta_start: 0.0001 beta_start: 0.0001 # beta start parameter for the scheduler
beta_end: 0.06 beta_end: 0.06 # beta end parameter for the scheduler
beta_schedule: "squaredcos_cap_v2" beta_schedule: "linear" # beta schedule parameter for the scheduler
num_max_timesteps: 60 num_max_timesteps: 60 # The max timestep transition from real to noise
########################################################### ###########################################################
@ -142,17 +141,12 @@ ds_grad_norm: 1
########################################################### ###########################################################
# INTERVAL SETTING # # INTERVAL SETTING #
########################################################### ###########################################################
ds_train_start_steps: 32500 # Number of steps to start to train diffusion module. ds_train_start_steps: 100000 # Number of steps to start to train diffusion module.
train_max_steps: 65000 # Number of training steps. train_max_steps: 200000 # Number of training steps.
save_interval_steps: 500 # Interval steps to save checkpoint. save_interval_steps: 500 # Interval steps to save checkpoint.
eval_interval_steps: 500 # Interval steps to evaluate the network. eval_interval_steps: 100 # Interval steps to evaluate the network.
num_snapshots: 20 num_snapshots: 10 # Number of saved models
# ds_train_start_steps: 4 # Number of steps to start to train diffusion module.
# train_max_steps: 8 # Number of training steps.
# save_interval_steps: 1 # Interval steps to save checkpoint.
# eval_interval_steps: 2 # Interval steps to evaluate the network.
# num_snapshots: 5
########################################################### ###########################################################
# OTHER SETTING # # OTHER SETTING #

@ -21,7 +21,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--voc_config=pwgan_opencpop/default.yaml \ --voc_config=pwgan_opencpop/default.yaml \
--voc_ckpt=pwgan_opencpop/snapshot_iter_100000.pdz \ --voc_ckpt=pwgan_opencpop/snapshot_iter_100000.pdz \
--voc_stat=pwgan_opencpop/feats_stats.npy \ --voc_stat=pwgan_opencpop/feats_stats.npy \
--test_metadata=dump/test/norm/metadata.jsonl \ --test_metadata=test1.jsonl \
--output_dir=${train_output_path}/test_${iter} \ --output_dir=${train_output_path}/test_${iter} \
--phones_dict=dump/phone_id_map.txt --phones_dict=dump/phone_id_map.txt
fi fi

@ -4,7 +4,6 @@ set -e
source path.sh source path.sh
gpus=4,5,6,7 gpus=4,5,6,7
#gpus=0
stage=1 stage=1
stop_stage=1 stop_stage=1

@ -126,7 +126,7 @@ class Pitch():
input: np.ndarray, input: np.ndarray,
use_continuous_f0: bool=True, use_continuous_f0: bool=True,
use_log_f0: bool=True) -> np.ndarray: use_log_f0: bool=True) -> np.ndarray:
input = input.astype(np.float) input = input.astype(float)
frame_period = 1000 * self.hop_length / self.sr frame_period = 1000 * self.hop_length / self.sr
f0, timeaxis = pyworld.dio( f0, timeaxis = pyworld.dio(
input, input,

@ -80,27 +80,20 @@ def main():
# restore scaler # restore scaler
speech_scaler = StandardScaler() speech_scaler = StandardScaler()
# speech_scaler.mean_ = np.load(args.speech_stats)[0] speech_scaler.mean_ = np.load(args.speech_stats)[0]
# speech_scaler.scale_ = np.load(args.speech_stats)[1] speech_scaler.scale_ = np.load(args.speech_stats)[1]
speech_scaler.mean_ = np.zeros(np.load(args.speech_stats)[0].shape, dtype="float32")
speech_scaler.scale_ = np.ones(np.load(args.speech_stats)[1].shape, dtype="float32")
speech_scaler.n_features_in_ = speech_scaler.mean_.shape[0] speech_scaler.n_features_in_ = speech_scaler.mean_.shape[0]
pitch_scaler = StandardScaler() pitch_scaler = StandardScaler()
# pitch_scaler.mean_ = np.load(args.pitch_stats)[0] pitch_scaler.mean_ = np.load(args.pitch_stats)[0]
# pitch_scaler.scale_ = np.load(args.pitch_stats)[1] pitch_scaler.scale_ = np.load(args.pitch_stats)[1]
pitch_scaler.mean_ = np.zeros(np.load(args.pitch_stats)[0].shape, dtype="float32")
pitch_scaler.scale_ = np.ones(np.load(args.pitch_stats)[1].shape, dtype="float32")
pitch_scaler.n_features_in_ = pitch_scaler.mean_.shape[0] pitch_scaler.n_features_in_ = pitch_scaler.mean_.shape[0]
energy_scaler = StandardScaler() energy_scaler = StandardScaler()
# energy_scaler.mean_ = np.load(args.energy_stats)[0] energy_scaler.mean_ = np.load(args.energy_stats)[0]
# energy_scaler.scale_ = np.load(args.energy_stats)[1] energy_scaler.scale_ = np.load(args.energy_stats)[1]
energy_scaler.mean_ = np.zeros(np.load(args.energy_stats)[0].shape, dtype="float32")
energy_scaler.scale_ = np.ones(np.load(args.energy_stats)[1].shape, dtype="float32")
energy_scaler.n_features_in_ = energy_scaler.mean_.shape[0] energy_scaler.n_features_in_ = energy_scaler.mean_.shape[0]
vocab_phones = {} vocab_phones = {}
with open(args.phones_dict, 'rt') as f: with open(args.phones_dict, 'rt') as f:
phn_id = [line.strip().split() for line in f.readlines()] phn_id = [line.strip().split() for line in f.readlines()]

@ -88,10 +88,7 @@ def process_sentence(
phones = sentences[utt_id][0] phones = sentences[utt_id][0]
durations = sentences[utt_id][1] durations = sentences[utt_id][1]
num_frames = logmel.shape[0] num_frames = logmel.shape[0]
word_boundary = [
1 if x in ALL_FINALS + ['AP', 'SP'] else 0 for x in phones
]
# print(sum(durations), num_frames)
assert sum( assert sum(
durations durations
) == num_frames, "the sum of durations doesn't equal to the num of mel frames. " ) == num_frames, "the sum of durations doesn't equal to the num of mel frames. "
@ -105,7 +102,6 @@ def process_sentence(
pitch_dir = output_dir / "data_pitch" pitch_dir = output_dir / "data_pitch"
pitch_dir.mkdir(parents=True, exist_ok=True) pitch_dir.mkdir(parents=True, exist_ok=True)
pitch_path = pitch_dir / (utt_id + "_pitch.npy") pitch_path = pitch_dir / (utt_id + "_pitch.npy")
# print(pitch, pitch.shape)
np.save(pitch_path, pitch) np.save(pitch_path, pitch)
energy = energy_extractor.get_energy(wav) energy = energy_extractor.get_energy(wav)
assert energy.shape[0] == num_frames assert energy.shape[0] == num_frames

@ -138,20 +138,17 @@ def train_sp(args, config):
model_ds = model._layers.diffusion model_ds = model._layers.diffusion
print("models done!") print("models done!")
# criterion_fs2 = FastSpeech2Loss(**config["fs2_updater"])
criterion_fs2 = FastSpeech2MIDILoss(**config["fs2_updater"]) criterion_fs2 = FastSpeech2MIDILoss(**config["fs2_updater"])
criterion_ds = DiffusionLoss(**config["ds_updater"]) criterion_ds = DiffusionLoss(**config["ds_updater"])
print("criterions done!") print("criterions done!")
optimizer_fs2 = build_optimizers(model_fs2, **config["fs2_optimizer"]) optimizer_fs2 = build_optimizers(model_fs2, **config["fs2_optimizer"])
lr_schedule_ds = StepDecay(**config["ds_scheduler_params"])
gradient_clip_ds = nn.ClipGradByGlobalNorm(config["ds_grad_norm"]) gradient_clip_ds = nn.ClipGradByGlobalNorm(config["ds_grad_norm"])
optimizer_ds = AdamW( optimizer_ds = AdamW(
learning_rate=lr_schedule_ds, learning_rate=config["ds_scheduler_params"]["learning_rate"],
grad_clip=gradient_clip_ds, grad_clip=gradient_clip_ds,
parameters=model_ds.parameters(), parameters=model_ds.parameters(),
**config["ds_optimizer_params"]) **config["ds_optimizer_params"])
# optimizer_ds = build_optimizers(ds, **config["ds_optimizer"])
print("optimizer done!") print("optimizer done!")
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
@ -182,7 +179,7 @@ def train_sp(args, config):
"ds": criterion_ds, "ds": criterion_ds,
}, },
dataloader=dev_dataloader, dataloader=dev_dataloader,
output_dir=output_dir,) output_dir=output_dir, )
trainer = Trainer( trainer = Trainer(
updater, updater,

@ -54,10 +54,6 @@ class DiffSinger(nn.Layer):
"eunits": 1024, "eunits": 1024,
"dlayers": 4, "dlayers": 4,
"dunits": 1024, "dunits": 1024,
"postnet_layers": 5,
"postnet_chans": 512,
"postnet_filts": 5,
"postnet_dropout_rate": 0.5,
"positionwise_layer_type": "conv1d", "positionwise_layer_type": "conv1d",
"positionwise_conv_kernel_size": 1, "positionwise_conv_kernel_size": 1,
"use_scaled_pos_enc": True, "use_scaled_pos_enc": True,
@ -80,15 +76,8 @@ class DiffSinger(nn.Layer):
"duration_predictor_chans": 384, "duration_predictor_chans": 384,
"duration_predictor_kernel_size": 3, "duration_predictor_kernel_size": 3,
"duration_predictor_dropout_rate": 0.1, "duration_predictor_dropout_rate": 0.1,
# energy predictor
"energy_predictor_layers": 2,
"energy_predictor_chans": 384,
"energy_predictor_kernel_size": 3,
"energy_predictor_dropout": 0.5,
"energy_embed_kernel_size": 9,
"energy_embed_dropout": 0.5,
"stop_gradient_from_energy_predictor": False,
# pitch predictor # pitch predictor
"use_pitch_embed": True,
"pitch_predictor_layers": 2, "pitch_predictor_layers": 2,
"pitch_predictor_chans": 384, "pitch_predictor_chans": 384,
"pitch_predictor_kernel_size": 3, "pitch_predictor_kernel_size": 3,
@ -96,6 +85,20 @@ class DiffSinger(nn.Layer):
"pitch_embed_kernel_size": 9, "pitch_embed_kernel_size": 9,
"pitch_embed_dropout": 0.5, "pitch_embed_dropout": 0.5,
"stop_gradient_from_pitch_predictor": False, "stop_gradient_from_pitch_predictor": False,
# energy predictor
"use_energy_embed": False,
"energy_predictor_layers": 2,
"energy_predictor_chans": 384,
"energy_predictor_kernel_size": 3,
"energy_predictor_dropout": 0.5,
"energy_embed_kernel_size": 9,
"energy_embed_dropout": 0.5,
"stop_gradient_from_energy_predictor": False,
# postnet
"postnet_layers": 5,
"postnet_chans": 512,
"postnet_filts": 5,
"postnet_dropout_rate": 0.5,
# spk emb # spk emb
"spk_num": None, "spk_num": None,
"spk_embed_dim": None, "spk_embed_dim": None,
@ -170,7 +173,7 @@ class DiffSinger(nn.Layer):
energy: paddle.Tensor, energy: paddle.Tensor,
spk_emb: paddle.Tensor=None, spk_emb: paddle.Tensor=None,
spk_id: paddle.Tensor=None, spk_id: paddle.Tensor=None,
train_fs2: bool=True, only_train_fs2: bool=True,
) -> Tuple[paddle.Tensor, Dict[str, paddle.Tensor], paddle.Tensor]: ) -> Tuple[paddle.Tensor, Dict[str, paddle.Tensor], paddle.Tensor]:
"""Calculate forward propagation. """Calculate forward propagation.
@ -199,7 +202,7 @@ class DiffSinger(nn.Layer):
Batch of speaker embeddings (B, spk_embed_dim). Batch of speaker embeddings (B, spk_embed_dim).
spk_id(Tnesor[int64], optional(int64)): spk_id(Tnesor[int64], optional(int64)):
Batch of speaker ids (B,) Batch of speaker ids (B,)
train_fs2(bool): only_train_fs2(bool):
Whether to train only the fastspeech2 module Whether to train only the fastspeech2 module
Returns: Returns:
@ -219,7 +222,7 @@ class DiffSinger(nn.Layer):
energy=energy, energy=energy,
spk_id=spk_id, spk_id=spk_id,
spk_emb=spk_emb) spk_emb=spk_emb)
if train_fs2: if only_train_fs2:
return before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits return before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits
# get the encoder output from fastspeech2 as the condition of denoiser module # get the encoder output from fastspeech2 as the condition of denoiser module
@ -236,9 +239,9 @@ class DiffSinger(nn.Layer):
cond_fs2 = cond_fs2.transpose((0, 2, 1)) cond_fs2 = cond_fs2.transpose((0, 2, 1))
# get the output(final mel) from diffusion module # get the output(final mel) from diffusion module
mel, mel_ref = self.diffusion( noise_pred, noise_target = self.diffusion(
speech.transpose((0, 2, 1)), cond_fs2.detach()) speech.transpose((0, 2, 1)), cond_fs2)
return mel, mel_ref, mel_masks return noise_pred, noise_target, mel_masks
def inference( def inference(
self, self,
@ -270,10 +273,13 @@ class DiffSinger(nn.Layer):
mel_fs2 = mel_fs2.unsqueeze(0).transpose((0, 2, 1)) mel_fs2 = mel_fs2.unsqueeze(0).transpose((0, 2, 1))
cond_fs2 = self.fs2.encoder_infer(text, note, note_dur, is_slur) cond_fs2 = self.fs2.encoder_infer(text, note, note_dur, is_slur)
cond_fs2 = cond_fs2.transpose((0, 2, 1)) cond_fs2 = cond_fs2.transpose((0, 2, 1))
# mel, _ = self.diffusion(mel_fs2, cond_fs2)
noise = paddle.randn(mel_fs2.shape) noise = paddle.randn(mel_fs2.shape)
mel = self.diffusion.inference( mel = self.diffusion.inference(
noise=noise, cond=cond_fs2, ref_x=mel_fs2, num_inference_steps=100) noise=noise,
cond=cond_fs2,
ref_x=mel_fs2,
scheduler_type="ddpm",
num_inference_steps=25)
mel = mel.transpose((0, 2, 1)) mel = mel.transpose((0, 2, 1))
return mel[0] return mel[0]
@ -308,9 +314,7 @@ class DiffSingerInference(nn.Layer):
note_dur=note_dur, note_dur=note_dur,
is_slur=is_slur, is_slur=is_slur,
get_mel_fs2=get_mel_fs2) get_mel_fs2=get_mel_fs2)
print(normalized_mel) logmel = self.normalizer.inverse(normalized_mel)
# logmel = self.normalizer.inverse(normalized_mel)
logmel = normalized_mel
return logmel return logmel
@ -339,16 +343,16 @@ class DiffusionLoss(nn.Layer):
def forward( def forward(
self, self,
ref_mels: paddle.Tensor, noise_pred: paddle.Tensor,
out_mels: paddle.Tensor, noise_target: paddle.Tensor,
mel_masks: paddle.Tensor, ) -> paddle.Tensor: mel_masks: paddle.Tensor, ) -> paddle.Tensor:
"""Calculate forward propagation. """Calculate forward propagation.
Args: Args:
ref_mels(Tensor): noise_pred(Tensor):
Batch of real mel (B, Lmax, odim). Batch of outputs predict noise (B, Lmax, odim).
out_mels(Tensor): noise_target(Tensor):
Batch of outputs mel (B, Lmax, odim). Batch of target noise (B, Lmax, odim).
mel_masks(Tensor): mel_masks(Tensor):
Batch of mask of real mel (B, Lmax, 1). Batch of mask of real mel (B, Lmax, 1).
Returns: Returns:
@ -356,13 +360,13 @@ class DiffusionLoss(nn.Layer):
""" """
# apply mask to remove padded part # apply mask to remove padded part
if self.use_masking: if self.use_masking:
out_mels = out_mels.masked_select( noise_pred = noise_pred.masked_select(
mel_masks.broadcast_to(out_mels.shape)) mel_masks.broadcast_to(noise_pred.shape))
ref_mels = ref_mels.masked_select( noise_target = noise_target.masked_select(
mel_masks.broadcast_to(ref_mels.shape)) mel_masks.broadcast_to(noise_target.shape))
# calculate loss # calculate loss
l1_loss = self.l1_criterion(out_mels, ref_mels) l1_loss = self.l1_criterion(noise_pred, noise_target)
# make weighted mask and apply it # make weighted mask and apply it
if self.use_weighted_masking: if self.use_weighted_masking:
@ -370,7 +374,7 @@ class DiffusionLoss(nn.Layer):
out_weights = mel_masks.cast(dtype=paddle.float32) / mel_masks.cast( out_weights = mel_masks.cast(dtype=paddle.float32) / mel_masks.cast(
dtype=paddle.float32).sum( dtype=paddle.float32).sum(
axis=1, keepdim=True) axis=1, keepdim=True)
out_weights /= ref_mels.shape[0] * ref_mels.shape[2] out_weights /= noise_target.shape[0] * noise_target.shape[2]
# apply weight # apply weight
l1_loss = l1_loss.multiply(out_weights) l1_loss = l1_loss.multiply(out_weights)

@ -34,17 +34,18 @@ logger.setLevel(logging.INFO)
class DiffSingerUpdater(StandardUpdater): class DiffSingerUpdater(StandardUpdater):
def __init__( def __init__(self,
self,
model: Layer, model: Layer,
optimizers: Dict[str, Optimizer], optimizers: Dict[str, Optimizer],
criterions: Dict[str, Layer], criterions: Dict[str, Layer],
dataloader: DataLoader, dataloader: DataLoader,
ds_train_start_steps: int=160000, ds_train_start_steps: int=160000,
output_dir: Path=None, ): output_dir: Path=None,
only_train_diffusion: bool=True):
super().__init__(model, optimizers, dataloader, init_state=None) super().__init__(model, optimizers, dataloader, init_state=None)
self.model = model._layers if isinstance(model, self.model = model._layers if isinstance(model,
paddle.DataParallel) else model paddle.DataParallel) else model
self.only_train_diffusion = only_train_diffusion
self.optimizers = optimizers self.optimizers = optimizers
self.optimizer_fs2: Optimizer = optimizers['fs2'] self.optimizer_fs2: Optimizer = optimizers['fs2']
@ -78,8 +79,7 @@ class DiffSingerUpdater(StandardUpdater):
spk_id = None spk_id = None
# only train fastspeech2 module firstly # only train fastspeech2 module firstly
if self.state.iteration <= self.ds_train_start_steps: if self.state.iteration < self.ds_train_start_steps:
# print(batch)
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model( before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model(
text=batch["text"], text=batch["text"],
note=batch["note"], note=batch["note"],
@ -93,7 +93,7 @@ class DiffSingerUpdater(StandardUpdater):
energy=batch["energy"], energy=batch["energy"],
spk_id=spk_id, spk_id=spk_id,
spk_emb=spk_emb, spk_emb=spk_emb,
train_fs2=True, ) only_train_fs2=True, )
l1_loss_fs2, ssim_loss_fs2, duration_loss, pitch_loss, energy_loss, speaker_loss = self.criterion_fs2( l1_loss_fs2, ssim_loss_fs2, duration_loss, pitch_loss, energy_loss, speaker_loss = self.criterion_fs2(
after_outs=after_outs, after_outs=after_outs,
@ -110,7 +110,7 @@ class DiffSingerUpdater(StandardUpdater):
spk_logits=spk_logits, spk_logits=spk_logits,
spk_ids=spk_id, ) spk_ids=spk_id, )
loss_fs2 = l1_loss_fs2 + ssim_loss_fs2 + duration_loss + pitch_loss + energy_loss loss_fs2 = l1_loss_fs2 + ssim_loss_fs2 + duration_loss + pitch_loss + energy_loss + speaker_loss
self.optimizer_fs2.clear_grad() self.optimizer_fs2.clear_grad()
loss_fs2.backward() loss_fs2.backward()
@ -128,7 +128,10 @@ class DiffSingerUpdater(StandardUpdater):
losses_dict["duration_loss"] = float(duration_loss) losses_dict["duration_loss"] = float(duration_loss)
losses_dict["pitch_loss"] = float(pitch_loss) losses_dict["pitch_loss"] = float(pitch_loss)
losses_dict["energy_loss"] = float(energy_loss) losses_dict["energy_loss"] = float(energy_loss)
losses_dict["energy_loss"] = float(energy_loss)
if speaker_loss != 0.:
report("train/speaker_loss", float(speaker_loss))
losses_dict["speaker_loss"] = float(speaker_loss)
losses_dict["loss_fs2"] = float(loss_fs2) losses_dict["loss_fs2"] = float(loss_fs2)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v) self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
@ -136,10 +139,11 @@ class DiffSingerUpdater(StandardUpdater):
# Then only train diffusion module, freeze fastspeech2 parameters. # Then only train diffusion module, freeze fastspeech2 parameters.
if self.state.iteration > self.ds_train_start_steps: if self.state.iteration > self.ds_train_start_steps:
if self.only_train_diffusion:
for param in self.model.fs2.parameters(): for param in self.model.fs2.parameters():
param.trainable = False param.trainable = False
mel, mel_ref, mel_masks = self.model( noise_pred, noise_target, mel_masks = self.model(
text=batch["text"], text=batch["text"],
note=batch["note"], note=batch["note"],
note_dur=batch["note_dur"], note_dur=batch["note_dur"],
@ -152,14 +156,14 @@ class DiffSingerUpdater(StandardUpdater):
energy=batch["energy"], energy=batch["energy"],
spk_id=spk_id, spk_id=spk_id,
spk_emb=spk_emb, spk_emb=spk_emb,
train_fs2=False, ) only_train_fs2=False, )
mel = mel.transpose((0, 2, 1)) noise_pred = noise_pred.transpose((0, 2, 1))
mel_ref = mel_ref.transpose((0, 2, 1)) noise_target = noise_target.transpose((0, 2, 1))
mel_masks = mel_masks.transpose((0, 2, 1)) mel_masks = mel_masks.transpose((0, 2, 1))
l1_loss_ds = self.criterion_ds( l1_loss_ds = self.criterion_ds(
ref_mels=mel_ref, noise_pred=noise_pred,
out_mels=mel, noise_target=noise_target,
mel_masks=mel_masks, ) mel_masks=mel_masks, )
loss_ds = l1_loss_ds loss_ds = l1_loss_ds
@ -210,7 +214,7 @@ class DiffSingerEvaluator(StandardEvaluator):
spk_id = None spk_id = None
# Here show diffsinger eval # Here show diffsinger eval
mel, mel_ref, mel_masks = self.model( noise_pred, noise_target, mel_masks = self.model(
text=batch["text"], text=batch["text"],
note=batch["note"], note=batch["note"],
note_dur=batch["note_dur"], note_dur=batch["note_dur"],
@ -223,14 +227,14 @@ class DiffSingerEvaluator(StandardEvaluator):
energy=batch["energy"], energy=batch["energy"],
spk_id=spk_id, spk_id=spk_id,
spk_emb=spk_emb, spk_emb=spk_emb,
train_fs2=False, ) only_train_fs2=False, )
mel = mel.transpose((0, 2, 1)) noise_pred = noise_pred.transpose((0, 2, 1))
mel_ref = mel_ref.transpose((0, 2, 1)) noise_target = noise_target.transpose((0, 2, 1))
mel_masks = mel_masks.transpose((0, 2, 1)) mel_masks = mel_masks.transpose((0, 2, 1))
l1_loss_ds = self.criterion_ds( l1_loss_ds = self.criterion_ds(
ref_mels=mel_ref, noise_pred=noise_pred,
out_mels=mel, noise_target=noise_target,
mel_masks=mel_masks, ) mel_masks=mel_masks, )
loss_ds = l1_loss_ds loss_ds = l1_loss_ds

@ -23,10 +23,10 @@ from typeguard import check_argument_types
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2 from paddlespeech.t2s.models.fastspeech2 import FastSpeech2
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Loss from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Loss
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
from paddlespeech.t2s.modules.nets_utils import make_pad_mask
from paddlespeech.t2s.modules.losses import ssim from paddlespeech.t2s.modules.losses import ssim
from paddlespeech.t2s.modules.masked_fill import masked_fill from paddlespeech.t2s.modules.masked_fill import masked_fill
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
from paddlespeech.t2s.modules.nets_utils import make_pad_mask
class FastSpeech2MIDI(FastSpeech2): class FastSpeech2MIDI(FastSpeech2):
@ -61,14 +61,14 @@ class FastSpeech2MIDI(FastSpeech2):
self.note_embed_dim = self.is_slur_embed_dim = fastspeech2_params[ self.note_embed_dim = self.is_slur_embed_dim = fastspeech2_params[
"adim"] "adim"]
if note_num is not None: # note_ embed
self.note_embedding_table = nn.Embedding( self.note_embedding_table = nn.Embedding(
num_embeddings=note_num, num_embeddings=note_num,
embedding_dim=self.note_embed_dim, embedding_dim=self.note_embed_dim,
padding_idx=self.padding_idx) padding_idx=self.padding_idx)
self.note_dur_layer = nn.Linear(1, self.note_embed_dim) self.note_dur_layer = nn.Linear(1, self.note_embed_dim)
if is_slur_num is not None: # slur embed
self.is_slur_embedding_table = nn.Embedding( self.is_slur_embedding_table = nn.Embedding(
num_embeddings=is_slur_num, num_embeddings=is_slur_num,
embedding_dim=self.is_slur_embed_dim, embedding_dim=self.is_slur_embed_dim,
@ -203,7 +203,7 @@ class FastSpeech2MIDI(FastSpeech2):
spk_emb = self.spk_embedding_table(spk_id) spk_emb = self.spk_embedding_table(spk_id)
hs = self._integrate_with_spk_embed(hs, spk_emb) hs = self._integrate_with_spk_embed(hs, spk_emb)
# forward duration predictor and variance predictors # forward duration predictor (phone-level) and variance predictors (frame-level)
d_masks = make_pad_mask(ilens) d_masks = make_pad_mask(ilens)
if olens is not None: if olens is not None:
pitch_masks = make_pad_mask(olens).unsqueeze(-1) pitch_masks = make_pad_mask(olens).unsqueeze(-1)
@ -214,13 +214,12 @@ class FastSpeech2MIDI(FastSpeech2):
if is_train_diffusion: if is_train_diffusion:
hs = self.length_regulator(hs, ds, is_inference=False) hs = self.length_regulator(hs, ds, is_inference=False)
p_outs = self.pitch_predictor(hs.detach(), pitch_masks) p_outs = self.pitch_predictor(hs.detach(), pitch_masks)
# e_outs = self.energy_predictor(hs.detach(), pitch_masks) e_outs = self.energy_predictor(hs.detach(), pitch_masks)
p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose( p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose(
(0, 2, 1)) (0, 2, 1))
# e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose( e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose(
# (0, 2, 1)) (0, 2, 1))
# hs = hs + p_embs + e_embs hs = hs + p_embs + e_embs
hs = hs + p_embs
elif is_inference: elif is_inference:
# (B, Tmax) # (B, Tmax)
@ -240,20 +239,19 @@ class FastSpeech2MIDI(FastSpeech2):
else: else:
p_outs = self.pitch_predictor(hs, pitch_masks) p_outs = self.pitch_predictor(hs, pitch_masks)
# if es is not None: if es is not None:
# e_outs = es e_outs = es
# else: else:
# if self.stop_gradient_from_energy_predictor: if self.stop_gradient_from_energy_predictor:
# e_outs = self.energy_predictor(hs.detach(), pitch_masks) e_outs = self.energy_predictor(hs.detach(), pitch_masks)
# else: else:
# e_outs = self.energy_predictor(hs, pitch_masks) e_outs = self.energy_predictor(hs, pitch_masks)
p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose( p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose(
(0, 2, 1)) (0, 2, 1))
# e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose( e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose(
# (0, 2, 1)) (0, 2, 1))
# hs = hs + p_embs + e_embs hs = hs + p_embs + e_embs
hs = hs + p_embs
# training # training
else: else:
@ -264,16 +262,15 @@ class FastSpeech2MIDI(FastSpeech2):
p_outs = self.pitch_predictor(hs.detach(), pitch_masks) p_outs = self.pitch_predictor(hs.detach(), pitch_masks)
else: else:
p_outs = self.pitch_predictor(hs, pitch_masks) p_outs = self.pitch_predictor(hs, pitch_masks)
# if self.stop_gradient_from_energy_predictor: if self.stop_gradient_from_energy_predictor:
# e_outs = self.energy_predictor(hs.detach(), pitch_masks) e_outs = self.energy_predictor(hs.detach(), pitch_masks)
# else: else:
# e_outs = self.energy_predictor(hs, pitch_masks) e_outs = self.energy_predictor(hs, pitch_masks)
p_embs = self.pitch_embed(ps.transpose((0, 2, 1))).transpose( p_embs = self.pitch_embed(ps.transpose((0, 2, 1))).transpose(
(0, 2, 1)) (0, 2, 1))
# e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose( e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose(
# (0, 2, 1)) (0, 2, 1))
# hs = hs + p_embs + e_embs hs = hs + p_embs + e_embs
hs = hs + p_embs
# forward decoder # forward decoder
if olens is not None and not is_inference: if olens is not None and not is_inference:
@ -302,11 +299,11 @@ class FastSpeech2MIDI(FastSpeech2):
(paddle.shape(zs)[0], -1, self.odim)) (paddle.shape(zs)[0], -1, self.odim))
# postnet -> (B, Lmax//r * r, odim) # postnet -> (B, Lmax//r * r, odim)
# if self.postnet is None: if self.postnet is None:
# after_outs = before_outs after_outs = before_outs
# else: else:
# after_outs = before_outs + self.postnet( after_outs = before_outs + self.postnet(
# before_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) before_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
after_outs = before_outs after_outs = before_outs
return before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits return before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits
@ -478,8 +475,7 @@ class FastSpeech2MIDI(FastSpeech2):
spk_emb=spk_emb, spk_emb=spk_emb,
spk_id=spk_id, ) spk_id=spk_id, )
# return outs[0], d_outs[0], p_outs[0], e_outs[0] return outs[0], d_outs[0], p_outs[0], e_outs[0]
return outs[0], d_outs[0], p_outs[0], None
class FastSpeech2MIDILoss(FastSpeech2Loss): class FastSpeech2MIDILoss(FastSpeech2Loss):
@ -551,21 +547,21 @@ class FastSpeech2MIDILoss(FastSpeech2Loss):
""" """
l1_loss = duration_loss = pitch_loss = energy_loss = speaker_loss = ssim_loss = 0.0 l1_loss = duration_loss = pitch_loss = energy_loss = speaker_loss = ssim_loss = 0.0
out_pad_masks = make_pad_mask(olens).unsqueeze(-1)
before_outs_batch = masked_fill(before_outs, out_pad_masks, 0.0)
# print(before_outs.shape, ys.shape)
ssim_loss = 1.0 - ssim(before_outs_batch.unsqueeze(1), ys.unsqueeze(1))
ssim_loss = ssim_loss * 0.5
# apply mask to remove padded part # apply mask to remove padded part
if self.use_masking: if self.use_masking:
# make feature for ssim loss
out_pad_masks = make_pad_mask(olens).unsqueeze(-1)
before_outs_ssim = masked_fill(before_outs, out_pad_masks, 0.0)
if after_outs is not None:
after_outs_ssim = masked_fill(after_outs, out_pad_masks, 0.0)
ys_ssim = masked_fill(ys, out_pad_masks, 0.0)
out_masks = make_non_pad_mask(olens).unsqueeze(-1) out_masks = make_non_pad_mask(olens).unsqueeze(-1)
before_outs = before_outs.masked_select( before_outs = before_outs.masked_select(
out_masks.broadcast_to(before_outs.shape)) out_masks.broadcast_to(before_outs.shape))
if after_outs is not None:
# if after_outs is not None: after_outs = after_outs.masked_select(
# after_outs = after_outs.masked_select( out_masks.broadcast_to(after_outs.shape))
# out_masks.broadcast_to(after_outs.shape))
ys = ys.masked_select(out_masks.broadcast_to(ys.shape)) ys = ys.masked_select(out_masks.broadcast_to(ys.shape))
duration_masks = make_non_pad_mask(ilens) duration_masks = make_non_pad_mask(ilens)
d_outs = d_outs.masked_select( d_outs = d_outs.masked_select(
@ -574,8 +570,8 @@ class FastSpeech2MIDILoss(FastSpeech2Loss):
pitch_masks = out_masks pitch_masks = out_masks
p_outs = p_outs.masked_select( p_outs = p_outs.masked_select(
pitch_masks.broadcast_to(p_outs.shape)) pitch_masks.broadcast_to(p_outs.shape))
# e_outs = e_outs.masked_select( e_outs = e_outs.masked_select(
# pitch_masks.broadcast_to(e_outs.shape)) pitch_masks.broadcast_to(e_outs.shape))
ps = ps.masked_select(pitch_masks.broadcast_to(ps.shape)) ps = ps.masked_select(pitch_masks.broadcast_to(ps.shape))
es = es.masked_select(pitch_masks.broadcast_to(es.shape)) es = es.masked_select(pitch_masks.broadcast_to(es.shape))
@ -591,17 +587,18 @@ class FastSpeech2MIDILoss(FastSpeech2Loss):
# calculate loss # calculate loss
l1_loss = self.l1_criterion(before_outs, ys) l1_loss = self.l1_criterion(before_outs, ys)
# if after_outs is not None: ssim_loss = 1.0 - ssim(
# l1_loss += self.l1_criterion(after_outs, ys) before_outs_ssim.unsqueeze(1), ys_ssim.unsqueeze(1))
# ssim_loss += (1.0 - ssim(after_outs, ys)) if after_outs is not None:
l1_loss += self.l1_criterion(after_outs, ys)
ssim_loss += (
1.0 - ssim(after_outs_ssim.unsqueeze(1), ys_ssim.unsqueeze(1)))
l1_loss = l1_loss * 0.5 l1_loss = l1_loss * 0.5
ssim_loss = ssim_loss * 0.5
duration_loss = self.duration_criterion(d_outs, ds) duration_loss = self.duration_criterion(d_outs, ds)
# print("ppppppppppoooooooooooo: ", p_outs, p_outs.shape)
# print("ppppppppppssssssssssss: ", ps, ps.shape)
# pitch_loss = self.mse_criterion(p_outs, ps)
# energy_loss = self.mse_criterion(e_outs, es)
pitch_loss = self.l1_criterion(p_outs, ps) pitch_loss = self.l1_criterion(p_outs, ps)
energy_loss = self.l1_criterion(e_outs, es)
if spk_logits is not None and spk_ids is not None: if spk_logits is not None and spk_ids is not None:
speaker_loss = self.ce_criterion(spk_logits, spk_ids) / batch_size speaker_loss = self.ce_criterion(spk_logits, spk_ids) / batch_size
@ -623,6 +620,9 @@ class FastSpeech2MIDILoss(FastSpeech2Loss):
l1_loss = l1_loss.multiply(out_weights) l1_loss = l1_loss.multiply(out_weights)
l1_loss = l1_loss.masked_select( l1_loss = l1_loss.masked_select(
out_masks.broadcast_to(l1_loss.shape)).sum() out_masks.broadcast_to(l1_loss.shape)).sum()
ssim_loss = ssim_loss.multiply(out_weights)
ssim_loss = ssim_loss.masked_select(
out_masks.broadcast_to(ssim_loss.shape)).sum()
duration_loss = (duration_loss.multiply(duration_weights) duration_loss = (duration_loss.multiply(duration_weights)
.masked_select(duration_masks).sum()) .masked_select(duration_masks).sum())
pitch_masks = out_masks pitch_masks = out_masks
@ -630,8 +630,8 @@ class FastSpeech2MIDILoss(FastSpeech2Loss):
pitch_loss = pitch_loss.multiply(pitch_weights) pitch_loss = pitch_loss.multiply(pitch_weights)
pitch_loss = pitch_loss.masked_select( pitch_loss = pitch_loss.masked_select(
pitch_masks.broadcast_to(pitch_loss.shape)).sum() pitch_masks.broadcast_to(pitch_loss.shape)).sum()
# energy_loss = energy_loss.multiply(pitch_weights) energy_loss = energy_loss.multiply(pitch_weights)
# energy_loss = energy_loss.masked_select( energy_loss = energy_loss.masked_select(
# pitch_masks.broadcast_to(energy_loss.shape)).sum() pitch_masks.broadcast_to(energy_loss.shape)).sum()
return l1_loss, ssim_loss, duration_loss, pitch_loss, energy_loss, speaker_loss return l1_loss, ssim_loss, duration_loss, pitch_loss, energy_loss, speaker_loss

@ -17,7 +17,6 @@ from typing import Callable
from typing import Optional from typing import Optional
from typing import Tuple from typing import Tuple
import numpy as np
import paddle import paddle
import ppdiffusers import ppdiffusers
from paddle import nn from paddle import nn
@ -316,46 +315,8 @@ class GaussianDiffusion(nn.Layer):
beta_end=beta_end, beta_end=beta_end,
beta_schedule=beta_schedule) beta_schedule=beta_schedule)
self.num_max_timesteps = num_max_timesteps self.num_max_timesteps = num_max_timesteps
self.spec_min = paddle.to_tensor(
np.array([
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0
]))
self.spec_max = paddle.to_tensor(
np.array([
-0.79453, -0.81116, -0.61631, -0.30679, -0.13863, -0.050652,
-0.11563, -0.10679, -0.091068, -0.062174, -0.075302, -0.072217,
-0.063815, -0.073299, 0.007361, -0.072508, -0.050234, -0.16534,
-0.26928, -0.20782, -0.20823, -0.11702, -0.070128, -0.065868,
-0.012675, 0.0015121, -0.089902, -0.21392, -0.23789, -0.28922,
-0.30405, -0.23029, -0.22088, -0.21542, -0.29367, -0.30137,
-0.38281, -0.4359, -0.28681, -0.46855, -0.57485, -0.47022,
-0.54266, -0.44848, -0.6412, -0.687, -0.6486, -0.76436,
-0.49971, -0.71068, -0.69724, -0.61487, -0.55843, -0.69773,
-0.57502, -0.70919, -0.82431, -0.84213, -0.90431, -0.8284,
-0.77945, -0.82758, -0.87699, -1.0532, -1.0766, -1.1198,
-1.0185, -0.98983, -1.0001, -1.0756, -1.0024, -1.0304, -1.0579,
-1.0188, -1.05, -1.0842, -1.0923, -1.1223, -1.2381, -1.6467
]))
def norm_spec(self, x):
"""
Linearly map x to [-1, 1]
Args:
x: [B, T, N]
"""
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
def denorm_spec(self, x): def forward(self, x: paddle.Tensor, cond: Optional[paddle.Tensor]=None
return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
def forward(self, x: paddle.Tensor, cond: Optional[paddle.Tensor]=None, is_infer: bool=False,
) -> Tuple[paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Generate random timesteps noised x. """Generate random timesteps noised x.
@ -372,9 +333,6 @@ class GaussianDiffusion(nn.Layer):
The noises which is added to the input. The noises which is added to the input.
""" """
x = x.transpose((0, 2, 1))
x = self.norm_spec(x)
x = x.transpose((0, 2, 1))
noise_scheduler = self.noise_scheduler noise_scheduler = self.noise_scheduler
# Sample noise that we'll add to the mel-spectrograms # Sample noise that we'll add to the mel-spectrograms
@ -392,13 +350,6 @@ class GaussianDiffusion(nn.Layer):
y = self.denoiser(noisy_images, timesteps, cond) y = self.denoiser(noisy_images, timesteps, cond)
if is_infer:
y = y.transpose((0, 2, 1))
y = self.denorm_spec(y)
y = y.transpose((0, 2, 1))
# y = self.denorm_spec(y)
# then compute loss use output y and noisy target for prediction_type == "epsilon" # then compute loss use output y and noisy target for prediction_type == "epsilon"
return y, target return y, target
@ -478,9 +429,6 @@ class GaussianDiffusion(nn.Layer):
noisy_input = noise noisy_input = noise
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
if ref_x is not None: if ref_x is not None:
ref_x = ref_x.transpose((0, 2, 1))
ref_x = self.norm_spec(ref_x)
ref_x = ref_x.transpose((0, 2, 1))
init_timestep = None init_timestep = None
if strength is None or strength < 0. or strength > 1.: if strength is None or strength < 0. or strength > 1.:
strength = None strength = None
@ -521,10 +469,4 @@ class GaussianDiffusion(nn.Layer):
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, len(timesteps), denoised_output) callback(i, t, len(timesteps), denoised_output)
denoised_output = denoised_output.transpose((0, 2, 1))
denoised_output = self.denorm_spec(denoised_output)
denoised_output = denoised_output.transpose((0, 2, 1))
return denoised_output return denoised_output

Loading…
Cancel
Save