You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
134 lines
4.3 KiB
134 lines
4.3 KiB
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
|
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
# See the LICENSE file for licensing terms (BSD-style).
|
|
# Modified from https://github.com/webdataset/webdataset
|
|
#%%
|
|
import copy, os, random, sys, time
|
|
from dataclasses import dataclass
|
|
from itertools import islice
|
|
from typing import List
|
|
|
|
import braceexpand, yaml
|
|
|
|
from webdataset import autodecode, extradatasets as eds, filters, shardlists, tariterators
|
|
from webdataset.handlers import reraise_exception
|
|
from .paddle_utils import DataLoader, IterableDataset
|
|
from .utils import PipelineStage
|
|
|
|
|
|
def add_length_method(obj):
|
|
def length(self):
|
|
return self.size
|
|
|
|
Combined = type(
|
|
obj.__class__.__name__ + "_Length",
|
|
(obj.__class__, IterableDataset),
|
|
{"__len__": length},
|
|
)
|
|
obj.__class__ = Combined
|
|
return obj
|
|
|
|
|
|
class DataPipeline(IterableDataset, PipelineStage):
|
|
"""A pipeline starting with an IterableDataset and a series of filters."""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__()
|
|
self.pipeline = []
|
|
self.length = -1
|
|
self.repetitions = 1
|
|
self.nsamples = -1
|
|
for arg in args:
|
|
if arg is None:
|
|
continue
|
|
if isinstance(arg, list):
|
|
self.pipeline.extend(arg)
|
|
else:
|
|
self.pipeline.append(arg)
|
|
|
|
def invoke(self, f, *args, **kwargs):
|
|
"""Apply a pipeline stage, possibly to the output of a previous stage."""
|
|
if isinstance(f, PipelineStage):
|
|
return f.run(*args, **kwargs)
|
|
if isinstance(f, (IterableDataset, DataLoader)) and len(args) == 0:
|
|
return iter(f)
|
|
if isinstance(f, list):
|
|
return iter(f)
|
|
if callable(f):
|
|
result = f(*args, **kwargs)
|
|
return result
|
|
raise ValueError(f"{f}: not a valid pipeline stage")
|
|
|
|
def iterator1(self):
|
|
"""Create an iterator through one epoch in the pipeline."""
|
|
source = self.invoke(self.pipeline[0])
|
|
for step in self.pipeline[1:]:
|
|
source = self.invoke(step, source)
|
|
return source
|
|
|
|
def iterator(self):
|
|
"""Create an iterator through the entire dataset, using the given number of repetitions."""
|
|
for i in range(self.repetitions):
|
|
for sample in self.iterator1():
|
|
yield sample
|
|
|
|
def __iter__(self):
|
|
"""Create an iterator through the pipeline, repeating and slicing as requested."""
|
|
if self.repetitions != 1:
|
|
if self.nsamples > 0:
|
|
return islice(self.iterator(), self.nsamples)
|
|
else:
|
|
return self.iterator()
|
|
else:
|
|
return self.iterator()
|
|
|
|
def stage(self, i):
|
|
"""Return pipeline stage i."""
|
|
return self.pipeline[i]
|
|
|
|
def append(self, f):
|
|
"""Append a pipeline stage (modifies the object)."""
|
|
self.pipeline.append(f)
|
|
return self
|
|
|
|
def append_list(self, *args):
|
|
for arg in args:
|
|
self.pipeline.append(arg)
|
|
return self
|
|
|
|
def compose(self, *args):
|
|
"""Append a pipeline stage to a copy of the pipeline and returns the copy."""
|
|
result = copy.copy(self)
|
|
for arg in args:
|
|
result.append(arg)
|
|
return result
|
|
|
|
def with_length(self, n):
|
|
"""Add a __len__ method returning the desired value.
|
|
|
|
This does not change the actual number of samples in an epoch.
|
|
PyTorch IterableDataset should not have a __len__ method.
|
|
This is provided only as a workaround for some broken training environments
|
|
that require a __len__ method.
|
|
"""
|
|
self.size = n
|
|
return add_length_method(self)
|
|
|
|
def with_epoch(self, nsamples=-1, nbatches=-1):
|
|
"""Change the epoch to return the given number of samples/batches.
|
|
|
|
The two arguments mean the same thing."""
|
|
self.repetitions = sys.maxsize
|
|
self.nsamples = max(nsamples, nbatches)
|
|
return self
|
|
|
|
def repeat(self, nepochs=-1, nbatches=-1):
|
|
"""Repeat iterating through the dataset for the given #epochs up to the given #samples."""
|
|
if nepochs > 0:
|
|
self.repetitions = nepochs
|
|
self.nsamples = nbatches
|
|
else:
|
|
self.repetitions = sys.maxsize
|
|
self.nsamples = nbatches
|
|
return self
|