# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # See the LICENSE file for licensing terms (BSD-style). # Modified from https://github.com/webdataset/webdataset #%% import copy, os, random, sys, time from dataclasses import dataclass from itertools import islice from typing import List import braceexpand, yaml from webdataset import autodecode, extradatasets as eds, filters, shardlists, tariterators from webdataset.handlers import reraise_exception from .paddle_utils import DataLoader, IterableDataset from .utils import PipelineStage def add_length_method(obj): def length(self): return self.size Combined = type( obj.__class__.__name__ + "_Length", (obj.__class__, IterableDataset), {"__len__": length}, ) obj.__class__ = Combined return obj class DataPipeline(IterableDataset, PipelineStage): """A pipeline starting with an IterableDataset and a series of filters.""" def __init__(self, *args, **kwargs): super().__init__() self.pipeline = [] self.length = -1 self.repetitions = 1 self.nsamples = -1 for arg in args: if arg is None: continue if isinstance(arg, list): self.pipeline.extend(arg) else: self.pipeline.append(arg) def invoke(self, f, *args, **kwargs): """Apply a pipeline stage, possibly to the output of a previous stage.""" if isinstance(f, PipelineStage): return f.run(*args, **kwargs) if isinstance(f, (IterableDataset, DataLoader)) and len(args) == 0: return iter(f) if isinstance(f, list): return iter(f) if callable(f): result = f(*args, **kwargs) return result raise ValueError(f"{f}: not a valid pipeline stage") def iterator1(self): """Create an iterator through one epoch in the pipeline.""" source = self.invoke(self.pipeline[0]) for step in self.pipeline[1:]: source = self.invoke(step, source) return source def iterator(self): """Create an iterator through the entire dataset, using the given number of repetitions.""" for i in range(self.repetitions): for sample in self.iterator1(): yield sample def __iter__(self): """Create an iterator through the pipeline, repeating and slicing as requested.""" if self.repetitions != 1: if self.nsamples > 0: return islice(self.iterator(), self.nsamples) else: return self.iterator() else: return self.iterator() def stage(self, i): """Return pipeline stage i.""" return self.pipeline[i] def append(self, f): """Append a pipeline stage (modifies the object).""" self.pipeline.append(f) return self def append_list(self, *args): for arg in args: self.pipeline.append(arg) return self def compose(self, *args): """Append a pipeline stage to a copy of the pipeline and returns the copy.""" result = copy.copy(self) for arg in args: result.append(arg) return result def with_length(self, n): """Add a __len__ method returning the desired value. This does not change the actual number of samples in an epoch. PyTorch IterableDataset should not have a __len__ method. This is provided only as a workaround for some broken training environments that require a __len__ method. """ self.size = n return add_length_method(self) def with_epoch(self, nsamples=-1, nbatches=-1): """Change the epoch to return the given number of samples/batches. The two arguments mean the same thing.""" self.repetitions = sys.maxsize self.nsamples = max(nsamples, nbatches) return self def repeat(self, nepochs=-1, nbatches=-1): """Repeat iterating through the dataset for the given #epochs up to the given #samples.""" if nepochs > 0: self.repetitions = nepochs self.nsamples = nbatches else: self.repetitions = sys.maxsize self.nsamples = nbatches return self