modify iSTFT.yaml and hifigan.py

pull/3006/head
longrookie 3 years ago
parent 394e635958
commit ef5e96f5a2

@ -24,16 +24,15 @@ fmax: 7600 # Maximum frequency in mel basis calculation. (Hz)
# GENERATOR NETWORK ARCHITECTURE SETTING #
###########################################################
generator_params:
istft: True # use iSTFTNet
post_n_fft: 48 # stft fft_num
gen_istft_hop_size: 12 # istft hop_length
gen_istft_n_fft: 48 # istft n_fft
use_istft: True # Use iSTFTNet
istft_layer_id: 2 # Use istft after istft_layer_id layers of upsample layer if use_istft=True
overlap_ratio: 4 # The ratio of istft_n_fft and istft_hop_size
in_channels: 80 # Number of input channels.
out_channels: 1 # Number of output channels.
channels: 512 # Number of initial channels.
kernel_size: 7 # Kernel size of initial and final conv layers.
upsample_scales: [5,5] # Upsampling scales.
upsample_kernel_sizes: [10,10] # Kernel size for upsampling layers.
upsample_scales: [5, 5, 4, 3] # Upsampling scales.
upsample_kernel_sizes: [10, 10, 8, 6] # Kernel size for upsampling layers.
resblock_kernel_sizes: [3, 7, 11] # Kernel size for residual blocks.
resblock_dilations: # Dilations for residual blocks.
- [1, 3, 5]
@ -163,7 +162,7 @@ discriminator_grad_norm: -1 # Discriminator's gradient norm.
###########################################################
generator_train_start_steps: 1 # Number of steps to start to train discriminator.
discriminator_train_start_steps: 0 # Number of steps to start to train discriminator.
train_max_steps: 50000 # Number of training steps.
train_max_steps: 2500000 # Number of training steps.
save_interval_steps: 5000 # Interval steps to save checkpoint.
eval_interval_steps: 1000 # Interval steps to evaluate the network.

