modify iSTFT.yaml and hifigan.py

pull/3006/head
longrookie 3 years ago
parent 394e635958
commit ef5e96f5a2

@ -24,16 +24,15 @@ fmax: 7600 # Maximum frequency in mel basis calculation. (Hz)
# GENERATOR NETWORK ARCHITECTURE SETTING # # GENERATOR NETWORK ARCHITECTURE SETTING #
########################################################### ###########################################################
generator_params: generator_params:
istft: True # use iSTFTNet use_istft: True # Use iSTFTNet
post_n_fft: 48 # stft fft_num istft_layer_id: 2 # Use istft after istft_layer_id layers of upsample layer if use_istft=True
gen_istft_hop_size: 12 # istft hop_length overlap_ratio: 4 # The ratio of istft_n_fft and istft_hop_size
gen_istft_n_fft: 48 # istft n_fft
in_channels: 80 # Number of input channels. in_channels: 80 # Number of input channels.
out_channels: 1 # Number of output channels. out_channels: 1 # Number of output channels.
channels: 512 # Number of initial channels. channels: 512 # Number of initial channels.
kernel_size: 7 # Kernel size of initial and final conv layers. kernel_size: 7 # Kernel size of initial and final conv layers.
upsample_scales: [5,5] # Upsampling scales. upsample_scales: [5, 5, 4, 3] # Upsampling scales.
upsample_kernel_sizes: [10,10] # Kernel size for upsampling layers. upsample_kernel_sizes: [10, 10, 8, 6] # Kernel size for upsampling layers.
resblock_kernel_sizes: [3, 7, 11] # Kernel size for residual blocks. resblock_kernel_sizes: [3, 7, 11] # Kernel size for residual blocks.
resblock_dilations: # Dilations for residual blocks. resblock_dilations: # Dilations for residual blocks.
- [1, 3, 5] - [1, 3, 5]
@ -163,7 +162,7 @@ discriminator_grad_norm: -1 # Discriminator's gradient norm.
########################################################### ###########################################################
generator_train_start_steps: 1 # Number of steps to start to train discriminator. generator_train_start_steps: 1 # Number of steps to start to train discriminator.
discriminator_train_start_steps: 0 # Number of steps to start to train discriminator. discriminator_train_start_steps: 0 # Number of steps to start to train discriminator.
train_max_steps: 50000 # Number of training steps. train_max_steps: 2500000 # Number of training steps.
save_interval_steps: 5000 # Interval steps to save checkpoint. save_interval_steps: 5000 # Interval steps to save checkpoint.
eval_interval_steps: 1000 # Interval steps to evaluate the network. eval_interval_steps: 1000 # Interval steps to evaluate the network.

