diff --git a/data_utils/data.py b/data_utils/data.py index 8bff6826..7ddf1f33 100644 --- a/data_utils/data.py +++ b/data_utils/data.py @@ -327,7 +327,7 @@ class DataGenerator(object): shift_len = self._rng.randint(0, batch_size - 1) batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size) self._rng.shuffle(batch_manifest) - batch_manifest = list(sum(batch_manifest, ())) + batch_manifest = [item for batch in batch_manifest for item in batch] if not clipped: res_len = len(manifest) - shift_len - len(batch_manifest) batch_manifest.extend(manifest[-res_len:])