add StarGANv2-VC model scripts, test=tts

pull/2842/head
TianYuan 3 years ago
parent 478fd2593e
commit 61178b3a5f

@ -10,3 +10,4 @@
* voc2 - MelGAN * voc2 - MelGAN
* voc3 - MultiBand MelGAN * voc3 - MultiBand MelGAN
* ernie_sat - ERNIE-SAT * ernie_sat - ERNIE-SAT
* vc3 - StarGANv2-VC

@ -0,0 +1,22 @@
generator_params:
dim_in: 64
style_dim: 64
max_conv_dim: 512
w_hpf: 0
F0_channel: 256
mapping_network_params:
num_domains: 20 # num of speakers in StarGANv2
latent_dim: 16
style_dim: 64 # same as style_dim in generator_params
max_conv_dim: 512 # same as max_conv_dim in generator_params
style_encoder_params:
dim_in: 64 # same as dim_in in generator_params
style_dim: 64 # same as style_dim in generator_params
num_domains: 20 # same as num_domains in generator_params
max_conv_dim: 512 # same as max_conv_dim in generator_params
discriminator_params:
dim_in: 64 # same as dim_in in generator_params
num_domains: 20 # same as num_domains in mapping_network_params
max_conv_dim: 512 # same as max_conv_dim in generator_params
n_repeat: 4

@ -0,0 +1,18 @@
#!/bin/bash
stage=0
stop_stage=100
config_path=$1
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
fi

@ -0,0 +1,13 @@
#!/bin/bash
config_path=$1
train_output_path=$2
python3 ${BIN_DIR}/train.py \
--train-metadata=dump/train/norm/metadata.jsonl \
--dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \
--output-dir=${train_output_path} \
--ngpu=1 \
--phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt

@ -0,0 +1,13 @@
#!/bin/bash
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
MODEL=starganv2_vc
export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL}

@ -0,0 +1,33 @@
#!/bin/bash
set -e
source path.sh
gpus=0,1
stage=0
stop_stage=100
conf_path=conf/default.yaml
train_output_path=exp/default
ckpt_name=snapshot_iter_331.pdz
# with the following command, you can choose the stage range you want to run
# such as `./run.sh --stage 0 --stop-stage 0`
# this can not be mixed use with `$1`, `$2` ...
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
./local/preprocess.sh ${conf_path} || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# synthesize, vocoder is pwgan by default
CUDA_VISIBLE_DEVICES=${gpus} ./local/voice_conversion.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
fi

@ -0,0 +1,13 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,29 @@
log_dir: "logs"
save_freq: 20
device: "cuda"
epochs: 180
batch_size: 48
pretrained_model: ""
train_data: "asr_train_list.txt"
val_data: "asr_val_list.txt"
dataset_params:
data_augmentation: true
preprocess_parasm:
sr: 24000
spect_params:
n_fft: 2048
win_length: 1200
hop_length: 300
mel_params:
n_mels: 80
model_params:
input_dim: 80
hidden_dim: 256
n_token: 80
token_embedding_dim: 256
optimizer_params:
lr: 0.0005

