diff --git a/examples/other/tts_finetune/tts3/local/finetune.py b/examples/other/tts_finetune/tts3/local/finetune.py index 496c2355..814497aa 100644 --- a/examples/other/tts_finetune/tts3/local/finetune.py +++ b/examples/other/tts_finetune/tts3/local/finetune.py @@ -131,10 +131,10 @@ def train_sp(args, config): converters=converters, ) # collate function and dataloader - + train_batch_size = min(len(train_metadata), config.batch_size) train_sampler = DistributedBatchSampler( train_dataset, - batch_size=config.batch_size, + batch_size=train_batch_size, shuffle=True, drop_last=True)