commit
963e906f56
@ -0,0 +1,111 @@
|
||||
# Style MelGAN with CSMSC
|
||||
This example contains code used to train a [Style MelGAN](https://arxiv.org/abs/2011.01557) model with [Chinese Standard Mandarin Speech Copus](https://www.data-baker.com/open_source.html).
|
||||
## Dataset
|
||||
### Download and Extract
|
||||
Download CSMSC from the [official website](https://www.data-baker.com/data/index/source) and extract it to `~/datasets`. Then the dataset is in directory `~/datasets/BZNSYP`.
|
||||
|
||||
### Get MFA Result and Extract
|
||||
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut silence in the edge of audio.
|
||||
You can download from here [baker_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/BZNSYP/with_tone/baker_alignment_tone.tar.gz), or train your own MFA model reference to [mfa example](https://github.com/PaddlePaddle/Parakeet/tree/develop/examples/mfa) of our repo.
|
||||
|
||||
## Get Started
|
||||
Assume the path to the dataset is `~/datasets/BZNSYP`.
|
||||
Assume the path to the MFA result of CSMSC is `./baker_alignment_tone`.
|
||||
Run the command below to
|
||||
1. **source path**.
|
||||
2. preprocess the dataset.
|
||||
3. train the model.
|
||||
4. synthesize wavs.
|
||||
- synthesize waveform from `metadata.jsonl`.
|
||||
```bash
|
||||
./run.sh
|
||||
```
|
||||
You can choose a range of stages you want to run, or set `stage` equal to `stop-stage` to use only one stage, for example, run the following command will only preprocess the dataset.
|
||||
```bash
|
||||
./run.sh --stage 0 --stop-stage 0
|
||||
```
|
||||
### Data Preprocessing
|
||||
```bash
|
||||
./local/preprocess.sh ${conf_path}
|
||||
```
|
||||
When it is done. A `dump` folder is created in the current directory. The structure of the dump folder is listed below.
|
||||
|
||||
```text
|
||||
dump
|
||||
├── dev
|
||||
│ ├── norm
|
||||
│ └── raw
|
||||
├── test
|
||||
│ ├── norm
|
||||
│ └── raw
|
||||
└── train
|
||||
├── norm
|
||||
├── raw
|
||||
└── feats_stats.npy
|
||||
```
|
||||
The dataset is split into 3 parts, namely `train`, `dev` and `test`, each of which contains a `norm` and `raw` subfolder. The `raw` folder contains log magnitude of mel spectrogram of each utterances, while the norm folder contains normalized spectrogram. The statistics used to normalize the spectrogram is computed from the training set, which is located in `dump/train/feats_stats.npy`.
|
||||
|
||||
Also there is a `metadata.jsonl` in each subfolder. It is a table-like file which contains id and paths to spectrogam of each utterance.
|
||||
|
||||
### Model Training
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path}
|
||||
```
|
||||
`./local/train.sh` calls `${BIN_DIR}/train.py`.
|
||||
Here's the complete help message.
|
||||
|
||||
```text
|
||||
usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA]
|
||||
[--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR]
|
||||
[--ngpu NGPU] [--verbose VERBOSE]
|
||||
|
||||
Train a Multi-Band MelGAN model.
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--config CONFIG config file to overwrite default config.
|
||||
--train-metadata TRAIN_METADATA
|
||||
training data.
|
||||
--dev-metadata DEV_METADATA
|
||||
dev data.
|
||||
--output-dir OUTPUT_DIR
|
||||
output dir.
|
||||
--ngpu NGPU if ngpu == 0, use cpu.
|
||||
--verbose VERBOSE verbose.
|
||||
```
|
||||
|
||||
1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`.
|
||||
2. `--train-metadata` and `--dev-metadata` should be the metadata file in the normalized subfolder of `train` and `dev` in the `dump` folder.
|
||||
3. `--output-dir` is the directory to save the results of the experiment. Checkpoints are save in `checkpoints/` inside this directory.
|
||||
4. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
|
||||
|
||||
### Synthesizing
|
||||
`./local/synthesize.sh` calls `${BIN_DIR}/../synthesize.py`, which can synthesize waveform from `metadata.jsonl`.
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name}
|
||||
```
|
||||
```text
|
||||
usage: synthesize.py [-h] [--config CONFIG] [--checkpoint CHECKPOINT]
|
||||
[--test-metadata TEST_METADATA] [--output-dir OUTPUT_DIR]
|
||||
[--ngpu NGPU] [--verbose VERBOSE]
|
||||
|
||||
Synthesize with multi band melgan.
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--config CONFIG multi band melgan config file.
|
||||
--checkpoint CHECKPOINT
|
||||
snapshot to load.
|
||||
--test-metadata TEST_METADATA
|
||||
dev data.
|
||||
--output-dir OUTPUT_DIR
|
||||
output dir.
|
||||
--ngpu NGPU if ngpu == 0, use cpu.
|
||||
--verbose VERBOSE verbose.
|
||||
```
|
||||
|
||||
1. `--config` multi band melgan config file. You should use the same config with which the model is trained.
|
||||
2. `--checkpoint` is the checkpoint to load. Pick one of the checkpoints from `checkpoints` inside the training output directory.
|
||||
3. `--test-metadata` is the metadata of the test dataset. Use the `metadata.jsonl` in the `dev/norm` subfolder from the processed directory.
|
||||
4. `--output-dir` is the directory to save the synthesized audio files.
|
||||
5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
|
@ -0,0 +1,136 @@
|
||||
# This is the configuration file for CSMSC dataset.This configuration is based
|
||||
# on StyleMelGAN paper but uses MSE loss instead of Hinge loss. And I found that
|
||||
# batch_size = 8 is also working good. So maybe if you want to accelerate the training,
|
||||
# you can reduce the batch size (e.g. 8 or 16). Upsampling scales is modified to
|
||||
# fit the shift size 300 pt.
|
||||
# NOTE: batch_max_steps(24000) == prod(noise_upsample_scales)(80) * prod(upsample_scales)(300)
|
||||
|
||||
###########################################################
|
||||
# FEATURE EXTRACTION SETTING #
|
||||
###########################################################
|
||||
fs: 24000 # Sampling rate.
|
||||
n_fft: 2048 # FFT size. (in samples)
|
||||
n_shift: 300 # Hop size. (in samples)
|
||||
win_length: 1200 # Window length. (in samples)
|
||||
# If set to null, it will be the same as fft_size.
|
||||
window: "hann" # Window function.
|
||||
n_mels: 80 # Number of mel basis.
|
||||
fmin: 80 # Minimum freq in mel basis calculation. (Hz)
|
||||
fmax: 7600 # Maximum frequency in mel basis calculation. (Hz)
|
||||
|
||||
###########################################################
|
||||
# GENERATOR NETWORK ARCHITECTURE SETTING #
|
||||
###########################################################
|
||||
generator_params:
|
||||
in_channels: 128 # Number of input channels.
|
||||
aux_channels: 80
|
||||
channels: 64 # Initial number of channels for conv layers.
|
||||
out_channels: 1 # Number of output channels.
|
||||
kernel_size: 9 # Kernel size of initial and final conv layers.
|
||||
dilation: 2
|
||||
bias: True
|
||||
noise_upsample_scales: [10, 2, 2, 2]
|
||||
noise_upsample_activation: "leakyrelu"
|
||||
noise_upsample_activation_params:
|
||||
negative_slope: 0.2
|
||||
upsample_scales: [5, 1, 5, 1, 3, 1, 2, 2, 1] # List of Upsampling scales. prod(upsample_scales) == n_shift
|
||||
upsample_mode: "nearest"
|
||||
gated_function: "softmax"
|
||||
use_weight_norm: True # Whether to use weight normalization.
|
||||
|
||||
###########################################################
|
||||
# DISCRIMINATOR NETWORK ARCHITECTURE SETTING #
|
||||
###########################################################
|
||||
discriminator_params:
|
||||
repeats: 4
|
||||
window_sizes: [512, 1024, 2048, 4096]
|
||||
pqmf_params:
|
||||
- [1, None, None, None]
|
||||
- [2, 62, 0.26700, 9.0]
|
||||
- [4, 62, 0.14200, 9.0]
|
||||
- [8, 62, 0.07949, 9.0]
|
||||
discriminator_params:
|
||||
out_channels: 1 # Number of output channels.
|
||||
kernel_sizes: [5, 3] # List of kernel size.
|
||||
channels: 16 # Number of channels of the initial conv layer.
|
||||
max_downsample_channels: 512 # Maximum number of channels of downsampling layers.
|
||||
bias: True
|
||||
downsample_scales: [4, 4, 4, 1] # List of downsampling scales.
|
||||
nonlinear_activation: "leakyrelu" # Nonlinear activation function.
|
||||
nonlinear_activation_params: # Parameters of nonlinear activation function.
|
||||
negative_slope: 0.2
|
||||
use_weight_norm: True # Whether to use weight norm.
|
||||
|
||||
|
||||
###########################################################
|
||||
# STFT LOSS SETTING #
|
||||
###########################################################
|
||||
use_stft_loss: true
|
||||
stft_loss_params:
|
||||
fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss.
|
||||
hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss
|
||||
win_lengths: [600, 1200, 240] # List of window length for STFT-based loss.
|
||||
window: "hann" # Window function for STFT-based loss
|
||||
lambda_aux: 1.0 # Loss balancing coefficient for aux loss.
|
||||
|
||||
###########################################################
|
||||
# ADVERSARIAL LOSS SETTING #
|
||||
###########################################################
|
||||
lambda_adv: 1.0 # Loss balancing coefficient for adv loss.
|
||||
generator_adv_loss_params:
|
||||
average_by_discriminators: false # Whether to average loss by #discriminators.
|
||||
discriminator_adv_loss_params:
|
||||
average_by_discriminators: false # Whether to average loss by #discriminators.
|
||||
|
||||
###########################################################
|
||||
# DATA LOADER SETTING #
|
||||
###########################################################
|
||||
batch_size: 32 # Batch size.
|
||||
# batch_max_steps(24000) == prod(noise_upsample_scales)(80) * prod(upsample_scales)(300, n_shift)
|
||||
batch_max_steps: 24000 # Length of each audio in batch. Make sure dividable by n_shift.
|
||||
num_workers: 2 # Number of workers in Pytorch DataLoader.
|
||||
|
||||
###########################################################
|
||||
# OPTIMIZER & SCHEDULER SETTING #
|
||||
###########################################################
|
||||
generator_optimizer_params:
|
||||
beta1: 0.5
|
||||
beta2: 0.9
|
||||
weight_decay: 0.0 # Generator's weight decay coefficient.
|
||||
generator_scheduler_params:
|
||||
learning_rate: 1.0e-4 # Generator's learning rate.
|
||||
gamma: 0.5 # Generator's scheduler gamma.
|
||||
milestones: # At each milestone, lr will be multiplied by gamma.
|
||||
- 100000
|
||||
- 300000
|
||||
- 500000
|
||||
- 700000
|
||||
- 900000
|
||||
generator_grad_norm: -1 # Generator's gradient norm.
|
||||
discriminator_optimizer_params:
|
||||
beta1: 0.5
|
||||
beta2: 0.9
|
||||
weight_decay: 0.0 # Discriminator's weight decay coefficient.
|
||||
discriminator_scheduler_params:
|
||||
learning_rate: 2.0e-4 # Discriminator's learning rate.
|
||||
gamma: 0.5 # Discriminator's scheduler gamma.
|
||||
milestones: # At each milestone, lr will be multiplied by gamma.
|
||||
- 200000
|
||||
- 400000
|
||||
- 600000
|
||||
- 800000
|
||||
discriminator_grad_norm: -1 # Discriminator's gradient norm.
|
||||
|
||||
###########################################################
|
||||
# INTERVAL SETTING #
|
||||
###########################################################
|
||||
discriminator_train_start_steps: 100000 # Number of steps to start to train discriminator.
|
||||
train_max_steps: 1500000 # Number of training steps.
|
||||
save_interval_steps: 5000 # Interval steps to save checkpoint.
|
||||
eval_interval_steps: 1000 # Interval steps to evaluate the network.
|
||||
|
||||
###########################################################
|
||||
# OTHER SETTING #
|
||||
###########################################################
|
||||
num_snapshots: 10 # max number of snapshots to keep while training
|
||||
seed: 42 # random seed for paddle, random, and np.random
|
@ -0,0 +1,55 @@
|
||||
#!/bin/bash
|
||||
|
||||
stage=0
|
||||
stop_stage=100
|
||||
|
||||
config_path=$1
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
# get durations from MFA's result
|
||||
echo "Generate durations.txt from MFA results ..."
|
||||
python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \
|
||||
--inputdir=./baker_alignment_tone \
|
||||
--output=durations.txt \
|
||||
--config=${config_path}
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# extract features
|
||||
echo "Extract features ..."
|
||||
python3 ${BIN_DIR}/../preprocess.py \
|
||||
--rootdir=~/datasets/BZNSYP/ \
|
||||
--dataset=baker \
|
||||
--dumpdir=dump \
|
||||
--dur-file=durations.txt \
|
||||
--config=${config_path} \
|
||||
--cut-sil=True \
|
||||
--num-cpu=20
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# get features' stats(mean and std)
|
||||
echo "Get features' stats ..."
|
||||
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
|
||||
--metadata=dump/train/raw/metadata.jsonl \
|
||||
--field-name="feats"
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
# normalize, dev and test should use train's stats
|
||||
echo "Normalize ..."
|
||||
|
||||
python3 ${BIN_DIR}/../normalize.py \
|
||||
--metadata=dump/train/raw/metadata.jsonl \
|
||||
--dumpdir=dump/train/norm \
|
||||
--stats=dump/train/feats_stats.npy
|
||||
python3 ${BIN_DIR}/../normalize.py \
|
||||
--metadata=dump/dev/raw/metadata.jsonl \
|
||||
--dumpdir=dump/dev/norm \
|
||||
--stats=dump/train/feats_stats.npy
|
||||
|
||||
python3 ${BIN_DIR}/../normalize.py \
|
||||
--metadata=dump/test/raw/metadata.jsonl \
|
||||
--dumpdir=dump/test/norm \
|
||||
--stats=dump/train/feats_stats.npy
|
||||
fi
|
@ -0,0 +1,14 @@
|
||||
#!/bin/bash
|
||||
|
||||
config_path=$1
|
||||
train_output_path=$2
|
||||
ckpt_name=$3
|
||||
|
||||
FLAGS_allocator_strategy=naive_best_fit \
|
||||
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
|
||||
python3 ${BIN_DIR}/../synthesize.py \
|
||||
--config=${config_path} \
|
||||
--checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
|
||||
--test-metadata=dump/test/norm/metadata.jsonl \
|
||||
--output-dir=${train_output_path}/test \
|
||||
--generator-type=style_melgan
|
@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
|
||||
config_path=$1
|
||||
train_output_path=$2
|
||||
|
||||
FLAGS_cudnn_exhaustive_search=true \
|
||||
FLAGS_conv_workspace_size_limit=4000 \
|
||||
python ${BIN_DIR}/train.py \
|
||||
--train-metadata=dump/train/norm/metadata.jsonl \
|
||||
--dev-metadata=dump/dev/norm/metadata.jsonl \
|
||||
--config=${config_path} \
|
||||
--output-dir=${train_output_path} \
|
||||
--ngpu=1
|
@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
export MAIN_ROOT=`realpath ${PWD}/../../../`
|
||||
|
||||
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
|
||||
export LC_ALL=C
|
||||
|
||||
export PYTHONDONTWRITEBYTECODE=1
|
||||
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
||||
|
||||
MODEL=style_melgan
|
||||
export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/gan_vocoder/${MODEL}
|
@ -0,0 +1,32 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
source path.sh
|
||||
|
||||
gpus=0,1
|
||||
stage=0
|
||||
stop_stage=100
|
||||
|
||||
conf_path=conf/default.yaml
|
||||
train_output_path=exp/default
|
||||
ckpt_name=snapshot_iter_50000.pdz
|
||||
|
||||
# with the following command, you can choose the stage range you want to run
|
||||
# such as `./run.sh --stage 0 --stop-stage 0`
|
||||
# this can not be mixed use with `$1`, `$2` ...
|
||||
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
# prepare data
|
||||
./local/preprocess.sh ${conf_path} || exit -1
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# synthesize
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
|
||||
fi
|
@ -1,103 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
import paddle
|
||||
import soundfile as sf
|
||||
import yaml
|
||||
from paddle import distributed as dist
|
||||
from timer import timer
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from paddlespeech.t2s.datasets.data_table import DataTable
|
||||
from paddlespeech.t2s.models.parallel_wavegan import PWGGenerator
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Synthesize with parallel wavegan.")
|
||||
parser.add_argument(
|
||||
"--config", type=str, help="parallel wavegan config file.")
|
||||
parser.add_argument("--checkpoint", type=str, help="snapshot to load.")
|
||||
parser.add_argument("--test-metadata", type=str, help="dev data.")
|
||||
parser.add_argument("--output-dir", type=str, help="output dir.")
|
||||
parser.add_argument(
|
||||
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
|
||||
parser.add_argument("--verbose", type=int, default=1, help="verbose.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config) as f:
|
||||
config = CfgNode(yaml.safe_load(f))
|
||||
|
||||
print("========Args========")
|
||||
print(yaml.safe_dump(vars(args)))
|
||||
print("========Config========")
|
||||
print(config)
|
||||
print(
|
||||
f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
|
||||
)
|
||||
|
||||
if args.ngpu == 0:
|
||||
paddle.set_device("cpu")
|
||||
elif args.ngpu > 0:
|
||||
paddle.set_device("gpu")
|
||||
else:
|
||||
print("ngpu should >= 0 !")
|
||||
generator = PWGGenerator(**config["generator_params"])
|
||||
state_dict = paddle.load(args.checkpoint)
|
||||
generator.set_state_dict(state_dict["generator_params"])
|
||||
|
||||
generator.remove_weight_norm()
|
||||
generator.eval()
|
||||
with jsonlines.open(args.test_metadata, 'r') as reader:
|
||||
metadata = list(reader)
|
||||
|
||||
test_dataset = DataTable(
|
||||
metadata,
|
||||
fields=['utt_id', 'feats'],
|
||||
converters={
|
||||
'utt_id': None,
|
||||
'feats': np.load,
|
||||
})
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
N = 0
|
||||
T = 0
|
||||
for example in test_dataset:
|
||||
utt_id = example['utt_id']
|
||||
mel = example['feats']
|
||||
mel = paddle.to_tensor(mel) # (T, C)
|
||||
with timer() as t:
|
||||
with paddle.no_grad():
|
||||
wav = generator.inference(c=mel)
|
||||
wav = wav.numpy()
|
||||
N += wav.size
|
||||
T += t.elapse
|
||||
speed = wav.size / t.elapse
|
||||
rtf = config.fs / speed
|
||||
print(
|
||||
f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
|
||||
)
|
||||
sf.write(str(output_dir / (utt_id + ".wav")), wav, samplerate=config.fs)
|
||||
print(f"generation speed: {N / T}Hz, RTF: {config.fs / (N / T) }")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,258 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
import paddle
|
||||
import yaml
|
||||
from paddle import DataParallel
|
||||
from paddle import distributed as dist
|
||||
from paddle import nn
|
||||
from paddle.io import DataLoader
|
||||
from paddle.io import DistributedBatchSampler
|
||||
from paddle.optimizer import Adam
|
||||
from paddle.optimizer.lr import MultiStepDecay
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from paddlespeech.t2s.datasets.data_table import DataTable
|
||||
from paddlespeech.t2s.datasets.vocoder_batch_fn import Clip
|
||||
from paddlespeech.t2s.models.melgan import StyleMelGANDiscriminator
|
||||
from paddlespeech.t2s.models.melgan import StyleMelGANEvaluator
|
||||
from paddlespeech.t2s.models.melgan import StyleMelGANGenerator
|
||||
from paddlespeech.t2s.models.melgan import StyleMelGANUpdater
|
||||
from paddlespeech.t2s.modules.losses import DiscriminatorAdversarialLoss
|
||||
from paddlespeech.t2s.modules.losses import GeneratorAdversarialLoss
|
||||
from paddlespeech.t2s.modules.losses import MultiResolutionSTFTLoss
|
||||
from paddlespeech.t2s.training.extensions.snapshot import Snapshot
|
||||
from paddlespeech.t2s.training.extensions.visualizer import VisualDL
|
||||
from paddlespeech.t2s.training.seeding import seed_everything
|
||||
from paddlespeech.t2s.training.trainer import Trainer
|
||||
|
||||
|
||||
def train_sp(args, config):
|
||||
# decides device type and whether to run in parallel
|
||||
# setup running environment correctly
|
||||
world_size = paddle.distributed.get_world_size()
|
||||
if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0:
|
||||
paddle.set_device("cpu")
|
||||
else:
|
||||
paddle.set_device("gpu")
|
||||
if world_size > 1:
|
||||
paddle.distributed.init_parallel_env()
|
||||
|
||||
# set the random seed, it is a must for multiprocess training
|
||||
seed_everything(config.seed)
|
||||
|
||||
print(
|
||||
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
|
||||
)
|
||||
|
||||
# dataloader has been too verbose
|
||||
logging.getLogger("DataLoader").disabled = True
|
||||
|
||||
# construct dataset for training and validation
|
||||
with jsonlines.open(args.train_metadata, 'r') as reader:
|
||||
train_metadata = list(reader)
|
||||
train_dataset = DataTable(
|
||||
data=train_metadata,
|
||||
fields=["wave", "feats"],
|
||||
converters={
|
||||
"wave": np.load,
|
||||
"feats": np.load,
|
||||
}, )
|
||||
with jsonlines.open(args.dev_metadata, 'r') as reader:
|
||||
dev_metadata = list(reader)
|
||||
dev_dataset = DataTable(
|
||||
data=dev_metadata,
|
||||
fields=["wave", "feats"],
|
||||
converters={
|
||||
"wave": np.load,
|
||||
"feats": np.load,
|
||||
}, )
|
||||
|
||||
# collate function and dataloader
|
||||
train_sampler = DistributedBatchSampler(
|
||||
train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True)
|
||||
dev_sampler = DistributedBatchSampler(
|
||||
dev_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=False,
|
||||
drop_last=False)
|
||||
print("samplers done!")
|
||||
|
||||
if "aux_context_window" in config.generator_params:
|
||||
aux_context_window = config.generator_params.aux_context_window
|
||||
else:
|
||||
aux_context_window = 0
|
||||
train_batch_fn = Clip(
|
||||
batch_max_steps=config.batch_max_steps,
|
||||
hop_size=config.n_shift,
|
||||
aux_context_window=aux_context_window)
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_sampler=train_sampler,
|
||||
collate_fn=train_batch_fn,
|
||||
num_workers=config.num_workers)
|
||||
|
||||
dev_dataloader = DataLoader(
|
||||
dev_dataset,
|
||||
batch_sampler=dev_sampler,
|
||||
collate_fn=train_batch_fn,
|
||||
num_workers=config.num_workers)
|
||||
print("dataloaders done!")
|
||||
|
||||
generator = StyleMelGANGenerator(**config["generator_params"])
|
||||
discriminator = StyleMelGANDiscriminator(**config["discriminator_params"])
|
||||
if world_size > 1:
|
||||
generator = DataParallel(generator)
|
||||
discriminator = DataParallel(discriminator)
|
||||
print("models done!")
|
||||
criterion_stft = MultiResolutionSTFTLoss(**config["stft_loss_params"])
|
||||
|
||||
criterion_gen_adv = GeneratorAdversarialLoss(
|
||||
**config["generator_adv_loss_params"])
|
||||
criterion_dis_adv = DiscriminatorAdversarialLoss(
|
||||
**config["discriminator_adv_loss_params"])
|
||||
print("criterions done!")
|
||||
|
||||
lr_schedule_g = MultiStepDecay(**config["generator_scheduler_params"])
|
||||
# Compared to multi_band_melgan.v1 config, Adam optimizer without gradient norm is used
|
||||
generator_grad_norm = config["generator_grad_norm"]
|
||||
gradient_clip_g = nn.ClipGradByGlobalNorm(
|
||||
generator_grad_norm) if generator_grad_norm > 0 else None
|
||||
print("gradient_clip_g:", gradient_clip_g)
|
||||
|
||||
optimizer_g = Adam(
|
||||
learning_rate=lr_schedule_g,
|
||||
grad_clip=gradient_clip_g,
|
||||
parameters=generator.parameters(),
|
||||
**config["generator_optimizer_params"])
|
||||
lr_schedule_d = MultiStepDecay(**config["discriminator_scheduler_params"])
|
||||
discriminator_grad_norm = config["discriminator_grad_norm"]
|
||||
gradient_clip_d = nn.ClipGradByGlobalNorm(
|
||||
discriminator_grad_norm) if discriminator_grad_norm > 0 else None
|
||||
print("gradient_clip_d:", gradient_clip_d)
|
||||
optimizer_d = Adam(
|
||||
learning_rate=lr_schedule_d,
|
||||
grad_clip=gradient_clip_d,
|
||||
parameters=discriminator.parameters(),
|
||||
**config["discriminator_optimizer_params"])
|
||||
print("optimizers done!")
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
if dist.get_rank() == 0:
|
||||
config_name = args.config.split("/")[-1]
|
||||
# copy conf to output_dir
|
||||
shutil.copyfile(args.config, output_dir / config_name)
|
||||
|
||||
updater = StyleMelGANUpdater(
|
||||
models={
|
||||
"generator": generator,
|
||||
"discriminator": discriminator,
|
||||
},
|
||||
optimizers={
|
||||
"generator": optimizer_g,
|
||||
"discriminator": optimizer_d,
|
||||
},
|
||||
criterions={
|
||||
"stft": criterion_stft,
|
||||
"gen_adv": criterion_gen_adv,
|
||||
"dis_adv": criterion_dis_adv,
|
||||
},
|
||||
schedulers={
|
||||
"generator": lr_schedule_g,
|
||||
"discriminator": lr_schedule_d,
|
||||
},
|
||||
dataloader=train_dataloader,
|
||||
discriminator_train_start_steps=config.discriminator_train_start_steps,
|
||||
lambda_adv=config.lambda_adv,
|
||||
output_dir=output_dir)
|
||||
|
||||
evaluator = StyleMelGANEvaluator(
|
||||
models={
|
||||
"generator": generator,
|
||||
"discriminator": discriminator,
|
||||
},
|
||||
criterions={
|
||||
"stft": criterion_stft,
|
||||
"gen_adv": criterion_gen_adv,
|
||||
"dis_adv": criterion_dis_adv,
|
||||
},
|
||||
dataloader=dev_dataloader,
|
||||
lambda_adv=config.lambda_adv,
|
||||
output_dir=output_dir)
|
||||
|
||||
trainer = Trainer(
|
||||
updater,
|
||||
stop_trigger=(config.train_max_steps, "iteration"),
|
||||
out=output_dir)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
trainer.extend(
|
||||
evaluator, trigger=(config.eval_interval_steps, 'iteration'))
|
||||
trainer.extend(VisualDL(output_dir), trigger=(1, 'iteration'))
|
||||
trainer.extend(
|
||||
Snapshot(max_size=config.num_snapshots),
|
||||
trigger=(config.save_interval_steps, 'iteration'))
|
||||
|
||||
print("Trainer Done!")
|
||||
trainer.run()
|
||||
|
||||
|
||||
def main():
|
||||
# parse args and config and redirect to train_sp
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train a Multi-Band MelGAN model.")
|
||||
parser.add_argument(
|
||||
"--config", type=str, help="config file to overwrite default config.")
|
||||
parser.add_argument("--train-metadata", type=str, help="training data.")
|
||||
parser.add_argument("--dev-metadata", type=str, help="dev data.")
|
||||
parser.add_argument("--output-dir", type=str, help="output dir.")
|
||||
parser.add_argument(
|
||||
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
|
||||
parser.add_argument("--verbose", type=int, default=1, help="verbose.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config, 'rt') as f:
|
||||
config = CfgNode(yaml.safe_load(f))
|
||||
|
||||
print("========Args========")
|
||||
print(yaml.safe_dump(vars(args)))
|
||||
print("========Config========")
|
||||
print(config)
|
||||
print(
|
||||
f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
|
||||
)
|
||||
|
||||
# dispatch
|
||||
if args.ngpu > 1:
|
||||
dist.spawn(train_sp, (args, config), nprocs=args.ngpu)
|
||||
else:
|
||||
train_sp(args, config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,227 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import paddle
|
||||
from paddle import distributed as dist
|
||||
from paddle.io import DataLoader
|
||||
from paddle.nn import Layer
|
||||
from paddle.optimizer import Optimizer
|
||||
from paddle.optimizer.lr import LRScheduler
|
||||
|
||||
from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
|
||||
from paddlespeech.t2s.training.reporter import report
|
||||
from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater
|
||||
from paddlespeech.t2s.training.updaters.standard_updater import UpdaterState
|
||||
logging.basicConfig(
|
||||
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
|
||||
datefmt='[%Y-%m-%d %H:%M:%S]')
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class StyleMelGANUpdater(StandardUpdater):
|
||||
def __init__(self,
|
||||
models: Dict[str, Layer],
|
||||
optimizers: Dict[str, Optimizer],
|
||||
criterions: Dict[str, Layer],
|
||||
schedulers: Dict[str, LRScheduler],
|
||||
dataloader: DataLoader,
|
||||
generator_train_start_steps: int=0,
|
||||
discriminator_train_start_steps: int=100000,
|
||||
lambda_adv: float=1.0,
|
||||
lambda_aux: float=1.0,
|
||||
output_dir: Path=None):
|
||||
self.models = models
|
||||
self.generator: Layer = models['generator']
|
||||
self.discriminator: Layer = models['discriminator']
|
||||
|
||||
self.optimizers = optimizers
|
||||
self.optimizer_g: Optimizer = optimizers['generator']
|
||||
self.optimizer_d: Optimizer = optimizers['discriminator']
|
||||
|
||||
self.criterions = criterions
|
||||
self.criterion_stft = criterions['stft']
|
||||
self.criterion_gen_adv = criterions["gen_adv"]
|
||||
self.criterion_dis_adv = criterions["dis_adv"]
|
||||
|
||||
self.schedulers = schedulers
|
||||
self.scheduler_g = schedulers['generator']
|
||||
self.scheduler_d = schedulers['discriminator']
|
||||
|
||||
self.dataloader = dataloader
|
||||
|
||||
self.generator_train_start_steps = generator_train_start_steps
|
||||
self.discriminator_train_start_steps = discriminator_train_start_steps
|
||||
self.lambda_adv = lambda_adv
|
||||
self.lambda_aux = lambda_aux
|
||||
|
||||
self.state = UpdaterState(iteration=0, epoch=0)
|
||||
self.train_iterator = iter(self.dataloader)
|
||||
|
||||
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||
self.filehandler = logging.FileHandler(str(log_file))
|
||||
logger.addHandler(self.filehandler)
|
||||
self.logger = logger
|
||||
self.msg = ""
|
||||
|
||||
def update_core(self, batch):
|
||||
self.msg = "Rank: {}, ".format(dist.get_rank())
|
||||
losses_dict = {}
|
||||
# parse batch
|
||||
wav, mel = batch
|
||||
|
||||
# Generator
|
||||
if self.state.iteration > self.generator_train_start_steps:
|
||||
# (B, out_channels, T ** prod(upsample_scales)
|
||||
wav_ = self.generator(mel)
|
||||
|
||||
# initialize
|
||||
gen_loss = 0.0
|
||||
aux_loss = 0.0
|
||||
|
||||
# full band multi-resolution stft loss
|
||||
sc_loss, mag_loss = self.criterion_stft(wav_, wav)
|
||||
aux_loss += sc_loss + mag_loss
|
||||
report("train/spectral_convergence_loss", float(sc_loss))
|
||||
report("train/log_stft_magnitude_loss", float(mag_loss))
|
||||
losses_dict["spectral_convergence_loss"] = float(sc_loss)
|
||||
losses_dict["log_stft_magnitude_loss"] = float(mag_loss)
|
||||
|
||||
gen_loss += aux_loss * self.lambda_aux
|
||||
|
||||
# adversarial loss
|
||||
if self.state.iteration > self.discriminator_train_start_steps:
|
||||
p_ = self.discriminator(wav_)
|
||||
adv_loss = self.criterion_gen_adv(p_)
|
||||
report("train/adversarial_loss", float(adv_loss))
|
||||
losses_dict["adversarial_loss"] = float(adv_loss)
|
||||
|
||||
gen_loss += self.lambda_adv * adv_loss
|
||||
|
||||
report("train/generator_loss", float(gen_loss))
|
||||
losses_dict["generator_loss"] = float(gen_loss)
|
||||
|
||||
self.optimizer_g.clear_grad()
|
||||
gen_loss.backward()
|
||||
|
||||
self.optimizer_g.step()
|
||||
self.scheduler_g.step()
|
||||
|
||||
# Disctiminator
|
||||
if self.state.iteration > self.discriminator_train_start_steps:
|
||||
# re-compute wav_ which leads better quality
|
||||
with paddle.no_grad():
|
||||
wav_ = self.generator(mel)
|
||||
|
||||
p = self.discriminator(wav)
|
||||
p_ = self.discriminator(wav_.detach())
|
||||
real_loss, fake_loss = self.criterion_dis_adv(p_, p)
|
||||
dis_loss = real_loss + fake_loss
|
||||
report("train/real_loss", float(real_loss))
|
||||
report("train/fake_loss", float(fake_loss))
|
||||
report("train/discriminator_loss", float(dis_loss))
|
||||
losses_dict["real_loss"] = float(real_loss)
|
||||
losses_dict["fake_loss"] = float(fake_loss)
|
||||
losses_dict["discriminator_loss"] = float(dis_loss)
|
||||
|
||||
self.optimizer_d.clear_grad()
|
||||
dis_loss.backward()
|
||||
|
||||
self.optimizer_d.step()
|
||||
self.scheduler_d.step()
|
||||
|
||||
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||
for k, v in losses_dict.items())
|
||||
|
||||
|
||||
class StyleMelGANEvaluator(StandardEvaluator):
|
||||
def __init__(self,
|
||||
models: Dict[str, Layer],
|
||||
criterions: Dict[str, Layer],
|
||||
dataloader: DataLoader,
|
||||
lambda_adv: float=1.0,
|
||||
lambda_aux: float=1.0,
|
||||
output_dir: Path=None):
|
||||
self.models = models
|
||||
self.generator = models['generator']
|
||||
self.discriminator = models['discriminator']
|
||||
|
||||
self.criterions = criterions
|
||||
self.criterion_stft = criterions['stft']
|
||||
self.criterion_gen_adv = criterions["gen_adv"]
|
||||
self.criterion_dis_adv = criterions["dis_adv"]
|
||||
|
||||
self.dataloader = dataloader
|
||||
|
||||
self.lambda_adv = lambda_adv
|
||||
self.lambda_aux = lambda_aux
|
||||
|
||||
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||
self.filehandler = logging.FileHandler(str(log_file))
|
||||
logger.addHandler(self.filehandler)
|
||||
self.logger = logger
|
||||
self.msg = ""
|
||||
|
||||
def evaluate_core(self, batch):
|
||||
self.msg = "Evaluate: "
|
||||
losses_dict = {}
|
||||
wav, mel = batch
|
||||
|
||||
# Generator
|
||||
# (B, out_channels, T ** prod(upsample_scales)
|
||||
wav_ = self.generator(mel)
|
||||
|
||||
# initialize
|
||||
gen_loss = 0.0
|
||||
aux_loss = 0.0
|
||||
|
||||
# adversarial loss
|
||||
p_ = self.discriminator(wav_)
|
||||
adv_loss = self.criterion_gen_adv(p_)
|
||||
report("eval/adversarial_loss", float(adv_loss))
|
||||
losses_dict["adversarial_loss"] = float(adv_loss)
|
||||
|
||||
gen_loss += self.lambda_adv * adv_loss
|
||||
|
||||
# multi-resolution stft loss
|
||||
sc_loss, mag_loss = self.criterion_stft(wav_, wav)
|
||||
aux_loss += sc_loss + mag_loss
|
||||
report("eval/spectral_convergence_loss", float(sc_loss))
|
||||
report("eval/log_stft_magnitude_loss", float(mag_loss))
|
||||
losses_dict["spectral_convergence_loss"] = float(sc_loss)
|
||||
losses_dict["log_stft_magnitude_loss"] = float(mag_loss)
|
||||
|
||||
gen_loss += aux_loss * self.lambda_aux
|
||||
|
||||
report("eval/generator_loss", float(gen_loss))
|
||||
losses_dict["generator_loss"] = float(gen_loss)
|
||||
|
||||
# Disctiminator
|
||||
p = self.discriminator(wav)
|
||||
real_loss, fake_loss = self.criterion_dis_adv(p_, p)
|
||||
dis_loss = real_loss + fake_loss
|
||||
report("eval/real_loss", float(real_loss))
|
||||
report("eval/fake_loss", float(fake_loss))
|
||||
report("eval/discriminator_loss", float(dis_loss))
|
||||
|
||||
losses_dict["real_loss"] = float(real_loss)
|
||||
losses_dict["fake_loss"] = float(fake_loss)
|
||||
losses_dict["discriminator_loss"] = float(dis_loss)
|
||||
|
||||
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||
for k, v in losses_dict.items())
|
||||
self.logger.info(self.msg)
|
@ -0,0 +1,164 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from espnet(https://github.com/espnet/espnet)
|
||||
"""StyleMelGAN's TADEResBlock Modules."""
|
||||
from functools import partial
|
||||
|
||||
import paddle.nn.functional as F
|
||||
from paddle import nn
|
||||
|
||||
|
||||
class TADELayer(nn.Layer):
|
||||
"""TADE Layer module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int=64,
|
||||
aux_channels: int=80,
|
||||
kernel_size: int=9,
|
||||
bias: bool=True,
|
||||
upsample_factor: int=2,
|
||||
upsample_mode: str="nearest", ):
|
||||
"""Initilize TADE layer."""
|
||||
super().__init__()
|
||||
self.norm = nn.InstanceNorm1D(
|
||||
in_channels, momentum=0.1, data_format="NCL")
|
||||
self.aux_conv = nn.Sequential(
|
||||
nn.Conv1D(
|
||||
aux_channels,
|
||||
in_channels,
|
||||
kernel_size,
|
||||
1,
|
||||
bias_attr=bias,
|
||||
padding=(kernel_size - 1) // 2, ), )
|
||||
self.gated_conv = nn.Sequential(
|
||||
nn.Conv1D(
|
||||
in_channels,
|
||||
in_channels * 2,
|
||||
kernel_size,
|
||||
1,
|
||||
bias_attr=bias,
|
||||
padding=(kernel_size - 1) // 2, ), )
|
||||
self.upsample = nn.Upsample(
|
||||
scale_factor=upsample_factor, mode=upsample_mode)
|
||||
|
||||
def forward(self, x, c):
|
||||
"""Calculate forward propagation.
|
||||
Parameters
|
||||
----------
|
||||
x : Tensor
|
||||
Input tensor (B, in_channels, T).
|
||||
c : Tensor
|
||||
Auxiliary input tensor (B, aux_channels, T).
|
||||
Returns
|
||||
----------
|
||||
Tensor
|
||||
Output tensor (B, in_channels, T * upsample_factor).
|
||||
Tensor
|
||||
Upsampled aux tensor (B, in_channels, T * upsample_factor).
|
||||
"""
|
||||
|
||||
x = self.norm(x)
|
||||
# 'bilinear', 'bicubic' and 'nearest' only support 4-D tensor.
|
||||
c = self.upsample(c.unsqueeze(-1))
|
||||
c = c[:, :, :, 0]
|
||||
|
||||
c = self.aux_conv(c)
|
||||
cg = self.gated_conv(c)
|
||||
cg1, cg2 = cg.split(2, axis=1)
|
||||
# 'bilinear', 'bicubic' and 'nearest' only support 4-D tensor.
|
||||
y = cg1 * self.upsample(x.unsqueeze(-1))[:, :, :, 0] + cg2
|
||||
return y, c
|
||||
|
||||
|
||||
class TADEResBlock(nn.Layer):
|
||||
"""TADEResBlock module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int=64,
|
||||
aux_channels: int=80,
|
||||
kernel_size: int=9,
|
||||
dilation: int=2,
|
||||
bias: bool=True,
|
||||
upsample_factor: int=2,
|
||||
# this is a diff in paddle, the mode only can be "linear" when input is 3D
|
||||
upsample_mode: str="nearest",
|
||||
gated_function: str="softmax", ):
|
||||
"""Initialize TADEResBlock module."""
|
||||
super().__init__()
|
||||
self.tade1 = TADELayer(
|
||||
in_channels=in_channels,
|
||||
aux_channels=aux_channels,
|
||||
kernel_size=kernel_size,
|
||||
bias=bias,
|
||||
upsample_factor=1,
|
||||
upsample_mode=upsample_mode, )
|
||||
self.gated_conv1 = nn.Conv1D(
|
||||
in_channels,
|
||||
in_channels * 2,
|
||||
kernel_size,
|
||||
1,
|
||||
bias_attr=bias,
|
||||
padding=(kernel_size - 1) // 2, )
|
||||
self.tade2 = TADELayer(
|
||||
in_channels=in_channels,
|
||||
aux_channels=in_channels,
|
||||
kernel_size=kernel_size,
|
||||
bias=bias,
|
||||
upsample_factor=upsample_factor,
|
||||
upsample_mode=upsample_mode, )
|
||||
self.gated_conv2 = nn.Conv1D(
|
||||
in_channels,
|
||||
in_channels * 2,
|
||||
kernel_size,
|
||||
1,
|
||||
bias_attr=bias,
|
||||
dilation=dilation,
|
||||
padding=(kernel_size - 1) // 2 * dilation, )
|
||||
self.upsample = nn.Upsample(
|
||||
scale_factor=upsample_factor, mode=upsample_mode)
|
||||
if gated_function == "softmax":
|
||||
self.gated_function = partial(F.softmax, axis=1)
|
||||
elif gated_function == "sigmoid":
|
||||
self.gated_function = F.sigmoid
|
||||
else:
|
||||
raise ValueError(f"{gated_function} is not supported.")
|
||||
|
||||
def forward(self, x, c):
|
||||
"""Calculate forward propagation.
|
||||
Parameters
|
||||
----------
|
||||
x : Tensor
|
||||
Input tensor (B, in_channels, T).
|
||||
c : Tensor
|
||||
Auxiliary input tensor (B, aux_channels, T).
|
||||
Returns
|
||||
----------
|
||||
Tensor
|
||||
Output tensor (B, in_channels, T * upsample_factor).
|
||||
Tensor
|
||||
Upsampled auxirialy tensor (B, in_channels, T * upsample_factor).
|
||||
"""
|
||||
residual = x
|
||||
x, c = self.tade1(x, c)
|
||||
x = self.gated_conv1(x)
|
||||
xa, xb = x.split(2, axis=1)
|
||||
x = self.gated_function(xa) * F.tanh(xb)
|
||||
x, c = self.tade2(x, c)
|
||||
x = self.gated_conv2(x)
|
||||
xa, xb = x.split(2, axis=1)
|
||||
x = self.gated_function(xa) * F.tanh(xb)
|
||||
# 'bilinear', 'bicubic' and 'nearest' only support 4-D tensor.
|
||||
return self.upsample(residual.unsqueeze(-1))[:, :, :, 0] + x, c
|
Loading…
Reference in new issue