[TTS]clean starganv2 vc model code and add docstring (#2987)

* clean code

* add docstring
pull/3067/head
TianYuan 3 years ago committed by GitHub
parent 880c172db7
commit 0a2e367ff4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import paddle
import paddle.nn.functional as F
import paddleaudio.functional as audio_F
@ -46,7 +44,8 @@ class LinearNorm(nn.Layer):
self.linear_layer.weight, gain=_calculate_gain(w_init_gain))
def forward(self, x: paddle.Tensor):
return self.linear_layer(x)
out = self.linear_layer(x)
return out
class ConvNorm(nn.Layer):
@ -82,85 +81,6 @@ class ConvNorm(nn.Layer):
return conv_signal
class CausualConv(nn.Layer):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: int=1,
stride: int=1,
padding: int=1,
dilation: int=1,
bias: bool=True,
w_init_gain: str='linear',
param=None):
super().__init__()
if padding is None:
assert (kernel_size % 2 == 1)
padding = int(dilation * (kernel_size - 1) / 2) * 2
else:
self.padding = padding * 2
self.conv = nn.Conv1D(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
dilation=dilation,
bias_attr=bias)
xavier_uniform_(
self.conv.weight, gain=_calculate_gain(w_init_gain, param=param))
def forward(self, x: paddle.Tensor):
x = self.conv(x)
x = x[:, :, :-self.padding]
return x
class CausualBlock(nn.Layer):
def __init__(self,
hidden_dim: int,
n_conv: int=3,
dropout_p: float=0.2,
activ: str='lrelu'):
super().__init__()
self.blocks = nn.LayerList([
self._get_conv(
hidden_dim=hidden_dim,
dilation=3**i,
activ=activ,
dropout_p=dropout_p) for i in range(n_conv)
])
def forward(self, x):
for block in self.blocks:
res = x
x = block(x)
x += res
return x
def _get_conv(self,
hidden_dim: int,
dilation: int,
activ: str='lrelu',
dropout_p: float=0.2):
layers = [
CausualConv(
in_channels=hidden_dim,
out_channels=hidden_dim,
kernel_size=3,
padding=dilation,
dilation=dilation), _get_activation_fn(activ),
nn.BatchNorm1D(hidden_dim), nn.Dropout(p=dropout_p), CausualConv(
in_channels=hidden_dim,
out_channels=hidden_dim,
kernel_size=3,
padding=1,
dilation=1), _get_activation_fn(activ), nn.Dropout(p=dropout_p)
]
return nn.Sequential(*layers)
class ConvBlock(nn.Layer):
def __init__(self,
hidden_dim: int,
@ -264,13 +184,14 @@ class Attention(nn.Layer):
"""
Args:
query:
decoder output (batch, n_mel_channels * n_frames_per_step)
decoder output (B, n_mel_channels * n_frames_per_step)
processed_memory:
processed encoder outputs (B, T_in, attention_dim)
attention_weights_cat:
cumulative and prev. att weights (B, 2, max_time)
Returns:
Tensor: alignment (batch, max_time)
Tensor:
alignment (B, max_time)
"""
processed_query = self.query_layer(query.unsqueeze(1))
@ -316,144 +237,6 @@ class Attention(nn.Layer):
return attention_context, attention_weights
class ForwardAttentionV2(nn.Layer):
def __init__(self,
attention_rnn_dim: int,
embedding_dim: int,
attention_dim: int,
attention_location_n_filters: int,
attention_location_kernel_size: int):
super().__init__()
self.query_layer = LinearNorm(
in_dim=attention_rnn_dim,
out_dim=attention_dim,
bias=False,
w_init_gain='tanh')
self.memory_layer = LinearNorm(
in_dim=embedding_dim,
out_dim=attention_dim,
bias=False,
w_init_gain='tanh')
self.v = LinearNorm(in_dim=attention_dim, out_dim=1, bias=False)
self.location_layer = LocationLayer(
attention_n_filters=attention_location_n_filters,
attention_kernel_size=attention_location_kernel_size,
attention_dim=attention_dim)
self.score_mask_value = -float(1e20)
def get_alignment_energies(self,
query: paddle.Tensor,
processed_memory: paddle.Tensor,
attention_weights_cat: paddle.Tensor):
"""
Args:
query:
decoder output (batch, n_mel_channels * n_frames_per_step)
processed_memory:
processed encoder outputs (B, T_in, attention_dim)
attention_weights_cat:
prev. and cumulative att weights (B, 2, max_time)
Returns:
Tensor: alignment (batch, max_time)
"""
processed_query = self.query_layer(query.unsqueeze(1))
processed_attention_weights = self.location_layer(attention_weights_cat)
energies = self.v(
paddle.tanh(processed_query + processed_attention_weights +
processed_memory))
energies = energies.squeeze(-1)
return energies
def forward(self,
attention_hidden_state: paddle.Tensor,
memory: paddle.Tensor,
processed_memory: paddle.Tensor,
attention_weights_cat: paddle.Tensor,
mask: paddle.Tensor,
log_alpha: paddle.Tensor):
"""
Args:
attention_hidden_state:
attention rnn last output
memory:
encoder outputs
processed_memory:
processed encoder outputs
attention_weights_cat:
previous and cummulative attention weights
mask:
binary mask for padded data
"""
log_energy = self.get_alignment_energies(
query=attention_hidden_state,
processed_memory=processed_memory,
attention_weights_cat=attention_weights_cat)
if mask is not None:
log_energy[:] = paddle.where(
mask,
paddle.full(log_energy.shape, self.score_mask_value,
log_energy.dtype), log_energy)
log_alpha_shift_padded = []
max_time = log_energy.shape[1]
for sft in range(2):
shifted = log_alpha[:, :max_time - sft]
shift_padded = F.pad(shifted, (sft, 0), 'constant',
self.score_mask_value)
log_alpha_shift_padded.append(shift_padded.unsqueeze(2))
biased = paddle.logsumexp(paddle.conat(log_alpha_shift_padded, 2), 2)
log_alpha_new = biased + log_energy
attention_weights = F.softmax(log_alpha_new, axis=1)
attention_context = paddle.bmm(attention_weights.unsqueeze(1), memory)
attention_context = attention_context.squeeze(1)
return attention_context, attention_weights, log_alpha_new
class PhaseShuffle2D(nn.Layer):
def __init__(self, n: int=2):
super().__init__()
self.n = n
self.random = random.Random(1)
def forward(self, x: paddle.Tensor, move: int=None):
# x.size = (B, C, M, L)
if move is None:
move = self.random.randint(-self.n, self.n)
if move == 0:
return x
else:
left = x[:, :, :, :move]
right = x[:, :, :, move:]
shuffled = paddle.concat([right, left], axis=3)
return shuffled
class PhaseShuffle1D(nn.Layer):
def __init__(self, n: int=2):
super().__init__()
self.n = n
self.random = random.Random(1)
def forward(self, x: paddle.Tensor, move: int=None):
# x.size = (B, C, M, L)
if move is None:
move = self.random.randint(-self.n, self.n)
if move == 0:
return x
else:
left = x[:, :, :move]
right = x[:, :, move:]
shuffled = paddle.concat([right, left], axis=2)
return shuffled
class MFCC(nn.Layer):
def __init__(self, n_mfcc: int=40, n_mels: int=80):
super().__init__()
@ -473,7 +256,6 @@ class MFCC(nn.Layer):
# -> (channel, time, n_mfcc).tranpose(...)
mfcc = paddle.matmul(mel_specgram.transpose([0, 2, 1]),
self.dct_mat).transpose([0, 2, 1])
# unpack batch
if unsqueezed:
mfcc = mfcc.squeeze(0)

@ -99,7 +99,7 @@ class ASRCNN(nn.Layer):
unmask_futre_steps (int):
unmasking future step size.
Return:
mask (paddle.BoolTensor):
Tensor (paddle.Tensor(bool)):
mask future timesteps mask[i, j] = True if i > j + unmask_future_steps else False
"""
index_tensor = paddle.arange(out_length).unsqueeze(0).expand(
@ -194,8 +194,7 @@ class ASRS2S(nn.Layer):
logit_outputs += [logit]
alignments += [attention_weights]
hidden_outputs, logit_outputs, alignments = \
self.parse_decoder_outputs(
hidden_outputs, logit_outputs, alignments = self.parse_decoder_outputs(
hidden_outputs, logit_outputs, alignments)
return hidden_outputs, logit_outputs, alignments

@ -33,10 +33,9 @@ class JDCNet(nn.Layer):
super().__init__()
self.seq_len = seq_len
self.num_class = num_class
# input = (b, 1, 31, 513), b = batch size
# input: (B, num_class, T, n_mels)
self.conv_block = nn.Sequential(
# out: (b, 64, 31, 513)
# output: (B, out_channels, T, n_mels)
nn.Conv2D(
in_channels=1,
out_channels=64,
@ -45,127 +44,99 @@ class JDCNet(nn.Layer):
bias_attr=False),
nn.BatchNorm2D(num_features=64),
nn.LeakyReLU(leaky_relu_slope),
# (b, 64, 31, 513)
# out: (B, out_channels, T, n_mels)
nn.Conv2D(64, 64, 3, padding=1, bias_attr=False), )
# res blocks
# (b, 128, 31, 128)
# output: (B, out_channels, T, n_mels // 2)
self.res_block1 = ResBlock(in_channels=64, out_channels=128)
# (b, 192, 31, 32)
# output: (B, out_channels, T, n_mels // 4)
self.res_block2 = ResBlock(in_channels=128, out_channels=192)
# (b, 256, 31, 8)
# output: (B, out_channels, T, n_mels // 8)
self.res_block3 = ResBlock(in_channels=192, out_channels=256)
# pool block
self.pool_block = nn.Sequential(
nn.BatchNorm2D(num_features=256),
nn.LeakyReLU(leaky_relu_slope),
# (b, 256, 31, 2)
# (B, num_features, T, 2)
nn.MaxPool2D(kernel_size=(1, 4)),
nn.Dropout(p=0.5), )
# maxpool layers (for auxiliary network inputs)
# in = (b, 128, 31, 513) from conv_block, out = (b, 128, 31, 2)
self.maxpool1 = nn.MaxPool2D(kernel_size=(1, 40))
# in = (b, 128, 31, 128) from res_block1, out = (b, 128, 31, 2)
self.maxpool2 = nn.MaxPool2D(kernel_size=(1, 20))
# in = (b, 128, 31, 32) from res_block2, out = (b, 128, 31, 2)
self.maxpool3 = nn.MaxPool2D(kernel_size=(1, 10))
# in = (b, 640, 31, 2), out = (b, 256, 31, 2)
self.detector_conv = nn.Sequential(
nn.Conv2D(
in_channels=640,
out_channels=256,
kernel_size=1,
bias_attr=False),
nn.BatchNorm2D(256),
nn.LeakyReLU(leaky_relu_slope),
nn.Dropout(p=0.5), )
# input: (b, 31, 512) - resized from (b, 256, 31, 2)
# output: (b, 31, 512)
# input: (B, T, input_size), resized from (B, input_size // 2, T, 2)
# output: (B, T, input_size)
self.bilstm_classifier = nn.LSTM(
input_size=512,
hidden_size=256,
time_major=False,
direction='bidirectional')
# input: (b, 31, 512) - resized from (b, 256, 31, 2)
# output: (b, 31, 512)
self.bilstm_detector = nn.LSTM(
input_size=512,
hidden_size=256,
time_major=False,
direction='bidirectional')
# input: (b * 31, 512)
# output: (b * 31, num_class)
# input: (B * T, in_features)
# output: (B * T, num_class)
self.classifier = nn.Linear(
in_features=512, out_features=self.num_class)
# input: (b * 31, 512)
# output: (b * 31, 2) - binary classifier
self.detector = nn.Linear(in_features=512, out_features=2)
# initialize weights
self.apply(self.init_weights)
def get_feature_GAN(self, x: paddle.Tensor):
seq_len = x.shape[-2]
x = x.astype(paddle.float32).transpose([0, 1, 3, 2] if len(x.shape) == 4
else [0, 2, 1])
"""Calculate feature_GAN.
Args:
x(Tensor(float32)):
Shape (B, num_class, n_mels, T).
Returns:
Tensor:
Shape (B, num_features, n_mels // 8, T).
"""
x = x.astype(paddle.float32)
x = x.transpose([0, 1, 3, 2] if len(x.shape) == 4 else [0, 2, 1])
convblock_out = self.conv_block(x)
resblock1_out = self.res_block1(convblock_out)
resblock2_out = self.res_block2(resblock1_out)
resblock3_out = self.res_block3(resblock2_out)
poolblock_out = self.pool_block[0](resblock3_out)
poolblock_out = self.pool_block[1](poolblock_out)
return poolblock_out.transpose([0, 1, 3, 2] if len(poolblock_out.shape)
== 4 else [0, 2, 1])
GAN_feature = poolblock_out.transpose([0, 1, 3, 2] if len(
poolblock_out.shape) == 4 else [0, 2, 1])
return GAN_feature
def forward(self, x: paddle.Tensor):
"""
"""Calculate forward propagation.
Args:
x(Tensor(float32)):
Shape (B, num_class, n_mels, seq_len).
Returns:
classification_prediction, detection_prediction
sizes: (b, 31, 722), (b, 31, 2)
Tensor:
classifier output consists of predicted pitch classes per frame.
Shape: (B, seq_len, num_class).
Tensor:
GAN_feature. Shape: (B, num_features, n_mels // 8, seq_len)
Tensor:
poolblock_out. Shape (B, seq_len, 512)
"""
###############################
# forward pass for classifier #
###############################
# (B, num_class, n_mels, T) -> (B, num_class, T, n_mels)
x = x.transpose([0, 1, 3, 2] if len(x.shape) == 4 else
[0, 2, 1]).astype(paddle.float32)
convblock_out = self.conv_block(x)
resblock1_out = self.res_block1(convblock_out)
resblock2_out = self.res_block2(resblock1_out)
resblock3_out = self.res_block3(resblock2_out)
poolblock_out = self.pool_block[0](resblock3_out)
poolblock_out = self.pool_block[1](poolblock_out)
GAN_feature = poolblock_out.transpose([0, 1, 3, 2] if len(
poolblock_out.shape) == 4 else [0, 2, 1])
poolblock_out = self.pool_block[2](poolblock_out)
# (b, 256, 31, 2) => (b, 31, 256, 2) => (b, 31, 512)
# (B, 256, seq_len, 2) => (B, seq_len, 256, 2) => (B, seq_len, 512)
classifier_out = poolblock_out.transpose([0, 2, 1, 3]).reshape(
(-1, self.seq_len, 512))
self.bilstm_classifier.flatten_parameters()
classifier_out, _ = self.bilstm_classifier(
classifier_out) # ignore the hidden states
classifier_out = classifier_out.reshape((-1, 512)) # (b * 31, 512)
# ignore the hidden states
classifier_out, _ = self.bilstm_classifier(classifier_out)
# (B * seq_len, 512)
classifier_out = classifier_out.reshape((-1, 512))
classifier_out = self.classifier(classifier_out)
# (B, seq_len, num_class)
classifier_out = classifier_out.reshape(
(-1, self.seq_len, self.num_class)) # (b, 31, num_class)
# sizes: (b, 31, 722), (b, 31, 2)
# classifier output consists of predicted pitch classes per frame
# detector output consists of: (isvoice, notvoice) estimates per frame
(-1, self.seq_len, self.num_class))
return paddle.abs(classifier_out.squeeze()), GAN_feature, poolblock_out
@staticmethod
@ -188,10 +159,9 @@ class ResBlock(nn.Layer):
def __init__(self,
in_channels: int,
out_channels: int,
leaky_relu_slope=0.01):
leaky_relu_slope: float=0.01):
super().__init__()
self.downsample = in_channels != out_channels
# BN / LReLU / MaxPool layer before the conv layer - see Figure 1b in the paper
self.pre_conv = nn.Sequential(
nn.BatchNorm2D(num_features=in_channels),
@ -215,7 +185,6 @@ class ResBlock(nn.Layer):
kernel_size=3,
padding=1,
bias_attr=False), )
# 1 x 1 convolution layer to match the feature dimensions
self.conv1by1 = None
if self.downsample:
@ -226,6 +195,13 @@ class ResBlock(nn.Layer):
bias_attr=False)
def forward(self, x: paddle.Tensor):
"""Calculate forward propagation.
Args:
x(Tensor(float32)): Shape (B, in_channels, T, n_mels).
Returns:
Tensor:
The residual output, Shape (B, out_channels, T, n_mels // 2).
"""
x = self.pre_conv(x)
if self.downsample:
x = self.conv(x) + self.conv1by1(x)

