From 9125d71a8193ee2f86680eddc2d408395869b348 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Fri, 29 Oct 2021 06:48:16 +0000 Subject: [PATCH] fix pwg inference --- deepspeech/exps/deepspeech2/model.py | 1 - examples/csmsc/voc3/conf/default.yaml | 2 +- examples/csmsc/voc3/conf/use_tanh.yaml | 139 ++++++++++++++++++ parakeet/models/melgan/melgan.py | 9 +- .../parallel_wavegan/parallel_wavegan.py | 14 +- parakeet/modules/residual_stack.py | 5 +- 6 files changed, 155 insertions(+), 15 deletions(-) create mode 100644 examples/csmsc/voc3/conf/use_tanh.yaml diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 6424cfdf..5c010f56 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -189,7 +189,6 @@ class DeepSpeech2Trainer(Trainer): self.lr_scheduler = lr_scheduler logger.info("Setup optimizer/lr_scheduler!") - def setup_dataloader(self): config = self.config.clone() config.defrost() diff --git a/examples/csmsc/voc3/conf/default.yaml b/examples/csmsc/voc3/conf/default.yaml index 87a237c3..f6fcfced 100644 --- a/examples/csmsc/voc3/conf/default.yaml +++ b/examples/csmsc/voc3/conf/default.yaml @@ -88,7 +88,7 @@ lambda_adv: 2.5 # Loss balancing coefficient for adversarial loss. ########################################################### batch_size: 64 # Batch size. batch_max_steps: 16200 # Length of each audio in batch. Make sure dividable by hop_size. -num_workers: 4 # Number of workers in DataLoader. +num_workers: 2 # Number of workers in DataLoader. ########################################################### # OPTIMIZER & SCHEDULER SETTING # diff --git a/examples/csmsc/voc3/conf/use_tanh.yaml b/examples/csmsc/voc3/conf/use_tanh.yaml new file mode 100644 index 00000000..820c2a76 --- /dev/null +++ b/examples/csmsc/voc3/conf/use_tanh.yaml @@ -0,0 +1,139 @@ +# This is the hyperparameter configuration file for MelGAN. +# Please make sure this is adjusted for the CSMSC dataset. If you want to +# apply to the other dataset, you might need to carefully change some parameters. +# This configuration requires ~ 8GB memory and will finish within 7 days on Titan V. + +# This configuration is based on full-band MelGAN but the hop size and sampling +# rate is different from the paper (16kHz vs 24kHz). The number of iteraions +# is not shown in the paper so currently we train 1M iterations (not sure enough +# to converge). The optimizer setting is based on @dathudeptrai advice. +# https://github.com/kan-bayashi/ParallelWaveGAN/issues/143#issuecomment-632539906 + +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +fs: 24000 # Sampling rate. +n_fft: 2048 # FFT size. (in samples) +n_shift: 300 # Hop size. (in samples) +win_length: 1200 # Window length. (in samples) + # If set to null, it will be the same as fft_size. +window: "hann" # Window function. +n_mels: 80 # Number of mel basis. +fmin: 80 # Minimum freq in mel basis calculation. (Hz) +fmax: 7600 # Maximum frequency in mel basis calculation. (Hz) + +########################################################### +# GENERATOR NETWORK ARCHITECTURE SETTING # +########################################################### +generator_params: + in_channels: 80 # Number of input channels. + out_channels: 4 # Number of output channels. + kernel_size: 7 # Kernel size of initial and final conv layers. + channels: 384 # Initial number of channels for conv layers. + upsample_scales: [5, 5, 3] # List of Upsampling scales. + stack_kernel_size: 3 # Kernel size of dilated conv layers in residual stack. + stacks: 4 # Number of stacks in a single residual stack module. + use_weight_norm: True # Whether to use weight normalization. + use_causal_conv: False # Whether to use causal convolution. + use_final_nonlinear_activation: True # If True, spectral_convergence_loss and sub_spectral_convergence_loss will be too large (eg.30) + + +########################################################### +# DISCRIMINATOR NETWORK ARCHITECTURE SETTING # +########################################################### +discriminator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + scales: 3 # Number of multi-scales. + downsample_pooling: "AvgPool1D" # Pooling type for the input downsampling. + downsample_pooling_params: # Parameters of the above pooling function. + kernel_size: 4 + stride: 2 + padding: 1 + exclusive: True + kernel_sizes: [5, 3] # List of kernel size. + channels: 16 # Number of channels of the initial conv layer. + max_downsample_channels: 512 # Maximum number of channels of downsampling layers. + downsample_scales: [4, 4, 4] # List of downsampling scales. + nonlinear_activation: "LeakyReLU" # Nonlinear activation function. + nonlinear_activation_params: # Parameters of nonlinear activation function. + negative_slope: 0.2 + use_weight_norm: True # Whether to use weight norm. + + +########################################################### +# STFT LOSS SETTING # +########################################################### +use_stft_loss: true +stft_loss_params: + fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss. + hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss + win_lengths: [600, 1200, 240] # List of window length for STFT-based loss. + window: "hann" # Window function for STFT-based loss +use_subband_stft_loss: true +subband_stft_loss_params: + fft_sizes: [384, 683, 171] # List of FFT size for STFT-based loss. + hop_sizes: [30, 60, 10] # List of hop size for STFT-based loss + win_lengths: [150, 300, 60] # List of window length for STFT-based loss. + window: "hann" # Window function for STFT-based loss + +########################################################### +# ADVERSARIAL LOSS SETTING # +########################################################### +use_feat_match_loss: false # Whether to use feature matching loss. +lambda_adv: 2.5 # Loss balancing coefficient for adversarial loss. + +########################################################### +# DATA LOADER SETTING # +########################################################### +batch_size: 64 # Batch size. +batch_max_steps: 16200 # Length of each audio in batch. Make sure dividable by hop_size. +num_workers: 2 # Number of workers in DataLoader. + +########################################################### +# OPTIMIZER & SCHEDULER SETTING # +########################################################### +generator_optimizer_params: + epsilon: 1.0e-7 # Generator's epsilon. + weight_decay: 0.0 # Generator's weight decay coefficient. + +generator_grad_norm: -1 # Generator's gradient norm. +generator_scheduler_params: + learning_rate: 1.0e-3 # Generator's learning rate. + gamma: 0.5 # Generator's scheduler gamma. + milestones: # At each milestone, lr will be multiplied by gamma. + - 100000 + - 200000 + - 300000 + - 400000 + - 500000 + - 600000 +discriminator_optimizer_params: + epsilon: 1.0e-7 # Discriminator's epsilon. + weight_decay: 0.0 # Discriminator's weight decay coefficient. + +discriminator_grad_norm: -1 # Discriminator's gradient norm. +discriminator_scheduler_params: + learning_rate: 1.0e-3 # Discriminator's learning rate. + gamma: 0.5 # Discriminator's scheduler gamma. + milestones: # At each milestone, lr will be multiplied by gamma. + - 100000 + - 200000 + - 300000 + - 400000 + - 500000 + - 600000 + +########################################################### +# INTERVAL SETTING # +########################################################### +discriminator_train_start_steps: 200000 # Number of steps to start to train discriminator. +train_max_steps: 1000000 # Number of training steps. +save_interval_steps: 50000 # Interval steps to save checkpoint. +eval_interval_steps: 1000 # Interval steps to evaluate the network. + +########################################################### +# OTHER SETTING # +########################################################### +num_snapshots: 10 # max number of snapshots to keep while training +seed: 42 # random seed for paddle, random, and np.random \ No newline at end of file diff --git a/parakeet/models/melgan/melgan.py b/parakeet/models/melgan/melgan.py index ccc19d5f..0347ff22 100644 --- a/parakeet/models/melgan/melgan.py +++ b/parakeet/models/melgan/melgan.py @@ -19,7 +19,6 @@ from typing import List import numpy as np import paddle from paddle import nn -from paddle.fluid.layers import Normal from parakeet.modules.causal_conv import CausalConv1D from parakeet.modules.causal_conv import CausalConv1DTranspose @@ -238,7 +237,7 @@ class MelGANGenerator(nn.Layer): """ # 定义参数为float的正态分布。 - dist = Normal(loc=0.0, scale=0.02) + dist = paddle.distribution.Normal(loc=0.0, scale=0.02) def _reset_parameters(m): if isinstance(m, nn.Conv1D) or isinstance(m, nn.Conv1DTranspose): @@ -290,8 +289,8 @@ class MelGANDiscriminator(nn.Layer): """Initilize MelGAN discriminator module. Parameters ---------- - in_channels : - int): Number of input channels. + in_channels : int + Number of input channels. out_channels : int Number of output channels. kernel_sizes : List[int] @@ -531,7 +530,7 @@ class MelGANMultiScaleDiscriminator(nn.Layer): """ # 定义参数为float的正态分布。 - dist = Normal(loc=0.0, scale=0.02) + dist = paddle.distribution.Normal(loc=0.0, scale=0.02) def _reset_parameters(m): if isinstance(m, nn.Conv1D) or isinstance(m, nn.Conv1DTranspose): diff --git a/parakeet/models/parallel_wavegan/parallel_wavegan.py b/parakeet/models/parallel_wavegan/parallel_wavegan.py index e166ccde..fe4ec355 100644 --- a/parakeet/models/parallel_wavegan/parallel_wavegan.py +++ b/parakeet/models/parallel_wavegan/parallel_wavegan.py @@ -495,25 +495,25 @@ class PWGGenerator(nn.Layer): self.apply(_remove_weight_norm) - def inference(self, c): + def inference(self, c=None): """Waveform generation. This function is used for single instance inference. - Parameters ---------- - c : Tensor + c : Tensor, optional Shape (T', C_aux), the auxiliary input, by default None - + x : Tensor, optional + Shape (T, C_in), the noise waveform, by default None + If not provided, a sample is drawn from a gaussian distribution. Returns ------- Tensor Shape (T, C_out), the generated waveform """ - # a sample is drawn from a gaussian distribution. + # when to static, can not input x, see https://github.com/PaddlePaddle/Parakeet/pull/132/files x = paddle.randn( [1, self.in_channels, paddle.shape(c)[0] * self.upsample_factor]) - # pseudo batch - c = paddle.transpose(c, [1, 0]).unsqueeze(0) + c = paddle.transpose(c, [1, 0]).unsqueeze(0) # pseudo batch c = nn.Pad1D(self.aux_context_window, mode='replicate')(c) out = self(x, c).squeeze(0).transpose([1, 0]) return out diff --git a/parakeet/modules/residual_stack.py b/parakeet/modules/residual_stack.py index 135c32e5..b798fbb6 100644 --- a/parakeet/modules/residual_stack.py +++ b/parakeet/modules/residual_stack.py @@ -106,4 +106,7 @@ class ResidualStack(nn.Layer): Tensor Output tensor (B, chennels, T). """ - return self.stack(c) + self.skip_layer(c) + stack_output = self.stack(c) + skip_layer_output = self.skip_layer(c) + out = stack_output + skip_layer_output + return out