From 84a22ffb93ddb281622455c678c972cefcdd9032 Mon Sep 17 00:00:00 2001 From: liangym Date: Thu, 9 Feb 2023 08:01:33 +0000 Subject: [PATCH] diffsinger_tmp --- examples/opencpop/svs1/conf/default.yaml | 67 +++--- examples/opencpop/svs1/local/preprocess.sh | 2 +- examples/opencpop/svs1/local/synthesize.sh | 80 +------ examples/opencpop/svs1/local/train.sh | 2 +- examples/opencpop/svs1/run.sh | 7 +- paddlespeech/t2s/datasets/get_feats.py | 1 - paddlespeech/t2s/datasets/preprocess_utils.py | 4 +- paddlespeech/t2s/exps/diffsinger/normalize.py | 20 +- .../t2s/exps/diffsinger/preprocess.py | 3 +- paddlespeech/t2s/exps/diffsinger/train.py | 3 - paddlespeech/t2s/exps/synthesize.py | 6 + .../t2s/models/diffsinger/diffsinger.py | 42 +++- .../models/diffsinger/diffsinger_updater.py | 13 +- .../t2s/models/diffsinger/fastspeech2midi.py | 198 ++++++++++-------- .../t2s/models/fastspeech2/fastspeech2.py | 13 +- paddlespeech/t2s/modules/activation.py | 3 +- paddlespeech/t2s/modules/diffusion.py | 50 +++-- paddlespeech/t2s/modules/masked_fill.py | 2 - .../modules/predictor/variance_predictor.py | 2 +- 19 files changed, 256 insertions(+), 262 deletions(-) diff --git a/examples/opencpop/svs1/conf/default.yaml b/examples/opencpop/svs1/conf/default.yaml index 7e48190c7..7729889e5 100644 --- a/examples/opencpop/svs1/conf/default.yaml +++ b/examples/opencpop/svs1/conf/default.yaml @@ -23,7 +23,7 @@ f0max: 750 # Maximum f0 for pitch extraction. ########################################################### # DATA SETTING # ########################################################### -batch_size: 32 +batch_size: 48 num_workers: 1 @@ -37,33 +37,32 @@ model: # fastspeech2 module fastspeech2_params: - adim: 256 # attention dimension - aheads: 2 # number of attention heads - elayers: 4 # number of encoder layers - eunits: 1536 # number of encoder ff units - dlayers: 4 # number of decoder layers - dunits: 1536 # number of decoder ff units - positionwise_layer_type: conv1d # type of position-wise layer - positionwise_conv_kernel_size: 3 # kernel size of position wise conv layer - duration_predictor_layers: 2 # number of layers of duration predictor - duration_predictor_chans: 256 # number of channels of duration predictor - duration_predictor_kernel_size: 3 # filter size of duration predictor - postnet_layers: 5 # number of layers of postnset - postnet_filts: 5 # filter size of conv layers in postnet - postnet_chans: 256 # number of channels of conv layers in postnet - use_scaled_pos_enc: True # whether to use scaled positional encoding + adim: 256 # attention dimension # lym check + aheads: 2 # number of attention heads # lym check + elayers: 4 # number of encoder layers # lym check + eunits: 1024 # number of encoder ff units # lym check adim * 4 + dlayers: 4 # number of decoder layers # lym check + dunits: 1024 # number of decoder ff units # lym check + positionwise_layer_type: conv1d-linear # type of position-wise layer # lym check + positionwise_conv_kernel_size: 9 # kernel size of position wise conv layer # lym check + transformer_enc_dropout_rate: 0.1 # dropout rate for transformer encoder layer # lym check + transformer_enc_positional_dropout_rate: 0.1 # dropout rate for transformer encoder positional encoding # lym check + transformer_enc_attn_dropout_rate: 0.0 # dropout rate for transformer encoder attention layer # lym check + transformer_activation_type: "gelu" encoder_normalize_before: True # whether to perform layer normalization before the input decoder_normalize_before: True # whether to perform layer normalization before the input reduction_factor: 1 # reduction factor init_type: xavier_uniform # initialization type init_enc_alpha: 1.0 # initial value of alpha of encoder scaled position encoding init_dec_alpha: 1.0 # initial value of alpha of decoder scaled position encoding - transformer_enc_dropout_rate: 0.2 # dropout rate for transformer encoder layer - transformer_enc_positional_dropout_rate: 0.2 # dropout rate for transformer encoder positional encoding - transformer_enc_attn_dropout_rate: 0.2 # dropout rate for transformer encoder attention layer - transformer_dec_dropout_rate: 0.2 # dropout rate for transformer decoder layer - transformer_dec_positional_dropout_rate: 0.2 # dropout rate for transformer decoder positional encoding - transformer_dec_attn_dropout_rate: 0.2 # dropout rate for transformer decoder attention layer + use_scaled_pos_enc: True # whether to use scaled positional encoding + transformer_dec_dropout_rate: 0.1 # dropout rate for transformer decoder layer + transformer_dec_positional_dropout_rate: 0.1 # dropout rate for transformer decoder positional encoding + transformer_dec_attn_dropout_rate: 0.0 # dropout rate for transformer decoder attention layer + duration_predictor_layers: 5 # number of layers of duration predictor + duration_predictor_chans: 256 # number of channels of duration predictor + duration_predictor_kernel_size: 3 # filter size of duration predictor + duration_predictor_dropout_rate: 0.5 # dropout rate in energy predictor pitch_predictor_layers: 5 # number of conv layers in pitch predictor pitch_predictor_chans: 256 # number of channels of conv layers in pitch predictor pitch_predictor_kernel_size: 5 # kernel size of conv leyers in pitch predictor @@ -71,6 +70,11 @@ model: pitch_embed_kernel_size: 1 # kernel size of conv embedding layer for pitch pitch_embed_dropout: 0.0 # dropout rate after conv embedding layer for pitch stop_gradient_from_pitch_predictor: True # whether to stop the gradient from pitch predictor to encoder + + + postnet_layers: 5 # number of layers of postnset + postnet_filts: 5 # filter size of conv layers in postnet + postnet_chans: 256 # number of channels of conv layers in postnet energy_predictor_layers: 2 # number of conv layers in energy predictor energy_predictor_chans: 256 # number of channels of conv layers in energy predictor energy_predictor_kernel_size: 3 # kernel size of conv leyers in energy predictor @@ -131,19 +135,24 @@ ds_optimizer_params: ds_scheduler_params: learning_rate: 0.001 gamma: 0.5 - step_size: 25000 + step_size: 10000 ds_grad_norm: 1 ########################################################### # INTERVAL SETTING # ########################################################### -ds_train_start_steps: 160000 # Number of steps to start to train diffusion module. -train_max_steps: 320000 # Number of training steps. -save_interval_steps: 1000 # Interval steps to save checkpoint. -eval_interval_steps: 1000 # Interval steps to evaluate the network. -num_snapshots: 5 - +ds_train_start_steps: 32500 # Number of steps to start to train diffusion module. +train_max_steps: 65000 # Number of training steps. +save_interval_steps: 500 # Interval steps to save checkpoint. +eval_interval_steps: 500 # Interval steps to evaluate the network. +num_snapshots: 20 + +# ds_train_start_steps: 4 # Number of steps to start to train diffusion module. +# train_max_steps: 8 # Number of training steps. +# save_interval_steps: 1 # Interval steps to save checkpoint. +# eval_interval_steps: 2 # Interval steps to evaluate the network. +# num_snapshots: 5 ########################################################### # OTHER SETTING # diff --git a/examples/opencpop/svs1/local/preprocess.sh b/examples/opencpop/svs1/local/preprocess.sh index 78803b83f..1c98ca84d 100755 --- a/examples/opencpop/svs1/local/preprocess.sh +++ b/examples/opencpop/svs1/local/preprocess.sh @@ -1,6 +1,6 @@ #!/bin/bash -stage=1 +stage=0 stop_stage=100 config_path=$1 diff --git a/examples/opencpop/svs1/local/synthesize.sh b/examples/opencpop/svs1/local/synthesize.sh index 238437b4d..cc58b58ce 100755 --- a/examples/opencpop/svs1/local/synthesize.sh +++ b/examples/opencpop/svs1/local/synthesize.sh @@ -2,7 +2,9 @@ config_path=$1 train_output_path=$2 -ckpt_name=$3 +#ckpt_name=$3 +iter=$3 +ckpt_name=snapshot_iter_${iter}.pdz stage=0 stop_stage=0 @@ -20,81 +22,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then --voc_ckpt=pwgan_opencpop/snapshot_iter_100000.pdz \ --voc_stat=pwgan_opencpop/feats_stats.npy \ --test_metadata=dump/test/norm/metadata.jsonl \ - --output_dir=${train_output_path}/test \ + --output_dir=${train_output_path}/test_${iter} \ --phones_dict=dump/phone_id_map.txt fi -# for more GAN Vocoders -# multi band melgan -if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then - FLAGS_allocator_strategy=naive_best_fit \ - FLAGS_fraction_of_gpu_memory_to_use=0.01 \ - python3 ${BIN_DIR}/../synthesize.py \ - --am=fastspeech2_csmsc \ - --am_config=${config_path} \ - --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ - --am_stat=dump/train/speech_stats.npy \ - --voc=mb_melgan_csmsc \ - --voc_config=mb_melgan_csmsc_ckpt_0.1.1/default.yaml \ - --voc_ckpt=mb_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1000000.pdz\ - --voc_stat=mb_melgan_csmsc_ckpt_0.1.1/feats_stats.npy \ - --test_metadata=dump/test/norm/metadata.jsonl \ - --output_dir=${train_output_path}/test \ - --phones_dict=dump/phone_id_map.txt -fi - -# style melgan -if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then - FLAGS_allocator_strategy=naive_best_fit \ - FLAGS_fraction_of_gpu_memory_to_use=0.01 \ - python3 ${BIN_DIR}/../synthesize.py \ - --am=fastspeech2_csmsc \ - --am_config=${config_path} \ - --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ - --am_stat=dump/train/speech_stats.npy \ - --voc=style_melgan_csmsc \ - --voc_config=style_melgan_csmsc_ckpt_0.1.1/default.yaml \ - --voc_ckpt=style_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1500000.pdz \ - --voc_stat=style_melgan_csmsc_ckpt_0.1.1/feats_stats.npy \ - --test_metadata=dump/test/norm/metadata.jsonl \ - --output_dir=${train_output_path}/test \ - --phones_dict=dump/phone_id_map.txt -fi - -# hifigan -if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then - echo "in hifigan syn" - FLAGS_allocator_strategy=naive_best_fit \ - FLAGS_fraction_of_gpu_memory_to_use=0.01 \ - python3 ${BIN_DIR}/../synthesize.py \ - --am=fastspeech2_csmsc \ - --am_config=${config_path} \ - --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ - --am_stat=dump/train/speech_stats.npy \ - --voc=hifigan_csmsc \ - --voc_config=hifigan_csmsc_ckpt_0.1.1/default.yaml \ - --voc_ckpt=hifigan_csmsc_ckpt_0.1.1/snapshot_iter_2500000.pdz \ - --voc_stat=hifigan_csmsc_ckpt_0.1.1/feats_stats.npy \ - --test_metadata=dump/test/norm/metadata.jsonl \ - --output_dir=${train_output_path}/test \ - --phones_dict=dump/phone_id_map.txt -fi - -# wavernn -if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then - echo "in wavernn syn" - FLAGS_allocator_strategy=naive_best_fit \ - FLAGS_fraction_of_gpu_memory_to_use=0.01 \ - python3 ${BIN_DIR}/../synthesize.py \ - --am=fastspeech2_csmsc \ - --am_config=${config_path} \ - --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ - --am_stat=dump/train/speech_stats.npy \ - --voc=wavernn_csmsc \ - --voc_config=wavernn_csmsc_ckpt_0.2.0/default.yaml \ - --voc_ckpt=wavernn_csmsc_ckpt_0.2.0/snapshot_iter_400000.pdz \ - --voc_stat=wavernn_csmsc_ckpt_0.2.0/feats_stats.npy \ - --test_metadata=dump/test/norm/metadata.jsonl \ - --output_dir=${train_output_path}/test \ - --phones_dict=dump/phone_id_map.txt -fi diff --git a/examples/opencpop/svs1/local/train.sh b/examples/opencpop/svs1/local/train.sh index d1302f99f..42fff26ca 100755 --- a/examples/opencpop/svs1/local/train.sh +++ b/examples/opencpop/svs1/local/train.sh @@ -8,5 +8,5 @@ python3 ${BIN_DIR}/train.py \ --dev-metadata=dump/dev/norm/metadata.jsonl \ --config=${config_path} \ --output-dir=${train_output_path} \ - --ngpu=1 \ + --ngpu=4 \ --phones-dict=dump/phone_id_map.txt diff --git a/examples/opencpop/svs1/run.sh b/examples/opencpop/svs1/run.sh index 7bde38518..10a2b5290 100755 --- a/examples/opencpop/svs1/run.sh +++ b/examples/opencpop/svs1/run.sh @@ -3,9 +3,10 @@ set -e source path.sh -gpus=0 -stage=0 -stop_stage=100 +gpus=4,5,6,7 +#gpus=0 +stage=1 +stop_stage=1 conf_path=conf/default.yaml train_output_path=exp/default diff --git a/paddlespeech/t2s/datasets/get_feats.py b/paddlespeech/t2s/datasets/get_feats.py index bad1d43ce..871699c81 100644 --- a/paddlespeech/t2s/datasets/get_feats.py +++ b/paddlespeech/t2s/datasets/get_feats.py @@ -105,7 +105,6 @@ class Pitch(): if (f0 == 0).all(): print("All frames seems to be unvoiced.") return f0 - # padding start and end of f0 sequence start_f0 = f0[f0 != 0][0] end_f0 = f0[f0 != 0][-1] diff --git a/paddlespeech/t2s/datasets/preprocess_utils.py b/paddlespeech/t2s/datasets/preprocess_utils.py index e8e4083a4..bf813b22a 100644 --- a/paddlespeech/t2s/datasets/preprocess_utils.py +++ b/paddlespeech/t2s/datasets/preprocess_utils.py @@ -101,7 +101,7 @@ def get_sentences_svs( dataset (str): dataset name Returns: Dict: the information of sentence, include [phone id (int)], [the frame of phone (int)], [note id (int)], [note duration (float)], [is slur (int)], text(str), speaker name (str) - tunple: speaker name + tuple: speaker name ''' f = open(file_name, 'r') sentence = {} @@ -115,7 +115,7 @@ def get_sentences_svs( ph = line_list[2].split() midi = note2midi(line_list[3].split()) midi_dur = line_list[4].split() - ph_dur = time2frame([float(t) for t in line_list[5].split()]) + ph_dur = time2frame([float(t) for t in line_list[5].split()], sample_rate=sample_rate, n_shift=n_shift) is_slur = line_list[6].split() assert len(ph) == len(midi) == len(midi_dur) == len(is_slur) sentence[utt] = (ph, [int(i) for i in ph_dur], diff --git a/paddlespeech/t2s/exps/diffsinger/normalize.py b/paddlespeech/t2s/exps/diffsinger/normalize.py index ec80a229e..f0f4f0b0d 100644 --- a/paddlespeech/t2s/exps/diffsinger/normalize.py +++ b/paddlespeech/t2s/exps/diffsinger/normalize.py @@ -80,20 +80,27 @@ def main(): # restore scaler speech_scaler = StandardScaler() - speech_scaler.mean_ = np.load(args.speech_stats)[0] - speech_scaler.scale_ = np.load(args.speech_stats)[1] + # speech_scaler.mean_ = np.load(args.speech_stats)[0] + # speech_scaler.scale_ = np.load(args.speech_stats)[1] + speech_scaler.mean_ = np.zeros(np.load(args.speech_stats)[0].shape, dtype="float32") + speech_scaler.scale_ = np.ones(np.load(args.speech_stats)[1].shape, dtype="float32") speech_scaler.n_features_in_ = speech_scaler.mean_.shape[0] pitch_scaler = StandardScaler() - pitch_scaler.mean_ = np.load(args.pitch_stats)[0] - pitch_scaler.scale_ = np.load(args.pitch_stats)[1] + # pitch_scaler.mean_ = np.load(args.pitch_stats)[0] + # pitch_scaler.scale_ = np.load(args.pitch_stats)[1] + pitch_scaler.mean_ = np.zeros(np.load(args.pitch_stats)[0].shape, dtype="float32") + pitch_scaler.scale_ = np.ones(np.load(args.pitch_stats)[1].shape, dtype="float32") pitch_scaler.n_features_in_ = pitch_scaler.mean_.shape[0] energy_scaler = StandardScaler() - energy_scaler.mean_ = np.load(args.energy_stats)[0] - energy_scaler.scale_ = np.load(args.energy_stats)[1] + # energy_scaler.mean_ = np.load(args.energy_stats)[0] + # energy_scaler.scale_ = np.load(args.energy_stats)[1] + energy_scaler.mean_ = np.zeros(np.load(args.energy_stats)[0].shape, dtype="float32") + energy_scaler.scale_ = np.ones(np.load(args.energy_stats)[1].shape, dtype="float32") energy_scaler.n_features_in_ = energy_scaler.mean_.shape[0] + vocab_phones = {} with open(args.phones_dict, 'rt') as f: phn_id = [line.strip().split() for line in f.readlines()] @@ -111,6 +118,7 @@ def main(): for item in tqdm(dataset): utt_id = item['utt_id'] + print(utt_id) speech = item['speech'] pitch = item['pitch'] energy = item['energy'] diff --git a/paddlespeech/t2s/exps/diffsinger/preprocess.py b/paddlespeech/t2s/exps/diffsinger/preprocess.py index f9322dc98..7d47c61a5 100644 --- a/paddlespeech/t2s/exps/diffsinger/preprocess.py +++ b/paddlespeech/t2s/exps/diffsinger/preprocess.py @@ -34,7 +34,6 @@ from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_ from paddlespeech.t2s.datasets.preprocess_utils import get_input_token from paddlespeech.t2s.datasets.preprocess_utils import get_sentences_svs from paddlespeech.t2s.datasets.preprocess_utils import get_spk_id_map -from paddlespeech.t2s.datasets.preprocess_utils import merge_silence from paddlespeech.t2s.utils import str2bool ALL_INITIALS = [ @@ -106,6 +105,7 @@ def process_sentence( pitch_dir = output_dir / "data_pitch" pitch_dir.mkdir(parents=True, exist_ok=True) pitch_path = pitch_dir / (utt_id + "_pitch.npy") + # print(pitch, pitch.shape) np.save(pitch_path, pitch) energy = energy_extractor.get_energy(wav) assert energy.shape[0] == num_frames @@ -271,7 +271,6 @@ def main(): sample_rate=config.fs, n_shift=config.n_shift, ) - # merge_silence(sentences) phone_id_map_path = dumpdir / "phone_id_map.txt" speaker_id_map_path = dumpdir / "speaker_id_map.txt" get_input_token(sentences, phone_id_map_path, args.dataset) diff --git a/paddlespeech/t2s/exps/diffsinger/train.py b/paddlespeech/t2s/exps/diffsinger/train.py index c7940ad40..41b9f2a84 100644 --- a/paddlespeech/t2s/exps/diffsinger/train.py +++ b/paddlespeech/t2s/exps/diffsinger/train.py @@ -43,9 +43,6 @@ from paddlespeech.t2s.training.extensions.visualizer import VisualDL from paddlespeech.t2s.training.optimizer import build_optimizers from paddlespeech.t2s.training.seeding import seed_everything from paddlespeech.t2s.training.trainer import Trainer -from paddlespeech.t2s.utils import str2bool - -# from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Loss def train_sp(args, config): diff --git a/paddlespeech/t2s/exps/synthesize.py b/paddlespeech/t2s/exps/synthesize.py index 009e33a16..23d7e24dc 100644 --- a/paddlespeech/t2s/exps/synthesize.py +++ b/paddlespeech/t2s/exps/synthesize.py @@ -120,7 +120,11 @@ def evaluate(args): note_dur=note_dur, is_slur=is_slur, get_mel_fs2=get_mel_fs2) + # import numpy as np + # mel = np.load("/home/liangyunming/others_code/DiffSinger_lym/diffsinger_mel.npy") + # mel = paddle.to_tensor(mel) wav = voc_inference(mel) + wav = wav.numpy() N += wav.size @@ -131,8 +135,10 @@ def evaluate(args): f"{utt_id}, mel: {mel.shape}, wave: {wav.size}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." ) sf.write( + # str(output_dir / ("xiaojiuwo_diffsinger" + ".wav")), wav, samplerate=am_config.fs) str(output_dir / (utt_id + ".wav")), wav, samplerate=am_config.fs) print(f"{utt_id} done!") + # break print(f"generation speed: {N / T}Hz, RTF: {am_config.fs / (N / T) }") diff --git a/paddlespeech/t2s/models/diffsinger/diffsinger.py b/paddlespeech/t2s/models/diffsinger/diffsinger.py index de9b39602..2faf80cb2 100644 --- a/paddlespeech/t2s/models/diffsinger/diffsinger.py +++ b/paddlespeech/t2s/models/diffsinger/diffsinger.py @@ -48,12 +48,12 @@ class DiffSinger(nn.Layer): note_num: int=300, is_slur_num: int=2, fastspeech2_params: Dict[str, Any]={ - "adim": 384, - "aheads": 4, - "elayers": 6, - "eunits": 1536, - "dlayers": 6, - "dunits": 1536, + "adim": 256, + "aheads": 2, + "elayers": 4, + "eunits": 1024, + "dlayers": 4, + "dunits": 1024, "postnet_layers": 5, "postnet_chans": 512, "postnet_filts": 5, @@ -74,6 +74,7 @@ class DiffSinger(nn.Layer): "transformer_dec_dropout_rate": 0.1, "transformer_dec_positional_dropout_rate": 0.1, "transformer_dec_attn_dropout_rate": 0.1, + "transformer_activation_type": "gelu", # duration predictor "duration_predictor_layers": 2, "duration_predictor_chans": 384, @@ -149,7 +150,7 @@ class DiffSinger(nn.Layer): self.fs2 = FastSpeech2MIDI( idim=idim, odim=odim, - fastspeech2_config=fastspeech2_params, + fastspeech2_params=fastspeech2_params, note_num=note_num, is_slur_num=is_slur_num) denoiser = WaveNetDenoiser(**denoiser_params) @@ -260,7 +261,7 @@ class DiffSinger(nn.Layer): Whether to get mel from fastspeech2 module. Returns: - _type_: _description_ + """ mel_fs2, _, _, _ = self.fs2.inference(text, note, note_dur, is_slur) if get_mel_fs2: @@ -268,7 +269,9 @@ class DiffSinger(nn.Layer): mel_fs2 = mel_fs2.unsqueeze(0).transpose((0, 2, 1)) cond_fs2 = self.fs2.encoder_infer(text, note, note_dur, is_slur) cond_fs2 = cond_fs2.transpose((0, 2, 1)) - mel, _ = self.diffusion(mel_fs2, cond_fs2) + # mel, _ = self.diffusion(mel_fs2, cond_fs2) + noise = paddle.randn(mel_fs2.shape) + mel = self.diffusion.inference(noise=noise, cond=cond_fs2, ref_x=mel_fs2, num_inference_steps=100) mel = mel.transpose((0, 2, 1)) return mel[0] @@ -280,13 +283,32 @@ class DiffSingerInference(nn.Layer): self.acoustic_model = model def forward(self, text, note, note_dur, is_slur, get_mel_fs2: bool=False): + """Calculate forward propagation. + + Args: + text(Tensor(int64)): + Batch of padded token (phone) ids (B, Tmax). + note(Tensor(int64)): + Batch of padded note (element in music score) ids (B, Tmax). + note_dur(Tensor(float32)): + Batch of padded note durations in seconds (element in music score) (B, Tmax). + is_slur(Tensor(int64)): + Batch of padded slur (element in music score) ids (B, Tmax). + get_mel_fs2 (bool, optional): . Defaults to False. + Whether to get mel from fastspeech2 module. + + Returns: + logmel(Tensor(float32)): denorm logmel, [T, mel_bin] + """ normalized_mel = self.acoustic_model.inference( text, note=note, note_dur=note_dur, is_slur=is_slur, get_mel_fs2=get_mel_fs2) - logmel = self.normalizer.inverse(normalized_mel) + print(normalized_mel) + # logmel = self.normalizer.inverse(normalized_mel) + logmel = normalized_mel return logmel diff --git a/paddlespeech/t2s/models/diffsinger/diffsinger_updater.py b/paddlespeech/t2s/models/diffsinger/diffsinger_updater.py index 7ffe3198f..a91fee637 100644 --- a/paddlespeech/t2s/models/diffsinger/diffsinger_updater.py +++ b/paddlespeech/t2s/models/diffsinger/diffsinger_updater.py @@ -41,7 +41,6 @@ class DiffSingerUpdater(StandardUpdater): optimizers: Dict[str, Optimizer], criterions: Dict[str, Layer], dataloader: DataLoader, - fs2_train_start_steps: int=0, ds_train_start_steps: int=160000, output_dir: Path=None, ): super().__init__(model, optimizers, dataloader, init_state=None) @@ -58,7 +57,6 @@ class DiffSingerUpdater(StandardUpdater): self.dataloader = dataloader - self.fs2_train_start_steps = fs2_train_start_steps self.ds_train_start_steps = ds_train_start_steps self.state = UpdaterState(iteration=0, epoch=0) @@ -81,7 +79,8 @@ class DiffSingerUpdater(StandardUpdater): spk_id = None # only train fastspeech2 module firstly - if self.state.iteration > self.fs2_train_start_steps and self.state.iteration < self.ds_train_start_steps: + if self.state.iteration <= self.ds_train_start_steps: + # print(batch) before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model( text=batch["text"], note=batch["note"], @@ -97,7 +96,7 @@ class DiffSingerUpdater(StandardUpdater): spk_emb=spk_emb, train_fs2=True, ) - l1_loss_fs2, duration_loss, pitch_loss, energy_loss, speaker_loss = self.criterion_fs2( + l1_loss_fs2, ssim_loss_fs2, duration_loss, pitch_loss, energy_loss, speaker_loss = self.criterion_fs2( after_outs=after_outs, before_outs=before_outs, d_outs=d_outs, @@ -111,8 +110,8 @@ class DiffSingerUpdater(StandardUpdater): olens=olens, spk_logits=spk_logits, spk_ids=spk_id, ) - - loss_fs2 = l1_loss_fs2 + duration_loss + pitch_loss + energy_loss + + loss_fs2 = l1_loss_fs2 + ssim_loss_fs2 + duration_loss + pitch_loss + energy_loss self.optimizer_fs2.clear_grad() loss_fs2.backward() @@ -120,11 +119,13 @@ class DiffSingerUpdater(StandardUpdater): report("train/loss_fs2", float(loss_fs2)) report("train/l1_loss_fs2", float(l1_loss_fs2)) + report("train/ssim_loss_fs2", float(ssim_loss_fs2)) report("train/duration_loss", float(duration_loss)) report("train/pitch_loss", float(pitch_loss)) report("train/energy_loss", float(energy_loss)) losses_dict["l1_loss_fs2"] = float(l1_loss_fs2) + losses_dict["ssim_loss_fs2"] = float(ssim_loss_fs2) losses_dict["duration_loss"] = float(duration_loss) losses_dict["pitch_loss"] = float(pitch_loss) losses_dict["energy_loss"] = float(energy_loss) diff --git a/paddlespeech/t2s/models/diffsinger/fastspeech2midi.py b/paddlespeech/t2s/models/diffsinger/fastspeech2midi.py index 85f9a9550..0e24e69fb 100644 --- a/paddlespeech/t2s/models/diffsinger/fastspeech2midi.py +++ b/paddlespeech/t2s/models/diffsinger/fastspeech2midi.py @@ -22,9 +22,11 @@ from paddle import nn from typeguard import check_argument_types from paddlespeech.t2s.models.fastspeech2 import FastSpeech2 +from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Loss from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask from paddlespeech.t2s.modules.nets_utils import make_pad_mask -from paddlespeech.t2s.modules.predictor.duration_predictor import DurationPredictorLoss +from paddlespeech.t2s.modules.losses import ssim +from paddlespeech.t2s.modules.masked_fill import masked_fill class FastSpeech2MIDI(FastSpeech2): @@ -36,14 +38,14 @@ class FastSpeech2MIDI(FastSpeech2): # fastspeech2 network structure related idim: int, odim: int, - fastspeech2_config: Dict[str, Any], + fastspeech2_params: Dict[str, Any], # note emb note_num: int=300, # is_slur emb is_slur_num: int=2, ): """Initialize FastSpeech2 module for svs. Args: - fastspeech2_config (Dict): + fastspeech2_params (Dict): The config of FastSpeech2 module on DiffSinger model note_num (Optional[int]): Number of note. If not None, assume that the @@ -54,9 +56,9 @@ class FastSpeech2MIDI(FastSpeech2): """ assert check_argument_types() - super().__init__(idim=idim, odim=odim, **fastspeech2_config) + super().__init__(idim=idim, odim=odim, **fastspeech2_params) - self.note_embed_dim = self.is_slur_embed_dim = fastspeech2_config[ + self.note_embed_dim = self.is_slur_embed_dim = fastspeech2_params[ "adim"] if note_num is not None: @@ -133,15 +135,15 @@ class FastSpeech2MIDI(FastSpeech2): spk_id = paddle.cast(spk_id, 'int64') # forward propagation before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits = self._forward( - xs, - note, - note_dur, - is_slur, - ilens, - olens, - ds, - ps, - es, + xs=xs, + note=note, + note_dur=note_dur, + is_slur=is_slur, + ilens=ilens, + olens=olens, + ds=ds, + ps=ps, + es=es, is_inference=False, spk_emb=spk_emb, spk_id=spk_id, ) @@ -170,6 +172,8 @@ class FastSpeech2MIDI(FastSpeech2): alpha: float=1.0, spk_emb=None, spk_id=None, ) -> Sequence[paddle.Tensor]: + + before_outs = after_outs = d_outs = p_outs = e_outs = spk_logits = None # forward encoder x_masks = self._source_mask(ilens) note_emb = self.note_embedding_table(note) @@ -206,16 +210,17 @@ class FastSpeech2MIDI(FastSpeech2): else: pitch_masks = None - # inference for decoder input for duffusion + # inference for decoder input for diffusion if is_train_diffusion: hs = self.length_regulator(hs, ds, is_inference=False) p_outs = self.pitch_predictor(hs.detach(), pitch_masks) - e_outs = self.energy_predictor(hs.detach(), pitch_masks) + # e_outs = self.energy_predictor(hs.detach(), pitch_masks) p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose( (0, 2, 1)) - e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose( - (0, 2, 1)) - hs = hs + e_embs + p_embs + # e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose( + # (0, 2, 1)) + # hs = hs + p_embs + e_embs + hs = hs + p_embs elif is_inference: # (B, Tmax) @@ -235,19 +240,20 @@ class FastSpeech2MIDI(FastSpeech2): else: p_outs = self.pitch_predictor(hs, pitch_masks) - if es is not None: - e_outs = es - else: - if self.stop_gradient_from_energy_predictor: - e_outs = self.energy_predictor(hs.detach(), pitch_masks) - else: - e_outs = self.energy_predictor(hs, pitch_masks) + # if es is not None: + # e_outs = es + # else: + # if self.stop_gradient_from_energy_predictor: + # e_outs = self.energy_predictor(hs.detach(), pitch_masks) + # else: + # e_outs = self.energy_predictor(hs, pitch_masks) p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose( (0, 2, 1)) - e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose( - (0, 2, 1)) - hs = hs + e_embs + p_embs + # e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose( + # (0, 2, 1)) + # hs = hs + p_embs + e_embs + hs = hs + p_embs # training else: @@ -258,15 +264,16 @@ class FastSpeech2MIDI(FastSpeech2): p_outs = self.pitch_predictor(hs.detach(), pitch_masks) else: p_outs = self.pitch_predictor(hs, pitch_masks) - if self.stop_gradient_from_energy_predictor: - e_outs = self.energy_predictor(hs.detach(), pitch_masks) - else: - e_outs = self.energy_predictor(hs, pitch_masks) + # if self.stop_gradient_from_energy_predictor: + # e_outs = self.energy_predictor(hs.detach(), pitch_masks) + # else: + # e_outs = self.energy_predictor(hs, pitch_masks) p_embs = self.pitch_embed(ps.transpose((0, 2, 1))).transpose( (0, 2, 1)) - e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose( - (0, 2, 1)) - hs = hs + e_embs + p_embs + # e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose( + # (0, 2, 1)) + # hs = hs + p_embs + e_embs + hs = hs + p_embs # forward decoder if olens is not None and not is_inference: @@ -295,11 +302,12 @@ class FastSpeech2MIDI(FastSpeech2): (paddle.shape(zs)[0], -1, self.odim)) # postnet -> (B, Lmax//r * r, odim) - if self.postnet is None: - after_outs = before_outs - else: - after_outs = before_outs + self.postnet( - before_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) + # if self.postnet is None: + # after_outs = before_outs + # else: + # after_outs = before_outs + self.postnet( + # before_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) + after_outs = before_outs return before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits @@ -326,11 +334,11 @@ class FastSpeech2MIDI(FastSpeech2): # (1, L, odim) # use *_ to avoid bug in dygraph to static graph hs, _ = self._forward( - xs, - note, - note_dur, - is_slur, - ilens, + xs=xs, + note=note, + note_dur=note_dur, + is_slur=is_slur, + ilens=ilens, is_inference=True, return_after_enc=True, alpha=alpha, @@ -367,15 +375,15 @@ class FastSpeech2MIDI(FastSpeech2): # (1, L, odim) # use *_ to avoid bug in dygraph to static graph hs, h_masks = self._forward( - xs, - note, - note_dur, - is_slur, - ilens, - olens, - ds, - ps, - es, + xs=xs, + note=note, + note_dur=note_dur, + is_slur=is_slur, + ilens=ilens, + olens=olens, + ds=ds, + ps=ps, + es=es, return_after_enc=True, is_train_diffusion=True, alpha=alpha, @@ -446,11 +454,11 @@ class FastSpeech2MIDI(FastSpeech2): # (1, L, odim) _, outs, d_outs, p_outs, e_outs, _ = self._forward( - xs, - note, - note_dur, - is_slur, - ilens, + xs=xs, + note=note, + note_dur=note_dur, + is_slur=is_slur, + ilens=ilens, ds=ds, ps=ps, es=es, @@ -460,20 +468,21 @@ class FastSpeech2MIDI(FastSpeech2): else: # (1, L, odim) _, outs, d_outs, p_outs, e_outs, _ = self._forward( - xs, - note, - note_dur, - is_slur, - ilens, + xs=xs, + note=note, + note_dur=note_dur, + is_slur=is_slur, + ilens=ilens, is_inference=True, alpha=alpha, spk_emb=spk_emb, spk_id=spk_id, ) - return outs[0], d_outs[0], p_outs[0], e_outs[0] + # return outs[0], d_outs[0], p_outs[0], e_outs[0] + return outs[0], d_outs[0], p_outs[0], None -class FastSpeech2MIDILoss(nn.Layer): +class FastSpeech2MIDILoss(FastSpeech2Loss): """Loss function module for DiffSinger.""" def __init__(self, use_masking: bool=True, @@ -486,18 +495,7 @@ class FastSpeech2MIDILoss(nn.Layer): Whether to weighted masking in loss calculation. """ assert check_argument_types() - super().__init__() - - assert (use_masking != use_weighted_masking) or not use_masking - self.use_masking = use_masking - self.use_weighted_masking = use_weighted_masking - - # define criterions - reduction = "none" if self.use_weighted_masking else "mean" - self.l1_criterion = nn.L1Loss(reduction=reduction) - self.mse_criterion = nn.MSELoss(reduction=reduction) - self.duration_criterion = DurationPredictorLoss(reduction=reduction) - self.ce_criterion = nn.CrossEntropyLoss() + super().__init__(use_masking, use_weighted_masking) def forward( self, @@ -551,15 +549,23 @@ class FastSpeech2MIDILoss(nn.Layer): """ - speaker_loss = 0.0 + l1_loss = duration_loss = pitch_loss = energy_loss = speaker_loss = ssim_loss = 0.0 + + out_pad_masks = make_pad_mask(olens).unsqueeze(-1) + before_outs_batch = masked_fill(before_outs, out_pad_masks, 0.0) + # print(before_outs.shape, ys.shape) + ssim_loss = 1.0 - ssim(before_outs_batch.unsqueeze(1), ys.unsqueeze(1)) + ssim_loss = ssim_loss * 0.5 + # apply mask to remove padded part if self.use_masking: out_masks = make_non_pad_mask(olens).unsqueeze(-1) before_outs = before_outs.masked_select( out_masks.broadcast_to(before_outs.shape)) - if after_outs is not None: - after_outs = after_outs.masked_select( - out_masks.broadcast_to(after_outs.shape)) + + # if after_outs is not None: + # after_outs = after_outs.masked_select( + # out_masks.broadcast_to(after_outs.shape)) ys = ys.masked_select(out_masks.broadcast_to(ys.shape)) duration_masks = make_non_pad_mask(ilens) d_outs = d_outs.masked_select( @@ -568,8 +574,8 @@ class FastSpeech2MIDILoss(nn.Layer): pitch_masks = out_masks p_outs = p_outs.masked_select( pitch_masks.broadcast_to(p_outs.shape)) - e_outs = e_outs.masked_select( - pitch_masks.broadcast_to(e_outs.shape)) + # e_outs = e_outs.masked_select( + # pitch_masks.broadcast_to(e_outs.shape)) ps = ps.masked_select(pitch_masks.broadcast_to(ps.shape)) es = es.masked_select(pitch_masks.broadcast_to(es.shape)) @@ -585,11 +591,17 @@ class FastSpeech2MIDILoss(nn.Layer): # calculate loss l1_loss = self.l1_criterion(before_outs, ys) - if after_outs is not None: - l1_loss += self.l1_criterion(after_outs, ys) + # if after_outs is not None: + # l1_loss += self.l1_criterion(after_outs, ys) + # ssim_loss += (1.0 - ssim(after_outs, ys)) + l1_loss = l1_loss * 0.5 + duration_loss = self.duration_criterion(d_outs, ds) - pitch_loss = self.mse_criterion(p_outs, ps) - energy_loss = self.mse_criterion(e_outs, es) + # print("ppppppppppoooooooooooo: ", p_outs, p_outs.shape) + # print("ppppppppppssssssssssss: ", ps, ps.shape) + # pitch_loss = self.mse_criterion(p_outs, ps) + # energy_loss = self.mse_criterion(e_outs, es) + pitch_loss = self.l1_criterion(p_outs, ps) if spk_logits is not None and spk_ids is not None: speaker_loss = self.ce_criterion(spk_logits, spk_ids) / batch_size @@ -618,8 +630,8 @@ class FastSpeech2MIDILoss(nn.Layer): pitch_loss = pitch_loss.multiply(pitch_weights) pitch_loss = pitch_loss.masked_select( pitch_masks.broadcast_to(pitch_loss.shape)).sum() - energy_loss = energy_loss.multiply(pitch_weights) - energy_loss = energy_loss.masked_select( - pitch_masks.broadcast_to(energy_loss.shape)).sum() + # energy_loss = energy_loss.multiply(pitch_weights) + # energy_loss = energy_loss.masked_select( + # pitch_masks.broadcast_to(energy_loss.shape)).sum() - return l1_loss, duration_loss, pitch_loss, energy_loss, speaker_loss + return l1_loss, ssim_loss, duration_loss, pitch_loss, energy_loss, speaker_loss diff --git a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py index 0eb44beb6..c13717578 100644 --- a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py +++ b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py @@ -93,6 +93,7 @@ class FastSpeech2(nn.Layer): transformer_dec_dropout_rate: float=0.1, transformer_dec_positional_dropout_rate: float=0.1, transformer_dec_attn_dropout_rate: float=0.1, + transformer_activation_type: str="relu", # for conformer conformer_pos_enc_layer_type: str="rel_pos", conformer_self_attn_layer_type: str="rel_selfattn", @@ -200,6 +201,8 @@ class FastSpeech2(nn.Layer): Dropout rate after decoder positional encoding. transformer_dec_attn_dropout_rate (float): Dropout rate in decoder self-attention module. + transformer_activation_type (str): + Activation function type in transformer. conformer_pos_enc_layer_type (str): Pos encoding layer type in conformer. conformer_self_attn_layer_type (str): @@ -250,7 +253,7 @@ class FastSpeech2(nn.Layer): Kernel size of energy embedding. energy_embed_dropout_rate (float): Dropout rate for energy embedding. - stop_gradient_from_energy_predictor(bool): + stop_gradient_from_energy_predictor (bool): Whether to stop gradient from energy predictor to encoder. spk_num (Optional[int]): Number of speakers. If not None, assume that the spk_embed_dim is not None, @@ -269,7 +272,7 @@ class FastSpeech2(nn.Layer): How to integrate tone embedding. init_type (str): How to initialize transformer parameters. - init_enc_alpha (float): + init_enc_alpha (float): Initial value of alpha in scaled pos encoding of the encoder. init_dec_alpha (float): Initial value of alpha in scaled pos encoding of the decoder. @@ -344,7 +347,8 @@ class FastSpeech2(nn.Layer): normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, - positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) + positionwise_conv_kernel_size=positionwise_conv_kernel_size, + activation_type=transformer_activation_type) elif encoder_type == "conformer": self.encoder = ConformerEncoder( idim=idim, @@ -453,7 +457,8 @@ class FastSpeech2(nn.Layer): normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, - positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) + positionwise_conv_kernel_size=positionwise_conv_kernel_size, + activation_type=conformer_activation_type, ) elif decoder_type == "conformer": self.decoder = ConformerEncoder( idim=0, diff --git a/paddlespeech/t2s/modules/activation.py b/paddlespeech/t2s/modules/activation.py index 8d8cd62ef..f1c099b76 100644 --- a/paddlespeech/t2s/modules/activation.py +++ b/paddlespeech/t2s/modules/activation.py @@ -37,7 +37,8 @@ def get_activation(act, **kwargs): "selu": paddle.nn.SELU, "leakyrelu": paddle.nn.LeakyReLU, "swish": paddle.nn.Swish, - "glu": GLU + "glu": GLU, + "gelu": paddle.nn.GELU, } return activation_funcs[act](**kwargs) diff --git a/paddlespeech/t2s/modules/diffusion.py b/paddlespeech/t2s/modules/diffusion.py index 52fe84ceb..66754e3b2 100644 --- a/paddlespeech/t2s/modules/diffusion.py +++ b/paddlespeech/t2s/modules/diffusion.py @@ -40,7 +40,7 @@ class WaveNetDenoiser(nn.Layer): layers (int, optional): Number of residual blocks inside, by default 20 stacks (int, optional): - The number of groups to split the residual blocks into, by default 4 + The number of groups to split the residual blocks into, by default 5 Within each group, the dilation of the residual block grows exponentially. residual_channels (int, optional): Residual channel of the residual blocks, by default 256 @@ -64,7 +64,7 @@ class WaveNetDenoiser(nn.Layer): out_channels: int=80, kernel_size: int=3, layers: int=20, - stacks: int=4, + stacks: int=5, residual_channels: int=256, gate_channels: int=512, skip_channels: int=256, @@ -72,7 +72,7 @@ class WaveNetDenoiser(nn.Layer): dropout: float=0., bias: bool=True, use_weight_norm: bool=False, - init_type: str="kaiming_uniform", ): + init_type: str="kaiming_normal", ): super().__init__() # initialize parameters @@ -118,18 +118,15 @@ class WaveNetDenoiser(nn.Layer): bias=bias) self.conv_layers.append(conv) + final_conv = nn.Conv1D(skip_channels, out_channels, 1, bias_attr=True) + nn.initializer.Constant(0.0)(final_conv.weight) self.last_conv_layers = nn.Sequential(nn.ReLU(), nn.Conv1D( skip_channels, skip_channels, 1, bias_attr=True), - nn.ReLU(), - nn.Conv1D( - skip_channels, - out_channels, - 1, - bias_attr=True)) + nn.ReLU(), final_conv) if use_weight_norm: self.apply_weight_norm() @@ -200,10 +197,6 @@ class GaussianDiffusion(nn.Layer): Args: denoiser (Layer, optional): The model used for denoising noises. - In fact, the denoiser model performs the operation - of producing a output with more noises from the noisy input. - Then we use the diffusion algorithm to calculate - the input with the output to get the denoised result. num_train_timesteps (int, optional): The number of timesteps between the noise and the real during training, by default 1000. beta_start (float, optional): @@ -233,7 +226,8 @@ class GaussianDiffusion(nn.Layer): >>> def callback(index, timestep, num_timesteps, sample): >>> nonlocal pbar >>> if pbar is None: - >>> pbar = tqdm(total=num_timesteps-index) + >>> pbar = tqdm(total=num_timesteps) + >>> pbar.update(index) >>> pbar.update() >>> >>> return callback @@ -247,7 +241,7 @@ class GaussianDiffusion(nn.Layer): >>> diffusion = GaussianDiffusion(denoiser, num_train_timesteps=ds, num_max_timesteps=K_step) >>> with paddle.no_grad(): >>> sample = diffusion.inference( - >>> paddle.randn(x.shape), c, x, + >>> paddle.randn(x.shape), c, ref_x=x_in, >>> num_inference_steps=infer_steps, >>> scheduler_type=scheduler_type, >>> callback=create_progress_callback()) @@ -262,7 +256,7 @@ class GaussianDiffusion(nn.Layer): >>> diffusion = GaussianDiffusion(denoiser, num_train_timesteps=ds, num_max_timesteps=K_step) >>> with paddle.no_grad(): >>> sample = diffusion.inference( - >>> paddle.randn(x.shape), c, x_in, + >>> paddle.randn(x.shape), c, ref_x=x_in, >>> num_inference_steps=infer_steps, >>> scheduler_type=scheduler_type, >>> callback=create_progress_callback()) @@ -277,11 +271,11 @@ class GaussianDiffusion(nn.Layer): >>> diffusion = GaussianDiffusion(denoiser, num_train_timesteps=ds, num_max_timesteps=K_step) >>> with paddle.no_grad(): >>> sample = diffusion.inference( - >>> paddle.randn(x.shape), c, None, + >>> paddle.randn(x.shape), c, ref_x=x_in, >>> num_inference_steps=infer_steps, >>> scheduler_type=scheduler_type, >>> callback=create_progress_callback()) - 100%|█████| 25/25 [00:01<00:00, 19.75it/s] + 100%|█████| 34/34 [00:01<00:00, 19.75it/s] >>> >>> # ds=1000, K_step=100, scheduler=pndm, infer_step=50, from aux fs2 mel output >>> ds = 1000 @@ -292,11 +286,11 @@ class GaussianDiffusion(nn.Layer): >>> diffusion = GaussianDiffusion(denoiser, num_train_timesteps=ds, num_max_timesteps=K_step) >>> with paddle.no_grad(): >>> sample = diffusion.inference( - >>> paddle.randn(x.shape), c, x, + >>> paddle.randn(x.shape), c, ref_x=x_in, >>> num_inference_steps=infer_steps, >>> scheduler_type=scheduler_type, >>> callback=create_progress_callback()) - 100%|█████| 5/5 [00:00<00:00, 23.80it/s] + 100%|█████| 14/14 [00:00<00:00, 23.80it/s] """ @@ -366,6 +360,8 @@ class GaussianDiffusion(nn.Layer): num_inference_steps: Optional[int]=1000, strength: Optional[float]=None, scheduler_type: Optional[str]="ddpm", + clip_noise: Optional[bool]=True, + clip_noise_range: Optional[Tuple[float, float]]=(-1, 1), callback: Optional[Callable[[int, int, int, paddle.Tensor], None]]=None, callback_steps: Optional[int]=1): @@ -386,6 +382,10 @@ class GaussianDiffusion(nn.Layer): scheduler_type (str, optional): Noise scheduler for generate noises. Choose a great scheduler can skip many denoising step, by default 'ddpm'. + clip_noise (bool, optional): + Whether to clip each denoised output, by default True. + clip_noise_range (tuple, optional): + denoised output min and max value range after clip, by default (-1, 1). callback (Callable[[int,int,int,Tensor], None], optional): Callback function during denoising steps. @@ -426,6 +426,7 @@ class GaussianDiffusion(nn.Layer): scheduler.set_timesteps(num_inference_steps) # prepare first noise variables + import pdb;pdb.set_trace() noisy_input = noise timesteps = scheduler.timesteps if ref_x is not None: @@ -444,8 +445,13 @@ class GaussianDiffusion(nn.Layer): noisy_input = scheduler.add_noise( ref_x, noise, timesteps[:1].tile([noise.shape[0]])) + + # denoising loop denoised_output = noisy_input + if clip_noise: + n_min, n_max = clip_noise_range + denoised_output = paddle.clip(denoised_output, n_min, n_max) num_warmup_steps = len( timesteps) - num_inference_steps * scheduler.order for i, t in enumerate(timesteps): @@ -457,6 +463,8 @@ class GaussianDiffusion(nn.Layer): # compute the previous noisy sample x_t -> x_t-1 denoised_output = scheduler.step(noise_pred, t, denoised_output).prev_sample + if clip_noise: + denoised_output = paddle.clip(denoised_output, n_min, n_max) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and @@ -464,4 +472,4 @@ class GaussianDiffusion(nn.Layer): if callback is not None and i % callback_steps == 0: callback(i, t, len(timesteps), denoised_output) - return denoised_output + return denoised_output \ No newline at end of file diff --git a/paddlespeech/t2s/modules/masked_fill.py b/paddlespeech/t2s/modules/masked_fill.py index b32222547..1445a926a 100644 --- a/paddlespeech/t2s/modules/masked_fill.py +++ b/paddlespeech/t2s/modules/masked_fill.py @@ -38,11 +38,9 @@ def masked_fill(xs: paddle.Tensor, value: Union[float, int]): # comment following line for converting dygraph to static graph. # assert is_broadcastable(xs.shape, mask.shape) is True - # bshape = paddle.broadcast_shape(xs.shape, mask.shape) bshape = broadcast_shape(xs.shape, mask.shape) mask.stop_gradient = True mask = mask.broadcast_to(bshape) - trues = paddle.ones_like(xs) * value mask = mask.cast(dtype=paddle.bool) xs = paddle.where(mask, trues, xs) diff --git a/paddlespeech/t2s/modules/predictor/variance_predictor.py b/paddlespeech/t2s/modules/predictor/variance_predictor.py index 4c2a67cc4..197f73595 100644 --- a/paddlespeech/t2s/modules/predictor/variance_predictor.py +++ b/paddlespeech/t2s/modules/predictor/variance_predictor.py @@ -96,7 +96,7 @@ class VariancePredictor(nn.Layer): xs = f(xs) # (B, Tmax, 1) xs = self.linear(xs.transpose([0, 2, 1])) - + if x_masks is not None: xs = masked_fill(xs, x_masks, 0.0) return xs