fix some trainer

pull/3182/head
TianYuan 2 years ago
parent dc56c3a10e
commit 90acc1b435

@ -41,7 +41,7 @@ discriminator_params:
dim_in: 64 # same as dim_in in generator_params
num_domains: 20 # same as num_domains in mapping_network_params
max_conv_dim: 512 # same as max_conv_dim in generator_params
n_repeat: 4
repeat_num: 4
asr_params:
input_dim: 80
hidden_dim: 256
@ -77,6 +77,7 @@ loss_params:
###########################################################
batch_size: 5 # Batch size.
num_workers: 2 # Number of workers in DataLoader.
max_mel_length: 192
###########################################################
# OPTIMIZER & SCHEDULER SETTING #
@ -84,47 +85,47 @@ num_workers: 2 # Number of workers in DataLoader.
generator_optimizer_params:
beta1: 0.0
beta2: 0.99
weight_decay: 1e-4
epsilon: 1e-9
weight_decay: 1.0e-4
epsilon: 1.0e-9
generator_scheduler_params:
max_learning_rate: 2e-4
max_learning_rate: 2.0e-4
phase_pct: 0.0
divide_factor: 1
total_steps: 200000 # train_max_steps
end_learning_rate: 2e-4
end_learning_rate: 2.0e-4
style_encoder_optimizer_params:
beta1: 0.0
beta2: 0.99
weight_decay: 1e-4
epsilon: 1e-9
weight_decay: 1.0e-4
epsilon: 1.0e-9
style_encoder_scheduler_params:
max_learning_rate: 2e-4
max_learning_rate: 2.0e-4
phase_pct: 0.0
divide_factor: 1
total_steps: 200000 # train_max_steps
end_learning_rate: 2e-4
end_learning_rate: 2.0e-4
mapping_network_optimizer_params:
beta1: 0.0
beta2: 0.99
weight_decay: 1e-4
epsilon: 1e-9
weight_decay: 1.0e-4
epsilon: 1.0e-9
mapping_network_scheduler_params:
max_learning_rate: 2e-6
max_learning_rate: 2.0e-6
phase_pct: 0.0
divide_factor: 1
total_steps: 200000 # train_max_steps
end_learning_rate: 2e-6
end_learning_rate: 2.0e-6
discriminator_optimizer_params:
beta1: 0.0
beta2: 0.99
weight_decay: 1e-4
epsilon: 1e-9
weight_decay: 1.0e-4
epsilon: 1.0e-9
discriminator_scheduler_params:
max_learning_rate: 2e-4
max_learning_rate: 2.0e-4
phase_pct: 0.0
divide_factor: 1
total_steps: 200000 # train_max_steps
end_learning_rate: 2e-4
end_learning_rate: 2.0e-4
###########################################################
# TRAINING SETTING #

@ -8,6 +8,4 @@ python3 ${BIN_DIR}/train.py \
--dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \
--output-dir=${train_output_path} \
--ngpu=1 \
--phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt
--ngpu=1

@ -852,20 +852,21 @@ class StarGANv2VCCollateFn:
# (B,)
label = paddle.to_tensor(label)
ref_label = paddle.to_tensor(ref_label)
# [B, 80, T] -> [B, 1, 80, T]
mel = paddle.to_tensor(mel)
ref_mel = paddle.to_tensor(ref_mel)
ref_mel_2 = paddle.to_tensor(ref_mel_2)
# [B, T, 80] -> [B, 1, 80, T]
mel = paddle.to_tensor(mel).transpose([0, 2, 1]).unsqueeze(1)
ref_mel = paddle.to_tensor(ref_mel).transpose([0, 2, 1]).unsqueeze(1)
ref_mel_2 = paddle.to_tensor(ref_mel_2).transpose(
[0, 2, 1]).unsqueeze(1)
z_trg = paddle.randn(batch_size, self.latent_dim)
z_trg2 = paddle.randn(batch_size, self.latent_dim)
z_trg = paddle.randn([batch_size, self.latent_dim])
z_trg2 = paddle.randn([batch_size, self.latent_dim])
batch = {
"x_real": mels,
"y_org": labels,
"x_ref": ref_mels,
"x_ref2": ref_mels_2,
"y_trg": ref_labels,
"x_real": mel,
"y_org": label,
"x_ref": ref_mel,
"x_ref2": ref_mel_2,
"y_trg": ref_label,
"z_trg": z_trg,
"z_trg2": z_trg2
}

