parent
c7a7b113c8
commit
0c7abc1f17
@ -0,0 +1,445 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
# This file is part of the WebDataset library.
|
||||||
|
# See the LICENSE file for licensing terms (BSD-style).
|
||||||
|
# Modified from https://github.com/webdataset/webdataset
|
||||||
|
#
|
||||||
|
|
||||||
|
"""Automatically decode webdataset samples."""
|
||||||
|
|
||||||
|
import io, json, os, pickle, re, tempfile
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
"""Extensions passed on to the image decoder."""
|
||||||
|
image_extensions = "jpg jpeg png ppm pgm pbm pnm".split()
|
||||||
|
|
||||||
|
|
||||||
|
################################################################
|
||||||
|
# handle basic datatypes
|
||||||
|
################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def paddle_loads(data):
|
||||||
|
"""Load data using paddle.loads, importing paddle only if needed.
|
||||||
|
|
||||||
|
:param data: data to be decoded
|
||||||
|
"""
|
||||||
|
import io
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
stream = io.BytesIO(data)
|
||||||
|
return paddle.load(stream)
|
||||||
|
|
||||||
|
|
||||||
|
def tenbin_loads(data):
|
||||||
|
from . import tenbin
|
||||||
|
|
||||||
|
return tenbin.decode_buffer(data)
|
||||||
|
|
||||||
|
|
||||||
|
def msgpack_loads(data):
|
||||||
|
import msgpack
|
||||||
|
|
||||||
|
return msgpack.unpackb(data)
|
||||||
|
|
||||||
|
|
||||||
|
def npy_loads(data):
|
||||||
|
import numpy.lib.format
|
||||||
|
|
||||||
|
stream = io.BytesIO(data)
|
||||||
|
return numpy.lib.format.read_array(stream)
|
||||||
|
|
||||||
|
|
||||||
|
def cbor_loads(data):
|
||||||
|
import cbor
|
||||||
|
|
||||||
|
return cbor.loads(data)
|
||||||
|
|
||||||
|
|
||||||
|
decoders = {
|
||||||
|
"txt": lambda data: data.decode("utf-8"),
|
||||||
|
"text": lambda data: data.decode("utf-8"),
|
||||||
|
"transcript": lambda data: data.decode("utf-8"),
|
||||||
|
"cls": lambda data: int(data),
|
||||||
|
"cls2": lambda data: int(data),
|
||||||
|
"index": lambda data: int(data),
|
||||||
|
"inx": lambda data: int(data),
|
||||||
|
"id": lambda data: int(data),
|
||||||
|
"json": lambda data: json.loads(data),
|
||||||
|
"jsn": lambda data: json.loads(data),
|
||||||
|
"pyd": lambda data: pickle.loads(data),
|
||||||
|
"pickle": lambda data: pickle.loads(data),
|
||||||
|
"pdparams": lambda data: paddle_loads(data),
|
||||||
|
"ten": tenbin_loads,
|
||||||
|
"tb": tenbin_loads,
|
||||||
|
"mp": msgpack_loads,
|
||||||
|
"msg": msgpack_loads,
|
||||||
|
"npy": npy_loads,
|
||||||
|
"npz": lambda data: np.load(io.BytesIO(data)),
|
||||||
|
"cbor": cbor_loads,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def basichandlers(key, data):
|
||||||
|
"""Handle basic file decoding.
|
||||||
|
|
||||||
|
This function is usually part of the post= decoders.
|
||||||
|
This handles the following forms of decoding:
|
||||||
|
|
||||||
|
- txt -> unicode string
|
||||||
|
- cls cls2 class count index inx id -> int
|
||||||
|
- json jsn -> JSON decoding
|
||||||
|
- pyd pickle -> pickle decoding
|
||||||
|
- pdparams -> paddle.loads
|
||||||
|
- ten tenbin -> fast tensor loading
|
||||||
|
- mp messagepack msg -> messagepack decoding
|
||||||
|
- npy -> Python NPY decoding
|
||||||
|
|
||||||
|
:param key: file name extension
|
||||||
|
:param data: binary data to be decoded
|
||||||
|
"""
|
||||||
|
extension = re.sub(r".*[.]", "", key)
|
||||||
|
|
||||||
|
if extension in decoders:
|
||||||
|
return decoders[extension](data)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
################################################################
|
||||||
|
# Generic extension handler.
|
||||||
|
################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def call_extension_handler(key, data, f, extensions):
|
||||||
|
"""Call the function f with the given data if the key matches the extensions.
|
||||||
|
|
||||||
|
:param key: actual key found in the sample
|
||||||
|
:param data: binary data
|
||||||
|
:param f: decoder function
|
||||||
|
:param extensions: list of matching extensions
|
||||||
|
"""
|
||||||
|
extension = key.lower().split(".")
|
||||||
|
for target in extensions:
|
||||||
|
target = target.split(".")
|
||||||
|
if len(target) > len(extension):
|
||||||
|
continue
|
||||||
|
if extension[-len(target) :] == target:
|
||||||
|
return f(data)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def handle_extension(extensions, f):
|
||||||
|
"""Return a decoder function for the list of extensions.
|
||||||
|
|
||||||
|
Extensions can be a space separated list of extensions.
|
||||||
|
Extensions can contain dots, in which case the corresponding number
|
||||||
|
of extension components must be present in the key given to f.
|
||||||
|
Comparisons are case insensitive.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
handle_extension("jpg jpeg", my_decode_jpg) # invoked for any file.jpg
|
||||||
|
handle_extension("seg.jpg", special_case_jpg) # invoked only for file.seg.jpg
|
||||||
|
"""
|
||||||
|
extensions = extensions.lower().split()
|
||||||
|
return partial(call_extension_handler, f=f, extensions=extensions)
|
||||||
|
|
||||||
|
|
||||||
|
################################################################
|
||||||
|
# handle images
|
||||||
|
################################################################
|
||||||
|
|
||||||
|
imagespecs = {
|
||||||
|
"l8": ("numpy", "uint8", "l"),
|
||||||
|
"rgb8": ("numpy", "uint8", "rgb"),
|
||||||
|
"rgba8": ("numpy", "uint8", "rgba"),
|
||||||
|
"l": ("numpy", "float", "l"),
|
||||||
|
"rgb": ("numpy", "float", "rgb"),
|
||||||
|
"rgba": ("numpy", "float", "rgba"),
|
||||||
|
"paddlel8": ("paddle", "uint8", "l"),
|
||||||
|
"paddlergb8": ("paddle", "uint8", "rgb"),
|
||||||
|
"paddlergba8": ("paddle", "uint8", "rgba"),
|
||||||
|
"paddlel": ("paddle", "float", "l"),
|
||||||
|
"paddlergb": ("paddle", "float", "rgb"),
|
||||||
|
"paddle": ("paddle", "float", "rgb"),
|
||||||
|
"paddlergba": ("paddle", "float", "rgba"),
|
||||||
|
"pill": ("pil", None, "l"),
|
||||||
|
"pil": ("pil", None, "rgb"),
|
||||||
|
"pilrgb": ("pil", None, "rgb"),
|
||||||
|
"pilrgba": ("pil", None, "rgba"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ImageHandler:
|
||||||
|
"""Decode image data using the given `imagespec`.
|
||||||
|
|
||||||
|
The `imagespec` specifies whether the image is decoded
|
||||||
|
to numpy/paddle/pi, decoded to uint8/float, and decoded
|
||||||
|
to l/rgb/rgba:
|
||||||
|
|
||||||
|
- l8: numpy uint8 l
|
||||||
|
- rgb8: numpy uint8 rgb
|
||||||
|
- rgba8: numpy uint8 rgba
|
||||||
|
- l: numpy float l
|
||||||
|
- rgb: numpy float rgb
|
||||||
|
- rgba: numpy float rgba
|
||||||
|
- paddlel8: paddle uint8 l
|
||||||
|
- paddlergb8: paddle uint8 rgb
|
||||||
|
- paddlergba8: paddle uint8 rgba
|
||||||
|
- paddlel: paddle float l
|
||||||
|
- paddlergb: paddle float rgb
|
||||||
|
- paddle: paddle float rgb
|
||||||
|
- paddlergba: paddle float rgba
|
||||||
|
- pill: pil None l
|
||||||
|
- pil: pil None rgb
|
||||||
|
- pilrgb: pil None rgb
|
||||||
|
- pilrgba: pil None rgba
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, imagespec, extensions=image_extensions):
|
||||||
|
"""Create an image handler.
|
||||||
|
|
||||||
|
:param imagespec: short string indicating the type of decoding
|
||||||
|
:param extensions: list of extensions the image handler is invoked for
|
||||||
|
"""
|
||||||
|
if imagespec not in list(imagespecs.keys()):
|
||||||
|
raise ValueError("Unknown imagespec: %s" % imagespec)
|
||||||
|
self.imagespec = imagespec.lower()
|
||||||
|
self.extensions = extensions
|
||||||
|
|
||||||
|
def __call__(self, key, data):
|
||||||
|
"""Perform image decoding.
|
||||||
|
|
||||||
|
:param key: file name extension
|
||||||
|
:param data: binary data
|
||||||
|
"""
|
||||||
|
import PIL.Image
|
||||||
|
|
||||||
|
extension = re.sub(r".*[.]", "", key)
|
||||||
|
if extension.lower() not in self.extensions:
|
||||||
|
return None
|
||||||
|
imagespec = self.imagespec
|
||||||
|
atype, etype, mode = imagespecs[imagespec]
|
||||||
|
with io.BytesIO(data) as stream:
|
||||||
|
img = PIL.Image.open(stream)
|
||||||
|
img.load()
|
||||||
|
img = img.convert(mode.upper())
|
||||||
|
if atype == "pil":
|
||||||
|
return img
|
||||||
|
elif atype == "numpy":
|
||||||
|
result = np.asarray(img)
|
||||||
|
if result.dtype != np.uint8:
|
||||||
|
raise ValueError("ImageHandler: numpy image must be uint8")
|
||||||
|
if etype == "uint8":
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
return result.astype("f") / 255.0
|
||||||
|
elif atype == "paddle":
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
result = np.asarray(img)
|
||||||
|
if result.dtype != np.uint8:
|
||||||
|
raise ValueError("ImageHandler: paddle image must be uint8")
|
||||||
|
if etype == "uint8":
|
||||||
|
result = np.array(result.transpose(2, 0, 1))
|
||||||
|
return paddle.tensor(result)
|
||||||
|
else:
|
||||||
|
result = np.array(result.transpose(2, 0, 1))
|
||||||
|
return paddle.tensor(result) / 255.0
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def imagehandler(imagespec, extensions=image_extensions):
|
||||||
|
"""Create an image handler.
|
||||||
|
|
||||||
|
This is just a lower case alias for ImageHander.
|
||||||
|
|
||||||
|
:param imagespec: textual image spec
|
||||||
|
:param extensions: list of extensions the handler should be applied for
|
||||||
|
"""
|
||||||
|
return ImageHandler(imagespec, extensions)
|
||||||
|
|
||||||
|
|
||||||
|
################################################################
|
||||||
|
# torch video
|
||||||
|
################################################################
|
||||||
|
|
||||||
|
'''
|
||||||
|
def torch_video(key, data):
|
||||||
|
"""Decode video using the torchvideo library.
|
||||||
|
|
||||||
|
:param key: file name extension
|
||||||
|
:param data: data to be decoded
|
||||||
|
"""
|
||||||
|
extension = re.sub(r".*[.]", "", key)
|
||||||
|
if extension not in "mp4 ogv mjpeg avi mov h264 mpg webm wmv".split():
|
||||||
|
return None
|
||||||
|
|
||||||
|
import torchvision.io
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as dirname:
|
||||||
|
fname = os.path.join(dirname, f"file.{extension}")
|
||||||
|
with open(fname, "wb") as stream:
|
||||||
|
stream.write(data)
|
||||||
|
return torchvision.io.read_video(fname, pts_unit="sec")
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
################################################################
|
||||||
|
# paddleaudio
|
||||||
|
################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def paddle_audio(key, data):
|
||||||
|
"""Decode audio using the paddleaudio library.
|
||||||
|
|
||||||
|
:param key: file name extension
|
||||||
|
:param data: data to be decoded
|
||||||
|
"""
|
||||||
|
extension = re.sub(r".*[.]", "", key)
|
||||||
|
if extension not in ["flac", "mp3", "sox", "wav", "m4a", "ogg", "wma"]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
import paddleaudio
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as dirname:
|
||||||
|
fname = os.path.join(dirname, f"file.{extension}")
|
||||||
|
with open(fname, "wb") as stream:
|
||||||
|
stream.write(data)
|
||||||
|
return paddleaudio.load(fname)
|
||||||
|
|
||||||
|
|
||||||
|
################################################################
|
||||||
|
# special class for continuing decoding
|
||||||
|
################################################################
|
||||||
|
|
||||||
|
|
||||||
|
class Continue:
|
||||||
|
"""Special class for continuing decoding.
|
||||||
|
|
||||||
|
This is mostly used for decompression, as in:
|
||||||
|
|
||||||
|
def decompressor(key, data):
|
||||||
|
if key.endswith(".gz"):
|
||||||
|
return Continue(key[:-3], decompress(data))
|
||||||
|
return None
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, key, data):
|
||||||
|
"""__init__.
|
||||||
|
|
||||||
|
:param key:
|
||||||
|
:param data:
|
||||||
|
"""
|
||||||
|
self.key, self.data = key, data
|
||||||
|
|
||||||
|
|
||||||
|
def gzfilter(key, data):
|
||||||
|
"""Decode .gz files.
|
||||||
|
|
||||||
|
This decodes compressed files and the continues decoding.
|
||||||
|
|
||||||
|
:param key: file name extension
|
||||||
|
:param data: binary data
|
||||||
|
"""
|
||||||
|
import gzip
|
||||||
|
|
||||||
|
if not key.endswith(".gz"):
|
||||||
|
return None
|
||||||
|
decompressed = gzip.open(io.BytesIO(data)).read()
|
||||||
|
return Continue(key[:-3], decompressed)
|
||||||
|
|
||||||
|
|
||||||
|
################################################################
|
||||||
|
# decode entire training amples
|
||||||
|
################################################################
|
||||||
|
|
||||||
|
|
||||||
|
default_pre_handlers = [gzfilter]
|
||||||
|
default_post_handlers = [basichandlers]
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder:
|
||||||
|
"""Decode samples using a list of handlers.
|
||||||
|
|
||||||
|
For each key/data item, this iterates through the list of
|
||||||
|
handlers until some handler returns something other than None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, handlers, pre=None, post=None, only=None, partial=False):
|
||||||
|
"""Create a Decoder.
|
||||||
|
|
||||||
|
:param handlers: main list of handlers
|
||||||
|
:param pre: handlers called before the main list (.gz handler by default)
|
||||||
|
:param post: handlers called after the main list (default handlers by default)
|
||||||
|
:param only: a list of extensions; when give, only ignores files with those extensions
|
||||||
|
:param partial: allow partial decoding (i.e., don't decode fields that aren't of type bytes)
|
||||||
|
"""
|
||||||
|
if isinstance(only, str):
|
||||||
|
only = only.split()
|
||||||
|
self.only = only if only is None else set(only)
|
||||||
|
if pre is None:
|
||||||
|
pre = default_pre_handlers
|
||||||
|
if post is None:
|
||||||
|
post = default_post_handlers
|
||||||
|
assert all(callable(h) for h in handlers), f"one of {handlers} not callable"
|
||||||
|
assert all(callable(h) for h in pre), f"one of {pre} not callable"
|
||||||
|
assert all(callable(h) for h in post), f"one of {post} not callable"
|
||||||
|
self.handlers = pre + handlers + post
|
||||||
|
self.partial = partial
|
||||||
|
|
||||||
|
def decode1(self, key, data):
|
||||||
|
"""Decode a single field of a sample.
|
||||||
|
|
||||||
|
:param key: file name extension
|
||||||
|
:param data: binary data
|
||||||
|
"""
|
||||||
|
key = "." + key
|
||||||
|
for f in self.handlers:
|
||||||
|
result = f(key, data)
|
||||||
|
if isinstance(result, Continue):
|
||||||
|
key, data = result.key, result.data
|
||||||
|
continue
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
return data
|
||||||
|
|
||||||
|
def decode(self, sample):
|
||||||
|
"""Decode an entire sample.
|
||||||
|
|
||||||
|
:param sample: the sample, a dictionary of key value pairs
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
assert isinstance(sample, dict), sample
|
||||||
|
for k, v in list(sample.items()):
|
||||||
|
if k[0] == "_":
|
||||||
|
if isinstance(v, bytes):
|
||||||
|
v = v.decode("utf-8")
|
||||||
|
result[k] = v
|
||||||
|
continue
|
||||||
|
if self.only is not None and k not in self.only:
|
||||||
|
result[k] = v
|
||||||
|
continue
|
||||||
|
assert v is not None
|
||||||
|
if self.partial:
|
||||||
|
if isinstance(v, bytes):
|
||||||
|
result[k] = self.decode1(k, v)
|
||||||
|
else:
|
||||||
|
result[k] = v
|
||||||
|
else:
|
||||||
|
assert isinstance(v, bytes)
|
||||||
|
result[k] = self.decode1(k, v)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __call__(self, sample):
|
||||||
|
"""Decode an entire sample.
|
||||||
|
|
||||||
|
:param sample: the sample
|
||||||
|
"""
|
||||||
|
assert isinstance(sample, dict), (len(sample), sample)
|
||||||
|
return self.decode(sample)
|
@ -0,0 +1,141 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
# This file is part of the WebDataset library.
|
||||||
|
# See the LICENSE file for licensing terms (BSD-style).
|
||||||
|
# Modified from https://github.com/webdataset/webdataset
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
"""Train PyTorch models directly from POSIX tar archive.
|
||||||
|
|
||||||
|
Code works locally or over HTTP connections.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import itertools as itt
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import braceexpand
|
||||||
|
|
||||||
|
from . import utils
|
||||||
|
from .paddle_utils import IterableDataset
|
||||||
|
from .utils import PipelineStage
|
||||||
|
|
||||||
|
|
||||||
|
class MockDataset(IterableDataset):
|
||||||
|
"""MockDataset.
|
||||||
|
|
||||||
|
A mock dataset for performance testing and unit testing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sample, length):
|
||||||
|
"""Create a mock dataset instance.
|
||||||
|
|
||||||
|
:param sample: the sample to be returned repeatedly
|
||||||
|
:param length: the length of the mock dataset
|
||||||
|
"""
|
||||||
|
self.sample = sample
|
||||||
|
self.length = length
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
"""Return an iterator over this mock dataset."""
|
||||||
|
for i in range(self.length):
|
||||||
|
yield self.sample
|
||||||
|
|
||||||
|
|
||||||
|
class repeatedly(IterableDataset, PipelineStage):
|
||||||
|
"""Repeatedly yield samples from a dataset."""
|
||||||
|
|
||||||
|
def __init__(self, source, nepochs=None, nbatches=None, length=None):
|
||||||
|
"""Create an instance of Repeatedly.
|
||||||
|
|
||||||
|
:param nepochs: repeat for a maximum of nepochs
|
||||||
|
:param nbatches: repeat for a maximum of nbatches
|
||||||
|
"""
|
||||||
|
self.source = source
|
||||||
|
self.length = length
|
||||||
|
self.nbatches = nbatches
|
||||||
|
|
||||||
|
def invoke(self, source):
|
||||||
|
"""Return an iterator that iterates repeatedly over a source."""
|
||||||
|
return utils.repeatedly(
|
||||||
|
source,
|
||||||
|
nepochs=self.nepochs,
|
||||||
|
nbatches=self.nbatches,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class with_epoch(IterableDataset):
|
||||||
|
"""Change the actual and nominal length of an IterableDataset.
|
||||||
|
|
||||||
|
This will continuously iterate through the original dataset, but
|
||||||
|
impose new epoch boundaries at the given length/nominal.
|
||||||
|
This exists mainly as a workaround for the odd logic in DataLoader.
|
||||||
|
It is also useful for choosing smaller nominal epoch sizes with
|
||||||
|
very large datasets.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dataset, length):
|
||||||
|
"""Chop the dataset to the given length.
|
||||||
|
|
||||||
|
:param dataset: IterableDataset
|
||||||
|
:param length: declared length of the dataset
|
||||||
|
:param nominal: nominal length of dataset (if different from declared)
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.length = length
|
||||||
|
self.source = None
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
"""Return the pickled state of the dataset.
|
||||||
|
|
||||||
|
This resets the dataset iterator, since that can't be pickled.
|
||||||
|
"""
|
||||||
|
result = dict(self.__dict__)
|
||||||
|
result["source"] = None
|
||||||
|
return result
|
||||||
|
|
||||||
|
def invoke(self, dataset):
|
||||||
|
"""Return an iterator over the dataset.
|
||||||
|
|
||||||
|
This iterator returns as many samples as given by the `length`
|
||||||
|
parameter.
|
||||||
|
"""
|
||||||
|
if self.source is None:
|
||||||
|
self.source = iter(dataset)
|
||||||
|
for i in range(self.length):
|
||||||
|
try:
|
||||||
|
sample = next(self.source)
|
||||||
|
except StopIteration:
|
||||||
|
self.source = iter(dataset)
|
||||||
|
try:
|
||||||
|
sample = next(self.source)
|
||||||
|
except StopIteration:
|
||||||
|
return
|
||||||
|
yield sample
|
||||||
|
self.source = None
|
||||||
|
|
||||||
|
|
||||||
|
class with_length(IterableDataset, PipelineStage):
|
||||||
|
"""Repeatedly yield samples from a dataset."""
|
||||||
|
|
||||||
|
def __init__(self, dataset, length):
|
||||||
|
"""Create an instance of Repeatedly.
|
||||||
|
|
||||||
|
:param dataset: source dataset
|
||||||
|
:param length: stated length
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.dataset = dataset
|
||||||
|
self.length = length
|
||||||
|
|
||||||
|
def invoke(self, dataset):
|
||||||
|
"""Return an iterator that iterates repeatedly over a source."""
|
||||||
|
return iter(dataset)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""Return the user specified length."""
|
||||||
|
return self.length
|
@ -0,0 +1,340 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||||
|
# This file is part of the WebDataset library.
|
||||||
|
# See the LICENSE file for licensing terms (BSD-style).
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
"""Open URLs by calling subcommands."""
|
||||||
|
|
||||||
|
import os, sys, re
|
||||||
|
from subprocess import PIPE, Popen
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
# global used for printing additional node information during verbose output
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
|
||||||
|
class Pipe:
|
||||||
|
"""Wrapper class for subprocess.Pipe.
|
||||||
|
|
||||||
|
This class looks like a stream from the outside, but it checks
|
||||||
|
subprocess status and handles timeouts with exceptions.
|
||||||
|
This way, clients of the class do not need to know that they are
|
||||||
|
dealing with subprocesses.
|
||||||
|
|
||||||
|
:param *args: passed to `subprocess.Pipe`
|
||||||
|
:param **kw: passed to `subprocess.Pipe`
|
||||||
|
:param timeout: timeout for closing/waiting
|
||||||
|
:param ignore_errors: don't raise exceptions on subprocess errors
|
||||||
|
:param ignore_status: list of status codes to ignore
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
mode=None,
|
||||||
|
timeout=7200.0,
|
||||||
|
ignore_errors=False,
|
||||||
|
ignore_status=[],
|
||||||
|
**kw,
|
||||||
|
):
|
||||||
|
"""Create an IO Pipe."""
|
||||||
|
self.ignore_errors = ignore_errors
|
||||||
|
self.ignore_status = [0] + ignore_status
|
||||||
|
self.timeout = timeout
|
||||||
|
self.args = (args, kw)
|
||||||
|
if mode[0] == "r":
|
||||||
|
self.proc = Popen(*args, stdout=PIPE, **kw)
|
||||||
|
self.stream = self.proc.stdout
|
||||||
|
if self.stream is None:
|
||||||
|
raise ValueError(f"{args}: couldn't open")
|
||||||
|
elif mode[0] == "w":
|
||||||
|
self.proc = Popen(*args, stdin=PIPE, **kw)
|
||||||
|
self.stream = self.proc.stdin
|
||||||
|
if self.stream is None:
|
||||||
|
raise ValueError(f"{args}: couldn't open")
|
||||||
|
self.status = None
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"<Pipe {self.args}>"
|
||||||
|
|
||||||
|
def check_status(self):
|
||||||
|
"""Poll the process and handle any errors."""
|
||||||
|
status = self.proc.poll()
|
||||||
|
if status is not None:
|
||||||
|
self.wait_for_child()
|
||||||
|
|
||||||
|
def wait_for_child(self):
|
||||||
|
"""Check the status variable and raise an exception if necessary."""
|
||||||
|
verbose = int(os.environ.get("GOPEN_VERBOSE", 0))
|
||||||
|
if self.status is not None and verbose:
|
||||||
|
# print(f"(waiting again [{self.status} {os.getpid()}:{self.proc.pid}])", file=sys.stderr)
|
||||||
|
return
|
||||||
|
self.status = self.proc.wait()
|
||||||
|
if verbose:
|
||||||
|
print(
|
||||||
|
f"pipe exit [{self.status} {os.getpid()}:{self.proc.pid}] {self.args} {info}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
if self.status not in self.ignore_status and not self.ignore_errors:
|
||||||
|
raise Exception(f"{self.args}: exit {self.status} (read) {info}")
|
||||||
|
|
||||||
|
def read(self, *args, **kw):
|
||||||
|
"""Wrap stream.read and checks status."""
|
||||||
|
result = self.stream.read(*args, **kw)
|
||||||
|
self.check_status()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def write(self, *args, **kw):
|
||||||
|
"""Wrap stream.write and checks status."""
|
||||||
|
result = self.stream.write(*args, **kw)
|
||||||
|
self.check_status()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def readLine(self, *args, **kw):
|
||||||
|
"""Wrap stream.readLine and checks status."""
|
||||||
|
result = self.stream.readLine(*args, **kw)
|
||||||
|
self.status = self.proc.poll()
|
||||||
|
self.check_status()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Wrap stream.close, wait for the subprocess, and handle errors."""
|
||||||
|
self.stream.close()
|
||||||
|
self.status = self.proc.wait(self.timeout)
|
||||||
|
self.wait_for_child()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
"""Context handler."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, etype, value, traceback):
|
||||||
|
"""Context handler."""
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
|
||||||
|
def set_options(
|
||||||
|
obj, timeout=None, ignore_errors=None, ignore_status=None, handler=None
|
||||||
|
):
|
||||||
|
"""Set options for Pipes.
|
||||||
|
|
||||||
|
This function can be called on any stream. It will set pipe options only
|
||||||
|
when its argument is a pipe.
|
||||||
|
|
||||||
|
:param obj: any kind of stream
|
||||||
|
:param timeout: desired timeout
|
||||||
|
:param ignore_errors: desired ignore_errors setting
|
||||||
|
:param ignore_status: desired ignore_status setting
|
||||||
|
:param handler: desired error handler
|
||||||
|
"""
|
||||||
|
if not isinstance(obj, Pipe):
|
||||||
|
return False
|
||||||
|
if timeout is not None:
|
||||||
|
obj.timeout = timeout
|
||||||
|
if ignore_errors is not None:
|
||||||
|
obj.ignore_errors = ignore_errors
|
||||||
|
if ignore_status is not None:
|
||||||
|
obj.ignore_status = ignore_status
|
||||||
|
if handler is not None:
|
||||||
|
obj.handler = handler
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def gopen_file(url, mode="rb", bufsize=8192):
|
||||||
|
"""Open a file.
|
||||||
|
|
||||||
|
This works for local files, files over HTTP, and pipe: files.
|
||||||
|
|
||||||
|
:param url: URL to be opened
|
||||||
|
:param mode: mode to open it with
|
||||||
|
:param bufsize: requested buffer size
|
||||||
|
"""
|
||||||
|
return open(url, mode)
|
||||||
|
|
||||||
|
|
||||||
|
def gopen_pipe(url, mode="rb", bufsize=8192):
|
||||||
|
"""Use gopen to open a pipe.
|
||||||
|
|
||||||
|
:param url: a pipe: URL
|
||||||
|
:param mode: desired mode
|
||||||
|
:param bufsize: desired buffer size
|
||||||
|
"""
|
||||||
|
assert url.startswith("pipe:")
|
||||||
|
cmd = url[5:]
|
||||||
|
if mode[0] == "r":
|
||||||
|
return Pipe(
|
||||||
|
cmd,
|
||||||
|
mode=mode,
|
||||||
|
shell=True,
|
||||||
|
bufsize=bufsize,
|
||||||
|
ignore_status=[141],
|
||||||
|
) # skipcq: BAN-B604
|
||||||
|
elif mode[0] == "w":
|
||||||
|
return Pipe(
|
||||||
|
cmd,
|
||||||
|
mode=mode,
|
||||||
|
shell=True,
|
||||||
|
bufsize=bufsize,
|
||||||
|
ignore_status=[141],
|
||||||
|
) # skipcq: BAN-B604
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{mode}: unknown mode")
|
||||||
|
|
||||||
|
|
||||||
|
def gopen_curl(url, mode="rb", bufsize=8192):
|
||||||
|
"""Open a URL with `curl`.
|
||||||
|
|
||||||
|
:param url: url (usually, http:// etc.)
|
||||||
|
:param mode: file mode
|
||||||
|
:param bufsize: buffer size
|
||||||
|
"""
|
||||||
|
if mode[0] == "r":
|
||||||
|
cmd = f"curl -s -L '{url}'"
|
||||||
|
return Pipe(
|
||||||
|
cmd,
|
||||||
|
mode=mode,
|
||||||
|
shell=True,
|
||||||
|
bufsize=bufsize,
|
||||||
|
ignore_status=[141, 23],
|
||||||
|
) # skipcq: BAN-B604
|
||||||
|
elif mode[0] == "w":
|
||||||
|
cmd = f"curl -s -L -T - '{url}'"
|
||||||
|
return Pipe(
|
||||||
|
cmd,
|
||||||
|
mode=mode,
|
||||||
|
shell=True,
|
||||||
|
bufsize=bufsize,
|
||||||
|
ignore_status=[141, 26],
|
||||||
|
) # skipcq: BAN-B604
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{mode}: unknown mode")
|
||||||
|
|
||||||
|
|
||||||
|
def gopen_htgs(url, mode="rb", bufsize=8192):
|
||||||
|
"""Open a URL with `curl`.
|
||||||
|
|
||||||
|
:param url: url (usually, http:// etc.)
|
||||||
|
:param mode: file mode
|
||||||
|
:param bufsize: buffer size
|
||||||
|
"""
|
||||||
|
if mode[0] == "r":
|
||||||
|
url = re.sub(r"(?i)^htgs://", "gs://", url)
|
||||||
|
cmd = f"curl -s -L '{url}'"
|
||||||
|
return Pipe(
|
||||||
|
cmd,
|
||||||
|
mode=mode,
|
||||||
|
shell=True,
|
||||||
|
bufsize=bufsize,
|
||||||
|
ignore_status=[141, 23],
|
||||||
|
) # skipcq: BAN-B604
|
||||||
|
elif mode[0] == "w":
|
||||||
|
raise ValueError(f"{mode}: cannot write")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{mode}: unknown mode")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def gopen_gsutil(url, mode="rb", bufsize=8192):
|
||||||
|
"""Open a URL with `curl`.
|
||||||
|
|
||||||
|
:param url: url (usually, http:// etc.)
|
||||||
|
:param mode: file mode
|
||||||
|
:param bufsize: buffer size
|
||||||
|
"""
|
||||||
|
if mode[0] == "r":
|
||||||
|
cmd = f"gsutil cat '{url}'"
|
||||||
|
return Pipe(
|
||||||
|
cmd,
|
||||||
|
mode=mode,
|
||||||
|
shell=True,
|
||||||
|
bufsize=bufsize,
|
||||||
|
ignore_status=[141, 23],
|
||||||
|
) # skipcq: BAN-B604
|
||||||
|
elif mode[0] == "w":
|
||||||
|
cmd = f"gsutil cp - '{url}'"
|
||||||
|
return Pipe(
|
||||||
|
cmd,
|
||||||
|
mode=mode,
|
||||||
|
shell=True,
|
||||||
|
bufsize=bufsize,
|
||||||
|
ignore_status=[141, 26],
|
||||||
|
) # skipcq: BAN-B604
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{mode}: unknown mode")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def gopen_error(url, *args, **kw):
|
||||||
|
"""Raise a value error.
|
||||||
|
|
||||||
|
:param url: url
|
||||||
|
:param args: other arguments
|
||||||
|
:param kw: other keywords
|
||||||
|
"""
|
||||||
|
raise ValueError(f"{url}: no gopen handler defined")
|
||||||
|
|
||||||
|
|
||||||
|
"""A dispatch table mapping URL schemes to handlers."""
|
||||||
|
gopen_schemes = dict(
|
||||||
|
__default__=gopen_error,
|
||||||
|
pipe=gopen_pipe,
|
||||||
|
http=gopen_curl,
|
||||||
|
https=gopen_curl,
|
||||||
|
sftp=gopen_curl,
|
||||||
|
ftps=gopen_curl,
|
||||||
|
scp=gopen_curl,
|
||||||
|
gs=gopen_gsutil,
|
||||||
|
htgs=gopen_htgs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def gopen(url, mode="rb", bufsize=8192, **kw):
|
||||||
|
"""Open the URL.
|
||||||
|
|
||||||
|
This uses the `gopen_schemes` dispatch table to dispatch based
|
||||||
|
on scheme.
|
||||||
|
|
||||||
|
Support for the following schemes is built-in: pipe, file,
|
||||||
|
http, https, sftp, ftps, scp.
|
||||||
|
|
||||||
|
When no scheme is given the url is treated as a file.
|
||||||
|
|
||||||
|
You can use the OPEN_VERBOSE argument to get info about
|
||||||
|
files being opened.
|
||||||
|
|
||||||
|
:param url: the source URL
|
||||||
|
:param mode: the mode ("rb", "r")
|
||||||
|
:param bufsize: the buffer size
|
||||||
|
"""
|
||||||
|
global fallback_gopen
|
||||||
|
verbose = int(os.environ.get("GOPEN_VERBOSE", 0))
|
||||||
|
if verbose:
|
||||||
|
print("GOPEN", url, info, file=sys.stderr)
|
||||||
|
assert mode in ["rb", "wb"], mode
|
||||||
|
if url == "-":
|
||||||
|
if mode == "rb":
|
||||||
|
return sys.stdin.buffer
|
||||||
|
elif mode == "wb":
|
||||||
|
return sys.stdout.buffer
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unknown mode {mode}")
|
||||||
|
pr = urlparse(url)
|
||||||
|
if pr.scheme == "":
|
||||||
|
bufsize = int(os.environ.get("GOPEN_BUFFER", -1))
|
||||||
|
return open(url, mode, buffering=bufsize)
|
||||||
|
if pr.scheme == "file":
|
||||||
|
bufsize = int(os.environ.get("GOPEN_BUFFER", -1))
|
||||||
|
return open(pr.path, mode, buffering=bufsize)
|
||||||
|
handler = gopen_schemes["__default__"]
|
||||||
|
handler = gopen_schemes.get(pr.scheme, handler)
|
||||||
|
return handler(url, mode, bufsize, **kw)
|
||||||
|
|
||||||
|
|
||||||
|
def reader(url, **kw):
|
||||||
|
"""Open url with gopen and mode "rb".
|
||||||
|
|
||||||
|
:param url: source URL
|
||||||
|
:param kw: other keywords forwarded to gopen
|
||||||
|
"""
|
||||||
|
return gopen(url, "rb", **kw)
|
@ -0,0 +1,47 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||||
|
# This file is part of the WebDataset library.
|
||||||
|
# See the LICENSE file for licensing terms (BSD-style).
|
||||||
|
#
|
||||||
|
|
||||||
|
"""Pluggable exception handlers.
|
||||||
|
|
||||||
|
These are functions that take an exception as an argument and then return...
|
||||||
|
|
||||||
|
- the exception (in order to re-raise it)
|
||||||
|
- True (in order to continue and ignore the exception)
|
||||||
|
- False (in order to ignore the exception and stop processing)
|
||||||
|
|
||||||
|
They are used as handler= arguments in much of the library.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time, warnings
|
||||||
|
|
||||||
|
|
||||||
|
def reraise_exception(exn):
|
||||||
|
"""Call in an exception handler to re-raise the exception."""
|
||||||
|
raise exn
|
||||||
|
|
||||||
|
|
||||||
|
def ignore_and_continue(exn):
|
||||||
|
"""Call in an exception handler to ignore any exception and continue."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def warn_and_continue(exn):
|
||||||
|
"""Call in an exception handler to ignore any exception, isssue a warning, and continue."""
|
||||||
|
warnings.warn(repr(exn))
|
||||||
|
time.sleep(0.5)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def ignore_and_stop(exn):
|
||||||
|
"""Call in an exception handler to ignore any exception and stop further processing."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def warn_and_stop(exn):
|
||||||
|
"""Call in an exception handler to ignore any exception and stop further processing."""
|
||||||
|
warnings.warn(repr(exn))
|
||||||
|
time.sleep(0.5)
|
||||||
|
return False
|
@ -0,0 +1,85 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
# This file is part of the WebDataset library.
|
||||||
|
# See the LICENSE file for licensing terms (BSD-style).
|
||||||
|
# Modified from https://github.com/webdataset/webdataset
|
||||||
|
#
|
||||||
|
|
||||||
|
"""Classes for mixing samples from multiple sources."""
|
||||||
|
|
||||||
|
import itertools, os, random, time, sys
|
||||||
|
from functools import reduce, wraps
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from . import autodecode, utils
|
||||||
|
from .paddle_utils import PaddleTensor, IterableDataset
|
||||||
|
from .utils import PipelineStage
|
||||||
|
|
||||||
|
|
||||||
|
def round_robin_shortest(*sources):
|
||||||
|
i = 0
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
sample = next(sources[i % len(sources)])
|
||||||
|
yield sample
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
|
||||||
|
def round_robin_longest(*sources):
|
||||||
|
i = 0
|
||||||
|
while len(sources) > 0:
|
||||||
|
try:
|
||||||
|
sample = next(sources[i])
|
||||||
|
i += 1
|
||||||
|
yield sample
|
||||||
|
except StopIteration:
|
||||||
|
del sources[i]
|
||||||
|
|
||||||
|
|
||||||
|
class RoundRobin(IterableDataset):
|
||||||
|
def __init__(self, datasets, longest=False):
|
||||||
|
self.datasets = datasets
|
||||||
|
self.longest = longest
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
"""Return an iterator over the sources."""
|
||||||
|
sources = [iter(d) for d in self.datasets]
|
||||||
|
if self.longest:
|
||||||
|
return round_robin_longest(*sources)
|
||||||
|
else:
|
||||||
|
return round_robin_shortest(*sources)
|
||||||
|
|
||||||
|
|
||||||
|
def random_samples(sources, probs=None, longest=False):
|
||||||
|
if probs is None:
|
||||||
|
probs = [1] * len(sources)
|
||||||
|
else:
|
||||||
|
probs = list(probs)
|
||||||
|
while len(sources) > 0:
|
||||||
|
cum = (np.array(probs) / np.sum(probs)).cumsum()
|
||||||
|
r = random.random()
|
||||||
|
i = np.searchsorted(cum, r)
|
||||||
|
try:
|
||||||
|
yield next(sources[i])
|
||||||
|
except StopIteration:
|
||||||
|
if longest:
|
||||||
|
del sources[i]
|
||||||
|
del probs[i]
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
class RandomMix(IterableDataset):
|
||||||
|
def __init__(self, datasets, probs=None, longest=False):
|
||||||
|
self.datasets = datasets
|
||||||
|
self.probs = probs
|
||||||
|
self.longest = longest
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
"""Return an iterator over the sources."""
|
||||||
|
sources = [iter(d) for d in self.datasets]
|
||||||
|
return random_samples(sources, self.probs, longest=self.longest)
|
@ -0,0 +1,450 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
# This file is part of the WebDataset library.
|
||||||
|
# See the LICENSE file for licensing terms (BSD-style).
|
||||||
|
# Modified from https://github.com/webdataset/webdataset
|
||||||
|
#
|
||||||
|
|
||||||
|
"""Classes and functions for writing tar files and WebDataset files."""
|
||||||
|
|
||||||
|
import io, json, pickle, re, tarfile, time
|
||||||
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from . import gopen
|
||||||
|
|
||||||
|
|
||||||
|
def imageencoder(image: Any, format: str = "PNG"): # skipcq: PYL-W0622
|
||||||
|
"""Compress an image using PIL and return it as a string.
|
||||||
|
|
||||||
|
Can handle float or uint8 images.
|
||||||
|
|
||||||
|
:param image: ndarray representing an image
|
||||||
|
:param format: compression format (PNG, JPEG, PPM)
|
||||||
|
|
||||||
|
"""
|
||||||
|
import PIL
|
||||||
|
|
||||||
|
assert isinstance(image, (PIL.Image.Image, np.ndarray)), type(image)
|
||||||
|
|
||||||
|
if isinstance(image, np.ndarray):
|
||||||
|
if image.dtype in [np.dtype("f"), np.dtype("d")]:
|
||||||
|
if not (np.amin(image) > -0.001 and np.amax(image) < 1.001):
|
||||||
|
raise ValueError(
|
||||||
|
f"image values out of range {np.amin(image)} {np.amax(image)}"
|
||||||
|
)
|
||||||
|
image = np.clip(image, 0.0, 1.0)
|
||||||
|
image = np.array(image * 255.0, "uint8")
|
||||||
|
assert image.ndim in [2, 3]
|
||||||
|
if image.ndim == 3:
|
||||||
|
assert image.shape[2] in [1, 3]
|
||||||
|
image = PIL.Image.fromarray(image)
|
||||||
|
if format.upper() == "JPG":
|
||||||
|
format = "JPEG"
|
||||||
|
elif format.upper() in ["IMG", "IMAGE"]:
|
||||||
|
format = "PPM"
|
||||||
|
if format == "JPEG":
|
||||||
|
opts = dict(quality=100)
|
||||||
|
else:
|
||||||
|
opts = {}
|
||||||
|
with io.BytesIO() as result:
|
||||||
|
image.save(result, format=format, **opts)
|
||||||
|
return result.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
def bytestr(data: Any):
|
||||||
|
"""Convert data into a bytestring.
|
||||||
|
|
||||||
|
Uses str and ASCII encoding for data that isn't already in string format.
|
||||||
|
|
||||||
|
:param data: data
|
||||||
|
"""
|
||||||
|
if isinstance(data, bytes):
|
||||||
|
return data
|
||||||
|
if isinstance(data, str):
|
||||||
|
return data.encode("ascii")
|
||||||
|
return str(data).encode("ascii")
|
||||||
|
|
||||||
|
def paddle_dumps(data: Any):
|
||||||
|
"""Dump data into a bytestring using paddle.dumps.
|
||||||
|
|
||||||
|
This delays importing paddle until needed.
|
||||||
|
|
||||||
|
:param data: data to be dumped
|
||||||
|
"""
|
||||||
|
import io
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
stream = io.BytesIO()
|
||||||
|
paddle.save(data, stream)
|
||||||
|
return stream.getvalue()
|
||||||
|
|
||||||
|
def numpy_dumps(data: np.ndarray):
|
||||||
|
"""Dump data into a bytestring using numpy npy format.
|
||||||
|
|
||||||
|
:param data: data to be dumped
|
||||||
|
"""
|
||||||
|
import io
|
||||||
|
|
||||||
|
import numpy.lib.format
|
||||||
|
|
||||||
|
stream = io.BytesIO()
|
||||||
|
numpy.lib.format.write_array(stream, data)
|
||||||
|
return stream.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
def numpy_npz_dumps(data: np.ndarray):
|
||||||
|
"""Dump data into a bytestring using numpy npz format.
|
||||||
|
|
||||||
|
:param data: data to be dumped
|
||||||
|
"""
|
||||||
|
import io
|
||||||
|
|
||||||
|
stream = io.BytesIO()
|
||||||
|
np.savez_compressed(stream, **data)
|
||||||
|
return stream.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
def tenbin_dumps(x):
|
||||||
|
from . import tenbin
|
||||||
|
|
||||||
|
if isinstance(x, list):
|
||||||
|
return memoryview(tenbin.encode_buffer(x))
|
||||||
|
else:
|
||||||
|
return memoryview(tenbin.encode_buffer([x]))
|
||||||
|
|
||||||
|
|
||||||
|
def cbor_dumps(x):
|
||||||
|
import cbor
|
||||||
|
|
||||||
|
return cbor.dumps(x)
|
||||||
|
|
||||||
|
|
||||||
|
def mp_dumps(x):
|
||||||
|
import msgpack
|
||||||
|
|
||||||
|
return msgpack.packb(x)
|
||||||
|
|
||||||
|
|
||||||
|
def add_handlers(d, keys, value):
|
||||||
|
if isinstance(keys, str):
|
||||||
|
keys = keys.split()
|
||||||
|
for k in keys:
|
||||||
|
d[k] = value
|
||||||
|
|
||||||
|
|
||||||
|
def make_handlers():
|
||||||
|
"""Create a list of handlers for encoding data."""
|
||||||
|
handlers = {}
|
||||||
|
add_handlers(
|
||||||
|
handlers, "cls cls2 class count index inx id", lambda x: str(x).encode("ascii")
|
||||||
|
)
|
||||||
|
add_handlers(handlers, "txt text transcript", lambda x: x.encode("utf-8"))
|
||||||
|
add_handlers(handlers, "html htm", lambda x: x.encode("utf-8"))
|
||||||
|
add_handlers(handlers, "pyd pickle", pickle.dumps)
|
||||||
|
add_handlers(handlers, "pdparams", paddle_dumps)
|
||||||
|
add_handlers(handlers, "npy", numpy_dumps)
|
||||||
|
add_handlers(handlers, "npz", numpy_npz_dumps)
|
||||||
|
add_handlers(handlers, "ten tenbin tb", tenbin_dumps)
|
||||||
|
add_handlers(handlers, "json jsn", lambda x: json.dumps(x).encode("utf-8"))
|
||||||
|
add_handlers(handlers, "mp msgpack msg", mp_dumps)
|
||||||
|
add_handlers(handlers, "cbor", cbor_dumps)
|
||||||
|
add_handlers(handlers, "jpg jpeg img image", lambda data: imageencoder(data, "jpg"))
|
||||||
|
add_handlers(handlers, "png", lambda data: imageencoder(data, "png"))
|
||||||
|
add_handlers(handlers, "pbm", lambda data: imageencoder(data, "pbm"))
|
||||||
|
add_handlers(handlers, "pgm", lambda data: imageencoder(data, "pgm"))
|
||||||
|
add_handlers(handlers, "ppm", lambda data: imageencoder(data, "ppm"))
|
||||||
|
return handlers
|
||||||
|
|
||||||
|
|
||||||
|
default_handlers = make_handlers()
|
||||||
|
|
||||||
|
|
||||||
|
def encode_based_on_extension1(data: Any, tname: str, handlers: dict):
|
||||||
|
"""Encode data based on its extension and a dict of handlers.
|
||||||
|
|
||||||
|
:param data: data
|
||||||
|
:param tname: file extension
|
||||||
|
:param handlers: handlers
|
||||||
|
"""
|
||||||
|
if tname[0] == "_":
|
||||||
|
if not isinstance(data, str):
|
||||||
|
raise ValueError("the values of metadata must be of string type")
|
||||||
|
return data
|
||||||
|
extension = re.sub(r".*\.", "", tname).lower()
|
||||||
|
if isinstance(data, bytes):
|
||||||
|
return data
|
||||||
|
if isinstance(data, str):
|
||||||
|
return data.encode("utf-8")
|
||||||
|
handler = handlers.get(extension)
|
||||||
|
if handler is None:
|
||||||
|
raise ValueError(f"no handler found for {extension}")
|
||||||
|
return handler(data)
|
||||||
|
|
||||||
|
|
||||||
|
def encode_based_on_extension(sample: dict, handlers: dict):
|
||||||
|
"""Encode an entire sample with a collection of handlers.
|
||||||
|
|
||||||
|
:param sample: data sample (a dict)
|
||||||
|
:param handlers: handlers for encoding
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
k: encode_based_on_extension1(v, k, handlers) for k, v in list(sample.items())
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def make_encoder(spec: Union[bool, str, dict, Callable]):
|
||||||
|
"""Make an encoder function from a specification.
|
||||||
|
|
||||||
|
:param spec: specification
|
||||||
|
"""
|
||||||
|
if spec is False or spec is None:
|
||||||
|
|
||||||
|
def encoder(x):
|
||||||
|
"""Do not encode at all."""
|
||||||
|
return x
|
||||||
|
|
||||||
|
elif callable(spec):
|
||||||
|
encoder = spec
|
||||||
|
elif isinstance(spec, dict):
|
||||||
|
|
||||||
|
def f(sample):
|
||||||
|
"""Encode based on extension."""
|
||||||
|
return encode_based_on_extension(sample, spec)
|
||||||
|
|
||||||
|
encoder = f
|
||||||
|
|
||||||
|
elif spec is True:
|
||||||
|
handlers = default_handlers
|
||||||
|
|
||||||
|
def g(sample):
|
||||||
|
"""Encode based on extension."""
|
||||||
|
return encode_based_on_extension(sample, handlers)
|
||||||
|
|
||||||
|
encoder = g
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{spec}: unknown decoder spec")
|
||||||
|
if not callable(encoder):
|
||||||
|
raise ValueError(f"{spec} did not yield a callable encoder")
|
||||||
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
|
class TarWriter:
|
||||||
|
"""A class for writing dictionaries to tar files.
|
||||||
|
|
||||||
|
:param fileobj: fileobj: file name for tar file (.tgz/.tar) or open file descriptor
|
||||||
|
:param encoder: sample encoding (Default value = True)
|
||||||
|
:param compress: (Default value = None)
|
||||||
|
|
||||||
|
`True` will use an encoder that behaves similar to the automatic
|
||||||
|
decoder for `Dataset`. `False` disables encoding and expects byte strings
|
||||||
|
(except for metadata, which must be strings). The `encoder` argument can
|
||||||
|
also be a `callable`, or a dictionary mapping extensions to encoders.
|
||||||
|
|
||||||
|
The following code will add two file to the tar archive: `a/b.png` and
|
||||||
|
`a/b.output.png`.
|
||||||
|
|
||||||
|
```Python
|
||||||
|
tarwriter = TarWriter(stream)
|
||||||
|
image = imread("b.jpg")
|
||||||
|
image2 = imread("b.out.jpg")
|
||||||
|
sample = {"__key__": "a/b", "png": image, "output.png": image2}
|
||||||
|
tarwriter.write(sample)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fileobj,
|
||||||
|
user: str = "bigdata",
|
||||||
|
group: str = "bigdata",
|
||||||
|
mode: int = 0o0444,
|
||||||
|
compress: Optional[bool] = None,
|
||||||
|
encoder: Union[None, bool, Callable] = True,
|
||||||
|
keep_meta: bool = False,
|
||||||
|
):
|
||||||
|
"""Create a tar writer.
|
||||||
|
|
||||||
|
:param fileobj: stream to write data to
|
||||||
|
:param user: user for tar files
|
||||||
|
:param group: group for tar files
|
||||||
|
:param mode: mode for tar files
|
||||||
|
:param compress: desired compression
|
||||||
|
:param encoder: encoder function
|
||||||
|
:param keep_meta: keep metadata (entries starting with "_")
|
||||||
|
"""
|
||||||
|
if isinstance(fileobj, str):
|
||||||
|
if compress is False:
|
||||||
|
tarmode = "w|"
|
||||||
|
elif compress is True:
|
||||||
|
tarmode = "w|gz"
|
||||||
|
else:
|
||||||
|
tarmode = "w|gz" if fileobj.endswith("gz") else "w|"
|
||||||
|
fileobj = gopen.gopen(fileobj, "wb")
|
||||||
|
self.own_fileobj = fileobj
|
||||||
|
else:
|
||||||
|
tarmode = "w|gz" if compress is True else "w|"
|
||||||
|
self.own_fileobj = None
|
||||||
|
self.encoder = make_encoder(encoder)
|
||||||
|
self.keep_meta = keep_meta
|
||||||
|
self.stream = fileobj
|
||||||
|
self.tarstream = tarfile.open(fileobj=fileobj, mode=tarmode)
|
||||||
|
|
||||||
|
self.user = user
|
||||||
|
self.group = group
|
||||||
|
self.mode = mode
|
||||||
|
self.compress = compress
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
"""Enter context."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""Exit context."""
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Close the tar file."""
|
||||||
|
self.tarstream.close()
|
||||||
|
if self.own_fileobj is not None:
|
||||||
|
self.own_fileobj.close()
|
||||||
|
self.own_fileobj = None
|
||||||
|
|
||||||
|
def write(self, obj):
|
||||||
|
"""Write a dictionary to the tar file.
|
||||||
|
|
||||||
|
:param obj: dictionary of objects to be stored
|
||||||
|
:returns: size of the entry
|
||||||
|
|
||||||
|
"""
|
||||||
|
total = 0
|
||||||
|
obj = self.encoder(obj)
|
||||||
|
if "__key__" not in obj:
|
||||||
|
raise ValueError("object must contain a __key__")
|
||||||
|
for k, v in list(obj.items()):
|
||||||
|
if k[0] == "_":
|
||||||
|
continue
|
||||||
|
if not isinstance(v, (bytes, bytearray, memoryview)):
|
||||||
|
raise ValueError(
|
||||||
|
f"{k} doesn't map to a bytes after encoding ({type(v)})"
|
||||||
|
)
|
||||||
|
key = obj["__key__"]
|
||||||
|
for k in sorted(obj.keys()):
|
||||||
|
if k == "__key__":
|
||||||
|
continue
|
||||||
|
if not self.keep_meta and k[0] == "_":
|
||||||
|
continue
|
||||||
|
v = obj[k]
|
||||||
|
if isinstance(v, str):
|
||||||
|
v = v.encode("utf-8")
|
||||||
|
now = time.time()
|
||||||
|
ti = tarfile.TarInfo(key + "." + k)
|
||||||
|
ti.size = len(v)
|
||||||
|
ti.mtime = now
|
||||||
|
ti.mode = self.mode
|
||||||
|
ti.uname = self.user
|
||||||
|
ti.gname = self.group
|
||||||
|
if not isinstance(v, (bytes, bytearray, memoryview)):
|
||||||
|
raise ValueError(f"converter didn't yield bytes: {k}, {type(v)}")
|
||||||
|
stream = io.BytesIO(v)
|
||||||
|
self.tarstream.addfile(ti, stream)
|
||||||
|
total += ti.size
|
||||||
|
return total
|
||||||
|
|
||||||
|
|
||||||
|
class ShardWriter:
|
||||||
|
"""Like TarWriter but splits into multiple shards."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pattern: str,
|
||||||
|
maxcount: int = 100000,
|
||||||
|
maxsize: float = 3e9,
|
||||||
|
post: Optional[Callable] = None,
|
||||||
|
start_shard: int = 0,
|
||||||
|
**kw,
|
||||||
|
):
|
||||||
|
"""Create a ShardWriter.
|
||||||
|
|
||||||
|
:param pattern: output file pattern
|
||||||
|
:param maxcount: maximum number of records per shard (Default value = 100000)
|
||||||
|
:param maxsize: maximum size of each shard (Default value = 3e9)
|
||||||
|
:param kw: other options passed to TarWriter
|
||||||
|
"""
|
||||||
|
self.verbose = 1
|
||||||
|
self.kw = kw
|
||||||
|
self.maxcount = maxcount
|
||||||
|
self.maxsize = maxsize
|
||||||
|
self.post = post
|
||||||
|
|
||||||
|
self.tarstream = None
|
||||||
|
self.shard = start_shard
|
||||||
|
self.pattern = pattern
|
||||||
|
self.total = 0
|
||||||
|
self.count = 0
|
||||||
|
self.size = 0
|
||||||
|
self.fname = None
|
||||||
|
self.next_stream()
|
||||||
|
|
||||||
|
def next_stream(self):
|
||||||
|
"""Close the current stream and move to the next."""
|
||||||
|
self.finish()
|
||||||
|
self.fname = self.pattern % self.shard
|
||||||
|
if self.verbose:
|
||||||
|
print(
|
||||||
|
"# writing",
|
||||||
|
self.fname,
|
||||||
|
self.count,
|
||||||
|
"%.1f GB" % (self.size / 1e9),
|
||||||
|
self.total,
|
||||||
|
)
|
||||||
|
self.shard += 1
|
||||||
|
stream = open(self.fname, "wb")
|
||||||
|
self.tarstream = TarWriter(stream, **self.kw)
|
||||||
|
self.count = 0
|
||||||
|
self.size = 0
|
||||||
|
|
||||||
|
def write(self, obj):
|
||||||
|
"""Write a sample.
|
||||||
|
|
||||||
|
:param obj: sample to be written
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
self.tarstream is None
|
||||||
|
or self.count >= self.maxcount
|
||||||
|
or self.size >= self.maxsize
|
||||||
|
):
|
||||||
|
self.next_stream()
|
||||||
|
size = self.tarstream.write(obj)
|
||||||
|
self.count += 1
|
||||||
|
self.total += 1
|
||||||
|
self.size += size
|
||||||
|
|
||||||
|
def finish(self):
|
||||||
|
"""Finish all writing (use close instead)."""
|
||||||
|
if self.tarstream is not None:
|
||||||
|
self.tarstream.close()
|
||||||
|
assert self.fname is not None
|
||||||
|
if callable(self.post):
|
||||||
|
self.post(self.fname)
|
||||||
|
self.tarstream = None
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Close the stream."""
|
||||||
|
self.finish()
|
||||||
|
del self.tarstream
|
||||||
|
del self.shard
|
||||||
|
del self.count
|
||||||
|
del self.size
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
"""Enter context."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args, **kw):
|
||||||
|
"""Exit context."""
|
||||||
|
self.close()
|
Loading…
Reference in new issue