diff --git a/paddlespeech/t2s/exps/vits/train.py b/paddlespeech/t2s/exps/vits/train.py index c994faa5a..c0238a98a 100644 --- a/paddlespeech/t2s/exps/vits/train.py +++ b/paddlespeech/t2s/exps/vits/train.py @@ -24,13 +24,13 @@ import yaml from paddle import DataParallel from paddle import distributed as dist from paddle.io import DataLoader -from paddle.io import DistributedBatchSampler from paddle.optimizer import Adam from yacs.config import CfgNode from paddlespeech.t2s.datasets.am_batch_fn import vits_multi_spk_batch_fn from paddlespeech.t2s.datasets.am_batch_fn import vits_single_spk_batch_fn from paddlespeech.t2s.datasets.data_table import DataTable +from paddlespeech.t2s.datasets.sampler import ErnieSATSampler from paddlespeech.t2s.models.vits import VITS from paddlespeech.t2s.models.vits import VITSEvaluator from paddlespeech.t2s.models.vits import VITSUpdater @@ -107,12 +107,12 @@ def train_sp(args, config): converters=converters, ) # collate function and dataloader - train_sampler = DistributedBatchSampler( + train_sampler = ErnieSATSampler( train_dataset, batch_size=config.batch_size, shuffle=True, drop_last=True) - dev_sampler = DistributedBatchSampler( + dev_sampler = ErnieSATSampler( dev_dataset, batch_size=config.batch_size, shuffle=False,