add style_melgan

pull/1068/head
TianYuan 3 years ago
parent ef8e61813a
commit dd36eafe34

@ -35,7 +35,7 @@ generator_params:
dropout: 0.0 # Dropout rate. 0.0 means no dropout applied.
use_weight_norm: true # Whether to use weight norm.
# If set to true, it will be applied to all of the conv layers.
upsample_scales: [4, 5, 3, 5] # Upsampling scales. Prodcut of these must be the same as hop size.
upsample_scales: [4, 5, 3, 5] # Upsampling scales. prod(upsample_scales) == n_shift
###########################################################
# DISCRIMINATOR NETWORK ARCHITECTURE SETTING #
@ -71,7 +71,7 @@ lambda_adv: 4.0 # Loss balancing coefficient.
# DATA LOADER SETTING #
###########################################################
batch_size: 8 # Batch size.
batch_max_steps: 24000 # Length of each audio in batch. Make sure dividable by hop_size.
batch_max_steps: 24000 # Length of each audio in batch. Make sure dividable by n_shift.
pin_memory: true # Whether to pin memory in Pytorch DataLoader.
num_workers: 4 # Number of workers in Pytorch DataLoader.
remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps.

@ -6,8 +6,9 @@ ckpt_name=$3
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/synthesize.py \
python3 ${BIN_DIR}/../synthesize.py \
--config=${config_path} \
--checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
--test-metadata=dump/test/norm/metadata.jsonl \
--output-dir=${train_output_path}/test
--output-dir=${train_output_path}/test \
--generator-type=pwgan

@ -9,3 +9,4 @@
* voc1 - Parallel WaveGAN
* voc2 - MelGAN
* voc3 - MultiBand MelGAN
* voc4 - Style MelGAN

@ -78,7 +78,7 @@ lambda_adv: 4.0 # Loss balancing coefficient.
# DATA LOADER SETTING #
###########################################################
batch_size: 8 # Batch size.
batch_max_steps: 25500 # Length of each audio in batch. Make sure dividable by hop_size.
batch_max_steps: 25500 # Length of each audio in batch. Make sure dividable by n_shift.
pin_memory: true # Whether to pin memory in Pytorch DataLoader.
num_workers: 2 # Number of workers in Pytorch DataLoader.
remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps.
@ -88,23 +88,23 @@ allow_cache: true # Whether to allow cache in dataset. If true, it requ
# OPTIMIZER & SCHEDULER SETTING #
###########################################################
generator_optimizer_params:
epsilon: 1.0e-6 # Generator's epsilon.
epsilon: 1.0e-6 # Generator's epsilon.
weight_decay: 0.0 # Generator's weight decay coefficient.
generator_scheduler_params:
learning_rate: 0.0001 # Generator's learning rate.
learning_rate: 0.0001 # Generator's learning rate.
step_size: 200000 # Generator's scheduler step size.
gamma: 0.5 # Generator's scheduler gamma.
# At each step size, lr will be multiplied by this parameter.
generator_grad_norm: 10 # Generator's gradient norm.
discriminator_optimizer_params:
epsilon: 1.0e-6 # Discriminator's epsilon.
weight_decay: 0.0 # Discriminator's weight decay coefficient.
weight_decay: 0.0 # Discriminator's weight decay coefficient.
discriminator_scheduler_params:
learning_rate: 0.00005 # Discriminator's learning rate.
step_size: 200000 # Discriminator's scheduler step size.
gamma: 0.5 # Discriminator's scheduler gamma.
# At each step size, lr will be multiplied by this parameter.
discriminator_grad_norm: 1 # Discriminator's gradient norm.
learning_rate: 0.00005 # Discriminator's learning rate.
step_size: 200000 # Discriminator's scheduler step size.
gamma: 0.5 # Discriminator's scheduler gamma.
# At each step size, lr will be multiplied by this parameter.
discriminator_grad_norm: 1 # Discriminator's gradient norm.
###########################################################
# INTERVAL SETTING #

@ -6,8 +6,9 @@ ckpt_name=$3
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/synthesize.py \
python3 ${BIN_DIR}/../synthesize.py \
--config=${config_path} \
--checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
--test-metadata=dump/test/norm/metadata.jsonl \
--output-dir=${train_output_path}/test
--output-dir=${train_output_path}/test \
--generator-type=pwgan

@ -6,8 +6,7 @@
# This configuration is based on full-band MelGAN but the hop size and sampling
# rate is different from the paper (16kHz vs 24kHz). The number of iteraions
# is not shown in the paper so currently we train 1M iterations (not sure enough
# to converge). The optimizer setting is based on @dathudeptrai advice.
# https://github.com/kan-bayashi/ParallelWaveGAN/issues/143#issuecomment-632539906
# to converge).
###########################################################
# FEATURE EXTRACTION SETTING #
@ -30,7 +29,7 @@ generator_params:
out_channels: 4 # Number of output channels.
kernel_size: 7 # Kernel size of initial and final conv layers.
channels: 384 # Initial number of channels for conv layers.
upsample_scales: [5, 5, 3] # List of Upsampling scales.
upsample_scales: [5, 5, 3] # List of Upsampling scales. prod(upsample_scales) == n_shift
stack_kernel_size: 3 # Kernel size of dilated conv layers in residual stack.
stacks: 4 # Number of stacks in a single residual stack module.
use_weight_norm: True # Whether to use weight normalization.
@ -67,7 +66,7 @@ discriminator_params:
use_stft_loss: true
stft_loss_params:
fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss.
hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss
hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss.
win_lengths: [600, 1200, 240] # List of window length for STFT-based loss.
window: "hann" # Window function for STFT-based loss
use_subband_stft_loss: true
@ -87,7 +86,7 @@ lambda_adv: 2.5 # Loss balancing coefficient for adversarial loss.
# DATA LOADER SETTING #
###########################################################
batch_size: 64 # Batch size.
batch_max_steps: 16200 # Length of each audio in batch. Make sure dividable by hop_size.
batch_max_steps: 16200 # Length of each audio in batch. Make sure dividable by n_shift.
num_workers: 2 # Number of workers in DataLoader.
###########################################################
@ -109,7 +108,7 @@ generator_scheduler_params:
- 500000
- 600000
discriminator_optimizer_params:
epsilon: 1.0e-7 # Discriminator's epsilon.
epsilon: 1.0e-7 # Discriminator's epsilon.
weight_decay: 0.0 # Discriminator's weight decay coefficient.
discriminator_grad_norm: -1 # Discriminator's gradient norm.
@ -129,7 +128,7 @@ discriminator_scheduler_params:
###########################################################
discriminator_train_start_steps: 200000 # Number of steps to start to train discriminator.
train_max_steps: 1000000 # Number of training steps.
save_interval_steps: 5000 # Interval steps to save checkpoint.
save_interval_steps: 5000 # Interval steps to save checkpoint.
eval_interval_steps: 1000 # Interval steps to evaluate the network.
###########################################################

@ -6,8 +6,7 @@
# This configuration is based on full-band MelGAN but the hop size and sampling
# rate is different from the paper (16kHz vs 24kHz). The number of iteraions
# is not shown in the paper so currently we train 1M iterations (not sure enough
# to converge). The optimizer setting is based on @dathudeptrai advice.
# https://github.com/kan-bayashi/ParallelWaveGAN/issues/143#issuecomment-632539906
# to converge).
###########################################################
# FEATURE EXTRACTION SETTING #
@ -30,7 +29,7 @@ generator_params:
out_channels: 4 # Number of output channels.
kernel_size: 7 # Kernel size of initial and final conv layers.
channels: 384 # Initial number of channels for conv layers.
upsample_scales: [5, 5, 3] # List of Upsampling scales.
upsample_scales: [5, 5, 3] # List of Upsampling scales. prod(upsample_scales) == n_shift
stack_kernel_size: 3 # Kernel size of dilated conv layers in residual stack.
stacks: 4 # Number of stacks in a single residual stack module.
use_weight_norm: True # Whether to use weight normalization.
@ -73,7 +72,7 @@ stft_loss_params:
use_subband_stft_loss: true
subband_stft_loss_params:
fft_sizes: [384, 683, 171] # List of FFT size for STFT-based loss.
hop_sizes: [30, 60, 10] # List of hop size for STFT-based loss
hop_sizes: [30, 60, 10] # List of hop size for STFT-based loss.
win_lengths: [150, 300, 60] # List of window length for STFT-based loss.
window: "hann" # Window function for STFT-based loss
@ -87,7 +86,7 @@ lambda_adv: 2.5 # Loss balancing coefficient for adversarial loss.
# DATA LOADER SETTING #
###########################################################
batch_size: 64 # Batch size.
batch_max_steps: 16200 # Length of each audio in batch. Make sure dividable by hop_size.
batch_max_steps: 16200 # Length of each audio in batch. Make sure dividable by n_shift.
num_workers: 2 # Number of workers in DataLoader.
###########################################################
@ -109,7 +108,7 @@ generator_scheduler_params:
- 500000
- 600000
discriminator_optimizer_params:
epsilon: 1.0e-7 # Discriminator's epsilon.
epsilon: 1.0e-7 # Discriminator's epsilon.
weight_decay: 0.0 # Discriminator's weight decay coefficient.
discriminator_grad_norm: -1 # Discriminator's gradient norm.
@ -129,7 +128,7 @@ discriminator_scheduler_params:
###########################################################
discriminator_train_start_steps: 200000 # Number of steps to start to train discriminator.
train_max_steps: 2000000 # Number of training steps.
save_interval_steps: 1000 # Interval steps to save checkpoint.
save_interval_steps: 1000 # Interval steps to save checkpoint.
eval_interval_steps: 1000 # Interval steps to evaluate the network.
###########################################################