@ -38,8 +38,8 @@ class HiFiGANGenerator(nn.Layer):
channels: int=512,
global_channels: int=-1,
kernel_size: int=7,
upsample_scales: List[int]=(8, 8, 2, 2),
upsample_kernel_sizes: List[int]=(16, 16, 4, 4),
upsample_scales: List[int]=(5, 5, 4, 3),
upsample_kernel_sizes: List[int]=(10, 10, 8, 6),
resblock_kernel_sizes: List[int]=(3, 7, 11),
resblock_dilations: List[List[int]]=[(1, 3, 5), (1, 3, 5),
(1, 3, 5)],
@ -49,12 +49,11 @@ class HiFiGANGenerator(nn.Layer):
nonlinear_activation_params: Dict[str, Any]={"negative_slope": 0.1},
use_weight_norm: bool=True,
init_type: str="xavier_uniform",
istft: bool = False,
post_n_fft: int=16,
gen_istft_hop_size: int=12,
gen_istft_n_fft: int=16,
):
use_istft: bool=False,
istft_layer_id: int=2,
overlap_ratio: float=4, ):
"""Initialize HiFiGANGenerator module.
Args:
in_channels (int):
Number of input channels.
@ -85,14 +84,12 @@ class HiFiGANGenerator(nn.Layer):
use_weight_norm (bool):
Whether to use weight norm.
If set to true, it will be applied to all of the conv layers.
istft (bool):
use_istft (bool):
If set to true, it will be a iSTFTNet based on hifigan.
post_n_fft (int):
Emulate nfft in stft
gen_istft_hop_size (int):
Hop_length in istft
gen_istft_n_fft (int):
N_fft in istft, equal to post_n_fft
istft_layer_id (int):
Use istft after istft_layer_id layers of upsample layer if use_istft=True
overlap_ratio (float):
The ratio of istft_n_fft and istft_hop_size
"""
super().__init__()
@ -103,9 +100,10 @@ class HiFiGANGenerator(nn.Layer):
assert kernel_size % 2 == 1, "Kernel size must be odd number."
assert len(upsample_scales) == len(upsample_kernel_sizes)
assert len(resblock_dilations) == len(resblock_kernel_sizes)
assert len(upsample_scales) >= istft_layer_id if use_istft else True
# define modules
self.num_upsamples = len(upsample_kernel_sizes)
self.num_upsamples = len(upsample_kernel_sizes) if not use_istft else istft_layer_id
self.num_blocks = len(resblock_kernel_sizes)
self.input_conv = nn.Conv1D(
in_channels,
@ -115,7 +113,7 @@ class HiFiGANGenerator(nn.Layer):
padding=(kernel_size - 1) // 2, )
self.upsamples = nn.LayerList()
self.blocks = nn.LayerList()
for i in range(len(upsample_kernel_sizes)):
for i in range(self.num_upsamples):
assert upsample_kernel_sizes[i] == 2 * upsample_scales[i]
self.upsamples.append(
nn.Sequential(
@ -149,13 +147,18 @@ class HiFiGANGenerator(nn.Layer):
1,
padding=(kernel_size - 1) // 2, ),
nn.Tanh(), )
self.istft = istft
if self.istft:
self.post_n_fft = post_n_fft
self.gen_istft_hop_size = gen_istft_hop_size
self.gen_istft_n_fft = gen_istft_n_fft
self.use_istft = use_istft
if self.use_istft:
self.istft_hop_size = 1
for j in range(istft_layer_id, len(upsample_scales)):
self.istft_hop_size *= upsample_scales[j]
self.istft_layer_id = istft_layer_id
self.istft_n_fft = int(self.istft_hop_size * overlap_ratio)
self.istft_win_size = self.istft_n_fft
self.reflection_pad = nn.Pad1D(padding=[1,0], mode='reflect')
self.conv_post = nn.Conv1D(channels// (2**(i + 1)), self.post_n_fft + 2, 7, 1, padding=3)
self.conv_post = nn.Conv1D(channels// (2**(i + 1)), (self.istft_n_fft // 2 + 1)*2, kernel_size, 1, padding=(kernel_size - 1) // 2, )
else:
self.istft_layer_id = len(upsample_scales)
if global_channels > 0:
self.global_conv = nn.Conv1D(global_channels, channels, 1)
@ -188,8 +191,8 @@ class HiFiGANGenerator(nn.Layer):
for j in range(self.num_blocks):
cs += self.blocks[i * self.num_blocks + j](c)
c = cs / self.num_blocks
if self.istft:
if self.use_istft:
c = F.leaky_relu(c)
c = self.reflection_pad(c)
c = self.conv_post(c)
@ -198,11 +201,11 @@ class HiFiGANGenerator(nn.Layer):
https://www.paddlepaddle.org.cn/documentation/docs/en/api/paddle/exp_en.html
Use Euler's formula to implement spec*paddle.exp(1j*phase)
"""
spec = paddle.exp(c[:,:self.post_n_fft // 2 + 1, :])
phase = paddle.sin(c[:, self.post_n_fft // 2 + 1:, :])
spec = paddle.exp(c[:, :self.istft_n_fft // 2 + 1, :])
phase = paddle.sin(c[:, self.istft_n_fft // 2 + 1:, :])
c = paddle.complex(spec*(paddle.cos(phase)), spec*(paddle.sin(phase)))
c = paddle.signal.istft(c , self.gen_istft_n_fft, hop_length=self.gen_istft_hop_size, win_length=self.gen_istft_n_fft)
c = paddle.signal.istft(c, n_fft=self.istft_n_fft, hop_length=self.istft_hop_size, win_length=self.istft_win_size)
c = c.unsqueeze(1)
else:
c = self.output_conv(c)

Loading…
Cancel
Save