|
|
@ -131,10 +131,10 @@ def train_sp(args, config):
|
|
|
|
converters=converters, )
|
|
|
|
converters=converters, )
|
|
|
|
|
|
|
|
|
|
|
|
# collate function and dataloader
|
|
|
|
# collate function and dataloader
|
|
|
|
|
|
|
|
train_batch_size = min(len(train_metadata), config.batch_size)
|
|
|
|
train_sampler = DistributedBatchSampler(
|
|
|
|
train_sampler = DistributedBatchSampler(
|
|
|
|
train_dataset,
|
|
|
|
train_dataset,
|
|
|
|
batch_size=config.batch_size,
|
|
|
|
batch_size=train_batch_size,
|
|
|
|
shuffle=True,
|
|
|
|
shuffle=True,
|
|
|
|
drop_last=True)
|
|
|
|
drop_last=True)
|
|
|
|
|
|
|
|
|
|
|
|