@ -6,8 +6,9 @@ ckpt_name=$3
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/synthesize.py \
python3 ${BIN_DIR}/../synthesize.py \
--config=${config_path} \
--checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
--test-metadata=dump/test/norm/metadata.jsonl \
--output-dir=${train_output_path}/test
--output-dir=${train_output_path}/test \
--generator-type=mb_melgan

@ -0,0 +1,136 @@
# This is the configuration file for CSMSC dataset.This configuration is based
# on StyleMelGAN paper but uses MSE loss instead of Hinge loss. And I found that
# batch_size = 8 is also working good. So maybe if you want to accelerate the training,
# you can reduce the batch size (e.g. 8 or 16). Upsampling scales is modified to
# fit the shift size 300 pt.
# NOTE: batch_max_steps(24000) == prod(noise_upsample_scales)(80) * prod(upsample_scales)(300)
###########################################################
# FEATURE EXTRACTION SETTING #
###########################################################
fs: 24000 # Sampling rate.
n_fft: 2048 # FFT size. (in samples)
n_shift: 300 # Hop size. (in samples)
win_length: 1200 # Window length. (in samples)
# If set to null, it will be the same as fft_size.
window: "hann" # Window function.
n_mels: 80 # Number of mel basis.
fmin: 80 # Minimum freq in mel basis calculation. (Hz)
fmax: 7600 # Maximum frequency in mel basis calculation. (Hz)
###########################################################
# GENERATOR NETWORK ARCHITECTURE SETTING #
###########################################################
generator_params:
in_channels: 128 # Number of input channels.
aux_channels: 80
channels: 64 # Initial number of channels for conv layers.
out_channels: 1 # Number of output channels.
kernel_size: 9 # Kernel size of initial and final conv layers.
dilation: 2
bias: True
noise_upsample_scales: [10, 2, 2, 2]
noise_upsample_activation: "leakyrelu"
noise_upsample_activation_params:
negative_slope: 0.2
upsample_scales: [5, 1, 5, 1, 3, 1, 2, 2, 1] # List of Upsampling scales. prod(upsample_scales) == n_shift
upsample_mode: "nearest"
gated_function: "softmax"
use_weight_norm: True # Whether to use weight normalization.
###########################################################
# DISCRIMINATOR NETWORK ARCHITECTURE SETTING #
###########################################################
discriminator_params:
repeats: 4
window_sizes: [512, 1024, 2048, 4096]
pqmf_params:
- [1, None, None, None]
- [2, 62, 0.26700, 9.0]
- [4, 62, 0.14200, 9.0]
- [8, 62, 0.07949, 9.0]
discriminator_params:
out_channels: 1 # Number of output channels.
kernel_sizes: [5, 3] # List of kernel size.
channels: 16 # Number of channels of the initial conv layer.
max_downsample_channels: 512 # Maximum number of channels of downsampling layers.
bias: True
downsample_scales: [4, 4, 4, 1] # List of downsampling scales.
nonlinear_activation: "leakyrelu" # Nonlinear activation function.
nonlinear_activation_params: # Parameters of nonlinear activation function.
negative_slope: 0.2
use_weight_norm: True # Whether to use weight norm.
###########################################################
# STFT LOSS SETTING #
###########################################################
use_stft_loss: true
stft_loss_params:
fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss.
hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss
win_lengths: [600, 1200, 240] # List of window length for STFT-based loss.
window: "hann" # Window function for STFT-based loss
lambda_aux: 1.0 # Loss balancing coefficient for aux loss.
###########################################################
# ADVERSARIAL LOSS SETTING #
###########################################################
lambda_adv: 1.0 # Loss balancing coefficient for adv loss.
generator_adv_loss_params:
average_by_discriminators: false # Whether to average loss by #discriminators.
discriminator_adv_loss_params:
average_by_discriminators: false # Whether to average loss by #discriminators.
###########################################################
# DATA LOADER SETTING #
###########################################################
batch_size: 32 # Batch size.
# batch_max_steps(24000) == prod(noise_upsample_scales)(80) * prod(upsample_scales)(300, n_shift)
batch_max_steps: 24000 # Length of each audio in batch. Make sure dividable by n_shift.
num_workers: 2 # Number of workers in Pytorch DataLoader.
###########################################################
# OPTIMIZER & SCHEDULER SETTING #
###########################################################
generator_optimizer_params:
beta1: 0.5
beta2: 0.9
weight_decay: 0.0 # Generator's weight decay coefficient.
generator_scheduler_params:
learning_rate: 1.0e-4 # Generator's learning rate.
gamma: 0.5 # Generator's scheduler gamma.
milestones: # At each milestone, lr will be multiplied by gamma.
- 100000
- 300000
- 500000
- 700000
- 900000
generator_grad_norm: -1 # Generator's gradient norm.
discriminator_optimizer_params:
beta1: 0.5
beta2: 0.9
weight_decay: 0.0 # Discriminator's weight decay coefficient.
discriminator_scheduler_params:
learning_rate: 2.0e-4 # Discriminator's learning rate.
gamma: 0.5 # Discriminator's scheduler gamma.
milestones: # At each milestone, lr will be multiplied by gamma.
- 200000
- 400000
- 600000
- 800000
discriminator_grad_norm: -1 # Discriminator's gradient norm.
###########################################################
# INTERVAL SETTING #
###########################################################
discriminator_train_start_steps: 100000 # Number of steps to start to train discriminator.
train_max_steps: 1500000 # Number of training steps.
save_interval_steps: 5000 # Interval steps to save checkpoint.
eval_interval_steps: 1000 # Interval steps to evaluate the network.
###########################################################
# OTHER SETTING #
###########################################################
num_snapshots: 10 # max number of snapshots to keep while training
seed: 42 # random seed for paddle, random, and np.random

@ -0,0 +1,55 @@
#!/bin/bash
stage=0
stop_stage=100
config_path=$1
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# get durations from MFA's result
echo "Generate durations.txt from MFA results ..."
python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \
--inputdir=./baker_alignment_tone \
--output=durations.txt \
--config=${config_path}
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# extract features
echo "Extract features ..."
python3 ${BIN_DIR}/../preprocess.py \
--rootdir=~/datasets/BZNSYP/ \
--dataset=baker \
--dumpdir=dump \
--dur-file=durations.txt \
--config=${config_path} \
--cut-sil=True \
--num-cpu=20
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# get features' stats(mean and std)
echo "Get features' stats ..."
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
--metadata=dump/train/raw/metadata.jsonl \
--field-name="feats"
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# normalize, dev and test should use train's stats
echo "Normalize ..."
python3 ${BIN_DIR}/../normalize.py \
--metadata=dump/train/raw/metadata.jsonl \
--dumpdir=dump/train/norm \
--stats=dump/train/feats_stats.npy
python3 ${BIN_DIR}/../normalize.py \
--metadata=dump/dev/raw/metadata.jsonl \
--dumpdir=dump/dev/norm \
--stats=dump/train/feats_stats.npy
python3 ${BIN_DIR}/../normalize.py \
--metadata=dump/test/raw/metadata.jsonl \
--dumpdir=dump/test/norm \
--stats=dump/train/feats_stats.npy
fi

