Adjust Sampler for VITS, test=tts

pull/2770/head
WongLaw 3 years ago
parent c5f8e44e53
commit 760e4d71a4

@ -24,13 +24,13 @@ import yaml
from paddle import DataParallel from paddle import DataParallel
from paddle import distributed as dist from paddle import distributed as dist
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from paddle.optimizer import Adam from paddle.optimizer import Adam
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.t2s.datasets.am_batch_fn import vits_multi_spk_batch_fn from paddlespeech.t2s.datasets.am_batch_fn import vits_multi_spk_batch_fn
from paddlespeech.t2s.datasets.am_batch_fn import vits_single_spk_batch_fn from paddlespeech.t2s.datasets.am_batch_fn import vits_single_spk_batch_fn
from paddlespeech.t2s.datasets.data_table import DataTable from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.datasets.sampler import ErnieSATSampler
from paddlespeech.t2s.models.vits import VITS from paddlespeech.t2s.models.vits import VITS
from paddlespeech.t2s.models.vits import VITSEvaluator from paddlespeech.t2s.models.vits import VITSEvaluator
from paddlespeech.t2s.models.vits import VITSUpdater from paddlespeech.t2s.models.vits import VITSUpdater
@ -107,12 +107,12 @@ def train_sp(args, config):
converters=converters, ) converters=converters, )
# collate function and dataloader # collate function and dataloader
train_sampler = DistributedBatchSampler( train_sampler = ErnieSATSampler(
train_dataset, train_dataset,
batch_size=config.batch_size, batch_size=config.batch_size,
shuffle=True, shuffle=True,
drop_last=True) drop_last=True)
dev_sampler = DistributedBatchSampler( dev_sampler = ErnieSATSampler(
dev_dataset, dev_dataset,
batch_size=config.batch_size, batch_size=config.batch_size,
shuffle=False, shuffle=False,

Loading…
Cancel
Save