@ -38,8 +38,8 @@ class HiFiGANGenerator(nn.Layer):
channels: int=512, channels: int=512,
global_channels: int=-1, global_channels: int=-1,
kernel_size: int=7, kernel_size: int=7,
upsample_scales: List[int]=(8, 8, 2, 2), upsample_scales: List[int]=(5, 5, 4, 3),
upsample_kernel_sizes: List[int]=(16, 16, 4, 4), upsample_kernel_sizes: List[int]=(10, 10, 8, 6),
resblock_kernel_sizes: List[int]=(3, 7, 11), resblock_kernel_sizes: List[int]=(3, 7, 11),
resblock_dilations: List[List[int]]=[(1, 3, 5), (1, 3, 5), resblock_dilations: List[List[int]]=[(1, 3, 5), (1, 3, 5),
(1, 3, 5)], (1, 3, 5)],
@ -49,12 +49,11 @@ class HiFiGANGenerator(nn.Layer):
nonlinear_activation_params: Dict[str, Any]={"negative_slope": 0.1}, nonlinear_activation_params: Dict[str, Any]={"negative_slope": 0.1},
use_weight_norm: bool=True, use_weight_norm: bool=True,
init_type: str="xavier_uniform", init_type: str="xavier_uniform",
istft: bool = False, use_istft: bool=False,
post_n_fft: int=16, istft_layer_id: int=2,
gen_istft_hop_size: int=12, overlap_ratio: float=4, ):
gen_istft_n_fft: int=16,
):
"""Initialize HiFiGANGenerator module. """Initialize HiFiGANGenerator module.
Args: Args:
in_channels (int): in_channels (int):
Number of input channels. Number of input channels.
@ -85,14 +84,12 @@ class HiFiGANGenerator(nn.Layer):
use_weight_norm (bool): use_weight_norm (bool):
Whether to use weight norm. Whether to use weight norm.
If set to true, it will be applied to all of the conv layers. If set to true, it will be applied to all of the conv layers.
istft (bool): use_istft (bool):
If set to true, it will be a iSTFTNet based on hifigan. If set to true, it will be a iSTFTNet based on hifigan.
post_n_fft (int): istft_layer_id (int):
Emulate nfft in stft Use istft after istft_layer_id layers of upsample layer if use_istft=True
gen_istft_hop_size (int): overlap_ratio (float):
Hop_length in istft The ratio of istft_n_fft and istft_hop_size
gen_istft_n_fft (int):
N_fft in istft, equal to post_n_fft
""" """
super().__init__() super().__init__()
@ -103,9 +100,10 @@ class HiFiGANGenerator(nn.Layer):
assert kernel_size % 2 == 1, "Kernel size must be odd number." assert kernel_size % 2 == 1, "Kernel size must be odd number."
assert len(upsample_scales) == len(upsample_kernel_sizes) assert len(upsample_scales) == len(upsample_kernel_sizes)
assert len(resblock_dilations) == len(resblock_kernel_sizes) assert len(resblock_dilations) == len(resblock_kernel_sizes)
assert len(upsample_scales) >= istft_layer_id if use_istft else True
# define modules # define modules
self.num_upsamples = len(upsample_kernel_sizes) self.num_upsamples = len(upsample_kernel_sizes) if not use_istft else istft_layer_id
self.num_blocks = len(resblock_kernel_sizes) self.num_blocks = len(resblock_kernel_sizes)
self.input_conv = nn.Conv1D( self.input_conv = nn.Conv1D(
in_channels, in_channels,
@ -115,7 +113,7 @@ class HiFiGANGenerator(nn.Layer):
padding=(kernel_size - 1) // 2, ) padding=(kernel_size - 1) // 2, )
self.upsamples = nn.LayerList() self.upsamples = nn.LayerList()
self.blocks = nn.LayerList() self.blocks = nn.LayerList()
for i in range(len(upsample_kernel_sizes)): for i in range(self.num_upsamples):
assert upsample_kernel_sizes[i] == 2 * upsample_scales[i] assert upsample_kernel_sizes[i] == 2 * upsample_scales[i]
self.upsamples.append( self.upsamples.append(
nn.Sequential( nn.Sequential(
@ -149,13 +147,18 @@ class HiFiGANGenerator(nn.Layer):
1, 1,
padding=(kernel_size - 1) // 2, ), padding=(kernel_size - 1) // 2, ),
nn.Tanh(), ) nn.Tanh(), )
self.istft = istft self.use_istft = use_istft
if self.istft: if self.use_istft:
self.post_n_fft = post_n_fft self.istft_hop_size = 1
self.gen_istft_hop_size = gen_istft_hop_size for j in range(istft_layer_id, len(upsample_scales)):
self.gen_istft_n_fft = gen_istft_n_fft self.istft_hop_size *= upsample_scales[j]
self.istft_layer_id = istft_layer_id
self.istft_n_fft = int(self.istft_hop_size * overlap_ratio)
self.istft_win_size = self.istft_n_fft
self.reflection_pad = nn.Pad1D(padding=[1,0], mode='reflect') self.reflection_pad = nn.Pad1D(padding=[1,0], mode='reflect')
self.conv_post = nn.Conv1D(channels// (2**(i + 1)), self.post_n_fft + 2, 7, 1, padding=3) self.conv_post = nn.Conv1D(channels// (2**(i + 1)), (self.istft_n_fft // 2 + 1)*2, kernel_size, 1, padding=(kernel_size - 1) // 2, )
else:
self.istft_layer_id = len(upsample_scales)
if global_channels > 0: if global_channels > 0:
self.global_conv = nn.Conv1D(global_channels, channels, 1) self.global_conv = nn.Conv1D(global_channels, channels, 1)
@ -189,7 +192,7 @@ class HiFiGANGenerator(nn.Layer):
cs += self.blocks[i * self.num_blocks + j](c) cs += self.blocks[i * self.num_blocks + j](c)
c = cs / self.num_blocks c = cs / self.num_blocks
if self.istft: if self.use_istft:
c = F.leaky_relu(c) c = F.leaky_relu(c)
c = self.reflection_pad(c) c = self.reflection_pad(c)
c = self.conv_post(c) c = self.conv_post(c)
@ -198,11 +201,11 @@ class HiFiGANGenerator(nn.Layer):
https://www.paddlepaddle.org.cn/documentation/docs/en/api/paddle/exp_en.html https://www.paddlepaddle.org.cn/documentation/docs/en/api/paddle/exp_en.html
Use Euler's formula to implement spec*paddle.exp(1j*phase) Use Euler's formula to implement spec*paddle.exp(1j*phase)
""" """
spec = paddle.exp(c[:,:self.post_n_fft // 2 + 1, :]) spec = paddle.exp(c[:, :self.istft_n_fft // 2 + 1, :])
phase = paddle.sin(c[:, self.post_n_fft // 2 + 1:, :]) phase = paddle.sin(c[:, self.istft_n_fft // 2 + 1:, :])
c = paddle.complex(spec*(paddle.cos(phase)), spec*(paddle.sin(phase))) c = paddle.complex(spec*(paddle.cos(phase)), spec*(paddle.sin(phase)))
c = paddle.signal.istft(c , self.gen_istft_n_fft, hop_length=self.gen_istft_hop_size, win_length=self.gen_istft_n_fft) c = paddle.signal.istft(c, n_fft=self.istft_n_fft, hop_length=self.istft_hop_size, win_length=self.istft_win_size)
c = c.unsqueeze(1) c = c.unsqueeze(1)
else: else:
c = self.output_conv(c) c = self.output_conv(c)

Loading…
Cancel
Save