@ -0,0 +1,14 @@
#!/bin/bash
config_path=$1
train_output_path=$2
ckpt_name=$3
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../synthesize.py \
--config=${config_path} \
--checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
--test-metadata=dump/test/norm/metadata.jsonl \
--output-dir=${train_output_path}/test \
--generator-type=style_melgan

@ -0,0 +1,13 @@
#!/bin/bash
config_path=$1
train_output_path=$2
FLAGS_cudnn_exhaustive_search=true \
FLAGS_conv_workspace_size_limit=4000 \
python ${BIN_DIR}/train.py \
--train-metadata=dump/train/norm/metadata.jsonl \
--dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \
--output-dir=${train_output_path} \
--ngpu=1

@ -0,0 +1,13 @@
#!/bin/bash
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
MODEL=style_melgan
export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/gan_vocoder/${MODEL}

@ -0,0 +1,32 @@
#!/bin/bash
set -e
source path.sh
gpus=0,1
stage=0
stop_stage=100
conf_path=conf/default.yaml
train_output_path=exp/default
ckpt_name=snapshot_iter_50000.pdz
# with the following command, you can choose the stage range you want to run
# such as `./run.sh --stage 0 --stop-stage 0`
# this can not be mixed use with `$1`, `$2` ...
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
./local/preprocess.sh ${conf_path} || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# synthesize
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
fi

@ -35,7 +35,7 @@ generator_params:
dropout: 0.0 # Dropout rate. 0.0 means no dropout applied.
use_weight_norm: true # Whether to use weight norm.
# If set to true, it will be applied to all of the conv layers.
upsample_scales: [4, 4, 4, 4] # Upsampling scales. Prodcut of these must be the same as hop size.
upsample_scales: [4, 4, 4, 4] # Upsampling scales. prod(upsample_scales) == n_shift
###########################################################
# DISCRIMINATOR NETWORK ARCHITECTURE SETTING #
@ -60,7 +60,7 @@ stft_loss_params:
fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss.
hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss
win_lengths: [600, 1200, 240] # List of window length for STFT-based loss.
window: "hann" # Window function for STFT-based loss
window: "hann" # Window function for STFT-based loss
###########################################################
# ADVERSARIAL LOSS SETTING #
@ -71,7 +71,7 @@ lambda_adv: 4.0 # Loss balancing coefficient.
# DATA LOADER SETTING #
###########################################################
batch_size: 8 # Batch size.
batch_max_steps: 25600 # Length of each audio in batch. Make sure dividable by hop_size.
batch_max_steps: 25600 # Length of each audio in batch. Make sure dividable by n_shift.
pin_memory: true # Whether to pin memory in Pytorch DataLoader.
num_workers: 4 # Number of workers in Pytorch DataLoader.
remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps.
@ -84,20 +84,20 @@ generator_optimizer_params:
epsilon: 1.0e-6 # Generator's epsilon.
weight_decay: 0.0 # Generator's weight decay coefficient.
generator_scheduler_params:
learning_rate: 0.0001 # Generator's learning rate.
learning_rate: 0.0001 # Generator's learning rate.
step_size: 200000 # Generator's scheduler step size.
gamma: 0.5 # Generator's scheduler gamma.
# At each step size, lr will be multiplied by this parameter.
generator_grad_norm: 10 # Generator's gradient norm.
discriminator_optimizer_params:
epsilon: 1.0e-6 # Discriminator's epsilon.
weight_decay: 0.0 # Discriminator's weight decay coefficient.
epsilon: 1.0e-6 # Discriminator's epsilon.
weight_decay: 0.0 # Discriminator's weight decay coefficient.
discriminator_scheduler_params:
learning_rate: 0.00005 # Discriminator's learning rate.
step_size: 200000 # Discriminator's scheduler step size.
gamma: 0.5 # Discriminator's scheduler gamma.
# At each step size, lr will be multiplied by this parameter.
discriminator_grad_norm: 1 # Discriminator's gradient norm.
learning_rate: 0.00005 # Discriminator's learning rate.
step_size: 200000 # Discriminator's scheduler step size.
gamma: 0.5 # Discriminator's scheduler gamma.
# At each step size, lr will be multiplied by this parameter.
discriminator_grad_norm: 1 # Discriminator's gradient norm.
###########################################################
# INTERVAL SETTING #

@ -6,8 +6,9 @@ ckpt_name=$3
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/synthesize.py \
python3 ${BIN_DIR}/../synthesize.py \
--config=${config_path} \
--checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
--test-metadata=dump/test/norm/metadata.jsonl \
--output-dir=${train_output_path}/test
--output-dir=${train_output_path}/test \
--generator-type=pwgan

@ -35,7 +35,7 @@ generator_params:
dropout: 0.0 # Dropout rate. 0.0 means no dropout applied.
use_weight_norm: true # Whether to use weight norm.
# If set to true, it will be applied to all of the conv layers.
upsample_scales: [4, 5, 3, 5] # Upsampling scales. Prodcut of these must be the same as hop size.
upsample_scales: [4, 5, 3, 5] # Upsampling scales. prod(upsample_scales) == n_shift
###########################################################
# DISCRIMINATOR NETWORK ARCHITECTURE SETTING #
@ -71,7 +71,7 @@ lambda_adv: 4.0 # Loss balancing coefficient.
# DATA LOADER SETTING #
###########################################################
batch_size: 8 # Batch size.
batch_max_steps: 24000 # Length of each audio in batch. Make sure dividable by hop_size.
batch_max_steps: 24000 # Length of each audio in batch. Make sure dividable by n_shift.
pin_memory: true # Whether to pin memory in Pytorch DataLoader.
num_workers: 4 # Number of workers in Pytorch DataLoader.
remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps.

@ -6,8 +6,9 @@ ckpt_name=$3
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/synthesize.py \
python3 ${BIN_DIR}/../synthesize.py \
--config=${config_path} \
--checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
--test-metadata=dump/test/norm/metadata.jsonl \
--output-dir=${train_output_path}/test
--output-dir=${train_output_path}/test \
--generator-type=pwgan

