fix dataloader pickle bugs

pull/785/head
Hui Zhang 3 years ago
parent dde3267e9b
commit d1db859657

@ -83,11 +83,13 @@ FILES = glob.glob('kenlm/util/*.cc') \
FILES += glob.glob('openfst-1.6.3/src/lib/*.cc') FILES += glob.glob('openfst-1.6.3/src/lib/*.cc')
# yapf: disable
FILES = [ FILES = [
fn for fn in FILES fn for fn in FILES
if not (fn.endswith('main.cc') or fn.endswith('test.cc') or fn.endswith( if not (fn.endswith('main.cc') or fn.endswith('test.cc') or fn.endswith(
'unittest.cc')) 'unittest.cc'))
] ]
# yapf: enable
LIBS = ['stdc++'] LIBS = ['stdc++']
if platform.system() != 'Darwin': if platform.system() != 'Darwin':

@ -171,10 +171,7 @@ class U2Trainer(Trainer):
if from_scratch: if from_scratch:
# save init model, i.e. 0 epoch # save init model, i.e. 0 epoch
self.save(tag='init') self.save(tag='init')
self.lr_scheduler.step(self.iteration) self.lr_scheduler.step(self.iteration)
if self.parallel:
self.train_loader.batch_sampler.set_epoch(self.epoch)
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch: while self.epoch < self.config.training.n_epoch:

@ -153,7 +153,7 @@ class SpecAugmentor(AugmentorBase):
window = max_time_warp = self.W window = max_time_warp = self.W
if window == 0: if window == 0:
return x return x
if mode == "PIL": if mode == "PIL":
t = x.shape[0] t = x.shape[0]
if t - window <= window: if t - window <= window:

@ -43,7 +43,7 @@ class CustomConverter():
batch (list): The batch to transform. batch (list): The batch to transform.
Returns: Returns:
tuple(paddle.Tensor, paddle.Tensor, paddle.Tensor) tuple(np.ndarray, nn.ndarray, nn.ndarray)
""" """
# batch should be located in list # batch should be located in list

@ -43,6 +43,18 @@ def feat_dim_and_vocab_size(data_json: List[Dict[Text, Any]],
return feat_dim, vocab_size return feat_dim, vocab_size
def batch_collate(x):
"""de-tuple.
Args:
x (List[Tuple]): [(utts, xs, ilens, ys, olens)]
Returns:
Tuple: (utts, xs, ilens, ys, olens)
"""
return x[0]
class BatchDataLoader(): class BatchDataLoader():
def __init__(self, def __init__(self,
json_file: str, json_file: str,
@ -120,15 +132,15 @@ class BatchDataLoader():
# actual bathsize is included in a list # actual bathsize is included in a list
# default collate function converts numpy array to pytorch tensor # default collate function converts numpy array to pytorch tensor
# we used an empty collate function instead which returns list # we used an empty collate function instead which returns list
self.dataset = TransformDataset( self.dataset = TransformDataset(self.minibaches, self.converter,
self.minibaches, self.reader)
lambda data: self.converter([self.reader(data, return_uttid=True)]))
self.dataloader = DataLoader( self.dataloader = DataLoader(
dataset=self.dataset, dataset=self.dataset,
batch_size=1, batch_size=1,
shuffle=not self.use_sortagrad if train_mode else False, shuffle=not self.use_sortagrad if self.train_mode else False,
collate_fn=lambda x: x[0], collate_fn=batch_collate,
num_workers=n_iter_processes, ) num_workers=self.n_iter_processes, )
def __repr__(self): def __repr__(self):
echo = f"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> " echo = f"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> "

@ -129,15 +129,16 @@ class TransformDataset(Dataset):
Args: Args:
data: list object from make_batchset data: list object from make_batchset
transfrom: transform function converter: batch function
reader: read data
""" """
def __init__(self, data, transform): def __init__(self, data, converter, reader):
"""Init function.""" """Init function."""
super().__init__() super().__init__()
self.data = data self.data = data
self.transform = transform self.converter = converter
self.reader = reader
def __len__(self): def __len__(self):
"""Len function.""" """Len function."""
@ -145,4 +146,4 @@ class TransformDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
"""[] operator.""" """[] operator."""
return self.transform(self.data[idx]) return self.converter([self.reader(self.data[idx], return_uttid=True)])

@ -29,7 +29,7 @@
"adaptive_number_ratio": 0, "adaptive_number_ratio": 0,
"adaptive_size_ratio": 0, "adaptive_size_ratio": 0,
"max_n_time_masks": 20, "max_n_time_masks": 20,
"replace_with_zero": true "replace_with_zero": true
}, },
"prob": 1.0 "prob": 1.0
} }

@ -2,17 +2,17 @@
{ {
"type": "specaug", "type": "specaug",
"params": { "params": {
"F": 10, "W": 5,
"T": 50, "warp_mode": "PIL",
"F": 30,
"n_freq_masks": 2, "n_freq_masks": 2,
"T": 40,
"n_time_masks": 2, "n_time_masks": 2,
"p": 1.0, "p": 1.0,
"W": 80,
"adaptive_number_ratio": 0, "adaptive_number_ratio": 0,
"adaptive_size_ratio": 0, "adaptive_size_ratio": 0,
"max_n_time_masks": 20, "max_n_time_masks": 20,
"replace_with_zero": true, "replace_with_zero": false
"warp_mode": "PIL"
}, },
"prob": 1.0 "prob": 1.0
} }

@ -8,26 +8,23 @@ collator:
vocab_filepath: data/train_960_unigram5000_units.txt vocab_filepath: data/train_960_unigram5000_units.txt
unit_type: 'spm' unit_type: 'spm'
spm_model_prefix: 'data/train_960_unigram5000' spm_model_prefix: 'data/train_960_unigram5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 64
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 83 feat_dim: 83
delta_delta: False
dither: 1.0
target_sample_rate: 16000
max_freq: None
n_fft: None
stride_ms: 10.0 stride_ms: 10.0
window_ms: 25.0 window_ms: 25.0
use_dB_normalization: True sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
target_dB: -20 batch_size: 32
random_seed: 0 maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
keep_transcription_text: False maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
sortagrad: True minibatches: 0 # for debug
shuffle_method: batch_shuffle batch_count: auto
batch_bins: 0
batch_frames_in: 0
batch_frames_out: 0
batch_frames_inout: 0
augmentation_config: conf/augmentation.json
num_workers: 2 num_workers: 2
subsampling_factor: 1
num_encs: 1
# network architecture # network architecture

Loading…
Cancel
Save