From 6dbcd7720d349813258c0a532093b4123432986d Mon Sep 17 00:00:00 2001 From: TianYuan Date: Tue, 26 Oct 2021 11:45:12 +0000 Subject: [PATCH] add csmsc mb melgan example --- examples/csmsc/voc3/README.md | 127 ++++ examples/csmsc/voc3/conf/default.yaml | 139 +++++ examples/csmsc/voc3/local/preprocess.sh | 55 ++ examples/csmsc/voc3/local/synthesize.sh | 13 + examples/csmsc/voc3/local/train.sh | 13 + examples/csmsc/voc3/path.sh | 13 + examples/csmsc/voc3/run.sh | 32 ++ parakeet/datasets/vocoder_batch_fn.py | 9 +- .../gan_vocoder/multi_band_melgan/__init__.py | 13 + .../multi_band_melgan/synthesize.py | 97 ++++ .../gan_vocoder/multi_band_melgan/train.py | 271 +++++++++ parakeet/models/melgan/__init__.py | 15 + parakeet/models/melgan/melgan.py | 541 ++++++++++++++++++ parakeet/models/melgan/melgan_updater.py | 231 ++++++++ .../melgan/multi_band_melgan_updater.py | 245 ++++++++ .../parallel_wavegan/parallel_wavegan.py | 11 +- parakeet/modules/adversarial_loss.py | 124 ++++ parakeet/modules/causal_conv.py | 81 +++ parakeet/modules/pqmf.py | 140 +++++ parakeet/modules/residual_stack.py | 109 ++++ 20 files changed, 2271 insertions(+), 8 deletions(-) create mode 100644 examples/csmsc/voc3/README.md create mode 100644 examples/csmsc/voc3/conf/default.yaml create mode 100755 examples/csmsc/voc3/local/preprocess.sh create mode 100755 examples/csmsc/voc3/local/synthesize.sh create mode 100755 examples/csmsc/voc3/local/train.sh create mode 100755 examples/csmsc/voc3/path.sh create mode 100755 examples/csmsc/voc3/run.sh create mode 100644 parakeet/exps/gan_vocoder/multi_band_melgan/__init__.py create mode 100644 parakeet/exps/gan_vocoder/multi_band_melgan/synthesize.py create mode 100644 parakeet/exps/gan_vocoder/multi_band_melgan/train.py create mode 100644 parakeet/models/melgan/__init__.py create mode 100644 parakeet/models/melgan/melgan.py create mode 100644 parakeet/models/melgan/melgan_updater.py create mode 100644 parakeet/models/melgan/multi_band_melgan_updater.py create mode 100644 parakeet/modules/adversarial_loss.py create mode 100644 parakeet/modules/causal_conv.py create mode 100644 parakeet/modules/pqmf.py create mode 100644 parakeet/modules/residual_stack.py diff --git a/examples/csmsc/voc3/README.md b/examples/csmsc/voc3/README.md new file mode 100644 index 00000000..780a8ccd --- /dev/null +++ b/examples/csmsc/voc3/README.md @@ -0,0 +1,127 @@ +# Multi Band MelGAN with CSMSC +This example contains code used to train a [Multi Band MelGAN](https://arxiv.org/abs/2005.05106) model with [Chinese Standard Mandarin Speech Copus](https://www.data-baker.com/open_source.html). +## Dataset +### Download and Extract the datasaet +Download CSMSC from the [official website](https://www.data-baker.com/data/index/source) and extract it to `~/datasets`. Then the dataset is in directory `~/datasets/BZNSYP`. + +### Get MFA results for silence trim +We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut silence in the edge of audio. +You can download from here [baker_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/BZNSYP/with_tone/baker_alignment_tone.tar.gz), or train your own MFA model reference to [use_mfa example](https://github.com/PaddlePaddle/Parakeet/tree/develop/examples/use_mfa) of our repo. + +## Get Started +Assume the path to the dataset is `~/datasets/BZNSYP`. +Assume the path to the MFA result of CSMSC is `./baker_alignment_tone`. +Run the command below to +1. **source path**. +2. preprocess the dataset, +3. train the model. +4. synthesize wavs. + - synthesize waveform from `metadata.jsonl`. +```bash +./run.sh +``` +### Preprocess the dataset +```bash +./local/preprocess.sh ${conf_path} +``` +When it is done. A `dump` folder is created in the current directory. The structure of the dump folder is listed below. + +```text +dump +├── dev +│ ├── norm +│ └── raw +├── test +│ ├── norm +│ └── raw +└── train + ├── norm + ├── raw + └── feats_stats.npy +``` +The dataset is split into 3 parts, namely `train`, `dev` and `test`, each of which contains a `norm` and `raw` subfolder. The `raw` folder contains log magnitude of mel spectrogram of each utterances, while the norm folder contains normalized spectrogram. The statistics used to normalize the spectrogram is computed from the training set, which is located in `dump/train/feats_stats.npy`. + +Also there is a `metadata.jsonl` in each subfolder. It is a table-like file which contains id and paths to spectrogam of each utterance. + +### Train the model +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} +``` +`./local/train.sh` calls `${BIN_DIR}/train.py`. +Here's the complete help message. + +```text +usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA] + [--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR] + [--device DEVICE] [--nprocs NPROCS] [--verbose VERBOSE] + [--batch-size BATCH_SIZE] [--max-iter MAX_ITER] + [--run-benchmark RUN_BENCHMARK] + [--profiler_options PROFILER_OPTIONS] + +Train a ParallelWaveGAN model. + +optional arguments: + -h, --help show this help message and exit + --config CONFIG config file to overwrite default config. + --train-metadata TRAIN_METADATA + training data. + --dev-metadata DEV_METADATA + dev data. + --output-dir OUTPUT_DIR + output dir. + --device DEVICE device type to use. + --nprocs NPROCS number of processes. + --verbose VERBOSE verbose. + +benchmark: + arguments related to benchmark. + + --batch-size BATCH_SIZE + batch size. + --max-iter MAX_ITER train max steps. + --run-benchmark RUN_BENCHMARK + runing benchmark or not, if True, use the --batch-size + and --max-iter. + --profiler_options PROFILER_OPTIONS + The option of profiler, which should be in format + "key1=value1;key2=value2;key3=value3". +``` + +1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`. +2. `--train-metadata` and `--dev-metadata` should be the metadata file in the normalized subfolder of `train` and `dev` in the `dump` folder. +3. `--output-dir` is the directory to save the results of the experiment. Checkpoints are save in `checkpoints/` inside this directory. +4. `--device` is the type of the device to run the experiment, 'cpu' or 'gpu' are supported. +5. `--nprocs` is the number of processes to run in parallel, note that nprocs > 1 is only supported when `--device` is 'gpu'. + +### Synthesize +`./local/synthesize.sh` calls `${BIN_DIR}/synthesize.py`, which can synthesize waveform from `metadata.jsonl`. +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} +``` +```text +usage: synthesize.py [-h] [--config CONFIG] [--checkpoint CHECKPOINT] + [--test-metadata TEST_METADATA] [--output-dir OUTPUT_DIR] + [--device DEVICE] [--verbose VERBOSE] + +Synthesize with parallel wavegan. + +optional arguments: + -h, --help show this help message and exit + --config CONFIG parallel wavegan config file. + --checkpoint CHECKPOINT + snapshot to load. + --test-metadata TEST_METADATA + dev data. + --output-dir OUTPUT_DIR + output dir. + --device DEVICE device to run. + --verbose VERBOSE verbose. +``` + +1. `--config` parallel wavegan config file. You should use the same config with which the model is trained. +2. `--checkpoint` is the checkpoint to load. Pick one of the checkpoints from `checkpoints` inside the training output directory. +3. `--test-metadata` is the metadata of the test dataset. Use the `metadata.jsonl` in the `dev/norm` subfolder from the processed directory. +4. `--output-dir` is the directory to save the synthesized audio files. +5. `--device` is the type of device to run synthesis, 'cpu' and 'gpu' are supported. + +## Pretrained Models diff --git a/examples/csmsc/voc3/conf/default.yaml b/examples/csmsc/voc3/conf/default.yaml new file mode 100644 index 00000000..87a237c3 --- /dev/null +++ b/examples/csmsc/voc3/conf/default.yaml @@ -0,0 +1,139 @@ +# This is the hyperparameter configuration file for MelGAN. +# Please make sure this is adjusted for the CSMSC dataset. If you want to +# apply to the other dataset, you might need to carefully change some parameters. +# This configuration requires ~ 8GB memory and will finish within 7 days on Titan V. + +# This configuration is based on full-band MelGAN but the hop size and sampling +# rate is different from the paper (16kHz vs 24kHz). The number of iteraions +# is not shown in the paper so currently we train 1M iterations (not sure enough +# to converge). The optimizer setting is based on @dathudeptrai advice. +# https://github.com/kan-bayashi/ParallelWaveGAN/issues/143#issuecomment-632539906 + +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +fs: 24000 # Sampling rate. +n_fft: 2048 # FFT size. (in samples) +n_shift: 300 # Hop size. (in samples) +win_length: 1200 # Window length. (in samples) + # If set to null, it will be the same as fft_size. +window: "hann" # Window function. +n_mels: 80 # Number of mel basis. +fmin: 80 # Minimum freq in mel basis calculation. (Hz) +fmax: 7600 # Maximum frequency in mel basis calculation. (Hz) + +########################################################### +# GENERATOR NETWORK ARCHITECTURE SETTING # +########################################################### +generator_params: + in_channels: 80 # Number of input channels. + out_channels: 4 # Number of output channels. + kernel_size: 7 # Kernel size of initial and final conv layers. + channels: 384 # Initial number of channels for conv layers. + upsample_scales: [5, 5, 3] # List of Upsampling scales. + stack_kernel_size: 3 # Kernel size of dilated conv layers in residual stack. + stacks: 4 # Number of stacks in a single residual stack module. + use_weight_norm: True # Whether to use weight normalization. + use_causal_conv: False # Whether to use causal convolution. + use_final_nonlinear_activation: False # If True, spectral_convergence_loss and sub_spectral_convergence_loss will be too large (eg.30) + + +########################################################### +# DISCRIMINATOR NETWORK ARCHITECTURE SETTING # +########################################################### +discriminator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + scales: 3 # Number of multi-scales. + downsample_pooling: "AvgPool1D" # Pooling type for the input downsampling. + downsample_pooling_params: # Parameters of the above pooling function. + kernel_size: 4 + stride: 2 + padding: 1 + exclusive: True + kernel_sizes: [5, 3] # List of kernel size. + channels: 16 # Number of channels of the initial conv layer. + max_downsample_channels: 512 # Maximum number of channels of downsampling layers. + downsample_scales: [4, 4, 4] # List of downsampling scales. + nonlinear_activation: "LeakyReLU" # Nonlinear activation function. + nonlinear_activation_params: # Parameters of nonlinear activation function. + negative_slope: 0.2 + use_weight_norm: True # Whether to use weight norm. + + +########################################################### +# STFT LOSS SETTING # +########################################################### +use_stft_loss: true +stft_loss_params: + fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss. + hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss + win_lengths: [600, 1200, 240] # List of window length for STFT-based loss. + window: "hann" # Window function for STFT-based loss +use_subband_stft_loss: true +subband_stft_loss_params: + fft_sizes: [384, 683, 171] # List of FFT size for STFT-based loss. + hop_sizes: [30, 60, 10] # List of hop size for STFT-based loss + win_lengths: [150, 300, 60] # List of window length for STFT-based loss. + window: "hann" # Window function for STFT-based loss + +########################################################### +# ADVERSARIAL LOSS SETTING # +########################################################### +use_feat_match_loss: false # Whether to use feature matching loss. +lambda_adv: 2.5 # Loss balancing coefficient for adversarial loss. + +########################################################### +# DATA LOADER SETTING # +########################################################### +batch_size: 64 # Batch size. +batch_max_steps: 16200 # Length of each audio in batch. Make sure dividable by hop_size. +num_workers: 4 # Number of workers in DataLoader. + +########################################################### +# OPTIMIZER & SCHEDULER SETTING # +########################################################### +generator_optimizer_params: + epsilon: 1.0e-7 # Generator's epsilon. + weight_decay: 0.0 # Generator's weight decay coefficient. + +generator_grad_norm: -1 # Generator's gradient norm. +generator_scheduler_params: + learning_rate: 1.0e-3 # Generator's learning rate. + gamma: 0.5 # Generator's scheduler gamma. + milestones: # At each milestone, lr will be multiplied by gamma. + - 100000 + - 200000 + - 300000 + - 400000 + - 500000 + - 600000 +discriminator_optimizer_params: + epsilon: 1.0e-7 # Discriminator's epsilon. + weight_decay: 0.0 # Discriminator's weight decay coefficient. + +discriminator_grad_norm: -1 # Discriminator's gradient norm. +discriminator_scheduler_params: + learning_rate: 1.0e-3 # Discriminator's learning rate. + gamma: 0.5 # Discriminator's scheduler gamma. + milestones: # At each milestone, lr will be multiplied by gamma. + - 100000 + - 200000 + - 300000 + - 400000 + - 500000 + - 600000 + +########################################################### +# INTERVAL SETTING # +########################################################### +discriminator_train_start_steps: 200000 # Number of steps to start to train discriminator. +train_max_steps: 1000000 # Number of training steps. +save_interval_steps: 50000 # Interval steps to save checkpoint. +eval_interval_steps: 1000 # Interval steps to evaluate the network. + +########################################################### +# OTHER SETTING # +########################################################### +num_snapshots: 10 # max number of snapshots to keep while training +seed: 42 # random seed for paddle, random, and np.random \ No newline at end of file diff --git a/examples/csmsc/voc3/local/preprocess.sh b/examples/csmsc/voc3/local/preprocess.sh new file mode 100755 index 00000000..61d6d62b --- /dev/null +++ b/examples/csmsc/voc3/local/preprocess.sh @@ -0,0 +1,55 @@ +#!/bin/bash + +stage=0 +stop_stage=100 + +config_path=$1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # get durations from MFA's result + echo "Generate durations.txt from MFA results ..." + python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \ + --inputdir=./baker_alignment_tone \ + --output=durations.txt \ + --config=${config_path} +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # extract features + echo "Extract features ..." + python3 ${BIN_DIR}/../preprocess.py \ + --rootdir=~/datasets/BZNSYP/ \ + --dataset=baker \ + --dumpdir=dump \ + --dur-file=durations.txt \ + --config=${config_path} \ + --cut-sil=True \ + --num-cpu=20 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # get features' stats(mean and std) + echo "Get features' stats ..." + python3 ${MAIN_ROOT}/utils/compute_statistics.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --field-name="feats" +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # normalize, dev and test should use train's stats + echo "Normalize ..." + + python3 ${BIN_DIR}/../normalize.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --dumpdir=dump/train/norm \ + --stats=dump/train/feats_stats.npy + python3 ${BIN_DIR}/../normalize.py \ + --metadata=dump/dev/raw/metadata.jsonl \ + --dumpdir=dump/dev/norm \ + --stats=dump/train/feats_stats.npy + + python3 ${BIN_DIR}/../normalize.py \ + --metadata=dump/test/raw/metadata.jsonl \ + --dumpdir=dump/test/norm \ + --stats=dump/train/feats_stats.npy +fi diff --git a/examples/csmsc/voc3/local/synthesize.sh b/examples/csmsc/voc3/local/synthesize.sh new file mode 100755 index 00000000..9f904ac0 --- /dev/null +++ b/examples/csmsc/voc3/local/synthesize.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 + +FLAGS_allocator_strategy=naive_best_fit \ +FLAGS_fraction_of_gpu_memory_to_use=0.01 \ +python3 ${BIN_DIR}/synthesize.py \ + --config=${config_path} \ + --checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ + --test-metadata=dump/test/norm/metadata.jsonl \ + --output-dir=${train_output_path}/test diff --git a/examples/csmsc/voc3/local/train.sh b/examples/csmsc/voc3/local/train.sh new file mode 100755 index 00000000..1ef860c3 --- /dev/null +++ b/examples/csmsc/voc3/local/train.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 + +FLAGS_cudnn_exhaustive_search=true \ +FLAGS_conv_workspace_size_limit=4000 \ +python ${BIN_DIR}/train.py \ + --train-metadata=dump/train/norm/metadata.jsonl \ + --dev-metadata=dump/dev/norm/metadata.jsonl \ + --config=${config_path} \ + --output-dir=${train_output_path} \ + --nprocs=1 diff --git a/examples/csmsc/voc3/path.sh b/examples/csmsc/voc3/path.sh new file mode 100755 index 00000000..f6b9fe61 --- /dev/null +++ b/examples/csmsc/voc3/path.sh @@ -0,0 +1,13 @@ +#!/bin/bash +export MAIN_ROOT=`realpath ${PWD}/../../../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +export PYTHONDONTWRITEBYTECODE=1 +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +MODEL=multi_band_melgan +export BIN_DIR=${MAIN_ROOT}/parakeet/exps/gan_vocoder/${MODEL} \ No newline at end of file diff --git a/examples/csmsc/voc3/run.sh b/examples/csmsc/voc3/run.sh new file mode 100755 index 00000000..360f6ec2 --- /dev/null +++ b/examples/csmsc/voc3/run.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +set -e +source path.sh + +gpus=0,1 +stage=0 +stop_stage=100 + +conf_path=conf/default.yaml +train_output_path=exp/default +ckpt_name=snapshot_iter_50000.pdz + +# with the following command, you can choice the stage range you want to run +# such as `./run.sh --stage 0 --stop-stage 0` +# this can not be mixed use with `$1`, `$2` ... +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + ./local/preprocess.sh ${conf_path} || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `train_output_path/checkpoints/` dir + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # synthesize + CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 +fi diff --git a/parakeet/datasets/vocoder_batch_fn.py b/parakeet/datasets/vocoder_batch_fn.py index 30adb142..2de4fb12 100644 --- a/parakeet/datasets/vocoder_batch_fn.py +++ b/parakeet/datasets/vocoder_batch_fn.py @@ -107,8 +107,13 @@ class Clip(object): features, this process will be needed. """ - if len(x) < c.shape[1] * self.hop_size: - x = np.pad(x, (0, c.shape[1] * self.hop_size - len(x)), mode="edge") + if len(x) < c.shape[0] * self.hop_size: + x = np.pad(x, (0, c.shape[0] * self.hop_size - len(x)), mode="edge") + elif len(x) > c.shape[0] * self.hop_size: + print( + f"wave length: ({len(x)}), mel length: ({c.shape[0]}), hop size: ({self.hop_size })" + ) + x = x[:c.shape[1] * self.hop_size] # check the legnth is valid assert len(x) == c.shape[ diff --git a/parakeet/exps/gan_vocoder/multi_band_melgan/__init__.py b/parakeet/exps/gan_vocoder/multi_band_melgan/__init__.py new file mode 100644 index 00000000..185a92b8 --- /dev/null +++ b/parakeet/exps/gan_vocoder/multi_band_melgan/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/parakeet/exps/gan_vocoder/multi_band_melgan/synthesize.py b/parakeet/exps/gan_vocoder/multi_band_melgan/synthesize.py new file mode 100644 index 00000000..d48fbbd0 --- /dev/null +++ b/parakeet/exps/gan_vocoder/multi_band_melgan/synthesize.py @@ -0,0 +1,97 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +from pathlib import Path + +import jsonlines +import numpy as np +import paddle +import soundfile as sf +import yaml +from paddle import distributed as dist +from timer import timer +from yacs.config import CfgNode + +from parakeet.datasets.data_table import DataTable +from parakeet.models.melgan import MelGANGenerator + + +def main(): + parser = argparse.ArgumentParser( + description="Synthesize with parallel wavegan.") + parser.add_argument( + "--config", type=str, help="parallel wavegan config file.") + parser.add_argument("--checkpoint", type=str, help="snapshot to load.") + parser.add_argument("--test-metadata", type=str, help="dev data.") + parser.add_argument("--output-dir", type=str, help="output dir.") + parser.add_argument( + "--device", type=str, default="gpu", help="device to run.") + parser.add_argument("--verbose", type=int, default=1, help="verbose.") + + args = parser.parse_args() + + with open(args.config) as f: + config = CfgNode(yaml.safe_load(f)) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(config) + print( + f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}" + ) + + paddle.set_device(args.device) + generator = MelGANGenerator(**config["generator_params"]) + state_dict = paddle.load(args.checkpoint) + generator.set_state_dict(state_dict["generator_params"]) + + generator.remove_weight_norm() + generator.eval() + with jsonlines.open(args.test_metadata, 'r') as reader: + metadata = list(reader) + + test_dataset = DataTable( + metadata, + fields=['utt_id', 'feats'], + converters={ + 'utt_id': None, + 'feats': np.load, + }) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + N = 0 + T = 0 + for example in test_dataset: + utt_id = example['utt_id'] + mel = example['feats'] + mel = paddle.to_tensor(mel) # (T, C) + with timer() as t: + with paddle.no_grad(): + wav = generator.inference(c=mel) + wav = wav.numpy() + N += wav.size + T += t.elapse + speed = wav.size / t.elapse + print( + f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {config.fs / speed}." + ) + sf.write(str(output_dir / (utt_id + ".wav")), wav, samplerate=config.fs) + print(f"generation speed: {N / T}Hz, RTF: {config.fs / (N / T) }") + + +if __name__ == "__main__": + main() diff --git a/parakeet/exps/gan_vocoder/multi_band_melgan/train.py b/parakeet/exps/gan_vocoder/multi_band_melgan/train.py new file mode 100644 index 00000000..bb9b0b8a --- /dev/null +++ b/parakeet/exps/gan_vocoder/multi_band_melgan/train.py @@ -0,0 +1,271 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import logging +import os +import shutil +from pathlib import Path + +import jsonlines +import numpy as np +import paddle +import yaml +from paddle import DataParallel +from paddle import distributed as dist +from paddle import nn +from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler +from paddle.optimizer import Adam +from paddle.optimizer.lr import MultiStepDecay +from visualdl import LogWriter +from yacs.config import CfgNode + +from parakeet.datasets.data_table import DataTable +from parakeet.datasets.vocoder_batch_fn import Clip +from parakeet.models.melgan import MBMelGANEvaluator +from parakeet.models.melgan import MBMelGANUpdater +from parakeet.models.melgan import MelGANGenerator +from parakeet.models.melgan import MelGANMultiScaleDiscriminator +from parakeet.modules.adversarial_loss import DiscriminatorAdversarialLoss +from parakeet.modules.adversarial_loss import GeneratorAdversarialLoss +from parakeet.modules.pqmf import PQMF +from parakeet.modules.stft_loss import MultiResolutionSTFTLoss +from parakeet.training.extensions.snapshot import Snapshot +from parakeet.training.extensions.visualizer import VisualDL +from parakeet.training.seeding import seed_everything +from parakeet.training.trainer import Trainer + + +def train_sp(args, config): + # decides device type and whether to run in parallel + # setup running environment correctly + world_size = paddle.distributed.get_world_size() + if not paddle.is_compiled_with_cuda(): + paddle.set_device("cpu") + else: + paddle.set_device("gpu") + if world_size > 1: + paddle.distributed.init_parallel_env() + + # set the random seed, it is a must for multiprocess training + seed_everything(config.seed) + + print( + f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}", + ) + + # dataloader has been too verbose + logging.getLogger("DataLoader").disabled = True + + # construct dataset for training and validation + with jsonlines.open(args.train_metadata, 'r') as reader: + train_metadata = list(reader) + train_dataset = DataTable( + data=train_metadata, + fields=["wave", "feats"], + converters={ + "wave": np.load, + "feats": np.load, + }, ) + with jsonlines.open(args.dev_metadata, 'r') as reader: + dev_metadata = list(reader) + dev_dataset = DataTable( + data=dev_metadata, + fields=["wave", "feats"], + converters={ + "wave": np.load, + "feats": np.load, + }, ) + + # collate function and dataloader + train_sampler = DistributedBatchSampler( + train_dataset, + batch_size=config.batch_size, + shuffle=True, + drop_last=True) + dev_sampler = DistributedBatchSampler( + dev_dataset, + batch_size=config.batch_size, + shuffle=False, + drop_last=False) + print("samplers done!") + + if "aux_context_window" in config.generator_params: + aux_context_window = config.generator_params.aux_context_window + else: + aux_context_window = 0 + train_batch_fn = Clip( + batch_max_steps=config.batch_max_steps, + hop_size=config.n_shift, + aux_context_window=aux_context_window) + + train_dataloader = DataLoader( + train_dataset, + batch_sampler=train_sampler, + collate_fn=train_batch_fn, + num_workers=config.num_workers) + + dev_dataloader = DataLoader( + dev_dataset, + batch_sampler=dev_sampler, + collate_fn=train_batch_fn, + num_workers=config.num_workers) + print("dataloaders done!") + + generator = MelGANGenerator(**config["generator_params"]) + discriminator = MelGANMultiScaleDiscriminator( + **config["discriminator_params"]) + if world_size > 1: + generator = DataParallel(generator) + discriminator = DataParallel(discriminator) + print("models done!") + criterion_stft = MultiResolutionSTFTLoss(**config["stft_loss_params"]) + criterion_sub_stft = MultiResolutionSTFTLoss( + **config["subband_stft_loss_params"]) + criterion_gen_adv = GeneratorAdversarialLoss() + criterion_dis_adv = DiscriminatorAdversarialLoss() + # define special module for subband processing + criterion_pqmf = PQMF(subbands=config["generator_params"]["out_channels"]) + print("criterions done!") + + lr_schedule_g = MultiStepDecay(**config["generator_scheduler_params"]) + # Compared to multi_band_melgan.v1 config, Adam optimizer without gradient norm is used + generator_grad_norm = config["generator_grad_norm"] + gradient_clip_g = nn.ClipGradByGlobalNorm( + generator_grad_norm) if generator_grad_norm > 0 else None + print("gradient_clip_g:", gradient_clip_g) + + optimizer_g = Adam( + learning_rate=lr_schedule_g, + grad_clip=gradient_clip_g, + parameters=generator.parameters(), + **config["generator_optimizer_params"]) + lr_schedule_d = MultiStepDecay(**config["discriminator_scheduler_params"]) + discriminator_grad_norm = config["discriminator_grad_norm"] + gradient_clip_d = nn.ClipGradByGlobalNorm( + discriminator_grad_norm) if discriminator_grad_norm > 0 else None + print("gradient_clip_d:", gradient_clip_d) + optimizer_d = Adam( + learning_rate=lr_schedule_d, + grad_clip=gradient_clip_d, + parameters=discriminator.parameters(), + **config["discriminator_optimizer_params"]) + print("optimizers done!") + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + if dist.get_rank() == 0: + config_name = args.config.split("/")[-1] + # copy conf to output_dir + shutil.copyfile(args.config, output_dir / config_name) + + updater = MBMelGANUpdater( + models={ + "generator": generator, + "discriminator": discriminator, + }, + optimizers={ + "generator": optimizer_g, + "discriminator": optimizer_d, + }, + criterions={ + "stft": criterion_stft, + "sub_stft": criterion_sub_stft, + "gen_adv": criterion_gen_adv, + "dis_adv": criterion_dis_adv, + "pqmf": criterion_pqmf + }, + schedulers={ + "generator": lr_schedule_g, + "discriminator": lr_schedule_d, + }, + dataloader=train_dataloader, + discriminator_train_start_steps=config.discriminator_train_start_steps, + lambda_adv=config.lambda_adv, + output_dir=output_dir) + + evaluator = MBMelGANEvaluator( + models={ + "generator": generator, + "discriminator": discriminator, + }, + criterions={ + "stft": criterion_stft, + "sub_stft": criterion_sub_stft, + "gen_adv": criterion_gen_adv, + "dis_adv": criterion_dis_adv, + "pqmf": criterion_pqmf + }, + dataloader=dev_dataloader, + lambda_adv=config.lambda_adv, + output_dir=output_dir) + + trainer = Trainer( + updater, + stop_trigger=(config.train_max_steps, "iteration"), + out=output_dir) + + if dist.get_rank() == 0: + trainer.extend( + evaluator, trigger=(config.eval_interval_steps, 'iteration')) + writer = LogWriter(str(trainer.out)) + trainer.extend(VisualDL(writer), trigger=(1, 'iteration')) + trainer.extend( + Snapshot(max_size=config.num_snapshots), + trigger=(config.save_interval_steps, 'iteration')) + + print("Trainer Done!") + trainer.run() + + +def main(): + # parse args and config and redirect to train_sp + + parser = argparse.ArgumentParser( + description="Train a Multi-Band MelGAN model.") + parser.add_argument( + "--config", type=str, help="config file to overwrite default config.") + parser.add_argument("--train-metadata", type=str, help="training data.") + parser.add_argument("--dev-metadata", type=str, help="dev data.") + parser.add_argument("--output-dir", type=str, help="output dir.") + parser.add_argument( + "--device", type=str, default="gpu", help="device type to use.") + parser.add_argument( + "--nprocs", type=int, default=1, help="number of processes.") + parser.add_argument("--verbose", type=int, default=1, help="verbose.") + + args = parser.parse_args() + if args.device == "cpu" and args.nprocs > 1: + raise RuntimeError("Multiprocess training on CPU is not supported.") + + with open(args.config, 'rt') as f: + config = CfgNode(yaml.safe_load(f)) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(config) + print( + f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}" + ) + + # dispatch + if args.nprocs > 1: + dist.spawn(train_sp, (args, config), nprocs=args.nprocs) + else: + train_sp(args, config) + + +if __name__ == "__main__": + main() diff --git a/parakeet/models/melgan/__init__.py b/parakeet/models/melgan/__init__.py new file mode 100644 index 00000000..d4f557db --- /dev/null +++ b/parakeet/models/melgan/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .melgan import * +from .multi_band_melgan_updater import * diff --git a/parakeet/models/melgan/melgan.py b/parakeet/models/melgan/melgan.py new file mode 100644 index 00000000..ccc19d5f --- /dev/null +++ b/parakeet/models/melgan/melgan.py @@ -0,0 +1,541 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MelGAN Modules.""" +from typing import Any +from typing import Dict +from typing import List + +import numpy as np +import paddle +from paddle import nn +from paddle.fluid.layers import Normal + +from parakeet.modules.causal_conv import CausalConv1D +from parakeet.modules.causal_conv import CausalConv1DTranspose +from parakeet.modules.pqmf import PQMF +from parakeet.modules.residual_stack import ResidualStack + + +class MelGANGenerator(nn.Layer): + """MelGAN generator module.""" + + def __init__( + self, + in_channels: int=80, + out_channels: int=1, + kernel_size: int=7, + channels: int=512, + bias: bool=True, + upsample_scales: List[int]=[8, 8, 2, 2], + stack_kernel_size: int=3, + stacks: int=3, + nonlinear_activation: str="LeakyReLU", + nonlinear_activation_params: Dict[str, Any]={"negative_slope": 0.2}, + pad: str="Pad1D", + pad_params: Dict[str, Any]={"mode": "reflect"}, + use_final_nonlinear_activation: bool=True, + use_weight_norm: bool=True, + use_causal_conv: bool=False, ): + """Initialize MelGANGenerator module. + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels, + the number of sub-band is out_channels in multi-band melgan. + kernel_size : int + Kernel size of initial and final conv layer. + channels : int + Initial number of channels for conv layer. + bias : bool + Whether to add bias parameter in convolution layers. + upsample_scales : List[int] + List of upsampling scales. + stack_kernel_size : int + Kernel size of dilated conv layers in residual stack. + stacks : int + Number of stacks in a single residual stack. + nonlinear_activation : Optional[str], optional + Non linear activation in upsample network, by default None + nonlinear_activation_params : Dict[str, Any], optional + Parameters passed to the linear activation in the upsample network, + by default {} + pad : str + Padding function module name before dilated convolution layer. + pad_params : dict + Hyperparameters for padding function. + use_final_nonlinear_activation : paddle.nn.Layer + Activation function for the final layer. + use_weight_norm : bool + Whether to use weight norm. + If set to true, it will be applied to all of the conv layers. + use_causal_conv : bool + Whether to use causal convolution. + """ + super().__init__() + + # check hyper parameters is valid + assert channels >= np.prod(upsample_scales) + assert channels % (2**len(upsample_scales)) == 0 + if not use_causal_conv: + assert (kernel_size - 1 + ) % 2 == 0, "Not support even number kernel size." + # add initial layer + layers = [] + if not use_causal_conv: + layers += [ + getattr(paddle.nn, pad)((kernel_size - 1) // 2, **pad_params), + nn.Conv1D(in_channels, channels, kernel_size, bias_attr=bias), + ] + else: + layers += [ + CausalConv1D( + in_channels, + channels, + kernel_size, + bias=bias, + pad=pad, + pad_params=pad_params, ), + ] + + for i, upsample_scale in enumerate(upsample_scales): + # add upsampling layer + layers += [ + getattr(nn, nonlinear_activation)(**nonlinear_activation_params) + ] + if not use_causal_conv: + layers += [ + nn.Conv1DTranspose( + channels // (2**i), + channels // (2**(i + 1)), + upsample_scale * 2, + stride=upsample_scale, + padding=upsample_scale // 2 + upsample_scale % 2, + output_padding=upsample_scale % 2, + bias_attr=bias, ) + ] + else: + layers += [ + CausalConv1DTranspose( + channels // (2**i), + channels // (2**(i + 1)), + upsample_scale * 2, + stride=upsample_scale, + bias=bias, ) + ] + + # add residual stack + for j in range(stacks): + layers += [ + ResidualStack( + kernel_size=stack_kernel_size, + channels=channels // (2**(i + 1)), + dilation=stack_kernel_size**j, + bias=bias, + nonlinear_activation=nonlinear_activation, + nonlinear_activation_params=nonlinear_activation_params, + pad=pad, + pad_params=pad_params, + use_causal_conv=use_causal_conv, ) + ] + + # add final layer + layers += [ + getattr(nn, nonlinear_activation)(**nonlinear_activation_params) + ] + if not use_causal_conv: + layers += [ + getattr(nn, pad)((kernel_size - 1) // 2, **pad_params), + nn.Conv1D( + channels // (2**(i + 1)), + out_channels, + kernel_size, + bias_attr=bias), + ] + else: + layers += [ + CausalConv1D( + channels // (2**(i + 1)), + out_channels, + kernel_size, + bias=bias, + pad=pad, + pad_params=pad_params, ), + ] + if use_final_nonlinear_activation: + layers += [nn.Tanh()] + + # define the model as a single function + self.melgan = nn.Sequential(*layers) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + # reset parameters + self.reset_parameters() + + # initialize pqmf for multi-band melgan inference + if out_channels > 1: + self.pqmf = PQMF(subbands=out_channels) + else: + self.pqmf = None + + def forward(self, c): + """Calculate forward propagation. + Parameters + ---------- + c : Tensor + Input tensor (B, in_channels, T). + Returns + ---------- + Tensor + Output tensor (B, out_channels, T ** prod(upsample_scales)). + """ + out = self.melgan(c) + return out + + def apply_weight_norm(self): + """Recursively apply weight normalization to all the Convolution layers + in the sublayers. + """ + + def _apply_weight_norm(layer): + if isinstance(layer, (nn.Conv1D, nn.Conv2D, nn.Conv1DTranspose)): + nn.utils.weight_norm(layer) + + self.apply(_apply_weight_norm) + + def remove_weight_norm(self): + """Recursively remove weight normalization from all the Convolution + layers in the sublayers. + """ + + def _remove_weight_norm(layer): + try: + nn.utils.remove_weight_norm(layer) + except ValueError: + pass + + self.apply(_remove_weight_norm) + + def reset_parameters(self): + """Reset parameters. + This initialization follows official implementation manner. + https://github.com/descriptinc/melgan-neurips/blob/master/mel2wav/modules.py + """ + + # 定义参数为float的正态分布。 + dist = Normal(loc=0.0, scale=0.02) + + def _reset_parameters(m): + if isinstance(m, nn.Conv1D) or isinstance(m, nn.Conv1DTranspose): + w = dist.sample(m.weight.shape) + m.weight.set_value(w) + + self.apply(_reset_parameters) + + def inference(self, c): + """Perform inference. + Parameters + ---------- + c : Union[Tensor, ndarray] + Input tensor (T, in_channels). + Returns + ---------- + Tensor + Output tensor (out_channels*T ** prod(upsample_scales), 1). + """ + if not isinstance(c, paddle.Tensor): + c = paddle.to_tensor(c, dtype="float32") + # pseudo batch + c = c.transpose([1, 0]).unsqueeze(0) + # (B, out_channels, T ** prod(upsample_scales) + out = self.melgan(c) + if self.pqmf is not None: + # (B, 1, out_channels * T ** prod(upsample_scales) + out = self.pqmf.synthesis(out) + out = out.squeeze(0).transpose([1, 0]) + return out + + +class MelGANDiscriminator(nn.Layer): + """MelGAN discriminator module.""" + + def __init__( + self, + in_channels: int=1, + out_channels: int=1, + kernel_sizes: List[int]=[5, 3], + channels: int=16, + max_downsample_channels: int=1024, + bias: bool=True, + downsample_scales: List[int]=[4, 4, 4, 4], + nonlinear_activation: str="LeakyReLU", + nonlinear_activation_params: Dict[str, Any]={"negative_slope": 0.2}, + pad: str="Pad1D", + pad_params: Dict[str, Any]={"mode": "reflect"}, ): + """Initilize MelGAN discriminator module. + Parameters + ---------- + in_channels : + int): Number of input channels. + out_channels : int + Number of output channels. + kernel_sizes : List[int] + List of two kernel sizes. The prod will be used for the first conv layer, + and the first and the second kernel sizes will be used for the last two layers. + For example if kernel_sizes = [5, 3], the first layer kernel size will be 5 * 3 = 15, + the last two layers' kernel size will be 5 and 3, respectively. + channels : int + Initial number of channels for conv layer. + max_downsample_channels : int + Maximum number of channels for downsampling layers. + bias : bool + Whether to add bias parameter in convolution layers. + downsample_scales : List[int] + List of downsampling scales. + nonlinear_activation : str + Activation function module name. + nonlinear_activation_params : dict + Hyperparameters for activation function. + pad : str + Padding function module name before dilated convolution layer. + pad_params : dict + Hyperparameters for padding function. + """ + super().__init__() + self.layers = nn.LayerList() + + # check kernel size is valid + assert len(kernel_sizes) == 2 + assert kernel_sizes[0] % 2 == 1 + assert kernel_sizes[1] % 2 == 1 + # add first layer + self.layers.append( + nn.Sequential( + getattr(nn, pad)((np.prod(kernel_sizes) - 1) // 2, ** + pad_params), + nn.Conv1D( + in_channels, + channels, + int(np.prod(kernel_sizes)), + bias_attr=bias), + getattr(nn, nonlinear_activation)( + **nonlinear_activation_params), )) + + # add downsample layers + in_chs = channels + for downsample_scale in downsample_scales: + out_chs = min(in_chs * downsample_scale, max_downsample_channels) + self.layers.append( + nn.Sequential( + nn.Conv1D( + in_chs, + out_chs, + kernel_size=downsample_scale * 10 + 1, + stride=downsample_scale, + padding=downsample_scale * 5, + groups=in_chs // 4, + bias_attr=bias, ), + getattr(nn, nonlinear_activation)( + **nonlinear_activation_params), )) + in_chs = out_chs + + # add final layers + out_chs = min(in_chs * 2, max_downsample_channels) + self.layers.append( + nn.Sequential( + nn.Conv1D( + in_chs, + out_chs, + kernel_sizes[0], + padding=(kernel_sizes[0] - 1) // 2, + bias_attr=bias, ), + getattr(nn, nonlinear_activation)( + **nonlinear_activation_params), )) + self.layers.append( + nn.Conv1D( + out_chs, + out_channels, + kernel_sizes[1], + padding=(kernel_sizes[1] - 1) // 2, + bias_attr=bias, ), ) + + def forward(self, x): + """Calculate forward propagation. + Parameters + ---------- + x : Tensor + Input noise signal (B, 1, T). + Returns + ---------- + List + List of output tensors of each layer (for feat_match_loss). + """ + outs = [] + for f in self.layers: + x = f(x) + outs += [x] + + return outs + + +class MelGANMultiScaleDiscriminator(nn.Layer): + """MelGAN multi-scale discriminator module.""" + + def __init__( + self, + in_channels: int=1, + out_channels: int=1, + scales: int=3, + downsample_pooling: str="AvgPool1D", + # follow the official implementation setting + downsample_pooling_params: Dict[str, Any]={ + "kernel_size": 4, + "stride": 2, + "padding": 1, + "exclusive": True, + }, + kernel_sizes: List[int]=[5, 3], + channels: int=16, + max_downsample_channels: int=1024, + bias: bool=True, + downsample_scales: List[int]=[4, 4, 4, 4], + nonlinear_activation: str="LeakyReLU", + nonlinear_activation_params: Dict[str, Any]={"negative_slope": 0.2}, + pad: str="Pad1D", + pad_params: Dict[str, Any]={"mode": "reflect"}, + use_weight_norm: bool=True, ): + """Initilize MelGAN multi-scale discriminator module. + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + scales : int + Number of multi-scales. + downsample_pooling : str + Pooling module name for downsampling of the inputs. + downsample_pooling_params : dict + Parameters for the above pooling module. + kernel_sizes : List[int] + List of two kernel sizes. The sum will be used for the first conv layer, + and the first and the second kernel sizes will be used for the last two layers. + channels : int + Initial number of channels for conv layer. + max_downsample_channels : int + Maximum number of channels for downsampling layers. + bias : bool + Whether to add bias parameter in convolution layers. + downsample_scales : List[int] + List of downsampling scales. + nonlinear_activation : str + Activation function module name. + nonlinear_activation_params : dict + Hyperparameters for activation function. + pad : str + Padding function module name before dilated convolution layer. + pad_params : dict + Hyperparameters for padding function. + use_causal_conv : bool + Whether to use causal convolution. + """ + super().__init__() + self.discriminators = nn.LayerList() + + # add discriminators + for _ in range(scales): + self.discriminators.append( + MelGANDiscriminator( + in_channels=in_channels, + out_channels=out_channels, + kernel_sizes=kernel_sizes, + channels=channels, + max_downsample_channels=max_downsample_channels, + bias=bias, + downsample_scales=downsample_scales, + nonlinear_activation=nonlinear_activation, + nonlinear_activation_params=nonlinear_activation_params, + pad=pad, + pad_params=pad_params, )) + self.pooling = getattr(nn, downsample_pooling)( + **downsample_pooling_params) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + # reset parameters + self.reset_parameters() + + def forward(self, x): + """Calculate forward propagation. + Parameters + ---------- + x : Tensor + Input noise signal (B, 1, T). + Returns + ---------- + List + List of list of each discriminator outputs, which consists of each layer output tensors. + """ + outs = [] + for f in self.discriminators: + outs += [f(x)] + x = self.pooling(x) + + return outs + + def apply_weight_norm(self): + """Recursively apply weight normalization to all the Convolution layers + in the sublayers. + """ + + def _apply_weight_norm(layer): + if isinstance(layer, (nn.Conv1D, nn.Conv2D, nn.Conv1DTranspose)): + nn.utils.weight_norm(layer) + + self.apply(_apply_weight_norm) + + def remove_weight_norm(self): + """Recursively remove weight normalization from all the Convolution + layers in the sublayers. + """ + + def _remove_weight_norm(layer): + try: + nn.utils.remove_weight_norm(layer) + except ValueError: + pass + + self.apply(_remove_weight_norm) + + def reset_parameters(self): + """Reset parameters. + This initialization follows official implementation manner. + https://github.com/descriptinc/melgan-neurips/blob/master/mel2wav/modules.py + """ + + # 定义参数为float的正态分布。 + dist = Normal(loc=0.0, scale=0.02) + + def _reset_parameters(m): + if isinstance(m, nn.Conv1D) or isinstance(m, nn.Conv1DTranspose): + w = dist.sample(m.weight.shape) + m.weight.set_value(w) + + self.apply(_reset_parameters) diff --git a/parakeet/models/melgan/melgan_updater.py b/parakeet/models/melgan/melgan_updater.py new file mode 100644 index 00000000..7bd59881 --- /dev/null +++ b/parakeet/models/melgan/melgan_updater.py @@ -0,0 +1,231 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Dict + +import paddle +from paddle import distributed as dist +from paddle.io import DataLoader +from paddle.nn import Layer +from paddle.optimizer import Optimizer +from paddle.optimizer.lr import LRScheduler +from timer import timer + +from parakeet.training.extensions.evaluator import StandardEvaluator +from parakeet.training.reporter import report +from parakeet.training.updaters.standard_updater import StandardUpdater +from parakeet.training.updaters.standard_updater import UpdaterState +logging.basicConfig( + format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s', + datefmt='[%Y-%m-%d %H:%M:%S]') +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class PWGUpdater(StandardUpdater): + def __init__(self, + models: Dict[str, Layer], + optimizers: Dict[str, Optimizer], + criterions: Dict[str, Layer], + schedulers: Dict[str, LRScheduler], + dataloader: DataLoader, + discriminator_train_start_steps: int, + lambda_adv: float, + output_dir=None): + self.models = models + self.generator: Layer = models['generator'] + self.discriminator: Layer = models['discriminator'] + + self.optimizers = optimizers + self.optimizer_g: Optimizer = optimizers['generator'] + self.optimizer_d: Optimizer = optimizers['discriminator'] + + self.criterions = criterions + self.criterion_stft = criterions['stft'] + self.criterion_mse = criterions['mse'] + + self.schedulers = schedulers + self.scheduler_g = schedulers['generator'] + self.scheduler_d = schedulers['discriminator'] + + self.dataloader = dataloader + + self.discriminator_train_start_steps = discriminator_train_start_steps + self.lambda_adv = lambda_adv + self.state = UpdaterState(iteration=0, epoch=0) + + self.train_iterator = iter(self.dataloader) + + log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) + self.filehandler = logging.FileHandler(str(log_file)) + logger.addHandler(self.filehandler) + self.logger = logger + self.msg = "" + + def update_core(self, batch): + self.msg = "Rank: {}, ".format(dist.get_rank()) + losses_dict = {} + + # parse batch + wav, mel = batch + + # Generator + noise = paddle.randn(wav.shape) + + with timer() as t: + wav_ = self.generator(noise, mel) + # logging.debug(f"Generator takes {t.elapse}s.") + + # initialize + gen_loss = 0.0 + + ## Multi-resolution stft loss + with timer() as t: + sc_loss, mag_loss = self.criterion_stft(wav_, wav) + # logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s.") + + report("train/spectral_convergence_loss", float(sc_loss)) + report("train/log_stft_magnitude_loss", float(mag_loss)) + + losses_dict["spectral_convergence_loss"] = float(sc_loss) + losses_dict["log_stft_magnitude_loss"] = float(mag_loss) + + gen_loss += sc_loss + mag_loss + + ## Adversarial loss + if self.state.iteration > self.discriminator_train_start_steps: + with timer() as t: + p_ = self.discriminator(wav_) + adv_loss = self.criterion_mse(p_, paddle.ones_like(p_)) + # logging.debug( + # f"Discriminator and adversarial loss takes {t.elapse}s") + report("train/adversarial_loss", float(adv_loss)) + losses_dict["adversarial_loss"] = float(adv_loss) + gen_loss += self.lambda_adv * adv_loss + + report("train/generator_loss", float(gen_loss)) + losses_dict["generator_loss"] = float(gen_loss) + + with timer() as t: + self.optimizer_g.clear_grad() + gen_loss.backward() + # logging.debug(f"Backward takes {t.elapse}s.") + + with timer() as t: + self.optimizer_g.step() + self.scheduler_g.step() + # logging.debug(f"Update takes {t.elapse}s.") + + # Disctiminator + if self.state.iteration > self.discriminator_train_start_steps: + with paddle.no_grad(): + wav_ = self.generator(noise, mel) + p = self.discriminator(wav) + p_ = self.discriminator(wav_.detach()) + real_loss = self.criterion_mse(p, paddle.ones_like(p)) + fake_loss = self.criterion_mse(p_, paddle.zeros_like(p_)) + dis_loss = real_loss + fake_loss + report("train/real_loss", float(real_loss)) + report("train/fake_loss", float(fake_loss)) + report("train/discriminator_loss", float(dis_loss)) + losses_dict["real_loss"] = float(real_loss) + losses_dict["fake_loss"] = float(fake_loss) + losses_dict["discriminator_loss"] = float(dis_loss) + + self.optimizer_d.clear_grad() + dis_loss.backward() + + self.optimizer_d.step() + self.scheduler_d.step() + + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + + +class PWGEvaluator(StandardEvaluator): + def __init__(self, + models, + criterions, + dataloader, + lambda_adv, + output_dir=None): + self.models = models + self.generator = models['generator'] + self.discriminator = models['discriminator'] + + self.criterions = criterions + self.criterion_stft = criterions['stft'] + self.criterion_mse = criterions['mse'] + + self.dataloader = dataloader + self.lambda_adv = lambda_adv + + log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) + self.filehandler = logging.FileHandler(str(log_file)) + logger.addHandler(self.filehandler) + self.logger = logger + self.msg = "" + + def evaluate_core(self, batch): + # logging.debug("Evaluate: ") + self.msg = "Evaluate: " + losses_dict = {} + + wav, mel = batch + noise = paddle.randn(wav.shape) + + with timer() as t: + wav_ = self.generator(noise, mel) + # logging.debug(f"Generator takes {t.elapse}s") + + ## Adversarial loss + with timer() as t: + p_ = self.discriminator(wav_) + adv_loss = self.criterion_mse(p_, paddle.ones_like(p_)) + # logging.debug( + # f"Discriminator and adversarial loss takes {t.elapse}s") + report("eval/adversarial_loss", float(adv_loss)) + losses_dict["adversarial_loss"] = float(adv_loss) + gen_loss = self.lambda_adv * adv_loss + + # stft loss + with timer() as t: + sc_loss, mag_loss = self.criterion_stft(wav_, wav) + # logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s") + + report("eval/spectral_convergence_loss", float(sc_loss)) + report("eval/log_stft_magnitude_loss", float(mag_loss)) + losses_dict["spectral_convergence_loss"] = float(sc_loss) + losses_dict["log_stft_magnitude_loss"] = float(mag_loss) + gen_loss += sc_loss + mag_loss + + report("eval/generator_loss", float(gen_loss)) + losses_dict["generator_loss"] = float(gen_loss) + + # Disctiminator + p = self.discriminator(wav) + real_loss = self.criterion_mse(p, paddle.ones_like(p)) + fake_loss = self.criterion_mse(p_, paddle.zeros_like(p_)) + dis_loss = real_loss + fake_loss + report("eval/real_loss", float(real_loss)) + report("eval/fake_loss", float(fake_loss)) + report("eval/discriminator_loss", float(dis_loss)) + + losses_dict["real_loss"] = float(real_loss) + losses_dict["fake_loss"] = float(fake_loss) + losses_dict["discriminator_loss"] = float(dis_loss) + + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + self.logger.info(self.msg) diff --git a/parakeet/models/melgan/multi_band_melgan_updater.py b/parakeet/models/melgan/multi_band_melgan_updater.py new file mode 100644 index 00000000..0783cb97 --- /dev/null +++ b/parakeet/models/melgan/multi_band_melgan_updater.py @@ -0,0 +1,245 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Dict + +import paddle +from paddle import distributed as dist +from paddle.io import DataLoader +from paddle.nn import Layer +from paddle.optimizer import Optimizer +from paddle.optimizer.lr import LRScheduler + +from parakeet.training.extensions.evaluator import StandardEvaluator +from parakeet.training.reporter import report +from parakeet.training.updaters.standard_updater import StandardUpdater +from parakeet.training.updaters.standard_updater import UpdaterState +logging.basicConfig( + format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s', + datefmt='[%Y-%m-%d %H:%M:%S]') +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class MBMelGANUpdater(StandardUpdater): + def __init__(self, + models: Dict[str, Layer], + optimizers: Dict[str, Optimizer], + criterions: Dict[str, Layer], + schedulers: Dict[str, LRScheduler], + dataloader: DataLoader, + discriminator_train_start_steps: int, + lambda_adv: float, + output_dir=None): + self.models = models + self.generator: Layer = models['generator'] + self.discriminator: Layer = models['discriminator'] + + self.optimizers = optimizers + self.optimizer_g: Optimizer = optimizers['generator'] + self.optimizer_d: Optimizer = optimizers['discriminator'] + + self.criterions = criterions + self.criterion_stft = criterions['stft'] + self.criterion_sub_stft = criterions['sub_stft'] + self.criterion_pqmf = criterions['pqmf'] + self.criterion_gen_adv = criterions["gen_adv"] + self.criterion_dis_adv = criterions["dis_adv"] + + self.schedulers = schedulers + self.scheduler_g = schedulers['generator'] + self.scheduler_d = schedulers['discriminator'] + + self.dataloader = dataloader + + self.discriminator_train_start_steps = discriminator_train_start_steps + self.lambda_adv = lambda_adv + self.state = UpdaterState(iteration=0, epoch=0) + + self.train_iterator = iter(self.dataloader) + + log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) + self.filehandler = logging.FileHandler(str(log_file)) + logger.addHandler(self.filehandler) + self.logger = logger + self.msg = "" + + def update_core(self, batch): + self.msg = "Rank: {}, ".format(dist.get_rank()) + losses_dict = {} + + # parse batch + wav, mel = batch + # Generator + # (B, out_channels, T ** prod(upsample_scales) + wav_ = self.generator(mel) + wav_mb_ = wav_ + # (B, 1, out_channels*T ** prod(upsample_scales) + wav_ = self.criterion_pqmf.synthesis(wav_mb_) + + # initialize + gen_loss = 0.0 + + # full band Multi-resolution stft loss + sc_loss, mag_loss = self.criterion_stft(wav_, wav) + # for balancing with subband stft loss + # Eq.(9) in paper + gen_loss += 0.5 * (sc_loss + mag_loss) + report("train/spectral_convergence_loss", float(sc_loss)) + report("train/log_stft_magnitude_loss", float(mag_loss)) + losses_dict["spectral_convergence_loss"] = float(sc_loss) + losses_dict["log_stft_magnitude_loss"] = float(mag_loss) + + # sub band Multi-resolution stft loss + # (B, subbands, T // subbands) + wav_mb = self.criterion_pqmf.analysis(wav) + sub_sc_loss, sub_mag_loss = self.criterion_sub_stft(wav_mb_, wav_mb) + # Eq.(9) in paper + gen_loss += 0.5 * (sub_sc_loss + sub_mag_loss) + report("train/sub_spectral_convergence_loss", float(sub_sc_loss)) + report("train/sub_log_stft_magnitude_loss", float(sub_mag_loss)) + losses_dict["sub_spectral_convergence_loss"] = float(sub_sc_loss) + losses_dict["sub_log_stft_magnitude_loss"] = float(sub_mag_loss) + + ## Adversarial loss + if self.state.iteration > self.discriminator_train_start_steps: + p_ = self.discriminator(wav_) + adv_loss = self.criterion_gen_adv(p_) + + report("train/adversarial_loss", float(adv_loss)) + losses_dict["adversarial_loss"] = float(adv_loss) + gen_loss += self.lambda_adv * adv_loss + + report("train/generator_loss", float(gen_loss)) + losses_dict["generator_loss"] = float(gen_loss) + + self.optimizer_g.clear_grad() + gen_loss.backward() + + self.optimizer_g.step() + self.scheduler_g.step() + + # Disctiminator + if self.state.iteration > self.discriminator_train_start_steps: + # re-compute wav_ which leads better quality + with paddle.no_grad(): + wav_ = self.generator(mel) + wav_ = self.criterion_pqmf.synthesis(wav_) + p = self.discriminator(wav) + p_ = self.discriminator(wav_.detach()) + real_loss, fake_loss = self.criterion_dis_adv(p_, p) + dis_loss = real_loss + fake_loss + report("train/real_loss", float(real_loss)) + report("train/fake_loss", float(fake_loss)) + report("train/discriminator_loss", float(dis_loss)) + losses_dict["real_loss"] = float(real_loss) + losses_dict["fake_loss"] = float(fake_loss) + losses_dict["discriminator_loss"] = float(dis_loss) + + self.optimizer_d.clear_grad() + dis_loss.backward() + + self.optimizer_d.step() + self.scheduler_d.step() + + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + + +class MBMelGANEvaluator(StandardEvaluator): + def __init__(self, + models, + criterions, + dataloader, + lambda_adv, + output_dir=None): + self.models = models + self.generator = models['generator'] + self.discriminator = models['discriminator'] + + self.criterions = criterions + self.criterion_stft = criterions['stft'] + self.criterion_sub_stft = criterions['sub_stft'] + self.criterion_pqmf = criterions['pqmf'] + self.criterion_gen_adv = criterions["gen_adv"] + self.criterion_dis_adv = criterions["dis_adv"] + + self.dataloader = dataloader + self.lambda_adv = lambda_adv + + log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) + self.filehandler = logging.FileHandler(str(log_file)) + logger.addHandler(self.filehandler) + self.logger = logger + self.msg = "" + + def evaluate_core(self, batch): + # logging.debug("Evaluate: ") + self.msg = "Evaluate: " + losses_dict = {} + + wav, mel = batch + # Generator + # (B, out_channels, T ** prod(upsample_scales) + wav_ = self.generator(mel) + wav_mb_ = wav_ + # (B, 1, out_channels*T ** prod(upsample_scales) + wav_ = self.criterion_pqmf.synthesis(wav_mb_) + + ## Adversarial loss + p_ = self.discriminator(wav_) + adv_loss = self.criterion_gen_adv(p_) + + report("eval/adversarial_loss", float(adv_loss)) + losses_dict["adversarial_loss"] = float(adv_loss) + gen_loss = self.lambda_adv * adv_loss + + # Multi-resolution stft loss + sc_loss, mag_loss = self.criterion_stft(wav_, wav) + # Eq.(9) in paper + gen_loss += 0.5 * (sc_loss + mag_loss) + report("eval/spectral_convergence_loss", float(sc_loss)) + report("eval/log_stft_magnitude_loss", float(mag_loss)) + losses_dict["spectral_convergence_loss"] = float(sc_loss) + losses_dict["log_stft_magnitude_loss"] = float(mag_loss) + + # sub band Multi-resolution stft loss + # (B, subbands, T // subbands) + wav_mb = self.criterion_pqmf.analysis(wav) + sub_sc_loss, sub_mag_loss = self.criterion_sub_stft(wav_mb_, wav_mb) + # Eq.(9) in paper + gen_loss += 0.5 * (sub_sc_loss + sub_mag_loss) + report("eval/sub_spectral_convergence_loss", float(sub_sc_loss)) + report("eval/sub_log_stft_magnitude_loss", float(sub_mag_loss)) + losses_dict["sub_spectral_convergence_loss"] = float(sub_sc_loss) + losses_dict["sub_log_stft_magnitude_loss"] = float(sub_mag_loss) + + report("eval/generator_loss", float(gen_loss)) + losses_dict["generator_loss"] = float(gen_loss) + + # Disctiminator + p = self.discriminator(wav) + real_loss, fake_loss = self.criterion_dis_adv(p_, p) + dis_loss = real_loss + fake_loss + report("eval/real_loss", float(real_loss)) + report("eval/fake_loss", float(fake_loss)) + report("eval/discriminator_loss", float(dis_loss)) + + losses_dict["real_loss"] = float(real_loss) + losses_dict["fake_loss"] = float(fake_loss) + losses_dict["discriminator_loss"] = float(dis_loss) + + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + self.logger.info(self.msg) diff --git a/parakeet/models/parallel_wavegan/parallel_wavegan.py b/parakeet/models/parallel_wavegan/parallel_wavegan.py index bb214653..e166ccde 100644 --- a/parakeet/models/parallel_wavegan/parallel_wavegan.py +++ b/parakeet/models/parallel_wavegan/parallel_wavegan.py @@ -495,26 +495,25 @@ class PWGGenerator(nn.Layer): self.apply(_remove_weight_norm) - def inference(self, c=None): + def inference(self, c): """Waveform generation. This function is used for single instance inference. Parameters ---------- - c : Tensor, optional + c : Tensor Shape (T', C_aux), the auxiliary input, by default None - x : Tensor, optional - Shape (T, C_in), the noise waveform, by default None - If not provided, a sample is drawn from a gaussian distribution. Returns ------- Tensor Shape (T, C_out), the generated waveform """ + # a sample is drawn from a gaussian distribution. x = paddle.randn( [1, self.in_channels, paddle.shape(c)[0] * self.upsample_factor]) - c = paddle.transpose(c, [1, 0]).unsqueeze(0) # pseudo batch + # pseudo batch + c = paddle.transpose(c, [1, 0]).unsqueeze(0) c = nn.Pad1D(self.aux_context_window, mode='replicate')(c) out = self(x, c).squeeze(0).transpose([1, 0]) return out diff --git a/parakeet/modules/adversarial_loss.py b/parakeet/modules/adversarial_loss.py new file mode 100644 index 00000000..02e8c807 --- /dev/null +++ b/parakeet/modules/adversarial_loss.py @@ -0,0 +1,124 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Adversarial loss modules.""" +import paddle +import paddle.nn.functional as F +from paddle import nn + + +class GeneratorAdversarialLoss(nn.Layer): + """Generator adversarial loss module.""" + + def __init__( + self, + average_by_discriminators=True, + loss_type="mse", ): + """Initialize GeneratorAversarialLoss module.""" + super().__init__() + self.average_by_discriminators = average_by_discriminators + assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." + if loss_type == "mse": + self.criterion = self._mse_loss + else: + self.criterion = self._hinge_loss + + def forward(self, outputs): + """Calcualate generator adversarial loss. + Parameters + ---------- + outputs: Tensor or List + Discriminator outputs or list of discriminator outputs. + Returns + ---------- + Tensor + Generator adversarial loss value. + """ + if isinstance(outputs, (tuple, list)): + adv_loss = 0.0 + for i, outputs_ in enumerate(outputs): + if isinstance(outputs_, (tuple, list)): + # case including feature maps + outputs_ = outputs_[-1] + adv_loss += self.criterion(outputs_) + if self.average_by_discriminators: + adv_loss /= i + 1 + else: + adv_loss = self.criterion(outputs) + + return adv_loss + + def _mse_loss(self, x): + return F.mse_loss(x, paddle.ones_like(x)) + + def _hinge_loss(self, x): + return -x.mean() + + +class DiscriminatorAdversarialLoss(nn.Layer): + """Discriminator adversarial loss module.""" + + def __init__( + self, + average_by_discriminators=True, + loss_type="mse", ): + """Initialize DiscriminatorAversarialLoss module.""" + super().__init__() + self.average_by_discriminators = average_by_discriminators + assert loss_type in ["mse"], f"{loss_type} is not supported." + if loss_type == "mse": + self.fake_criterion = self._mse_fake_loss + self.real_criterion = self._mse_real_loss + + def forward(self, outputs_hat, outputs): + """Calcualate discriminator adversarial loss. + Parameters + ---------- + outputs_hat : Tensor or list + Discriminator outputs or list of + discriminator outputs calculated from generator outputs. + outputs : Tensor or list + Discriminator outputs or list of + discriminator outputs calculated from groundtruth. + Returns + ---------- + Tensor + Discriminator real loss value. + Tensor + Discriminator fake loss value. + """ + if isinstance(outputs, (tuple, list)): + real_loss = 0.0 + fake_loss = 0.0 + for i, (outputs_hat_, + outputs_) in enumerate(zip(outputs_hat, outputs)): + if isinstance(outputs_hat_, (tuple, list)): + # case including feature maps + outputs_hat_ = outputs_hat_[-1] + outputs_ = outputs_[-1] + real_loss += self.real_criterion(outputs_) + fake_loss += self.fake_criterion(outputs_hat_) + if self.average_by_discriminators: + fake_loss /= i + 1 + real_loss /= i + 1 + else: + real_loss = self.real_criterion(outputs) + fake_loss = self.fake_criterion(outputs_hat) + + return real_loss, fake_loss + + def _mse_real_loss(self, x): + return F.mse_loss(x, paddle.ones_like(x)) + + def _mse_fake_loss(self, x): + return F.mse_loss(x, paddle.zeros_like(x)) diff --git a/parakeet/modules/causal_conv.py b/parakeet/modules/causal_conv.py new file mode 100644 index 00000000..c0dd5b28 --- /dev/null +++ b/parakeet/modules/causal_conv.py @@ -0,0 +1,81 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Causal convolusion layer modules.""" +import paddle + + +class CausalConv1D(paddle.nn.Layer): + """CausalConv1D module with customized initialization.""" + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + dilation=1, + bias=True, + pad="Pad1D", + pad_params={"value": 0.0}, ): + """Initialize CausalConv1d module.""" + super().__init__() + self.pad = getattr(paddle.nn, pad)((kernel_size - 1) * dilation, + **pad_params) + self.conv = paddle.nn.Conv1D( + in_channels, + out_channels, + kernel_size, + dilation=dilation, + bias_attr=bias) + + def forward(self, x): + """Calculate forward propagation. + Parameters + ---------- + x : Tensor + Input tensor (B, in_channels, T). + Returns + ---------- + Tensor + Output tensor (B, out_channels, T). + """ + return self.conv(self.pad(x))[:, :, :x.shape[2]] + + +class CausalConv1DTranspose(paddle.nn.Layer): + """CausalConv1DTranspose module with customized initialization.""" + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + bias=True): + """Initialize CausalConvTranspose1d module.""" + super().__init__() + self.deconv = paddle.nn.Conv1DTranspose( + in_channels, out_channels, kernel_size, stride, bias_attr=bias) + self.stride = stride + + def forward(self, x): + """Calculate forward propagation. + Parameters + ---------- + x : Tensor + Input tensor (B, in_channels, T_in). + Returns + ---------- + Tensor + Output tensor (B, out_channels, T_out). + """ + return self.deconv(x)[:, :, :-self.stride] diff --git a/parakeet/modules/pqmf.py b/parakeet/modules/pqmf.py new file mode 100644 index 00000000..275addd2 --- /dev/null +++ b/parakeet/modules/pqmf.py @@ -0,0 +1,140 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pseudo QMF modules.""" +import numpy as np +import paddle +import paddle.nn.functional as F +from scipy.signal import kaiser + + +def design_prototype_filter(taps=62, cutoff_ratio=0.142, beta=9.0): + """Design prototype filter for PQMF. + This method is based on `A Kaiser window approach for the design of prototype + filters of cosine modulated filterbanks`_. + Parameters + ---------- + taps : int + The number of filter taps. + cutoff_ratio : float + Cut-off frequency ratio. + beta : float + Beta coefficient for kaiser window. + Returns + ---------- + ndarray + Impluse response of prototype filter (taps + 1,). + .. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`: + https://ieeexplore.ieee.org/abstract/document/681427 + """ + # check the arguments are valid + assert taps % 2 == 0, "The number of taps mush be even number." + assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0." + + # make initial filter + omega_c = np.pi * cutoff_ratio + with np.errstate(invalid="ignore"): + h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) / ( + np.pi * (np.arange(taps + 1) - 0.5 * taps)) + h_i[taps // + 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form + + # apply kaiser window + w = kaiser(taps + 1, beta) + h = h_i * w + + return h + + +class PQMF(paddle.nn.Layer): + """PQMF module. + This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_. + .. _`Near-perfect-reconstruction pseudo-QMF banks`: + https://ieeexplore.ieee.org/document/258122 + """ + + def __init__(self, subbands=4, taps=62, cutoff_ratio=0.142, beta=9.0): + """Initilize PQMF module. + The cutoff_ratio and beta parameters are optimized for #subbands = 4. + See dicussion in https://github.com/kan-bayashi/ParallelWaveGAN/issues/195. + Parameters + ---------- + subbands : int + The number of subbands. + taps : int + The number of filter taps. + cutoff_ratio : float + Cut-off frequency ratio. + beta : float + Beta coefficient for kaiser window. + """ + super(PQMF, self).__init__() + + # build analysis & synthesis filter coefficients + h_proto = design_prototype_filter(taps, cutoff_ratio, beta) + h_analysis = np.zeros((subbands, len(h_proto))) + h_synthesis = np.zeros((subbands, len(h_proto))) + for k in range(subbands): + h_analysis[k] = ( + 2 * h_proto * np.cos((2 * k + 1) * (np.pi / (2 * subbands)) * ( + np.arange(taps + 1) - (taps / 2)) + (-1)**k * np.pi / 4)) + h_synthesis[k] = ( + 2 * h_proto * np.cos((2 * k + 1) * (np.pi / (2 * subbands)) * ( + np.arange(taps + 1) - (taps / 2)) - (-1)**k * np.pi / 4)) + + # convert to tensor + self.analysis_filter = paddle.to_tensor( + h_analysis, dtype="float32").unsqueeze(1) + self.synthesis_filter = paddle.to_tensor( + h_synthesis, dtype="float32").unsqueeze(0) + + # filter for downsampling & upsampling + updown_filter = paddle.zeros( + (subbands, subbands, subbands), dtype="float32") + for k in range(subbands): + updown_filter[k, k, 0] = 1.0 + self.updown_filter = updown_filter + self.subbands = subbands + + # keep padding info + self.pad_fn = paddle.nn.Pad1D(taps // 2, mode='constant', value=0.0) + + def analysis(self, x): + """Analysis with PQMF. + Parameters + ---------- + x : Tensor + Input tensor (B, 1, T). + Returns + ---------- + Tensor + Output tensor (B, subbands, T // subbands). + """ + x = F.conv1d(self.pad_fn(x), self.analysis_filter) + return F.conv1d(x, self.updown_filter, stride=self.subbands) + + def synthesis(self, x): + """Synthesis with PQMF. + Parameters + ---------- + x : Tensor + Input tensor (B, subbands, T // subbands). + Returns + ---------- + Tensor + Output tensor (B, 1, T). + """ + + x = F.conv1d_transpose( + x, self.updown_filter * self.subbands, stride=self.subbands) + return F.conv1d(self.pad_fn(x), self.synthesis_filter) diff --git a/parakeet/modules/residual_stack.py b/parakeet/modules/residual_stack.py new file mode 100644 index 00000000..135c32e5 --- /dev/null +++ b/parakeet/modules/residual_stack.py @@ -0,0 +1,109 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Residual stack module in MelGAN.""" +from typing import Any +from typing import Dict + +from paddle import nn + +from parakeet.modules.causal_conv import CausalConv1D + + +class ResidualStack(nn.Layer): + """Residual stack module introduced in MelGAN.""" + + def __init__( + self, + kernel_size: int=3, + channels: int=32, + dilation: int=1, + bias: bool=True, + nonlinear_activation: str="LeakyReLU", + nonlinear_activation_params: Dict[str, Any]={"negative_slope": 0.2}, + pad: str="Pad1D", + pad_params: Dict[str, Any]={"mode": "reflect"}, + use_causal_conv: bool=False, ): + """Initialize ResidualStack module. + Parameters + ---------- + kernel_size : int + Kernel size of dilation convolution layer. + channels : int + Number of channels of convolution layers. + dilation : int + Dilation factor. + bias : bool + Whether to add bias parameter in convolution layers. + nonlinear_activation : str + Activation function module name. + nonlinear_activation_params : Dict[str,Any] + Hyperparameters for activation function. + pad : str + Padding function module name before dilated convolution layer. + pad_params : Dict[str, Any] + Hyperparameters for padding function. + use_causal_conv : bool + Whether to use causal convolution. + """ + super().__init__() + + # defile residual stack part + if not use_causal_conv: + assert (kernel_size - 1 + ) % 2 == 0, "Not support even number kernel size." + self.stack = nn.Sequential( + getattr(nn, nonlinear_activation)( + **nonlinear_activation_params), + getattr(nn, pad)((kernel_size - 1) // 2 * dilation, + **pad_params), + nn.Conv1D( + channels, + channels, + kernel_size, + dilation=dilation, + bias_attr=bias), + getattr(nn, nonlinear_activation)( + **nonlinear_activation_params), + nn.Conv1D(channels, channels, 1, bias_attr=bias), ) + else: + self.stack = nn.Sequential( + getattr(nn, nonlinear_activation)( + **nonlinear_activation_params), + CausalConv1D( + channels, + channels, + kernel_size, + dilation=dilation, + bias=bias, + pad=pad, + pad_params=pad_params, ), + getattr(nn, nonlinear_activation)( + **nonlinear_activation_params), + nn.Conv1D(channels, channels, 1, bias_attr=bias), ) + + # defile extra layer for skip connection + self.skip_layer = nn.Conv1D(channels, channels, 1, bias_attr=bias) + + def forward(self, c): + """Calculate forward propagation. + Parameters + ---------- + c : Tensor + Input tensor (B, channels, T). + Returns + ---------- + Tensor + Output tensor (B, chennels, T). + """ + return self.stack(c) + self.skip_layer(c)