@ -1,103 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from pathlib import Path
import jsonlines
import numpy as np
import paddle
import soundfile as sf
import yaml
from paddle import distributed as dist
from timer import timer
from yacs.config import CfgNode
from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.models.parallel_wavegan import PWGGenerator
def main():
parser = argparse.ArgumentParser(
description="Synthesize with parallel wavegan.")
parser.add_argument(
"--config", type=str, help="parallel wavegan config file.")
parser.add_argument("--checkpoint", type=str, help="snapshot to load.")
parser.add_argument("--test-metadata", type=str, help="dev data.")
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
parser.add_argument("--verbose", type=int, default=1, help="verbose.")
args = parser.parse_args()
with open(args.config) as f:
config = CfgNode(yaml.safe_load(f))
print("========Args========")
print(yaml.safe_dump(vars(args)))
print("========Config========")
print(config)
print(
f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
)
if args.ngpu == 0:
paddle.set_device("cpu")
elif args.ngpu > 0:
paddle.set_device("gpu")
else:
print("ngpu should >= 0 !")
generator = PWGGenerator(**config["generator_params"])
state_dict = paddle.load(args.checkpoint)
generator.set_state_dict(state_dict["generator_params"])
generator.remove_weight_norm()
generator.eval()
with jsonlines.open(args.test_metadata, 'r') as reader:
metadata = list(reader)
test_dataset = DataTable(
metadata,
fields=['utt_id', 'feats'],
converters={
'utt_id': None,
'feats': np.load,
})
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
N = 0
T = 0
for example in test_dataset:
utt_id = example['utt_id']
mel = example['feats']
mel = paddle.to_tensor(mel) # (T, C)
with timer() as t:
with paddle.no_grad():
wav = generator.inference(c=mel)
wav = wav.numpy()
N += wav.size
T += t.elapse
speed = wav.size / t.elapse
rtf = config.fs / speed
print(
f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
)
sf.write(str(output_dir / (utt_id + ".wav")), wav, samplerate=config.fs)
print(f"generation speed: {N / T}Hz, RTF: {config.fs / (N / T) }")
if __name__ == "__main__":
main()

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,258 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
import shutil
from pathlib import Path
import jsonlines
import numpy as np
import paddle
import yaml
from paddle import DataParallel
from paddle import distributed as dist
from paddle import nn
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from paddle.optimizer import Adam
from paddle.optimizer.lr import MultiStepDecay
from yacs.config import CfgNode
from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.datasets.vocoder_batch_fn import Clip
from paddlespeech.t2s.models.melgan import StyleMelGANDiscriminator
from paddlespeech.t2s.models.melgan import StyleMelGANEvaluator
from paddlespeech.t2s.models.melgan import StyleMelGANGenerator
from paddlespeech.t2s.models.melgan import StyleMelGANUpdater
from paddlespeech.t2s.modules.losses import DiscriminatorAdversarialLoss
from paddlespeech.t2s.modules.losses import GeneratorAdversarialLoss
from paddlespeech.t2s.modules.losses import MultiResolutionSTFTLoss
from paddlespeech.t2s.training.extensions.snapshot import Snapshot
from paddlespeech.t2s.training.extensions.visualizer import VisualDL
from paddlespeech.t2s.training.seeding import seed_everything
from paddlespeech.t2s.training.trainer import Trainer
def train_sp(args, config):
# decides device type and whether to run in parallel
# setup running environment correctly
world_size = paddle.distributed.get_world_size()
if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0:
paddle.set_device("cpu")
else:
paddle.set_device("gpu")
if world_size > 1:
paddle.distributed.init_parallel_env()
# set the random seed, it is a must for multiprocess training
seed_everything(config.seed)
print(
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
)
# dataloader has been too verbose
logging.getLogger("DataLoader").disabled = True
# construct dataset for training and validation
with jsonlines.open(args.train_metadata, 'r') as reader:
train_metadata = list(reader)
train_dataset = DataTable(
data=train_metadata,
fields=["wave", "feats"],
converters={
"wave": np.load,
"feats": np.load,
}, )
with jsonlines.open(args.dev_metadata, 'r') as reader:
dev_metadata = list(reader)
dev_dataset = DataTable(
data=dev_metadata,
fields=["wave", "feats"],
converters={
"wave": np.load,
"feats": np.load,
}, )
# collate function and dataloader
train_sampler = DistributedBatchSampler(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
drop_last=True)
dev_sampler = DistributedBatchSampler(
dev_dataset,
batch_size=config.batch_size,
shuffle=False,
drop_last=False)
print("samplers done!")
if "aux_context_window" in config.generator_params:
aux_context_window = config.generator_params.aux_context_window
else:
aux_context_window = 0
train_batch_fn = Clip(
batch_max_steps=config.batch_max_steps,
hop_size=config.n_shift,
aux_context_window=aux_context_window)
train_dataloader = DataLoader(
train_dataset,
batch_sampler=train_sampler,
collate_fn=train_batch_fn,
num_workers=config.num_workers)
dev_dataloader = DataLoader(
dev_dataset,
batch_sampler=dev_sampler,
collate_fn=train_batch_fn,
num_workers=config.num_workers)
print("dataloaders done!")
generator = StyleMelGANGenerator(**config["generator_params"])
discriminator = StyleMelGANDiscriminator(**config["discriminator_params"])
if world_size > 1:
generator = DataParallel(generator)
discriminator = DataParallel(discriminator)
print("models done!")
criterion_stft = MultiResolutionSTFTLoss(**config["stft_loss_params"])
criterion_gen_adv = GeneratorAdversarialLoss(
**config["generator_adv_loss_params"])
criterion_dis_adv = DiscriminatorAdversarialLoss(
**config["discriminator_adv_loss_params"])
print("criterions done!")
lr_schedule_g = MultiStepDecay(**config["generator_scheduler_params"])
# Compared to multi_band_melgan.v1 config, Adam optimizer without gradient norm is used
generator_grad_norm = config["generator_grad_norm"]
gradient_clip_g = nn.ClipGradByGlobalNorm(
generator_grad_norm) if generator_grad_norm > 0 else None
print("gradient_clip_g:", gradient_clip_g)
optimizer_g = Adam(
learning_rate=lr_schedule_g,
grad_clip=gradient_clip_g,
parameters=generator.parameters(),
**config["generator_optimizer_params"])
lr_schedule_d = MultiStepDecay(**config["discriminator_scheduler_params"])
discriminator_grad_norm = config["discriminator_grad_norm"]
gradient_clip_d = nn.ClipGradByGlobalNorm(
discriminator_grad_norm) if discriminator_grad_norm > 0 else None
print("gradient_clip_d:", gradient_clip_d)
optimizer_d = Adam(
learning_rate=lr_schedule_d,
grad_clip=gradient_clip_d,
parameters=discriminator.parameters(),
**config["discriminator_optimizer_params"])
print("optimizers done!")
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
if dist.get_rank() == 0:
config_name = args.config.split("/")[-1]
# copy conf to output_dir
shutil.copyfile(args.config, output_dir / config_name)
updater = StyleMelGANUpdater(
models={
"generator": generator,
"discriminator": discriminator,
},
optimizers={
"generator": optimizer_g,
"discriminator": optimizer_d,
},
criterions={
"stft": criterion_stft,
"gen_adv": criterion_gen_adv,
"dis_adv": criterion_dis_adv,
},
schedulers={
"generator": lr_schedule_g,
"discriminator": lr_schedule_d,
},
dataloader=train_dataloader,
discriminator_train_start_steps=config.discriminator_train_start_steps,
lambda_adv=config.lambda_adv,
output_dir=output_dir)
evaluator = StyleMelGANEvaluator(
models={
"generator": generator,
"discriminator": discriminator,
},
criterions={
"stft": criterion_stft,
"gen_adv": criterion_gen_adv,
"dis_adv": criterion_dis_adv,
},
dataloader=dev_dataloader,
lambda_adv=config.lambda_adv,
output_dir=output_dir)
trainer = Trainer(
updater,
stop_trigger=(config.train_max_steps, "iteration"),
out=output_dir)
if dist.get_rank() == 0:
trainer.extend(
evaluator, trigger=(config.eval_interval_steps, 'iteration'))
trainer.extend(VisualDL(output_dir), trigger=(1, 'iteration'))
trainer.extend(
Snapshot(max_size=config.num_snapshots),
trigger=(config.save_interval_steps, 'iteration'))
print("Trainer Done!")
trainer.run()
def main():
# parse args and config and redirect to train_sp
parser = argparse.ArgumentParser(
description="Train a Multi-Band MelGAN model.")
parser.add_argument(
"--config", type=str, help="config file to overwrite default config.")
parser.add_argument("--train-metadata", type=str, help="training data.")
parser.add_argument("--dev-metadata", type=str, help="dev data.")
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
parser.add_argument("--verbose", type=int, default=1, help="verbose.")
args = parser.parse_args()
with open(args.config, 'rt') as f:
config = CfgNode(yaml.safe_load(f))
print("========Args========")
print(yaml.safe_dump(vars(args)))
print("========Config========")
print(config)
print(
f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
)
# dispatch
if args.ngpu > 1:
dist.spawn(train_sp, (args, config), nprocs=args.ngpu)
else:
train_sp(args, config)
if __name__ == "__main__":
main()

