add csmsc mb melgan example

pull/948/head
TianYuan 3 years ago
parent e395462419
commit 6dbcd7720d

@ -0,0 +1,127 @@
# Multi Band MelGAN with CSMSC
This example contains code used to train a [Multi Band MelGAN](https://arxiv.org/abs/2005.05106) model with [Chinese Standard Mandarin Speech Copus](https://www.data-baker.com/open_source.html).
## Dataset
### Download and Extract the datasaet
Download CSMSC from the [official website](https://www.data-baker.com/data/index/source) and extract it to `~/datasets`. Then the dataset is in directory `~/datasets/BZNSYP`.
### Get MFA results for silence trim
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut silence in the edge of audio.
You can download from here [baker_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/BZNSYP/with_tone/baker_alignment_tone.tar.gz), or train your own MFA model reference to [use_mfa example](https://github.com/PaddlePaddle/Parakeet/tree/develop/examples/use_mfa) of our repo.
## Get Started
Assume the path to the dataset is `~/datasets/BZNSYP`.
Assume the path to the MFA result of CSMSC is `./baker_alignment_tone`.
Run the command below to
1. **source path**.
2. preprocess the dataset,
3. train the model.
4. synthesize wavs.
- synthesize waveform from `metadata.jsonl`.
```bash
./run.sh
```
### Preprocess the dataset
```bash
./local/preprocess.sh ${conf_path}
```
When it is done. A `dump` folder is created in the current directory. The structure of the dump folder is listed below.
```text
dump
├── dev
│ ├── norm
│ └── raw
├── test
│ ├── norm
│ └── raw
└── train
├── norm
├── raw
└── feats_stats.npy
```
The dataset is split into 3 parts, namely `train`, `dev` and `test`, each of which contains a `norm` and `raw` subfolder. The `raw` folder contains log magnitude of mel spectrogram of each utterances, while the norm folder contains normalized spectrogram. The statistics used to normalize the spectrogram is computed from the training set, which is located in `dump/train/feats_stats.npy`.
Also there is a `metadata.jsonl` in each subfolder. It is a table-like file which contains id and paths to spectrogam of each utterance.
### Train the model
```bash
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path}
```
`./local/train.sh` calls `${BIN_DIR}/train.py`.
Here's the complete help message.
```text
usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA]
[--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR]
[--device DEVICE] [--nprocs NPROCS] [--verbose VERBOSE]
[--batch-size BATCH_SIZE] [--max-iter MAX_ITER]
[--run-benchmark RUN_BENCHMARK]
[--profiler_options PROFILER_OPTIONS]
Train a ParallelWaveGAN model.
optional arguments:
-h, --help show this help message and exit
--config CONFIG config file to overwrite default config.
--train-metadata TRAIN_METADATA
training data.
--dev-metadata DEV_METADATA
dev data.
--output-dir OUTPUT_DIR
output dir.
--device DEVICE device type to use.
--nprocs NPROCS number of processes.
--verbose VERBOSE verbose.
benchmark:
arguments related to benchmark.
--batch-size BATCH_SIZE
batch size.
--max-iter MAX_ITER train max steps.
--run-benchmark RUN_BENCHMARK
runing benchmark or not, if True, use the --batch-size
and --max-iter.
--profiler_options PROFILER_OPTIONS
The option of profiler, which should be in format
"key1=value1;key2=value2;key3=value3".
```
1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`.
2. `--train-metadata` and `--dev-metadata` should be the metadata file in the normalized subfolder of `train` and `dev` in the `dump` folder.
3. `--output-dir` is the directory to save the results of the experiment. Checkpoints are save in `checkpoints/` inside this directory.
4. `--device` is the type of the device to run the experiment, 'cpu' or 'gpu' are supported.
5. `--nprocs` is the number of processes to run in parallel, note that nprocs > 1 is only supported when `--device` is 'gpu'.
### Synthesize
`./local/synthesize.sh` calls `${BIN_DIR}/synthesize.py`, which can synthesize waveform from `metadata.jsonl`.
```bash
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name}
```
```text
usage: synthesize.py [-h] [--config CONFIG] [--checkpoint CHECKPOINT]
[--test-metadata TEST_METADATA] [--output-dir OUTPUT_DIR]
[--device DEVICE] [--verbose VERBOSE]
Synthesize with parallel wavegan.
optional arguments:
-h, --help show this help message and exit
--config CONFIG parallel wavegan config file.
--checkpoint CHECKPOINT
snapshot to load.
--test-metadata TEST_METADATA
dev data.
--output-dir OUTPUT_DIR
output dir.
--device DEVICE device to run.
--verbose VERBOSE verbose.
```
1. `--config` parallel wavegan config file. You should use the same config with which the model is trained.
2. `--checkpoint` is the checkpoint to load. Pick one of the checkpoints from `checkpoints` inside the training output directory.
3. `--test-metadata` is the metadata of the test dataset. Use the `metadata.jsonl` in the `dev/norm` subfolder from the processed directory.
4. `--output-dir` is the directory to save the synthesized audio files.
5. `--device` is the type of device to run synthesis, 'cpu' and 'gpu' are supported.
## Pretrained Models

@ -0,0 +1,139 @@
# This is the hyperparameter configuration file for MelGAN.
# Please make sure this is adjusted for the CSMSC dataset. If you want to
# apply to the other dataset, you might need to carefully change some parameters.
# This configuration requires ~ 8GB memory and will finish within 7 days on Titan V.
# This configuration is based on full-band MelGAN but the hop size and sampling
# rate is different from the paper (16kHz vs 24kHz). The number of iteraions
# is not shown in the paper so currently we train 1M iterations (not sure enough
# to converge). The optimizer setting is based on @dathudeptrai advice.
# https://github.com/kan-bayashi/ParallelWaveGAN/issues/143#issuecomment-632539906
###########################################################
# FEATURE EXTRACTION SETTING #
###########################################################
fs: 24000 # Sampling rate.
n_fft: 2048 # FFT size. (in samples)
n_shift: 300 # Hop size. (in samples)
win_length: 1200 # Window length. (in samples)
# If set to null, it will be the same as fft_size.
window: "hann" # Window function.
n_mels: 80 # Number of mel basis.
fmin: 80 # Minimum freq in mel basis calculation. (Hz)
fmax: 7600 # Maximum frequency in mel basis calculation. (Hz)
###########################################################
# GENERATOR NETWORK ARCHITECTURE SETTING #
###########################################################
generator_params:
in_channels: 80 # Number of input channels.
out_channels: 4 # Number of output channels.
kernel_size: 7 # Kernel size of initial and final conv layers.
channels: 384 # Initial number of channels for conv layers.
upsample_scales: [5, 5, 3] # List of Upsampling scales.
stack_kernel_size: 3 # Kernel size of dilated conv layers in residual stack.
stacks: 4 # Number of stacks in a single residual stack module.
use_weight_norm: True # Whether to use weight normalization.
use_causal_conv: False # Whether to use causal convolution.
use_final_nonlinear_activation: False # If True, spectral_convergence_loss and sub_spectral_convergence_loss will be too large (eg.30)
###########################################################
# DISCRIMINATOR NETWORK ARCHITECTURE SETTING #
###########################################################
discriminator_params:
in_channels: 1 # Number of input channels.
out_channels: 1 # Number of output channels.
scales: 3 # Number of multi-scales.
downsample_pooling: "AvgPool1D" # Pooling type for the input downsampling.
downsample_pooling_params: # Parameters of the above pooling function.
kernel_size: 4
stride: 2
padding: 1
exclusive: True
kernel_sizes: [5, 3] # List of kernel size.
channels: 16 # Number of channels of the initial conv layer.
max_downsample_channels: 512 # Maximum number of channels of downsampling layers.
downsample_scales: [4, 4, 4] # List of downsampling scales.
nonlinear_activation: "LeakyReLU" # Nonlinear activation function.
nonlinear_activation_params: # Parameters of nonlinear activation function.
negative_slope: 0.2
use_weight_norm: True # Whether to use weight norm.
###########################################################
# STFT LOSS SETTING #
###########################################################
use_stft_loss: true
stft_loss_params:
fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss.
hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss
win_lengths: [600, 1200, 240] # List of window length for STFT-based loss.
window: "hann" # Window function for STFT-based loss
use_subband_stft_loss: true
subband_stft_loss_params:
fft_sizes: [384, 683, 171] # List of FFT size for STFT-based loss.
hop_sizes: [30, 60, 10] # List of hop size for STFT-based loss
win_lengths: [150, 300, 60] # List of window length for STFT-based loss.
window: "hann" # Window function for STFT-based loss
###########################################################
# ADVERSARIAL LOSS SETTING #
###########################################################
use_feat_match_loss: false # Whether to use feature matching loss.
lambda_adv: 2.5 # Loss balancing coefficient for adversarial loss.
###########################################################
# DATA LOADER SETTING #
###########################################################
batch_size: 64 # Batch size.
batch_max_steps: 16200 # Length of each audio in batch. Make sure dividable by hop_size.
num_workers: 4 # Number of workers in DataLoader.
###########################################################
# OPTIMIZER & SCHEDULER SETTING #
###########################################################
generator_optimizer_params:
epsilon: 1.0e-7 # Generator's epsilon.
weight_decay: 0.0 # Generator's weight decay coefficient.
generator_grad_norm: -1 # Generator's gradient norm.
generator_scheduler_params:
learning_rate: 1.0e-3 # Generator's learning rate.
gamma: 0.5 # Generator's scheduler gamma.
milestones: # At each milestone, lr will be multiplied by gamma.
- 100000
- 200000
- 300000
- 400000
- 500000
- 600000
discriminator_optimizer_params:
epsilon: 1.0e-7 # Discriminator's epsilon.
weight_decay: 0.0 # Discriminator's weight decay coefficient.
discriminator_grad_norm: -1 # Discriminator's gradient norm.
discriminator_scheduler_params:
learning_rate: 1.0e-3 # Discriminator's learning rate.
gamma: 0.5 # Discriminator's scheduler gamma.
milestones: # At each milestone, lr will be multiplied by gamma.
- 100000
- 200000
- 300000
- 400000
- 500000
- 600000
###########################################################
# INTERVAL SETTING #
###########################################################
discriminator_train_start_steps: 200000 # Number of steps to start to train discriminator.
train_max_steps: 1000000 # Number of training steps.
save_interval_steps: 50000 # Interval steps to save checkpoint.
eval_interval_steps: 1000 # Interval steps to evaluate the network.
###########################################################
# OTHER SETTING #
###########################################################
num_snapshots: 10 # max number of snapshots to keep while training
seed: 42 # random seed for paddle, random, and np.random

@ -0,0 +1,55 @@
#!/bin/bash
stage=0
stop_stage=100
config_path=$1
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# get durations from MFA's result
echo "Generate durations.txt from MFA results ..."
python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \
--inputdir=./baker_alignment_tone \
--output=durations.txt \
--config=${config_path}
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# extract features
echo "Extract features ..."
python3 ${BIN_DIR}/../preprocess.py \
--rootdir=~/datasets/BZNSYP/ \
--dataset=baker \
--dumpdir=dump \
--dur-file=durations.txt \
--config=${config_path} \
--cut-sil=True \
--num-cpu=20
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# get features' stats(mean and std)
echo "Get features' stats ..."
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
--metadata=dump/train/raw/metadata.jsonl \
--field-name="feats"
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# normalize, dev and test should use train's stats
echo "Normalize ..."
python3 ${BIN_DIR}/../normalize.py \
--metadata=dump/train/raw/metadata.jsonl \
--dumpdir=dump/train/norm \
--stats=dump/train/feats_stats.npy
python3 ${BIN_DIR}/../normalize.py \
--metadata=dump/dev/raw/metadata.jsonl \
--dumpdir=dump/dev/norm \
--stats=dump/train/feats_stats.npy
python3 ${BIN_DIR}/../normalize.py \
--metadata=dump/test/raw/metadata.jsonl \
--dumpdir=dump/test/norm \
--stats=dump/train/feats_stats.npy
fi

@ -0,0 +1,13 @@
#!/bin/bash
config_path=$1
train_output_path=$2
ckpt_name=$3
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/synthesize.py \
--config=${config_path} \
--checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
--test-metadata=dump/test/norm/metadata.jsonl \
--output-dir=${train_output_path}/test

@ -0,0 +1,13 @@
#!/bin/bash
config_path=$1
train_output_path=$2
FLAGS_cudnn_exhaustive_search=true \
FLAGS_conv_workspace_size_limit=4000 \
python ${BIN_DIR}/train.py \
--train-metadata=dump/train/norm/metadata.jsonl \
--dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \
--output-dir=${train_output_path} \
--nprocs=1

@ -0,0 +1,13 @@
#!/bin/bash
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
MODEL=multi_band_melgan
export BIN_DIR=${MAIN_ROOT}/parakeet/exps/gan_vocoder/${MODEL}

@ -0,0 +1,32 @@
#!/bin/bash
set -e
source path.sh
gpus=0,1
stage=0
stop_stage=100
conf_path=conf/default.yaml
train_output_path=exp/default
ckpt_name=snapshot_iter_50000.pdz
# with the following command, you can choice the stage range you want to run
# such as `./run.sh --stage 0 --stop-stage 0`
# this can not be mixed use with `$1`, `$2` ...
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
./local/preprocess.sh ${conf_path} || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# synthesize
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
fi

@ -107,8 +107,13 @@ class Clip(object):
features, this process will be needed.
"""
if len(x) < c.shape[1] * self.hop_size:
x = np.pad(x, (0, c.shape[1] * self.hop_size - len(x)), mode="edge")
if len(x) < c.shape[0] * self.hop_size:
x = np.pad(x, (0, c.shape[0] * self.hop_size - len(x)), mode="edge")
elif len(x) > c.shape[0] * self.hop_size:
print(
f"wave length: ({len(x)}), mel length: ({c.shape[0]}), hop size: ({self.hop_size })"
)
x = x[:c.shape[1] * self.hop_size]
# check the legnth is valid
assert len(x) == c.shape[

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,97 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from pathlib import Path
import jsonlines
import numpy as np
import paddle
import soundfile as sf
import yaml
from paddle import distributed as dist
from timer import timer
from yacs.config import CfgNode
from parakeet.datasets.data_table import DataTable
from parakeet.models.melgan import MelGANGenerator
def main():
parser = argparse.ArgumentParser(
description="Synthesize with parallel wavegan.")
parser.add_argument(
"--config", type=str, help="parallel wavegan config file.")
parser.add_argument("--checkpoint", type=str, help="snapshot to load.")
parser.add_argument("--test-metadata", type=str, help="dev data.")
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--device", type=str, default="gpu", help="device to run.")
parser.add_argument("--verbose", type=int, default=1, help="verbose.")
args = parser.parse_args()
with open(args.config) as f:
config = CfgNode(yaml.safe_load(f))
print("========Args========")
print(yaml.safe_dump(vars(args)))
print("========Config========")
print(config)
print(
f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
)
paddle.set_device(args.device)
generator = MelGANGenerator(**config["generator_params"])
state_dict = paddle.load(args.checkpoint)
generator.set_state_dict(state_dict["generator_params"])
generator.remove_weight_norm()
generator.eval()
with jsonlines.open(args.test_metadata, 'r') as reader:
metadata = list(reader)
test_dataset = DataTable(
metadata,
fields=['utt_id', 'feats'],
converters={
'utt_id': None,
'feats': np.load,
})
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
N = 0
T = 0
for example in test_dataset:
utt_id = example['utt_id']
mel = example['feats']
mel = paddle.to_tensor(mel) # (T, C)
with timer() as t:
with paddle.no_grad():
wav = generator.inference(c=mel)
wav = wav.numpy()
N += wav.size
T += t.elapse
speed = wav.size / t.elapse
print(
f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {config.fs / speed}."
)
sf.write(str(output_dir / (utt_id + ".wav")), wav, samplerate=config.fs)
print(f"generation speed: {N / T}Hz, RTF: {config.fs / (N / T) }")
if __name__ == "__main__":
main()

@ -0,0 +1,271 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
import shutil
from pathlib import Path
import jsonlines
import numpy as np
import paddle
import yaml
from paddle import DataParallel
from paddle import distributed as dist
from paddle import nn
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from paddle.optimizer import Adam
from paddle.optimizer.lr import MultiStepDecay
from visualdl import LogWriter
from yacs.config import CfgNode
from parakeet.datasets.data_table import DataTable
from parakeet.datasets.vocoder_batch_fn import Clip
from parakeet.models.melgan import MBMelGANEvaluator
from parakeet.models.melgan import MBMelGANUpdater
from parakeet.models.melgan import MelGANGenerator
from parakeet.models.melgan import MelGANMultiScaleDiscriminator
from parakeet.modules.adversarial_loss import DiscriminatorAdversarialLoss
from parakeet.modules.adversarial_loss import GeneratorAdversarialLoss
from parakeet.modules.pqmf import PQMF
from parakeet.modules.stft_loss import MultiResolutionSTFTLoss
from parakeet.training.extensions.snapshot import Snapshot
from parakeet.training.extensions.visualizer import VisualDL
from parakeet.training.seeding import seed_everything
from parakeet.training.trainer import Trainer
def train_sp(args, config):
# decides device type and whether to run in parallel
# setup running environment correctly
world_size = paddle.distributed.get_world_size()
if not paddle.is_compiled_with_cuda():
paddle.set_device("cpu")
else:
paddle.set_device("gpu")
if world_size > 1:
paddle.distributed.init_parallel_env()
# set the random seed, it is a must for multiprocess training
seed_everything(config.seed)
print(
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
)
# dataloader has been too verbose
logging.getLogger("DataLoader").disabled = True
# construct dataset for training and validation
with jsonlines.open(args.train_metadata, 'r') as reader:
train_metadata = list(reader)
train_dataset = DataTable(
data=train_metadata,
fields=["wave", "feats"],
converters={
"wave": np.load,
"feats": np.load,
}, )
with jsonlines.open(args.dev_metadata, 'r') as reader:
dev_metadata = list(reader)
dev_dataset = DataTable(
data=dev_metadata,
fields=["wave", "feats"],
converters={
"wave": np.load,
"feats": np.load,
}, )
# collate function and dataloader
train_sampler = DistributedBatchSampler(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
drop_last=True)
dev_sampler = DistributedBatchSampler(
dev_dataset,
batch_size=config.batch_size,
shuffle=False,
drop_last=False)
print("samplers done!")
if "aux_context_window" in config.generator_params:
aux_context_window = config.generator_params.aux_context_window
else:
aux_context_window = 0
train_batch_fn = Clip(
batch_max_steps=config.batch_max_steps,
hop_size=config.n_shift,
aux_context_window=aux_context_window)
train_dataloader = DataLoader(
train_dataset,
batch_sampler=train_sampler,
collate_fn=train_batch_fn,
num_workers=config.num_workers)
dev_dataloader = DataLoader(
dev_dataset,
batch_sampler=dev_sampler,
collate_fn=train_batch_fn,
num_workers=config.num_workers)
print("dataloaders done!")
generator = MelGANGenerator(**config["generator_params"])
discriminator = MelGANMultiScaleDiscriminator(
**config["discriminator_params"])
if world_size > 1:
generator = DataParallel(generator)
discriminator = DataParallel(discriminator)
print("models done!")
criterion_stft = MultiResolutionSTFTLoss(**config["stft_loss_params"])
criterion_sub_stft = MultiResolutionSTFTLoss(
**config["subband_stft_loss_params"])
criterion_gen_adv = GeneratorAdversarialLoss()
criterion_dis_adv = DiscriminatorAdversarialLoss()
# define special module for subband processing
criterion_pqmf = PQMF(subbands=config["generator_params"]["out_channels"])
print("criterions done!")
lr_schedule_g = MultiStepDecay(**config["generator_scheduler_params"])
# Compared to multi_band_melgan.v1 config, Adam optimizer without gradient norm is used
generator_grad_norm = config["generator_grad_norm"]
gradient_clip_g = nn.ClipGradByGlobalNorm(
generator_grad_norm) if generator_grad_norm > 0 else None
print("gradient_clip_g:", gradient_clip_g)
optimizer_g = Adam(
learning_rate=lr_schedule_g,
grad_clip=gradient_clip_g,
parameters=generator.parameters(),
**config["generator_optimizer_params"])
lr_schedule_d = MultiStepDecay(**config["discriminator_scheduler_params"])
discriminator_grad_norm = config["discriminator_grad_norm"]
gradient_clip_d = nn.ClipGradByGlobalNorm(
discriminator_grad_norm) if discriminator_grad_norm > 0 else None
print("gradient_clip_d:", gradient_clip_d)
optimizer_d = Adam(
learning_rate=lr_schedule_d,
grad_clip=gradient_clip_d,
parameters=discriminator.parameters(),
**config["discriminator_optimizer_params"])
print("optimizers done!")
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
if dist.get_rank() == 0:
config_name = args.config.split("/")[-1]
# copy conf to output_dir
shutil.copyfile(args.config, output_dir / config_name)
updater = MBMelGANUpdater(
models={
"generator": generator,
"discriminator": discriminator,
},
optimizers={
"generator": optimizer_g,
"discriminator": optimizer_d,
},
criterions={
"stft": criterion_stft,
"sub_stft": criterion_sub_stft,
"gen_adv": criterion_gen_adv,
"dis_adv": criterion_dis_adv,
"pqmf": criterion_pqmf
},
schedulers={
"generator": lr_schedule_g,
"discriminator": lr_schedule_d,
},
dataloader=train_dataloader,
discriminator_train_start_steps=config.discriminator_train_start_steps,
lambda_adv=config.lambda_adv,
output_dir=output_dir)
evaluator = MBMelGANEvaluator(
models={
"generator": generator,
"discriminator": discriminator,
},
criterions={
"stft": criterion_stft,
"sub_stft": criterion_sub_stft,
"gen_adv": criterion_gen_adv,
"dis_adv": criterion_dis_adv,
"pqmf": criterion_pqmf
},
dataloader=dev_dataloader,
lambda_adv=config.lambda_adv,
output_dir=output_dir)
trainer = Trainer(
updater,
stop_trigger=(config.train_max_steps, "iteration"),
out=output_dir)
if dist.get_rank() == 0:
trainer.extend(
evaluator, trigger=(config.eval_interval_steps, 'iteration'))
writer = LogWriter(str(trainer.out))
trainer.extend(VisualDL(writer), trigger=(1, 'iteration'))
trainer.extend(
Snapshot(max_size=config.num_snapshots),
trigger=(config.save_interval_steps, 'iteration'))
print("Trainer Done!")
trainer.run()
def main():
# parse args and config and redirect to train_sp
parser = argparse.ArgumentParser(
description="Train a Multi-Band MelGAN model.")
parser.add_argument(
"--config", type=str, help="config file to overwrite default config.")
parser.add_argument("--train-metadata", type=str, help="training data.")
parser.add_argument("--dev-metadata", type=str, help="dev data.")
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--device", type=str, default="gpu", help="device type to use.")
parser.add_argument(
"--nprocs", type=int, default=1, help="number of processes.")
parser.add_argument("--verbose", type=int, default=1, help="verbose.")
args = parser.parse_args()
if args.device == "cpu" and args.nprocs > 1:
raise RuntimeError("Multiprocess training on CPU is not supported.")
with open(args.config, 'rt') as f:
config = CfgNode(yaml.safe_load(f))
print("========Args========")
print(yaml.safe_dump(vars(args)))
print("========Config========")
print(config)
print(
f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
)
# dispatch
if args.nprocs > 1:
dist.spawn(train_sp, (args, config), nprocs=args.nprocs)
else:
train_sp(args, config)
if __name__ == "__main__":
main()

@ -0,0 +1,15 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .melgan import *
from .multi_band_melgan_updater import *

@ -0,0 +1,541 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MelGAN Modules."""
from typing import Any
from typing import Dict
from typing import List
import numpy as np
import paddle
from paddle import nn
from paddle.fluid.layers import Normal
from parakeet.modules.causal_conv import CausalConv1D
from parakeet.modules.causal_conv import CausalConv1DTranspose
from parakeet.modules.pqmf import PQMF
from parakeet.modules.residual_stack import ResidualStack
class MelGANGenerator(nn.Layer):
"""MelGAN generator module."""
def __init__(
self,
in_channels: int=80,
out_channels: int=1,
kernel_size: int=7,
channels: int=512,
bias: bool=True,
upsample_scales: List[int]=[8, 8, 2, 2],
stack_kernel_size: int=3,
stacks: int=3,
nonlinear_activation: str="LeakyReLU",
nonlinear_activation_params: Dict[str, Any]={"negative_slope": 0.2},
pad: str="Pad1D",
pad_params: Dict[str, Any]={"mode": "reflect"},
use_final_nonlinear_activation: bool=True,
use_weight_norm: bool=True,
use_causal_conv: bool=False, ):
"""Initialize MelGANGenerator module.
Parameters
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels,
the number of sub-band is out_channels in multi-band melgan.
kernel_size : int
Kernel size of initial and final conv layer.
channels : int
Initial number of channels for conv layer.
bias : bool
Whether to add bias parameter in convolution layers.
upsample_scales : List[int]
List of upsampling scales.
stack_kernel_size : int
Kernel size of dilated conv layers in residual stack.
stacks : int
Number of stacks in a single residual stack.
nonlinear_activation : Optional[str], optional
Non linear activation in upsample network, by default None
nonlinear_activation_params : Dict[str, Any], optional
Parameters passed to the linear activation in the upsample network,
by default {}
pad : str
Padding function module name before dilated convolution layer.
pad_params : dict
Hyperparameters for padding function.
use_final_nonlinear_activation : paddle.nn.Layer
Activation function for the final layer.
use_weight_norm : bool
Whether to use weight norm.
If set to true, it will be applied to all of the conv layers.
use_causal_conv : bool
Whether to use causal convolution.
"""
super().__init__()
# check hyper parameters is valid
assert channels >= np.prod(upsample_scales)
assert channels % (2**len(upsample_scales)) == 0
if not use_causal_conv:
assert (kernel_size - 1
) % 2 == 0, "Not support even number kernel size."
# add initial layer
layers = []
if not use_causal_conv:
layers += [
getattr(paddle.nn, pad)((kernel_size - 1) // 2, **pad_params),
nn.Conv1D(in_channels, channels, kernel_size, bias_attr=bias),
]
else:
layers += [
CausalConv1D(
in_channels,
channels,
kernel_size,
bias=bias,
pad=pad,
pad_params=pad_params, ),
]
for i, upsample_scale in enumerate(upsample_scales):
# add upsampling layer
layers += [
getattr(nn, nonlinear_activation)(**nonlinear_activation_params)
]
if not use_causal_conv:
layers += [
nn.Conv1DTranspose(
channels // (2**i),
channels // (2**(i + 1)),
upsample_scale * 2,
stride=upsample_scale,
padding=upsample_scale // 2 + upsample_scale % 2,
output_padding=upsample_scale % 2,
bias_attr=bias, )
]
else:
layers += [
CausalConv1DTranspose(
channels // (2**i),
channels // (2**(i + 1)),
upsample_scale * 2,
stride=upsample_scale,
bias=bias, )
]
# add residual stack
for j in range(stacks):
layers += [
ResidualStack(
kernel_size=stack_kernel_size,
channels=channels // (2**(i + 1)),
dilation=stack_kernel_size**j,
bias=bias,
nonlinear_activation=nonlinear_activation,
nonlinear_activation_params=nonlinear_activation_params,
pad=pad,
pad_params=pad_params,
use_causal_conv=use_causal_conv, )
]
# add final layer
layers += [
getattr(nn, nonlinear_activation)(**nonlinear_activation_params)
]
if not use_causal_conv:
layers += [
getattr(nn, pad)((kernel_size - 1) // 2, **pad_params),
nn.Conv1D(
channels // (2**(i + 1)),
out_channels,
kernel_size,
bias_attr=bias),
]
else:
layers += [
CausalConv1D(
channels // (2**(i + 1)),
out_channels,
kernel_size,
bias=bias,
pad=pad,
pad_params=pad_params, ),
]
if use_final_nonlinear_activation:
layers += [nn.Tanh()]
# define the model as a single function
self.melgan = nn.Sequential(*layers)
# apply weight norm
if use_weight_norm:
self.apply_weight_norm()
# reset parameters
self.reset_parameters()
# initialize pqmf for multi-band melgan inference
if out_channels > 1:
self.pqmf = PQMF(subbands=out_channels)
else:
self.pqmf = None
def forward(self, c):
"""Calculate forward propagation.
Parameters
----------
c : Tensor
Input tensor (B, in_channels, T).
Returns
----------
Tensor
Output tensor (B, out_channels, T ** prod(upsample_scales)).
"""
out = self.melgan(c)
return out
def apply_weight_norm(self):
"""Recursively apply weight normalization to all the Convolution layers
in the sublayers.
"""
def _apply_weight_norm(layer):
if isinstance(layer, (nn.Conv1D, nn.Conv2D, nn.Conv1DTranspose)):
nn.utils.weight_norm(layer)
self.apply(_apply_weight_norm)
def remove_weight_norm(self):
"""Recursively remove weight normalization from all the Convolution
layers in the sublayers.
"""
def _remove_weight_norm(layer):
try:
nn.utils.remove_weight_norm(layer)
except ValueError:
pass
self.apply(_remove_weight_norm)
def reset_parameters(self):
"""Reset parameters.
This initialization follows official implementation manner.
https://github.com/descriptinc/melgan-neurips/blob/master/mel2wav/modules.py
"""
# 定义参数为float的正态分布。
dist = Normal(loc=0.0, scale=0.02)
def _reset_parameters(m):
if isinstance(m, nn.Conv1D) or isinstance(m, nn.Conv1DTranspose):
w = dist.sample(m.weight.shape)
m.weight.set_value(w)
self.apply(_reset_parameters)
def inference(self, c):
"""Perform inference.
Parameters
----------
c : Union[Tensor, ndarray]
Input tensor (T, in_channels).
Returns
----------
Tensor
Output tensor (out_channels*T ** prod(upsample_scales), 1).
"""
if not isinstance(c, paddle.Tensor):
c = paddle.to_tensor(c, dtype="float32")
# pseudo batch
c = c.transpose([1, 0]).unsqueeze(0)
# (B, out_channels, T ** prod(upsample_scales)
out = self.melgan(c)
if self.pqmf is not None:
# (B, 1, out_channels * T ** prod(upsample_scales)
out = self.pqmf.synthesis(out)
out = out.squeeze(0).transpose([1, 0])
return out
class MelGANDiscriminator(nn.Layer):
"""MelGAN discriminator module."""
def __init__(
self,
in_channels: int=1,
out_channels: int=1,
kernel_sizes: List[int]=[5, 3],
channels: int=16,
max_downsample_channels: int=1024,
bias: bool=True,
downsample_scales: List[int]=[4, 4, 4, 4],
nonlinear_activation: str="LeakyReLU",
nonlinear_activation_params: Dict[str, Any]={"negative_slope": 0.2},
pad: str="Pad1D",
pad_params: Dict[str, Any]={"mode": "reflect"}, ):
"""Initilize MelGAN discriminator module.
Parameters
----------
in_channels :
int): Number of input channels.
out_channels : int
Number of output channels.
kernel_sizes : List[int]
List of two kernel sizes. The prod will be used for the first conv layer,
and the first and the second kernel sizes will be used for the last two layers.
For example if kernel_sizes = [5, 3], the first layer kernel size will be 5 * 3 = 15,
the last two layers' kernel size will be 5 and 3, respectively.
channels : int
Initial number of channels for conv layer.
max_downsample_channels : int
Maximum number of channels for downsampling layers.
bias : bool
Whether to add bias parameter in convolution layers.
downsample_scales : List[int]
List of downsampling scales.
nonlinear_activation : str
Activation function module name.
nonlinear_activation_params : dict
Hyperparameters for activation function.
pad : str
Padding function module name before dilated convolution layer.
pad_params : dict
Hyperparameters for padding function.
"""
super().__init__()
self.layers = nn.LayerList()
# check kernel size is valid
assert len(kernel_sizes) == 2
assert kernel_sizes[0] % 2 == 1
assert kernel_sizes[1] % 2 == 1
# add first layer
self.layers.append(
nn.Sequential(
getattr(nn, pad)((np.prod(kernel_sizes) - 1) // 2, **
pad_params),
nn.Conv1D(
in_channels,
channels,
int(np.prod(kernel_sizes)),
bias_attr=bias),
getattr(nn, nonlinear_activation)(
**nonlinear_activation_params), ))
# add downsample layers
in_chs = channels
for downsample_scale in downsample_scales:
out_chs = min(in_chs * downsample_scale, max_downsample_channels)
self.layers.append(
nn.Sequential(
nn.Conv1D(
in_chs,
out_chs,
kernel_size=downsample_scale * 10 + 1,
stride=downsample_scale,
padding=downsample_scale * 5,
groups=in_chs // 4,
bias_attr=bias, ),
getattr(nn, nonlinear_activation)(
**nonlinear_activation_params), ))
in_chs = out_chs
# add final layers
out_chs = min(in_chs * 2, max_downsample_channels)
self.layers.append(
nn.Sequential(
nn.Conv1D(
in_chs,
out_chs,
kernel_sizes[0],
padding=(kernel_sizes[0] - 1) // 2,
bias_attr=bias, ),
getattr(nn, nonlinear_activation)(
**nonlinear_activation_params), ))
self.layers.append(
nn.Conv1D(
out_chs,
out_channels,
kernel_sizes[1],
padding=(kernel_sizes[1] - 1) // 2,
bias_attr=bias, ), )
def forward(self, x):
"""Calculate forward propagation.
Parameters
----------
x : Tensor
Input noise signal (B, 1, T).
Returns
----------
List
List of output tensors of each layer (for feat_match_loss).
"""
outs = []
for f in self.layers:
x = f(x)
outs += [x]
return outs
class MelGANMultiScaleDiscriminator(nn.Layer):
"""MelGAN multi-scale discriminator module."""
def __init__(
self,
in_channels: int=1,
out_channels: int=1,
scales: int=3,
downsample_pooling: str="AvgPool1D",
# follow the official implementation setting
downsample_pooling_params: Dict[str, Any]={
"kernel_size": 4,
"stride": 2,
"padding": 1,
"exclusive": True,
},
kernel_sizes: List[int]=[5, 3],
channels: int=16,
max_downsample_channels: int=1024,
bias: bool=True,
downsample_scales: List[int]=[4, 4, 4, 4],
nonlinear_activation: str="LeakyReLU",
nonlinear_activation_params: Dict[str, Any]={"negative_slope": 0.2},
pad: str="Pad1D",
pad_params: Dict[str, Any]={"mode": "reflect"},
use_weight_norm: bool=True, ):
"""Initilize MelGAN multi-scale discriminator module.
Parameters
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
scales : int
Number of multi-scales.
downsample_pooling : str
Pooling module name for downsampling of the inputs.
downsample_pooling_params : dict
Parameters for the above pooling module.
kernel_sizes : List[int]
List of two kernel sizes. The sum will be used for the first conv layer,
and the first and the second kernel sizes will be used for the last two layers.
channels : int
Initial number of channels for conv layer.
max_downsample_channels : int
Maximum number of channels for downsampling layers.
bias : bool
Whether to add bias parameter in convolution layers.
downsample_scales : List[int]
List of downsampling scales.
nonlinear_activation : str
Activation function module name.
nonlinear_activation_params : dict
Hyperparameters for activation function.
pad : str
Padding function module name before dilated convolution layer.
pad_params : dict
Hyperparameters for padding function.
use_causal_conv : bool
Whether to use causal convolution.
"""
super().__init__()
self.discriminators = nn.LayerList()
# add discriminators
for _ in range(scales):
self.discriminators.append(
MelGANDiscriminator(
in_channels=in_channels,
out_channels=out_channels,
kernel_sizes=kernel_sizes,
channels=channels,
max_downsample_channels=max_downsample_channels,
bias=bias,
downsample_scales=downsample_scales,
nonlinear_activation=nonlinear_activation,
nonlinear_activation_params=nonlinear_activation_params,
pad=pad,
pad_params=pad_params, ))
self.pooling = getattr(nn, downsample_pooling)(
**downsample_pooling_params)
# apply weight norm
if use_weight_norm:
self.apply_weight_norm()
# reset parameters
self.reset_parameters()
def forward(self, x):
"""Calculate forward propagation.
Parameters
----------
x : Tensor
Input noise signal (B, 1, T).
Returns
----------
List
List of list of each discriminator outputs, which consists of each layer output tensors.
"""
outs = []
for f in self.discriminators:
outs += [f(x)]
x = self.pooling(x)
return outs
def apply_weight_norm(self):
"""Recursively apply weight normalization to all the Convolution layers
in the sublayers.
"""
def _apply_weight_norm(layer):
if isinstance(layer, (nn.Conv1D, nn.Conv2D, nn.Conv1DTranspose)):
nn.utils.weight_norm(layer)
self.apply(_apply_weight_norm)
def remove_weight_norm(self):
"""Recursively remove weight normalization from all the Convolution
layers in the sublayers.
"""
def _remove_weight_norm(layer):
try:
nn.utils.remove_weight_norm(layer)
except ValueError:
pass
self.apply(_remove_weight_norm)
def reset_parameters(self):
"""Reset parameters.
This initialization follows official implementation manner.
https://github.com/descriptinc/melgan-neurips/blob/master/mel2wav/modules.py
"""
# 定义参数为float的正态分布。
dist = Normal(loc=0.0, scale=0.02)
def _reset_parameters(m):
if isinstance(m, nn.Conv1D) or isinstance(m, nn.Conv1DTranspose):
w = dist.sample(m.weight.shape)
m.weight.set_value(w)
self.apply(_reset_parameters)

@ -0,0 +1,231 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Dict
import paddle
from paddle import distributed as dist
from paddle.io import DataLoader
from paddle.nn import Layer
from paddle.optimizer import Optimizer
from paddle.optimizer.lr import LRScheduler
from timer import timer
from parakeet.training.extensions.evaluator import StandardEvaluator
from parakeet.training.reporter import report
from parakeet.training.updaters.standard_updater import StandardUpdater
from parakeet.training.updaters.standard_updater import UpdaterState
logging.basicConfig(
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
datefmt='[%Y-%m-%d %H:%M:%S]')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class PWGUpdater(StandardUpdater):
def __init__(self,
models: Dict[str, Layer],
optimizers: Dict[str, Optimizer],
criterions: Dict[str, Layer],
schedulers: Dict[str, LRScheduler],
dataloader: DataLoader,
discriminator_train_start_steps: int,
lambda_adv: float,
output_dir=None):
self.models = models
self.generator: Layer = models['generator']
self.discriminator: Layer = models['discriminator']
self.optimizers = optimizers
self.optimizer_g: Optimizer = optimizers['generator']
self.optimizer_d: Optimizer = optimizers['discriminator']
self.criterions = criterions
self.criterion_stft = criterions['stft']
self.criterion_mse = criterions['mse']
self.schedulers = schedulers
self.scheduler_g = schedulers['generator']
self.scheduler_d = schedulers['discriminator']
self.dataloader = dataloader
self.discriminator_train_start_steps = discriminator_train_start_steps
self.lambda_adv = lambda_adv
self.state = UpdaterState(iteration=0, epoch=0)
self.train_iterator = iter(self.dataloader)
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def update_core(self, batch):
self.msg = "Rank: {}, ".format(dist.get_rank())
losses_dict = {}
# parse batch
wav, mel = batch
# Generator
noise = paddle.randn(wav.shape)
with timer() as t:
wav_ = self.generator(noise, mel)
# logging.debug(f"Generator takes {t.elapse}s.")
# initialize
gen_loss = 0.0
## Multi-resolution stft loss
with timer() as t:
sc_loss, mag_loss = self.criterion_stft(wav_, wav)
# logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s.")
report("train/spectral_convergence_loss", float(sc_loss))
report("train/log_stft_magnitude_loss", float(mag_loss))
losses_dict["spectral_convergence_loss"] = float(sc_loss)
losses_dict["log_stft_magnitude_loss"] = float(mag_loss)
gen_loss += sc_loss + mag_loss
## Adversarial loss
if self.state.iteration > self.discriminator_train_start_steps:
with timer() as t:
p_ = self.discriminator(wav_)
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
# logging.debug(
# f"Discriminator and adversarial loss takes {t.elapse}s")
report("train/adversarial_loss", float(adv_loss))
losses_dict["adversarial_loss"] = float(adv_loss)
gen_loss += self.lambda_adv * adv_loss
report("train/generator_loss", float(gen_loss))
losses_dict["generator_loss"] = float(gen_loss)
with timer() as t:
self.optimizer_g.clear_grad()
gen_loss.backward()
# logging.debug(f"Backward takes {t.elapse}s.")
with timer() as t:
self.optimizer_g.step()
self.scheduler_g.step()
# logging.debug(f"Update takes {t.elapse}s.")
# Disctiminator
if self.state.iteration > self.discriminator_train_start_steps:
with paddle.no_grad():
wav_ = self.generator(noise, mel)
p = self.discriminator(wav)
p_ = self.discriminator(wav_.detach())
real_loss = self.criterion_mse(p, paddle.ones_like(p))
fake_loss = self.criterion_mse(p_, paddle.zeros_like(p_))
dis_loss = real_loss + fake_loss
report("train/real_loss", float(real_loss))
report("train/fake_loss", float(fake_loss))
report("train/discriminator_loss", float(dis_loss))
losses_dict["real_loss"] = float(real_loss)
losses_dict["fake_loss"] = float(fake_loss)
losses_dict["discriminator_loss"] = float(dis_loss)
self.optimizer_d.clear_grad()
dis_loss.backward()
self.optimizer_d.step()
self.scheduler_d.step()
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
class PWGEvaluator(StandardEvaluator):
def __init__(self,
models,
criterions,
dataloader,
lambda_adv,
output_dir=None):
self.models = models
self.generator = models['generator']
self.discriminator = models['discriminator']
self.criterions = criterions
self.criterion_stft = criterions['stft']
self.criterion_mse = criterions['mse']
self.dataloader = dataloader
self.lambda_adv = lambda_adv
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def evaluate_core(self, batch):
# logging.debug("Evaluate: ")
self.msg = "Evaluate: "
losses_dict = {}
wav, mel = batch
noise = paddle.randn(wav.shape)
with timer() as t:
wav_ = self.generator(noise, mel)
# logging.debug(f"Generator takes {t.elapse}s")
## Adversarial loss
with timer() as t:
p_ = self.discriminator(wav_)
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
# logging.debug(
# f"Discriminator and adversarial loss takes {t.elapse}s")
report("eval/adversarial_loss", float(adv_loss))
losses_dict["adversarial_loss"] = float(adv_loss)
gen_loss = self.lambda_adv * adv_loss
# stft loss
with timer() as t:
sc_loss, mag_loss = self.criterion_stft(wav_, wav)
# logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s")
report("eval/spectral_convergence_loss", float(sc_loss))
report("eval/log_stft_magnitude_loss", float(mag_loss))
losses_dict["spectral_convergence_loss"] = float(sc_loss)
losses_dict["log_stft_magnitude_loss"] = float(mag_loss)
gen_loss += sc_loss + mag_loss
report("eval/generator_loss", float(gen_loss))
losses_dict["generator_loss"] = float(gen_loss)
# Disctiminator
p = self.discriminator(wav)
real_loss = self.criterion_mse(p, paddle.ones_like(p))
fake_loss = self.criterion_mse(p_, paddle.zeros_like(p_))
dis_loss = real_loss + fake_loss
report("eval/real_loss", float(real_loss))
report("eval/fake_loss", float(fake_loss))
report("eval/discriminator_loss", float(dis_loss))
losses_dict["real_loss"] = float(real_loss)
losses_dict["fake_loss"] = float(fake_loss)
losses_dict["discriminator_loss"] = float(dis_loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
self.logger.info(self.msg)

@ -0,0 +1,245 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Dict
import paddle
from paddle import distributed as dist
from paddle.io import DataLoader
from paddle.nn import Layer
from paddle.optimizer import Optimizer
from paddle.optimizer.lr import LRScheduler
from parakeet.training.extensions.evaluator import StandardEvaluator
from parakeet.training.reporter import report
from parakeet.training.updaters.standard_updater import StandardUpdater
from parakeet.training.updaters.standard_updater import UpdaterState
logging.basicConfig(
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
datefmt='[%Y-%m-%d %H:%M:%S]')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class MBMelGANUpdater(StandardUpdater):
def __init__(self,
models: Dict[str, Layer],
optimizers: Dict[str, Optimizer],
criterions: Dict[str, Layer],
schedulers: Dict[str, LRScheduler],
dataloader: DataLoader,
discriminator_train_start_steps: int,
lambda_adv: float,
output_dir=None):
self.models = models
self.generator: Layer = models['generator']
self.discriminator: Layer = models['discriminator']
self.optimizers = optimizers
self.optimizer_g: Optimizer = optimizers['generator']
self.optimizer_d: Optimizer = optimizers['discriminator']
self.criterions = criterions
self.criterion_stft = criterions['stft']
self.criterion_sub_stft = criterions['sub_stft']
self.criterion_pqmf = criterions['pqmf']
self.criterion_gen_adv = criterions["gen_adv"]
self.criterion_dis_adv = criterions["dis_adv"]
self.schedulers = schedulers
self.scheduler_g = schedulers['generator']
self.scheduler_d = schedulers['discriminator']
self.dataloader = dataloader
self.discriminator_train_start_steps = discriminator_train_start_steps
self.lambda_adv = lambda_adv
self.state = UpdaterState(iteration=0, epoch=0)
self.train_iterator = iter(self.dataloader)
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def update_core(self, batch):
self.msg = "Rank: {}, ".format(dist.get_rank())
losses_dict = {}
# parse batch
wav, mel = batch
# Generator
# (B, out_channels, T ** prod(upsample_scales)
wav_ = self.generator(mel)
wav_mb_ = wav_
# (B, 1, out_channels*T ** prod(upsample_scales)
wav_ = self.criterion_pqmf.synthesis(wav_mb_)
# initialize
gen_loss = 0.0
# full band Multi-resolution stft loss
sc_loss, mag_loss = self.criterion_stft(wav_, wav)
# for balancing with subband stft loss
# Eq.(9) in paper
gen_loss += 0.5 * (sc_loss + mag_loss)
report("train/spectral_convergence_loss", float(sc_loss))
report("train/log_stft_magnitude_loss", float(mag_loss))
losses_dict["spectral_convergence_loss"] = float(sc_loss)
losses_dict["log_stft_magnitude_loss"] = float(mag_loss)
# sub band Multi-resolution stft loss
# (B, subbands, T // subbands)
wav_mb = self.criterion_pqmf.analysis(wav)
sub_sc_loss, sub_mag_loss = self.criterion_sub_stft(wav_mb_, wav_mb)
# Eq.(9) in paper
gen_loss += 0.5 * (sub_sc_loss + sub_mag_loss)
report("train/sub_spectral_convergence_loss", float(sub_sc_loss))
report("train/sub_log_stft_magnitude_loss", float(sub_mag_loss))
losses_dict["sub_spectral_convergence_loss"] = float(sub_sc_loss)
losses_dict["sub_log_stft_magnitude_loss"] = float(sub_mag_loss)
## Adversarial loss
if self.state.iteration > self.discriminator_train_start_steps:
p_ = self.discriminator(wav_)
adv_loss = self.criterion_gen_adv(p_)
report("train/adversarial_loss", float(adv_loss))
losses_dict["adversarial_loss"] = float(adv_loss)
gen_loss += self.lambda_adv * adv_loss
report("train/generator_loss", float(gen_loss))
losses_dict["generator_loss"] = float(gen_loss)
self.optimizer_g.clear_grad()
gen_loss.backward()
self.optimizer_g.step()
self.scheduler_g.step()
# Disctiminator
if self.state.iteration > self.discriminator_train_start_steps:
# re-compute wav_ which leads better quality
with paddle.no_grad():
wav_ = self.generator(mel)
wav_ = self.criterion_pqmf.synthesis(wav_)
p = self.discriminator(wav)
p_ = self.discriminator(wav_.detach())
real_loss, fake_loss = self.criterion_dis_adv(p_, p)
dis_loss = real_loss + fake_loss
report("train/real_loss", float(real_loss))
report("train/fake_loss", float(fake_loss))
report("train/discriminator_loss", float(dis_loss))
losses_dict["real_loss"] = float(real_loss)
losses_dict["fake_loss"] = float(fake_loss)
losses_dict["discriminator_loss"] = float(dis_loss)
self.optimizer_d.clear_grad()
dis_loss.backward()
self.optimizer_d.step()
self.scheduler_d.step()
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
class MBMelGANEvaluator(StandardEvaluator):
def __init__(self,
models,
criterions,
dataloader,
lambda_adv,
output_dir=None):
self.models = models
self.generator = models['generator']
self.discriminator = models['discriminator']
self.criterions = criterions
self.criterion_stft = criterions['stft']
self.criterion_sub_stft = criterions['sub_stft']
self.criterion_pqmf = criterions['pqmf']
self.criterion_gen_adv = criterions["gen_adv"]
self.criterion_dis_adv = criterions["dis_adv"]
self.dataloader = dataloader
self.lambda_adv = lambda_adv
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def evaluate_core(self, batch):
# logging.debug("Evaluate: ")
self.msg = "Evaluate: "
losses_dict = {}
wav, mel = batch
# Generator
# (B, out_channels, T ** prod(upsample_scales)
wav_ = self.generator(mel)
wav_mb_ = wav_
# (B, 1, out_channels*T ** prod(upsample_scales)
wav_ = self.criterion_pqmf.synthesis(wav_mb_)
## Adversarial loss
p_ = self.discriminator(wav_)
adv_loss = self.criterion_gen_adv(p_)
report("eval/adversarial_loss", float(adv_loss))
losses_dict["adversarial_loss"] = float(adv_loss)
gen_loss = self.lambda_adv * adv_loss
# Multi-resolution stft loss
sc_loss, mag_loss = self.criterion_stft(wav_, wav)
# Eq.(9) in paper
gen_loss += 0.5 * (sc_loss + mag_loss)
report("eval/spectral_convergence_loss", float(sc_loss))
report("eval/log_stft_magnitude_loss", float(mag_loss))
losses_dict["spectral_convergence_loss"] = float(sc_loss)
losses_dict["log_stft_magnitude_loss"] = float(mag_loss)
# sub band Multi-resolution stft loss
# (B, subbands, T // subbands)
wav_mb = self.criterion_pqmf.analysis(wav)
sub_sc_loss, sub_mag_loss = self.criterion_sub_stft(wav_mb_, wav_mb)
# Eq.(9) in paper
gen_loss += 0.5 * (sub_sc_loss + sub_mag_loss)
report("eval/sub_spectral_convergence_loss", float(sub_sc_loss))
report("eval/sub_log_stft_magnitude_loss", float(sub_mag_loss))
losses_dict["sub_spectral_convergence_loss"] = float(sub_sc_loss)
losses_dict["sub_log_stft_magnitude_loss"] = float(sub_mag_loss)
report("eval/generator_loss", float(gen_loss))
losses_dict["generator_loss"] = float(gen_loss)
# Disctiminator
p = self.discriminator(wav)
real_loss, fake_loss = self.criterion_dis_adv(p_, p)
dis_loss = real_loss + fake_loss
report("eval/real_loss", float(real_loss))
report("eval/fake_loss", float(fake_loss))
report("eval/discriminator_loss", float(dis_loss))
losses_dict["real_loss"] = float(real_loss)
losses_dict["fake_loss"] = float(fake_loss)
losses_dict["discriminator_loss"] = float(dis_loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
self.logger.info(self.msg)

@ -495,26 +495,25 @@ class PWGGenerator(nn.Layer):
self.apply(_remove_weight_norm)
def inference(self, c=None):
def inference(self, c):
"""Waveform generation. This function is used for single instance
inference.
Parameters
----------
c : Tensor, optional
c : Tensor
Shape (T', C_aux), the auxiliary input, by default None
x : Tensor, optional
Shape (T, C_in), the noise waveform, by default None
If not provided, a sample is drawn from a gaussian distribution.
Returns
-------
Tensor
Shape (T, C_out), the generated waveform
"""
# a sample is drawn from a gaussian distribution.
x = paddle.randn(
[1, self.in_channels, paddle.shape(c)[0] * self.upsample_factor])
c = paddle.transpose(c, [1, 0]).unsqueeze(0) # pseudo batch
# pseudo batch
c = paddle.transpose(c, [1, 0]).unsqueeze(0)
c = nn.Pad1D(self.aux_context_window, mode='replicate')(c)
out = self(x, c).squeeze(0).transpose([1, 0])
return out

@ -0,0 +1,124 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Adversarial loss modules."""
import paddle
import paddle.nn.functional as F
from paddle import nn
class GeneratorAdversarialLoss(nn.Layer):
"""Generator adversarial loss module."""
def __init__(
self,
average_by_discriminators=True,
loss_type="mse", ):
"""Initialize GeneratorAversarialLoss module."""
super().__init__()
self.average_by_discriminators = average_by_discriminators
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
if loss_type == "mse":
self.criterion = self._mse_loss
else:
self.criterion = self._hinge_loss
def forward(self, outputs):
"""Calcualate generator adversarial loss.
Parameters
----------
outputs: Tensor or List
Discriminator outputs or list of discriminator outputs.
Returns
----------
Tensor
Generator adversarial loss value.
"""
if isinstance(outputs, (tuple, list)):
adv_loss = 0.0
for i, outputs_ in enumerate(outputs):
if isinstance(outputs_, (tuple, list)):
# case including feature maps
outputs_ = outputs_[-1]
adv_loss += self.criterion(outputs_)
if self.average_by_discriminators:
adv_loss /= i + 1
else:
adv_loss = self.criterion(outputs)
return adv_loss
def _mse_loss(self, x):
return F.mse_loss(x, paddle.ones_like(x))
def _hinge_loss(self, x):
return -x.mean()
class DiscriminatorAdversarialLoss(nn.Layer):
"""Discriminator adversarial loss module."""
def __init__(
self,
average_by_discriminators=True,
loss_type="mse", ):
"""Initialize DiscriminatorAversarialLoss module."""
super().__init__()
self.average_by_discriminators = average_by_discriminators
assert loss_type in ["mse"], f"{loss_type} is not supported."
if loss_type == "mse":
self.fake_criterion = self._mse_fake_loss
self.real_criterion = self._mse_real_loss
def forward(self, outputs_hat, outputs):
"""Calcualate discriminator adversarial loss.
Parameters
----------
outputs_hat : Tensor or list
Discriminator outputs or list of
discriminator outputs calculated from generator outputs.
outputs : Tensor or list
Discriminator outputs or list of
discriminator outputs calculated from groundtruth.
Returns
----------
Tensor
Discriminator real loss value.
Tensor
Discriminator fake loss value.
"""
if isinstance(outputs, (tuple, list)):
real_loss = 0.0
fake_loss = 0.0
for i, (outputs_hat_,
outputs_) in enumerate(zip(outputs_hat, outputs)):
if isinstance(outputs_hat_, (tuple, list)):
# case including feature maps
outputs_hat_ = outputs_hat_[-1]
outputs_ = outputs_[-1]
real_loss += self.real_criterion(outputs_)
fake_loss += self.fake_criterion(outputs_hat_)
if self.average_by_discriminators:
fake_loss /= i + 1
real_loss /= i + 1
else:
real_loss = self.real_criterion(outputs)
fake_loss = self.fake_criterion(outputs_hat)
return real_loss, fake_loss
def _mse_real_loss(self, x):
return F.mse_loss(x, paddle.ones_like(x))
def _mse_fake_loss(self, x):
return F.mse_loss(x, paddle.zeros_like(x))

@ -0,0 +1,81 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Causal convolusion layer modules."""
import paddle
class CausalConv1D(paddle.nn.Layer):
"""CausalConv1D module with customized initialization."""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
dilation=1,
bias=True,
pad="Pad1D",
pad_params={"value": 0.0}, ):
"""Initialize CausalConv1d module."""
super().__init__()
self.pad = getattr(paddle.nn, pad)((kernel_size - 1) * dilation,
**pad_params)
self.conv = paddle.nn.Conv1D(
in_channels,
out_channels,
kernel_size,
dilation=dilation,
bias_attr=bias)
def forward(self, x):
"""Calculate forward propagation.
Parameters
----------
x : Tensor
Input tensor (B, in_channels, T).
Returns
----------
Tensor
Output tensor (B, out_channels, T).
"""
return self.conv(self.pad(x))[:, :, :x.shape[2]]
class CausalConv1DTranspose(paddle.nn.Layer):
"""CausalConv1DTranspose module with customized initialization."""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
bias=True):
"""Initialize CausalConvTranspose1d module."""
super().__init__()
self.deconv = paddle.nn.Conv1DTranspose(
in_channels, out_channels, kernel_size, stride, bias_attr=bias)
self.stride = stride
def forward(self, x):
"""Calculate forward propagation.
Parameters
----------
x : Tensor
Input tensor (B, in_channels, T_in).
Returns
----------
Tensor
Output tensor (B, out_channels, T_out).
"""
return self.deconv(x)[:, :, :-self.stride]

@ -0,0 +1,140 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pseudo QMF modules."""
import numpy as np
import paddle
import paddle.nn.functional as F
from scipy.signal import kaiser
def design_prototype_filter(taps=62, cutoff_ratio=0.142, beta=9.0):
"""Design prototype filter for PQMF.
This method is based on `A Kaiser window approach for the design of prototype
filters of cosine modulated filterbanks`_.
Parameters
----------
taps : int
The number of filter taps.
cutoff_ratio : float
Cut-off frequency ratio.
beta : float
Beta coefficient for kaiser window.
Returns
----------
ndarray
Impluse response of prototype filter (taps + 1,).
.. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:
https://ieeexplore.ieee.org/abstract/document/681427
"""
# check the arguments are valid
assert taps % 2 == 0, "The number of taps mush be even number."
assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0."
# make initial filter
omega_c = np.pi * cutoff_ratio
with np.errstate(invalid="ignore"):
h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) / (
np.pi * (np.arange(taps + 1) - 0.5 * taps))
h_i[taps //
2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form
# apply kaiser window
w = kaiser(taps + 1, beta)
h = h_i * w
return h
class PQMF(paddle.nn.Layer):
"""PQMF module.
This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_.
.. _`Near-perfect-reconstruction pseudo-QMF banks`:
https://ieeexplore.ieee.org/document/258122
"""
def __init__(self, subbands=4, taps=62, cutoff_ratio=0.142, beta=9.0):
"""Initilize PQMF module.
The cutoff_ratio and beta parameters are optimized for #subbands = 4.
See dicussion in https://github.com/kan-bayashi/ParallelWaveGAN/issues/195.
Parameters
----------
subbands : int
The number of subbands.
taps : int
The number of filter taps.
cutoff_ratio : float
Cut-off frequency ratio.
beta : float
Beta coefficient for kaiser window.
"""
super(PQMF, self).__init__()
# build analysis & synthesis filter coefficients
h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
h_analysis = np.zeros((subbands, len(h_proto)))
h_synthesis = np.zeros((subbands, len(h_proto)))
for k in range(subbands):
h_analysis[k] = (
2 * h_proto * np.cos((2 * k + 1) * (np.pi / (2 * subbands)) * (
np.arange(taps + 1) - (taps / 2)) + (-1)**k * np.pi / 4))
h_synthesis[k] = (
2 * h_proto * np.cos((2 * k + 1) * (np.pi / (2 * subbands)) * (
np.arange(taps + 1) - (taps / 2)) - (-1)**k * np.pi / 4))
# convert to tensor
self.analysis_filter = paddle.to_tensor(
h_analysis, dtype="float32").unsqueeze(1)
self.synthesis_filter = paddle.to_tensor(
h_synthesis, dtype="float32").unsqueeze(0)
# filter for downsampling & upsampling
updown_filter = paddle.zeros(
(subbands, subbands, subbands), dtype="float32")
for k in range(subbands):
updown_filter[k, k, 0] = 1.0
self.updown_filter = updown_filter
self.subbands = subbands
# keep padding info
self.pad_fn = paddle.nn.Pad1D(taps // 2, mode='constant', value=0.0)
def analysis(self, x):
"""Analysis with PQMF.
Parameters
----------
x : Tensor
Input tensor (B, 1, T).
Returns
----------
Tensor
Output tensor (B, subbands, T // subbands).
"""
x = F.conv1d(self.pad_fn(x), self.analysis_filter)
return F.conv1d(x, self.updown_filter, stride=self.subbands)
def synthesis(self, x):
"""Synthesis with PQMF.
Parameters
----------
x : Tensor
Input tensor (B, subbands, T // subbands).
Returns
----------
Tensor
Output tensor (B, 1, T).
"""
x = F.conv1d_transpose(
x, self.updown_filter * self.subbands, stride=self.subbands)
return F.conv1d(self.pad_fn(x), self.synthesis_filter)

@ -0,0 +1,109 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Residual stack module in MelGAN."""
from typing import Any
from typing import Dict
from paddle import nn
from parakeet.modules.causal_conv import CausalConv1D
class ResidualStack(nn.Layer):
"""Residual stack module introduced in MelGAN."""
def __init__(
self,
kernel_size: int=3,
channels: int=32,
dilation: int=1,
bias: bool=True,
nonlinear_activation: str="LeakyReLU",
nonlinear_activation_params: Dict[str, Any]={"negative_slope": 0.2},
pad: str="Pad1D",
pad_params: Dict[str, Any]={"mode": "reflect"},
use_causal_conv: bool=False, ):
"""Initialize ResidualStack module.
Parameters
----------
kernel_size : int
Kernel size of dilation convolution layer.
channels : int
Number of channels of convolution layers.
dilation : int
Dilation factor.
bias : bool
Whether to add bias parameter in convolution layers.
nonlinear_activation : str
Activation function module name.
nonlinear_activation_params : Dict[str,Any]
Hyperparameters for activation function.
pad : str
Padding function module name before dilated convolution layer.
pad_params : Dict[str, Any]
Hyperparameters for padding function.
use_causal_conv : bool
Whether to use causal convolution.
"""
super().__init__()
# defile residual stack part
if not use_causal_conv:
assert (kernel_size - 1
) % 2 == 0, "Not support even number kernel size."
self.stack = nn.Sequential(
getattr(nn, nonlinear_activation)(
**nonlinear_activation_params),
getattr(nn, pad)((kernel_size - 1) // 2 * dilation,
**pad_params),
nn.Conv1D(
channels,
channels,
kernel_size,
dilation=dilation,
bias_attr=bias),
getattr(nn, nonlinear_activation)(
**nonlinear_activation_params),
nn.Conv1D(channels, channels, 1, bias_attr=bias), )
else:
self.stack = nn.Sequential(
getattr(nn, nonlinear_activation)(
**nonlinear_activation_params),
CausalConv1D(
channels,
channels,
kernel_size,
dilation=dilation,
bias=bias,
pad=pad,
pad_params=pad_params, ),
getattr(nn, nonlinear_activation)(
**nonlinear_activation_params),
nn.Conv1D(channels, channels, 1, bias_attr=bias), )
# defile extra layer for skip connection
self.skip_layer = nn.Conv1D(channels, channels, 1, bias_attr=bias)
def forward(self, c):
"""Calculate forward propagation.
Parameters
----------
c : Tensor
Input tensor (B, channels, T).
Returns
----------
Tensor
Output tensor (B, chennels, T).
"""
return self.stack(c) + self.skip_layer(c)
Loading…
Cancel
Save