[TTS]add starganv2 vc trainer (#3143)
* add starganv2 vc trainer * fix StarGANv2VCUpdater and losses * fix StarGANv2VCEvaluator * add some typehintpull/3155/head
parent
54ef90fcec
commit
72aa19c32c
@ -1,22 +1,123 @@
|
||||
generator_params:
|
||||
###########################################################
|
||||
# FEATURE EXTRACTION SETTING #
|
||||
###########################################################
|
||||
# 其实没用上,其实用的是 16000
|
||||
sr: 24000
|
||||
n_fft: 2048
|
||||
win_length: 1200
|
||||
hop_length: 300
|
||||
n_mels: 80
|
||||
###########################################################
|
||||
# MODEL SETTING #
|
||||
###########################################################
|
||||
generator_params:
|
||||
dim_in: 64
|
||||
style_dim: 64
|
||||
max_conv_dim: 512
|
||||
w_hpf: 0
|
||||
F0_channel: 256
|
||||
mapping_network_params:
|
||||
mapping_network_params:
|
||||
num_domains: 20 # num of speakers in StarGANv2
|
||||
latent_dim: 16
|
||||
style_dim: 64 # same as style_dim in generator_params
|
||||
hidden_dim: 512 # same as max_conv_dim in generator_params
|
||||
style_encoder_params:
|
||||
style_encoder_params:
|
||||
dim_in: 64 # same as dim_in in generator_params
|
||||
style_dim: 64 # same as style_dim in generator_params
|
||||
num_domains: 20 # same as num_domains in generator_params
|
||||
max_conv_dim: 512 # same as max_conv_dim in generator_params
|
||||
discriminator_params:
|
||||
discriminator_params:
|
||||
dim_in: 64 # same as dim_in in generator_params
|
||||
num_domains: 20 # same as num_domains in mapping_network_params
|
||||
max_conv_dim: 512 # same as max_conv_dim in generator_params
|
||||
n_repeat: 4
|
||||
asr_params:
|
||||
input_dim: 80
|
||||
hidden_dim: 256
|
||||
n_token: 80
|
||||
token_embedding_dim: 256
|
||||
|
||||
###########################################################
|
||||
# ADVERSARIAL LOSS SETTING #
|
||||
###########################################################
|
||||
loss_params:
|
||||
g_loss:
|
||||
lambda_sty: 1.
|
||||
lambda_cyc: 5.
|
||||
lambda_ds: 1.
|
||||
lambda_norm: 1.
|
||||
lambda_asr: 10.
|
||||
lambda_f0: 5.
|
||||
lambda_f0_sty: 0.1
|
||||
lambda_adv: 2.
|
||||
lambda_adv_cls: 0.5
|
||||
norm_bias: 0.5
|
||||
d_loss:
|
||||
lambda_reg: 1.
|
||||
lambda_adv_cls: 0.1
|
||||
lambda_con_reg: 10.
|
||||
|
||||
adv_cls_epoch: 50
|
||||
con_reg_epoch: 30
|
||||
|
||||
|
||||
###########################################################
|
||||
# DATA LOADER SETTING #
|
||||
###########################################################
|
||||
batch_size: 5 # Batch size.
|
||||
num_workers: 2 # Number of workers in DataLoader.
|
||||
|
||||
###########################################################
|
||||
# OPTIMIZER & SCHEDULER SETTING #
|
||||
###########################################################
|
||||
generator_optimizer_params:
|
||||
beta1: 0.0
|
||||
beta2: 0.99
|
||||
weight_decay: 1e-4
|
||||
epsilon: 1e-9
|
||||
generator_scheduler_params:
|
||||
max_learning_rate: 2e-4
|
||||
phase_pct: 0.0
|
||||
divide_factor: 1
|
||||
total_steps: 200000 # train_max_steps
|
||||
end_learning_rate: 2e-4
|
||||
style_encoder_optimizer_params:
|
||||
beta1: 0.0
|
||||
beta2: 0.99
|
||||
weight_decay: 1e-4
|
||||
epsilon: 1e-9
|
||||
style_encoder_scheduler_params:
|
||||
max_learning_rate: 2e-4
|
||||
phase_pct: 0.0
|
||||
divide_factor: 1
|
||||
total_steps: 200000 # train_max_steps
|
||||
end_learning_rate: 2e-4
|
||||
mapping_network_optimizer_params:
|
||||
beta1: 0.0
|
||||
beta2: 0.99
|
||||
weight_decay: 1e-4
|
||||
epsilon: 1e-9
|
||||
mapping_network_scheduler_params:
|
||||
max_learning_rate: 2e-6
|
||||
phase_pct: 0.0
|
||||
divide_factor: 1
|
||||
total_steps: 200000 # train_max_steps
|
||||
end_learning_rate: 2e-6
|
||||
discriminator_optimizer_params:
|
||||
beta1: 0.0
|
||||
beta2: 0.99
|
||||
weight_decay: 1e-4
|
||||
epsilon: 1e-9
|
||||
discriminator_scheduler_params:
|
||||
max_learning_rate: 2e-4
|
||||
phase_pct: 0.0
|
||||
divide_factor: 1
|
||||
total_steps: 200000 # train_max_steps
|
||||
end_learning_rate: 2e-4
|
||||
|
||||
###########################################################
|
||||
# TRAINING SETTING #
|
||||
###########################################################
|
||||
max_epoch: 150
|
||||
num_snapshots: 5
|
||||
seed: 1
|
@ -0,0 +1,259 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
import paddle
|
||||
import yaml
|
||||
from paddle import DataParallel
|
||||
from paddle import distributed as dist
|
||||
from paddle.io import DataLoader
|
||||
from paddle.io import DistributedBatchSampler
|
||||
from paddle.optimizer import AdamW
|
||||
from paddle.optimizer.lr import OneCycleLR
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from paddlespeech.t2s.datasets.am_batch_fn import starganv2_vc_batch_fn
|
||||
from paddlespeech.t2s.datasets.data_table import DataTable
|
||||
from paddlespeech.t2s.models.starganv2_vc import ASRCNN
|
||||
from paddlespeech.t2s.models.starganv2_vc import Generator
|
||||
from paddlespeech.t2s.models.starganv2_vc import JDCNet
|
||||
from paddlespeech.t2s.models.starganv2_vc import MappingNetwork
|
||||
from paddlespeech.t2s.models.starganv2_vc import StarGANv2VCEvaluator
|
||||
from paddlespeech.t2s.models.starganv2_vc import StarGANv2VCUpdater
|
||||
from paddlespeech.t2s.models.starganv2_vc import StyleEncoder
|
||||
from paddlespeech.t2s.training.extensions.snapshot import Snapshot
|
||||
from paddlespeech.t2s.training.extensions.visualizer import VisualDL
|
||||
from paddlespeech.t2s.training.seeding import seed_everything
|
||||
from paddlespeech.t2s.training.trainer import Trainer
|
||||
from paddlespeech.utils.env import MODEL_HOME
|
||||
|
||||
|
||||
def train_sp(args, config):
|
||||
# decides device type and whether to run in parallel
|
||||
# setup running environment correctly
|
||||
world_size = paddle.distributed.get_world_size()
|
||||
if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0:
|
||||
paddle.set_device("cpu")
|
||||
else:
|
||||
paddle.set_device("gpu")
|
||||
if world_size > 1:
|
||||
paddle.distributed.init_parallel_env()
|
||||
|
||||
# set the random seed, it is a must for multiprocess training
|
||||
seed_everything(config.seed)
|
||||
|
||||
print(
|
||||
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
|
||||
)
|
||||
# to edit
|
||||
fields = ["speech", "speech_lengths"]
|
||||
converters = {"speech": np.load}
|
||||
|
||||
collate_fn = starganv2_vc_batch_fn
|
||||
|
||||
# dataloader has been too verbose
|
||||
logging.getLogger("DataLoader").disabled = True
|
||||
|
||||
# construct dataset for training and validation
|
||||
with jsonlines.open(args.train_metadata, 'r') as reader:
|
||||
train_metadata = list(reader)
|
||||
train_dataset = DataTable(
|
||||
data=train_metadata,
|
||||
fields=fields,
|
||||
converters=converters, )
|
||||
with jsonlines.open(args.dev_metadata, 'r') as reader:
|
||||
dev_metadata = list(reader)
|
||||
dev_dataset = DataTable(
|
||||
data=dev_metadata,
|
||||
fields=fields,
|
||||
converters=converters, )
|
||||
|
||||
# collate function and dataloader
|
||||
train_sampler = DistributedBatchSampler(
|
||||
train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True)
|
||||
|
||||
print("samplers done!")
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_sampler=train_sampler,
|
||||
collate_fn=collate_fn,
|
||||
num_workers=config.num_workers)
|
||||
|
||||
dev_dataloader = DataLoader(
|
||||
dev_dataset,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
batch_size=config.batch_size,
|
||||
collate_fn=collate_fn,
|
||||
num_workers=config.num_workers)
|
||||
|
||||
print("dataloaders done!")
|
||||
|
||||
# load model
|
||||
model_version = '1.0'
|
||||
uncompress_path = download_and_decompress(StarGANv2VC_source[model_version],
|
||||
MODEL_HOME)
|
||||
|
||||
generator = Generator(**config['generator_params'])
|
||||
mapping_network = MappingNetwork(**config['mapping_network_params'])
|
||||
style_encoder = StyleEncoder(**config['style_encoder_params'])
|
||||
|
||||
# load pretrained model
|
||||
jdc_model_dir = os.path.join(uncompress_path, 'jdcnet.pdz')
|
||||
asr_model_dir = os.path.join(uncompress_path, 'asr.pdz')
|
||||
|
||||
F0_model = JDCNet(num_class=1, seq_len=192)
|
||||
F0_model.set_state_dict(paddle.load(jdc_model_dir)['main_params'])
|
||||
F0_model.eval()
|
||||
|
||||
asr_model = ASRCNN(**config['asr_params'])
|
||||
asr_model.set_state_dict(paddle.load(asr_model_dir)['main_params'])
|
||||
asr_model.eval()
|
||||
|
||||
if world_size > 1:
|
||||
generator = DataParallel(generator)
|
||||
discriminator = DataParallel(discriminator)
|
||||
print("models done!")
|
||||
|
||||
lr_schedule_g = OneCycleLR(**config["generator_scheduler_params"])
|
||||
optimizer_g = AdamW(
|
||||
learning_rate=lr_schedule_g,
|
||||
parameters=generator.parameters(),
|
||||
**config["generator_optimizer_params"])
|
||||
|
||||
lr_schedule_s = OneCycleLR(**config["style_encoder_scheduler_params"])
|
||||
optimizer_s = AdamW(
|
||||
learning_rate=lr_schedule_s,
|
||||
parameters=style_encoder.parameters(),
|
||||
**config["style_encoder_optimizer_params"])
|
||||
|
||||
lr_schedule_m = OneCycleLR(**config["mapping_network_scheduler_params"])
|
||||
optimizer_m = AdamW(
|
||||
learning_rate=lr_schedule_m,
|
||||
parameters=mapping_network.parameters(),
|
||||
**config["mapping_network_optimizer_params"])
|
||||
|
||||
lr_schedule_d = OneCycleLR(**config["discriminator_scheduler_params"])
|
||||
optimizer_d = AdamW(
|
||||
learning_rate=lr_schedule_d,
|
||||
parameters=discriminator.parameters(),
|
||||
**config["discriminator_optimizer_params"])
|
||||
print("optimizers done!")
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
if dist.get_rank() == 0:
|
||||
config_name = args.config.split("/")[-1]
|
||||
# copy conf to output_dir
|
||||
shutil.copyfile(args.config, output_dir / config_name)
|
||||
|
||||
updater = StarGANv2VCUpdater(
|
||||
models={
|
||||
"generator": generator,
|
||||
"style_encoder": style_encoder,
|
||||
"mapping_network": mapping_network,
|
||||
"discriminator": discriminator,
|
||||
"F0_model": F0_model,
|
||||
"asr_model": asr_model,
|
||||
},
|
||||
optimizers={
|
||||
"generator": optimizer_g,
|
||||
"style_encoder": optimizer_s,
|
||||
"mapping_network": optimizer_m,
|
||||
"discriminator": optimizer_d,
|
||||
},
|
||||
schedulers={
|
||||
"generator": lr_schedule_g,
|
||||
"style_encoder": lr_schedule_s,
|
||||
"mapping_network": lr_schedule_m,
|
||||
"discriminator": lr_schedule_d,
|
||||
},
|
||||
dataloader=train_dataloader,
|
||||
g_loss_params=config.loss_params.g_loss,
|
||||
d_loss_params=config.loss_params.d_loss,
|
||||
adv_cls_epoch=config.loss_params.adv_cls_epoch,
|
||||
con_reg_epoch=config.loss_params.con_reg_epoch,
|
||||
output_dir=output_dir)
|
||||
|
||||
evaluator = StarGANv2VCEvaluator(
|
||||
models={
|
||||
"generator": generator,
|
||||
"style_encoder": style_encoder,
|
||||
"mapping_network": mapping_network,
|
||||
"discriminator": discriminator,
|
||||
"F0_model": F0_model,
|
||||
"asr_model": asr_model,
|
||||
},
|
||||
dataloader=dev_dataloader,
|
||||
g_loss_params=config.loss_params.g_loss,
|
||||
d_loss_params=config.loss_params.d_loss,
|
||||
adv_cls_epoch=config.loss_params.adv_cls_epoch,
|
||||
con_reg_epoch=config.loss_params.con_reg_epoch,
|
||||
output_dir=output_dir)
|
||||
|
||||
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
trainer.extend(evaluator, trigger=(1, "epoch"))
|
||||
trainer.extend(VisualDL(output_dir), trigger=(1, "iteration"))
|
||||
trainer.extend(
|
||||
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
|
||||
print("Trainer Done!")
|
||||
|
||||
trainer.run()
|
||||
|
||||
|
||||
def main():
|
||||
# parse args and config and redirect to train_sp
|
||||
|
||||
parser = argparse.ArgumentParser(description="Train a HiFiGAN model.")
|
||||
parser.add_argument("--config", type=str, help="HiFiGAN config file.")
|
||||
parser.add_argument("--train-metadata", type=str, help="training data.")
|
||||
parser.add_argument("--dev-metadata", type=str, help="dev data.")
|
||||
parser.add_argument("--output-dir", type=str, help="output dir.")
|
||||
parser.add_argument(
|
||||
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config, 'rt') as f:
|
||||
config = CfgNode(yaml.safe_load(f))
|
||||
|
||||
print("========Args========")
|
||||
print(yaml.safe_dump(vars(args)))
|
||||
print("========Config========")
|
||||
print(config)
|
||||
print(
|
||||
f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
|
||||
)
|
||||
|
||||
# dispatch
|
||||
if args.ngpu > 1:
|
||||
dist.spawn(train_sp, (args, config), nprocs=args.ngpu)
|
||||
else:
|
||||
train_sp(args, config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,143 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
from paddle import nn
|
||||
|
||||
|
||||
## 1. RandomTimeStrech
|
||||
class TimeStrech(nn.Layer):
|
||||
def __init__(self, scale):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x: paddle.Tensor):
|
||||
mel_size = x.shape[-1]
|
||||
|
||||
x = F.interpolate(
|
||||
x,
|
||||
scale_factor=(1, self.scale),
|
||||
align_corners=False,
|
||||
mode='bilinear').squeeze()
|
||||
|
||||
if x.shape[-1] < mel_size:
|
||||
noise_length = (mel_size - x.shape[-1])
|
||||
random_pos = random.randint(0, x.shape[-1]) - noise_length
|
||||
if random_pos < 0:
|
||||
random_pos = 0
|
||||
noise = x[..., random_pos:random_pos + noise_length]
|
||||
x = paddle.concat([x, noise], axis=-1)
|
||||
else:
|
||||
x = x[..., :mel_size]
|
||||
|
||||
return x.unsqueeze(1)
|
||||
|
||||
|
||||
## 2. PitchShift
|
||||
class PitchShift(nn.Layer):
|
||||
def __init__(self, shift):
|
||||
super().__init__()
|
||||
self.shift = shift
|
||||
|
||||
def forward(self, x: paddle.Tensor):
|
||||
if len(x.shape) == 2:
|
||||
x = x.unsqueeze(0)
|
||||
x = x.squeeze()
|
||||
mel_size = x.shape[1]
|
||||
shift_scale = (mel_size + self.shift) / mel_size
|
||||
x = F.interpolate(
|
||||
x.unsqueeze(1),
|
||||
scale_factor=(shift_scale, 1.),
|
||||
align_corners=False,
|
||||
mode='bilinear').squeeze(1)
|
||||
|
||||
x = x[:, :mel_size]
|
||||
if x.shape[1] < mel_size:
|
||||
pad_size = mel_size - x.shape[1]
|
||||
x = paddle.cat(
|
||||
[x, paddle.zeros(x.shape[0], pad_size, x.shape[2])], axis=1)
|
||||
x = x.squeeze()
|
||||
return x.unsqueeze(1)
|
||||
|
||||
|
||||
## 3. ShiftBias
|
||||
class ShiftBias(nn.Layer):
|
||||
def __init__(self, bias):
|
||||
super().__init__()
|
||||
self.bias = bias
|
||||
|
||||
def forward(self, x: paddle.Tensor):
|
||||
return x + self.bias
|
||||
|
||||
|
||||
## 4. Scaling
|
||||
class SpectScaling(nn.Layer):
|
||||
def __init__(self, scale):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x: paddle.Tensor):
|
||||
return x * self.scale
|
||||
|
||||
|
||||
## 5. Time Flip
|
||||
class TimeFlip(nn.Layer):
|
||||
def __init__(self, length):
|
||||
super().__init__()
|
||||
self.length = round(length)
|
||||
|
||||
def forward(self, x: paddle.Tensor):
|
||||
if self.length > 1:
|
||||
start = np.random.randint(0, x.shape[-1] - self.length)
|
||||
x_ret = x.clone()
|
||||
x_ret[..., start:start + self.length] = paddle.flip(
|
||||
x[..., start:start + self.length], axis=[-1])
|
||||
x = x_ret
|
||||
return x
|
||||
|
||||
|
||||
class PhaseShuffle2D(nn.Layer):
|
||||
def __init__(self, n: int=2):
|
||||
super().__init__()
|
||||
self.n = n
|
||||
self.random = random.Random(1)
|
||||
|
||||
def forward(self, x: paddle.Tensor, move=None):
|
||||
# x.size = (B, C, M, L)
|
||||
if move is None:
|
||||
move = self.random.randint(-self.n, self.n)
|
||||
|
||||
if move == 0:
|
||||
return x
|
||||
else:
|
||||
left = x[:, :, :, :move]
|
||||
right = x[:, :, :, move:]
|
||||
shuffled = paddle.concat([right, left], axis=3)
|
||||
|
||||
return shuffled
|
||||
|
||||
|
||||
def build_transforms():
|
||||
transforms = [
|
||||
lambda M: TimeStrech(1 + (np.random.random() - 0.5) * M * 0.2),
|
||||
lambda M: SpectScaling(1 + (np.random.random() - 1) * M * 0.1),
|
||||
lambda M: PhaseShuffle2D(192),
|
||||
]
|
||||
N, M = len(transforms), np.random.random()
|
||||
composed = nn.Sequential(
|
||||
* [trans(M) for trans in np.random.choice(transforms, N)])
|
||||
return composed
|
Loading…
Reference in new issue