@ -0,0 +1,480 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import paddle
import paddle.nn.functional as F
import paddleaudio.functional as audio_F
from paddle import nn
from paddlespeech.utils.initialize import _calculate_gain
from paddlespeech.utils.initialize import xavier_uniform_
def _get_activation_fn(activ):
if activ == 'relu':
return nn.ReLU()
elif activ == 'lrelu':
return nn.LeakyReLU(0.2)
elif activ == 'swish':
return nn.Swish()
else:
raise RuntimeError(
'Unexpected activ type %s, expected [relu, lrelu, swish]' % activ)
class LinearNorm(nn.Layer):
def __init__(self,
in_dim: int,
out_dim: int,
bias: bool=True,
w_init_gain: str='linear'):
super().__init__()
self.linear_layer = nn.Linear(in_dim, out_dim, bias_attr=bias)
xavier_uniform_(
self.linear_layer.weight, gain=_calculate_gain(w_init_gain))
def forward(self, x: paddle.Tensor):
return self.linear_layer(x)
class ConvNorm(nn.Layer):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: int=1,
stride: int=1,
padding: int=None,
dilation: int=1,
bias: bool=True,
w_init_gain: str='linear',
param=None):
super().__init__()
if padding is None:
assert (kernel_size % 2 == 1)
padding = int(dilation * (kernel_size - 1) / 2)
self.conv = nn.Conv1D(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias_attr=bias)
xavier_uniform_(
self.conv.weight, gain=_calculate_gain(w_init_gain, param=param))
def forward(self, signal: paddle.Tensor):
conv_signal = self.conv(signal)
return conv_signal
class CausualConv(nn.Layer):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: int=1,
stride: int=1,
padding: int=1,
dilation: int=1,
bias: bool=True,
w_init_gain: str='linear',
param=None):
super().__init__()
if padding is None:
assert (kernel_size % 2 == 1)
padding = int(dilation * (kernel_size - 1) / 2) * 2
else:
self.padding = padding * 2
self.conv = nn.Conv1D(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
dilation=dilation,
bias_attr=bias)
xavier_uniform_(
self.conv.weight, gain=_calculate_gain(w_init_gain, param=param))
def forward(self, x: paddle.Tensor):
x = self.conv(x)
x = x[:, :, :-self.padding]
return x
class CausualBlock(nn.Layer):
def __init__(self,
hidden_dim: int,
n_conv: int=3,
dropout_p: float=0.2,
activ: str='lrelu'):
super().__init__()
self.blocks = nn.LayerList([
self._get_conv(
hidden_dim=hidden_dim,
dilation=3**i,
activ=activ,
dropout_p=dropout_p) for i in range(n_conv)
])
def forward(self, x):
for block in self.blocks:
res = x
x = block(x)
x += res
return x
def _get_conv(self,
hidden_dim: int,
dilation: int,
activ: str='lrelu',
dropout_p: float=0.2):
layers = [
CausualConv(
in_channels=hidden_dim,
out_channels=hidden_dim,
kernel_size=3,
padding=dilation,
dilation=dilation), _get_activation_fn(activ),
nn.BatchNorm1D(hidden_dim), nn.Dropout(p=dropout_p), CausualConv(
in_channels=hidden_dim,
out_channels=hidden_dim,
kernel_size=3,
padding=1,
dilation=1), _get_activation_fn(activ), nn.Dropout(p=dropout_p)
]
return nn.Sequential(*layers)
class ConvBlock(nn.Layer):
def __init__(self,
hidden_dim: int,
n_conv: int=3,
dropout_p: float=0.2,
activ: str='relu'):
super().__init__()
self._n_groups = 8
self.blocks = nn.LayerList([
self._get_conv(
hidden_dim=hidden_dim,
dilation=3**i,
activ=activ,
dropout_p=dropout_p) for i in range(n_conv)
])
def forward(self, x: paddle.Tensor):
for block in self.blocks:
res = x
x = block(x)
x += res
return x
def _get_conv(self,
hidden_dim: int,
dilation: int,
activ: str='relu',
dropout_p: float=0.2):
layers = [
ConvNorm(
in_channels=hidden_dim,
out_channels=hidden_dim,
kernel_size=3,
padding=dilation,
dilation=dilation), _get_activation_fn(activ),
nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
nn.Dropout(p=dropout_p), ConvNorm(
hidden_dim, hidden_dim, kernel_size=3, padding=1,
dilation=1), _get_activation_fn(activ), nn.Dropout(p=dropout_p)
]
return nn.Sequential(*layers)
class LocationLayer(nn.Layer):
def __init__(self,
attention_n_filters: int,
attention_kernel_size: int,
attention_dim: int):
super().__init__()
padding = int((attention_kernel_size - 1) / 2)
self.location_conv = ConvNorm(
in_channels=2,
out_channels=attention_n_filters,
kernel_size=attention_kernel_size,
padding=padding,
bias=False,
stride=1,
dilation=1)
self.location_dense = LinearNorm(
in_dim=attention_n_filters,
out_dim=attention_dim,
bias=False,
w_init_gain='tanh')
def forward(self, attention_weights_cat: paddle.Tensor):
processed_attention = self.location_conv(attention_weights_cat)
processed_attention = processed_attention.transpose([0, 2, 1])
processed_attention = self.location_dense(processed_attention)
return processed_attention
class Attention(nn.Layer):
def __init__(self,
attention_rnn_dim: int,
embedding_dim: int,
attention_dim: int,
attention_location_n_filters: int,
attention_location_kernel_size: int):
super().__init__()
self.query_layer = LinearNorm(
in_dim=attention_rnn_dim,
out_dim=attention_dim,
bias=False,
w_init_gain='tanh')
self.memory_layer = LinearNorm(
in_dim=embedding_dim,
out_dim=attention_dim,
bias=False,
w_init_gain='tanh')
self.v = LinearNorm(in_dim=attention_dim, out_dim=1, bias=False)
self.location_layer = LocationLayer(
attention_n_filters=attention_location_n_filters,
attention_kernel_size=attention_location_kernel_size,
attention_dim=attention_dim)
self.score_mask_value = -float("inf")
def get_alignment_energies(self,
query: paddle.Tensor,
processed_memory: paddle.Tensor,
attention_weights_cat: paddle.Tensor):
"""
Args:
query:
decoder output (batch, n_mel_channels * n_frames_per_step)
processed_memory:
processed encoder outputs (B, T_in, attention_dim)
attention_weights_cat:
cumulative and prev. att weights (B, 2, max_time)
Returns:
Tensor: alignment (batch, max_time)
"""
processed_query = self.query_layer(query.unsqueeze(1))
processed_attention_weights = self.location_layer(attention_weights_cat)
energies = self.v(
paddle.tanh(processed_query + processed_attention_weights +
processed_memory))
energies = energies.squeeze(-1)
return energies
def forward(self,
attention_hidden_state: paddle.Tensor,
memory: paddle.Tensor,
processed_memory: paddle.Tensor,
attention_weights_cat: paddle.Tensor,
mask: paddle.Tensor):
"""
Args:
attention_hidden_state:
attention rnn last output
memory:
encoder outputs
processed_memory:
processed encoder outputs
attention_weights_cat:
previous and cummulative attention weights
mask:
binary mask for padded data
"""
alignment = self.get_alignment_energies(
query=attention_hidden_state,
processed_memory=processed_memory,
attention_weights_cat=attention_weights_cat)
if mask is not None:
alignment.data.masked_fill_(mask, self.score_mask_value)
attention_weights = F.softmax(alignment, axis=1)
attention_context = paddle.bmm(attention_weights.unsqueeze(1), memory)
attention_context = attention_context.squeeze(1)
return attention_context, attention_weights
class ForwardAttentionV2(nn.Layer):
def __init__(self,
attention_rnn_dim: int,
embedding_dim: int,
attention_dim: int,
attention_location_n_filters: int,
attention_location_kernel_size: int):
super().__init__()
self.query_layer = LinearNorm(
in_dim=attention_rnn_dim,
out_dim=attention_dim,
bias=False,
w_init_gain='tanh')
self.memory_layer = LinearNorm(
in_dim=embedding_dim,
out_dim=attention_dim,
bias=False,
w_init_gain='tanh')
self.v = LinearNorm(in_dim=attention_dim, out_dim=1, bias=False)
self.location_layer = LocationLayer(
attention_n_filters=attention_location_n_filters,
attention_kernel_size=attention_location_kernel_size,
attention_dim=attention_dim)
self.score_mask_value = -float(1e20)
def get_alignment_energies(self,
query: paddle.Tensor,
processed_memory: paddle.Tensor,
attention_weights_cat: paddle.Tensor):
"""
Args:
query:
decoder output (batch, n_mel_channels * n_frames_per_step)
processed_memory:
processed encoder outputs (B, T_in, attention_dim)
attention_weights_cat:
prev. and cumulative att weights (B, 2, max_time)
Returns:
Tensor: alignment (batch, max_time)
"""
processed_query = self.query_layer(query.unsqueeze(1))
processed_attention_weights = self.location_layer(attention_weights_cat)
energies = self.v(
paddle.tanh(processed_query + processed_attention_weights +
processed_memory))
energies = energies.squeeze(-1)
return energies
def forward(self,
attention_hidden_state: paddle.Tensor,
memory: paddle.Tensor,
processed_memory: paddle.Tensor,
attention_weights_cat: paddle.Tensor,
mask: paddle.Tensor,
log_alpha: paddle.Tensor):
"""
Args:
attention_hidden_state:
attention rnn last output
memory:
encoder outputs
processed_memory:
processed encoder outputs
attention_weights_cat:
previous and cummulative attention weights
mask:
binary mask for padded data
"""
log_energy = self.get_alignment_energies(
query=attention_hidden_state,
processed_memory=processed_memory,
attention_weights_cat=attention_weights_cat)
if mask is not None:
log_energy[:] = paddle.where(
mask,
paddle.full(log_energy.shape, self.score_mask_value,
log_energy.dtype), log_energy)
log_alpha_shift_padded = []
max_time = log_energy.shape[1]
for sft in range(2):
shifted = log_alpha[:, :max_time - sft]
shift_padded = F.pad(shifted, (sft, 0), 'constant',
self.score_mask_value)
log_alpha_shift_padded.append(shift_padded.unsqueeze(2))
biased = paddle.logsumexp(paddle.conat(log_alpha_shift_padded, 2), 2)
log_alpha_new = biased + log_energy
attention_weights = F.softmax(log_alpha_new, axis=1)
attention_context = paddle.bmm(attention_weights.unsqueeze(1), memory)
attention_context = attention_context.squeeze(1)
return attention_context, attention_weights, log_alpha_new
class PhaseShuffle2D(nn.Layer):
def __init__(self, n: int=2):
super().__init__()
self.n = n
self.random = random.Random(1)
def forward(self, x: paddle.Tensor, move: int=None):
# x.size = (B, C, M, L)
if move is None:
move = self.random.randint(-self.n, self.n)
if move == 0:
return x
else:
left = x[:, :, :, :move]
right = x[:, :, :, move:]
shuffled = paddle.concat([right, left], axis=3)
return shuffled
class PhaseShuffle1D(nn.Layer):
def __init__(self, n: int=2):
super().__init__()
self.n = n
self.random = random.Random(1)
def forward(self, x: paddle.Tensor, move: int=None):
# x.size = (B, C, M, L)
if move is None:
move = self.random.randint(-self.n, self.n)
if move == 0:
return x
else:
left = x[:, :, :move]
right = x[:, :, move:]
shuffled = paddle.concat([right, left], axis=2)
return shuffled
class MFCC(nn.Layer):
def __init__(self, n_mfcc: int=40, n_mels: int=80):
super().__init__()
self.n_mfcc = n_mfcc
self.n_mels = n_mels
self.norm = 'ortho'
dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
self.register_buffer('dct_mat', dct_mat)
def forward(self, mel_specgram: paddle.Tensor):
if len(mel_specgram.shape) == 2:
mel_specgram = mel_specgram.unsqueeze(0)
unsqueezed = True
else:
unsqueezed = False
# (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
# -> (channel, time, n_mfcc).tranpose(...)
mfcc = paddle.matmul(mel_specgram.transpose([0, 2, 1]),
self.dct_mat).transpose([0, 2, 1])
# unpack batch
if unsqueezed:
mfcc = mfcc.squeeze(0)
return mfcc

