|
|
|
@ -78,7 +78,8 @@ class BatchDataLoader():
|
|
|
|
|
load_aux_input: bool=False,
|
|
|
|
|
load_aux_output: bool=False,
|
|
|
|
|
num_encs: int=1,
|
|
|
|
|
dist_sampler: bool=False):
|
|
|
|
|
dist_sampler: bool=False,
|
|
|
|
|
shortest_first: bool=False):
|
|
|
|
|
self.json_file = json_file
|
|
|
|
|
self.train_mode = train_mode
|
|
|
|
|
self.use_sortagrad = sortagrad == -1 or sortagrad > 0
|
|
|
|
@ -97,6 +98,7 @@ class BatchDataLoader():
|
|
|
|
|
self.load_aux_input = load_aux_input
|
|
|
|
|
self.load_aux_output = load_aux_output
|
|
|
|
|
self.dist_sampler = dist_sampler
|
|
|
|
|
self.shortest_first = shortest_first
|
|
|
|
|
|
|
|
|
|
# read json data
|
|
|
|
|
with jsonlines.open(json_file, 'r') as reader:
|
|
|
|
@ -113,7 +115,7 @@ class BatchDataLoader():
|
|
|
|
|
maxlen_out,
|
|
|
|
|
minibatches, # for debug
|
|
|
|
|
min_batch_size=mini_batch_size,
|
|
|
|
|
shortest_first=self.use_sortagrad,
|
|
|
|
|
shortest_first=self.shortest_first or self.use_sortagrad,
|
|
|
|
|
count=batch_count,
|
|
|
|
|
batch_bins=batch_bins,
|
|
|
|
|
batch_frames_in=batch_frames_in,
|
|
|
|
@ -149,13 +151,13 @@ class BatchDataLoader():
|
|
|
|
|
self.reader)
|
|
|
|
|
|
|
|
|
|
if self.dist_sampler:
|
|
|
|
|
self.sampler = DistributedBatchSampler(
|
|
|
|
|
self.batch_sampler = DistributedBatchSampler(
|
|
|
|
|
dataset=self.dataset,
|
|
|
|
|
batch_size=1,
|
|
|
|
|
shuffle=not self.use_sortagrad if self.train_mode else False,
|
|
|
|
|
drop_last=False, )
|
|
|
|
|
else:
|
|
|
|
|
self.sampler = BatchSampler(
|
|
|
|
|
self.batch_sampler = BatchSampler(
|
|
|
|
|
dataset=self.dataset,
|
|
|
|
|
batch_size=1,
|
|
|
|
|
shuffle=not self.use_sortagrad if self.train_mode else False,
|
|
|
|
@ -163,7 +165,7 @@ class BatchDataLoader():
|
|
|
|
|
|
|
|
|
|
self.dataloader = DataLoader(
|
|
|
|
|
dataset=self.dataset,
|
|
|
|
|
batch_sampler=self.sampler,
|
|
|
|
|
batch_sampler=self.batch_sampler,
|
|
|
|
|
collate_fn=batch_collate,
|
|
|
|
|
num_workers=self.n_iter_processes, )
|
|
|
|
|
|
|
|
|
@ -194,5 +196,6 @@ class BatchDataLoader():
|
|
|
|
|
echo += f"load_aux_input: {self.load_aux_input}, "
|
|
|
|
|
echo += f"load_aux_output: {self.load_aux_output}, "
|
|
|
|
|
echo += f"dist_sampler: {self.dist_sampler}, "
|
|
|
|
|
echo += f"shortest_first: {self.shortest_first}, "
|
|
|
|
|
echo += f"file: {self.json_file}"
|
|
|
|
|
return echo
|
|
|
|
|