From 760e4d71a4326ae6355a33e11c91dba7e58535fc Mon Sep 17 00:00:00 2001 From: WongLaw Date: Tue, 27 Dec 2022 09:01:52 +0000 Subject: [PATCH] Adjust Sampler for VITS, test=tts --- paddlespeech/t2s/exps/vits/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddlespeech/t2s/exps/vits/train.py b/paddlespeech/t2s/exps/vits/train.py index c994faa5a..c0238a98a 100644 --- a/paddlespeech/t2s/exps/vits/train.py +++ b/paddlespeech/t2s/exps/vits/train.py @@ -24,13 +24,13 @@ import yaml from paddle import DataParallel from paddle import distributed as dist from paddle.io import DataLoader -from paddle.io import DistributedBatchSampler from paddle.optimizer import Adam from yacs.config import CfgNode from paddlespeech.t2s.datasets.am_batch_fn import vits_multi_spk_batch_fn from paddlespeech.t2s.datasets.am_batch_fn import vits_single_spk_batch_fn from paddlespeech.t2s.datasets.data_table import DataTable +from paddlespeech.t2s.datasets.sampler import ErnieSATSampler from paddlespeech.t2s.models.vits import VITS from paddlespeech.t2s.models.vits import VITSEvaluator from paddlespeech.t2s.models.vits import VITSUpdater @@ -107,12 +107,12 @@ def train_sp(args, config): converters=converters, ) # collate function and dataloader - train_sampler = DistributedBatchSampler( + train_sampler = ErnieSATSampler( train_dataset, batch_size=config.batch_size, shuffle=True, drop_last=True) - dev_sampler = DistributedBatchSampler( + dev_sampler = ErnieSATSampler( dev_dataset, batch_size=config.batch_size, shuffle=False,