add docstring

pull/2987/head
TianYuan 3 years ago
parent 6ee353ccaa
commit 14d46fe613

@ -44,7 +44,8 @@ class LinearNorm(nn.Layer):
self.linear_layer.weight, gain=_calculate_gain(w_init_gain)) self.linear_layer.weight, gain=_calculate_gain(w_init_gain))
def forward(self, x: paddle.Tensor): def forward(self, x: paddle.Tensor):
return self.linear_layer(x) out = self.linear_layer(x)
return out
class ConvNorm(nn.Layer): class ConvNorm(nn.Layer):
@ -183,13 +184,14 @@ class Attention(nn.Layer):
""" """
Args: Args:
query: query:
decoder output (batch, n_mel_channels * n_frames_per_step) decoder output (B, n_mel_channels * n_frames_per_step)
processed_memory: processed_memory:
processed encoder outputs (B, T_in, attention_dim) processed encoder outputs (B, T_in, attention_dim)
attention_weights_cat: attention_weights_cat:
cumulative and prev. att weights (B, 2, max_time) cumulative and prev. att weights (B, 2, max_time)
Returns: Returns:
Tensor: alignment (batch, max_time) Tensor:
alignment (B, max_time)
""" """
processed_query = self.query_layer(query.unsqueeze(1)) processed_query = self.query_layer(query.unsqueeze(1))
@ -254,7 +256,6 @@ class MFCC(nn.Layer):
# -> (channel, time, n_mfcc).tranpose(...) # -> (channel, time, n_mfcc).tranpose(...)
mfcc = paddle.matmul(mel_specgram.transpose([0, 2, 1]), mfcc = paddle.matmul(mel_specgram.transpose([0, 2, 1]),
self.dct_mat).transpose([0, 2, 1]) self.dct_mat).transpose([0, 2, 1])
# unpack batch # unpack batch
if unsqueezed: if unsqueezed:
mfcc = mfcc.squeeze(0) mfcc = mfcc.squeeze(0)

@ -194,8 +194,7 @@ class ASRS2S(nn.Layer):
logit_outputs += [logit] logit_outputs += [logit]
alignments += [attention_weights] alignments += [attention_weights]
hidden_outputs, logit_outputs, alignments = \ hidden_outputs, logit_outputs, alignments = self.parse_decoder_outputs(
self.parse_decoder_outputs(
hidden_outputs, logit_outputs, alignments) hidden_outputs, logit_outputs, alignments)
return hidden_outputs, logit_outputs, alignments return hidden_outputs, logit_outputs, alignments

