parent
c7a7b113c8
commit
0c7abc1f17
@ -0,0 +1,445 @@
|
||||
#
|
||||
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
# This file is part of the WebDataset library.
|
||||
# See the LICENSE file for licensing terms (BSD-style).
|
||||
# Modified from https://github.com/webdataset/webdataset
|
||||
#
|
||||
|
||||
"""Automatically decode webdataset samples."""
|
||||
|
||||
import io, json, os, pickle, re, tempfile
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
|
||||
"""Extensions passed on to the image decoder."""
|
||||
image_extensions = "jpg jpeg png ppm pgm pbm pnm".split()
|
||||
|
||||
|
||||
################################################################
|
||||
# handle basic datatypes
|
||||
################################################################
|
||||
|
||||
|
||||
def paddle_loads(data):
|
||||
"""Load data using paddle.loads, importing paddle only if needed.
|
||||
|
||||
:param data: data to be decoded
|
||||
"""
|
||||
import io
|
||||
|
||||
import paddle
|
||||
|
||||
stream = io.BytesIO(data)
|
||||
return paddle.load(stream)
|
||||
|
||||
|
||||
def tenbin_loads(data):
|
||||
from . import tenbin
|
||||
|
||||
return tenbin.decode_buffer(data)
|
||||
|
||||
|
||||
def msgpack_loads(data):
|
||||
import msgpack
|
||||
|
||||
return msgpack.unpackb(data)
|
||||
|
||||
|
||||
def npy_loads(data):
|
||||
import numpy.lib.format
|
||||
|
||||
stream = io.BytesIO(data)
|
||||
return numpy.lib.format.read_array(stream)
|
||||
|
||||
|
||||
def cbor_loads(data):
|
||||
import cbor
|
||||
|
||||
return cbor.loads(data)
|
||||
|
||||
|
||||
decoders = {
|
||||
"txt": lambda data: data.decode("utf-8"),
|
||||
"text": lambda data: data.decode("utf-8"),
|
||||
"transcript": lambda data: data.decode("utf-8"),
|
||||
"cls": lambda data: int(data),
|
||||
"cls2": lambda data: int(data),
|
||||
"index": lambda data: int(data),
|
||||
"inx": lambda data: int(data),
|
||||
"id": lambda data: int(data),
|
||||
"json": lambda data: json.loads(data),
|
||||
"jsn": lambda data: json.loads(data),
|
||||
"pyd": lambda data: pickle.loads(data),
|
||||
"pickle": lambda data: pickle.loads(data),
|
||||
"pdparams": lambda data: paddle_loads(data),
|
||||
"ten": tenbin_loads,
|
||||
"tb": tenbin_loads,
|
||||
"mp": msgpack_loads,
|
||||
"msg": msgpack_loads,
|
||||
"npy": npy_loads,
|
||||
"npz": lambda data: np.load(io.BytesIO(data)),
|
||||
"cbor": cbor_loads,
|
||||
}
|
||||
|
||||
|
||||
def basichandlers(key, data):
|
||||
"""Handle basic file decoding.
|
||||
|
||||
This function is usually part of the post= decoders.
|
||||
This handles the following forms of decoding:
|
||||
|
||||
- txt -> unicode string
|
||||
- cls cls2 class count index inx id -> int
|
||||
- json jsn -> JSON decoding
|
||||
- pyd pickle -> pickle decoding
|
||||
- pdparams -> paddle.loads
|
||||
- ten tenbin -> fast tensor loading
|
||||
- mp messagepack msg -> messagepack decoding
|
||||
- npy -> Python NPY decoding
|
||||
|
||||
:param key: file name extension
|
||||
:param data: binary data to be decoded
|
||||
"""
|
||||
extension = re.sub(r".*[.]", "", key)
|
||||
|
||||
if extension in decoders:
|
||||
return decoders[extension](data)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
################################################################
|
||||
# Generic extension handler.
|
||||
################################################################
|
||||
|
||||
|
||||
def call_extension_handler(key, data, f, extensions):
|
||||
"""Call the function f with the given data if the key matches the extensions.
|
||||
|
||||
:param key: actual key found in the sample
|
||||
:param data: binary data
|
||||
:param f: decoder function
|
||||
:param extensions: list of matching extensions
|
||||
"""
|
||||
extension = key.lower().split(".")
|
||||
for target in extensions:
|
||||
target = target.split(".")
|
||||
if len(target) > len(extension):
|
||||
continue
|
||||
if extension[-len(target) :] == target:
|
||||
return f(data)
|
||||
return None
|
||||
|
||||
|
||||
def handle_extension(extensions, f):
|
||||
"""Return a decoder function for the list of extensions.
|
||||
|
||||
Extensions can be a space separated list of extensions.
|
||||
Extensions can contain dots, in which case the corresponding number
|
||||
of extension components must be present in the key given to f.
|
||||
Comparisons are case insensitive.
|
||||
|
||||
Examples:
|
||||
handle_extension("jpg jpeg", my_decode_jpg) # invoked for any file.jpg
|
||||
handle_extension("seg.jpg", special_case_jpg) # invoked only for file.seg.jpg
|
||||
"""
|
||||
extensions = extensions.lower().split()
|
||||
return partial(call_extension_handler, f=f, extensions=extensions)
|
||||
|
||||
|
||||
################################################################
|
||||
# handle images
|
||||
################################################################
|
||||
|
||||
imagespecs = {
|
||||
"l8": ("numpy", "uint8", "l"),
|
||||
"rgb8": ("numpy", "uint8", "rgb"),
|
||||
"rgba8": ("numpy", "uint8", "rgba"),
|
||||
"l": ("numpy", "float", "l"),
|
||||
"rgb": ("numpy", "float", "rgb"),
|
||||
"rgba": ("numpy", "float", "rgba"),
|
||||
"paddlel8": ("paddle", "uint8", "l"),
|
||||
"paddlergb8": ("paddle", "uint8", "rgb"),
|
||||
"paddlergba8": ("paddle", "uint8", "rgba"),
|
||||
"paddlel": ("paddle", "float", "l"),
|
||||
"paddlergb": ("paddle", "float", "rgb"),
|
||||
"paddle": ("paddle", "float", "rgb"),
|
||||
"paddlergba": ("paddle", "float", "rgba"),
|
||||
"pill": ("pil", None, "l"),
|
||||
"pil": ("pil", None, "rgb"),
|
||||
"pilrgb": ("pil", None, "rgb"),
|
||||
"pilrgba": ("pil", None, "rgba"),
|
||||
}
|
||||
|
||||
|
||||
class ImageHandler:
|
||||
"""Decode image data using the given `imagespec`.
|
||||
|
||||
The `imagespec` specifies whether the image is decoded
|
||||
to numpy/paddle/pi, decoded to uint8/float, and decoded
|
||||
to l/rgb/rgba:
|
||||
|
||||
- l8: numpy uint8 l
|
||||
- rgb8: numpy uint8 rgb
|
||||
- rgba8: numpy uint8 rgba
|
||||
- l: numpy float l
|
||||
- rgb: numpy float rgb
|
||||
- rgba: numpy float rgba
|
||||
- paddlel8: paddle uint8 l
|
||||
- paddlergb8: paddle uint8 rgb
|
||||
- paddlergba8: paddle uint8 rgba
|
||||
- paddlel: paddle float l
|
||||
- paddlergb: paddle float rgb
|
||||
- paddle: paddle float rgb
|
||||
- paddlergba: paddle float rgba
|
||||
- pill: pil None l
|
||||
- pil: pil None rgb
|
||||
- pilrgb: pil None rgb
|
||||
- pilrgba: pil None rgba
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, imagespec, extensions=image_extensions):
|
||||
"""Create an image handler.
|
||||
|
||||
:param imagespec: short string indicating the type of decoding
|
||||
:param extensions: list of extensions the image handler is invoked for
|
||||
"""
|
||||
if imagespec not in list(imagespecs.keys()):
|
||||
raise ValueError("Unknown imagespec: %s" % imagespec)
|
||||
self.imagespec = imagespec.lower()
|
||||
self.extensions = extensions
|
||||
|
||||
def __call__(self, key, data):
|
||||
"""Perform image decoding.
|
||||
|
||||
:param key: file name extension
|
||||
:param data: binary data
|
||||
"""
|
||||
import PIL.Image
|
||||
|
||||
extension = re.sub(r".*[.]", "", key)
|
||||
if extension.lower() not in self.extensions:
|
||||
return None
|
||||
imagespec = self.imagespec
|
||||
atype, etype, mode = imagespecs[imagespec]
|
||||
with io.BytesIO(data) as stream:
|
||||
img = PIL.Image.open(stream)
|
||||
img.load()
|
||||
img = img.convert(mode.upper())
|
||||
if atype == "pil":
|
||||
return img
|
||||
elif atype == "numpy":
|
||||
result = np.asarray(img)
|
||||
if result.dtype != np.uint8:
|
||||
raise ValueError("ImageHandler: numpy image must be uint8")
|
||||
if etype == "uint8":
|
||||
return result
|
||||
else:
|
||||
return result.astype("f") / 255.0
|
||||
elif atype == "paddle":
|
||||
import paddle
|
||||
|
||||
result = np.asarray(img)
|
||||
if result.dtype != np.uint8:
|
||||
raise ValueError("ImageHandler: paddle image must be uint8")
|
||||
if etype == "uint8":
|
||||
result = np.array(result.transpose(2, 0, 1))
|
||||
return paddle.tensor(result)
|
||||
else:
|
||||
result = np.array(result.transpose(2, 0, 1))
|
||||
return paddle.tensor(result) / 255.0
|
||||
return None
|
||||
|
||||
|
||||
def imagehandler(imagespec, extensions=image_extensions):
|
||||
"""Create an image handler.
|
||||
|
||||
This is just a lower case alias for ImageHander.
|
||||
|
||||
:param imagespec: textual image spec
|
||||
:param extensions: list of extensions the handler should be applied for
|
||||
"""
|
||||
return ImageHandler(imagespec, extensions)
|
||||
|
||||
|
||||
################################################################
|
||||
# torch video
|
||||
################################################################
|
||||
|
||||
'''
|
||||
def torch_video(key, data):
|
||||
"""Decode video using the torchvideo library.
|
||||
|
||||
:param key: file name extension
|
||||
:param data: data to be decoded
|
||||
"""
|
||||
extension = re.sub(r".*[.]", "", key)
|
||||
if extension not in "mp4 ogv mjpeg avi mov h264 mpg webm wmv".split():
|
||||
return None
|
||||
|
||||
import torchvision.io
|
||||
|
||||
with tempfile.TemporaryDirectory() as dirname:
|
||||
fname = os.path.join(dirname, f"file.{extension}")
|
||||
with open(fname, "wb") as stream:
|
||||
stream.write(data)
|
||||
return torchvision.io.read_video(fname, pts_unit="sec")
|
||||
'''
|
||||
|
||||
|
||||
################################################################
|
||||
# paddleaudio
|
||||
################################################################
|
||||
|
||||
|
||||
def paddle_audio(key, data):
|
||||
"""Decode audio using the paddleaudio library.
|
||||
|
||||
:param key: file name extension
|
||||
:param data: data to be decoded
|
||||
"""
|
||||
extension = re.sub(r".*[.]", "", key)
|
||||
if extension not in ["flac", "mp3", "sox", "wav", "m4a", "ogg", "wma"]:
|
||||
return None
|
||||
|
||||
import paddleaudio
|
||||
|
||||
with tempfile.TemporaryDirectory() as dirname:
|
||||
fname = os.path.join(dirname, f"file.{extension}")
|
||||
with open(fname, "wb") as stream:
|
||||
stream.write(data)
|
||||
return paddleaudio.load(fname)
|
||||
|
||||
|
||||
################################################################
|
||||
# special class for continuing decoding
|
||||
################################################################
|
||||
|
||||
|
||||
class Continue:
|
||||
"""Special class for continuing decoding.
|
||||
|
||||
This is mostly used for decompression, as in:
|
||||
|
||||
def decompressor(key, data):
|
||||
if key.endswith(".gz"):
|
||||
return Continue(key[:-3], decompress(data))
|
||||
return None
|
||||
"""
|
||||
|
||||
def __init__(self, key, data):
|
||||
"""__init__.
|
||||
|
||||
:param key:
|
||||
:param data:
|
||||
"""
|
||||
self.key, self.data = key, data
|
||||
|
||||
|
||||
def gzfilter(key, data):
|
||||
"""Decode .gz files.
|
||||
|
||||
This decodes compressed files and the continues decoding.
|
||||
|
||||
:param key: file name extension
|
||||
:param data: binary data
|
||||
"""
|
||||
import gzip
|
||||
|
||||
if not key.endswith(".gz"):
|
||||
return None
|
||||
decompressed = gzip.open(io.BytesIO(data)).read()
|
||||
return Continue(key[:-3], decompressed)
|
||||
|
||||
|
||||
################################################################
|
||||
# decode entire training amples
|
||||
################################################################
|
||||
|
||||
|
||||
default_pre_handlers = [gzfilter]
|
||||
default_post_handlers = [basichandlers]
|
||||
|
||||
|
||||
class Decoder:
|
||||
"""Decode samples using a list of handlers.
|
||||
|
||||
For each key/data item, this iterates through the list of
|
||||
handlers until some handler returns something other than None.
|
||||
"""
|
||||
|
||||
def __init__(self, handlers, pre=None, post=None, only=None, partial=False):
|
||||
"""Create a Decoder.
|
||||
|
||||
:param handlers: main list of handlers
|
||||
:param pre: handlers called before the main list (.gz handler by default)
|
||||
:param post: handlers called after the main list (default handlers by default)
|
||||
:param only: a list of extensions; when give, only ignores files with those extensions
|
||||
:param partial: allow partial decoding (i.e., don't decode fields that aren't of type bytes)
|
||||
"""
|
||||
if isinstance(only, str):
|
||||
only = only.split()
|
||||
self.only = only if only is None else set(only)
|
||||
if pre is None:
|
||||
pre = default_pre_handlers
|
||||
if post is None:
|
||||
post = default_post_handlers
|
||||
assert all(callable(h) for h in handlers), f"one of {handlers} not callable"
|
||||
assert all(callable(h) for h in pre), f"one of {pre} not callable"
|
||||
assert all(callable(h) for h in post), f"one of {post} not callable"
|
||||
self.handlers = pre + handlers + post
|
||||
self.partial = partial
|
||||
|
||||
def decode1(self, key, data):
|
||||
"""Decode a single field of a sample.
|
||||
|
||||
:param key: file name extension
|
||||
:param data: binary data
|
||||
"""
|
||||
key = "." + key
|
||||
for f in self.handlers:
|
||||
result = f(key, data)
|
||||
if isinstance(result, Continue):
|
||||
key, data = result.key, result.data
|
||||
continue
|
||||
if result is not None:
|
||||
return result
|
||||
return data
|
||||
|
||||
def decode(self, sample):
|
||||
"""Decode an entire sample.
|
||||
|
||||
:param sample: the sample, a dictionary of key value pairs
|
||||
"""
|
||||
result = {}
|
||||
assert isinstance(sample, dict), sample
|
||||
for k, v in list(sample.items()):
|
||||
if k[0] == "_":
|
||||
if isinstance(v, bytes):
|
||||
v = v.decode("utf-8")
|
||||
result[k] = v
|
||||
continue
|
||||
if self.only is not None and k not in self.only:
|
||||
result[k] = v
|
||||
continue
|
||||
assert v is not None
|
||||
if self.partial:
|
||||
if isinstance(v, bytes):
|
||||
result[k] = self.decode1(k, v)
|
||||
else:
|
||||
result[k] = v
|
||||
else:
|
||||
assert isinstance(v, bytes)
|
||||
result[k] = self.decode1(k, v)
|
||||
return result
|
||||
|
||||
def __call__(self, sample):
|
||||
"""Decode an entire sample.
|
||||
|
||||
:param sample: the sample
|
||||
"""
|
||||
assert isinstance(sample, dict), (len(sample), sample)
|
||||
return self.decode(sample)
|
@ -0,0 +1,141 @@
|
||||
#
|
||||
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
# This file is part of the WebDataset library.
|
||||
# See the LICENSE file for licensing terms (BSD-style).
|
||||
# Modified from https://github.com/webdataset/webdataset
|
||||
#
|
||||
|
||||
|
||||
"""Train PyTorch models directly from POSIX tar archive.
|
||||
|
||||
Code works locally or over HTTP connections.
|
||||
"""
|
||||
|
||||
import itertools as itt
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
import braceexpand
|
||||
|
||||
from . import utils
|
||||
from .paddle_utils import IterableDataset
|
||||
from .utils import PipelineStage
|
||||
|
||||
|
||||
class MockDataset(IterableDataset):
|
||||
"""MockDataset.
|
||||
|
||||
A mock dataset for performance testing and unit testing.
|
||||
"""
|
||||
|
||||
def __init__(self, sample, length):
|
||||
"""Create a mock dataset instance.
|
||||
|
||||
:param sample: the sample to be returned repeatedly
|
||||
:param length: the length of the mock dataset
|
||||
"""
|
||||
self.sample = sample
|
||||
self.length = length
|
||||
|
||||
def __iter__(self):
|
||||
"""Return an iterator over this mock dataset."""
|
||||
for i in range(self.length):
|
||||
yield self.sample
|
||||
|
||||
|
||||
class repeatedly(IterableDataset, PipelineStage):
|
||||
"""Repeatedly yield samples from a dataset."""
|
||||
|
||||
def __init__(self, source, nepochs=None, nbatches=None, length=None):
|
||||
"""Create an instance of Repeatedly.
|
||||
|
||||
:param nepochs: repeat for a maximum of nepochs
|
||||
:param nbatches: repeat for a maximum of nbatches
|
||||
"""
|
||||
self.source = source
|
||||
self.length = length
|
||||
self.nbatches = nbatches
|
||||
|
||||
def invoke(self, source):
|
||||
"""Return an iterator that iterates repeatedly over a source."""
|
||||
return utils.repeatedly(
|
||||
source,
|
||||
nepochs=self.nepochs,
|
||||
nbatches=self.nbatches,
|
||||
)
|
||||
|
||||
|
||||
class with_epoch(IterableDataset):
|
||||
"""Change the actual and nominal length of an IterableDataset.
|
||||
|
||||
This will continuously iterate through the original dataset, but
|
||||
impose new epoch boundaries at the given length/nominal.
|
||||
This exists mainly as a workaround for the odd logic in DataLoader.
|
||||
It is also useful for choosing smaller nominal epoch sizes with
|
||||
very large datasets.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, length):
|
||||
"""Chop the dataset to the given length.
|
||||
|
||||
:param dataset: IterableDataset
|
||||
:param length: declared length of the dataset
|
||||
:param nominal: nominal length of dataset (if different from declared)
|
||||
"""
|
||||
super().__init__()
|
||||
self.length = length
|
||||
self.source = None
|
||||
|
||||
def __getstate__(self):
|
||||
"""Return the pickled state of the dataset.
|
||||
|
||||
This resets the dataset iterator, since that can't be pickled.
|
||||
"""
|
||||
result = dict(self.__dict__)
|
||||
result["source"] = None
|
||||
return result
|
||||
|
||||
def invoke(self, dataset):
|
||||
"""Return an iterator over the dataset.
|
||||
|
||||
This iterator returns as many samples as given by the `length`
|
||||
parameter.
|
||||
"""
|
||||
if self.source is None:
|
||||
self.source = iter(dataset)
|
||||
for i in range(self.length):
|
||||
try:
|
||||
sample = next(self.source)
|
||||
except StopIteration:
|
||||
self.source = iter(dataset)
|
||||
try:
|
||||
sample = next(self.source)
|
||||
except StopIteration:
|
||||
return
|
||||
yield sample
|
||||
self.source = None
|
||||
|
||||
|
||||
class with_length(IterableDataset, PipelineStage):
|
||||
"""Repeatedly yield samples from a dataset."""
|
||||
|
||||
def __init__(self, dataset, length):
|
||||
"""Create an instance of Repeatedly.
|
||||
|
||||
:param dataset: source dataset
|
||||
:param length: stated length
|
||||
"""
|
||||
super().__init__()
|
||||
self.dataset = dataset
|
||||
self.length = length
|
||||
|
||||
def invoke(self, dataset):
|
||||
"""Return an iterator that iterates repeatedly over a source."""
|
||||
return iter(dataset)
|
||||
|
||||
def __len__(self):
|
||||
"""Return the user specified length."""
|
||||
return self.length
|
@ -0,0 +1,340 @@
|
||||
#
|
||||
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||
# This file is part of the WebDataset library.
|
||||
# See the LICENSE file for licensing terms (BSD-style).
|
||||
#
|
||||
|
||||
|
||||
"""Open URLs by calling subcommands."""
|
||||
|
||||
import os, sys, re
|
||||
from subprocess import PIPE, Popen
|
||||
from urllib.parse import urlparse
|
||||
|
||||
# global used for printing additional node information during verbose output
|
||||
info = {}
|
||||
|
||||
|
||||
class Pipe:
|
||||
"""Wrapper class for subprocess.Pipe.
|
||||
|
||||
This class looks like a stream from the outside, but it checks
|
||||
subprocess status and handles timeouts with exceptions.
|
||||
This way, clients of the class do not need to know that they are
|
||||
dealing with subprocesses.
|
||||
|
||||
:param *args: passed to `subprocess.Pipe`
|
||||
:param **kw: passed to `subprocess.Pipe`
|
||||
:param timeout: timeout for closing/waiting
|
||||
:param ignore_errors: don't raise exceptions on subprocess errors
|
||||
:param ignore_status: list of status codes to ignore
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
mode=None,
|
||||
timeout=7200.0,
|
||||
ignore_errors=False,
|
||||
ignore_status=[],
|
||||
**kw,
|
||||
):
|
||||
"""Create an IO Pipe."""
|
||||
self.ignore_errors = ignore_errors
|
||||
self.ignore_status = [0] + ignore_status
|
||||
self.timeout = timeout
|
||||
self.args = (args, kw)
|
||||
if mode[0] == "r":
|
||||
self.proc = Popen(*args, stdout=PIPE, **kw)
|
||||
self.stream = self.proc.stdout
|
||||
if self.stream is None:
|
||||
raise ValueError(f"{args}: couldn't open")
|
||||
elif mode[0] == "w":
|
||||
self.proc = Popen(*args, stdin=PIPE, **kw)
|
||||
self.stream = self.proc.stdin
|
||||
if self.stream is None:
|
||||
raise ValueError(f"{args}: couldn't open")
|
||||
self.status = None
|
||||
|
||||
def __str__(self):
|
||||
return f"<Pipe {self.args}>"
|
||||
|
||||
def check_status(self):
|
||||
"""Poll the process and handle any errors."""
|
||||
status = self.proc.poll()
|
||||
if status is not None:
|
||||
self.wait_for_child()
|
||||
|
||||
def wait_for_child(self):
|
||||
"""Check the status variable and raise an exception if necessary."""
|
||||
verbose = int(os.environ.get("GOPEN_VERBOSE", 0))
|
||||
if self.status is not None and verbose:
|
||||
# print(f"(waiting again [{self.status} {os.getpid()}:{self.proc.pid}])", file=sys.stderr)
|
||||
return
|
||||
self.status = self.proc.wait()
|
||||
if verbose:
|
||||
print(
|
||||
f"pipe exit [{self.status} {os.getpid()}:{self.proc.pid}] {self.args} {info}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
if self.status not in self.ignore_status and not self.ignore_errors:
|
||||
raise Exception(f"{self.args}: exit {self.status} (read) {info}")
|
||||
|
||||
def read(self, *args, **kw):
|
||||
"""Wrap stream.read and checks status."""
|
||||
result = self.stream.read(*args, **kw)
|
||||
self.check_status()
|
||||
return result
|
||||
|
||||
def write(self, *args, **kw):
|
||||
"""Wrap stream.write and checks status."""
|
||||
result = self.stream.write(*args, **kw)
|
||||
self.check_status()
|
||||
return result
|
||||
|
||||
def readLine(self, *args, **kw):
|
||||
"""Wrap stream.readLine and checks status."""
|
||||
result = self.stream.readLine(*args, **kw)
|
||||
self.status = self.proc.poll()
|
||||
self.check_status()
|
||||
return result
|
||||
|
||||
def close(self):
|
||||
"""Wrap stream.close, wait for the subprocess, and handle errors."""
|
||||
self.stream.close()
|
||||
self.status = self.proc.wait(self.timeout)
|
||||
self.wait_for_child()
|
||||
|
||||
def __enter__(self):
|
||||
"""Context handler."""
|
||||
return self
|
||||
|
||||
def __exit__(self, etype, value, traceback):
|
||||
"""Context handler."""
|
||||
self.close()
|
||||
|
||||
|
||||
def set_options(
|
||||
obj, timeout=None, ignore_errors=None, ignore_status=None, handler=None
|
||||
):
|
||||
"""Set options for Pipes.
|
||||
|
||||
This function can be called on any stream. It will set pipe options only
|
||||
when its argument is a pipe.
|
||||
|
||||
:param obj: any kind of stream
|
||||
:param timeout: desired timeout
|
||||
:param ignore_errors: desired ignore_errors setting
|
||||
:param ignore_status: desired ignore_status setting
|
||||
:param handler: desired error handler
|
||||
"""
|
||||
if not isinstance(obj, Pipe):
|
||||
return False
|
||||
if timeout is not None:
|
||||
obj.timeout = timeout
|
||||
if ignore_errors is not None:
|
||||
obj.ignore_errors = ignore_errors
|
||||
if ignore_status is not None:
|
||||
obj.ignore_status = ignore_status
|
||||
if handler is not None:
|
||||
obj.handler = handler
|
||||
return True
|
||||
|
||||
|
||||
def gopen_file(url, mode="rb", bufsize=8192):
|
||||
"""Open a file.
|
||||
|
||||
This works for local files, files over HTTP, and pipe: files.
|
||||
|
||||
:param url: URL to be opened
|
||||
:param mode: mode to open it with
|
||||
:param bufsize: requested buffer size
|
||||
"""
|
||||
return open(url, mode)
|
||||
|
||||
|
||||
def gopen_pipe(url, mode="rb", bufsize=8192):
|
||||
"""Use gopen to open a pipe.
|
||||
|
||||
:param url: a pipe: URL
|
||||
:param mode: desired mode
|
||||
:param bufsize: desired buffer size
|
||||
"""
|
||||
assert url.startswith("pipe:")
|
||||
cmd = url[5:]
|
||||
if mode[0] == "r":
|
||||
return Pipe(
|
||||
cmd,
|
||||
mode=mode,
|
||||
shell=True,
|
||||
bufsize=bufsize,
|
||||
ignore_status=[141],
|
||||
) # skipcq: BAN-B604
|
||||
elif mode[0] == "w":
|
||||
return Pipe(
|
||||
cmd,
|
||||
mode=mode,
|
||||
shell=True,
|
||||
bufsize=bufsize,
|
||||
ignore_status=[141],
|
||||
) # skipcq: BAN-B604
|
||||
else:
|
||||
raise ValueError(f"{mode}: unknown mode")
|
||||
|
||||
|
||||
def gopen_curl(url, mode="rb", bufsize=8192):
|
||||
"""Open a URL with `curl`.
|
||||
|
||||
:param url: url (usually, http:// etc.)
|
||||
:param mode: file mode
|
||||
:param bufsize: buffer size
|
||||
"""
|
||||
if mode[0] == "r":
|
||||
cmd = f"curl -s -L '{url}'"
|
||||
return Pipe(
|
||||
cmd,
|
||||
mode=mode,
|
||||
shell=True,
|
||||
bufsize=bufsize,
|
||||
ignore_status=[141, 23],
|
||||
) # skipcq: BAN-B604
|
||||
elif mode[0] == "w":
|
||||
cmd = f"curl -s -L -T - '{url}'"
|
||||
return Pipe(
|
||||
cmd,
|
||||
mode=mode,
|
||||
shell=True,
|
||||
bufsize=bufsize,
|
||||
ignore_status=[141, 26],
|
||||
) # skipcq: BAN-B604
|
||||
else:
|
||||
raise ValueError(f"{mode}: unknown mode")
|
||||
|
||||
|
||||
def gopen_htgs(url, mode="rb", bufsize=8192):
|
||||
"""Open a URL with `curl`.
|
||||
|
||||
:param url: url (usually, http:// etc.)
|
||||
:param mode: file mode
|
||||
:param bufsize: buffer size
|
||||
"""
|
||||
if mode[0] == "r":
|
||||
url = re.sub(r"(?i)^htgs://", "gs://", url)
|
||||
cmd = f"curl -s -L '{url}'"
|
||||
return Pipe(
|
||||
cmd,
|
||||
mode=mode,
|
||||
shell=True,
|
||||
bufsize=bufsize,
|
||||
ignore_status=[141, 23],
|
||||
) # skipcq: BAN-B604
|
||||
elif mode[0] == "w":
|
||||
raise ValueError(f"{mode}: cannot write")
|
||||
else:
|
||||
raise ValueError(f"{mode}: unknown mode")
|
||||
|
||||
|
||||
|
||||
def gopen_gsutil(url, mode="rb", bufsize=8192):
|
||||
"""Open a URL with `curl`.
|
||||
|
||||
:param url: url (usually, http:// etc.)
|
||||
:param mode: file mode
|
||||
:param bufsize: buffer size
|
||||
"""
|
||||
if mode[0] == "r":
|
||||
cmd = f"gsutil cat '{url}'"
|
||||
return Pipe(
|
||||
cmd,
|
||||
mode=mode,
|
||||
shell=True,
|
||||
bufsize=bufsize,
|
||||
ignore_status=[141, 23],
|
||||
) # skipcq: BAN-B604
|
||||
elif mode[0] == "w":
|
||||
cmd = f"gsutil cp - '{url}'"
|
||||
return Pipe(
|
||||
cmd,
|
||||
mode=mode,
|
||||
shell=True,
|
||||
bufsize=bufsize,
|
||||
ignore_status=[141, 26],
|
||||
) # skipcq: BAN-B604
|
||||
else:
|
||||
raise ValueError(f"{mode}: unknown mode")
|
||||
|
||||
|
||||
|
||||
def gopen_error(url, *args, **kw):
|
||||
"""Raise a value error.
|
||||
|
||||
:param url: url
|
||||
:param args: other arguments
|
||||
:param kw: other keywords
|
||||
"""
|
||||
raise ValueError(f"{url}: no gopen handler defined")
|
||||
|
||||
|
||||
"""A dispatch table mapping URL schemes to handlers."""
|
||||
gopen_schemes = dict(
|
||||
__default__=gopen_error,
|
||||
pipe=gopen_pipe,
|
||||
http=gopen_curl,
|
||||
https=gopen_curl,
|
||||
sftp=gopen_curl,
|
||||
ftps=gopen_curl,
|
||||
scp=gopen_curl,
|
||||
gs=gopen_gsutil,
|
||||
htgs=gopen_htgs,
|
||||
)
|
||||
|
||||
|
||||
def gopen(url, mode="rb", bufsize=8192, **kw):
|
||||
"""Open the URL.
|
||||
|
||||
This uses the `gopen_schemes` dispatch table to dispatch based
|
||||
on scheme.
|
||||
|
||||
Support for the following schemes is built-in: pipe, file,
|
||||
http, https, sftp, ftps, scp.
|
||||
|
||||
When no scheme is given the url is treated as a file.
|
||||
|
||||
You can use the OPEN_VERBOSE argument to get info about
|
||||
files being opened.
|
||||
|
||||
:param url: the source URL
|
||||
:param mode: the mode ("rb", "r")
|
||||
:param bufsize: the buffer size
|
||||
"""
|
||||
global fallback_gopen
|
||||
verbose = int(os.environ.get("GOPEN_VERBOSE", 0))
|
||||
if verbose:
|
||||
print("GOPEN", url, info, file=sys.stderr)
|
||||
assert mode in ["rb", "wb"], mode
|
||||
if url == "-":
|
||||
if mode == "rb":
|
||||
return sys.stdin.buffer
|
||||
elif mode == "wb":
|
||||
return sys.stdout.buffer
|
||||
else:
|
||||
raise ValueError(f"unknown mode {mode}")
|
||||
pr = urlparse(url)
|
||||
if pr.scheme == "":
|
||||
bufsize = int(os.environ.get("GOPEN_BUFFER", -1))
|
||||
return open(url, mode, buffering=bufsize)
|
||||
if pr.scheme == "file":
|
||||
bufsize = int(os.environ.get("GOPEN_BUFFER", -1))
|
||||
return open(pr.path, mode, buffering=bufsize)
|
||||
handler = gopen_schemes["__default__"]
|
||||
handler = gopen_schemes.get(pr.scheme, handler)
|
||||
return handler(url, mode, bufsize, **kw)
|
||||
|
||||
|
||||
def reader(url, **kw):
|
||||
"""Open url with gopen and mode "rb".
|
||||
|
||||
:param url: source URL
|
||||
:param kw: other keywords forwarded to gopen
|
||||
"""
|
||||
return gopen(url, "rb", **kw)
|
@ -0,0 +1,47 @@
|
||||
#
|
||||
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||
# This file is part of the WebDataset library.
|
||||
# See the LICENSE file for licensing terms (BSD-style).
|
||||
#
|
||||
|
||||
"""Pluggable exception handlers.
|
||||
|
||||
These are functions that take an exception as an argument and then return...
|
||||
|
||||
- the exception (in order to re-raise it)
|
||||
- True (in order to continue and ignore the exception)
|
||||
- False (in order to ignore the exception and stop processing)
|
||||
|
||||
They are used as handler= arguments in much of the library.
|
||||
"""
|
||||
|
||||
import time, warnings
|
||||
|
||||
|
||||
def reraise_exception(exn):
|
||||
"""Call in an exception handler to re-raise the exception."""
|
||||
raise exn
|
||||
|
||||
|
||||
def ignore_and_continue(exn):
|
||||
"""Call in an exception handler to ignore any exception and continue."""
|
||||
return True
|
||||
|
||||
|
||||
def warn_and_continue(exn):
|
||||
"""Call in an exception handler to ignore any exception, isssue a warning, and continue."""
|
||||
warnings.warn(repr(exn))
|
||||
time.sleep(0.5)
|
||||
return True
|
||||
|
||||
|
||||
def ignore_and_stop(exn):
|
||||
"""Call in an exception handler to ignore any exception and stop further processing."""
|
||||
return False
|
||||
|
||||
|
||||
def warn_and_stop(exn):
|
||||
"""Call in an exception handler to ignore any exception and stop further processing."""
|
||||
warnings.warn(repr(exn))
|
||||
time.sleep(0.5)
|
||||
return False
|
@ -0,0 +1,85 @@
|
||||
#
|
||||
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
# This file is part of the WebDataset library.
|
||||
# See the LICENSE file for licensing terms (BSD-style).
|
||||
# Modified from https://github.com/webdataset/webdataset
|
||||
#
|
||||
|
||||
"""Classes for mixing samples from multiple sources."""
|
||||
|
||||
import itertools, os, random, time, sys
|
||||
from functools import reduce, wraps
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import autodecode, utils
|
||||
from .paddle_utils import PaddleTensor, IterableDataset
|
||||
from .utils import PipelineStage
|
||||
|
||||
|
||||
def round_robin_shortest(*sources):
|
||||
i = 0
|
||||
while True:
|
||||
try:
|
||||
sample = next(sources[i % len(sources)])
|
||||
yield sample
|
||||
except StopIteration:
|
||||
break
|
||||
i += 1
|
||||
|
||||
|
||||
def round_robin_longest(*sources):
|
||||
i = 0
|
||||
while len(sources) > 0:
|
||||
try:
|
||||
sample = next(sources[i])
|
||||
i += 1
|
||||
yield sample
|
||||
except StopIteration:
|
||||
del sources[i]
|
||||
|
||||
|
||||
class RoundRobin(IterableDataset):
|
||||
def __init__(self, datasets, longest=False):
|
||||
self.datasets = datasets
|
||||
self.longest = longest
|
||||
|
||||
def __iter__(self):
|
||||
"""Return an iterator over the sources."""
|
||||
sources = [iter(d) for d in self.datasets]
|
||||
if self.longest:
|
||||
return round_robin_longest(*sources)
|
||||
else:
|
||||
return round_robin_shortest(*sources)
|
||||
|
||||
|
||||
def random_samples(sources, probs=None, longest=False):
|
||||
if probs is None:
|
||||
probs = [1] * len(sources)
|
||||
else:
|
||||
probs = list(probs)
|
||||
while len(sources) > 0:
|
||||
cum = (np.array(probs) / np.sum(probs)).cumsum()
|
||||
r = random.random()
|
||||
i = np.searchsorted(cum, r)
|
||||
try:
|
||||
yield next(sources[i])
|
||||
except StopIteration:
|
||||
if longest:
|
||||
del sources[i]
|
||||
del probs[i]
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
class RandomMix(IterableDataset):
|
||||
def __init__(self, datasets, probs=None, longest=False):
|
||||
self.datasets = datasets
|
||||
self.probs = probs
|
||||
self.longest = longest
|
||||
|
||||
def __iter__(self):
|
||||
"""Return an iterator over the sources."""
|
||||
sources = [iter(d) for d in self.datasets]
|
||||
return random_samples(sources, self.probs, longest=self.longest)
|
@ -0,0 +1,450 @@
|
||||
#
|
||||
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
# This file is part of the WebDataset library.
|
||||
# See the LICENSE file for licensing terms (BSD-style).
|
||||
# Modified from https://github.com/webdataset/webdataset
|
||||
#
|
||||
|
||||
"""Classes and functions for writing tar files and WebDataset files."""
|
||||
|
||||
import io, json, pickle, re, tarfile, time
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import gopen
|
||||
|
||||
|
||||
def imageencoder(image: Any, format: str = "PNG"): # skipcq: PYL-W0622
|
||||
"""Compress an image using PIL and return it as a string.
|
||||
|
||||
Can handle float or uint8 images.
|
||||
|
||||
:param image: ndarray representing an image
|
||||
:param format: compression format (PNG, JPEG, PPM)
|
||||
|
||||
"""
|
||||
import PIL
|
||||
|
||||
assert isinstance(image, (PIL.Image.Image, np.ndarray)), type(image)
|
||||
|
||||
if isinstance(image, np.ndarray):
|
||||
if image.dtype in [np.dtype("f"), np.dtype("d")]:
|
||||
if not (np.amin(image) > -0.001 and np.amax(image) < 1.001):
|
||||
raise ValueError(
|
||||
f"image values out of range {np.amin(image)} {np.amax(image)}"
|
||||
)
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
image = np.array(image * 255.0, "uint8")
|
||||
assert image.ndim in [2, 3]
|
||||
if image.ndim == 3:
|
||||
assert image.shape[2] in [1, 3]
|
||||
image = PIL.Image.fromarray(image)
|
||||
if format.upper() == "JPG":
|
||||
format = "JPEG"
|
||||
elif format.upper() in ["IMG", "IMAGE"]:
|
||||
format = "PPM"
|
||||
if format == "JPEG":
|
||||
opts = dict(quality=100)
|
||||
else:
|
||||
opts = {}
|
||||
with io.BytesIO() as result:
|
||||
image.save(result, format=format, **opts)
|
||||
return result.getvalue()
|
||||
|
||||
|
||||
def bytestr(data: Any):
|
||||
"""Convert data into a bytestring.
|
||||
|
||||
Uses str and ASCII encoding for data that isn't already in string format.
|
||||
|
||||
:param data: data
|
||||
"""
|
||||
if isinstance(data, bytes):
|
||||
return data
|
||||
if isinstance(data, str):
|
||||
return data.encode("ascii")
|
||||
return str(data).encode("ascii")
|
||||
|
||||
def paddle_dumps(data: Any):
|
||||
"""Dump data into a bytestring using paddle.dumps.
|
||||
|
||||
This delays importing paddle until needed.
|
||||
|
||||
:param data: data to be dumped
|
||||
"""
|
||||
import io
|
||||
|
||||
import paddle
|
||||
|
||||
stream = io.BytesIO()
|
||||
paddle.save(data, stream)
|
||||
return stream.getvalue()
|
||||
|
||||
def numpy_dumps(data: np.ndarray):
|
||||
"""Dump data into a bytestring using numpy npy format.
|
||||
|
||||
:param data: data to be dumped
|
||||
"""
|
||||
import io
|
||||
|
||||
import numpy.lib.format
|
||||
|
||||
stream = io.BytesIO()
|
||||
numpy.lib.format.write_array(stream, data)
|
||||
return stream.getvalue()
|
||||
|
||||
|
||||
def numpy_npz_dumps(data: np.ndarray):
|
||||
"""Dump data into a bytestring using numpy npz format.
|
||||
|
||||
:param data: data to be dumped
|
||||
"""
|
||||
import io
|
||||
|
||||
stream = io.BytesIO()
|
||||
np.savez_compressed(stream, **data)
|
||||
return stream.getvalue()
|
||||
|
||||
|
||||
def tenbin_dumps(x):
|
||||
from . import tenbin
|
||||
|
||||
if isinstance(x, list):
|
||||
return memoryview(tenbin.encode_buffer(x))
|
||||
else:
|
||||
return memoryview(tenbin.encode_buffer([x]))
|
||||
|
||||
|
||||
def cbor_dumps(x):
|
||||
import cbor
|
||||
|
||||
return cbor.dumps(x)
|
||||
|
||||
|
||||
def mp_dumps(x):
|
||||
import msgpack
|
||||
|
||||
return msgpack.packb(x)
|
||||
|
||||
|
||||
def add_handlers(d, keys, value):
|
||||
if isinstance(keys, str):
|
||||
keys = keys.split()
|
||||
for k in keys:
|
||||
d[k] = value
|
||||
|
||||
|
||||
def make_handlers():
|
||||
"""Create a list of handlers for encoding data."""
|
||||
handlers = {}
|
||||
add_handlers(
|
||||
handlers, "cls cls2 class count index inx id", lambda x: str(x).encode("ascii")
|
||||
)
|
||||
add_handlers(handlers, "txt text transcript", lambda x: x.encode("utf-8"))
|
||||
add_handlers(handlers, "html htm", lambda x: x.encode("utf-8"))
|
||||
add_handlers(handlers, "pyd pickle", pickle.dumps)
|
||||
add_handlers(handlers, "pdparams", paddle_dumps)
|
||||
add_handlers(handlers, "npy", numpy_dumps)
|
||||
add_handlers(handlers, "npz", numpy_npz_dumps)
|
||||
add_handlers(handlers, "ten tenbin tb", tenbin_dumps)
|
||||
add_handlers(handlers, "json jsn", lambda x: json.dumps(x).encode("utf-8"))
|
||||
add_handlers(handlers, "mp msgpack msg", mp_dumps)
|
||||
add_handlers(handlers, "cbor", cbor_dumps)
|
||||
add_handlers(handlers, "jpg jpeg img image", lambda data: imageencoder(data, "jpg"))
|
||||
add_handlers(handlers, "png", lambda data: imageencoder(data, "png"))
|
||||
add_handlers(handlers, "pbm", lambda data: imageencoder(data, "pbm"))
|
||||
add_handlers(handlers, "pgm", lambda data: imageencoder(data, "pgm"))
|
||||
add_handlers(handlers, "ppm", lambda data: imageencoder(data, "ppm"))
|
||||
return handlers
|
||||
|
||||
|
||||
default_handlers = make_handlers()
|
||||
|
||||
|
||||
def encode_based_on_extension1(data: Any, tname: str, handlers: dict):
|
||||
"""Encode data based on its extension and a dict of handlers.
|
||||
|
||||
:param data: data
|
||||
:param tname: file extension
|
||||
:param handlers: handlers
|
||||
"""
|
||||
if tname[0] == "_":
|
||||
if not isinstance(data, str):
|
||||
raise ValueError("the values of metadata must be of string type")
|
||||
return data
|
||||
extension = re.sub(r".*\.", "", tname).lower()
|
||||
if isinstance(data, bytes):
|
||||
return data
|
||||
if isinstance(data, str):
|
||||
return data.encode("utf-8")
|
||||
handler = handlers.get(extension)
|
||||
if handler is None:
|
||||
raise ValueError(f"no handler found for {extension}")
|
||||
return handler(data)
|
||||
|
||||
|
||||
def encode_based_on_extension(sample: dict, handlers: dict):
|
||||
"""Encode an entire sample with a collection of handlers.
|
||||
|
||||
:param sample: data sample (a dict)
|
||||
:param handlers: handlers for encoding
|
||||
"""
|
||||
return {
|
||||
k: encode_based_on_extension1(v, k, handlers) for k, v in list(sample.items())
|
||||
}
|
||||
|
||||
|
||||
def make_encoder(spec: Union[bool, str, dict, Callable]):
|
||||
"""Make an encoder function from a specification.
|
||||
|
||||
:param spec: specification
|
||||
"""
|
||||
if spec is False or spec is None:
|
||||
|
||||
def encoder(x):
|
||||
"""Do not encode at all."""
|
||||
return x
|
||||
|
||||
elif callable(spec):
|
||||
encoder = spec
|
||||
elif isinstance(spec, dict):
|
||||
|
||||
def f(sample):
|
||||
"""Encode based on extension."""
|
||||
return encode_based_on_extension(sample, spec)
|
||||
|
||||
encoder = f
|
||||
|
||||
elif spec is True:
|
||||
handlers = default_handlers
|
||||
|
||||
def g(sample):
|
||||
"""Encode based on extension."""
|
||||
return encode_based_on_extension(sample, handlers)
|
||||
|
||||
encoder = g
|
||||
|
||||
else:
|
||||
raise ValueError(f"{spec}: unknown decoder spec")
|
||||
if not callable(encoder):
|
||||
raise ValueError(f"{spec} did not yield a callable encoder")
|
||||
return encoder
|
||||
|
||||
|
||||
class TarWriter:
|
||||
"""A class for writing dictionaries to tar files.
|
||||
|
||||
:param fileobj: fileobj: file name for tar file (.tgz/.tar) or open file descriptor
|
||||
:param encoder: sample encoding (Default value = True)
|
||||
:param compress: (Default value = None)
|
||||
|
||||
`True` will use an encoder that behaves similar to the automatic
|
||||
decoder for `Dataset`. `False` disables encoding and expects byte strings
|
||||
(except for metadata, which must be strings). The `encoder` argument can
|
||||
also be a `callable`, or a dictionary mapping extensions to encoders.
|
||||
|
||||
The following code will add two file to the tar archive: `a/b.png` and
|
||||
`a/b.output.png`.
|
||||
|
||||
```Python
|
||||
tarwriter = TarWriter(stream)
|
||||
image = imread("b.jpg")
|
||||
image2 = imread("b.out.jpg")
|
||||
sample = {"__key__": "a/b", "png": image, "output.png": image2}
|
||||
tarwriter.write(sample)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fileobj,
|
||||
user: str = "bigdata",
|
||||
group: str = "bigdata",
|
||||
mode: int = 0o0444,
|
||||
compress: Optional[bool] = None,
|
||||
encoder: Union[None, bool, Callable] = True,
|
||||
keep_meta: bool = False,
|
||||
):
|
||||
"""Create a tar writer.
|
||||
|
||||
:param fileobj: stream to write data to
|
||||
:param user: user for tar files
|
||||
:param group: group for tar files
|
||||
:param mode: mode for tar files
|
||||
:param compress: desired compression
|
||||
:param encoder: encoder function
|
||||
:param keep_meta: keep metadata (entries starting with "_")
|
||||
"""
|
||||
if isinstance(fileobj, str):
|
||||
if compress is False:
|
||||
tarmode = "w|"
|
||||
elif compress is True:
|
||||
tarmode = "w|gz"
|
||||
else:
|
||||
tarmode = "w|gz" if fileobj.endswith("gz") else "w|"
|
||||
fileobj = gopen.gopen(fileobj, "wb")
|
||||
self.own_fileobj = fileobj
|
||||
else:
|
||||
tarmode = "w|gz" if compress is True else "w|"
|
||||
self.own_fileobj = None
|
||||
self.encoder = make_encoder(encoder)
|
||||
self.keep_meta = keep_meta
|
||||
self.stream = fileobj
|
||||
self.tarstream = tarfile.open(fileobj=fileobj, mode=tarmode)
|
||||
|
||||
self.user = user
|
||||
self.group = group
|
||||
self.mode = mode
|
||||
self.compress = compress
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter context."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Exit context."""
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
"""Close the tar file."""
|
||||
self.tarstream.close()
|
||||
if self.own_fileobj is not None:
|
||||
self.own_fileobj.close()
|
||||
self.own_fileobj = None
|
||||
|
||||
def write(self, obj):
|
||||
"""Write a dictionary to the tar file.
|
||||
|
||||
:param obj: dictionary of objects to be stored
|
||||
:returns: size of the entry
|
||||
|
||||
"""
|
||||
total = 0
|
||||
obj = self.encoder(obj)
|
||||
if "__key__" not in obj:
|
||||
raise ValueError("object must contain a __key__")
|
||||
for k, v in list(obj.items()):
|
||||
if k[0] == "_":
|
||||
continue
|
||||
if not isinstance(v, (bytes, bytearray, memoryview)):
|
||||
raise ValueError(
|
||||
f"{k} doesn't map to a bytes after encoding ({type(v)})"
|
||||
)
|
||||
key = obj["__key__"]
|
||||
for k in sorted(obj.keys()):
|
||||
if k == "__key__":
|
||||
continue
|
||||
if not self.keep_meta and k[0] == "_":
|
||||
continue
|
||||
v = obj[k]
|
||||
if isinstance(v, str):
|
||||
v = v.encode("utf-8")
|
||||
now = time.time()
|
||||
ti = tarfile.TarInfo(key + "." + k)
|
||||
ti.size = len(v)
|
||||
ti.mtime = now
|
||||
ti.mode = self.mode
|
||||
ti.uname = self.user
|
||||
ti.gname = self.group
|
||||
if not isinstance(v, (bytes, bytearray, memoryview)):
|
||||
raise ValueError(f"converter didn't yield bytes: {k}, {type(v)}")
|
||||
stream = io.BytesIO(v)
|
||||
self.tarstream.addfile(ti, stream)
|
||||
total += ti.size
|
||||
return total
|
||||
|
||||
|
||||
class ShardWriter:
|
||||
"""Like TarWriter but splits into multiple shards."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pattern: str,
|
||||
maxcount: int = 100000,
|
||||
maxsize: float = 3e9,
|
||||
post: Optional[Callable] = None,
|
||||
start_shard: int = 0,
|
||||
**kw,
|
||||
):
|
||||
"""Create a ShardWriter.
|
||||
|
||||
:param pattern: output file pattern
|
||||
:param maxcount: maximum number of records per shard (Default value = 100000)
|
||||
:param maxsize: maximum size of each shard (Default value = 3e9)
|
||||
:param kw: other options passed to TarWriter
|
||||
"""
|
||||
self.verbose = 1
|
||||
self.kw = kw
|
||||
self.maxcount = maxcount
|
||||
self.maxsize = maxsize
|
||||
self.post = post
|
||||
|
||||
self.tarstream = None
|
||||
self.shard = start_shard
|
||||
self.pattern = pattern
|
||||
self.total = 0
|
||||
self.count = 0
|
||||
self.size = 0
|
||||
self.fname = None
|
||||
self.next_stream()
|
||||
|
||||
def next_stream(self):
|
||||
"""Close the current stream and move to the next."""
|
||||
self.finish()
|
||||
self.fname = self.pattern % self.shard
|
||||
if self.verbose:
|
||||
print(
|
||||
"# writing",
|
||||
self.fname,
|
||||
self.count,
|
||||
"%.1f GB" % (self.size / 1e9),
|
||||
self.total,
|
||||
)
|
||||
self.shard += 1
|
||||
stream = open(self.fname, "wb")
|
||||
self.tarstream = TarWriter(stream, **self.kw)
|
||||
self.count = 0
|
||||
self.size = 0
|
||||
|
||||
def write(self, obj):
|
||||
"""Write a sample.
|
||||
|
||||
:param obj: sample to be written
|
||||
"""
|
||||
if (
|
||||
self.tarstream is None
|
||||
or self.count >= self.maxcount
|
||||
or self.size >= self.maxsize
|
||||
):
|
||||
self.next_stream()
|
||||
size = self.tarstream.write(obj)
|
||||
self.count += 1
|
||||
self.total += 1
|
||||
self.size += size
|
||||
|
||||
def finish(self):
|
||||
"""Finish all writing (use close instead)."""
|
||||
if self.tarstream is not None:
|
||||
self.tarstream.close()
|
||||
assert self.fname is not None
|
||||
if callable(self.post):
|
||||
self.post(self.fname)
|
||||
self.tarstream = None
|
||||
|
||||
def close(self):
|
||||
"""Close the stream."""
|
||||
self.finish()
|
||||
del self.tarstream
|
||||
del self.shard
|
||||
del self.count
|
||||
del self.size
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter context."""
|
||||
return self
|
||||
|
||||
def __exit__(self, *args, **kw):
|
||||
"""Exit context."""
|
||||
self.close()
|
Loading…
Reference in new issue