You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
142 lines
3.9 KiB
142 lines
3.9 KiB
3 years ago
|
#
|
||
|
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||
|
# This file is part of the WebDataset library.
|
||
|
# See the LICENSE file for licensing terms (BSD-style).
|
||
|
# Modified from https://github.com/webdataset/webdataset
|
||
|
#
|
||
|
|
||
|
|
||
|
"""Train PyTorch models directly from POSIX tar archive.
|
||
|
|
||
|
Code works locally or over HTTP connections.
|
||
|
"""
|
||
|
|
||
|
import itertools as itt
|
||
|
import os
|
||
|
import random
|
||
|
import sys
|
||
|
|
||
|
import braceexpand
|
||
|
|
||
|
from . import utils
|
||
|
from .paddle_utils import IterableDataset
|
||
|
from .utils import PipelineStage
|
||
|
|
||
|
|
||
|
class MockDataset(IterableDataset):
|
||
|
"""MockDataset.
|
||
|
|
||
|
A mock dataset for performance testing and unit testing.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, sample, length):
|
||
|
"""Create a mock dataset instance.
|
||
|
|
||
|
:param sample: the sample to be returned repeatedly
|
||
|
:param length: the length of the mock dataset
|
||
|
"""
|
||
|
self.sample = sample
|
||
|
self.length = length
|
||
|
|
||
|
def __iter__(self):
|
||
|
"""Return an iterator over this mock dataset."""
|
||
|
for i in range(self.length):
|
||
|
yield self.sample
|
||
|
|
||
|
|
||
|
class repeatedly(IterableDataset, PipelineStage):
|
||
|
"""Repeatedly yield samples from a dataset."""
|
||
|
|
||
|
def __init__(self, source, nepochs=None, nbatches=None, length=None):
|
||
|
"""Create an instance of Repeatedly.
|
||
|
|
||
|
:param nepochs: repeat for a maximum of nepochs
|
||
|
:param nbatches: repeat for a maximum of nbatches
|
||
|
"""
|
||
|
self.source = source
|
||
|
self.length = length
|
||
|
self.nbatches = nbatches
|
||
|
|
||
|
def invoke(self, source):
|
||
|
"""Return an iterator that iterates repeatedly over a source."""
|
||
|
return utils.repeatedly(
|
||
|
source,
|
||
|
nepochs=self.nepochs,
|
||
|
nbatches=self.nbatches,
|
||
|
)
|
||
|
|
||
|
|
||
|
class with_epoch(IterableDataset):
|
||
|
"""Change the actual and nominal length of an IterableDataset.
|
||
|
|
||
|
This will continuously iterate through the original dataset, but
|
||
|
impose new epoch boundaries at the given length/nominal.
|
||
|
This exists mainly as a workaround for the odd logic in DataLoader.
|
||
|
It is also useful for choosing smaller nominal epoch sizes with
|
||
|
very large datasets.
|
||
|
|
||
|
"""
|
||
|
|
||
|
def __init__(self, dataset, length):
|
||
|
"""Chop the dataset to the given length.
|
||
|
|
||
|
:param dataset: IterableDataset
|
||
|
:param length: declared length of the dataset
|
||
|
:param nominal: nominal length of dataset (if different from declared)
|
||
|
"""
|
||
|
super().__init__()
|
||
|
self.length = length
|
||
|
self.source = None
|
||
|
|
||
|
def __getstate__(self):
|
||
|
"""Return the pickled state of the dataset.
|
||
|
|
||
|
This resets the dataset iterator, since that can't be pickled.
|
||
|
"""
|
||
|
result = dict(self.__dict__)
|
||
|
result["source"] = None
|
||
|
return result
|
||
|
|
||
|
def invoke(self, dataset):
|
||
|
"""Return an iterator over the dataset.
|
||
|
|
||
|
This iterator returns as many samples as given by the `length`
|
||
|
parameter.
|
||
|
"""
|
||
|
if self.source is None:
|
||
|
self.source = iter(dataset)
|
||
|
for i in range(self.length):
|
||
|
try:
|
||
|
sample = next(self.source)
|
||
|
except StopIteration:
|
||
|
self.source = iter(dataset)
|
||
|
try:
|
||
|
sample = next(self.source)
|
||
|
except StopIteration:
|
||
|
return
|
||
|
yield sample
|
||
|
self.source = None
|
||
|
|
||
|
|
||
|
class with_length(IterableDataset, PipelineStage):
|
||
|
"""Repeatedly yield samples from a dataset."""
|
||
|
|
||
|
def __init__(self, dataset, length):
|
||
|
"""Create an instance of Repeatedly.
|
||
|
|
||
|
:param dataset: source dataset
|
||
|
:param length: stated length
|
||
|
"""
|
||
|
super().__init__()
|
||
|
self.dataset = dataset
|
||
|
self.length = length
|
||
|
|
||
|
def invoke(self, dataset):
|
||
|
"""Return an iterator that iterates repeatedly over a source."""
|
||
|
return iter(dataset)
|
||
|
|
||
|
def __len__(self):
|
||
|
"""Return the user specified length."""
|
||
|
return self.length
|