@ -59,7 +59,7 @@ class JDCNet(nn.Layer):
# (B, num_features, T, 2) # (B, num_features, T, 2)
nn.MaxPool2D(kernel_size=(1, 4)), nn.MaxPool2D(kernel_size=(1, 4)),
nn.Dropout(p=0.5), ) nn.Dropout(p=0.5), )
# input: (B, T, input_size) - resized from (B, input_size//2, T, 2) # input: (B, T, input_size), resized from (B, input_size // 2, T, 2)
# output: (B, T, input_size) # output: (B, T, input_size)
self.bilstm_classifier = nn.LSTM( self.bilstm_classifier = nn.LSTM(
input_size=512, input_size=512,
@ -108,7 +108,6 @@ class JDCNet(nn.Layer):
GAN_feature. Shape: (B, num_features, n_mels // 8, seq_len) GAN_feature. Shape: (B, num_features, n_mels // 8, seq_len)
Tensor: Tensor:
poolblock_out. Shape (B, seq_len, 512) poolblock_out. Shape (B, seq_len, 512)
""" """
############################### ###############################
# forward pass for classifier # # forward pass for classifier #

@ -32,12 +32,23 @@ class DownSample(nn.Layer):
self.layer_type = layer_type self.layer_type = layer_type
def forward(self, x: paddle.Tensor): def forward(self, x: paddle.Tensor):
"""Calculate forward propagation.
Args:
x(Tensor(float32)): Shape (B, dim_in, n_mels, T).
Returns:
Tensor:
layer_type == 'none': Shape (B, dim_in, n_mels, T)
layer_type == 'timepreserve': Shape (B, dim_in, n_mels // 2, T)
layer_type == 'half': Shape (B, dim_in, n_mels // 2, T // 2)
"""
if self.layer_type == 'none': if self.layer_type == 'none':
return x return x
elif self.layer_type == 'timepreserve': elif self.layer_type == 'timepreserve':
return F.avg_pool2d(x, (2, 1)) out = F.avg_pool2d(x, (2, 1))
return out
elif self.layer_type == 'half': elif self.layer_type == 'half':
return F.avg_pool2d(x, 2) out = F.avg_pool2d(x, 2)
return out
else: else:
raise RuntimeError( raise RuntimeError(
'Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' 'Got unexpected donwsampletype %s, expected is [none, timepreserve, half]'
@ -50,12 +61,23 @@ class UpSample(nn.Layer):
self.layer_type = layer_type self.layer_type = layer_type
def forward(self, x: paddle.Tensor): def forward(self, x: paddle.Tensor):
"""Calculate forward propagation.
Args:
x(Tensor(float32)): Shape (B, dim_in, n_mels, T).
Returns:
Tensor:
layer_type == 'none': Shape (B, dim_in, n_mels, T)
layer_type == 'timepreserve': Shape (B, dim_in, n_mels * 2, T)
layer_type == 'half': Shape (B, dim_in, n_mels * 2, T * 2)
"""
if self.layer_type == 'none': if self.layer_type == 'none':
return x return x
elif self.layer_type == 'timepreserve': elif self.layer_type == 'timepreserve':
return F.interpolate(x, scale_factor=(2, 1), mode='nearest') out = F.interpolate(x, scale_factor=(2, 1), mode='nearest')
return out
elif self.layer_type == 'half': elif self.layer_type == 'half':
return F.interpolate(x, scale_factor=2, mode='nearest') out = F.interpolate(x, scale_factor=2, mode='nearest')
return out
else: else:
raise RuntimeError( raise RuntimeError(
'Got unexpected upsampletype %s, expected is [none, timepreserve, half]' 'Got unexpected upsampletype %s, expected is [none, timepreserve, half]'
@ -126,7 +148,9 @@ class ResBlk(nn.Layer):
x(Tensor(float32)): Shape (B, dim_in, n_mels, T). x(Tensor(float32)): Shape (B, dim_in, n_mels, T).
Returns: Returns:
Tensor: Tensor:
Shape (B, dim_out, T, n_mels//(1 or 2), T//(1 or 2)). downsample == 'none': Shape (B, dim_in, n_mels, T).
downsample == 'timepreserve': Shape (B, dim_out, T, n_mels // 2, T).
downsample == 'half': Shape (B, dim_out, T, n_mels // 2, T // 2).
""" """
x = self._shortcut(x) + self._residual(x) x = self._shortcut(x) + self._residual(x)
# unit variance # unit variance
@ -142,12 +166,21 @@ class AdaIN(nn.Layer):
self.fc = nn.Linear(style_dim, num_features * 2) self.fc = nn.Linear(style_dim, num_features * 2)
def forward(self, x: paddle.Tensor, s: paddle.Tensor): def forward(self, x: paddle.Tensor, s: paddle.Tensor):
"""Calculate forward propagation.
Args:
x(Tensor(float32)): Shape (B, style_dim, n_mels, T).
s(Tensor(float32)): Shape (style_dim, ).
Returns:
Tensor:
Shape (B, style_dim, T, n_mels, T).
"""
if len(s.shape) == 1: if len(s.shape) == 1:
s = s[None] s = s[None]
h = self.fc(s) h = self.fc(s)
h = h.reshape((h.shape[0], h.shape[1], 1, 1)) h = h.reshape((h.shape[0], h.shape[1], 1, 1))
gamma, beta = paddle.split(h, 2, axis=1) gamma, beta = paddle.split(h, 2, axis=1)
return (1 + gamma) * self.norm(x) + beta out = (1 + gamma) * self.norm(x) + beta
return out
class AdainResBlk(nn.Layer): class AdainResBlk(nn.Layer):
@ -164,6 +197,7 @@ class AdainResBlk(nn.Layer):
self.upsample = UpSample(layer_type=upsample) self.upsample = UpSample(layer_type=upsample)
self.learned_sc = dim_in != dim_out self.learned_sc = dim_in != dim_out
self._build_weights(dim_in, dim_out, style_dim) self._build_weights(dim_in, dim_out, style_dim)
self.layer_type = upsample
def _build_weights(self, dim_in: int, dim_out: int, style_dim: int=64): def _build_weights(self, dim_in: int, dim_out: int, style_dim: int=64):
self.conv1 = nn.Conv2D( self.conv1 = nn.Conv2D(
@ -209,12 +243,14 @@ class AdainResBlk(nn.Layer):
"""Calculate forward propagation. """Calculate forward propagation.
Args: Args:
x(Tensor(float32)): x(Tensor(float32)):
Shape (B, dim_in, n_mels', T'). Shape (B, dim_in, n_mels, T).
s(Tensor(float32)): s(Tensor(float32)):
Shape (64,). Shape (64,).
Returns: Returns:
Tensor: Tensor:
Shape (B, dim_out, n_mels'', T''). upsample == 'none': Shape (B, dim_out, T, n_mels, T).
upsample == 'timepreserve': Shape (B, dim_out, T, n_mels * 2, T).
upsample == 'half': Shape (B, dim_out, T, n_mels * 2, T * 2).
""" """
out = self._residual(x, s) out = self._residual(x, s)
if self.w_hpf == 0: if self.w_hpf == 0:
@ -398,7 +434,6 @@ class MappingNetwork(nn.Layer):
Shape (B, 1, n_mels, T). Shape (B, 1, n_mels, T).
y(Tensor(float32)): y(Tensor(float32)):
speaker label. Shape (B, ). speaker label. Shape (B, ).
Returns: Returns:
Tensor: Tensor:
Shape (style_dim, ) Shape (style_dim, )
@ -502,10 +537,12 @@ class Discriminator(nn.Layer):
self.num_domains = num_domains self.num_domains = num_domains
def forward(self, x: paddle.Tensor, y: paddle.Tensor): def forward(self, x: paddle.Tensor, y: paddle.Tensor):
return self.dis(x, y) out = self.dis(x, y)
return out
def classifier(self, x: paddle.Tensor): def classifier(self, x: paddle.Tensor):
return self.cls.get_feature(x) out = self.cls.get_feature(x)
return out
class Discriminator2D(nn.Layer): class Discriminator2D(nn.Layer):

Loading…
Cancel
Save