|
|
|
@ -33,10 +33,9 @@ class JDCNet(nn.Layer):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.seq_len = seq_len
|
|
|
|
|
self.num_class = num_class
|
|
|
|
|
|
|
|
|
|
# input = (b, 1, 31, 513), b = batch size
|
|
|
|
|
# input: (B, num_class, T, n_mels)
|
|
|
|
|
self.conv_block = nn.Sequential(
|
|
|
|
|
# out: (b, 64, 31, 513)
|
|
|
|
|
# output: (B, out_channels, T, n_mels)
|
|
|
|
|
nn.Conv2D(
|
|
|
|
|
in_channels=1,
|
|
|
|
|
out_channels=64,
|
|
|
|
@ -45,127 +44,100 @@ class JDCNet(nn.Layer):
|
|
|
|
|
bias_attr=False),
|
|
|
|
|
nn.BatchNorm2D(num_features=64),
|
|
|
|
|
nn.LeakyReLU(leaky_relu_slope),
|
|
|
|
|
# (b, 64, 31, 513)
|
|
|
|
|
# out: (B, out_channels, T, n_mels)
|
|
|
|
|
nn.Conv2D(64, 64, 3, padding=1, bias_attr=False), )
|
|
|
|
|
|
|
|
|
|
# res blocks
|
|
|
|
|
# (b, 128, 31, 128)
|
|
|
|
|
# output: (B, out_channels, T, n_mels//2)
|
|
|
|
|
self.res_block1 = ResBlock(in_channels=64, out_channels=128)
|
|
|
|
|
# (b, 192, 31, 32)
|
|
|
|
|
# output: (B, out_channels, T, n_mels//4)
|
|
|
|
|
self.res_block2 = ResBlock(in_channels=128, out_channels=192)
|
|
|
|
|
# (b, 256, 31, 8)
|
|
|
|
|
# output: (B, out_channels, T, n_mels//8)
|
|
|
|
|
self.res_block3 = ResBlock(in_channels=192, out_channels=256)
|
|
|
|
|
|
|
|
|
|
# pool block
|
|
|
|
|
self.pool_block = nn.Sequential(
|
|
|
|
|
nn.BatchNorm2D(num_features=256),
|
|
|
|
|
nn.LeakyReLU(leaky_relu_slope),
|
|
|
|
|
# (b, 256, 31, 2)
|
|
|
|
|
# (B, num_features, T, 2)
|
|
|
|
|
nn.MaxPool2D(kernel_size=(1, 4)),
|
|
|
|
|
nn.Dropout(p=0.5), )
|
|
|
|
|
|
|
|
|
|
# maxpool layers (for auxiliary network inputs)
|
|
|
|
|
# in = (b, 128, 31, 513) from conv_block, out = (b, 128, 31, 2)
|
|
|
|
|
self.maxpool1 = nn.MaxPool2D(kernel_size=(1, 40))
|
|
|
|
|
# in = (b, 128, 31, 128) from res_block1, out = (b, 128, 31, 2)
|
|
|
|
|
self.maxpool2 = nn.MaxPool2D(kernel_size=(1, 20))
|
|
|
|
|
# in = (b, 128, 31, 32) from res_block2, out = (b, 128, 31, 2)
|
|
|
|
|
self.maxpool3 = nn.MaxPool2D(kernel_size=(1, 10))
|
|
|
|
|
|
|
|
|
|
# in = (b, 640, 31, 2), out = (b, 256, 31, 2)
|
|
|
|
|
self.detector_conv = nn.Sequential(
|
|
|
|
|
nn.Conv2D(
|
|
|
|
|
in_channels=640,
|
|
|
|
|
out_channels=256,
|
|
|
|
|
kernel_size=1,
|
|
|
|
|
bias_attr=False),
|
|
|
|
|
nn.BatchNorm2D(256),
|
|
|
|
|
nn.LeakyReLU(leaky_relu_slope),
|
|
|
|
|
nn.Dropout(p=0.5), )
|
|
|
|
|
|
|
|
|
|
# input: (b, 31, 512) - resized from (b, 256, 31, 2)
|
|
|
|
|
# output: (b, 31, 512)
|
|
|
|
|
# input: (B, T, input_size) - resized from (B, input_size//2, T, 2)
|
|
|
|
|
# output: (B, T, input_size)
|
|
|
|
|
self.bilstm_classifier = nn.LSTM(
|
|
|
|
|
input_size=512,
|
|
|
|
|
hidden_size=256,
|
|
|
|
|
time_major=False,
|
|
|
|
|
direction='bidirectional')
|
|
|
|
|
|
|
|
|
|
# input: (b, 31, 512) - resized from (b, 256, 31, 2)
|
|
|
|
|
# output: (b, 31, 512)
|
|
|
|
|
self.bilstm_detector = nn.LSTM(
|
|
|
|
|
input_size=512,
|
|
|
|
|
hidden_size=256,
|
|
|
|
|
time_major=False,
|
|
|
|
|
direction='bidirectional')
|
|
|
|
|
|
|
|
|
|
# input: (b * 31, 512)
|
|
|
|
|
# output: (b * 31, num_class)
|
|
|
|
|
# input: (B * T, in_features)
|
|
|
|
|
# output: (B * T, num_class)
|
|
|
|
|
self.classifier = nn.Linear(
|
|
|
|
|
in_features=512, out_features=self.num_class)
|
|
|
|
|
|
|
|
|
|
# input: (b * 31, 512)
|
|
|
|
|
# output: (b * 31, 2) - binary classifier
|
|
|
|
|
self.detector = nn.Linear(in_features=512, out_features=2)
|
|
|
|
|
|
|
|
|
|
# initialize weights
|
|
|
|
|
self.apply(self.init_weights)
|
|
|
|
|
|
|
|
|
|
def get_feature_GAN(self, x: paddle.Tensor):
|
|
|
|
|
seq_len = x.shape[-2]
|
|
|
|
|
x = x.astype(paddle.float32).transpose([0, 1, 3, 2] if len(x.shape) == 4
|
|
|
|
|
else [0, 2, 1])
|
|
|
|
|
|
|
|
|
|
"""Calculate feature_GAN.
|
|
|
|
|
Args:
|
|
|
|
|
x(Tensor(float32)):
|
|
|
|
|
Shape (B, num_class, n_mels, T).
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor:
|
|
|
|
|
Shape (B, num_features, n_mels//8, T).
|
|
|
|
|
"""
|
|
|
|
|
x = x.astype(paddle.float32)
|
|
|
|
|
x = x.transpose([0, 1, 3, 2] if len(x.shape) == 4 else [0, 2, 1])
|
|
|
|
|
convblock_out = self.conv_block(x)
|
|
|
|
|
|
|
|
|
|
resblock1_out = self.res_block1(convblock_out)
|
|
|
|
|
resblock2_out = self.res_block2(resblock1_out)
|
|
|
|
|
resblock3_out = self.res_block3(resblock2_out)
|
|
|
|
|
poolblock_out = self.pool_block[0](resblock3_out)
|
|
|
|
|
poolblock_out = self.pool_block[1](poolblock_out)
|
|
|
|
|
|
|
|
|
|
return poolblock_out.transpose([0, 1, 3, 2] if len(poolblock_out.shape)
|
|
|
|
|
== 4 else [0, 2, 1])
|
|
|
|
|
GAN_feature = poolblock_out.transpose([0, 1, 3, 2] if len(
|
|
|
|
|
poolblock_out.shape) == 4 else [0, 2, 1])
|
|
|
|
|
return GAN_feature
|
|
|
|
|
|
|
|
|
|
def forward(self, x: paddle.Tensor):
|
|
|
|
|
"""
|
|
|
|
|
"""Calculate forward propagation.
|
|
|
|
|
Args:
|
|
|
|
|
x(Tensor(float32)):
|
|
|
|
|
Shape (B, num_class, n_mels, seq_len).
|
|
|
|
|
Returns:
|
|
|
|
|
classification_prediction, detection_prediction
|
|
|
|
|
sizes: (b, 31, 722), (b, 31, 2)
|
|
|
|
|
Tensor:
|
|
|
|
|
classifier output consists of predicted pitch classes per frame.
|
|
|
|
|
Shape: (B, seq_len, num_class).
|
|
|
|
|
Tensor:
|
|
|
|
|
GAN_feature. Shape: (B, num_features, n_mels//8, seq_len)
|
|
|
|
|
Tensor:
|
|
|
|
|
poolblock_out. Shape (B, seq_len, 512)
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
###############################
|
|
|
|
|
# forward pass for classifier #
|
|
|
|
|
###############################
|
|
|
|
|
# (B, num_class, n_mels, T) -> (B, num_class, T, n_mels)
|
|
|
|
|
x = x.transpose([0, 1, 3, 2] if len(x.shape) == 4 else
|
|
|
|
|
[0, 2, 1]).astype(paddle.float32)
|
|
|
|
|
|
|
|
|
|
convblock_out = self.conv_block(x)
|
|
|
|
|
|
|
|
|
|
resblock1_out = self.res_block1(convblock_out)
|
|
|
|
|
resblock2_out = self.res_block2(resblock1_out)
|
|
|
|
|
resblock3_out = self.res_block3(resblock2_out)
|
|
|
|
|
|
|
|
|
|
poolblock_out = self.pool_block[0](resblock3_out)
|
|
|
|
|
poolblock_out = self.pool_block[1](poolblock_out)
|
|
|
|
|
GAN_feature = poolblock_out.transpose([0, 1, 3, 2] if len(
|
|
|
|
|
poolblock_out.shape) == 4 else [0, 2, 1])
|
|
|
|
|
poolblock_out = self.pool_block[2](poolblock_out)
|
|
|
|
|
|
|
|
|
|
# (b, 256, 31, 2) => (b, 31, 256, 2) => (b, 31, 512)
|
|
|
|
|
# (B, 256, seq_len, 2) => (B, seq_len, 256, 2) => (B, seq_len, 512)
|
|
|
|
|
classifier_out = poolblock_out.transpose([0, 2, 1, 3]).reshape(
|
|
|
|
|
(-1, self.seq_len, 512))
|
|
|
|
|
self.bilstm_classifier.flatten_parameters()
|
|
|
|
|
classifier_out, _ = self.bilstm_classifier(
|
|
|
|
|
classifier_out) # ignore the hidden states
|
|
|
|
|
|
|
|
|
|
classifier_out = classifier_out.reshape((-1, 512)) # (b * 31, 512)
|
|
|
|
|
# ignore the hidden states
|
|
|
|
|
classifier_out, _ = self.bilstm_classifier(classifier_out)
|
|
|
|
|
# (B * seq_len, 512)
|
|
|
|
|
classifier_out = classifier_out.reshape((-1, 512))
|
|
|
|
|
classifier_out = self.classifier(classifier_out)
|
|
|
|
|
# (B, seq_len, num_class)
|
|
|
|
|
classifier_out = classifier_out.reshape(
|
|
|
|
|
(-1, self.seq_len, self.num_class)) # (b, 31, num_class)
|
|
|
|
|
|
|
|
|
|
# sizes: (b, 31, 722), (b, 31, 2)
|
|
|
|
|
# classifier output consists of predicted pitch classes per frame
|
|
|
|
|
# detector output consists of: (isvoice, notvoice) estimates per frame
|
|
|
|
|
(-1, self.seq_len, self.num_class))
|
|
|
|
|
return paddle.abs(classifier_out.squeeze()), GAN_feature, poolblock_out
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@ -188,10 +160,9 @@ class ResBlock(nn.Layer):
|
|
|
|
|
def __init__(self,
|
|
|
|
|
in_channels: int,
|
|
|
|
|
out_channels: int,
|
|
|
|
|
leaky_relu_slope=0.01):
|
|
|
|
|
leaky_relu_slope: float=0.01):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.downsample = in_channels != out_channels
|
|
|
|
|
|
|
|
|
|
# BN / LReLU / MaxPool layer before the conv layer - see Figure 1b in the paper
|
|
|
|
|
self.pre_conv = nn.Sequential(
|
|
|
|
|
nn.BatchNorm2D(num_features=in_channels),
|
|
|
|
@ -215,7 +186,6 @@ class ResBlock(nn.Layer):
|
|
|
|
|
kernel_size=3,
|
|
|
|
|
padding=1,
|
|
|
|
|
bias_attr=False), )
|
|
|
|
|
|
|
|
|
|
# 1 x 1 convolution layer to match the feature dimensions
|
|
|
|
|
self.conv1by1 = None
|
|
|
|
|
if self.downsample:
|
|
|
|
@ -226,6 +196,13 @@ class ResBlock(nn.Layer):
|
|
|
|
|
bias_attr=False)
|
|
|
|
|
|
|
|
|
|
def forward(self, x: paddle.Tensor):
|
|
|
|
|
"""Calculate forward propagation.
|
|
|
|
|
Args:
|
|
|
|
|
x(Tensor(float32)): Shape (B, in_channels, T, n_mels).
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor:
|
|
|
|
|
The residual output, Shape (B, out_channels, T, n_mels//2).
|
|
|
|
|
"""
|
|
|
|
|
x = self.pre_conv(x)
|
|
|
|
|
if self.downsample:
|
|
|
|
|
x = self.conv(x) + self.conv1by1(x)
|
|
|
|
|