@ -19,31 +19,36 @@ This work is licensed under the Creative Commons Attribution-NonCommercial
http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
"""
# import copy
import math
import paddle
import paddle.nn.functional as F
from paddle import nn
from paddlespeech.utils.initialize import _calculate_gain
from paddlespeech.utils.initialize import xavier_uniform_
# from munch import Munch
class DownSample(nn.Layer):
def __init__(self, layer_type: str):
super().__init__()
self.layer_type = layer_type
def forward(self, x):
def forward(self, x: paddle.Tensor):
"""Calculate forward propagation.
Args:
x(Tensor(float32)): Shape (B, dim_in, n_mels, T).
Returns:
Tensor:
layer_type == 'none': Shape (B, dim_in, n_mels, T)
layer_type == 'timepreserve': Shape (B, dim_in, n_mels // 2, T)
layer_type == 'half': Shape (B, dim_in, n_mels // 2, T // 2)
"""
if self.layer_type == 'none':
return x
elif self.layer_type == 'timepreserve':
return F.avg_pool2d(x, (2, 1))
out = F.avg_pool2d(x, (2, 1))
return out
elif self.layer_type == 'half':
return F.avg_pool2d(x, 2)
out = F.avg_pool2d(x, 2)
return out
else:
raise RuntimeError(
'Got unexpected donwsampletype %s, expected is [none, timepreserve, half]'
@ -55,13 +60,24 @@ class UpSample(nn.Layer):
super().__init__()
self.layer_type = layer_type
def forward(self, x):
def forward(self, x: paddle.Tensor):
"""Calculate forward propagation.
Args:
x(Tensor(float32)): Shape (B, dim_in, n_mels, T).
Returns:
Tensor:
layer_type == 'none': Shape (B, dim_in, n_mels, T)
layer_type == 'timepreserve': Shape (B, dim_in, n_mels * 2, T)
layer_type == 'half': Shape (B, dim_in, n_mels * 2, T * 2)
"""
if self.layer_type == 'none':
return x
elif self.layer_type == 'timepreserve':
return F.interpolate(x, scale_factor=(2, 1), mode='nearest')
out = F.interpolate(x, scale_factor=(2, 1), mode='nearest')
return out
elif self.layer_type == 'half':
return F.interpolate(x, scale_factor=2, mode='nearest')
out = F.interpolate(x, scale_factor=2, mode='nearest')
return out
else:
raise RuntimeError(
'Got unexpected upsampletype %s, expected is [none, timepreserve, half]'
@ -127,9 +143,19 @@ class ResBlk(nn.Layer):
return x
def forward(self, x: paddle.Tensor):
"""Calculate forward propagation.
Args:
x(Tensor(float32)): Shape (B, dim_in, n_mels, T).
Returns:
Tensor:
downsample == 'none': Shape (B, dim_in, n_mels, T).
downsample == 'timepreserve': Shape (B, dim_out, T, n_mels // 2, T).
downsample == 'half': Shape (B, dim_out, T, n_mels // 2, T // 2).
"""
x = self._shortcut(x) + self._residual(x)
# unit variance
return x / math.sqrt(2)
out = x / math.sqrt(2)
return out
class AdaIN(nn.Layer):
@ -140,12 +166,21 @@ class AdaIN(nn.Layer):
self.fc = nn.Linear(style_dim, num_features * 2)
def forward(self, x: paddle.Tensor, s: paddle.Tensor):
"""Calculate forward propagation.
Args:
x(Tensor(float32)): Shape (B, style_dim, n_mels, T).
s(Tensor(float32)): Shape (style_dim, ).
Returns:
Tensor:
Shape (B, style_dim, T, n_mels, T).
"""
if len(s.shape) == 1:
s = s[None]
h = self.fc(s)
h = h.reshape((h.shape[0], h.shape[1], 1, 1))
gamma, beta = paddle.split(h, 2, axis=1)
return (1 + gamma) * self.norm(x) + beta
out = (1 + gamma) * self.norm(x) + beta
return out
class AdainResBlk(nn.Layer):
@ -162,6 +197,7 @@ class AdainResBlk(nn.Layer):
self.upsample = UpSample(layer_type=upsample)
self.learned_sc = dim_in != dim_out
self._build_weights(dim_in, dim_out, style_dim)
self.layer_type = upsample
def _build_weights(self, dim_in: int, dim_out: int, style_dim: int=64):
self.conv1 = nn.Conv2D(
@ -204,6 +240,18 @@ class AdainResBlk(nn.Layer):
return x
def forward(self, x: paddle.Tensor, s: paddle.Tensor):
"""Calculate forward propagation.
Args:
x(Tensor(float32)):
Shape (B, dim_in, n_mels, T).
s(Tensor(float32)):
Shape (64,).
Returns:
Tensor:
upsample == 'none': Shape (B, dim_out, T, n_mels, T).
upsample == 'timepreserve': Shape (B, dim_out, T, n_mels * 2, T).
upsample == 'half': Shape (B, dim_out, T, n_mels * 2, T * 2).
"""
out = self._residual(x, s)
if self.w_hpf == 0:
out = (out + self._shortcut(x)) / math.sqrt(2)
@ -219,7 +267,8 @@ class HighPass(nn.Layer):
def forward(self, x: paddle.Tensor):
filter = self.filter.unsqueeze(0).unsqueeze(1).tile(
[x.shape[1], 1, 1, 1])
return F.conv2d(x, filter, padding=1, groups=x.shape[1])
out = F.conv2d(x, filter, padding=1, groups=x.shape[1])
return out
class Generator(nn.Layer):
@ -276,12 +325,10 @@ class Generator(nn.Layer):
w_hpf=w_hpf,
upsample=_downtype)) # stack-like
dim_in = dim_out
# bottleneck blocks (encoder)
for _ in range(2):
self.encode.append(
ResBlk(dim_in=dim_out, dim_out=dim_out, normalize=True))
# F0 blocks
if F0_channel != 0:
self.decode.insert(0,
@ -290,7 +337,6 @@ class Generator(nn.Layer):
dim_out=dim_out,
style_dim=style_dim,
w_hpf=w_hpf))
# bottleneck blocks (decoder)
for _ in range(2):
self.decode.insert(0,
@ -299,7 +345,6 @@ class Generator(nn.Layer):
dim_out=dim_out + int(F0_channel / 2),
style_dim=style_dim,
w_hpf=w_hpf))
if F0_channel != 0:
self.F0_conv = nn.Sequential(
ResBlk(
@ -307,7 +352,6 @@ class Generator(nn.Layer):
dim_out=int(F0_channel / 2),
normalize=True,
downsample="half"), )
if w_hpf > 0:
self.hpf = HighPass(w_hpf)
@ -316,26 +360,44 @@ class Generator(nn.Layer):
s: paddle.Tensor,
masks: paddle.Tensor=None,
F0: paddle.Tensor=None):
"""Calculate forward propagation.
Args:
x(Tensor(float32)):
Shape (B, 1, n_mels, T).
s(Tensor(float32)):
Shape (64,).
masks:
None.
F0:
Shape (B, num_features(256), n_mels // 8, T).
Returns:
Tensor:
output of generator. Shape (B, 1, n_mels, T // 4 * 4)
"""
x = self.stem(x)
cache = {}
# output: (B, max_conv_dim, n_mels // 16, T // 4)
for block in self.encode:
if (masks is not None) and (x.shape[2] in [32, 64, 128]):
cache[x.shape[2]] = x
x = block(x)
if F0 is not None:
# input: (B, num_features(256), n_mels // 8, T)
# output: (B, num_features(256) // 2, n_mels // 16, T // 2)
F0 = self.F0_conv(F0)
# output: (B, num_features(256) // 2, n_mels // 16, T // 4)
F0 = F.adaptive_avg_pool2d(F0, [x.shape[-2], x.shape[-1]])
x = paddle.concat([x, F0], axis=1)
# input: (B, max_conv_dim+num_features(256) // 2, n_mels // 16, T // 4 * 4)
# output: (B, dim_in, n_mels, T // 4 * 4)
for block in self.decode:
x = block(x, s)
if (masks is not None) and (x.shape[2] in [32, 64, 128]):
mask = masks[0] if x.shape[2] in [32] else masks[1]
mask = F.interpolate(mask, size=x.shape[2], mode='bilinear')
x = x + self.hpf(mask * cache[x.shape[2]])
return self.to_out(x)
out = self.to_out(x)
return out
class MappingNetwork(nn.Layer):
@ -366,14 +428,25 @@ class MappingNetwork(nn.Layer):
])
def forward(self, z: paddle.Tensor, y: paddle.Tensor):
"""Calculate forward propagation.
Args:
z(Tensor(float32)):
Shape (B, 1, n_mels, T).
y(Tensor(float32)):
speaker label. Shape (B, ).
Returns:
Tensor:
Shape (style_dim, )
"""
h = self.shared(z)
out = []
for layer in self.unshared:
out += [layer(h)]
# (batch, num_domains, style_dim)
# (B, num_domains, style_dim)
out = paddle.stack(out, axis=1)
idx = paddle.arange(y.shape[0])
# (batch, style_dim)
# (style_dim, )
s = out[idx, y]
return s
@ -419,15 +492,25 @@ class StyleEncoder(nn.Layer):
self.unshared.append(nn.Linear(dim_out, style_dim))
def forward(self, x: paddle.Tensor, y: paddle.Tensor):
"""Calculate forward propagation.
Args:
x(Tensor(float32)):
Shape (B, 1, n_mels, T).
y(Tensor(float32)):
speaker label. Shape (B, ).
Returns:
Tensor:
Shape (style_dim, )
"""
h = self.shared(x)
h = h.reshape((h.shape[0], -1))
out = []
for layer in self.unshared:
out += [layer(h)]
# (batch, num_domains, style_dim)
# (B, num_domains, style_dim)
out = paddle.stack(out, axis=1)
idx = paddle.arange(y.shape[0])
# (batch, style_dim)
# (style_dim,)
s = out[idx, y]
return s
@ -454,25 +537,12 @@ class Discriminator(nn.Layer):
self.num_domains = num_domains
def forward(self, x: paddle.Tensor, y: paddle.Tensor):
return self.dis(x, y)
out = self.dis(x, y)
return out
def classifier(self, x: paddle.Tensor):
return self.cls.get_feature(x)
class LinearNorm(nn.Layer):
def __init__(self,
in_dim: int,
out_dim: int,
bias: bool=True,
w_init_gain: str='linear'):
super().__init__()
self.linear_layer = nn.Linear(in_dim, out_dim, bias_attr=bias)
xavier_uniform_(
self.linear_layer.weight, gain=_calculate_gain(w_init_gain))
def forward(self, x):
return self.linear_layer(x)
out = self.cls.get_feature(x)
return out
class Discriminator2D(nn.Layer):
@ -520,97 +590,13 @@ class Discriminator2D(nn.Layer):
def get_feature(self, x: paddle.Tensor):
out = self.main(x)
# (batch, num_domains)
# (B, num_domains)
out = out.reshape((out.shape[0], -1))
return out
def forward(self, x: paddle.Tensor, y: paddle.Tensor):
out = self.get_feature(x)
idx = paddle.arange(y.shape[0])
# (batch)
# (B,) ?
out = out[idx, y]
return out
'''
def build_model(args, F0_model: nn.Layer, ASR_model: nn.Layer):
generator = Generator(
dim_in=args.dim_in,
style_dim=args.style_dim,
max_conv_dim=args.max_conv_dim,
w_hpf=args.w_hpf,
F0_channel=args.F0_channel)
mapping_network = MappingNetwork(
latent_dim=args.latent_dim,
style_dim=args.style_dim,
num_domains=args.num_domains,
hidden_dim=args.max_conv_dim)
style_encoder = StyleEncoder(
dim_in=args.dim_in,
style_dim=args.style_dim,
num_domains=args.num_domains,
max_conv_dim=args.max_conv_dim)
discriminator = Discriminator(
dim_in=args.dim_in,
num_domains=args.num_domains,
max_conv_dim=args.max_conv_dim,
n_repeat=args.n_repeat)
generator_ema = copy.deepcopy(generator)
mapping_network_ema = copy.deepcopy(mapping_network)
style_encoder_ema = copy.deepcopy(style_encoder)
nets = Munch(
generator=generator,
mapping_network=mapping_network,
style_encoder=style_encoder,
discriminator=discriminator,
f0_model=F0_model,
asr_model=ASR_model)
nets_ema = Munch(
generator=generator_ema,
mapping_network=mapping_network_ema,
style_encoder=style_encoder_ema)
return nets, nets_ema
class StarGANv2VC(nn.Layer):
def __init__(
self,
# spk_num
num_domains: int=20,
dim_in: int=64,
style_dim: int=64,
latent_dim: int=16,
max_conv_dim: int=512,
n_repeat: int=4,
w_hpf: int=0,
F0_channel: int=256):
super().__init__()
self.generator = Generator(
dim_in=dim_in,
style_dim=style_dim,
max_conv_dim=max_conv_dim,
w_hpf=w_hpf,
F0_channel=F0_channel)
# MappingNetwork and StyleEncoder are used to generate reference_embeddings
self.mapping_network = MappingNetwork(
latent_dim=latent_dim,
style_dim=style_dim,
num_domains=num_domains,
hidden_dim=max_conv_dim)
self.style_encoder = StyleEncoder(
dim_in=dim_in,
style_dim=style_dim,
num_domains=num_domains,
max_conv_dim=max_conv_dim)
self.discriminator = Discriminator(
dim_in=dim_in,
num_domains=num_domains,
max_conv_dim=max_conv_dim,
repeat_num=n_repeat)
'''

Loading…
Cancel
Save