@ -0,0 +1,239 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import paddle
import paddle.nn.functional as F
from paddle import nn
from .layers import Attention
from .layers import ConvBlock
from .layers import ConvNorm
from .layers import LinearNorm
from .layers import MFCC
from paddlespeech.utils.initialize import uniform_
class ASRCNN(nn.Layer):
def __init__(
self,
input_dim: int=80,
hidden_dim: int=256,
n_token: int=35,
n_layers: int=6,
token_embedding_dim: int=256, ):
super().__init__()
self.n_token = n_token
self.n_down = 1
self.to_mfcc = MFCC()
self.init_cnn = ConvNorm(
in_channels=input_dim // 2,
out_channels=hidden_dim,
kernel_size=7,
padding=3,
stride=2)
self.cnns = nn.Sequential(* [
nn.Sequential(
ConvBlock(hidden_dim),
nn.GroupNorm(num_groups=1, num_channels=hidden_dim))
for n in range(n_layers)
])
self.projection = ConvNorm(
in_channels=hidden_dim, out_channels=hidden_dim // 2)
self.ctc_linear = nn.Sequential(
LinearNorm(in_dim=hidden_dim // 2, out_dim=hidden_dim),
nn.ReLU(), LinearNorm(in_dim=hidden_dim, out_dim=n_token))
self.asr_s2s = ASRS2S(
embedding_dim=token_embedding_dim,
hidden_dim=hidden_dim // 2,
n_token=n_token)
def forward(self,
x: paddle.Tensor,
src_key_padding_mask: paddle.Tensor=None,
text_input: paddle.Tensor=None):
x = self.to_mfcc(x)
x = self.init_cnn(x)
x = self.cnns(x)
x = self.projection(x)
x = x.transpose([0, 2, 1])
ctc_logit = self.ctc_linear(x)
if text_input is not None:
_, s2s_logit, s2s_attn = self.asr_s2s(
memory=x,
memory_mask=src_key_padding_mask,
text_input=text_input)
return ctc_logit, s2s_logit, s2s_attn
else:
return ctc_logit
def get_feature(self, x: paddle.Tensor):
x = self.to_mfcc(x.squeeze(1))
x = self.init_cnn(x)
x = self.cnns(x)
x = self.projection(x)
return x
def length_to_mask(self, lengths: paddle.Tensor):
mask = paddle.arange(lengths.max()).unsqueeze(0).expand(
(lengths.shape[0], -1)).astype(lengths.dtype)
mask = paddle.greater_than(mask + 1, lengths.unsqueeze(1))
return mask
def get_future_mask(self, out_length: int, unmask_future_steps: int=0):
"""
Args:
out_length (int):
returned mask shape is (out_length, out_length).
unmask_futre_steps (int):
unmasking future step size.
Return:
mask (paddle.BoolTensor):
mask future timesteps mask[i, j] = True if i > j + unmask_future_steps else False
"""
index_tensor = paddle.arange(out_length).unsqueeze(0).expand(
[out_length, -1])
mask = paddle.greater_than(index_tensor,
index_tensor.T + unmask_future_steps)
return mask
class ASRS2S(nn.Layer):
def __init__(self,
embedding_dim: int=256,
hidden_dim: int=512,
n_location_filters: int=32,
location_kernel_size: int=63,
n_token: int=40):
super().__init__()
self.embedding = nn.Embedding(n_token, embedding_dim)
val_range = math.sqrt(6 / hidden_dim)
uniform_(self.embedding.weight, -val_range, val_range)
self.decoder_rnn_dim = hidden_dim
self.project_to_n_symbols = nn.Linear(self.decoder_rnn_dim, n_token)
self.attention_layer = Attention(
attention_rnn_dim=self.decoder_rnn_dim,
embedding_dim=hidden_dim,
attention_dim=hidden_dim,
attention_location_n_filters=n_location_filters,
attention_location_kernel_size=location_kernel_size)
self.decoder_rnn = nn.LSTMCell(self.decoder_rnn_dim + embedding_dim,
self.decoder_rnn_dim)
self.project_to_hidden = nn.Sequential(
LinearNorm(in_dim=self.decoder_rnn_dim * 2, out_dim=hidden_dim),
nn.Tanh())
self.sos = 1
self.eos = 2
def initialize_decoder_states(self,
memory: paddle.Tensor,
mask: paddle.Tensor):
"""
moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim)
"""
B, L, H = memory.shape
dtype = memory.dtype
self.decoder_hidden = paddle.zeros(
(B, self.decoder_rnn_dim)).astype(dtype)
self.decoder_cell = paddle.zeros(
(B, self.decoder_rnn_dim)).astype(dtype)
self.attention_weights = paddle.zeros((B, L)).astype(dtype)
self.attention_weights_cum = paddle.zeros((B, L)).astype(dtype)
self.attention_context = paddle.zeros((B, H)).astype(dtype)
self.memory = memory
self.processed_memory = self.attention_layer.memory_layer(memory)
self.mask = mask
self.unk_index = 3
self.random_mask = 0.1
def forward(self,
memory: paddle.Tensor,
memory_mask: paddle.Tensor,
text_input: paddle.Tensor):
"""
moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim)
moemory_mask.shape = (B, L, )
texts_input.shape = (B, T)
"""
self.initialize_decoder_states(memory, memory_mask)
# text random mask
random_mask = (paddle.rand(text_input.shape) < self.random_mask)
_text_input = text_input.clone()
_text_input[:] = paddle.where(
condition=random_mask,
x=paddle.full(
shape=_text_input.shape,
fill_value=self.unk_index,
dtype=_text_input.dtype),
y=_text_input)
decoder_inputs = self.embedding(_text_input).transpose(
[1, 0, 2]) # -> [T, B, channel]
start_embedding = self.embedding(
paddle.to_tensor(
[self.sos] * decoder_inputs.shape[1], dtype=paddle.long))
decoder_inputs = paddle.concat(
(start_embedding.unsqueeze(0), decoder_inputs), axis=0)
hidden_outputs, logit_outputs, alignments = [], [], []
while len(hidden_outputs) < decoder_inputs.shape[0]:
decoder_input = decoder_inputs[len(hidden_outputs)]
hidden, logit, attention_weights = self.decode(decoder_input)
hidden_outputs += [hidden]
logit_outputs += [logit]
alignments += [attention_weights]
hidden_outputs, logit_outputs, alignments = \
self.parse_decoder_outputs(
hidden_outputs, logit_outputs, alignments)
return hidden_outputs, logit_outputs, alignments
def decode(self, decoder_input: paddle.Tensor):
cell_input = paddle.concat((decoder_input, self.attention_context), -1)
self.decoder_rnn.flatten_parameters()
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
cell_input, (self.decoder_hidden, self.decoder_cell))
attention_weights_cat = paddle.concat(
(self.attention_weights.unsqueeze(1),
self.attention_weights_cum.unsqueeze(1)),
axis=1)
self.attention_context, self.attention_weights = self.attention_layer(
self.decoder_hidden, self.memory, self.processed_memory,
attention_weights_cat, self.mask)
self.attention_weights_cum += self.attention_weights
hidden_and_context = paddle.concat(
(self.decoder_hidden, self.attention_context), -1)
hidden = self.project_to_hidden(hidden_and_context)
# dropout to increasing g
logit = self.project_to_n_symbols(F.dropout(hidden, 0.5, self.training))
return hidden, logit, self.attention_weights
def parse_decoder_outputs(self,
hidden: paddle.Tensor,
logit: paddle.Tensor,
alignments: paddle.Tensor):
# -> [B, T_out + 1, max_time]
alignments = paddle.stack(alignments).transpose([1, 0, 2])
# [T_out + 1, B, n_symbols] -> [B, T_out + 1, n_symbols]
logit = paddle.stack(logit).transpose([1, 0, 2])
hidden = paddle.stack(hidden).transpose([1, 0, 2])
return hidden, logit, alignments

@ -0,0 +1,13 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,234 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Implementation of model from:
Kum et al. - "Joint Detection and Classification of Singing Voice Melody Using
Convolutional Recurrent Neural Networks" (2019)
Link: https://www.semanticscholar.org/paper/Joint-Detection-and-Classification-of-Singing-Voice-Kum-Nam/60a2ad4c7db43bace75805054603747fcd062c0d
"""
import paddle
from paddle import nn
class JDCNet(nn.Layer):
"""
Joint Detection and Classification Network model for singing voice melody.
"""
def __init__(self,
num_class: int=722,
seq_len: int=31,
leaky_relu_slope: float=0.01):
super().__init__()
self.seq_len = seq_len
self.num_class = num_class
# input = (b, 1, 31, 513), b = batch size
self.conv_block = nn.Sequential(
# out: (b, 64, 31, 513)
nn.Conv2D(
in_channels=1,
out_channels=64,
kernel_size=3,
padding=1,
bias_attr=False),
nn.BatchNorm2D(num_features=64),
nn.LeakyReLU(leaky_relu_slope),
# (b, 64, 31, 513)
nn.Conv2D(64, 64, 3, padding=1, bias_attr=False), )
# res blocks
# (b, 128, 31, 128)
self.res_block1 = ResBlock(in_channels=64, out_channels=128)
# (b, 192, 31, 32)
self.res_block2 = ResBlock(in_channels=128, out_channels=192)
# (b, 256, 31, 8)
self.res_block3 = ResBlock(in_channels=192, out_channels=256)
# pool block
self.pool_block = nn.Sequential(
nn.BatchNorm2D(num_features=256),
nn.LeakyReLU(leaky_relu_slope),
# (b, 256, 31, 2)
nn.MaxPool2D(kernel_size=(1, 4)),
nn.Dropout(p=0.5), )
# maxpool layers (for auxiliary network inputs)
# in = (b, 128, 31, 513) from conv_block, out = (b, 128, 31, 2)
self.maxpool1 = nn.MaxPool2D(kernel_size=(1, 40))
# in = (b, 128, 31, 128) from res_block1, out = (b, 128, 31, 2)
self.maxpool2 = nn.MaxPool2D(kernel_size=(1, 20))
# in = (b, 128, 31, 32) from res_block2, out = (b, 128, 31, 2)
self.maxpool3 = nn.MaxPool2D(kernel_size=(1, 10))
# in = (b, 640, 31, 2), out = (b, 256, 31, 2)
self.detector_conv = nn.Sequential(
nn.Conv2D(
in_channels=640,
out_channels=256,
kernel_size=1,
bias_attr=False),
nn.BatchNorm2D(256),
nn.LeakyReLU(leaky_relu_slope),
nn.Dropout(p=0.5), )
# input: (b, 31, 512) - resized from (b, 256, 31, 2)
# output: (b, 31, 512)
self.bilstm_classifier = nn.LSTM(
input_size=512,
hidden_size=256,
time_major=False,
direction='bidirectional')
# input: (b, 31, 512) - resized from (b, 256, 31, 2)
# output: (b, 31, 512)
self.bilstm_detector = nn.LSTM(
input_size=512,
hidden_size=256,
time_major=False,
direction='bidirectional')
# input: (b * 31, 512)
# output: (b * 31, num_class)
self.classifier = nn.Linear(
in_features=512, out_features=self.num_class)
# input: (b * 31, 512)
# output: (b * 31, 2) - binary classifier
self.detector = nn.Linear(in_features=512, out_features=2)
# initialize weights
self.apply(self.init_weights)
def get_feature_GAN(self, x: paddle.Tensor):
seq_len = x.shape[-2]
x = x.astype(paddle.float32).transpose([0, 1, 3, 2] if len(x.shape) == 4
else [0, 2, 1])
convblock_out = self.conv_block(x)
resblock1_out = self.res_block1(convblock_out)
resblock2_out = self.res_block2(resblock1_out)
resblock3_out = self.res_block3(resblock2_out)
poolblock_out = self.pool_block[0](resblock3_out)
poolblock_out = self.pool_block[1](poolblock_out)
return poolblock_out.transpose([0, 1, 3, 2] if len(poolblock_out.shape)
== 4 else [0, 2, 1])
def forward(self, x: paddle.Tensor):
"""
Returns:
classification_prediction, detection_prediction
sizes: (b, 31, 722), (b, 31, 2)
"""
###############################
# forward pass for classifier #
###############################
x = x.transpose([0, 1, 3, 2] if len(x.shape) == 4 else
[0, 2, 1]).astype(paddle.float32)
convblock_out = self.conv_block(x)
resblock1_out = self.res_block1(convblock_out)
resblock2_out = self.res_block2(resblock1_out)
resblock3_out = self.res_block3(resblock2_out)
poolblock_out = self.pool_block[0](resblock3_out)
poolblock_out = self.pool_block[1](poolblock_out)
GAN_feature = poolblock_out.transpose([0, 1, 3, 2] if len(
poolblock_out.shape) == 4 else [0, 2, 1])
poolblock_out = self.pool_block[2](poolblock_out)
# (b, 256, 31, 2) => (b, 31, 256, 2) => (b, 31, 512)
classifier_out = poolblock_out.transpose([0, 2, 1, 3]).reshape(
(-1, self.seq_len, 512))
self.bilstm_classifier.flatten_parameters()
classifier_out, _ = self.bilstm_classifier(
classifier_out) # ignore the hidden states
classifier_out = classifier_out.reshape((-1, 512)) # (b * 31, 512)
classifier_out = self.classifier(classifier_out)
classifier_out = classifier_out.reshape(
(-1, self.seq_len, self.num_class)) # (b, 31, num_class)
# sizes: (b, 31, 722), (b, 31, 2)
# classifier output consists of predicted pitch classes per frame
# detector output consists of: (isvoice, notvoice) estimates per frame
return paddle.abs(classifier_out.squeeze()), GAN_feature, poolblock_out
@staticmethod
def init_weights(m):
if isinstance(m, nn.Linear):
nn.initializer.KaimingUniform()(m.weight)
if m.bias is not None:
nn.initializer.Constant(0)(m.bias)
elif isinstance(m, nn.Conv2D):
nn.initializer.XavierNormal()(m.weight)
elif isinstance(m, nn.LSTM) or isinstance(m, nn.LSTMCell):
for p in m.parameters():
if len(p.shape) >= 2:
nn.initializer.Orthogonal()(p)
else:
nn.initializer.Normal()(p)
class ResBlock(nn.Layer):
def __init__(self,
in_channels: int,
out_channels: int,
leaky_relu_slope=0.01):
super().__init__()
self.downsample = in_channels != out_channels
# BN / LReLU / MaxPool layer before the conv layer - see Figure 1b in the paper
self.pre_conv = nn.Sequential(
nn.BatchNorm2D(num_features=in_channels),
nn.LeakyReLU(leaky_relu_slope),
# apply downsampling on the y axis only
nn.MaxPool2D(kernel_size=(1, 2)), )
# conv layers
self.conv = nn.Sequential(
nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
bias_attr=False),
nn.BatchNorm2D(out_channels),
nn.LeakyReLU(leaky_relu_slope),
nn.Conv2D(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
bias_attr=False), )
# 1 x 1 convolution layer to match the feature dimensions
self.conv1by1 = None
if self.downsample:
self.conv1by1 = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
bias_attr=False)
def forward(self, x: paddle.Tensor):
x = self.pre_conv(x)
if self.downsample:
x = self.conv(x) + self.conv1by1(x)
else:
x = self.conv(x) + x
return x

@ -0,0 +1,17 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .starganv2_vc import *
from .starganv2_vc_updater import *
from .ASR.model import *
from .JDC.model import *

@ -0,0 +1,255 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn.functional as F
from munch import Munch
from starganv2vc_paddle.transforms import build_transforms
# 这些都写到 updater 里
def compute_d_loss(nets,
args,
x_real,
y_org,
y_trg,
z_trg=None,
x_ref=None,
use_r1_reg=True,
use_adv_cls=False,
use_con_reg=False):
args = Munch(args)
assert (z_trg is None) != (x_ref is None)
# with real audios
x_real.stop_gradient = False
out = nets.discriminator(x_real, y_org)
loss_real = adv_loss(out, 1)
# R1 regularizaition (https://arxiv.org/abs/1801.04406v4)
if use_r1_reg:
loss_reg = r1_reg(out, x_real)
else:
loss_reg = paddle.to_tensor([0.], dtype=paddle.float32)
# consistency regularization (bCR-GAN: https://arxiv.org/abs/2002.04724)
loss_con_reg = paddle.to_tensor([0.], dtype=paddle.float32)
if use_con_reg:
t = build_transforms()
out_aug = nets.discriminator(t(x_real).detach(), y_org)
loss_con_reg += F.smooth_l1_loss(out, out_aug)
# with fake audios
with paddle.no_grad():
if z_trg is not None:
s_trg = nets.mapping_network(z_trg, y_trg)
else: # x_ref is not None
s_trg = nets.style_encoder(x_ref, y_trg)
F0 = nets.f0_model.get_feature_GAN(x_real)
x_fake = nets.generator(x_real, s_trg, masks=None, F0=F0)
out = nets.discriminator(x_fake, y_trg)
loss_fake = adv_loss(out, 0)
if use_con_reg:
out_aug = nets.discriminator(t(x_fake).detach(), y_trg)
loss_con_reg += F.smooth_l1_loss(out, out_aug)
# adversarial classifier loss
if use_adv_cls:
out_de = nets.discriminator.classifier(x_fake)
loss_real_adv_cls = F.cross_entropy(out_de[y_org != y_trg],
y_org[y_org != y_trg])
if use_con_reg:
out_de_aug = nets.discriminator.classifier(t(x_fake).detach())
loss_con_reg += F.smooth_l1_loss(out_de, out_de_aug)
else:
loss_real_adv_cls = paddle.zeros([1]).mean()
loss = loss_real + loss_fake + args.lambda_reg * loss_reg + \
args.lambda_adv_cls * loss_real_adv_cls + \
args.lambda_con_reg * loss_con_reg
return loss, Munch(
real=loss_real.item(),
fake=loss_fake.item(),
reg=loss_reg.item(),
real_adv_cls=loss_real_adv_cls.item(),
con_reg=loss_con_reg.item())
def compute_g_loss(nets,
args,
x_real,
y_org,
y_trg,
z_trgs=None,
x_refs=None,
use_adv_cls=False):
args = Munch(args)
assert (z_trgs is None) != (x_refs is None)
if z_trgs is not None:
z_trg, z_trg2 = z_trgs
if x_refs is not None:
x_ref, x_ref2 = x_refs
# compute style vectors
if z_trgs is not None:
s_trg = nets.mapping_network(z_trg, y_trg)
else:
s_trg = nets.style_encoder(x_ref, y_trg)
# compute ASR/F0 features (real)
with paddle.no_grad():
F0_real, GAN_F0_real, cyc_F0_real = nets.f0_model(x_real)
ASR_real = nets.asr_model.get_feature(x_real)
# adversarial loss
x_fake = nets.generator(x_real, s_trg, masks=None, F0=GAN_F0_real)
out = nets.discriminator(x_fake, y_trg)
loss_adv = adv_loss(out, 1)
# compute ASR/F0 features (fake)
F0_fake, GAN_F0_fake, _ = nets.f0_model(x_fake)
ASR_fake = nets.asr_model.get_feature(x_fake)
# norm consistency loss
x_fake_norm = log_norm(x_fake)
x_real_norm = log_norm(x_real)
loss_norm = ((
paddle.nn.ReLU()(paddle.abs(x_fake_norm - x_real_norm) - args.norm_bias)
)**2).mean()
# F0 loss
loss_f0 = f0_loss(F0_fake, F0_real)
# style F0 loss (style initialization)
if x_refs is not None and args.lambda_f0_sty > 0 and not use_adv_cls:
F0_sty, _, _ = nets.f0_model(x_ref)
loss_f0_sty = F.l1_loss(
compute_mean_f0(F0_fake), compute_mean_f0(F0_sty))
else:
loss_f0_sty = paddle.zeros([1]).mean()
# ASR loss
loss_asr = F.smooth_l1_loss(ASR_fake, ASR_real)
# style reconstruction loss
s_pred = nets.style_encoder(x_fake, y_trg)
loss_sty = paddle.mean(paddle.abs(s_pred - s_trg))
# diversity sensitive loss
if z_trgs is not None:
s_trg2 = nets.mapping_network(z_trg2, y_trg)
else:
s_trg2 = nets.style_encoder(x_ref2, y_trg)
x_fake2 = nets.generator(x_real, s_trg2, masks=None, F0=GAN_F0_real)
x_fake2 = x_fake2.detach()
_, GAN_F0_fake2, _ = nets.f0_model(x_fake2)
loss_ds = paddle.mean(paddle.abs(x_fake - x_fake2))
loss_ds += F.smooth_l1_loss(GAN_F0_fake, GAN_F0_fake2.detach())
# cycle-consistency loss
s_org = nets.style_encoder(x_real, y_org)
x_rec = nets.generator(x_fake, s_org, masks=None, F0=GAN_F0_fake)
loss_cyc = paddle.mean(paddle.abs(x_rec - x_real))
# F0 loss in cycle-consistency loss
if args.lambda_f0 > 0:
_, _, cyc_F0_rec = nets.f0_model(x_rec)
loss_cyc += F.smooth_l1_loss(cyc_F0_rec, cyc_F0_real)
if args.lambda_asr > 0:
ASR_recon = nets.asr_model.get_feature(x_rec)
loss_cyc += F.smooth_l1_loss(ASR_recon, ASR_real)
# adversarial classifier loss
if use_adv_cls:
out_de = nets.discriminator.classifier(x_fake)
loss_adv_cls = F.cross_entropy(out_de[y_org != y_trg],
y_trg[y_org != y_trg])
else:
loss_adv_cls = paddle.zeros([1]).mean()
loss = args.lambda_adv * loss_adv + args.lambda_sty * loss_sty \
- args.lambda_ds * loss_ds + args.lambda_cyc * loss_cyc\
+ args.lambda_norm * loss_norm \
+ args.lambda_asr * loss_asr \
+ args.lambda_f0 * loss_f0 \
+ args.lambda_f0_sty * loss_f0_sty \
+ args.lambda_adv_cls * loss_adv_cls
return loss, Munch(
adv=loss_adv.item(),
sty=loss_sty.item(),
ds=loss_ds.item(),
cyc=loss_cyc.item(),
norm=loss_norm.item(),
asr=loss_asr.item(),
f0=loss_f0.item(),
adv_cls=loss_adv_cls.item())
# for norm consistency loss
def log_norm(x, mean=-4, std=4, axis=2):
"""
normalized log mel -> mel -> norm -> log(norm)
"""
x = paddle.log(paddle.exp(x * std + mean).norm(axis=axis))
return x
# for adversarial loss
def adv_loss(logits, target):
assert target in [1, 0]
if len(logits.shape) > 1:
logits = logits.reshape([-1])
targets = paddle.full_like(logits, fill_value=target)
logits = logits.clip(min=-10, max=10) # prevent nan
loss = F.binary_cross_entropy_with_logits(logits, targets)
return loss
# for R1 regularization loss
def r1_reg(d_out, x_in):
# zero-centered gradient penalty for real images
batch_size = x_in.shape[0]
grad_dout = paddle.grad(
outputs=d_out.sum(),
inputs=x_in,
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
grad_dout2 = grad_dout.pow(2)
assert (grad_dout2.shape == x_in.shape)
reg = 0.5 * grad_dout2.reshape((batch_size, -1)).sum(1).mean(0)
return reg
# for F0 consistency loss
def compute_mean_f0(f0):
f0_mean = f0.mean(-1)
f0_mean = f0_mean.expand((f0.shape[-1], f0_mean.shape[0])).transpose(
(1, 0)) # (B, M)
return f0_mean
def f0_loss(x_f0, y_f0):
"""
x.shape = (B, 1, M, L): predict
y.shape = (B, 1, M, L): target
"""
# compute the mean
x_mean = compute_mean_f0(x_f0)
y_mean = compute_mean_f0(y_f0)
loss = F.l1_loss(x_f0 / x_mean, y_f0 / y_mean)
return loss

@ -0,0 +1,613 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
StarGAN v2
Copyright (c) 2020-present NAVER Corp.
This work is licensed under the Creative Commons Attribution-NonCommercial
4.0 International License. To view a copy of this license, visit
http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
"""
import copy
import math
import paddle
import paddle.nn.functional as F
from munch import Munch
from paddle import nn
from paddlespeech.utils.initialize import _calculate_gain
from paddlespeech.utils.initialize import xavier_uniform_
class DownSample(nn.Layer):
def __init__(self, layer_type: str):
super().__init__()
self.layer_type = layer_type
def forward(self, x):
if self.layer_type == 'none':
return x
elif self.layer_type == 'timepreserve':
return F.avg_pool2d(x, (2, 1))
elif self.layer_type == 'half':
return F.avg_pool2d(x, 2)
else:
raise RuntimeError(
'Got unexpected donwsampletype %s, expected is [none, timepreserve, half]'
% self.layer_type)
class UpSample(nn.Layer):
def __init__(self, layer_type: str):
super().__init__()
self.layer_type = layer_type
def forward(self, x):
if self.layer_type == 'none':
return x
elif self.layer_type == 'timepreserve':
return F.interpolate(x, scale_factor=(2, 1), mode='nearest')
elif self.layer_type == 'half':
return F.interpolate(x, scale_factor=2, mode='nearest')
else:
raise RuntimeError(
'Got unexpected upsampletype %s, expected is [none, timepreserve, half]'
% self.layer_type)
class ResBlk(nn.Layer):
def __init__(self,
dim_in: int,
dim_out: int,
actv: nn.LeakyReLU=nn.LeakyReLU(0.2),
normalize: bool=False,
downsample: str='none'):
super().__init__()
self.actv = actv
self.normalize = normalize
self.downsample = DownSample(layer_type=downsample)
self.learned_sc = dim_in != dim_out
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in: int, dim_out: int):
self.conv1 = nn.Conv2D(
in_channels=dim_in,
out_channels=dim_in,
kernel_size=3,
stride=1,
padding=1)
self.conv2 = nn.Conv2D(
in_channels=dim_in,
out_channels=dim_out,
kernel_size=3,
stride=1,
padding=1)
if self.normalize:
self.norm1 = nn.InstanceNorm2D(dim_in)
self.norm2 = nn.InstanceNorm2D(dim_in)
if self.learned_sc:
self.conv1x1 = nn.Conv2D(
in_channels=dim_in,
out_channels=dim_out,
kernel_size=1,
stride=1,
padding=0,
bias_attr=False)
def _shortcut(self, x: paddle.Tensor):
if self.learned_sc:
x = self.conv1x1(x)
if self.downsample:
x = self.downsample(x)
return x
def _residual(self, x: paddle.Tensor):
if self.normalize:
x = self.norm1(x)
x = self.actv(x)
x = self.conv1(x)
x = self.downsample(x)
if self.normalize:
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x: paddle.Tensor):
x = self._shortcut(x) + self._residual(x)
# unit variance
return x / math.sqrt(2)
class AdaIN(nn.Layer):
def __init__(self, style_dim: int, num_features: int):
super().__init__()
self.norm = nn.InstanceNorm2D(
num_features=num_features, weight_attr=False, bias_attr=False)
self.fc = nn.Linear(style_dim, num_features * 2)
def forward(self, x: paddle.Tensor, s: paddle.Tensor):
if len(s.shape) == 1:
s = s[None]
h = self.fc(s)
h = h.reshape((h.shape[0], h.shape[1], 1, 1))
gamma, beta = paddle.split(h, 2, axis=1)
return (1 + gamma) * self.norm(x) + beta
class AdainResBlk(nn.Layer):
def __init__(self,
dim_in: int,
dim_out: int,
style_dim: int=64,
w_hpf: int=0,
actv: nn.Layer=nn.LeakyReLU(0.2),
upsample: str='none'):
super().__init__()
self.w_hpf = w_hpf
self.actv = actv
self.upsample = UpSample(layer_type=upsample)
self.learned_sc = dim_in != dim_out
self._build_weights(dim_in, dim_out, style_dim)
def _build_weights(self, dim_in: int, dim_out: int, style_dim: int=64):
self.conv1 = nn.Conv2D(
in_channels=dim_in,
out_channels=dim_out,
kernel_size=3,
stride=1,
padding=1)
self.conv2 = nn.Conv2D(
in_channels=dim_out,
out_channels=dim_out,
kernel_size=3,
stride=1,
padding=1)
self.norm1 = AdaIN(style_dim=style_dim, num_features=dim_in)
self.norm2 = AdaIN(style_dim=style_dim, num_features=dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2D(
in_channels=dim_in,
out_channels=dim_out,
kernel_size=1,
stride=1,
padding=0,
bias_attr=False)
def _shortcut(self, x: paddle.Tensor):
x = self.upsample(x)
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x: paddle.Tensor, s: paddle.Tensor):
x = self.norm1(x, s)
x = self.actv(x)
x = self.upsample(x)
x = self.conv1(x)
x = self.norm2(x, s)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x: paddle.Tensor, s: paddle.Tensor):
out = self._residual(x, s)
if self.w_hpf == 0:
out = (out + self._shortcut(x)) / math.sqrt(2)
return out
class HighPass(nn.Layer):
def __init__(self, w_hpf: int):
super().__init__()
self.filter = paddle.to_tensor([[-1, -1, -1], [-1, 8., -1],
[-1, -1, -1]]) / w_hpf
def forward(self, x: paddle.Tensor):
filter = self.filter.unsqueeze(0).unsqueeze(1).tile(
[x.shape[1], 1, 1, 1])
return F.conv2d(x, filter, padding=1, groups=x.shape[1])
class Generator(nn.Layer):
def __init__(self,
dim_in: int=48,
style_dim: int=48,
max_conv_dim: int=48 * 8,
w_hpf: int=1,
F0_channel: int=0):
super().__init__()
self.stem = nn.Conv2D(
in_channels=1,
out_channels=dim_in,
kernel_size=3,
stride=1,
padding=1)
self.encode = nn.LayerList()
self.decode = nn.LayerList()
self.to_out = nn.Sequential(
nn.InstanceNorm2D(dim_in),
nn.LeakyReLU(0.2),
nn.Conv2D(
in_channels=dim_in,
out_channels=1,
kernel_size=1,
stride=1,
padding=0))
self.F0_channel = F0_channel
# down/up-sampling blocks
# int(np.log2(img_size)) - 4
repeat_num = 4
if w_hpf > 0:
repeat_num += 1
for lid in range(repeat_num):
if lid in [1, 3]:
_downtype = 'timepreserve'
else:
_downtype = 'half'
dim_out = min(dim_in * 2, max_conv_dim)
self.encode.append(
ResBlk(
dim_in=dim_in,
dim_out=dim_out,
normalize=True,
downsample=_downtype))
(self.decode.insert if lid else
lambda i, sublayer: self.decode.append(sublayer))(0, AdainResBlk(
dim_in=dim_out,
dim_out=dim_in,
style_dim=style_dim,
w_hpf=w_hpf,
upsample=_downtype)) # stack-like
dim_in = dim_out
# bottleneck blocks (encoder)
for _ in range(2):
self.encode.append(
ResBlk(dim_in=dim_out, dim_out=dim_out, normalize=True))
# F0 blocks
if F0_channel != 0:
self.decode.insert(0,
AdainResBlk(
dim_in=dim_out + int(F0_channel / 2),
dim_out=dim_out,
style_dim=style_dim,
w_hpf=w_hpf))
# bottleneck blocks (decoder)
for _ in range(2):
self.decode.insert(0,
AdainResBlk(
dim_in=dim_out + int(F0_channel / 2),
dim_out=dim_out + int(F0_channel / 2),
style_dim=style_dim,
w_hpf=w_hpf))
if F0_channel != 0:
self.F0_conv = nn.Sequential(
ResBlk(
dim_in=F0_channel,
dim_out=int(F0_channel / 2),
normalize=True,
downsample="half"), )
if w_hpf > 0:
self.hpf = HighPass(w_hpf)
def forward(self,
x: paddle.Tensor,
s: paddle.Tensor,
masks: paddle.Tensor=None,
F0: paddle.Tensor=None):
x = self.stem(x)
cache = {}
for block in self.encode:
if (masks is not None) and (x.shape[2] in [32, 64, 128]):
cache[x.shape[2]] = x
x = block(x)
if F0 is not None:
F0 = self.F0_conv(F0)
F0 = F.adaptive_avg_pool2d(F0, [x.shape[-2], x.shape[-1]])
x = paddle.concat([x, F0], axis=1)
for block in self.decode:
x = block(x, s)
if (masks is not None) and (x.shape[2] in [32, 64, 128]):
mask = masks[0] if x.shape[2] in [32] else masks[1]
mask = F.interpolate(mask, size=x.shape[2], mode='bilinear')
x = x + self.hpf(mask * cache[x.shape[2]])
return self.to_out(x)
class MappingNetwork(nn.Layer):
def __init__(self,
latent_dim: int=16,
style_dim: int=48,
num_domains: int=2,
hidden_dim: int=384):
super().__init__()
layers = []
layers += [nn.Linear(latent_dim, hidden_dim)]
layers += [nn.ReLU()]
for _ in range(3):
layers += [nn.Linear(hidden_dim, hidden_dim)]
layers += [nn.ReLU()]
self.shared = nn.Sequential(*layers)
self.unshared = nn.LayerList()
for _ in range(num_domains):
self.unshared.extend([
nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(), nn.Linear(hidden_dim, style_dim))
])
def forward(self, z: paddle.Tensor, y: paddle.Tensor):
h = self.shared(z)
out = []
for layer in self.unshared:
out += [layer(h)]
# (batch, num_domains, style_dim)
out = paddle.stack(out, axis=1)
idx = paddle.arange(y.shape[0])
# (batch, style_dim)
s = out[idx, y]
return s
class StyleEncoder(nn.Layer):
def __init__(self,
dim_in: int=48,
style_dim: int=48,
num_domains: int=2,
max_conv_dim: int=384):
super().__init__()
blocks = []
blocks += [
nn.Conv2D(
in_channels=1,
out_channels=dim_in,
kernel_size=3,
stride=1,
padding=1)
]
repeat_num = 4
for _ in range(repeat_num):
dim_out = min(dim_in * 2, max_conv_dim)
blocks += [
ResBlk(dim_in=dim_in, dim_out=dim_out, downsample='half')
]
dim_in = dim_out
blocks += [nn.LeakyReLU(0.2)]
blocks += [
nn.Conv2D(
in_channels=dim_out,
out_channels=dim_out,
kernel_size=5,
stride=1,
padding=0)
]
blocks += [nn.AdaptiveAvgPool2D(1)]
blocks += [nn.LeakyReLU(0.2)]
self.shared = nn.Sequential(*blocks)
self.unshared = nn.LayerList()
for _ in range(num_domains):
self.unshared.append(nn.Linear(dim_out, style_dim))
def forward(self, x: paddle.Tensor, y: paddle.Tensor):
h = self.shared(x)
h = h.reshape((h.shape[0], -1))
out = []
for layer in self.unshared:
out += [layer(h)]
# (batch, num_domains, style_dim)
out = paddle.stack(out, axis=1)
idx = paddle.arange(y.shape[0])
# (batch, style_dim)
s = out[idx, y]
return s
class Discriminator(nn.Layer):
def __init__(self,
dim_in: int=48,
num_domains: int=2,
max_conv_dim: int=384,
repeat_num: int=4):
super().__init__()
# real/fake discriminator
self.dis = Discriminator2D(
dim_in=dim_in,
num_domains=num_domains,
max_conv_dim=max_conv_dim,
repeat_num=repeat_num)
# adversarial classifier
self.cls = Discriminator2D(
dim_in=dim_in,
num_domains=num_domains,
max_conv_dim=max_conv_dim,
repeat_num=repeat_num)
self.num_domains = num_domains
def forward(self, x: paddle.Tensor, y: paddle.Tensor):
return self.dis(x, y)
def classifier(self, x: paddle.Tensor):
return self.cls.get_feature(x)
class LinearNorm(nn.Layer):
def __init__(self,
in_dim: int,
out_dim: int,
bias: bool=True,
w_init_gain: str='linear'):
super().__init__()
self.linear_layer = nn.Linear(in_dim, out_dim, bias_attr=bias)
xavier_uniform_(
self.linear_layer.weight, gain=_calculate_gain(w_init_gain))
def forward(self, x):
return self.linear_layer(x)
class Discriminator2D(nn.Layer):
def __init__(self,
dim_in: int=48,
num_domains: int=2,
max_conv_dim: int=384,
repeat_num: int=4):
super().__init__()
blocks = []
blocks += [
nn.Conv2D(
in_channels=1,
out_channels=dim_in,
kernel_size=3,
stride=1,
padding=1)
]
for lid in range(repeat_num):
dim_out = min(dim_in * 2, max_conv_dim)
blocks += [ResBlk(dim_in, dim_out, downsample='half')]
dim_in = dim_out
blocks += [nn.LeakyReLU(0.2)]
blocks += [
nn.Conv2D(
in_channels=dim_out,
out_channels=dim_out,
kernel_size=5,
stride=1,
padding=0)
]
blocks += [nn.LeakyReLU(0.2)]
blocks += [nn.AdaptiveAvgPool2D(1)]
blocks += [
nn.Conv2D(
in_channels=dim_out,
out_channels=num_domains,
kernel_size=1,
stride=1,
padding=0)
]
self.main = nn.Sequential(*blocks)
def get_feature(self, x: paddle.Tensor):
out = self.main(x)
# (batch, num_domains)
out = out.reshape((out.shape[0], -1))
return out
def forward(self, x: paddle.Tensor, y: paddle.Tensor):
out = self.get_feature(x)
idx = paddle.arange(y.shape[0])
# (batch)
out = out[idx, y]
return out
def build_model(args, F0_model: nn.Layer, ASR_model: nn.Layer):
generator = Generator(
dim_in=args.dim_in,
style_dim=args.style_dim,
max_conv_dim=args.max_conv_dim,
w_hpf=args.w_hpf,
F0_channel=args.F0_channel)
mapping_network = MappingNetwork(
latent_dim=args.latent_dim,
style_dim=args.style_dim,
num_domains=args.num_domains,
hidden_dim=args.max_conv_dim)
style_encoder = StyleEncoder(
dim_in=args.dim_in,
style_dim=args.style_dim,
num_domains=args.num_domains,
max_conv_dim=args.max_conv_dim)
discriminator = Discriminator(
dim_in=args.dim_in,
num_domains=args.num_domains,
max_conv_dim=args.max_conv_dim,
n_repeat=args.n_repeat)
generator_ema = copy.deepcopy(generator)
mapping_network_ema = copy.deepcopy(mapping_network)
style_encoder_ema = copy.deepcopy(style_encoder)
nets = Munch(
generator=generator,
mapping_network=mapping_network,
style_encoder=style_encoder,
discriminator=discriminator,
f0_model=F0_model,
asr_model=ASR_model)
nets_ema = Munch(
generator=generator_ema,
mapping_network=mapping_network_ema,
style_encoder=style_encoder_ema)
return nets, nets_ema
class StarGANv2VC(nn.Layer):
def __init__(
self,
# spk_num
num_domains: int=20,
dim_in: int=64,
style_dim: int=64,
latent_dim: int=16,
max_conv_dim: int=512,
n_repeat: int=4,
w_hpf: int=0,
F0_channel: int=256):
super().__init__()
self.generator = Generator(
dim_in=dim_in,
style_dim=style_dim,
max_conv_dim=max_conv_dim,
w_hpf=w_hpf,
F0_channel=F0_channel)
# MappingNetwork and StyleEncoder are used to generate reference_embeddings
self.mapping_network = MappingNetwork(
latent_dim=latent_dim,
style_dim=style_dim,
num_domains=num_domains,
hidden_dim=max_conv_dim)
self.style_encoder = StyleEncoder(
dim_in=dim_in,
style_dim=style_dim,
num_domains=num_domains,
max_conv_dim=max_conv_dim)
self.discriminator = Discriminator(
dim_in=dim_in,
num_domains=num_domains,
max_conv_dim=max_conv_dim,
repeat_num=n_repeat)

@ -0,0 +1,13 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading…
Cancel
Save