parent
ef8e61813a
commit
dd36eafe34
@ -0,0 +1,136 @@
|
|||||||
|
# This is the configuration file for CSMSC dataset.This configuration is based
|
||||||
|
# on StyleMelGAN paper but uses MSE loss instead of Hinge loss. And I found that
|
||||||
|
# batch_size = 8 is also working good. So maybe if you want to accelerate the training,
|
||||||
|
# you can reduce the batch size (e.g. 8 or 16). Upsampling scales is modified to
|
||||||
|
# fit the shift size 300 pt.
|
||||||
|
# NOTE: batch_max_steps(24000) == prod(noise_upsample_scales)(80) * prod(upsample_scales)(300)
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# FEATURE EXTRACTION SETTING #
|
||||||
|
###########################################################
|
||||||
|
fs: 24000 # Sampling rate.
|
||||||
|
n_fft: 2048 # FFT size. (in samples)
|
||||||
|
n_shift: 300 # Hop size. (in samples)
|
||||||
|
win_length: 1200 # Window length. (in samples)
|
||||||
|
# If set to null, it will be the same as fft_size.
|
||||||
|
window: "hann" # Window function.
|
||||||
|
n_mels: 80 # Number of mel basis.
|
||||||
|
fmin: 80 # Minimum freq in mel basis calculation. (Hz)
|
||||||
|
fmax: 7600 # Maximum frequency in mel basis calculation. (Hz)
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# GENERATOR NETWORK ARCHITECTURE SETTING #
|
||||||
|
###########################################################
|
||||||
|
generator_params:
|
||||||
|
in_channels: 128 # Number of input channels.
|
||||||
|
aux_channels: 80
|
||||||
|
channels: 64 # Initial number of channels for conv layers.
|
||||||
|
out_channels: 1 # Number of output channels.
|
||||||
|
kernel_size: 9 # Kernel size of initial and final conv layers.
|
||||||
|
dilation: 2
|
||||||
|
bias: True
|
||||||
|
noise_upsample_scales: [10, 2, 2, 2]
|
||||||
|
noise_upsample_activation: "leakyrelu"
|
||||||
|
noise_upsample_activation_params:
|
||||||
|
negative_slope: 0.2
|
||||||
|
upsample_scales: [5, 1, 5, 1, 3, 1, 2, 2, 1] # List of Upsampling scales. prod(upsample_scales) == n_shift
|
||||||
|
upsample_mode: "nearest"
|
||||||
|
gated_function: "softmax"
|
||||||
|
use_weight_norm: True # Whether to use weight normalization.
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# DISCRIMINATOR NETWORK ARCHITECTURE SETTING #
|
||||||
|
###########################################################
|
||||||
|
discriminator_params:
|
||||||
|
repeats: 4
|
||||||
|
window_sizes: [512, 1024, 2048, 4096]
|
||||||
|
pqmf_params:
|
||||||
|
- [1, None, None, None]
|
||||||
|
- [2, 62, 0.26700, 9.0]
|
||||||
|
- [4, 62, 0.14200, 9.0]
|
||||||
|
- [8, 62, 0.07949, 9.0]
|
||||||
|
discriminator_params:
|
||||||
|
out_channels: 1 # Number of output channels.
|
||||||
|
kernel_sizes: [5, 3] # List of kernel size.
|
||||||
|
channels: 16 # Number of channels of the initial conv layer.
|
||||||
|
max_downsample_channels: 512 # Maximum number of channels of downsampling layers.
|
||||||
|
bias: True
|
||||||
|
downsample_scales: [4, 4, 4, 1] # List of downsampling scales.
|
||||||
|
nonlinear_activation: "leakyrelu" # Nonlinear activation function.
|
||||||
|
nonlinear_activation_params: # Parameters of nonlinear activation function.
|
||||||
|
negative_slope: 0.2
|
||||||
|
use_weight_norm: True # Whether to use weight norm.
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# STFT LOSS SETTING #
|
||||||
|
###########################################################
|
||||||
|
use_stft_loss: true
|
||||||
|
stft_loss_params:
|
||||||
|
fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss.
|
||||||
|
hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss
|
||||||
|
win_lengths: [600, 1200, 240] # List of window length for STFT-based loss.
|
||||||
|
window: "hann" # Window function for STFT-based loss
|
||||||
|
lambda_aux: 1.0 # Loss balancing coefficient for aux loss.
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# ADVERSARIAL LOSS SETTING #
|
||||||
|
###########################################################
|
||||||
|
lambda_adv: 1.0 # Loss balancing coefficient for adv loss.
|
||||||
|
generator_adv_loss_params:
|
||||||
|
average_by_discriminators: false # Whether to average loss by #discriminators.
|
||||||
|
discriminator_adv_loss_params:
|
||||||
|
average_by_discriminators: false # Whether to average loss by #discriminators.
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# DATA LOADER SETTING #
|
||||||
|
###########################################################
|
||||||
|
batch_size: 32 # Batch size.
|
||||||
|
# batch_max_steps(24000) == prod(noise_upsample_scales)(80) * prod(upsample_scales)(300, n_shift)
|
||||||
|
batch_max_steps: 24000 # Length of each audio in batch. Make sure dividable by n_shift.
|
||||||
|
num_workers: 2 # Number of workers in Pytorch DataLoader.
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# OPTIMIZER & SCHEDULER SETTING #
|
||||||
|
###########################################################
|
||||||
|
generator_optimizer_params:
|
||||||
|
beta1: 0.5
|
||||||
|
beta2: 0.9
|
||||||
|
weight_decay: 0.0 # Generator's weight decay coefficient.
|
||||||
|
generator_scheduler_params:
|
||||||
|
learning_rate: 1.0e-4 # Generator's learning rate.
|
||||||
|
gamma: 0.5 # Generator's scheduler gamma.
|
||||||
|
milestones: # At each milestone, lr will be multiplied by gamma.
|
||||||
|
- 100000
|
||||||
|
- 300000
|
||||||
|
- 500000
|
||||||
|
- 700000
|
||||||
|
- 900000
|
||||||
|
generator_grad_norm: -1 # Generator's gradient norm.
|
||||||
|
discriminator_optimizer_params:
|
||||||
|
beta1: 0.5
|
||||||
|
beta2: 0.9
|
||||||
|
weight_decay: 0.0 # Discriminator's weight decay coefficient.
|
||||||
|
discriminator_scheduler_params:
|
||||||
|
learning_rate: 2.0e-4 # Discriminator's learning rate.
|
||||||
|
gamma: 0.5 # Discriminator's scheduler gamma.
|
||||||
|
milestones: # At each milestone, lr will be multiplied by gamma.
|
||||||
|
- 200000
|
||||||
|
- 400000
|
||||||
|
- 600000
|
||||||
|
- 800000
|
||||||
|
discriminator_grad_norm: -1 # Discriminator's gradient norm.
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# INTERVAL SETTING #
|
||||||
|
###########################################################
|
||||||
|
discriminator_train_start_steps: 100000 # Number of steps to start to train discriminator.
|
||||||
|
train_max_steps: 1500000 # Number of training steps.
|
||||||
|
save_interval_steps: 5000 # Interval steps to save checkpoint.
|
||||||
|
eval_interval_steps: 1000 # Interval steps to evaluate the network.
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# OTHER SETTING #
|
||||||
|
###########################################################
|
||||||
|
num_snapshots: 10 # max number of snapshots to keep while training
|
||||||
|
seed: 42 # random seed for paddle, random, and np.random
|
@ -0,0 +1,55 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
stage=0
|
||||||
|
stop_stage=100
|
||||||
|
|
||||||
|
config_path=$1
|
||||||
|
|
||||||
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
# get durations from MFA's result
|
||||||
|
echo "Generate durations.txt from MFA results ..."
|
||||||
|
python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \
|
||||||
|
--inputdir=./baker_alignment_tone \
|
||||||
|
--output=durations.txt \
|
||||||
|
--config=${config_path}
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||||
|
# extract features
|
||||||
|
echo "Extract features ..."
|
||||||
|
python3 ${BIN_DIR}/../preprocess.py \
|
||||||
|
--rootdir=~/datasets/BZNSYP/ \
|
||||||
|
--dataset=baker \
|
||||||
|
--dumpdir=dump \
|
||||||
|
--dur-file=durations.txt \
|
||||||
|
--config=${config_path} \
|
||||||
|
--cut-sil=True \
|
||||||
|
--num-cpu=20
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||||
|
# get features' stats(mean and std)
|
||||||
|
echo "Get features' stats ..."
|
||||||
|
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
|
||||||
|
--metadata=dump/train/raw/metadata.jsonl \
|
||||||
|
--field-name="feats"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||||
|
# normalize, dev and test should use train's stats
|
||||||
|
echo "Normalize ..."
|
||||||
|
|
||||||
|
python3 ${BIN_DIR}/../normalize.py \
|
||||||
|
--metadata=dump/train/raw/metadata.jsonl \
|
||||||
|
--dumpdir=dump/train/norm \
|
||||||
|
--stats=dump/train/feats_stats.npy
|
||||||
|
python3 ${BIN_DIR}/../normalize.py \
|
||||||
|
--metadata=dump/dev/raw/metadata.jsonl \
|
||||||
|
--dumpdir=dump/dev/norm \
|
||||||
|
--stats=dump/train/feats_stats.npy
|
||||||
|
|
||||||
|
python3 ${BIN_DIR}/../normalize.py \
|
||||||
|
--metadata=dump/test/raw/metadata.jsonl \
|
||||||
|
--dumpdir=dump/test/norm \
|
||||||
|
--stats=dump/train/feats_stats.npy
|
||||||
|
fi
|
@ -0,0 +1,14 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
config_path=$1
|
||||||
|
train_output_path=$2
|
||||||
|
ckpt_name=$3
|
||||||
|
|
||||||
|
FLAGS_allocator_strategy=naive_best_fit \
|
||||||
|
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
|
||||||
|
python3 ${BIN_DIR}/../synthesize.py \
|
||||||
|
--config=${config_path} \
|
||||||
|
--checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
|
||||||
|
--test-metadata=dump/test/norm/metadata.jsonl \
|
||||||
|
--output-dir=${train_output_path}/test \
|
||||||
|
--generator-type=style_melgan
|
@ -0,0 +1,13 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
config_path=$1
|
||||||
|
train_output_path=$2
|
||||||
|
|
||||||
|
FLAGS_cudnn_exhaustive_search=true \
|
||||||
|
FLAGS_conv_workspace_size_limit=4000 \
|
||||||
|
python ${BIN_DIR}/train.py \
|
||||||
|
--train-metadata=dump/train/norm/metadata.jsonl \
|
||||||
|
--dev-metadata=dump/dev/norm/metadata.jsonl \
|
||||||
|
--config=${config_path} \
|
||||||
|
--output-dir=${train_output_path} \
|
||||||
|
--ngpu=1
|
@ -0,0 +1,13 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
export MAIN_ROOT=`realpath ${PWD}/../../../`
|
||||||
|
|
||||||
|
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
|
||||||
|
export LC_ALL=C
|
||||||
|
|
||||||
|
export PYTHONDONTWRITEBYTECODE=1
|
||||||
|
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||||
|
export PYTHONIOENCODING=UTF-8
|
||||||
|
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
||||||
|
|
||||||
|
MODEL=style_melgan
|
||||||
|
export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/gan_vocoder/${MODEL}
|
@ -0,0 +1,32 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
source path.sh
|
||||||
|
|
||||||
|
gpus=0,1
|
||||||
|
stage=0
|
||||||
|
stop_stage=100
|
||||||
|
|
||||||
|
conf_path=conf/default.yaml
|
||||||
|
train_output_path=exp/default
|
||||||
|
ckpt_name=snapshot_iter_50000.pdz
|
||||||
|
|
||||||
|
# with the following command, you can choose the stage range you want to run
|
||||||
|
# such as `./run.sh --stage 0 --stop-stage 0`
|
||||||
|
# this can not be mixed use with `$1`, `$2` ...
|
||||||
|
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
|
||||||
|
|
||||||
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
# prepare data
|
||||||
|
./local/preprocess.sh ${conf_path} || exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||||
|
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
|
||||||
|
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||||
|
# synthesize
|
||||||
|
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
|
||||||
|
fi
|
@ -1,103 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import jsonlines
|
|
||||||
import numpy as np
|
|
||||||
import paddle
|
|
||||||
import soundfile as sf
|
|
||||||
import yaml
|
|
||||||
from paddle import distributed as dist
|
|
||||||
from timer import timer
|
|
||||||
from yacs.config import CfgNode
|
|
||||||
|
|
||||||
from paddlespeech.t2s.datasets.data_table import DataTable
|
|
||||||
from paddlespeech.t2s.models.parallel_wavegan import PWGGenerator
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Synthesize with parallel wavegan.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--config", type=str, help="parallel wavegan config file.")
|
|
||||||
parser.add_argument("--checkpoint", type=str, help="snapshot to load.")
|
|
||||||
parser.add_argument("--test-metadata", type=str, help="dev data.")
|
|
||||||
parser.add_argument("--output-dir", type=str, help="output dir.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
|
|
||||||
parser.add_argument("--verbose", type=int, default=1, help="verbose.")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
with open(args.config) as f:
|
|
||||||
config = CfgNode(yaml.safe_load(f))
|
|
||||||
|
|
||||||
print("========Args========")
|
|
||||||
print(yaml.safe_dump(vars(args)))
|
|
||||||
print("========Config========")
|
|
||||||
print(config)
|
|
||||||
print(
|
|
||||||
f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.ngpu == 0:
|
|
||||||
paddle.set_device("cpu")
|
|
||||||
elif args.ngpu > 0:
|
|
||||||
paddle.set_device("gpu")
|
|
||||||
else:
|
|
||||||
print("ngpu should >= 0 !")
|
|
||||||
generator = PWGGenerator(**config["generator_params"])
|
|
||||||
state_dict = paddle.load(args.checkpoint)
|
|
||||||
generator.set_state_dict(state_dict["generator_params"])
|
|
||||||
|
|
||||||
generator.remove_weight_norm()
|
|
||||||
generator.eval()
|
|
||||||
with jsonlines.open(args.test_metadata, 'r') as reader:
|
|
||||||
metadata = list(reader)
|
|
||||||
|
|
||||||
test_dataset = DataTable(
|
|
||||||
metadata,
|
|
||||||
fields=['utt_id', 'feats'],
|
|
||||||
converters={
|
|
||||||
'utt_id': None,
|
|
||||||
'feats': np.load,
|
|
||||||
})
|
|
||||||
output_dir = Path(args.output_dir)
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
N = 0
|
|
||||||
T = 0
|
|
||||||
for example in test_dataset:
|
|
||||||
utt_id = example['utt_id']
|
|
||||||
mel = example['feats']
|
|
||||||
mel = paddle.to_tensor(mel) # (T, C)
|
|
||||||
with timer() as t:
|
|
||||||
with paddle.no_grad():
|
|
||||||
wav = generator.inference(c=mel)
|
|
||||||
wav = wav.numpy()
|
|
||||||
N += wav.size
|
|
||||||
T += t.elapse
|
|
||||||
speed = wav.size / t.elapse
|
|
||||||
rtf = config.fs / speed
|
|
||||||
print(
|
|
||||||
f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
|
|
||||||
)
|
|
||||||
sf.write(str(output_dir / (utt_id + ".wav")), wav, samplerate=config.fs)
|
|
||||||
print(f"generation speed: {N / T}Hz, RTF: {config.fs / (N / T) }")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
@ -0,0 +1,258 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import jsonlines
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import yaml
|
||||||
|
from paddle import DataParallel
|
||||||
|
from paddle import distributed as dist
|
||||||
|
from paddle import nn
|
||||||
|
from paddle.io import DataLoader
|
||||||
|
from paddle.io import DistributedBatchSampler
|
||||||
|
from paddle.optimizer import Adam
|
||||||
|
from paddle.optimizer.lr import MultiStepDecay
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
|
||||||
|
from paddlespeech.t2s.datasets.data_table import DataTable
|
||||||
|
from paddlespeech.t2s.datasets.vocoder_batch_fn import Clip
|
||||||
|
from paddlespeech.t2s.models.melgan import StyleMelGANDiscriminator
|
||||||
|
from paddlespeech.t2s.models.melgan import StyleMelGANEvaluator
|
||||||
|
from paddlespeech.t2s.models.melgan import StyleMelGANGenerator
|
||||||
|
from paddlespeech.t2s.models.melgan import StyleMelGANUpdater
|
||||||
|
from paddlespeech.t2s.modules.losses import DiscriminatorAdversarialLoss
|
||||||
|
from paddlespeech.t2s.modules.losses import GeneratorAdversarialLoss
|
||||||
|
from paddlespeech.t2s.modules.losses import MultiResolutionSTFTLoss
|
||||||
|
from paddlespeech.t2s.training.extensions.snapshot import Snapshot
|
||||||
|
from paddlespeech.t2s.training.extensions.visualizer import VisualDL
|
||||||
|
from paddlespeech.t2s.training.seeding import seed_everything
|
||||||
|
from paddlespeech.t2s.training.trainer import Trainer
|
||||||
|
|
||||||
|
|
||||||
|
def train_sp(args, config):
|
||||||
|
# decides device type and whether to run in parallel
|
||||||
|
# setup running environment correctly
|
||||||
|
world_size = paddle.distributed.get_world_size()
|
||||||
|
if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0:
|
||||||
|
paddle.set_device("cpu")
|
||||||
|
else:
|
||||||
|
paddle.set_device("gpu")
|
||||||
|
if world_size > 1:
|
||||||
|
paddle.distributed.init_parallel_env()
|
||||||
|
|
||||||
|
# set the random seed, it is a must for multiprocess training
|
||||||
|
seed_everything(config.seed)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# dataloader has been too verbose
|
||||||
|
logging.getLogger("DataLoader").disabled = True
|
||||||
|
|
||||||
|
# construct dataset for training and validation
|
||||||
|
with jsonlines.open(args.train_metadata, 'r') as reader:
|
||||||
|
train_metadata = list(reader)
|
||||||
|
train_dataset = DataTable(
|
||||||
|
data=train_metadata,
|
||||||
|
fields=["wave", "feats"],
|
||||||
|
converters={
|
||||||
|
"wave": np.load,
|
||||||
|
"feats": np.load,
|
||||||
|
}, )
|
||||||
|
with jsonlines.open(args.dev_metadata, 'r') as reader:
|
||||||
|
dev_metadata = list(reader)
|
||||||
|
dev_dataset = DataTable(
|
||||||
|
data=dev_metadata,
|
||||||
|
fields=["wave", "feats"],
|
||||||
|
converters={
|
||||||
|
"wave": np.load,
|
||||||
|
"feats": np.load,
|
||||||
|
}, )
|
||||||
|
|
||||||
|
# collate function and dataloader
|
||||||
|
train_sampler = DistributedBatchSampler(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
drop_last=True)
|
||||||
|
dev_sampler = DistributedBatchSampler(
|
||||||
|
dev_dataset,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
drop_last=False)
|
||||||
|
print("samplers done!")
|
||||||
|
|
||||||
|
if "aux_context_window" in config.generator_params:
|
||||||
|
aux_context_window = config.generator_params.aux_context_window
|
||||||
|
else:
|
||||||
|
aux_context_window = 0
|
||||||
|
train_batch_fn = Clip(
|
||||||
|
batch_max_steps=config.batch_max_steps,
|
||||||
|
hop_size=config.n_shift,
|
||||||
|
aux_context_window=aux_context_window)
|
||||||
|
|
||||||
|
train_dataloader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_sampler=train_sampler,
|
||||||
|
collate_fn=train_batch_fn,
|
||||||
|
num_workers=config.num_workers)
|
||||||
|
|
||||||
|
dev_dataloader = DataLoader(
|
||||||
|
dev_dataset,
|
||||||
|
batch_sampler=dev_sampler,
|
||||||
|
collate_fn=train_batch_fn,
|
||||||
|
num_workers=config.num_workers)
|
||||||
|
print("dataloaders done!")
|
||||||
|
|
||||||
|
generator = StyleMelGANGenerator(**config["generator_params"])
|
||||||
|
discriminator = StyleMelGANDiscriminator(**config["discriminator_params"])
|
||||||
|
if world_size > 1:
|
||||||
|
generator = DataParallel(generator)
|
||||||
|
discriminator = DataParallel(discriminator)
|
||||||
|
print("models done!")
|
||||||
|
criterion_stft = MultiResolutionSTFTLoss(**config["stft_loss_params"])
|
||||||
|
|
||||||
|
criterion_gen_adv = GeneratorAdversarialLoss(
|
||||||
|
**config["generator_adv_loss_params"])
|
||||||
|
criterion_dis_adv = DiscriminatorAdversarialLoss(
|
||||||
|
**config["discriminator_adv_loss_params"])
|
||||||
|
print("criterions done!")
|
||||||
|
|
||||||
|
lr_schedule_g = MultiStepDecay(**config["generator_scheduler_params"])
|
||||||
|
# Compared to multi_band_melgan.v1 config, Adam optimizer without gradient norm is used
|
||||||
|
generator_grad_norm = config["generator_grad_norm"]
|
||||||
|
gradient_clip_g = nn.ClipGradByGlobalNorm(
|
||||||
|
generator_grad_norm) if generator_grad_norm > 0 else None
|
||||||
|
print("gradient_clip_g:", gradient_clip_g)
|
||||||
|
|
||||||
|
optimizer_g = Adam(
|
||||||
|
learning_rate=lr_schedule_g,
|
||||||
|
grad_clip=gradient_clip_g,
|
||||||
|
parameters=generator.parameters(),
|
||||||
|
**config["generator_optimizer_params"])
|
||||||
|
lr_schedule_d = MultiStepDecay(**config["discriminator_scheduler_params"])
|
||||||
|
discriminator_grad_norm = config["discriminator_grad_norm"]
|
||||||
|
gradient_clip_d = nn.ClipGradByGlobalNorm(
|
||||||
|
discriminator_grad_norm) if discriminator_grad_norm > 0 else None
|
||||||
|
print("gradient_clip_d:", gradient_clip_d)
|
||||||
|
optimizer_d = Adam(
|
||||||
|
learning_rate=lr_schedule_d,
|
||||||
|
grad_clip=gradient_clip_d,
|
||||||
|
parameters=discriminator.parameters(),
|
||||||
|
**config["discriminator_optimizer_params"])
|
||||||
|
print("optimizers done!")
|
||||||
|
|
||||||
|
output_dir = Path(args.output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
config_name = args.config.split("/")[-1]
|
||||||
|
# copy conf to output_dir
|
||||||
|
shutil.copyfile(args.config, output_dir / config_name)
|
||||||
|
|
||||||
|
updater = StyleMelGANUpdater(
|
||||||
|
models={
|
||||||
|
"generator": generator,
|
||||||
|
"discriminator": discriminator,
|
||||||
|
},
|
||||||
|
optimizers={
|
||||||
|
"generator": optimizer_g,
|
||||||
|
"discriminator": optimizer_d,
|
||||||
|
},
|
||||||
|
criterions={
|
||||||
|
"stft": criterion_stft,
|
||||||
|
"gen_adv": criterion_gen_adv,
|
||||||
|
"dis_adv": criterion_dis_adv,
|
||||||
|
},
|
||||||
|
schedulers={
|
||||||
|
"generator": lr_schedule_g,
|
||||||
|
"discriminator": lr_schedule_d,
|
||||||
|
},
|
||||||
|
dataloader=train_dataloader,
|
||||||
|
discriminator_train_start_steps=config.discriminator_train_start_steps,
|
||||||
|
lambda_adv=config.lambda_adv,
|
||||||
|
output_dir=output_dir)
|
||||||
|
|
||||||
|
evaluator = StyleMelGANEvaluator(
|
||||||
|
models={
|
||||||
|
"generator": generator,
|
||||||
|
"discriminator": discriminator,
|
||||||
|
},
|
||||||
|
criterions={
|
||||||
|
"stft": criterion_stft,
|
||||||
|
"gen_adv": criterion_gen_adv,
|
||||||
|
"dis_adv": criterion_dis_adv,
|
||||||
|
},
|
||||||
|
dataloader=dev_dataloader,
|
||||||
|
lambda_adv=config.lambda_adv,
|
||||||
|
output_dir=output_dir)
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
updater,
|
||||||
|
stop_trigger=(config.train_max_steps, "iteration"),
|
||||||
|
out=output_dir)
|
||||||
|
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
trainer.extend(
|
||||||
|
evaluator, trigger=(config.eval_interval_steps, 'iteration'))
|
||||||
|
trainer.extend(VisualDL(output_dir), trigger=(1, 'iteration'))
|
||||||
|
trainer.extend(
|
||||||
|
Snapshot(max_size=config.num_snapshots),
|
||||||
|
trigger=(config.save_interval_steps, 'iteration'))
|
||||||
|
|
||||||
|
print("Trainer Done!")
|
||||||
|
trainer.run()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# parse args and config and redirect to train_sp
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Train a Multi-Band MelGAN model.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--config", type=str, help="config file to overwrite default config.")
|
||||||
|
parser.add_argument("--train-metadata", type=str, help="training data.")
|
||||||
|
parser.add_argument("--dev-metadata", type=str, help="dev data.")
|
||||||
|
parser.add_argument("--output-dir", type=str, help="output dir.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
|
||||||
|
parser.add_argument("--verbose", type=int, default=1, help="verbose.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
with open(args.config, 'rt') as f:
|
||||||
|
config = CfgNode(yaml.safe_load(f))
|
||||||
|
|
||||||
|
print("========Args========")
|
||||||
|
print(yaml.safe_dump(vars(args)))
|
||||||
|
print("========Config========")
|
||||||
|
print(config)
|
||||||
|
print(
|
||||||
|
f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# dispatch
|
||||||
|
if args.ngpu > 1:
|
||||||
|
dist.spawn(train_sp, (args, config), nprocs=args.ngpu)
|
||||||
|
else:
|
||||||
|
train_sp(args, config)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -0,0 +1,221 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import distributed as dist
|
||||||
|
from paddle.io import DataLoader
|
||||||
|
from paddle.nn import Layer
|
||||||
|
from paddle.optimizer import Optimizer
|
||||||
|
from paddle.optimizer.lr import LRScheduler
|
||||||
|
|
||||||
|
from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
|
||||||
|
from paddlespeech.t2s.training.reporter import report
|
||||||
|
from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater
|
||||||
|
from paddlespeech.t2s.training.updaters.standard_updater import UpdaterState
|
||||||
|
logging.basicConfig(
|
||||||
|
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
|
||||||
|
datefmt='[%Y-%m-%d %H:%M:%S]')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
class StyleMelGANUpdater(StandardUpdater):
|
||||||
|
def __init__(self,
|
||||||
|
models: Dict[str, Layer],
|
||||||
|
optimizers: Dict[str, Optimizer],
|
||||||
|
criterions: Dict[str, Layer],
|
||||||
|
schedulers: Dict[str, LRScheduler],
|
||||||
|
dataloader: DataLoader,
|
||||||
|
discriminator_train_start_steps: int,
|
||||||
|
lambda_adv: float,
|
||||||
|
lambda_aux: float=1.0,
|
||||||
|
output_dir: Path=None):
|
||||||
|
self.models = models
|
||||||
|
self.generator: Layer = models['generator']
|
||||||
|
self.discriminator: Layer = models['discriminator']
|
||||||
|
|
||||||
|
self.optimizers = optimizers
|
||||||
|
self.optimizer_g: Optimizer = optimizers['generator']
|
||||||
|
self.optimizer_d: Optimizer = optimizers['discriminator']
|
||||||
|
|
||||||
|
self.criterions = criterions
|
||||||
|
self.criterion_stft = criterions['stft']
|
||||||
|
self.criterion_gen_adv = criterions["gen_adv"]
|
||||||
|
self.criterion_dis_adv = criterions["dis_adv"]
|
||||||
|
|
||||||
|
self.schedulers = schedulers
|
||||||
|
self.scheduler_g = schedulers['generator']
|
||||||
|
self.scheduler_d = schedulers['discriminator']
|
||||||
|
|
||||||
|
self.dataloader = dataloader
|
||||||
|
|
||||||
|
self.discriminator_train_start_steps = discriminator_train_start_steps
|
||||||
|
self.lambda_adv = lambda_adv
|
||||||
|
self.lambda_aux = lambda_aux
|
||||||
|
self.state = UpdaterState(iteration=0, epoch=0)
|
||||||
|
|
||||||
|
self.train_iterator = iter(self.dataloader)
|
||||||
|
|
||||||
|
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||||
|
self.filehandler = logging.FileHandler(str(log_file))
|
||||||
|
logger.addHandler(self.filehandler)
|
||||||
|
self.logger = logger
|
||||||
|
self.msg = ""
|
||||||
|
|
||||||
|
def update_core(self, batch):
|
||||||
|
self.msg = "Rank: {}, ".format(dist.get_rank())
|
||||||
|
losses_dict = {}
|
||||||
|
|
||||||
|
# parse batch
|
||||||
|
wav, mel = batch
|
||||||
|
# Generator
|
||||||
|
# (B, out_channels, T ** prod(upsample_scales)
|
||||||
|
wav_ = self.generator(mel)
|
||||||
|
|
||||||
|
# initialize
|
||||||
|
gen_loss = 0.0
|
||||||
|
|
||||||
|
# full band Multi-resolution stft loss
|
||||||
|
sc_loss, mag_loss = self.criterion_stft(wav_, wav)
|
||||||
|
gen_loss += sc_loss + mag_loss
|
||||||
|
report("train/spectral_convergence_loss", float(sc_loss))
|
||||||
|
report("train/log_stft_magnitude_loss", float(mag_loss))
|
||||||
|
losses_dict["spectral_convergence_loss"] = float(sc_loss)
|
||||||
|
losses_dict["log_stft_magnitude_loss"] = float(mag_loss)
|
||||||
|
|
||||||
|
gen_loss *= self.lambda_aux
|
||||||
|
|
||||||
|
## Adversarial loss
|
||||||
|
if self.state.iteration > self.discriminator_train_start_steps:
|
||||||
|
p_ = self.discriminator(wav_)
|
||||||
|
adv_loss = self.criterion_gen_adv(p_)
|
||||||
|
report("train/adversarial_loss", float(adv_loss))
|
||||||
|
losses_dict["adversarial_loss"] = float(adv_loss)
|
||||||
|
gen_loss += self.lambda_adv * adv_loss
|
||||||
|
|
||||||
|
report("train/generator_loss", float(gen_loss))
|
||||||
|
losses_dict["generator_loss"] = float(gen_loss)
|
||||||
|
|
||||||
|
self.optimizer_g.clear_grad()
|
||||||
|
gen_loss.backward()
|
||||||
|
|
||||||
|
self.optimizer_g.step()
|
||||||
|
self.scheduler_g.step()
|
||||||
|
|
||||||
|
# Disctiminator
|
||||||
|
if self.state.iteration > self.discriminator_train_start_steps:
|
||||||
|
# re-compute wav_ which leads better quality
|
||||||
|
with paddle.no_grad():
|
||||||
|
wav_ = self.generator(mel)
|
||||||
|
|
||||||
|
p = self.discriminator(wav)
|
||||||
|
p_ = self.discriminator(wav_.detach())
|
||||||
|
real_loss, fake_loss = self.criterion_dis_adv(p_, p)
|
||||||
|
dis_loss = real_loss + fake_loss
|
||||||
|
report("train/real_loss", float(real_loss))
|
||||||
|
report("train/fake_loss", float(fake_loss))
|
||||||
|
report("train/discriminator_loss", float(dis_loss))
|
||||||
|
losses_dict["real_loss"] = float(real_loss)
|
||||||
|
losses_dict["fake_loss"] = float(fake_loss)
|
||||||
|
losses_dict["discriminator_loss"] = float(dis_loss)
|
||||||
|
|
||||||
|
self.optimizer_d.clear_grad()
|
||||||
|
dis_loss.backward()
|
||||||
|
|
||||||
|
self.optimizer_d.step()
|
||||||
|
self.scheduler_d.step()
|
||||||
|
|
||||||
|
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||||
|
for k, v in losses_dict.items())
|
||||||
|
|
||||||
|
|
||||||
|
class StyleMelGANEvaluator(StandardEvaluator):
|
||||||
|
def __init__(self,
|
||||||
|
models: Dict[str, Layer],
|
||||||
|
criterions: Dict[str, Layer],
|
||||||
|
dataloader: DataLoader,
|
||||||
|
lambda_adv: float,
|
||||||
|
lambda_aux: float=1.0,
|
||||||
|
output_dir: Path=None):
|
||||||
|
self.models = models
|
||||||
|
self.generator = models['generator']
|
||||||
|
self.discriminator = models['discriminator']
|
||||||
|
|
||||||
|
self.criterions = criterions
|
||||||
|
self.criterion_stft = criterions['stft']
|
||||||
|
self.criterion_gen_adv = criterions["gen_adv"]
|
||||||
|
self.criterion_dis_adv = criterions["dis_adv"]
|
||||||
|
|
||||||
|
self.dataloader = dataloader
|
||||||
|
self.lambda_adv = lambda_adv
|
||||||
|
self.lambda_aux = lambda_aux
|
||||||
|
|
||||||
|
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||||
|
self.filehandler = logging.FileHandler(str(log_file))
|
||||||
|
logger.addHandler(self.filehandler)
|
||||||
|
self.logger = logger
|
||||||
|
self.msg = ""
|
||||||
|
|
||||||
|
def evaluate_core(self, batch):
|
||||||
|
# logging.debug("Evaluate: ")
|
||||||
|
self.msg = "Evaluate: "
|
||||||
|
losses_dict = {}
|
||||||
|
|
||||||
|
wav, mel = batch
|
||||||
|
# Generator
|
||||||
|
# (B, out_channels, T ** prod(upsample_scales)
|
||||||
|
wav_ = self.generator(mel)
|
||||||
|
|
||||||
|
## Adversarial loss
|
||||||
|
p_ = self.discriminator(wav_)
|
||||||
|
adv_loss = self.criterion_gen_adv(p_)
|
||||||
|
|
||||||
|
report("eval/adversarial_loss", float(adv_loss))
|
||||||
|
losses_dict["adversarial_loss"] = float(adv_loss)
|
||||||
|
gen_loss = self.lambda_adv * adv_loss
|
||||||
|
|
||||||
|
# initialize
|
||||||
|
aux_loss = 0.0
|
||||||
|
# Multi-resolution stft loss
|
||||||
|
sc_loss, mag_loss = self.criterion_stft(wav_, wav)
|
||||||
|
aux_loss += sc_loss + mag_loss
|
||||||
|
report("eval/spectral_convergence_loss", float(sc_loss))
|
||||||
|
report("eval/log_stft_magnitude_loss", float(mag_loss))
|
||||||
|
losses_dict["spectral_convergence_loss"] = float(sc_loss)
|
||||||
|
losses_dict["log_stft_magnitude_loss"] = float(mag_loss)
|
||||||
|
|
||||||
|
aux_loss *= self.lambda_aux
|
||||||
|
gen_loss += aux_loss
|
||||||
|
|
||||||
|
report("eval/generator_loss", float(gen_loss))
|
||||||
|
losses_dict["generator_loss"] = float(gen_loss)
|
||||||
|
|
||||||
|
# Disctiminator
|
||||||
|
p = self.discriminator(wav)
|
||||||
|
real_loss, fake_loss = self.criterion_dis_adv(p_, p)
|
||||||
|
dis_loss = real_loss + fake_loss
|
||||||
|
report("eval/real_loss", float(real_loss))
|
||||||
|
report("eval/fake_loss", float(fake_loss))
|
||||||
|
report("eval/discriminator_loss", float(dis_loss))
|
||||||
|
|
||||||
|
losses_dict["real_loss"] = float(real_loss)
|
||||||
|
losses_dict["fake_loss"] = float(fake_loss)
|
||||||
|
losses_dict["discriminator_loss"] = float(dis_loss)
|
||||||
|
|
||||||
|
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||||
|
for k, v in losses_dict.items())
|
||||||
|
self.logger.info(self.msg)
|
@ -0,0 +1,164 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Modified from espnet(https://github.com/espnet/espnet)
|
||||||
|
"""StyleMelGAN's TADEResBlock Modules."""
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
from paddle import nn
|
||||||
|
|
||||||
|
|
||||||
|
class TADELayer(nn.Layer):
|
||||||
|
"""TADE Layer module."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int=64,
|
||||||
|
aux_channels: int=80,
|
||||||
|
kernel_size: int=9,
|
||||||
|
bias: bool=True,
|
||||||
|
upsample_factor: int=2,
|
||||||
|
upsample_mode: str="nearest", ):
|
||||||
|
"""Initilize TADE layer."""
|
||||||
|
super().__init__()
|
||||||
|
self.norm = nn.InstanceNorm1D(
|
||||||
|
in_channels, momentum=0.1, data_format="NCL")
|
||||||
|
self.aux_conv = nn.Sequential(
|
||||||
|
nn.Conv1D(
|
||||||
|
aux_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
bias_attr=bias,
|
||||||
|
padding=(kernel_size - 1) // 2, ), )
|
||||||
|
self.gated_conv = nn.Sequential(
|
||||||
|
nn.Conv1D(
|
||||||
|
in_channels,
|
||||||
|
in_channels * 2,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
bias_attr=bias,
|
||||||
|
padding=(kernel_size - 1) // 2, ), )
|
||||||
|
self.upsample = nn.Upsample(
|
||||||
|
scale_factor=upsample_factor, mode=upsample_mode)
|
||||||
|
|
||||||
|
def forward(self, x, c):
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : Tensor
|
||||||
|
Input tensor (B, in_channels, T).
|
||||||
|
c : Tensor
|
||||||
|
Auxiliary input tensor (B, aux_channels, T).
|
||||||
|
Returns
|
||||||
|
----------
|
||||||
|
Tensor
|
||||||
|
Output tensor (B, in_channels, T * upsample_factor).
|
||||||
|
Tensor
|
||||||
|
Upsampled aux tensor (B, in_channels, T * upsample_factor).
|
||||||
|
"""
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
# 'bilinear', 'bicubic' and 'nearest' only support 4-D tensor.
|
||||||
|
c = self.upsample(c.unsqueeze(-1))
|
||||||
|
c = c[:, :, :, 0]
|
||||||
|
|
||||||
|
c = self.aux_conv(c)
|
||||||
|
cg = self.gated_conv(c)
|
||||||
|
cg1, cg2 = cg.split(2, axis=1)
|
||||||
|
# 'bilinear', 'bicubic' and 'nearest' only support 4-D tensor.
|
||||||
|
y = cg1 * self.upsample(x.unsqueeze(-1))[:, :, :, 0] + cg2
|
||||||
|
return y, c
|
||||||
|
|
||||||
|
|
||||||
|
class TADEResBlock(nn.Layer):
|
||||||
|
"""TADEResBlock module."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int=64,
|
||||||
|
aux_channels: int=80,
|
||||||
|
kernel_size: int=9,
|
||||||
|
dilation: int=2,
|
||||||
|
bias: bool=True,
|
||||||
|
upsample_factor: int=2,
|
||||||
|
# this is a diff in paddle, the mode only can be "linear" when input is 3D
|
||||||
|
upsample_mode: str="nearest",
|
||||||
|
gated_function: str="softmax", ):
|
||||||
|
"""Initialize TADEResBlock module."""
|
||||||
|
super().__init__()
|
||||||
|
self.tade1 = TADELayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
aux_channels=aux_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
bias=bias,
|
||||||
|
upsample_factor=1,
|
||||||
|
upsample_mode=upsample_mode, )
|
||||||
|
self.gated_conv1 = nn.Conv1D(
|
||||||
|
in_channels,
|
||||||
|
in_channels * 2,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
bias_attr=bias,
|
||||||
|
padding=(kernel_size - 1) // 2, )
|
||||||
|
self.tade2 = TADELayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
aux_channels=in_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
bias=bias,
|
||||||
|
upsample_factor=upsample_factor,
|
||||||
|
upsample_mode=upsample_mode, )
|
||||||
|
self.gated_conv2 = nn.Conv1D(
|
||||||
|
in_channels,
|
||||||
|
in_channels * 2,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
bias_attr=bias,
|
||||||
|
dilation=dilation,
|
||||||
|
padding=(kernel_size - 1) // 2 * dilation, )
|
||||||
|
self.upsample = nn.Upsample(
|
||||||
|
scale_factor=upsample_factor, mode=upsample_mode)
|
||||||
|
if gated_function == "softmax":
|
||||||
|
self.gated_function = partial(F.softmax, axis=1)
|
||||||
|
elif gated_function == "sigmoid":
|
||||||
|
self.gated_function = F.sigmoid
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{gated_function} is not supported.")
|
||||||
|
|
||||||
|
def forward(self, x, c):
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : Tensor
|
||||||
|
Input tensor (B, in_channels, T).
|
||||||
|
c : Tensor
|
||||||
|
Auxiliary input tensor (B, aux_channels, T).
|
||||||
|
Returns
|
||||||
|
----------
|
||||||
|
Tensor
|
||||||
|
Output tensor (B, in_channels, T * upsample_factor).
|
||||||
|
Tensor
|
||||||
|
Upsampled auxirialy tensor (B, in_channels, T * upsample_factor).
|
||||||
|
"""
|
||||||
|
residual = x
|
||||||
|
x, c = self.tade1(x, c)
|
||||||
|
x = self.gated_conv1(x)
|
||||||
|
xa, xb = x.split(2, axis=1)
|
||||||
|
x = self.gated_function(xa) * F.tanh(xb)
|
||||||
|
x, c = self.tade2(x, c)
|
||||||
|
x = self.gated_conv2(x)
|
||||||
|
xa, xb = x.split(2, axis=1)
|
||||||
|
x = self.gated_function(xa) * F.tanh(xb)
|
||||||
|
# 'bilinear', 'bicubic' and 'nearest' only support 4-D tensor.
|
||||||
|
return self.upsample(residual.unsqueeze(-1))[:, :, :, 0] + x, c
|
Loading…
Reference in new issue