@ -24,15 +24,19 @@ from paddle import distributed as dist
from timer import timer
from yacs.config import CfgNode
import paddlespeech
from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.models.melgan import MelGANGenerator
def main():
parser = argparse.ArgumentParser(
description="Synthesize with multi band melgan.")
parser = argparse.ArgumentParser(description="Synthesize with GANVocoder.")
parser.add_argument(
"--config", type=str, help="multi band melgan config file.")
"--generator-type",
type=str,
default="pwgan",
help="type of GANVocoder, should in {pwgan, mb_melgan, style_melgan, } now"
)
parser.add_argument("--config", type=str, help="GANVocoder config file.")
parser.add_argument("--checkpoint", type=str, help="snapshot to load.")
parser.add_argument("--test-metadata", type=str, help="dev data.")
parser.add_argument("--output-dir", type=str, help="output dir.")
@ -59,15 +63,29 @@ def main():
paddle.set_device("gpu")
else:
print("ngpu should >= 0 !")
generator = MelGANGenerator(**config["generator_params"])
class_map = {
"hifigan": "HiFiGANGenerator",
"mb_melgan": "MelGANGenerator",
"pwgan": "PWGGenerator",
"style_melgan": "StyleMelGANGenerator",
}
generator_type = args.generator_type
assert generator_type in class_map
print("generator_type:", generator_type)
generator_class = getattr(paddlespeech.t2s.models,
class_map[generator_type])
generator = generator_class(**config["generator_params"])
state_dict = paddle.load(args.checkpoint)
generator.set_state_dict(state_dict["generator_params"])
generator.remove_weight_norm()
generator.eval()
with jsonlines.open(args.test_metadata, 'r') as reader:
metadata = list(reader)
test_dataset = DataTable(
metadata,
fields=['utt_id', 'feats'],

@ -14,6 +14,7 @@
from .fastspeech2 import *
from .melgan import *
from .parallel_wavegan import *
from .speedyspeech import *
from .tacotron2 import *
from .transformer_tts import *
from .waveflow import *

@ -13,3 +13,5 @@
# limitations under the License.
from .melgan import *
from .multi_band_melgan_updater import *
from .style_melgan import *
from .style_melgan_updater import *

@ -21,6 +21,7 @@ import numpy as np
import paddle
from paddle import nn
from paddlespeech.t2s.modules.activation import get_activation
from paddlespeech.t2s.modules.causal_conv import CausalConv1D
from paddlespeech.t2s.modules.causal_conv import CausalConv1DTranspose
from paddlespeech.t2s.modules.nets_utils import initialize
@ -41,7 +42,7 @@ class MelGANGenerator(nn.Layer):
upsample_scales: List[int]=[8, 8, 2, 2],
stack_kernel_size: int=3,
stacks: int=3,
nonlinear_activation: str="LeakyReLU",
nonlinear_activation: str="leakyrelu",
nonlinear_activation_params: Dict[str, Any]={"negative_slope": 0.2},
pad: str="Pad1D",
pad_params: Dict[str, Any]={"mode": "reflect"},
@ -88,16 +89,19 @@ class MelGANGenerator(nn.Layer):
"""
super().__init__()
# initialize parameters
initialize(self, init_type)
# for compatibility
if nonlinear_activation == "LeakyReLU":
nonlinear_activation = "leakyrelu"
# check hyper parameters is valid
assert channels >= np.prod(upsample_scales)
assert channels % (2**len(upsample_scales)) == 0
if not use_causal_conv:
assert (kernel_size - 1
) % 2 == 0, "Not support even number kernel size."
# initialize parameters
initialize(self, init_type)
layers = []
if not use_causal_conv:
layers += [
@ -118,7 +122,8 @@ class MelGANGenerator(nn.Layer):
for i, upsample_scale in enumerate(upsample_scales):
# add upsampling layer
layers += [
getattr(nn, nonlinear_activation)(**nonlinear_activation_params)
get_activation(nonlinear_activation,
**nonlinear_activation_params)
]
if not use_causal_conv:
layers += [
@ -158,7 +163,7 @@ class MelGANGenerator(nn.Layer):
# add final layer
layers += [
getattr(nn, nonlinear_activation)(**nonlinear_activation_params)
get_activation(nonlinear_activation, **nonlinear_activation_params)
]
if not use_causal_conv:
layers += [
@ -242,7 +247,6 @@ class MelGANGenerator(nn.Layer):
This initialization follows official implementation manner.
https://github.com/descriptinc/melgan-neurips/blob/master/mel2wav/modules.py
"""
# 定义参数为float的正态分布。
dist = paddle.distribution.Normal(loc=0.0, scale=0.02)
@ -287,10 +291,11 @@ class MelGANDiscriminator(nn.Layer):
max_downsample_channels: int=1024,
bias: bool=True,
downsample_scales: List[int]=[4, 4, 4, 4],
nonlinear_activation: str="LeakyReLU",
nonlinear_activation: str="leakyrelu",
nonlinear_activation_params: Dict[str, Any]={"negative_slope": 0.2},
pad: str="Pad1D",
pad_params: Dict[str, Any]={"mode": "reflect"}, ):
pad_params: Dict[str, Any]={"mode": "reflect"},
init_type: str="xavier_uniform", ):
"""Initilize MelGAN discriminator module.
Parameters
----------
@ -321,6 +326,14 @@ class MelGANDiscriminator(nn.Layer):
Hyperparameters for padding function.
"""
super().__init__()
# for compatibility
if nonlinear_activation == "LeakyReLU":
nonlinear_activation = "leakyrelu"
# initialize parameters
initialize(self, init_type)
self.layers = nn.LayerList()
# check kernel size is valid
@ -338,8 +351,8 @@ class MelGANDiscriminator(nn.Layer):
channels,
int(np.prod(kernel_sizes)),
bias_attr=bias),
getattr(nn, nonlinear_activation)(
**nonlinear_activation_params), ))
get_activation(nonlinear_activation, **
nonlinear_activation_params), ))
# add downsample layers
in_chs = channels
@ -355,8 +368,8 @@ class MelGANDiscriminator(nn.Layer):
padding=downsample_scale * 5,
groups=in_chs // 4,
bias_attr=bias, ),
getattr(nn, nonlinear_activation)(
**nonlinear_activation_params), ))
get_activation(nonlinear_activation, **
nonlinear_activation_params), ))
in_chs = out_chs
# add final layers
@ -369,8 +382,8 @@ class MelGANDiscriminator(nn.Layer):
kernel_sizes[0],
padding=(kernel_sizes[0] - 1) // 2,
bias_attr=bias, ),
getattr(nn, nonlinear_activation)(
**nonlinear_activation_params), ))
get_activation(nonlinear_activation, **
nonlinear_activation_params), ))
self.layers.append(
nn.Conv1D(
out_chs,
@ -419,7 +432,7 @@ class MelGANMultiScaleDiscriminator(nn.Layer):
max_downsample_channels: int=1024,
bias: bool=True,
downsample_scales: List[int]=[4, 4, 4, 4],
nonlinear_activation: str="LeakyReLU",
nonlinear_activation: str="leakyrelu",
nonlinear_activation_params: Dict[str, Any]={"negative_slope": 0.2},
pad: str="Pad1D",
pad_params: Dict[str, Any]={"mode": "reflect"},
@ -461,9 +474,14 @@ class MelGANMultiScaleDiscriminator(nn.Layer):
Whether to use causal convolution.
"""
super().__init__()
# initialize parameters
initialize(self, init_type)
# for compatibility
if nonlinear_activation == "LeakyReLU":
nonlinear_activation = "leakyrelu"
self.discriminators = nn.LayerList()
# add discriminators

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
from typing import Dict
import paddle
@ -41,7 +42,7 @@ class MBMelGANUpdater(StandardUpdater):
dataloader: DataLoader,
discriminator_train_start_steps: int,
lambda_adv: float,
output_dir=None):
output_dir: Path=None):
self.models = models
self.generator: Layer = models['generator']
self.discriminator: Layer = models['discriminator']
@ -159,11 +160,11 @@ class MBMelGANUpdater(StandardUpdater):
class MBMelGANEvaluator(StandardEvaluator):
def __init__(self,
models,
criterions,
dataloader,
lambda_adv,
output_dir=None):
models: Dict[str, Layer],
criterions: Dict[str, Layer],
dataloader: DataLoader,
lambda_adv: float,
output_dir: Path=None):
self.models = models
self.generator = models['generator']
self.discriminator = models['discriminator']

@ -0,0 +1,404 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
"""StyleMelGAN Modules."""
import copy
import math
from typing import Any
from typing import Dict
from typing import List
import numpy as np
import paddle
import paddle.nn.functional as F
from paddle import nn
from paddlespeech.t2s.models.melgan import MelGANDiscriminator as BaseDiscriminator
from paddlespeech.t2s.modules.activation import get_activation
from paddlespeech.t2s.modules.nets_utils import initialize
from paddlespeech.t2s.modules.pqmf import PQMF
from paddlespeech.t2s.modules.tade_res_block import TADEResBlock
class StyleMelGANGenerator(nn.Layer):
"""Style MelGAN generator module."""
def __init__(
self,
in_channels: int=128,
aux_channels: int=80,
channels: int=64,
out_channels: int=1,
kernel_size: int=9,
dilation: int=2,
bias: bool=True,
noise_upsample_scales: List[int]=[11, 2, 2, 2],
noise_upsample_activation: str="leakyrelu",
noise_upsample_activation_params: Dict[str,
Any]={"negative_slope": 0.2},
upsample_scales: List[int]=[2, 2, 2, 2, 2, 2, 2, 2, 1],
upsample_mode: str="linear",
gated_function: str="softmax",
use_weight_norm: bool=True,
init_type: str="xavier_uniform", ):
"""Initilize Style MelGAN generator.
Parameters
----------
in_channels : int
Number of input noise channels.
aux_channels : int
Number of auxiliary input channels.
channels : int
Number of channels for conv layer.
out_channels : int
Number of output channels.
kernel_size : int
Kernel size of conv layers.
dilation : int
Dilation factor for conv layers.
bias : bool
Whether to add bias parameter in convolution layers.
noise_upsample_scales : list
List of noise upsampling scales.
noise_upsample_activation : str
Activation function module name for noise upsampling.
noise_upsample_activation_params : dict
Hyperparameters for the above activation function.
upsample_scales : list
List of upsampling scales.
upsample_mode : str
Upsampling mode in TADE layer.
gated_function : str
Gated function in TADEResBlock ("softmax" or "sigmoid").
use_weight_norm : bool
Whether to use weight norm.
If set to true, it will be applied to all of the conv layers.
"""
super().__init__()
# initialize parameters
initialize(self, init_type)
self.in_channels = in_channels
noise_upsample = []
in_chs = in_channels
for noise_upsample_scale in noise_upsample_scales:
noise_upsample.append(
nn.Conv1DTranspose(
in_chs,
channels,
noise_upsample_scale * 2,
stride=noise_upsample_scale,
padding=noise_upsample_scale // 2 + noise_upsample_scale %
2,
output_padding=noise_upsample_scale % 2,
bias_attr=bias, ))
noise_upsample.append(
get_activation(noise_upsample_activation, **
noise_upsample_activation_params))
in_chs = channels
self.noise_upsample = nn.Sequential(*noise_upsample)
self.noise_upsample_factor = np.prod(noise_upsample_scales)
self.blocks = nn.LayerList()
aux_chs = aux_channels
for upsample_scale in upsample_scales:
self.blocks.append(
TADEResBlock(
in_channels=channels,
aux_channels=aux_chs,
kernel_size=kernel_size,
dilation=dilation,
bias=bias,
upsample_factor=upsample_scale,
upsample_mode=upsample_mode,
gated_function=gated_function, ), )
aux_chs = channels
self.upsample_factor = np.prod(upsample_scales)
self.output_conv = nn.Sequential(
nn.Conv1D(
channels,
out_channels,
kernel_size,
1,
bias_attr=bias,
padding=(kernel_size - 1) // 2, ),
nn.Tanh(), )
nn.initializer.set_global_initializer(None)
# apply weight norm
if use_weight_norm:
self.apply_weight_norm()
# reset parameters
self.reset_parameters()
def forward(self, c, z=None):
"""Calculate forward propagation.
Parameters
----------
c : Tensor
Auxiliary input tensor (B, channels, T).
z : Tensor
Input noise tensor (B, in_channels, 1).
Returns
----------
Tensor
Output tensor (B, out_channels, T ** prod(upsample_scales)).
"""
# batch_max_steps(24000) == noise_upsample_factor(80) * upsample_factor(300)
if z is None:
z = paddle.randn([paddle.shape(c)[0], self.in_channels, 1])
# (B, in_channels, noise_upsample_factor).
x = self.noise_upsample(z)
for block in self.blocks:
x, c = block(x, c)
x = self.output_conv(x)
return x
def apply_weight_norm(self):
"""Recursively apply weight normalization to all the Convolution layers
in the sublayers.
"""
def _apply_weight_norm(layer):
if isinstance(layer, (nn.Conv1D, nn.Conv1DTranspose)):
nn.utils.weight_norm(layer)
self.apply(_apply_weight_norm)
def remove_weight_norm(self):
"""Recursively remove weight normalization from all the Convolution
layers in the sublayers.
"""
def _remove_weight_norm(layer):
try:
if layer:
nn.utils.remove_weight_norm(layer)
except ValueError:
pass
self.apply(_remove_weight_norm)
def reset_parameters(self):
"""Reset parameters.
This initialization follows official implementation manner.
https://github.com/descriptinc/melgan-neurips/blob/master/mel2wav/modules.py
"""
# 定义参数为float的正态分布。
dist = paddle.distribution.Normal(loc=0.0, scale=0.02)
def _reset_parameters(m):
if isinstance(m, nn.Conv1D) or isinstance(m, nn.Conv1DTranspose):
w = dist.sample(m.weight.shape)
m.weight.set_value(w)
self.apply(_reset_parameters)
def inference(self, c):
"""Perform inference.
Parameters
----------
c : Tensor
Input tensor (T, in_channels).
Returns
----------
Tensor
Output tensor (T ** prod(upsample_scales), out_channels).
"""
# (1, in_channels, T)
c = c.transpose([1, 0]).unsqueeze(0)
c_shape = paddle.shape(c)
# prepare noise input
# there is a bug in Paddle int division, we must convert a int tensor to int here
noise_size = (1, self.in_channels,
math.ceil(int(c_shape[2]) / self.noise_upsample_factor))
# (1, in_channels, T/noise_upsample_factor)
noise = paddle.randn(noise_size)
# (1, in_channels, T)
x = self.noise_upsample(noise)
x_shape = paddle.shape(x)
total_length = c_shape[2] * self.upsample_factor
c = F.pad(
c, (0, x_shape[2] - c_shape[2]), "replicate", data_format="NCL")
# c.shape[2] == x.shape[2] here
# (1, in_channels, T*prod(upsample_scales))
for block in self.blocks:
x, c = block(x, c)
x = self.output_conv(x)[..., :total_length]
return x.squeeze(0).transpose([1, 0])
# StyleMelGANDiscriminator 不需要 remove weight norm 嘛?
class StyleMelGANDiscriminator(nn.Layer):
"""Style MelGAN disciminator module."""
def __init__(
self,
repeats: int=2,
window_sizes: List[int]=[512, 1024, 2048, 4096],
pqmf_params: List[List[int]]=[
[1, None, None, None],
[2, 62, 0.26700, 9.0],
[4, 62, 0.14200, 9.0],
[8, 62, 0.07949, 9.0],
],
discriminator_params: Dict[str, Any]={
"out_channels": 1,
"kernel_sizes": [5, 3],
"channels": 16,
"max_downsample_channels": 512,
"bias": True,
"downsample_scales": [4, 4, 4, 1],
"nonlinear_activation": "leakyrelu",
"nonlinear_activation_params": {
"negative_slope": 0.2
},
"pad": "Pad1D",
"pad_params": {
"mode": "reflect"
},
},
use_weight_norm: bool=True,
init_type: str="xavier_uniform", ):
"""Initilize Style MelGAN discriminator.
Parameters
----------
repeats : int
Number of repititons to apply RWD.
window_sizes : list
List of random window sizes.
pqmf_params : list
List of list of Parameters for PQMF modules
discriminator_params : dict
Parameters for base discriminator module.
use_weight_nom : bool
Whether to apply weight normalization.
"""
super().__init__()
# initialize parameters
initialize(self, init_type)
# window size check
assert len(window_sizes) == len(pqmf_params)
sizes = [ws // p[0] for ws, p in zip(window_sizes, pqmf_params)]
assert len(window_sizes) == sum([sizes[0] == size for size in sizes])
self.repeats = repeats
self.window_sizes = window_sizes
self.pqmfs = nn.LayerList()
self.discriminators = nn.LayerList()
for pqmf_param in pqmf_params:
d_params = copy.deepcopy(discriminator_params)
d_params["in_channels"] = pqmf_param[0]
if pqmf_param[0] == 1:
self.pqmfs.append(nn.Identity())
else:
self.pqmfs.append(PQMF(*pqmf_param))
self.discriminators.append(BaseDiscriminator(**d_params))
nn.initializer.set_global_initializer(None)
# apply weight norm
if use_weight_norm:
self.apply_weight_norm()
# reset parameters
self.reset_parameters()
def forward(self, x):
"""Calculate forward propagation.
Parameters
----------
x : Tensor
Input tensor (B, 1, T).
Returns
----------
List
List of discriminator outputs, #items in the list will be
equal to repeats * #discriminators.
"""
outs = []
for _ in range(self.repeats):
outs += self._forward(x)
return outs
def _forward(self, x):
outs = []
for idx, (ws, pqmf, disc) in enumerate(
zip(self.window_sizes, self.pqmfs, self.discriminators)):
start_idx = int(np.random.randint(paddle.shape(x)[-1] - ws))
x_ = x[:, :, start_idx:start_idx + ws]
if idx == 0:
# nn.Identity()
x_ = pqmf(x_)
else:
x_ = pqmf.analysis(x_)
outs += [disc(x_)]
return outs
def apply_weight_norm(self):
"""Recursively apply weight normalization to all the Convolution layers
in the sublayers.
"""
def _apply_weight_norm(layer):
if isinstance(layer, (nn.Conv1D, nn.Conv1DTranspose)):
nn.utils.weight_norm(layer)
self.apply(_apply_weight_norm)
def remove_weight_norm(self):
"""Recursively remove weight normalization from all the Convolution
layers in the sublayers.
"""
def _remove_weight_norm(layer):
try:
nn.utils.remove_weight_norm(layer)
except ValueError:
pass
self.apply(_remove_weight_norm)
def reset_parameters(self):
"""Reset parameters.
This initialization follows official implementation manner.
https://github.com/descriptinc/melgan-neurips/blob/master/mel2wav/modules.py
"""
# 定义参数为float的正态分布。
dist = paddle.distribution.Normal(loc=0.0, scale=0.02)
def _reset_parameters(m):
if isinstance(m, nn.Conv1D) or isinstance(m, nn.Conv1DTranspose):
w = dist.sample(m.weight.shape)
m.weight.set_value(w)
self.apply(_reset_parameters)
class StyleMelGANInference(nn.Layer):
def __init__(self, normalizer, style_melgan_generator):
super().__init__()
self.normalizer = normalizer
self.style_melgan_generator = style_melgan_generator
def forward(self, logmel):
normalized_mel = self.normalizer(logmel)
wav = self.style_melgan_generator.inference(normalized_mel)
return wav

@ -0,0 +1,221 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
from typing import Dict
import paddle
from paddle import distributed as dist
from paddle.io import DataLoader
from paddle.nn import Layer
from paddle.optimizer import Optimizer
from paddle.optimizer.lr import LRScheduler
from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
from paddlespeech.t2s.training.reporter import report
from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater
from paddlespeech.t2s.training.updaters.standard_updater import UpdaterState
logging.basicConfig(
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
datefmt='[%Y-%m-%d %H:%M:%S]')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class StyleMelGANUpdater(StandardUpdater):
def __init__(self,
models: Dict[str, Layer],
optimizers: Dict[str, Optimizer],
criterions: Dict[str, Layer],
schedulers: Dict[str, LRScheduler],
dataloader: DataLoader,
discriminator_train_start_steps: int,
lambda_adv: float,
lambda_aux: float=1.0,
output_dir: Path=None):
self.models = models
self.generator: Layer = models['generator']
self.discriminator: Layer = models['discriminator']
self.optimizers = optimizers
self.optimizer_g: Optimizer = optimizers['generator']
self.optimizer_d: Optimizer = optimizers['discriminator']
self.criterions = criterions
self.criterion_stft = criterions['stft']
self.criterion_gen_adv = criterions["gen_adv"]
self.criterion_dis_adv = criterions["dis_adv"]
self.schedulers = schedulers
self.scheduler_g = schedulers['generator']
self.scheduler_d = schedulers['discriminator']
self.dataloader = dataloader
self.discriminator_train_start_steps = discriminator_train_start_steps
self.lambda_adv = lambda_adv
self.lambda_aux = lambda_aux
self.state = UpdaterState(iteration=0, epoch=0)
self.train_iterator = iter(self.dataloader)
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def update_core(self, batch):
self.msg = "Rank: {}, ".format(dist.get_rank())
losses_dict = {}
# parse batch
wav, mel = batch
# Generator
# (B, out_channels, T ** prod(upsample_scales)
wav_ = self.generator(mel)
# initialize
gen_loss = 0.0
# full band Multi-resolution stft loss
sc_loss, mag_loss = self.criterion_stft(wav_, wav)
gen_loss += sc_loss + mag_loss
report("train/spectral_convergence_loss", float(sc_loss))
report("train/log_stft_magnitude_loss", float(mag_loss))
losses_dict["spectral_convergence_loss"] = float(sc_loss)
losses_dict["log_stft_magnitude_loss"] = float(mag_loss)
gen_loss *= self.lambda_aux
## Adversarial loss
if self.state.iteration > self.discriminator_train_start_steps:
p_ = self.discriminator(wav_)
adv_loss = self.criterion_gen_adv(p_)
report("train/adversarial_loss", float(adv_loss))
losses_dict["adversarial_loss"] = float(adv_loss)
gen_loss += self.lambda_adv * adv_loss
report("train/generator_loss", float(gen_loss))
losses_dict["generator_loss"] = float(gen_loss)
self.optimizer_g.clear_grad()
gen_loss.backward()
self.optimizer_g.step()
self.scheduler_g.step()
# Disctiminator
if self.state.iteration > self.discriminator_train_start_steps:
# re-compute wav_ which leads better quality
with paddle.no_grad():
wav_ = self.generator(mel)
p = self.discriminator(wav)
p_ = self.discriminator(wav_.detach())
real_loss, fake_loss = self.criterion_dis_adv(p_, p)
dis_loss = real_loss + fake_loss
report("train/real_loss", float(real_loss))
report("train/fake_loss", float(fake_loss))
report("train/discriminator_loss", float(dis_loss))
losses_dict["real_loss"] = float(real_loss)
losses_dict["fake_loss"] = float(fake_loss)
losses_dict["discriminator_loss"] = float(dis_loss)
self.optimizer_d.clear_grad()
dis_loss.backward()
self.optimizer_d.step()
self.scheduler_d.step()
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
class StyleMelGANEvaluator(StandardEvaluator):
def __init__(self,
models: Dict[str, Layer],
criterions: Dict[str, Layer],
dataloader: DataLoader,
lambda_adv: float,
lambda_aux: float=1.0,
output_dir: Path=None):
self.models = models
self.generator = models['generator']
self.discriminator = models['discriminator']
self.criterions = criterions
self.criterion_stft = criterions['stft']
self.criterion_gen_adv = criterions["gen_adv"]
self.criterion_dis_adv = criterions["dis_adv"]
self.dataloader = dataloader
self.lambda_adv = lambda_adv
self.lambda_aux = lambda_aux
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def evaluate_core(self, batch):
# logging.debug("Evaluate: ")
self.msg = "Evaluate: "
losses_dict = {}
wav, mel = batch
# Generator
# (B, out_channels, T ** prod(upsample_scales)
wav_ = self.generator(mel)
## Adversarial loss
p_ = self.discriminator(wav_)
adv_loss = self.criterion_gen_adv(p_)
report("eval/adversarial_loss", float(adv_loss))
losses_dict["adversarial_loss"] = float(adv_loss)
gen_loss = self.lambda_adv * adv_loss
# initialize
aux_loss = 0.0
# Multi-resolution stft loss
sc_loss, mag_loss = self.criterion_stft(wav_, wav)
aux_loss += sc_loss + mag_loss
report("eval/spectral_convergence_loss", float(sc_loss))
report("eval/log_stft_magnitude_loss", float(mag_loss))
losses_dict["spectral_convergence_loss"] = float(sc_loss)
losses_dict["log_stft_magnitude_loss"] = float(mag_loss)
aux_loss *= self.lambda_aux
gen_loss += aux_loss
report("eval/generator_loss", float(gen_loss))
losses_dict["generator_loss"] = float(gen_loss)
# Disctiminator
p = self.discriminator(wav)
real_loss, fake_loss = self.criterion_dis_adv(p_, p)
dis_loss = real_loss + fake_loss
report("eval/real_loss", float(real_loss))
report("eval/fake_loss", float(fake_loss))
report("eval/discriminator_loss", float(dis_loss))
losses_dict["real_loss"] = float(real_loss)
losses_dict["fake_loss"] = float(fake_loss)
losses_dict["discriminator_loss"] = float(dis_loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
self.logger.info(self.msg)

@ -10,8 +10,10 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations
# under the License.
import logging
from pathlib import Path
from typing import Dict
import paddle
@ -42,7 +44,7 @@ class PWGUpdater(StandardUpdater):
dataloader: DataLoader,
discriminator_train_start_steps: int,
lambda_adv: float,
output_dir=None):
output_dir: Path=None):
self.models = models
self.generator: Layer = models['generator']
self.discriminator: Layer = models['discriminator']
@ -155,11 +157,11 @@ class PWGUpdater(StandardUpdater):
class PWGEvaluator(StandardEvaluator):
def __init__(self,
models,
criterions,
dataloader,
lambda_adv,
output_dir=None):
models: Dict[str, Layer],
criterions: Dict[str, Layer],
dataloader: DataLoader,
lambda_adv: float,
output_dir: Path=None):
self.models = models
self.generator = models['generator']
self.discriminator = models['discriminator']

@ -27,7 +27,7 @@ class GLU(nn.Layer):
return F.glu(xs, axis=self.dim)
def get_activation(act):
def get_activation(act, **kwargs):
"""Return activation function."""
activation_funcs = {
@ -35,8 +35,9 @@ def get_activation(act):
"tanh": paddle.nn.Tanh,
"relu": paddle.nn.ReLU,
"selu": paddle.nn.SELU,
"leakyrelu": paddle.nn.LeakyReLU,
"swish": paddle.nn.Swish,
"glu": GLU
}
return activation_funcs[act]()
return activation_funcs[act](**kwargs)

@ -18,6 +18,7 @@ from typing import Dict
from paddle import nn
from paddlespeech.t2s.modules.activation import get_activation
from paddlespeech.t2s.modules.causal_conv import CausalConv1D
@ -30,7 +31,7 @@ class ResidualStack(nn.Layer):
channels: int=32,
dilation: int=1,
bias: bool=True,
nonlinear_activation: str="LeakyReLU",
nonlinear_activation: str="leakyrelu",
nonlinear_activation_params: Dict[str, Any]={"negative_slope": 0.2},
pad: str="Pad1D",
pad_params: Dict[str, Any]={"mode": "reflect"},
@ -58,14 +59,17 @@ class ResidualStack(nn.Layer):
Whether to use causal convolution.
"""
super().__init__()
# for compatibility
if nonlinear_activation == "LeakyReLU":
nonlinear_activation = "leakyrelu"
# defile residual stack part
if not use_causal_conv:
assert (kernel_size - 1
) % 2 == 0, "Not support even number kernel size."
self.stack = nn.Sequential(
getattr(nn, nonlinear_activation)(
**nonlinear_activation_params),
get_activation(nonlinear_activation,
**nonlinear_activation_params),
getattr(nn, pad)((kernel_size - 1) // 2 * dilation,
**pad_params),
nn.Conv1D(
@ -74,13 +78,13 @@ class ResidualStack(nn.Layer):
kernel_size,
dilation=dilation,
bias_attr=bias),
getattr(nn, nonlinear_activation)(
**nonlinear_activation_params),
get_activation(nonlinear_activation,
**nonlinear_activation_params),
nn.Conv1D(channels, channels, 1, bias_attr=bias), )
else:
self.stack = nn.Sequential(
getattr(nn, nonlinear_activation)(
**nonlinear_activation_params),
get_activation(nonlinear_activation,
**nonlinear_activation_params),
CausalConv1D(
channels,
channels,
@ -89,8 +93,8 @@ class ResidualStack(nn.Layer):
bias=bias,
pad=pad,
pad_params=pad_params, ),
getattr(nn, nonlinear_activation)(
**nonlinear_activation_params),
get_activation(nonlinear_activation,
**nonlinear_activation_params),
nn.Conv1D(channels, channels, 1, bias_attr=bias), )
# defile extra layer for skip connection

@ -298,8 +298,8 @@ class MultiHeadedAttention(BaseMultiHeadedAttention):
def __init__(self, q_dim, k_dim, v_dim, n_head, n_feat, dropout_rate=0.0):
"""Initialize multi head attention module."""
# NOTE(kan-bayashi): Do not use super().__init__() here since we want to
# overwrite BaseMultiHeadedAttention.__init__() method.
# Do not use super().__init__() here since we want to
# overwrite BaseMultiHeadedAttention.__init__() method.
nn.Layer.__init__(self)
assert n_feat % n_head == 0
# We assume d_v always equals d_k

@ -0,0 +1,164 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
"""StyleMelGAN's TADEResBlock Modules."""
from functools import partial
import paddle.nn.functional as F
from paddle import nn
class TADELayer(nn.Layer):
"""TADE Layer module."""
def __init__(
self,
in_channels: int=64,
aux_channels: int=80,
kernel_size: int=9,
bias: bool=True,
upsample_factor: int=2,
upsample_mode: str="nearest", ):
"""Initilize TADE layer."""
super().__init__()
self.norm = nn.InstanceNorm1D(
in_channels, momentum=0.1, data_format="NCL")
self.aux_conv = nn.Sequential(
nn.Conv1D(
aux_channels,
in_channels,
kernel_size,
1,
bias_attr=bias,
padding=(kernel_size - 1) // 2, ), )
self.gated_conv = nn.Sequential(
nn.Conv1D(
in_channels,
in_channels * 2,
kernel_size,
1,
bias_attr=bias,
padding=(kernel_size - 1) // 2, ), )
self.upsample = nn.Upsample(
scale_factor=upsample_factor, mode=upsample_mode)
def forward(self, x, c):
"""Calculate forward propagation.
Parameters
----------
x : Tensor
Input tensor (B, in_channels, T).
c : Tensor
Auxiliary input tensor (B, aux_channels, T).
Returns
----------
Tensor
Output tensor (B, in_channels, T * upsample_factor).
Tensor
Upsampled aux tensor (B, in_channels, T * upsample_factor).
"""
x = self.norm(x)
# 'bilinear', 'bicubic' and 'nearest' only support 4-D tensor.
c = self.upsample(c.unsqueeze(-1))
c = c[:, :, :, 0]
c = self.aux_conv(c)
cg = self.gated_conv(c)
cg1, cg2 = cg.split(2, axis=1)
# 'bilinear', 'bicubic' and 'nearest' only support 4-D tensor.
y = cg1 * self.upsample(x.unsqueeze(-1))[:, :, :, 0] + cg2
return y, c
class TADEResBlock(nn.Layer):
"""TADEResBlock module."""
def __init__(
self,
in_channels: int=64,
aux_channels: int=80,
kernel_size: int=9,
dilation: int=2,
bias: bool=True,
upsample_factor: int=2,
# this is a diff in paddle, the mode only can be "linear" when input is 3D
upsample_mode: str="nearest",
gated_function: str="softmax", ):
"""Initialize TADEResBlock module."""
super().__init__()
self.tade1 = TADELayer(
in_channels=in_channels,
aux_channels=aux_channels,
kernel_size=kernel_size,
bias=bias,
upsample_factor=1,
upsample_mode=upsample_mode, )
self.gated_conv1 = nn.Conv1D(
in_channels,
in_channels * 2,
kernel_size,
1,
bias_attr=bias,
padding=(kernel_size - 1) // 2, )
self.tade2 = TADELayer(
in_channels=in_channels,
aux_channels=in_channels,
kernel_size=kernel_size,
bias=bias,
upsample_factor=upsample_factor,
upsample_mode=upsample_mode, )
self.gated_conv2 = nn.Conv1D(
in_channels,
in_channels * 2,
kernel_size,
1,
bias_attr=bias,
dilation=dilation,
padding=(kernel_size - 1) // 2 * dilation, )
self.upsample = nn.Upsample(
scale_factor=upsample_factor, mode=upsample_mode)
if gated_function == "softmax":
self.gated_function = partial(F.softmax, axis=1)
elif gated_function == "sigmoid":
self.gated_function = F.sigmoid
else:
raise ValueError(f"{gated_function} is not supported.")
def forward(self, x, c):
"""Calculate forward propagation.
Parameters
----------
x : Tensor
Input tensor (B, in_channels, T).
c : Tensor
Auxiliary input tensor (B, aux_channels, T).
Returns
----------
Tensor
Output tensor (B, in_channels, T * upsample_factor).
Tensor
Upsampled auxirialy tensor (B, in_channels, T * upsample_factor).
"""
residual = x
x, c = self.tade1(x, c)
x = self.gated_conv1(x)
xa, xb = x.split(2, axis=1)
x = self.gated_function(xa) * F.tanh(xb)
x, c = self.tade2(x, c)
x = self.gated_conv2(x)
xa, xb = x.split(2, axis=1)
x = self.gated_function(xa) * F.tanh(xb)
# 'bilinear', 'bicubic' and 'nearest' only support 4-D tensor.
return self.upsample(residual.unsqueeze(-1))[:, :, :, 0] + x, c
Loading…
Cancel
Save