From 6dbcd7720d349813258c0a532093b4123432986d Mon Sep 17 00:00:00 2001 From: TianYuan Date: Tue, 26 Oct 2021 11:45:12 +0000 Subject: [PATCH 1/4] add csmsc mb melgan example --- examples/csmsc/voc3/README.md | 127 ++++ examples/csmsc/voc3/conf/default.yaml | 139 +++++ examples/csmsc/voc3/local/preprocess.sh | 55 ++ examples/csmsc/voc3/local/synthesize.sh | 13 + examples/csmsc/voc3/local/train.sh | 13 + examples/csmsc/voc3/path.sh | 13 + examples/csmsc/voc3/run.sh | 32 ++ parakeet/datasets/vocoder_batch_fn.py | 9 +- .../gan_vocoder/multi_band_melgan/__init__.py | 13 + .../multi_band_melgan/synthesize.py | 97 ++++ .../gan_vocoder/multi_band_melgan/train.py | 271 +++++++++ parakeet/models/melgan/__init__.py | 15 + parakeet/models/melgan/melgan.py | 541 ++++++++++++++++++ parakeet/models/melgan/melgan_updater.py | 231 ++++++++ .../melgan/multi_band_melgan_updater.py | 245 ++++++++ .../parallel_wavegan/parallel_wavegan.py | 11 +- parakeet/modules/adversarial_loss.py | 124 ++++ parakeet/modules/causal_conv.py | 81 +++ parakeet/modules/pqmf.py | 140 +++++ parakeet/modules/residual_stack.py | 109 ++++ 20 files changed, 2271 insertions(+), 8 deletions(-) create mode 100644 examples/csmsc/voc3/README.md create mode 100644 examples/csmsc/voc3/conf/default.yaml create mode 100755 examples/csmsc/voc3/local/preprocess.sh create mode 100755 examples/csmsc/voc3/local/synthesize.sh create mode 100755 examples/csmsc/voc3/local/train.sh create mode 100755 examples/csmsc/voc3/path.sh create mode 100755 examples/csmsc/voc3/run.sh create mode 100644 parakeet/exps/gan_vocoder/multi_band_melgan/__init__.py create mode 100644 parakeet/exps/gan_vocoder/multi_band_melgan/synthesize.py create mode 100644 parakeet/exps/gan_vocoder/multi_band_melgan/train.py create mode 100644 parakeet/models/melgan/__init__.py create mode 100644 parakeet/models/melgan/melgan.py create mode 100644 parakeet/models/melgan/melgan_updater.py create mode 100644 parakeet/models/melgan/multi_band_melgan_updater.py create mode 100644 parakeet/modules/adversarial_loss.py create mode 100644 parakeet/modules/causal_conv.py create mode 100644 parakeet/modules/pqmf.py create mode 100644 parakeet/modules/residual_stack.py diff --git a/examples/csmsc/voc3/README.md b/examples/csmsc/voc3/README.md new file mode 100644 index 00000000..780a8ccd --- /dev/null +++ b/examples/csmsc/voc3/README.md @@ -0,0 +1,127 @@ +# Multi Band MelGAN with CSMSC +This example contains code used to train a [Multi Band MelGAN](https://arxiv.org/abs/2005.05106) model with [Chinese Standard Mandarin Speech Copus](https://www.data-baker.com/open_source.html). +## Dataset +### Download and Extract the datasaet +Download CSMSC from the [official website](https://www.data-baker.com/data/index/source) and extract it to `~/datasets`. Then the dataset is in directory `~/datasets/BZNSYP`. + +### Get MFA results for silence trim +We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut silence in the edge of audio. +You can download from here [baker_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/BZNSYP/with_tone/baker_alignment_tone.tar.gz), or train your own MFA model reference to [use_mfa example](https://github.com/PaddlePaddle/Parakeet/tree/develop/examples/use_mfa) of our repo. + +## Get Started +Assume the path to the dataset is `~/datasets/BZNSYP`. +Assume the path to the MFA result of CSMSC is `./baker_alignment_tone`. +Run the command below to +1. **source path**. +2. preprocess the dataset, +3. train the model. +4. synthesize wavs. + - synthesize waveform from `metadata.jsonl`. +```bash +./run.sh +``` +### Preprocess the dataset +```bash +./local/preprocess.sh ${conf_path} +``` +When it is done. A `dump` folder is created in the current directory. The structure of the dump folder is listed below. + +```text +dump +├── dev +│ ├── norm +│ └── raw +├── test +│ ├── norm +│ └── raw +└── train + ├── norm + ├── raw + └── feats_stats.npy +``` +The dataset is split into 3 parts, namely `train`, `dev` and `test`, each of which contains a `norm` and `raw` subfolder. The `raw` folder contains log magnitude of mel spectrogram of each utterances, while the norm folder contains normalized spectrogram. The statistics used to normalize the spectrogram is computed from the training set, which is located in `dump/train/feats_stats.npy`. + +Also there is a `metadata.jsonl` in each subfolder. It is a table-like file which contains id and paths to spectrogam of each utterance. + +### Train the model +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} +``` +`./local/train.sh` calls `${BIN_DIR}/train.py`. +Here's the complete help message. + +```text +usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA] + [--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR] + [--device DEVICE] [--nprocs NPROCS] [--verbose VERBOSE] + [--batch-size BATCH_SIZE] [--max-iter MAX_ITER] + [--run-benchmark RUN_BENCHMARK] + [--profiler_options PROFILER_OPTIONS] + +Train a ParallelWaveGAN model. + +optional arguments: + -h, --help show this help message and exit + --config CONFIG config file to overwrite default config. + --train-metadata TRAIN_METADATA + training data. + --dev-metadata DEV_METADATA + dev data. + --output-dir OUTPUT_DIR + output dir. + --device DEVICE device type to use. + --nprocs NPROCS number of processes. + --verbose VERBOSE verbose. + +benchmark: + arguments related to benchmark. + + --batch-size BATCH_SIZE + batch size. + --max-iter MAX_ITER train max steps. + --run-benchmark RUN_BENCHMARK + runing benchmark or not, if True, use the --batch-size + and --max-iter. + --profiler_options PROFILER_OPTIONS + The option of profiler, which should be in format + "key1=value1;key2=value2;key3=value3". +``` + +1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`. +2. `--train-metadata` and `--dev-metadata` should be the metadata file in the normalized subfolder of `train` and `dev` in the `dump` folder. +3. `--output-dir` is the directory to save the results of the experiment. Checkpoints are save in `checkpoints/` inside this directory. +4. `--device` is the type of the device to run the experiment, 'cpu' or 'gpu' are supported. +5. `--nprocs` is the number of processes to run in parallel, note that nprocs > 1 is only supported when `--device` is 'gpu'. + +### Synthesize +`./local/synthesize.sh` calls `${BIN_DIR}/synthesize.py`, which can synthesize waveform from `metadata.jsonl`. +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} +``` +```text +usage: synthesize.py [-h] [--config CONFIG] [--checkpoint CHECKPOINT] + [--test-metadata TEST_METADATA] [--output-dir OUTPUT_DIR] + [--device DEVICE] [--verbose VERBOSE] + +Synthesize with parallel wavegan. + +optional arguments: + -h, --help show this help message and exit + --config CONFIG parallel wavegan config file. + --checkpoint CHECKPOINT + snapshot to load. + --test-metadata TEST_METADATA + dev data. + --output-dir OUTPUT_DIR + output dir. + --device DEVICE device to run. + --verbose VERBOSE verbose. +``` + +1. `--config` parallel wavegan config file. You should use the same config with which the model is trained. +2. `--checkpoint` is the checkpoint to load. Pick one of the checkpoints from `checkpoints` inside the training output directory. +3. `--test-metadata` is the metadata of the test dataset. Use the `metadata.jsonl` in the `dev/norm` subfolder from the processed directory. +4. `--output-dir` is the directory to save the synthesized audio files. +5. `--device` is the type of device to run synthesis, 'cpu' and 'gpu' are supported. + +## Pretrained Models diff --git a/examples/csmsc/voc3/conf/default.yaml b/examples/csmsc/voc3/conf/default.yaml new file mode 100644 index 00000000..87a237c3 --- /dev/null +++ b/examples/csmsc/voc3/conf/default.yaml @@ -0,0 +1,139 @@ +# This is the hyperparameter configuration file for MelGAN. +# Please make sure this is adjusted for the CSMSC dataset. If you want to +# apply to the other dataset, you might need to carefully change some parameters. +# This configuration requires ~ 8GB memory and will finish within 7 days on Titan V. + +# This configuration is based on full-band MelGAN but the hop size and sampling +# rate is different from the paper (16kHz vs 24kHz). The number of iteraions +# is not shown in the paper so currently we train 1M iterations (not sure enough +# to converge). The optimizer setting is based on @dathudeptrai advice. +# https://github.com/kan-bayashi/ParallelWaveGAN/issues/143#issuecomment-632539906 + +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +fs: 24000 # Sampling rate. +n_fft: 2048 # FFT size. (in samples) +n_shift: 300 # Hop size. (in samples) +win_length: 1200 # Window length. (in samples) + # If set to null, it will be the same as fft_size. +window: "hann" # Window function. +n_mels: 80 # Number of mel basis. +fmin: 80 # Minimum freq in mel basis calculation. (Hz) +fmax: 7600 # Maximum frequency in mel basis calculation. (Hz) + +########################################################### +# GENERATOR NETWORK ARCHITECTURE SETTING # +########################################################### +generator_params: + in_channels: 80 # Number of input channels. + out_channels: 4 # Number of output channels. + kernel_size: 7 # Kernel size of initial and final conv layers. + channels: 384 # Initial number of channels for conv layers. + upsample_scales: [5, 5, 3] # List of Upsampling scales. + stack_kernel_size: 3 # Kernel size of dilated conv layers in residual stack. + stacks: 4 # Number of stacks in a single residual stack module. + use_weight_norm: True # Whether to use weight normalization. + use_causal_conv: False # Whether to use causal convolution. + use_final_nonlinear_activation: False # If True, spectral_convergence_loss and sub_spectral_convergence_loss will be too large (eg.30) + + +########################################################### +# DISCRIMINATOR NETWORK ARCHITECTURE SETTING # +########################################################### +discriminator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + scales: 3 # Number of multi-scales. + downsample_pooling: "AvgPool1D" # Pooling type for the input downsampling. + downsample_pooling_params: # Parameters of the above pooling function. + kernel_size: 4 + stride: 2 + padding: 1 + exclusive: True + kernel_sizes: [5, 3] # List of kernel size. + channels: 16 # Number of channels of the initial conv layer. + max_downsample_channels: 512 # Maximum number of channels of downsampling layers. + downsample_scales: [4, 4, 4] # List of downsampling scales. + nonlinear_activation: "LeakyReLU" # Nonlinear activation function. + nonlinear_activation_params: # Parameters of nonlinear activation function. + negative_slope: 0.2 + use_weight_norm: True # Whether to use weight norm. + + +########################################################### +# STFT LOSS SETTING # +########################################################### +use_stft_loss: true +stft_loss_params: + fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss. + hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss + win_lengths: [600, 1200, 240] # List of window length for STFT-based loss. + window: "hann" # Window function for STFT-based loss +use_subband_stft_loss: true +subband_stft_loss_params: + fft_sizes: [384, 683, 171] # List of FFT size for STFT-based loss. + hop_sizes: [30, 60, 10] # List of hop size for STFT-based loss + win_lengths: [150, 300, 60] # List of window length for STFT-based loss. + window: "hann" # Window function for STFT-based loss + +########################################################### +# ADVERSARIAL LOSS SETTING # +########################################################### +use_feat_match_loss: false # Whether to use feature matching loss. +lambda_adv: 2.5 # Loss balancing coefficient for adversarial loss. + +########################################################### +# DATA LOADER SETTING # +########################################################### +batch_size: 64 # Batch size. +batch_max_steps: 16200 # Length of each audio in batch. Make sure dividable by hop_size. +num_workers: 4 # Number of workers in DataLoader. + +########################################################### +# OPTIMIZER & SCHEDULER SETTING # +########################################################### +generator_optimizer_params: + epsilon: 1.0e-7 # Generator's epsilon. + weight_decay: 0.0 # Generator's weight decay coefficient. + +generator_grad_norm: -1 # Generator's gradient norm. +generator_scheduler_params: + learning_rate: 1.0e-3 # Generator's learning rate. + gamma: 0.5 # Generator's scheduler gamma. + milestones: # At each milestone, lr will be multiplied by gamma. + - 100000 + - 200000 + - 300000 + - 400000 + - 500000 + - 600000 +discriminator_optimizer_params: + epsilon: 1.0e-7 # Discriminator's epsilon. + weight_decay: 0.0 # Discriminator's weight decay coefficient. + +discriminator_grad_norm: -1 # Discriminator's gradient norm. +discriminator_scheduler_params: + learning_rate: 1.0e-3 # Discriminator's learning rate. + gamma: 0.5 # Discriminator's scheduler gamma. + milestones: # At each milestone, lr will be multiplied by gamma. + - 100000 + - 200000 + - 300000 + - 400000 + - 500000 + - 600000 + +########################################################### +# INTERVAL SETTING # +########################################################### +discriminator_train_start_steps: 200000 # Number of steps to start to train discriminator. +train_max_steps: 1000000 # Number of training steps. +save_interval_steps: 50000 # Interval steps to save checkpoint. +eval_interval_steps: 1000 # Interval steps to evaluate the network. + +########################################################### +# OTHER SETTING # +########################################################### +num_snapshots: 10 # max number of snapshots to keep while training +seed: 42 # random seed for paddle, random, and np.random \ No newline at end of file diff --git a/examples/csmsc/voc3/local/preprocess.sh b/examples/csmsc/voc3/local/preprocess.sh new file mode 100755 index 00000000..61d6d62b --- /dev/null +++ b/examples/csmsc/voc3/local/preprocess.sh @@ -0,0 +1,55 @@ +#!/bin/bash + +stage=0 +stop_stage=100 + +config_path=$1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # get durations from MFA's result + echo "Generate durations.txt from MFA results ..." + python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \ + --inputdir=./baker_alignment_tone \ + --output=durations.txt \ + --config=${config_path} +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # extract features + echo "Extract features ..." + python3 ${BIN_DIR}/../preprocess.py \ + --rootdir=~/datasets/BZNSYP/ \ + --dataset=baker \ + --dumpdir=dump \ + --dur-file=durations.txt \ + --config=${config_path} \ + --cut-sil=True \ + --num-cpu=20 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # get features' stats(mean and std) + echo "Get features' stats ..." + python3 ${MAIN_ROOT}/utils/compute_statistics.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --field-name="feats" +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # normalize, dev and test should use train's stats + echo "Normalize ..." + + python3 ${BIN_DIR}/../normalize.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --dumpdir=dump/train/norm \ + --stats=dump/train/feats_stats.npy + python3 ${BIN_DIR}/../normalize.py \ + --metadata=dump/dev/raw/metadata.jsonl \ + --dumpdir=dump/dev/norm \ + --stats=dump/train/feats_stats.npy + + python3 ${BIN_DIR}/../normalize.py \ + --metadata=dump/test/raw/metadata.jsonl \ + --dumpdir=dump/test/norm \ + --stats=dump/train/feats_stats.npy +fi diff --git a/examples/csmsc/voc3/local/synthesize.sh b/examples/csmsc/voc3/local/synthesize.sh new file mode 100755 index 00000000..9f904ac0 --- /dev/null +++ b/examples/csmsc/voc3/local/synthesize.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 + +FLAGS_allocator_strategy=naive_best_fit \ +FLAGS_fraction_of_gpu_memory_to_use=0.01 \ +python3 ${BIN_DIR}/synthesize.py \ + --config=${config_path} \ + --checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ + --test-metadata=dump/test/norm/metadata.jsonl \ + --output-dir=${train_output_path}/test diff --git a/examples/csmsc/voc3/local/train.sh b/examples/csmsc/voc3/local/train.sh new file mode 100755 index 00000000..1ef860c3 --- /dev/null +++ b/examples/csmsc/voc3/local/train.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 + +FLAGS_cudnn_exhaustive_search=true \ +FLAGS_conv_workspace_size_limit=4000 \ +python ${BIN_DIR}/train.py \ + --train-metadata=dump/train/norm/metadata.jsonl \ + --dev-metadata=dump/dev/norm/metadata.jsonl \ + --config=${config_path} \ + --output-dir=${train_output_path} \ + --nprocs=1 diff --git a/examples/csmsc/voc3/path.sh b/examples/csmsc/voc3/path.sh new file mode 100755 index 00000000..f6b9fe61 --- /dev/null +++ b/examples/csmsc/voc3/path.sh @@ -0,0 +1,13 @@ +#!/bin/bash +export MAIN_ROOT=`realpath ${PWD}/../../../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +export PYTHONDONTWRITEBYTECODE=1 +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +MODEL=multi_band_melgan +export BIN_DIR=${MAIN_ROOT}/parakeet/exps/gan_vocoder/${MODEL} \ No newline at end of file diff --git a/examples/csmsc/voc3/run.sh b/examples/csmsc/voc3/run.sh new file mode 100755 index 00000000..360f6ec2 --- /dev/null +++ b/examples/csmsc/voc3/run.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +set -e +source path.sh + +gpus=0,1 +stage=0 +stop_stage=100 + +conf_path=conf/default.yaml +train_output_path=exp/default +ckpt_name=snapshot_iter_50000.pdz + +# with the following command, you can choice the stage range you want to run +# such as `./run.sh --stage 0 --stop-stage 0` +# this can not be mixed use with `$1`, `$2` ... +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + ./local/preprocess.sh ${conf_path} || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `train_output_path/checkpoints/` dir + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # synthesize + CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 +fi diff --git a/parakeet/datasets/vocoder_batch_fn.py b/parakeet/datasets/vocoder_batch_fn.py index 30adb142..2de4fb12 100644 --- a/parakeet/datasets/vocoder_batch_fn.py +++ b/parakeet/datasets/vocoder_batch_fn.py @@ -107,8 +107,13 @@ class Clip(object): features, this process will be needed. """ - if len(x) < c.shape[1] * self.hop_size: - x = np.pad(x, (0, c.shape[1] * self.hop_size - len(x)), mode="edge") + if len(x) < c.shape[0] * self.hop_size: + x = np.pad(x, (0, c.shape[0] * self.hop_size - len(x)), mode="edge") + elif len(x) > c.shape[0] * self.hop_size: + print( + f"wave length: ({len(x)}), mel length: ({c.shape[0]}), hop size: ({self.hop_size })" + ) + x = x[:c.shape[1] * self.hop_size] # check the legnth is valid assert len(x) == c.shape[ diff --git a/parakeet/exps/gan_vocoder/multi_band_melgan/__init__.py b/parakeet/exps/gan_vocoder/multi_band_melgan/__init__.py new file mode 100644 index 00000000..185a92b8 --- /dev/null +++ b/parakeet/exps/gan_vocoder/multi_band_melgan/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/parakeet/exps/gan_vocoder/multi_band_melgan/synthesize.py b/parakeet/exps/gan_vocoder/multi_band_melgan/synthesize.py new file mode 100644 index 00000000..d48fbbd0 --- /dev/null +++ b/parakeet/exps/gan_vocoder/multi_band_melgan/synthesize.py @@ -0,0 +1,97 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +from pathlib import Path + +import jsonlines +import numpy as np +import paddle +import soundfile as sf +import yaml +from paddle import distributed as dist +from timer import timer +from yacs.config import CfgNode + +from parakeet.datasets.data_table import DataTable +from parakeet.models.melgan import MelGANGenerator + + +def main(): + parser = argparse.ArgumentParser( + description="Synthesize with parallel wavegan.") + parser.add_argument( + "--config", type=str, help="parallel wavegan config file.") + parser.add_argument("--checkpoint", type=str, help="snapshot to load.") + parser.add_argument("--test-metadata", type=str, help="dev data.") + parser.add_argument("--output-dir", type=str, help="output dir.") + parser.add_argument( + "--device", type=str, default="gpu", help="device to run.") + parser.add_argument("--verbose", type=int, default=1, help="verbose.") + + args = parser.parse_args() + + with open(args.config) as f: + config = CfgNode(yaml.safe_load(f)) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(config) + print( + f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}" + ) + + paddle.set_device(args.device) + generator = MelGANGenerator(**config["generator_params"]) + state_dict = paddle.load(args.checkpoint) + generator.set_state_dict(state_dict["generator_params"]) + + generator.remove_weight_norm() + generator.eval() + with jsonlines.open(args.test_metadata, 'r') as reader: + metadata = list(reader) + + test_dataset = DataTable( + metadata, + fields=['utt_id', 'feats'], + converters={ + 'utt_id': None, + 'feats': np.load, + }) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + N = 0 + T = 0 + for example in test_dataset: + utt_id = example['utt_id'] + mel = example['feats'] + mel = paddle.to_tensor(mel) # (T, C) + with timer() as t: + with paddle.no_grad(): + wav = generator.inference(c=mel) + wav = wav.numpy() + N += wav.size + T += t.elapse + speed = wav.size / t.elapse + print( + f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {config.fs / speed}." + ) + sf.write(str(output_dir / (utt_id + ".wav")), wav, samplerate=config.fs) + print(f"generation speed: {N / T}Hz, RTF: {config.fs / (N / T) }") + + +if __name__ == "__main__": + main() diff --git a/parakeet/exps/gan_vocoder/multi_band_melgan/train.py b/parakeet/exps/gan_vocoder/multi_band_melgan/train.py new file mode 100644 index 00000000..bb9b0b8a --- /dev/null +++ b/parakeet/exps/gan_vocoder/multi_band_melgan/train.py @@ -0,0 +1,271 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import logging +import os +import shutil +from pathlib import Path + +import jsonlines +import numpy as np +import paddle +import yaml +from paddle import DataParallel +from paddle import distributed as dist +from paddle import nn +from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler +from paddle.optimizer import Adam +from paddle.optimizer.lr import MultiStepDecay +from visualdl import LogWriter +from yacs.config import CfgNode + +from parakeet.datasets.data_table import DataTable +from parakeet.datasets.vocoder_batch_fn import Clip +from parakeet.models.melgan import MBMelGANEvaluator +from parakeet.models.melgan import MBMelGANUpdater +from parakeet.models.melgan import MelGANGenerator +from parakeet.models.melgan import MelGANMultiScaleDiscriminator +from parakeet.modules.adversarial_loss import DiscriminatorAdversarialLoss +from parakeet.modules.adversarial_loss import GeneratorAdversarialLoss +from parakeet.modules.pqmf import PQMF +from parakeet.modules.stft_loss import MultiResolutionSTFTLoss +from parakeet.training.extensions.snapshot import Snapshot +from parakeet.training.extensions.visualizer import VisualDL +from parakeet.training.seeding import seed_everything +from parakeet.training.trainer import Trainer + + +def train_sp(args, config): + # decides device type and whether to run in parallel + # setup running environment correctly + world_size = paddle.distributed.get_world_size() + if not paddle.is_compiled_with_cuda(): + paddle.set_device("cpu") + else: + paddle.set_device("gpu") + if world_size > 1: + paddle.distributed.init_parallel_env() + + # set the random seed, it is a must for multiprocess training + seed_everything(config.seed) + + print( + f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}", + ) + + # dataloader has been too verbose + logging.getLogger("DataLoader").disabled = True + + # construct dataset for training and validation + with jsonlines.open(args.train_metadata, 'r') as reader: + train_metadata = list(reader) + train_dataset = DataTable( + data=train_metadata, + fields=["wave", "feats"], + converters={ + "wave": np.load, + "feats": np.load, + }, ) + with jsonlines.open(args.dev_metadata, 'r') as reader: + dev_metadata = list(reader) + dev_dataset = DataTable( + data=dev_metadata, + fields=["wave", "feats"], + converters={ + "wave": np.load, + "feats": np.load, + }, ) + + # collate function and dataloader + train_sampler = DistributedBatchSampler( + train_dataset, + batch_size=config.batch_size, + shuffle=True, + drop_last=True) + dev_sampler = DistributedBatchSampler( + dev_dataset, + batch_size=config.batch_size, + shuffle=False, + drop_last=False) + print("samplers done!") + + if "aux_context_window" in config.generator_params: + aux_context_window = config.generator_params.aux_context_window + else: + aux_context_window = 0 + train_batch_fn = Clip( + batch_max_steps=config.batch_max_steps, + hop_size=config.n_shift, + aux_context_window=aux_context_window) + + train_dataloader = DataLoader( + train_dataset, + batch_sampler=train_sampler, + collate_fn=train_batch_fn, + num_workers=config.num_workers) + + dev_dataloader = DataLoader( + dev_dataset, + batch_sampler=dev_sampler, + collate_fn=train_batch_fn, + num_workers=config.num_workers) + print("dataloaders done!") + + generator = MelGANGenerator(**config["generator_params"]) + discriminator = MelGANMultiScaleDiscriminator( + **config["discriminator_params"]) + if world_size > 1: + generator = DataParallel(generator) + discriminator = DataParallel(discriminator) + print("models done!") + criterion_stft = MultiResolutionSTFTLoss(**config["stft_loss_params"]) + criterion_sub_stft = MultiResolutionSTFTLoss( + **config["subband_stft_loss_params"]) + criterion_gen_adv = GeneratorAdversarialLoss() + criterion_dis_adv = DiscriminatorAdversarialLoss() + # define special module for subband processing + criterion_pqmf = PQMF(subbands=config["generator_params"]["out_channels"]) + print("criterions done!") + + lr_schedule_g = MultiStepDecay(**config["generator_scheduler_params"]) + # Compared to multi_band_melgan.v1 config, Adam optimizer without gradient norm is used + generator_grad_norm = config["generator_grad_norm"] + gradient_clip_g = nn.ClipGradByGlobalNorm( + generator_grad_norm) if generator_grad_norm > 0 else None + print("gradient_clip_g:", gradient_clip_g) + + optimizer_g = Adam( + learning_rate=lr_schedule_g, + grad_clip=gradient_clip_g, + parameters=generator.parameters(), + **config["generator_optimizer_params"]) + lr_schedule_d = MultiStepDecay(**config["discriminator_scheduler_params"]) + discriminator_grad_norm = config["discriminator_grad_norm"] + gradient_clip_d = nn.ClipGradByGlobalNorm( + discriminator_grad_norm) if discriminator_grad_norm > 0 else None + print("gradient_clip_d:", gradient_clip_d) + optimizer_d = Adam( + learning_rate=lr_schedule_d, + grad_clip=gradient_clip_d, + parameters=discriminator.parameters(), + **config["discriminator_optimizer_params"]) + print("optimizers done!") + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + if dist.get_rank() == 0: + config_name = args.config.split("/")[-1] + # copy conf to output_dir + shutil.copyfile(args.config, output_dir / config_name) + + updater = MBMelGANUpdater( + models={ + "generator": generator, + "discriminator": discriminator, + }, + optimizers={ + "generator": optimizer_g, + "discriminator": optimizer_d, + }, + criterions={ + "stft": criterion_stft, + "sub_stft": criterion_sub_stft, + "gen_adv": criterion_gen_adv, + "dis_adv": criterion_dis_adv, + "pqmf": criterion_pqmf + }, + schedulers={ + "generator": lr_schedule_g, + "discriminator": lr_schedule_d, + }, + dataloader=train_dataloader, + discriminator_train_start_steps=config.discriminator_train_start_steps, + lambda_adv=config.lambda_adv, + output_dir=output_dir) + + evaluator = MBMelGANEvaluator( + models={ + "generator": generator, + "discriminator": discriminator, + }, + criterions={ + "stft": criterion_stft, + "sub_stft": criterion_sub_stft, + "gen_adv": criterion_gen_adv, + "dis_adv": criterion_dis_adv, + "pqmf": criterion_pqmf + }, + dataloader=dev_dataloader, + lambda_adv=config.lambda_adv, + output_dir=output_dir) + + trainer = Trainer( + updater, + stop_trigger=(config.train_max_steps, "iteration"), + out=output_dir) + + if dist.get_rank() == 0: + trainer.extend( + evaluator, trigger=(config.eval_interval_steps, 'iteration')) + writer = LogWriter(str(trainer.out)) + trainer.extend(VisualDL(writer), trigger=(1, 'iteration')) + trainer.extend( + Snapshot(max_size=config.num_snapshots), + trigger=(config.save_interval_steps, 'iteration')) + + print("Trainer Done!") + trainer.run() + + +def main(): + # parse args and config and redirect to train_sp + + parser = argparse.ArgumentParser( + description="Train a Multi-Band MelGAN model.") + parser.add_argument( + "--config", type=str, help="config file to overwrite default config.") + parser.add_argument("--train-metadata", type=str, help="training data.") + parser.add_argument("--dev-metadata", type=str, help="dev data.") + parser.add_argument("--output-dir", type=str, help="output dir.") + parser.add_argument( + "--device", type=str, default="gpu", help="device type to use.") + parser.add_argument( + "--nprocs", type=int, default=1, help="number of processes.") + parser.add_argument("--verbose", type=int, default=1, help="verbose.") + + args = parser.parse_args() + if args.device == "cpu" and args.nprocs > 1: + raise RuntimeError("Multiprocess training on CPU is not supported.") + + with open(args.config, 'rt') as f: + config = CfgNode(yaml.safe_load(f)) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(config) + print( + f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}" + ) + + # dispatch + if args.nprocs > 1: + dist.spawn(train_sp, (args, config), nprocs=args.nprocs) + else: + train_sp(args, config) + + +if __name__ == "__main__": + main() diff --git a/parakeet/models/melgan/__init__.py b/parakeet/models/melgan/__init__.py new file mode 100644 index 00000000..d4f557db --- /dev/null +++ b/parakeet/models/melgan/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .melgan import * +from .multi_band_melgan_updater import * diff --git a/parakeet/models/melgan/melgan.py b/parakeet/models/melgan/melgan.py new file mode 100644 index 00000000..ccc19d5f --- /dev/null +++ b/parakeet/models/melgan/melgan.py @@ -0,0 +1,541 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MelGAN Modules.""" +from typing import Any +from typing import Dict +from typing import List + +import numpy as np +import paddle +from paddle import nn +from paddle.fluid.layers import Normal + +from parakeet.modules.causal_conv import CausalConv1D +from parakeet.modules.causal_conv import CausalConv1DTranspose +from parakeet.modules.pqmf import PQMF +from parakeet.modules.residual_stack import ResidualStack + + +class MelGANGenerator(nn.Layer): + """MelGAN generator module.""" + + def __init__( + self, + in_channels: int=80, + out_channels: int=1, + kernel_size: int=7, + channels: int=512, + bias: bool=True, + upsample_scales: List[int]=[8, 8, 2, 2], + stack_kernel_size: int=3, + stacks: int=3, + nonlinear_activation: str="LeakyReLU", + nonlinear_activation_params: Dict[str, Any]={"negative_slope": 0.2}, + pad: str="Pad1D", + pad_params: Dict[str, Any]={"mode": "reflect"}, + use_final_nonlinear_activation: bool=True, + use_weight_norm: bool=True, + use_causal_conv: bool=False, ): + """Initialize MelGANGenerator module. + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels, + the number of sub-band is out_channels in multi-band melgan. + kernel_size : int + Kernel size of initial and final conv layer. + channels : int + Initial number of channels for conv layer. + bias : bool + Whether to add bias parameter in convolution layers. + upsample_scales : List[int] + List of upsampling scales. + stack_kernel_size : int + Kernel size of dilated conv layers in residual stack. + stacks : int + Number of stacks in a single residual stack. + nonlinear_activation : Optional[str], optional + Non linear activation in upsample network, by default None + nonlinear_activation_params : Dict[str, Any], optional + Parameters passed to the linear activation in the upsample network, + by default {} + pad : str + Padding function module name before dilated convolution layer. + pad_params : dict + Hyperparameters for padding function. + use_final_nonlinear_activation : paddle.nn.Layer + Activation function for the final layer. + use_weight_norm : bool + Whether to use weight norm. + If set to true, it will be applied to all of the conv layers. + use_causal_conv : bool + Whether to use causal convolution. + """ + super().__init__() + + # check hyper parameters is valid + assert channels >= np.prod(upsample_scales) + assert channels % (2**len(upsample_scales)) == 0 + if not use_causal_conv: + assert (kernel_size - 1 + ) % 2 == 0, "Not support even number kernel size." + # add initial layer + layers = [] + if not use_causal_conv: + layers += [ + getattr(paddle.nn, pad)((kernel_size - 1) // 2, **pad_params), + nn.Conv1D(in_channels, channels, kernel_size, bias_attr=bias), + ] + else: + layers += [ + CausalConv1D( + in_channels, + channels, + kernel_size, + bias=bias, + pad=pad, + pad_params=pad_params, ), + ] + + for i, upsample_scale in enumerate(upsample_scales): + # add upsampling layer + layers += [ + getattr(nn, nonlinear_activation)(**nonlinear_activation_params) + ] + if not use_causal_conv: + layers += [ + nn.Conv1DTranspose( + channels // (2**i), + channels // (2**(i + 1)), + upsample_scale * 2, + stride=upsample_scale, + padding=upsample_scale // 2 + upsample_scale % 2, + output_padding=upsample_scale % 2, + bias_attr=bias, ) + ] + else: + layers += [ + CausalConv1DTranspose( + channels // (2**i), + channels // (2**(i + 1)), + upsample_scale * 2, + stride=upsample_scale, + bias=bias, ) + ] + + # add residual stack + for j in range(stacks): + layers += [ + ResidualStack( + kernel_size=stack_kernel_size, + channels=channels // (2**(i + 1)), + dilation=stack_kernel_size**j, + bias=bias, + nonlinear_activation=nonlinear_activation, + nonlinear_activation_params=nonlinear_activation_params, + pad=pad, + pad_params=pad_params, + use_causal_conv=use_causal_conv, ) + ] + + # add final layer + layers += [ + getattr(nn, nonlinear_activation)(**nonlinear_activation_params) + ] + if not use_causal_conv: + layers += [ + getattr(nn, pad)((kernel_size - 1) // 2, **pad_params), + nn.Conv1D( + channels // (2**(i + 1)), + out_channels, + kernel_size, + bias_attr=bias), + ] + else: + layers += [ + CausalConv1D( + channels // (2**(i + 1)), + out_channels, + kernel_size, + bias=bias, + pad=pad, + pad_params=pad_params, ), + ] + if use_final_nonlinear_activation: + layers += [nn.Tanh()] + + # define the model as a single function + self.melgan = nn.Sequential(*layers) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + # reset parameters + self.reset_parameters() + + # initialize pqmf for multi-band melgan inference + if out_channels > 1: + self.pqmf = PQMF(subbands=out_channels) + else: + self.pqmf = None + + def forward(self, c): + """Calculate forward propagation. + Parameters + ---------- + c : Tensor + Input tensor (B, in_channels, T). + Returns + ---------- + Tensor + Output tensor (B, out_channels, T ** prod(upsample_scales)). + """ + out = self.melgan(c) + return out + + def apply_weight_norm(self): + """Recursively apply weight normalization to all the Convolution layers + in the sublayers. + """ + + def _apply_weight_norm(layer): + if isinstance(layer, (nn.Conv1D, nn.Conv2D, nn.Conv1DTranspose)): + nn.utils.weight_norm(layer) + + self.apply(_apply_weight_norm) + + def remove_weight_norm(self): + """Recursively remove weight normalization from all the Convolution + layers in the sublayers. + """ + + def _remove_weight_norm(layer): + try: + nn.utils.remove_weight_norm(layer) + except ValueError: + pass + + self.apply(_remove_weight_norm) + + def reset_parameters(self): + """Reset parameters. + This initialization follows official implementation manner. + https://github.com/descriptinc/melgan-neurips/blob/master/mel2wav/modules.py + """ + + # 定义参数为float的正态分布。 + dist = Normal(loc=0.0, scale=0.02) + + def _reset_parameters(m): + if isinstance(m, nn.Conv1D) or isinstance(m, nn.Conv1DTranspose): + w = dist.sample(m.weight.shape) + m.weight.set_value(w) + + self.apply(_reset_parameters) + + def inference(self, c): + """Perform inference. + Parameters + ---------- + c : Union[Tensor, ndarray] + Input tensor (T, in_channels). + Returns + ---------- + Tensor + Output tensor (out_channels*T ** prod(upsample_scales), 1). + """ + if not isinstance(c, paddle.Tensor): + c = paddle.to_tensor(c, dtype="float32") + # pseudo batch + c = c.transpose([1, 0]).unsqueeze(0) + # (B, out_channels, T ** prod(upsample_scales) + out = self.melgan(c) + if self.pqmf is not None: + # (B, 1, out_channels * T ** prod(upsample_scales) + out = self.pqmf.synthesis(out) + out = out.squeeze(0).transpose([1, 0]) + return out + + +class MelGANDiscriminator(nn.Layer): + """MelGAN discriminator module.""" + + def __init__( + self, + in_channels: int=1, + out_channels: int=1, + kernel_sizes: List[int]=[5, 3], + channels: int=16, + max_downsample_channels: int=1024, + bias: bool=True, + downsample_scales: List[int]=[4, 4, 4, 4], + nonlinear_activation: str="LeakyReLU", + nonlinear_activation_params: Dict[str, Any]={"negative_slope": 0.2}, + pad: str="Pad1D", + pad_params: Dict[str, Any]={"mode": "reflect"}, ): + """Initilize MelGAN discriminator module. + Parameters + ---------- + in_channels : + int): Number of input channels. + out_channels : int + Number of output channels. + kernel_sizes : List[int] + List of two kernel sizes. The prod will be used for the first conv layer, + and the first and the second kernel sizes will be used for the last two layers. + For example if kernel_sizes = [5, 3], the first layer kernel size will be 5 * 3 = 15, + the last two layers' kernel size will be 5 and 3, respectively. + channels : int + Initial number of channels for conv layer. + max_downsample_channels : int + Maximum number of channels for downsampling layers. + bias : bool + Whether to add bias parameter in convolution layers. + downsample_scales : List[int] + List of downsampling scales. + nonlinear_activation : str + Activation function module name. + nonlinear_activation_params : dict + Hyperparameters for activation function. + pad : str + Padding function module name before dilated convolution layer. + pad_params : dict + Hyperparameters for padding function. + """ + super().__init__() + self.layers = nn.LayerList() + + # check kernel size is valid + assert len(kernel_sizes) == 2 + assert kernel_sizes[0] % 2 == 1 + assert kernel_sizes[1] % 2 == 1 + # add first layer + self.layers.append( + nn.Sequential( + getattr(nn, pad)((np.prod(kernel_sizes) - 1) // 2, ** + pad_params), + nn.Conv1D( + in_channels, + channels, + int(np.prod(kernel_sizes)), + bias_attr=bias), + getattr(nn, nonlinear_activation)( + **nonlinear_activation_params), )) + + # add downsample layers + in_chs = channels + for downsample_scale in downsample_scales: + out_chs = min(in_chs * downsample_scale, max_downsample_channels) + self.layers.append( + nn.Sequential( + nn.Conv1D( + in_chs, + out_chs, + kernel_size=downsample_scale * 10 + 1, + stride=downsample_scale, + padding=downsample_scale * 5, + groups=in_chs // 4, + bias_attr=bias, ), + getattr(nn, nonlinear_activation)( + **nonlinear_activation_params), )) + in_chs = out_chs + + # add final layers + out_chs = min(in_chs * 2, max_downsample_channels) + self.layers.append( + nn.Sequential( + nn.Conv1D( + in_chs, + out_chs, + kernel_sizes[0], + padding=(kernel_sizes[0] - 1) // 2, + bias_attr=bias, ), + getattr(nn, nonlinear_activation)( + **nonlinear_activation_params), )) + self.layers.append( + nn.Conv1D( + out_chs, + out_channels, + kernel_sizes[1], + padding=(kernel_sizes[1] - 1) // 2, + bias_attr=bias, ), ) + + def forward(self, x): + """Calculate forward propagation. + Parameters + ---------- + x : Tensor + Input noise signal (B, 1, T). + Returns + ---------- + List + List of output tensors of each layer (for feat_match_loss). + """ + outs = [] + for f in self.layers: + x = f(x) + outs += [x] + + return outs + + +class MelGANMultiScaleDiscriminator(nn.Layer): + """MelGAN multi-scale discriminator module.""" + + def __init__( + self, + in_channels: int=1, + out_channels: int=1, + scales: int=3, + downsample_pooling: str="AvgPool1D", + # follow the official implementation setting + downsample_pooling_params: Dict[str, Any]={ + "kernel_size": 4, + "stride": 2, + "padding": 1, + "exclusive": True, + }, + kernel_sizes: List[int]=[5, 3], + channels: int=16, + max_downsample_channels: int=1024, + bias: bool=True, + downsample_scales: List[int]=[4, 4, 4, 4], + nonlinear_activation: str="LeakyReLU", + nonlinear_activation_params: Dict[str, Any]={"negative_slope": 0.2}, + pad: str="Pad1D", + pad_params: Dict[str, Any]={"mode": "reflect"}, + use_weight_norm: bool=True, ): + """Initilize MelGAN multi-scale discriminator module. + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + scales : int + Number of multi-scales. + downsample_pooling : str + Pooling module name for downsampling of the inputs. + downsample_pooling_params : dict + Parameters for the above pooling module. + kernel_sizes : List[int] + List of two kernel sizes. The sum will be used for the first conv layer, + and the first and the second kernel sizes will be used for the last two layers. + channels : int + Initial number of channels for conv layer. + max_downsample_channels : int + Maximum number of channels for downsampling layers. + bias : bool + Whether to add bias parameter in convolution layers. + downsample_scales : List[int] + List of downsampling scales. + nonlinear_activation : str + Activation function module name. + nonlinear_activation_params : dict + Hyperparameters for activation function. + pad : str + Padding function module name before dilated convolution layer. + pad_params : dict + Hyperparameters for padding function. + use_causal_conv : bool + Whether to use causal convolution. + """ + super().__init__() + self.discriminators = nn.LayerList() + + # add discriminators + for _ in range(scales): + self.discriminators.append( + MelGANDiscriminator( + in_channels=in_channels, + out_channels=out_channels, + kernel_sizes=kernel_sizes, + channels=channels, + max_downsample_channels=max_downsample_channels, + bias=bias, + downsample_scales=downsample_scales, + nonlinear_activation=nonlinear_activation, + nonlinear_activation_params=nonlinear_activation_params, + pad=pad, + pad_params=pad_params, )) + self.pooling = getattr(nn, downsample_pooling)( + **downsample_pooling_params) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + # reset parameters + self.reset_parameters() + + def forward(self, x): + """Calculate forward propagation. + Parameters + ---------- + x : Tensor + Input noise signal (B, 1, T). + Returns + ---------- + List + List of list of each discriminator outputs, which consists of each layer output tensors. + """ + outs = [] + for f in self.discriminators: + outs += [f(x)] + x = self.pooling(x) + + return outs + + def apply_weight_norm(self): + """Recursively apply weight normalization to all the Convolution layers + in the sublayers. + """ + + def _apply_weight_norm(layer): + if isinstance(layer, (nn.Conv1D, nn.Conv2D, nn.Conv1DTranspose)): + nn.utils.weight_norm(layer) + + self.apply(_apply_weight_norm) + + def remove_weight_norm(self): + """Recursively remove weight normalization from all the Convolution + layers in the sublayers. + """ + + def _remove_weight_norm(layer): + try: + nn.utils.remove_weight_norm(layer) + except ValueError: + pass + + self.apply(_remove_weight_norm) + + def reset_parameters(self): + """Reset parameters. + This initialization follows official implementation manner. + https://github.com/descriptinc/melgan-neurips/blob/master/mel2wav/modules.py + """ + + # 定义参数为float的正态分布。 + dist = Normal(loc=0.0, scale=0.02) + + def _reset_parameters(m): + if isinstance(m, nn.Conv1D) or isinstance(m, nn.Conv1DTranspose): + w = dist.sample(m.weight.shape) + m.weight.set_value(w) + + self.apply(_reset_parameters) diff --git a/parakeet/models/melgan/melgan_updater.py b/parakeet/models/melgan/melgan_updater.py new file mode 100644 index 00000000..7bd59881 --- /dev/null +++ b/parakeet/models/melgan/melgan_updater.py @@ -0,0 +1,231 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Dict + +import paddle +from paddle import distributed as dist +from paddle.io import DataLoader +from paddle.nn import Layer +from paddle.optimizer import Optimizer +from paddle.optimizer.lr import LRScheduler +from timer import timer + +from parakeet.training.extensions.evaluator import StandardEvaluator +from parakeet.training.reporter import report +from parakeet.training.updaters.standard_updater import StandardUpdater +from parakeet.training.updaters.standard_updater import UpdaterState +logging.basicConfig( + format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s', + datefmt='[%Y-%m-%d %H:%M:%S]') +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class PWGUpdater(StandardUpdater): + def __init__(self, + models: Dict[str, Layer], + optimizers: Dict[str, Optimizer], + criterions: Dict[str, Layer], + schedulers: Dict[str, LRScheduler], + dataloader: DataLoader, + discriminator_train_start_steps: int, + lambda_adv: float, + output_dir=None): + self.models = models + self.generator: Layer = models['generator'] + self.discriminator: Layer = models['discriminator'] + + self.optimizers = optimizers + self.optimizer_g: Optimizer = optimizers['generator'] + self.optimizer_d: Optimizer = optimizers['discriminator'] + + self.criterions = criterions + self.criterion_stft = criterions['stft'] + self.criterion_mse = criterions['mse'] + + self.schedulers = schedulers + self.scheduler_g = schedulers['generator'] + self.scheduler_d = schedulers['discriminator'] + + self.dataloader = dataloader + + self.discriminator_train_start_steps = discriminator_train_start_steps + self.lambda_adv = lambda_adv + self.state = UpdaterState(iteration=0, epoch=0) + + self.train_iterator = iter(self.dataloader) + + log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) + self.filehandler = logging.FileHandler(str(log_file)) + logger.addHandler(self.filehandler) + self.logger = logger + self.msg = "" + + def update_core(self, batch): + self.msg = "Rank: {}, ".format(dist.get_rank()) + losses_dict = {} + + # parse batch + wav, mel = batch + + # Generator + noise = paddle.randn(wav.shape) + + with timer() as t: + wav_ = self.generator(noise, mel) + # logging.debug(f"Generator takes {t.elapse}s.") + + # initialize + gen_loss = 0.0 + + ## Multi-resolution stft loss + with timer() as t: + sc_loss, mag_loss = self.criterion_stft(wav_, wav) + # logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s.") + + report("train/spectral_convergence_loss", float(sc_loss)) + report("train/log_stft_magnitude_loss", float(mag_loss)) + + losses_dict["spectral_convergence_loss"] = float(sc_loss) + losses_dict["log_stft_magnitude_loss"] = float(mag_loss) + + gen_loss += sc_loss + mag_loss + + ## Adversarial loss + if self.state.iteration > self.discriminator_train_start_steps: + with timer() as t: + p_ = self.discriminator(wav_) + adv_loss = self.criterion_mse(p_, paddle.ones_like(p_)) + # logging.debug( + # f"Discriminator and adversarial loss takes {t.elapse}s") + report("train/adversarial_loss", float(adv_loss)) + losses_dict["adversarial_loss"] = float(adv_loss) + gen_loss += self.lambda_adv * adv_loss + + report("train/generator_loss", float(gen_loss)) + losses_dict["generator_loss"] = float(gen_loss) + + with timer() as t: + self.optimizer_g.clear_grad() + gen_loss.backward() + # logging.debug(f"Backward takes {t.elapse}s.") + + with timer() as t: + self.optimizer_g.step() + self.scheduler_g.step() + # logging.debug(f"Update takes {t.elapse}s.") + + # Disctiminator + if self.state.iteration > self.discriminator_train_start_steps: + with paddle.no_grad(): + wav_ = self.generator(noise, mel) + p = self.discriminator(wav) + p_ = self.discriminator(wav_.detach()) + real_loss = self.criterion_mse(p, paddle.ones_like(p)) + fake_loss = self.criterion_mse(p_, paddle.zeros_like(p_)) + dis_loss = real_loss + fake_loss + report("train/real_loss", float(real_loss)) + report("train/fake_loss", float(fake_loss)) + report("train/discriminator_loss", float(dis_loss)) + losses_dict["real_loss"] = float(real_loss) + losses_dict["fake_loss"] = float(fake_loss) + losses_dict["discriminator_loss"] = float(dis_loss) + + self.optimizer_d.clear_grad() + dis_loss.backward() + + self.optimizer_d.step() + self.scheduler_d.step() + + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + + +class PWGEvaluator(StandardEvaluator): + def __init__(self, + models, + criterions, + dataloader, + lambda_adv, + output_dir=None): + self.models = models + self.generator = models['generator'] + self.discriminator = models['discriminator'] + + self.criterions = criterions + self.criterion_stft = criterions['stft'] + self.criterion_mse = criterions['mse'] + + self.dataloader = dataloader + self.lambda_adv = lambda_adv + + log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) + self.filehandler = logging.FileHandler(str(log_file)) + logger.addHandler(self.filehandler) + self.logger = logger + self.msg = "" + + def evaluate_core(self, batch): + # logging.debug("Evaluate: ") + self.msg = "Evaluate: " + losses_dict = {} + + wav, mel = batch + noise = paddle.randn(wav.shape) + + with timer() as t: + wav_ = self.generator(noise, mel) + # logging.debug(f"Generator takes {t.elapse}s") + + ## Adversarial loss + with timer() as t: + p_ = self.discriminator(wav_) + adv_loss = self.criterion_mse(p_, paddle.ones_like(p_)) + # logging.debug( + # f"Discriminator and adversarial loss takes {t.elapse}s") + report("eval/adversarial_loss", float(adv_loss)) + losses_dict["adversarial_loss"] = float(adv_loss) + gen_loss = self.lambda_adv * adv_loss + + # stft loss + with timer() as t: + sc_loss, mag_loss = self.criterion_stft(wav_, wav) + # logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s") + + report("eval/spectral_convergence_loss", float(sc_loss)) + report("eval/log_stft_magnitude_loss", float(mag_loss)) + losses_dict["spectral_convergence_loss"] = float(sc_loss) + losses_dict["log_stft_magnitude_loss"] = float(mag_loss) + gen_loss += sc_loss + mag_loss + + report("eval/generator_loss", float(gen_loss)) + losses_dict["generator_loss"] = float(gen_loss) + + # Disctiminator + p = self.discriminator(wav) + real_loss = self.criterion_mse(p, paddle.ones_like(p)) + fake_loss = self.criterion_mse(p_, paddle.zeros_like(p_)) + dis_loss = real_loss + fake_loss + report("eval/real_loss", float(real_loss)) + report("eval/fake_loss", float(fake_loss)) + report("eval/discriminator_loss", float(dis_loss)) + + losses_dict["real_loss"] = float(real_loss) + losses_dict["fake_loss"] = float(fake_loss) + losses_dict["discriminator_loss"] = float(dis_loss) + + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + self.logger.info(self.msg) diff --git a/parakeet/models/melgan/multi_band_melgan_updater.py b/parakeet/models/melgan/multi_band_melgan_updater.py new file mode 100644 index 00000000..0783cb97 --- /dev/null +++ b/parakeet/models/melgan/multi_band_melgan_updater.py @@ -0,0 +1,245 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Dict + +import paddle +from paddle import distributed as dist +from paddle.io import DataLoader +from paddle.nn import Layer +from paddle.optimizer import Optimizer +from paddle.optimizer.lr import LRScheduler + +from parakeet.training.extensions.evaluator import StandardEvaluator +from parakeet.training.reporter import report +from parakeet.training.updaters.standard_updater import StandardUpdater +from parakeet.training.updaters.standard_updater import UpdaterState +logging.basicConfig( + format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s', + datefmt='[%Y-%m-%d %H:%M:%S]') +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class MBMelGANUpdater(StandardUpdater): + def __init__(self, + models: Dict[str, Layer], + optimizers: Dict[str, Optimizer], + criterions: Dict[str, Layer], + schedulers: Dict[str, LRScheduler], + dataloader: DataLoader, + discriminator_train_start_steps: int, + lambda_adv: float, + output_dir=None): + self.models = models + self.generator: Layer = models['generator'] + self.discriminator: Layer = models['discriminator'] + + self.optimizers = optimizers + self.optimizer_g: Optimizer = optimizers['generator'] + self.optimizer_d: Optimizer = optimizers['discriminator'] + + self.criterions = criterions + self.criterion_stft = criterions['stft'] + self.criterion_sub_stft = criterions['sub_stft'] + self.criterion_pqmf = criterions['pqmf'] + self.criterion_gen_adv = criterions["gen_adv"] + self.criterion_dis_adv = criterions["dis_adv"] + + self.schedulers = schedulers + self.scheduler_g = schedulers['generator'] + self.scheduler_d = schedulers['discriminator'] + + self.dataloader = dataloader + + self.discriminator_train_start_steps = discriminator_train_start_steps + self.lambda_adv = lambda_adv + self.state = UpdaterState(iteration=0, epoch=0) + + self.train_iterator = iter(self.dataloader) + + log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) + self.filehandler = logging.FileHandler(str(log_file)) + logger.addHandler(self.filehandler) + self.logger = logger + self.msg = "" + + def update_core(self, batch): + self.msg = "Rank: {}, ".format(dist.get_rank()) + losses_dict = {} + + # parse batch + wav, mel = batch + # Generator + # (B, out_channels, T ** prod(upsample_scales) + wav_ = self.generator(mel) + wav_mb_ = wav_ + # (B, 1, out_channels*T ** prod(upsample_scales) + wav_ = self.criterion_pqmf.synthesis(wav_mb_) + + # initialize + gen_loss = 0.0 + + # full band Multi-resolution stft loss + sc_loss, mag_loss = self.criterion_stft(wav_, wav) + # for balancing with subband stft loss + # Eq.(9) in paper + gen_loss += 0.5 * (sc_loss + mag_loss) + report("train/spectral_convergence_loss", float(sc_loss)) + report("train/log_stft_magnitude_loss", float(mag_loss)) + losses_dict["spectral_convergence_loss"] = float(sc_loss) + losses_dict["log_stft_magnitude_loss"] = float(mag_loss) + + # sub band Multi-resolution stft loss + # (B, subbands, T // subbands) + wav_mb = self.criterion_pqmf.analysis(wav) + sub_sc_loss, sub_mag_loss = self.criterion_sub_stft(wav_mb_, wav_mb) + # Eq.(9) in paper + gen_loss += 0.5 * (sub_sc_loss + sub_mag_loss) + report("train/sub_spectral_convergence_loss", float(sub_sc_loss)) + report("train/sub_log_stft_magnitude_loss", float(sub_mag_loss)) + losses_dict["sub_spectral_convergence_loss"] = float(sub_sc_loss) + losses_dict["sub_log_stft_magnitude_loss"] = float(sub_mag_loss) + + ## Adversarial loss + if self.state.iteration > self.discriminator_train_start_steps: + p_ = self.discriminator(wav_) + adv_loss = self.criterion_gen_adv(p_) + + report("train/adversarial_loss", float(adv_loss)) + losses_dict["adversarial_loss"] = float(adv_loss) + gen_loss += self.lambda_adv * adv_loss + + report("train/generator_loss", float(gen_loss)) + losses_dict["generator_loss"] = float(gen_loss) + + self.optimizer_g.clear_grad() + gen_loss.backward() + + self.optimizer_g.step() + self.scheduler_g.step() + + # Disctiminator + if self.state.iteration > self.discriminator_train_start_steps: + # re-compute wav_ which leads better quality + with paddle.no_grad(): + wav_ = self.generator(mel) + wav_ = self.criterion_pqmf.synthesis(wav_) + p = self.discriminator(wav) + p_ = self.discriminator(wav_.detach()) + real_loss, fake_loss = self.criterion_dis_adv(p_, p) + dis_loss = real_loss + fake_loss + report("train/real_loss", float(real_loss)) + report("train/fake_loss", float(fake_loss)) + report("train/discriminator_loss", float(dis_loss)) + losses_dict["real_loss"] = float(real_loss) + losses_dict["fake_loss"] = float(fake_loss) + losses_dict["discriminator_loss"] = float(dis_loss) + + self.optimizer_d.clear_grad() + dis_loss.backward() + + self.optimizer_d.step() + self.scheduler_d.step() + + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + + +class MBMelGANEvaluator(StandardEvaluator): + def __init__(self, + models, + criterions, + dataloader, + lambda_adv, + output_dir=None): + self.models = models + self.generator = models['generator'] + self.discriminator = models['discriminator'] + + self.criterions = criterions + self.criterion_stft = criterions['stft'] + self.criterion_sub_stft = criterions['sub_stft'] + self.criterion_pqmf = criterions['pqmf'] + self.criterion_gen_adv = criterions["gen_adv"] + self.criterion_dis_adv = criterions["dis_adv"] + + self.dataloader = dataloader + self.lambda_adv = lambda_adv + + log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) + self.filehandler = logging.FileHandler(str(log_file)) + logger.addHandler(self.filehandler) + self.logger = logger + self.msg = "" + + def evaluate_core(self, batch): + # logging.debug("Evaluate: ") + self.msg = "Evaluate: " + losses_dict = {} + + wav, mel = batch + # Generator + # (B, out_channels, T ** prod(upsample_scales) + wav_ = self.generator(mel) + wav_mb_ = wav_ + # (B, 1, out_channels*T ** prod(upsample_scales) + wav_ = self.criterion_pqmf.synthesis(wav_mb_) + + ## Adversarial loss + p_ = self.discriminator(wav_) + adv_loss = self.criterion_gen_adv(p_) + + report("eval/adversarial_loss", float(adv_loss)) + losses_dict["adversarial_loss"] = float(adv_loss) + gen_loss = self.lambda_adv * adv_loss + + # Multi-resolution stft loss + sc_loss, mag_loss = self.criterion_stft(wav_, wav) + # Eq.(9) in paper + gen_loss += 0.5 * (sc_loss + mag_loss) + report("eval/spectral_convergence_loss", float(sc_loss)) + report("eval/log_stft_magnitude_loss", float(mag_loss)) + losses_dict["spectral_convergence_loss"] = float(sc_loss) + losses_dict["log_stft_magnitude_loss"] = float(mag_loss) + + # sub band Multi-resolution stft loss + # (B, subbands, T // subbands) + wav_mb = self.criterion_pqmf.analysis(wav) + sub_sc_loss, sub_mag_loss = self.criterion_sub_stft(wav_mb_, wav_mb) + # Eq.(9) in paper + gen_loss += 0.5 * (sub_sc_loss + sub_mag_loss) + report("eval/sub_spectral_convergence_loss", float(sub_sc_loss)) + report("eval/sub_log_stft_magnitude_loss", float(sub_mag_loss)) + losses_dict["sub_spectral_convergence_loss"] = float(sub_sc_loss) + losses_dict["sub_log_stft_magnitude_loss"] = float(sub_mag_loss) + + report("eval/generator_loss", float(gen_loss)) + losses_dict["generator_loss"] = float(gen_loss) + + # Disctiminator + p = self.discriminator(wav) + real_loss, fake_loss = self.criterion_dis_adv(p_, p) + dis_loss = real_loss + fake_loss + report("eval/real_loss", float(real_loss)) + report("eval/fake_loss", float(fake_loss)) + report("eval/discriminator_loss", float(dis_loss)) + + losses_dict["real_loss"] = float(real_loss) + losses_dict["fake_loss"] = float(fake_loss) + losses_dict["discriminator_loss"] = float(dis_loss) + + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + self.logger.info(self.msg) diff --git a/parakeet/models/parallel_wavegan/parallel_wavegan.py b/parakeet/models/parallel_wavegan/parallel_wavegan.py index bb214653..e166ccde 100644 --- a/parakeet/models/parallel_wavegan/parallel_wavegan.py +++ b/parakeet/models/parallel_wavegan/parallel_wavegan.py @@ -495,26 +495,25 @@ class PWGGenerator(nn.Layer): self.apply(_remove_weight_norm) - def inference(self, c=None): + def inference(self, c): """Waveform generation. This function is used for single instance inference. Parameters ---------- - c : Tensor, optional + c : Tensor Shape (T', C_aux), the auxiliary input, by default None - x : Tensor, optional - Shape (T, C_in), the noise waveform, by default None - If not provided, a sample is drawn from a gaussian distribution. Returns ------- Tensor Shape (T, C_out), the generated waveform """ + # a sample is drawn from a gaussian distribution. x = paddle.randn( [1, self.in_channels, paddle.shape(c)[0] * self.upsample_factor]) - c = paddle.transpose(c, [1, 0]).unsqueeze(0) # pseudo batch + # pseudo batch + c = paddle.transpose(c, [1, 0]).unsqueeze(0) c = nn.Pad1D(self.aux_context_window, mode='replicate')(c) out = self(x, c).squeeze(0).transpose([1, 0]) return out diff --git a/parakeet/modules/adversarial_loss.py b/parakeet/modules/adversarial_loss.py new file mode 100644 index 00000000..02e8c807 --- /dev/null +++ b/parakeet/modules/adversarial_loss.py @@ -0,0 +1,124 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Adversarial loss modules.""" +import paddle +import paddle.nn.functional as F +from paddle import nn + + +class GeneratorAdversarialLoss(nn.Layer): + """Generator adversarial loss module.""" + + def __init__( + self, + average_by_discriminators=True, + loss_type="mse", ): + """Initialize GeneratorAversarialLoss module.""" + super().__init__() + self.average_by_discriminators = average_by_discriminators + assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." + if loss_type == "mse": + self.criterion = self._mse_loss + else: + self.criterion = self._hinge_loss + + def forward(self, outputs): + """Calcualate generator adversarial loss. + Parameters + ---------- + outputs: Tensor or List + Discriminator outputs or list of discriminator outputs. + Returns + ---------- + Tensor + Generator adversarial loss value. + """ + if isinstance(outputs, (tuple, list)): + adv_loss = 0.0 + for i, outputs_ in enumerate(outputs): + if isinstance(outputs_, (tuple, list)): + # case including feature maps + outputs_ = outputs_[-1] + adv_loss += self.criterion(outputs_) + if self.average_by_discriminators: + adv_loss /= i + 1 + else: + adv_loss = self.criterion(outputs) + + return adv_loss + + def _mse_loss(self, x): + return F.mse_loss(x, paddle.ones_like(x)) + + def _hinge_loss(self, x): + return -x.mean() + + +class DiscriminatorAdversarialLoss(nn.Layer): + """Discriminator adversarial loss module.""" + + def __init__( + self, + average_by_discriminators=True, + loss_type="mse", ): + """Initialize DiscriminatorAversarialLoss module.""" + super().__init__() + self.average_by_discriminators = average_by_discriminators + assert loss_type in ["mse"], f"{loss_type} is not supported." + if loss_type == "mse": + self.fake_criterion = self._mse_fake_loss + self.real_criterion = self._mse_real_loss + + def forward(self, outputs_hat, outputs): + """Calcualate discriminator adversarial loss. + Parameters + ---------- + outputs_hat : Tensor or list + Discriminator outputs or list of + discriminator outputs calculated from generator outputs. + outputs : Tensor or list + Discriminator outputs or list of + discriminator outputs calculated from groundtruth. + Returns + ---------- + Tensor + Discriminator real loss value. + Tensor + Discriminator fake loss value. + """ + if isinstance(outputs, (tuple, list)): + real_loss = 0.0 + fake_loss = 0.0 + for i, (outputs_hat_, + outputs_) in enumerate(zip(outputs_hat, outputs)): + if isinstance(outputs_hat_, (tuple, list)): + # case including feature maps + outputs_hat_ = outputs_hat_[-1] + outputs_ = outputs_[-1] + real_loss += self.real_criterion(outputs_) + fake_loss += self.fake_criterion(outputs_hat_) + if self.average_by_discriminators: + fake_loss /= i + 1 + real_loss /= i + 1 + else: + real_loss = self.real_criterion(outputs) + fake_loss = self.fake_criterion(outputs_hat) + + return real_loss, fake_loss + + def _mse_real_loss(self, x): + return F.mse_loss(x, paddle.ones_like(x)) + + def _mse_fake_loss(self, x): + return F.mse_loss(x, paddle.zeros_like(x)) diff --git a/parakeet/modules/causal_conv.py b/parakeet/modules/causal_conv.py new file mode 100644 index 00000000..c0dd5b28 --- /dev/null +++ b/parakeet/modules/causal_conv.py @@ -0,0 +1,81 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Causal convolusion layer modules.""" +import paddle + + +class CausalConv1D(paddle.nn.Layer): + """CausalConv1D module with customized initialization.""" + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + dilation=1, + bias=True, + pad="Pad1D", + pad_params={"value": 0.0}, ): + """Initialize CausalConv1d module.""" + super().__init__() + self.pad = getattr(paddle.nn, pad)((kernel_size - 1) * dilation, + **pad_params) + self.conv = paddle.nn.Conv1D( + in_channels, + out_channels, + kernel_size, + dilation=dilation, + bias_attr=bias) + + def forward(self, x): + """Calculate forward propagation. + Parameters + ---------- + x : Tensor + Input tensor (B, in_channels, T). + Returns + ---------- + Tensor + Output tensor (B, out_channels, T). + """ + return self.conv(self.pad(x))[:, :, :x.shape[2]] + + +class CausalConv1DTranspose(paddle.nn.Layer): + """CausalConv1DTranspose module with customized initialization.""" + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + bias=True): + """Initialize CausalConvTranspose1d module.""" + super().__init__() + self.deconv = paddle.nn.Conv1DTranspose( + in_channels, out_channels, kernel_size, stride, bias_attr=bias) + self.stride = stride + + def forward(self, x): + """Calculate forward propagation. + Parameters + ---------- + x : Tensor + Input tensor (B, in_channels, T_in). + Returns + ---------- + Tensor + Output tensor (B, out_channels, T_out). + """ + return self.deconv(x)[:, :, :-self.stride] diff --git a/parakeet/modules/pqmf.py b/parakeet/modules/pqmf.py new file mode 100644 index 00000000..275addd2 --- /dev/null +++ b/parakeet/modules/pqmf.py @@ -0,0 +1,140 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pseudo QMF modules.""" +import numpy as np +import paddle +import paddle.nn.functional as F +from scipy.signal import kaiser + + +def design_prototype_filter(taps=62, cutoff_ratio=0.142, beta=9.0): + """Design prototype filter for PQMF. + This method is based on `A Kaiser window approach for the design of prototype + filters of cosine modulated filterbanks`_. + Parameters + ---------- + taps : int + The number of filter taps. + cutoff_ratio : float + Cut-off frequency ratio. + beta : float + Beta coefficient for kaiser window. + Returns + ---------- + ndarray + Impluse response of prototype filter (taps + 1,). + .. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`: + https://ieeexplore.ieee.org/abstract/document/681427 + """ + # check the arguments are valid + assert taps % 2 == 0, "The number of taps mush be even number." + assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0." + + # make initial filter + omega_c = np.pi * cutoff_ratio + with np.errstate(invalid="ignore"): + h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) / ( + np.pi * (np.arange(taps + 1) - 0.5 * taps)) + h_i[taps // + 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form + + # apply kaiser window + w = kaiser(taps + 1, beta) + h = h_i * w + + return h + + +class PQMF(paddle.nn.Layer): + """PQMF module. + This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_. + .. _`Near-perfect-reconstruction pseudo-QMF banks`: + https://ieeexplore.ieee.org/document/258122 + """ + + def __init__(self, subbands=4, taps=62, cutoff_ratio=0.142, beta=9.0): + """Initilize PQMF module. + The cutoff_ratio and beta parameters are optimized for #subbands = 4. + See dicussion in https://github.com/kan-bayashi/ParallelWaveGAN/issues/195. + Parameters + ---------- + subbands : int + The number of subbands. + taps : int + The number of filter taps. + cutoff_ratio : float + Cut-off frequency ratio. + beta : float + Beta coefficient for kaiser window. + """ + super(PQMF, self).__init__() + + # build analysis & synthesis filter coefficients + h_proto = design_prototype_filter(taps, cutoff_ratio, beta) + h_analysis = np.zeros((subbands, len(h_proto))) + h_synthesis = np.zeros((subbands, len(h_proto))) + for k in range(subbands): + h_analysis[k] = ( + 2 * h_proto * np.cos((2 * k + 1) * (np.pi / (2 * subbands)) * ( + np.arange(taps + 1) - (taps / 2)) + (-1)**k * np.pi / 4)) + h_synthesis[k] = ( + 2 * h_proto * np.cos((2 * k + 1) * (np.pi / (2 * subbands)) * ( + np.arange(taps + 1) - (taps / 2)) - (-1)**k * np.pi / 4)) + + # convert to tensor + self.analysis_filter = paddle.to_tensor( + h_analysis, dtype="float32").unsqueeze(1) + self.synthesis_filter = paddle.to_tensor( + h_synthesis, dtype="float32").unsqueeze(0) + + # filter for downsampling & upsampling + updown_filter = paddle.zeros( + (subbands, subbands, subbands), dtype="float32") + for k in range(subbands): + updown_filter[k, k, 0] = 1.0 + self.updown_filter = updown_filter + self.subbands = subbands + + # keep padding info + self.pad_fn = paddle.nn.Pad1D(taps // 2, mode='constant', value=0.0) + + def analysis(self, x): + """Analysis with PQMF. + Parameters + ---------- + x : Tensor + Input tensor (B, 1, T). + Returns + ---------- + Tensor + Output tensor (B, subbands, T // subbands). + """ + x = F.conv1d(self.pad_fn(x), self.analysis_filter) + return F.conv1d(x, self.updown_filter, stride=self.subbands) + + def synthesis(self, x): + """Synthesis with PQMF. + Parameters + ---------- + x : Tensor + Input tensor (B, subbands, T // subbands). + Returns + ---------- + Tensor + Output tensor (B, 1, T). + """ + + x = F.conv1d_transpose( + x, self.updown_filter * self.subbands, stride=self.subbands) + return F.conv1d(self.pad_fn(x), self.synthesis_filter) diff --git a/parakeet/modules/residual_stack.py b/parakeet/modules/residual_stack.py new file mode 100644 index 00000000..135c32e5 --- /dev/null +++ b/parakeet/modules/residual_stack.py @@ -0,0 +1,109 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Residual stack module in MelGAN.""" +from typing import Any +from typing import Dict + +from paddle import nn + +from parakeet.modules.causal_conv import CausalConv1D + + +class ResidualStack(nn.Layer): + """Residual stack module introduced in MelGAN.""" + + def __init__( + self, + kernel_size: int=3, + channels: int=32, + dilation: int=1, + bias: bool=True, + nonlinear_activation: str="LeakyReLU", + nonlinear_activation_params: Dict[str, Any]={"negative_slope": 0.2}, + pad: str="Pad1D", + pad_params: Dict[str, Any]={"mode": "reflect"}, + use_causal_conv: bool=False, ): + """Initialize ResidualStack module. + Parameters + ---------- + kernel_size : int + Kernel size of dilation convolution layer. + channels : int + Number of channels of convolution layers. + dilation : int + Dilation factor. + bias : bool + Whether to add bias parameter in convolution layers. + nonlinear_activation : str + Activation function module name. + nonlinear_activation_params : Dict[str,Any] + Hyperparameters for activation function. + pad : str + Padding function module name before dilated convolution layer. + pad_params : Dict[str, Any] + Hyperparameters for padding function. + use_causal_conv : bool + Whether to use causal convolution. + """ + super().__init__() + + # defile residual stack part + if not use_causal_conv: + assert (kernel_size - 1 + ) % 2 == 0, "Not support even number kernel size." + self.stack = nn.Sequential( + getattr(nn, nonlinear_activation)( + **nonlinear_activation_params), + getattr(nn, pad)((kernel_size - 1) // 2 * dilation, + **pad_params), + nn.Conv1D( + channels, + channels, + kernel_size, + dilation=dilation, + bias_attr=bias), + getattr(nn, nonlinear_activation)( + **nonlinear_activation_params), + nn.Conv1D(channels, channels, 1, bias_attr=bias), ) + else: + self.stack = nn.Sequential( + getattr(nn, nonlinear_activation)( + **nonlinear_activation_params), + CausalConv1D( + channels, + channels, + kernel_size, + dilation=dilation, + bias=bias, + pad=pad, + pad_params=pad_params, ), + getattr(nn, nonlinear_activation)( + **nonlinear_activation_params), + nn.Conv1D(channels, channels, 1, bias_attr=bias), ) + + # defile extra layer for skip connection + self.skip_layer = nn.Conv1D(channels, channels, 1, bias_attr=bias) + + def forward(self, c): + """Calculate forward propagation. + Parameters + ---------- + c : Tensor + Input tensor (B, channels, T). + Returns + ---------- + Tensor + Output tensor (B, chennels, T). + """ + return self.stack(c) + self.skip_layer(c) From 04bcb6a12d44641f3c17c6aac1b2c6bf5eabb752 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Thu, 28 Oct 2021 07:46:11 +0000 Subject: [PATCH 2/4] fix rtf, fix inf input of speedyspeech, fix stft dir for 2.2.0 --- parakeet/exps/gan_vocoder/multi_band_melgan/synthesize.py | 3 ++- parakeet/exps/gan_vocoder/parallelwave_gan/synthesize.py | 3 ++- parakeet/exps/speedyspeech/inference.py | 4 ++-- parakeet/modules/stft_loss.py | 2 +- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/parakeet/exps/gan_vocoder/multi_band_melgan/synthesize.py b/parakeet/exps/gan_vocoder/multi_band_melgan/synthesize.py index d48fbbd0..00b1b96c 100644 --- a/parakeet/exps/gan_vocoder/multi_band_melgan/synthesize.py +++ b/parakeet/exps/gan_vocoder/multi_band_melgan/synthesize.py @@ -86,8 +86,9 @@ def main(): N += wav.size T += t.elapse speed = wav.size / t.elapse + rtf = config.fs / speed print( - f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {config.fs / speed}." + f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." ) sf.write(str(output_dir / (utt_id + ".wav")), wav, samplerate=config.fs) print(f"generation speed: {N / T}Hz, RTF: {config.fs / (N / T) }") diff --git a/parakeet/exps/gan_vocoder/parallelwave_gan/synthesize.py b/parakeet/exps/gan_vocoder/parallelwave_gan/synthesize.py index 9129caa5..2400e00b 100644 --- a/parakeet/exps/gan_vocoder/parallelwave_gan/synthesize.py +++ b/parakeet/exps/gan_vocoder/parallelwave_gan/synthesize.py @@ -86,8 +86,9 @@ def main(): N += wav.size T += t.elapse speed = wav.size / t.elapse + rtf = config.fs / speed print( - f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {config.fs / speed}." + f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." ) sf.write(str(output_dir / (utt_id + ".wav")), wav, samplerate=config.fs) print(f"generation speed: {N / T}Hz, RTF: {config.fs / (N / T) }") diff --git a/parakeet/exps/speedyspeech/inference.py b/parakeet/exps/speedyspeech/inference.py index bf144d76..77a90915 100644 --- a/parakeet/exps/speedyspeech/inference.py +++ b/parakeet/exps/speedyspeech/inference.py @@ -96,8 +96,8 @@ def main(): input_ids = frontend.get_input_ids( sentence, merge_sentences=True, get_tone_ids=True) - phone_ids = input_ids["phone_ids"] - tone_ids = input_ids["tone_ids"] + phone_ids = input_ids["phone_ids"].numpy() + tone_ids = input_ids["tone_ids"].numpy() phones = phone_ids[0] tones = tone_ids[0] diff --git a/parakeet/modules/stft_loss.py b/parakeet/modules/stft_loss.py index 1f400b46..8af55ab1 100644 --- a/parakeet/modules/stft_loss.py +++ b/parakeet/modules/stft_loss.py @@ -51,7 +51,7 @@ def stft(x, # calculate window window = signal.get_window(window, win_length, fftbins=True) window = paddle.to_tensor(window) - x_stft = paddle.tensor.signal.stft( + x_stft = paddle.signal.stft( x, fft_size, hop_length, From 88668513b1986e39f458d149be564b21d85191e8 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Fri, 29 Oct 2021 06:23:13 +0000 Subject: [PATCH 3/4] fix mv writer to visualdl in train --- README.md | 39 ++- docs/source/tts/install.md | 4 +- examples/tiny/s0/README.md | 1 - parakeet/data/batch.py | 4 +- parakeet/exps/fastspeech2/train.py | 4 +- .../gan_vocoder/multi_band_melgan/train.py | 4 +- .../gan_vocoder/parallelwave_gan/train.py | 4 +- parakeet/exps/speedyspeech/train.py | 4 +- parakeet/exps/tacotron2/ljspeech.py | 9 +- parakeet/exps/transformer_tts/train.py | 4 +- parakeet/models/melgan/melgan_updater.py | 231 ------------------ parakeet/training/extensions/visualizer.py | 6 +- 12 files changed, 38 insertions(+), 276 deletions(-) delete mode 100644 parakeet/models/melgan/melgan_updater.py diff --git a/README.md b/README.md index e0769720..468f42a6 100644 --- a/README.md +++ b/README.md @@ -9,34 +9,34 @@ English | [简体中文](README_ch.md)

