|
|
|
@ -29,9 +29,12 @@ from paddle.optimizer import AdamW
|
|
|
|
|
from paddle.optimizer.lr import OneCycleLR
|
|
|
|
|
from yacs.config import CfgNode
|
|
|
|
|
|
|
|
|
|
from paddlespeech.t2s.datasets.am_batch_fn import starganv2_vc_batch_fn
|
|
|
|
|
from paddlespeech.t2s.datasets.data_table import DataTable
|
|
|
|
|
from paddlespeech.cli.utils import download_and_decompress
|
|
|
|
|
from paddlespeech.resource.pretrained_models import StarGANv2VC_source
|
|
|
|
|
from paddlespeech.t2s.datasets.am_batch_fn import build_starganv2_vc_collate_fn
|
|
|
|
|
from paddlespeech.t2s.datasets.data_table import StarGANv2VCDataTable
|
|
|
|
|
from paddlespeech.t2s.models.starganv2_vc import ASRCNN
|
|
|
|
|
from paddlespeech.t2s.models.starganv2_vc import Discriminator
|
|
|
|
|
from paddlespeech.t2s.models.starganv2_vc import Generator
|
|
|
|
|
from paddlespeech.t2s.models.starganv2_vc import JDCNet
|
|
|
|
|
from paddlespeech.t2s.models.starganv2_vc import MappingNetwork
|
|
|
|
@ -66,7 +69,9 @@ def train_sp(args, config):
|
|
|
|
|
fields = ["speech", "speech_lengths"]
|
|
|
|
|
converters = {"speech": np.load}
|
|
|
|
|
|
|
|
|
|
collate_fn = starganv2_vc_batch_fn
|
|
|
|
|
collate_fn = build_starganv2_vc_collate_fn(
|
|
|
|
|
latent_dim=config['mapping_network_params']['latent_dim'],
|
|
|
|
|
max_mel_length=config['max_mel_length'])
|
|
|
|
|
|
|
|
|
|
# dataloader has been too verbose
|
|
|
|
|
logging.getLogger("DataLoader").disabled = True
|
|
|
|
@ -74,16 +79,10 @@ def train_sp(args, config):
|
|
|
|
|
# construct dataset for training and validation
|
|
|
|
|
with jsonlines.open(args.train_metadata, 'r') as reader:
|
|
|
|
|
train_metadata = list(reader)
|
|
|
|
|
train_dataset = DataTable(
|
|
|
|
|
data=train_metadata,
|
|
|
|
|
fields=fields,
|
|
|
|
|
converters=converters, )
|
|
|
|
|
train_dataset = StarGANv2VCDataTable(data=train_metadata)
|
|
|
|
|
with jsonlines.open(args.dev_metadata, 'r') as reader:
|
|
|
|
|
dev_metadata = list(reader)
|
|
|
|
|
dev_dataset = DataTable(
|
|
|
|
|
data=dev_metadata,
|
|
|
|
|
fields=fields,
|
|
|
|
|
converters=converters, )
|
|
|
|
|
dev_dataset = StarGANv2VCDataTable(data=dev_metadata)
|
|
|
|
|
|
|
|
|
|
# collate function and dataloader
|
|
|
|
|
train_sampler = DistributedBatchSampler(
|
|
|
|
@ -118,6 +117,7 @@ def train_sp(args, config):
|
|
|
|
|
generator = Generator(**config['generator_params'])
|
|
|
|
|
mapping_network = MappingNetwork(**config['mapping_network_params'])
|
|
|
|
|
style_encoder = StyleEncoder(**config['style_encoder_params'])
|
|
|
|
|
discriminator = Discriminator(**config['discriminator_params'])
|
|
|
|
|
|
|
|
|
|
# load pretrained model
|
|
|
|
|
jdc_model_dir = os.path.join(uncompress_path, 'jdcnet.pdz')
|
|
|
|
|