fix dataloader pickle bugs

pull/785/head
Hui Zhang 3 years ago
parent dde3267e9b
commit d1db859657

@ -83,11 +83,13 @@ FILES = glob.glob('kenlm/util/*.cc') \
FILES += glob.glob('openfst-1.6.3/src/lib/*.cc')
# yapf: disable
FILES = [
fn for fn in FILES
if not (fn.endswith('main.cc') or fn.endswith('test.cc') or fn.endswith(
'unittest.cc'))
]
# yapf: enable
LIBS = ['stdc++']
if platform.system() != 'Darwin':

@ -171,10 +171,7 @@ class U2Trainer(Trainer):
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init')
self.lr_scheduler.step(self.iteration)
if self.parallel:
self.train_loader.batch_sampler.set_epoch(self.epoch)
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch:

@ -153,7 +153,7 @@ class SpecAugmentor(AugmentorBase):
window = max_time_warp = self.W
if window == 0:
return x
if mode == "PIL":
t = x.shape[0]
if t - window <= window:

@ -43,7 +43,7 @@ class CustomConverter():
batch (list): The batch to transform.
Returns:
tuple(paddle.Tensor, paddle.Tensor, paddle.Tensor)
tuple(np.ndarray, nn.ndarray, nn.ndarray)
"""
# batch should be located in list

@ -43,6 +43,18 @@ def feat_dim_and_vocab_size(data_json: List[Dict[Text, Any]],
return feat_dim, vocab_size
def batch_collate(x):
"""de-tuple.
Args:
x (List[Tuple]): [(utts, xs, ilens, ys, olens)]
Returns:
Tuple: (utts, xs, ilens, ys, olens)
"""
return x[0]
class BatchDataLoader():
def __init__(self,
json_file: str,
@ -120,15 +132,15 @@ class BatchDataLoader():
# actual bathsize is included in a list
# default collate function converts numpy array to pytorch tensor
# we used an empty collate function instead which returns list
self.dataset = TransformDataset(
self.minibaches,
lambda data: self.converter([self.reader(data, return_uttid=True)]))
self.dataset = TransformDataset(self.minibaches, self.converter,
self.reader)
self.dataloader = DataLoader(
dataset=self.dataset,
batch_size=1,
shuffle=not self.use_sortagrad if train_mode else False,
collate_fn=lambda x: x[0],
num_workers=n_iter_processes, )
shuffle=not self.use_sortagrad if self.train_mode else False,
collate_fn=batch_collate,
num_workers=self.n_iter_processes, )
def __repr__(self):
echo = f"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> "

@ -129,15 +129,16 @@ class TransformDataset(Dataset):
Args:
data: list object from make_batchset
transfrom: transform function
converter: batch function
reader: read data
"""
def __init__(self, data, transform):
def __init__(self, data, converter, reader):
"""Init function."""
super().__init__()
self.data = data
self.transform = transform
self.converter = converter
self.reader = reader
def __len__(self):
"""Len function."""
@ -145,4 +146,4 @@ class TransformDataset(Dataset):
def __getitem__(self, idx):
"""[] operator."""
return self.transform(self.data[idx])
return self.converter([self.reader(self.data[idx], return_uttid=True)])

@ -29,7 +29,7 @@
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20,
"replace_with_zero": true
"replace_with_zero": true
},
"prob": 1.0
}

@ -2,17 +2,17 @@
{
"type": "specaug",
"params": {
"F": 10,
"T": 50,
"W": 5,
"warp_mode": "PIL",
"F": 30,
"n_freq_masks": 2,
"T": 40,
"n_time_masks": 2,
"p": 1.0,
"W": 80,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20,
"replace_with_zero": true,
"warp_mode": "PIL"
"replace_with_zero": false
},
"prob": 1.0
}

@ -8,26 +8,23 @@ collator:
vocab_filepath: data/train_960_unigram5000_units.txt
unit_type: 'spm'
spm_model_prefix: 'data/train_960_unigram5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 64
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 83
delta_delta: False
dither: 1.0
target_sample_rate: 16000
max_freq: None
n_fft: None
stride_ms: 10.0
window_ms: 25.0
use_dB_normalization: True
target_dB: -20
random_seed: 0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 32
maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
minibatches: 0 # for debug
batch_count: auto
batch_bins: 0
batch_frames_in: 0
batch_frames_out: 0
batch_frames_inout: 0
augmentation_config: conf/augmentation.json
num_workers: 2
subsampling_factor: 1
num_encs: 1
# network architecture

Loading…
Cancel
Save