-

- Quick Start - | Tutorials - | Models List - +

+ Quick Start + | Tutorials + | Models List +

- + ------------------------------------------------------------------------------------ ![License](https://img.shields.io/badge/license-Apache%202-red.svg) ![python version](https://img.shields.io/badge/python-3.7+-orange.svg) ![support os](https://img.shields.io/badge/os-linux-yellow.svg) -**PaddleSpeech** is an open-source toolkit on [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) platform for two critical tasks in Speech - **Automatic Speech Recognition (ASR)** and **Text-To-Speech Synthesis (TTS)**, with modules involving state-of-art and influential models. +**PaddleSpeech** is an open-source toolkit on [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) platform for two critical tasks in Speech - **Automatic Speech Recognition (ASR)** and **Text-To-Speech Synthesis (TTS)**, with modules involving state-of-art and influential models. Via the easy-to-use, efficient, flexible and scalable implementation, our vision is to empower both industrial application and academic research, including training, inference & testing module, and deployment. Besides, this toolkit also features at: - **Fast and Light-weight**: we provide a high-speed and ultra-lightweight model that is convenient for industrial deployment. -- **Rule-based Chinese frontend**: our frontend contains Text Normalization (TN) and Grapheme-to-Phoneme (G2P, including Polyphone and Tone Sandhi). Moreover, we use self-defined linguistic rules to adapt Chinese context. -- **Varieties of Functions that Vitalize Research**: +- **Rule-based Chinese frontend**: our frontend contains Text Normalization (TN) and Grapheme-to-Phoneme (G2P, including Polyphone and Tone Sandhi). Moreover, we use self-defined linguistic rules to adapt Chinese context. +- **Varieties of Functions that Vitalize Research**: - *Integration of mainstream models and datasets*: the toolkit implements modules that participate in the whole pipeline of both ASR and TTS, and uses datasets like LibriSpeech, LJSpeech, AIShell, etc. See also [model lists](#models-list) for more details. - *Support of ASR streaming and non-streaming data*: This toolkit contains non-streaming/streaming models like [DeepSpeech2](http://proceedings.mlr.press/v48/amodei16.pdf), [Transformer](https://arxiv.org/abs/1706.03762), [Conformer](https://arxiv.org/abs/2005.08100) and [U2](https://arxiv.org/pdf/2012.05481.pdf). - -Let's install PaddleSpeech with only a few lines of code! + +Let's install PaddleSpeech with only a few lines of code! >Note: The official name is still deepspeech. 2021/10/26 @@ -44,7 +44,7 @@ Let's install PaddleSpeech with only a few lines of code! # 1. Install essential libraries and paddlepaddle first. # install prerequisites sudo apt-get install -y sox pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev libsndfile1 -# `pip install paddlepaddle-gpu` instead if you are using GPU. +# `pip install paddlepaddle-gpu` instead if you are using GPU. pip install paddlepaddle # 2.Then install PaddleSpeech. @@ -109,7 +109,7 @@ If you want to try more functions like training and tuning, please see [ASR gett PaddleSpeech ASR supports a lot of mainstream models, which are summarized as follow. For more information, please refer to [ASR Models](./docs/source/asr/released_model.md). @@ -125,7 +125,7 @@ The current hyperlinks redirect to [Previous Parakeet](https://github.com/Paddle - + @@ -199,7 +199,7 @@ PaddleSpeech TTS mainly contains three modules: *Text Frontend*, *Acoustic Model - @@ -292,11 +292,11 @@ PaddleSpeech TTS mainly contains three modules: *Text Frontend*, *Acoustic Model
Acoustic Model Aishell2 Conv + 5 LSTM layers with only forward direction 2 Conv + 5 LSTM layers with only forward direction Ds2 Online Aishell Model
Text Frontend + chinese-fronted
-## Tutorials +## Tutorials Normally, [Speech SoTA](https://paperswithcode.com/area/speech) gives you an overview of the hot academic topics in speech. If you want to focus on the two tasks in PaddleSpeech, you will find the following guidelines are helpful to grasp the core ideas. -The original ASR module is based on [Baidu's DeepSpeech](https://arxiv.org/abs/1412.5567) which is an independent product named [DeepSpeech](https://deepspeech.readthedocs.io). However, the toolkit aligns almost all the SoTA modules in the pipeline. Specifically, these modules are +The original ASR module is based on [Baidu's DeepSpeech](https://arxiv.org/abs/1412.5567) which is an independent product named [DeepSpeech](https://deepspeech.readthedocs.io). However, the toolkit aligns almost all the SoTA modules in the pipeline. Specifically, these modules are * [Data Prepration](docs/source/asr/data_preparation.md) * [Data Augmentation](docs/source/asr/augmentation.md) @@ -318,4 +318,3 @@ PaddleSpeech is provided under the [Apache-2.0 License](./LICENSE). ## Acknowledgement PaddleSpeech depends on a lot of open source repos. See [references](docs/source/asr/reference.md) for more information. - diff --git a/docs/source/tts/install.md b/docs/source/tts/install.md index c4249a18..b092acff 100644 --- a/docs/source/tts/install.md +++ b/docs/source/tts/install.md @@ -10,13 +10,13 @@ Example instruction to install paddlepaddle via pip is listed below. ### PaddlePaddle with GPU ```python -# PaddlePaddle for CUDA10.1 +# PaddlePaddle for CUDA10.1 python -m pip install paddlepaddle-gpu==2.1.2.post101 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html # PaddlePaddle for CUDA10.2 python -m pip install paddlepaddle-gpu -i https://mirror.baidu.com/pypi/simple # PaddlePaddle for CUDA11.0 python -m pip install paddlepaddle-gpu==2.1.2.post110 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html -# PaddlePaddle for CUDA11.2 +# PaddlePaddle for CUDA11.2 python -m pip install paddlepaddle-gpu==2.1.2.post112 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html ``` ### PaddlePaddle with CPU diff --git a/examples/tiny/s0/README.md b/examples/tiny/s0/README.md index 7dc16dc3..11118dc4 100644 --- a/examples/tiny/s0/README.md +++ b/examples/tiny/s0/README.md @@ -37,4 +37,3 @@ ```bash bash local/export.sh ckpt_path saved_jit_model_path ``` - diff --git a/parakeet/data/batch.py b/parakeet/data/batch.py index 5e7ac399..515074d1 100644 --- a/parakeet/data/batch.py +++ b/parakeet/data/batch.py @@ -53,8 +53,8 @@ def batch_text_id(minibatch, pad_id=0, dtype=np.int64): peek_example = minibatch[0] assert len(peek_example.shape) == 1, "text example is an 1D tensor" - lengths = [example.shape[0] for example in - minibatch] # assume (channel, n_samples) or (n_samples, ) + lengths = [example.shape[0] for example in minibatch + ] # assume (channel, n_samples) or (n_samples, ) max_len = np.max(lengths) batch = [] diff --git a/parakeet/exps/fastspeech2/train.py b/parakeet/exps/fastspeech2/train.py index 59b1ea3a..47ad1b4d 100644 --- a/parakeet/exps/fastspeech2/train.py +++ b/parakeet/exps/fastspeech2/train.py @@ -25,7 +25,6 @@ from paddle import DataParallel from paddle import distributed as dist from paddle.io import DataLoader from paddle.io import DistributedBatchSampler -from visualdl import LogWriter from yacs.config import CfgNode from parakeet.datasets.am_batch_fn import fastspeech2_multi_spk_batch_fn @@ -160,8 +159,7 @@ def train_sp(args, config): if dist.get_rank() == 0: trainer.extend(evaluator, trigger=(1, "epoch")) - writer = LogWriter(str(output_dir)) - trainer.extend(VisualDL(writer), trigger=(1, "iteration")) + trainer.extend(VisualDL(output_dir), trigger=(1, "iteration")) trainer.extend( Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch')) # print(trainer.extensions) diff --git a/parakeet/exps/gan_vocoder/multi_band_melgan/train.py b/parakeet/exps/gan_vocoder/multi_band_melgan/train.py index bb9b0b8a..c03fb354 100644 --- a/parakeet/exps/gan_vocoder/multi_band_melgan/train.py +++ b/parakeet/exps/gan_vocoder/multi_band_melgan/train.py @@ -28,7 +28,6 @@ from paddle.io import DataLoader from paddle.io import DistributedBatchSampler from paddle.optimizer import Adam from paddle.optimizer.lr import MultiStepDecay -from visualdl import LogWriter from yacs.config import CfgNode from parakeet.datasets.data_table import DataTable @@ -219,8 +218,7 @@ def train_sp(args, config): if dist.get_rank() == 0: trainer.extend( evaluator, trigger=(config.eval_interval_steps, 'iteration')) - writer = LogWriter(str(trainer.out)) - trainer.extend(VisualDL(writer), trigger=(1, 'iteration')) + trainer.extend(VisualDL(output_dir), trigger=(1, 'iteration')) trainer.extend( Snapshot(max_size=config.num_snapshots), trigger=(config.save_interval_steps, 'iteration')) diff --git a/parakeet/exps/gan_vocoder/parallelwave_gan/train.py b/parakeet/exps/gan_vocoder/parallelwave_gan/train.py index 7a16ca59..ad50b65c 100644 --- a/parakeet/exps/gan_vocoder/parallelwave_gan/train.py +++ b/parakeet/exps/gan_vocoder/parallelwave_gan/train.py @@ -28,7 +28,6 @@ from paddle.io import DataLoader from paddle.io import DistributedBatchSampler from paddle.optimizer import Adam # No RAdaom from paddle.optimizer.lr import StepDecay -from visualdl import LogWriter from yacs.config import CfgNode from parakeet.datasets.data_table import DataTable @@ -193,8 +192,7 @@ def train_sp(args, config): if dist.get_rank() == 0: trainer.extend( evaluator, trigger=(config.eval_interval_steps, 'iteration')) - writer = LogWriter(str(trainer.out)) - trainer.extend(VisualDL(writer), trigger=(1, 'iteration')) + trainer.extend(VisualDL(output_dir), trigger=(1, 'iteration')) trainer.extend( Snapshot(max_size=config.num_snapshots), trigger=(config.save_interval_steps, 'iteration')) diff --git a/parakeet/exps/speedyspeech/train.py b/parakeet/exps/speedyspeech/train.py index ea9fe20d..6a4bf59e 100644 --- a/parakeet/exps/speedyspeech/train.py +++ b/parakeet/exps/speedyspeech/train.py @@ -25,7 +25,6 @@ from paddle import DataParallel from paddle import distributed as dist from paddle.io import DataLoader from paddle.io import DistributedBatchSampler -from visualdl import LogWriter from yacs.config import CfgNode from parakeet.datasets.am_batch_fn import speedyspeech_batch_fn @@ -153,8 +152,7 @@ def train_sp(args, config): if dist.get_rank() == 0: trainer.extend(evaluator, trigger=(1, "epoch")) - writer = LogWriter(str(output_dir)) - trainer.extend(VisualDL(writer), trigger=(1, "iteration")) + trainer.extend(VisualDL(output_dir), trigger=(1, "iteration")) trainer.extend( Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch')) trainer.run() diff --git a/parakeet/exps/tacotron2/ljspeech.py b/parakeet/exps/tacotron2/ljspeech.py index 59c855eb..20dc29d3 100644 --- a/parakeet/exps/tacotron2/ljspeech.py +++ b/parakeet/exps/tacotron2/ljspeech.py @@ -67,16 +67,19 @@ class LJSpeechCollector(object): # Sort by text_len in descending order texts = [ - i for i, _ in sorted( + i + for i, _ in sorted( zip(texts, text_lens), key=lambda x: x[1], reverse=True) ] mels = [ - i for i, _ in sorted( + i + for i, _ in sorted( zip(mels, text_lens), key=lambda x: x[1], reverse=True) ] mel_lens = [ - i for i, _ in sorted( + i + for i, _ in sorted( zip(mel_lens, text_lens), key=lambda x: x[1], reverse=True) ] diff --git a/parakeet/exps/transformer_tts/train.py b/parakeet/exps/transformer_tts/train.py index fdaff347..bf066390 100644 --- a/parakeet/exps/transformer_tts/train.py +++ b/parakeet/exps/transformer_tts/train.py @@ -25,7 +25,6 @@ from paddle import DataParallel from paddle import distributed as dist from paddle.io import DataLoader from paddle.io import DistributedBatchSampler -from visualdl import LogWriter from yacs.config import CfgNode from parakeet.datasets.am_batch_fn import transformer_single_spk_batch_fn @@ -148,8 +147,7 @@ def train_sp(args, config): if dist.get_rank() == 0: trainer.extend(evaluator, trigger=(1, "epoch")) - writer = LogWriter(str(output_dir)) - trainer.extend(VisualDL(writer), trigger=(1, "iteration")) + trainer.extend(VisualDL(output_dir), trigger=(1, "iteration")) trainer.extend( Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch')) # print(trainer.extensions) diff --git a/parakeet/models/melgan/melgan_updater.py b/parakeet/models/melgan/melgan_updater.py deleted file mode 100644 index 7bd59881..00000000 --- a/parakeet/models/melgan/melgan_updater.py +++ /dev/null @@ -1,231 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from typing import Dict - -import paddle -from paddle import distributed as dist -from paddle.io import DataLoader -from paddle.nn import Layer -from paddle.optimizer import Optimizer -from paddle.optimizer.lr import LRScheduler -from timer import timer - -from parakeet.training.extensions.evaluator import StandardEvaluator -from parakeet.training.reporter import report -from parakeet.training.updaters.standard_updater import StandardUpdater -from parakeet.training.updaters.standard_updater import UpdaterState -logging.basicConfig( - format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s', - datefmt='[%Y-%m-%d %H:%M:%S]') -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - - -class PWGUpdater(StandardUpdater): - def __init__(self, - models: Dict[str, Layer], - optimizers: Dict[str, Optimizer], - criterions: Dict[str, Layer], - schedulers: Dict[str, LRScheduler], - dataloader: DataLoader, - discriminator_train_start_steps: int, - lambda_adv: float, - output_dir=None): - self.models = models - self.generator: Layer = models['generator'] - self.discriminator: Layer = models['discriminator'] - - self.optimizers = optimizers - self.optimizer_g: Optimizer = optimizers['generator'] - self.optimizer_d: Optimizer = optimizers['discriminator'] - - self.criterions = criterions - self.criterion_stft = criterions['stft'] - self.criterion_mse = criterions['mse'] - - self.schedulers = schedulers - self.scheduler_g = schedulers['generator'] - self.scheduler_d = schedulers['discriminator'] - - self.dataloader = dataloader - - self.discriminator_train_start_steps = discriminator_train_start_steps - self.lambda_adv = lambda_adv - self.state = UpdaterState(iteration=0, epoch=0) - - self.train_iterator = iter(self.dataloader) - - log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) - self.filehandler = logging.FileHandler(str(log_file)) - logger.addHandler(self.filehandler) - self.logger = logger - self.msg = "" - - def update_core(self, batch): - self.msg = "Rank: {}, ".format(dist.get_rank()) - losses_dict = {} - - # parse batch - wav, mel = batch - - # Generator - noise = paddle.randn(wav.shape) - - with timer() as t: - wav_ = self.generator(noise, mel) - # logging.debug(f"Generator takes {t.elapse}s.") - - # initialize - gen_loss = 0.0 - - ## Multi-resolution stft loss - with timer() as t: - sc_loss, mag_loss = self.criterion_stft(wav_, wav) - # logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s.") - - report("train/spectral_convergence_loss", float(sc_loss)) - report("train/log_stft_magnitude_loss", float(mag_loss)) - - losses_dict["spectral_convergence_loss"] = float(sc_loss) - losses_dict["log_stft_magnitude_loss"] = float(mag_loss) - - gen_loss += sc_loss + mag_loss - - ## Adversarial loss - if self.state.iteration > self.discriminator_train_start_steps: - with timer() as t: - p_ = self.discriminator(wav_) - adv_loss = self.criterion_mse(p_, paddle.ones_like(p_)) - # logging.debug( - # f"Discriminator and adversarial loss takes {t.elapse}s") - report("train/adversarial_loss", float(adv_loss)) - losses_dict["adversarial_loss"] = float(adv_loss) - gen_loss += self.lambda_adv * adv_loss - - report("train/generator_loss", float(gen_loss)) - losses_dict["generator_loss"] = float(gen_loss) - - with timer() as t: - self.optimizer_g.clear_grad() - gen_loss.backward() - # logging.debug(f"Backward takes {t.elapse}s.") - - with timer() as t: - self.optimizer_g.step() - self.scheduler_g.step() - # logging.debug(f"Update takes {t.elapse}s.") - - # Disctiminator - if self.state.iteration > self.discriminator_train_start_steps: - with paddle.no_grad(): - wav_ = self.generator(noise, mel) - p = self.discriminator(wav) - p_ = self.discriminator(wav_.detach()) - real_loss = self.criterion_mse(p, paddle.ones_like(p)) - fake_loss = self.criterion_mse(p_, paddle.zeros_like(p_)) - dis_loss = real_loss + fake_loss - report("train/real_loss", float(real_loss)) - report("train/fake_loss", float(fake_loss)) - report("train/discriminator_loss", float(dis_loss)) - losses_dict["real_loss"] = float(real_loss) - losses_dict["fake_loss"] = float(fake_loss) - losses_dict["discriminator_loss"] = float(dis_loss) - - self.optimizer_d.clear_grad() - dis_loss.backward() - - self.optimizer_d.step() - self.scheduler_d.step() - - self.msg += ', '.join('{}: {:>.6f}'.format(k, v) - for k, v in losses_dict.items()) - - -class PWGEvaluator(StandardEvaluator): - def __init__(self, - models, - criterions, - dataloader, - lambda_adv, - output_dir=None): - self.models = models - self.generator = models['generator'] - self.discriminator = models['discriminator'] - - self.criterions = criterions - self.criterion_stft = criterions['stft'] - self.criterion_mse = criterions['mse'] - - self.dataloader = dataloader - self.lambda_adv = lambda_adv - - log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) - self.filehandler = logging.FileHandler(str(log_file)) - logger.addHandler(self.filehandler) - self.logger = logger - self.msg = "" - - def evaluate_core(self, batch): - # logging.debug("Evaluate: ") - self.msg = "Evaluate: " - losses_dict = {} - - wav, mel = batch - noise = paddle.randn(wav.shape) - - with timer() as t: - wav_ = self.generator(noise, mel) - # logging.debug(f"Generator takes {t.elapse}s") - - ## Adversarial loss - with timer() as t: - p_ = self.discriminator(wav_) - adv_loss = self.criterion_mse(p_, paddle.ones_like(p_)) - # logging.debug( - # f"Discriminator and adversarial loss takes {t.elapse}s") - report("eval/adversarial_loss", float(adv_loss)) - losses_dict["adversarial_loss"] = float(adv_loss) - gen_loss = self.lambda_adv * adv_loss - - # stft loss - with timer() as t: - sc_loss, mag_loss = self.criterion_stft(wav_, wav) - # logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s") - - report("eval/spectral_convergence_loss", float(sc_loss)) - report("eval/log_stft_magnitude_loss", float(mag_loss)) - losses_dict["spectral_convergence_loss"] = float(sc_loss) - losses_dict["log_stft_magnitude_loss"] = float(mag_loss) - gen_loss += sc_loss + mag_loss - - report("eval/generator_loss", float(gen_loss)) - losses_dict["generator_loss"] = float(gen_loss) - - # Disctiminator - p = self.discriminator(wav) - real_loss = self.criterion_mse(p, paddle.ones_like(p)) - fake_loss = self.criterion_mse(p_, paddle.zeros_like(p_)) - dis_loss = real_loss + fake_loss - report("eval/real_loss", float(real_loss)) - report("eval/fake_loss", float(fake_loss)) - report("eval/discriminator_loss", float(dis_loss)) - - losses_dict["real_loss"] = float(real_loss) - losses_dict["fake_loss"] = float(fake_loss) - losses_dict["discriminator_loss"] = float(dis_loss) - - self.msg += ', '.join('{}: {:>.6f}'.format(k, v) - for k, v in losses_dict.items()) - self.logger.info(self.msg) diff --git a/parakeet/training/extensions/visualizer.py b/parakeet/training/extensions/visualizer.py index 1c66ad8d..bc62c976 100644 --- a/parakeet/training/extensions/visualizer.py +++ b/parakeet/training/extensions/visualizer.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from visualdl import LogWriter + from parakeet.training import extension from parakeet.training.trainer import Trainer @@ -26,8 +28,8 @@ class VisualDL(extension.Extension): default_name = 'visualdl' priority = extension.PRIORITY_READER - def __init__(self, writer): - self.writer = writer + def __init__(self, logdir): + self.writer = LogWriter(str(logdir)) def __call__(self, trainer: Trainer): for k, v in trainer.observation.items(): From 9125d71a8193ee2f86680eddc2d408395869b348 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Fri, 29 Oct 2021 06:48:16 +0000 Subject: [PATCH 4/4] fix pwg inference --- deepspeech/exps/deepspeech2/model.py | 1 - examples/csmsc/voc3/conf/default.yaml | 2 +- examples/csmsc/voc3/conf/use_tanh.yaml | 139 ++++++++++++++++++ parakeet/models/melgan/melgan.py | 9 +- .../parallel_wavegan/parallel_wavegan.py | 14 +- parakeet/modules/residual_stack.py | 5 +- 6 files changed, 155 insertions(+), 15 deletions(-) create mode 100644 examples/csmsc/voc3/conf/use_tanh.yaml diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 6424cfdf..5c010f56 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -189,7 +189,6 @@ class DeepSpeech2Trainer(Trainer): self.lr_scheduler = lr_scheduler logger.info("Setup optimizer/lr_scheduler!") - def setup_dataloader(self): config = self.config.clone() config.defrost() diff --git a/examples/csmsc/voc3/conf/default.yaml b/examples/csmsc/voc3/conf/default.yaml index 87a237c3..f6fcfced 100644 --- a/examples/csmsc/voc3/conf/default.yaml +++ b/examples/csmsc/voc3/conf/default.yaml @@ -88,7 +88,7 @@ lambda_adv: 2.5 # Loss balancing coefficient for adversarial loss. ########################################################### batch_size: 64 # Batch size. batch_max_steps: 16200 # Length of each audio in batch. Make sure dividable by hop_size. -num_workers: 4 # Number of workers in DataLoader. +num_workers: 2 # Number of workers in DataLoader. ########################################################### # OPTIMIZER & SCHEDULER SETTING # diff --git a/examples/csmsc/voc3/conf/use_tanh.yaml b/examples/csmsc/voc3/conf/use_tanh.yaml new file mode 100644 index 00000000..820c2a76 --- /dev/null +++ b/examples/csmsc/voc3/conf/use_tanh.yaml @@ -0,0 +1,139 @@ +# This is the hyperparameter configuration file for MelGAN. +# Please make sure this is adjusted for the CSMSC dataset. If you want to +# apply to the other dataset, you might need to carefully change some parameters. +# This configuration requires ~ 8GB memory and will finish within 7 days on Titan V. + +# This configuration is based on full-band MelGAN but the hop size and sampling +# rate is different from the paper (16kHz vs 24kHz). The number of iteraions +# is not shown in the paper so currently we train 1M iterations (not sure enough +# to converge). The optimizer setting is based on @dathudeptrai advice. +# https://github.com/kan-bayashi/ParallelWaveGAN/issues/143#issuecomment-632539906 + +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +fs: 24000 # Sampling rate. +n_fft: 2048 # FFT size. (in samples) +n_shift: 300 # Hop size. (in samples) +win_length: 1200 # Window length. (in samples) + # If set to null, it will be the same as fft_size. +window: "hann" # Window function. +n_mels: 80 # Number of mel basis. +fmin: 80 # Minimum freq in mel basis calculation. (Hz) +fmax: 7600 # Maximum frequency in mel basis calculation. (Hz) + +########################################################### +# GENERATOR NETWORK ARCHITECTURE SETTING # +########################################################### +generator_params: + in_channels: 80 # Number of input channels. + out_channels: 4 # Number of output channels. + kernel_size: 7 # Kernel size of initial and final conv layers. + channels: 384 # Initial number of channels for conv layers. + upsample_scales: [5, 5, 3] # List of Upsampling scales. + stack_kernel_size: 3 # Kernel size of dilated conv layers in residual stack. + stacks: 4 # Number of stacks in a single residual stack module. + use_weight_norm: True # Whether to use weight normalization. + use_causal_conv: False # Whether to use causal convolution. + use_final_nonlinear_activation: True # If True, spectral_convergence_loss and sub_spectral_convergence_loss will be too large (eg.30) + + +########################################################### +# DISCRIMINATOR NETWORK ARCHITECTURE SETTING # +########################################################### +discriminator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + scales: 3 # Number of multi-scales. + downsample_pooling: "AvgPool1D" # Pooling type for the input downsampling. + downsample_pooling_params: # Parameters of the above pooling function. + kernel_size: 4 + stride: 2 + padding: 1 + exclusive: True + kernel_sizes: [5, 3] # List of kernel size. + channels: 16 # Number of channels of the initial conv layer. + max_downsample_channels: 512 # Maximum number of channels of downsampling layers. + downsample_scales: [4, 4, 4] # List of downsampling scales. + nonlinear_activation: "LeakyReLU" # Nonlinear activation function. + nonlinear_activation_params: # Parameters of nonlinear activation function. + negative_slope: 0.2 + use_weight_norm: True # Whether to use weight norm. + + +########################################################### +# STFT LOSS SETTING # +########################################################### +use_stft_loss: true +stft_loss_params: + fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss. + hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss + win_lengths: [600, 1200, 240] # List of window length for STFT-based loss. + window: "hann" # Window function for STFT-based loss +use_subband_stft_loss: true +subband_stft_loss_params: + fft_sizes: [384, 683, 171] # List of FFT size for STFT-based loss. + hop_sizes: [30, 60, 10] # List of hop size for STFT-based loss + win_lengths: [150, 300, 60] # List of window length for STFT-based loss. + window: "hann" # Window function for STFT-based loss + +########################################################### +# ADVERSARIAL LOSS SETTING # +########################################################### +use_feat_match_loss: false # Whether to use feature matching loss. +lambda_adv: 2.5 # Loss balancing coefficient for adversarial loss. + +########################################################### +# DATA LOADER SETTING # +########################################################### +batch_size: 64 # Batch size. +batch_max_steps: 16200 # Length of each audio in batch. Make sure dividable by hop_size. +num_workers: 2 # Number of workers in DataLoader. + +########################################################### +# OPTIMIZER & SCHEDULER SETTING # +########################################################### +generator_optimizer_params: + epsilon: 1.0e-7 # Generator's epsilon. + weight_decay: 0.0 # Generator's weight decay coefficient. + +generator_grad_norm: -1 # Generator's gradient norm. +generator_scheduler_params: + learning_rate: 1.0e-3 # Generator's learning rate. + gamma: 0.5 # Generator's scheduler gamma. + milestones: # At each milestone, lr will be multiplied by gamma. + - 100000 + - 200000 + - 300000 + - 400000 + - 500000 + - 600000 +discriminator_optimizer_params: + epsilon: 1.0e-7 # Discriminator's epsilon. + weight_decay: 0.0 # Discriminator's weight decay coefficient. + +discriminator_grad_norm: -1 # Discriminator's gradient norm. +discriminator_scheduler_params: + learning_rate: 1.0e-3 # Discriminator's learning rate. + gamma: 0.5 # Discriminator's scheduler gamma. + milestones: # At each milestone, lr will be multiplied by gamma. + - 100000 + - 200000 + - 300000 + - 400000 + - 500000 + - 600000 + +########################################################### +# INTERVAL SETTING # +########################################################### +discriminator_train_start_steps: 200000 # Number of steps to start to train discriminator. +train_max_steps: 1000000 # Number of training steps. +save_interval_steps: 50000 # Interval steps to save checkpoint. +eval_interval_steps: 1000 # Interval steps to evaluate the network. + +########################################################### +# OTHER SETTING # +########################################################### +num_snapshots: 10 # max number of snapshots to keep while training +seed: 42 # random seed for paddle, random, and np.random \ No newline at end of file diff --git a/parakeet/models/melgan/melgan.py b/parakeet/models/melgan/melgan.py index ccc19d5f..0347ff22 100644 --- a/parakeet/models/melgan/melgan.py +++ b/parakeet/models/melgan/melgan.py @@ -19,7 +19,6 @@ from typing import List import numpy as np import paddle from paddle import nn -from paddle.fluid.layers import Normal from parakeet.modules.causal_conv import CausalConv1D from parakeet.modules.causal_conv import CausalConv1DTranspose @@ -238,7 +237,7 @@ class MelGANGenerator(nn.Layer): """ # 定义参数为float的正态分布。 - dist = Normal(loc=0.0, scale=0.02) + dist = paddle.distribution.Normal(loc=0.0, scale=0.02) def _reset_parameters(m): if isinstance(m, nn.Conv1D) or isinstance(m, nn.Conv1DTranspose): @@ -290,8 +289,8 @@ class MelGANDiscriminator(nn.Layer): """Initilize MelGAN discriminator module. Parameters ---------- - in_channels : - int): Number of input channels. + in_channels : int + Number of input channels. out_channels : int Number of output channels. kernel_sizes : List[int] @@ -531,7 +530,7 @@ class MelGANMultiScaleDiscriminator(nn.Layer): """ # 定义参数为float的正态分布。 - dist = Normal(loc=0.0, scale=0.02) + dist = paddle.distribution.Normal(loc=0.0, scale=0.02) def _reset_parameters(m): if isinstance(m, nn.Conv1D) or isinstance(m, nn.Conv1DTranspose): diff --git a/parakeet/models/parallel_wavegan/parallel_wavegan.py b/parakeet/models/parallel_wavegan/parallel_wavegan.py index e166ccde..fe4ec355 100644 --- a/parakeet/models/parallel_wavegan/parallel_wavegan.py +++ b/parakeet/models/parallel_wavegan/parallel_wavegan.py @@ -495,25 +495,25 @@ class PWGGenerator(nn.Layer): self.apply(_remove_weight_norm) - def inference(self, c): + def inference(self, c=None): """Waveform generation. This function is used for single instance inference. - Parameters ---------- - c : Tensor + c : Tensor, optional Shape (T', C_aux), the auxiliary input, by default None - + x : Tensor, optional + Shape (T, C_in), the noise waveform, by default None + If not provided, a sample is drawn from a gaussian distribution. Returns ------- Tensor Shape (T, C_out), the generated waveform """ - # a sample is drawn from a gaussian distribution. + # when to static, can not input x, see https://github.com/PaddlePaddle/Parakeet/pull/132/files x = paddle.randn( [1, self.in_channels, paddle.shape(c)[0] * self.upsample_factor]) - # pseudo batch - c = paddle.transpose(c, [1, 0]).unsqueeze(0) + c = paddle.transpose(c, [1, 0]).unsqueeze(0) # pseudo batch c = nn.Pad1D(self.aux_context_window, mode='replicate')(c) out = self(x, c).squeeze(0).transpose([1, 0]) return out diff --git a/parakeet/modules/residual_stack.py b/parakeet/modules/residual_stack.py index 135c32e5..b798fbb6 100644 --- a/parakeet/modules/residual_stack.py +++ b/parakeet/modules/residual_stack.py @@ -106,4 +106,7 @@ class ResidualStack(nn.Layer): Tensor Output tensor (B, chennels, T). """ - return self.stack(c) + self.skip_layer(c) + stack_output = self.stack(c) + skip_layer_output = self.skip_layer(c) + out = stack_output + skip_layer_output + return out