|
|
|
@ -18,6 +18,7 @@ from typing import Text
|
|
|
|
|
|
|
|
|
|
import jsonlines
|
|
|
|
|
import numpy as np
|
|
|
|
|
from paddle.io import BatchSampler
|
|
|
|
|
from paddle.io import DataLoader
|
|
|
|
|
from paddle.io import DistributedBatchSampler
|
|
|
|
|
|
|
|
|
@ -76,7 +77,8 @@ class BatchDataLoader():
|
|
|
|
|
subsampling_factor: int=1,
|
|
|
|
|
load_aux_input: bool=False,
|
|
|
|
|
load_aux_output: bool=False,
|
|
|
|
|
num_encs: int=1):
|
|
|
|
|
num_encs: int=1,
|
|
|
|
|
dist_sampler: bool=False):
|
|
|
|
|
self.json_file = json_file
|
|
|
|
|
self.train_mode = train_mode
|
|
|
|
|
self.use_sortagrad = sortagrad == -1 or sortagrad > 0
|
|
|
|
@ -94,6 +96,7 @@ class BatchDataLoader():
|
|
|
|
|
self.n_iter_processes = n_iter_processes
|
|
|
|
|
self.load_aux_input = load_aux_input
|
|
|
|
|
self.load_aux_output = load_aux_output
|
|
|
|
|
self.dist_sampler = dist_sampler
|
|
|
|
|
|
|
|
|
|
# read json data
|
|
|
|
|
with jsonlines.open(json_file, 'r') as reader:
|
|
|
|
@ -145,11 +148,18 @@ class BatchDataLoader():
|
|
|
|
|
self.dataset = TransformDataset(self.minibaches, self.converter,
|
|
|
|
|
self.reader)
|
|
|
|
|
|
|
|
|
|
self.sampler = DistributedBatchSampler(
|
|
|
|
|
dataset=self.dataset,
|
|
|
|
|
batch_size=1,
|
|
|
|
|
shuffle=not self.use_sortagrad if self.train_mode else False,
|
|
|
|
|
)
|
|
|
|
|
if self.dist_sampler:
|
|
|
|
|
self.sampler = DistributedBatchSampler(
|
|
|
|
|
dataset=self.dataset,
|
|
|
|
|
batch_size=1,
|
|
|
|
|
shuffle=not self.use_sortagrad if self.train_mode else False,
|
|
|
|
|
drop_last=False, )
|
|
|
|
|
else:
|
|
|
|
|
self.sampler = BatchSampler(
|
|
|
|
|
dataset=self.dataset,
|
|
|
|
|
batch_size=1,
|
|
|
|
|
shuffle=not self.use_sortagrad if self.train_mode else False,
|
|
|
|
|
drop_last=False, )
|
|
|
|
|
|
|
|
|
|
self.dataloader = DataLoader(
|
|
|
|
|
dataset=self.dataset,
|
|
|
|
@ -181,5 +191,8 @@ class BatchDataLoader():
|
|
|
|
|
echo += f"subsampling_factor: {self.subsampling_factor}, "
|
|
|
|
|
echo += f"num_encs: {self.num_encs}, "
|
|
|
|
|
echo += f"num_workers: {self.n_iter_processes}, "
|
|
|
|
|
echo += f"load_aux_input: {self.load_aux_input}, "
|
|
|
|
|
echo += f"load_aux_output: {self.load_aux_output}, "
|
|
|
|
|
echo += f"dist_sampler: {self.dist_sampler}, "
|
|
|
|
|
echo += f"file: {self.json_file}"
|
|
|
|
|
return echo
|
|
|
|
|