@ -29,9 +29,12 @@ from paddle.optimizer import AdamW
from paddle.optimizer.lr import OneCycleLR
from yacs.config import CfgNode
from paddlespeech.t2s.datasets.am_batch_fn import starganv2_vc_batch_fn
from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.cli.utils import download_and_decompress
from paddlespeech.resource.pretrained_models import StarGANv2VC_source
from paddlespeech.t2s.datasets.am_batch_fn import build_starganv2_vc_collate_fn
from paddlespeech.t2s.datasets.data_table import StarGANv2VCDataTable
from paddlespeech.t2s.models.starganv2_vc import ASRCNN
from paddlespeech.t2s.models.starganv2_vc import Discriminator
from paddlespeech.t2s.models.starganv2_vc import Generator
from paddlespeech.t2s.models.starganv2_vc import JDCNet
from paddlespeech.t2s.models.starganv2_vc import MappingNetwork
@ -66,7 +69,9 @@ def train_sp(args, config):
fields = ["speech", "speech_lengths"]
converters = {"speech": np.load}
collate_fn = starganv2_vc_batch_fn
collate_fn = build_starganv2_vc_collate_fn(
latent_dim=config['mapping_network_params']['latent_dim'],
max_mel_length=config['max_mel_length'])
# dataloader has been too verbose
logging.getLogger("DataLoader").disabled = True
@ -74,16 +79,10 @@ def train_sp(args, config):
# construct dataset for training and validation
with jsonlines.open(args.train_metadata, 'r') as reader:
train_metadata = list(reader)
train_dataset = DataTable(
data=train_metadata,
fields=fields,
converters=converters, )
train_dataset = StarGANv2VCDataTable(data=train_metadata)
with jsonlines.open(args.dev_metadata, 'r') as reader:
dev_metadata = list(reader)
dev_dataset = DataTable(
data=dev_metadata,
fields=fields,
converters=converters, )
dev_dataset = StarGANv2VCDataTable(data=dev_metadata)
# collate function and dataloader
train_sampler = DistributedBatchSampler(
@ -118,6 +117,7 @@ def train_sp(args, config):
generator = Generator(**config['generator_params'])
mapping_network = MappingNetwork(**config['mapping_network_params'])
style_encoder = StyleEncoder(**config['style_encoder_params'])
discriminator = Discriminator(**config['discriminator_params'])
# load pretrained model
jdc_model_dir = os.path.join(uncompress_path, 'jdcnet.pdz')

@ -21,10 +21,13 @@ from paddle.nn import Layer
from paddle.optimizer import Optimizer
from paddle.optimizer.lr import LRScheduler
from paddlespeech.t2s.models.starganv2_vc.losses import compute_d_loss
from paddlespeech.t2s.models.starganv2_vc.losses import compute_g_loss
from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
from paddlespeech.t2s.training.reporter import report
from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater
from paddlespeech.t2s.training.updaters.standard_updater import UpdaterState
logging.basicConfig(
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
datefmt='[%Y-%m-%d %H:%M:%S]')
@ -62,10 +65,10 @@ class StarGANv2VCUpdater(StandardUpdater):
self.models = models
self.optimizers = optimizers
self.optimizer_g = optimizers['optimizer_g']
self.optimizer_s = optimizers['optimizer_s']
self.optimizer_m = optimizers['optimizer_m']
self.optimizer_d = optimizers['optimizer_d']
self.optimizer_g = optimizers['generator']
self.optimizer_s = optimizers['style_encoder']
self.optimizer_m = optimizers['mapping_network']
self.optimizer_d = optimizers['discriminator']
self.schedulers = schedulers
self.scheduler_g = schedulers['generator']

Loading…
Cancel
Save