You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
454 lines
13 KiB
454 lines
13 KiB
#
|
|
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
# This file is part of the WebDataset library.
|
|
# See the LICENSE file for licensing terms (BSD-style).
|
|
# Modified from https://github.com/webdataset/webdataset
|
|
#
|
|
"""Classes and functions for writing tar files and WebDataset files."""
|
|
import io
|
|
import json
|
|
import pickle
|
|
import re
|
|
import tarfile
|
|
import time
|
|
from typing import Any
|
|
from typing import Callable
|
|
from typing import Optional
|
|
from typing import Union
|
|
|
|
import numpy as np
|
|
|
|
from . import gopen
|
|
|
|
|
|
def imageencoder(image: Any, format: str="PNG"): # skipcq: PYL-W0622
|
|
"""Compress an image using PIL and return it as a string.
|
|
|
|
Can handle float or uint8 images.
|
|
|
|
:param image: ndarray representing an image
|
|
:param format: compression format (PNG, JPEG, PPM)
|
|
|
|
"""
|
|
import PIL
|
|
|
|
assert isinstance(image, (PIL.Image.Image, np.ndarray)), type(image)
|
|
|
|
if isinstance(image, np.ndarray):
|
|
if image.dtype in [np.dtype("f"), np.dtype("d")]:
|
|
if not (np.amin(image) > -0.001 and np.amax(image) < 1.001):
|
|
raise ValueError(
|
|
f"image values out of range {np.amin(image)} {np.amax(image)}"
|
|
)
|
|
image = np.clip(image, 0.0, 1.0)
|
|
image = np.array(image * 255.0, "uint8")
|
|
assert image.ndim in [2, 3]
|
|
if image.ndim == 3:
|
|
assert image.shape[2] in [1, 3]
|
|
image = PIL.Image.fromarray(image)
|
|
if format.upper() == "JPG":
|
|
format = "JPEG"
|
|
elif format.upper() in ["IMG", "IMAGE"]:
|
|
format = "PPM"
|
|
if format == "JPEG":
|
|
opts = dict(quality=100)
|
|
else:
|
|
opts = {}
|
|
with io.BytesIO() as result:
|
|
image.save(result, format=format, **opts)
|
|
return result.getvalue()
|
|
|
|
|
|
def bytestr(data: Any):
|
|
"""Convert data into a bytestring.
|
|
|
|
Uses str and ASCII encoding for data that isn't already in string format.
|
|
|
|
:param data: data
|
|
"""
|
|
if isinstance(data, bytes):
|
|
return data
|
|
if isinstance(data, str):
|
|
return data.encode("ascii")
|
|
return str(data).encode("ascii")
|
|
|
|
|
|
def paddle_dumps(data: Any):
|
|
"""Dump data into a bytestring using paddle.dumps.
|
|
|
|
This delays importing paddle until needed.
|
|
|
|
:param data: data to be dumped
|
|
"""
|
|
import io
|
|
|
|
import paddle
|
|
|
|
stream = io.BytesIO()
|
|
paddle.save(data, stream)
|
|
return stream.getvalue()
|
|
|
|
|
|
def numpy_dumps(data: np.ndarray):
|
|
"""Dump data into a bytestring using numpy npy format.
|
|
|
|
:param data: data to be dumped
|
|
"""
|
|
import io
|
|
|
|
import numpy.lib.format
|
|
|
|
stream = io.BytesIO()
|
|
numpy.lib.format.write_array(stream, data)
|
|
return stream.getvalue()
|
|
|
|
|
|
def numpy_npz_dumps(data: np.ndarray):
|
|
"""Dump data into a bytestring using numpy npz format.
|
|
|
|
:param data: data to be dumped
|
|
"""
|
|
import io
|
|
|
|
stream = io.BytesIO()
|
|
np.savez_compressed(stream, **data)
|
|
return stream.getvalue()
|
|
|
|
|
|
def tenbin_dumps(x):
|
|
from . import tenbin
|
|
|
|
if isinstance(x, list):
|
|
return memoryview(tenbin.encode_buffer(x))
|
|
else:
|
|
return memoryview(tenbin.encode_buffer([x]))
|
|
|
|
|
|
def cbor_dumps(x):
|
|
import cbor
|
|
|
|
return cbor.dumps(x)
|
|
|
|
|
|
def mp_dumps(x):
|
|
import msgpack
|
|
|
|
return msgpack.packb(x)
|
|
|
|
|
|
def add_handlers(d, keys, value):
|
|
if isinstance(keys, str):
|
|
keys = keys.split()
|
|
for k in keys:
|
|
d[k] = value
|
|
|
|
|
|
def make_handlers():
|
|
"""Create a list of handlers for encoding data."""
|
|
handlers = {}
|
|
add_handlers(handlers, "cls cls2 class count index inx id",
|
|
lambda x: str(x).encode("ascii"))
|
|
add_handlers(handlers, "txt text transcript", lambda x: x.encode("utf-8"))
|
|
add_handlers(handlers, "html htm", lambda x: x.encode("utf-8"))
|
|
add_handlers(handlers, "pyd pickle", pickle.dumps)
|
|
add_handlers(handlers, "pdparams", paddle_dumps)
|
|
add_handlers(handlers, "npy", numpy_dumps)
|
|
add_handlers(handlers, "npz", numpy_npz_dumps)
|
|
add_handlers(handlers, "ten tenbin tb", tenbin_dumps)
|
|
add_handlers(handlers, "json jsn", lambda x: json.dumps(x).encode("utf-8"))
|
|
add_handlers(handlers, "mp msgpack msg", mp_dumps)
|
|
add_handlers(handlers, "cbor", cbor_dumps)
|
|
add_handlers(handlers, "jpg jpeg img image",
|
|
lambda data: imageencoder(data, "jpg"))
|
|
add_handlers(handlers, "png", lambda data: imageencoder(data, "png"))
|
|
add_handlers(handlers, "pbm", lambda data: imageencoder(data, "pbm"))
|
|
add_handlers(handlers, "pgm", lambda data: imageencoder(data, "pgm"))
|
|
add_handlers(handlers, "ppm", lambda data: imageencoder(data, "ppm"))
|
|
return handlers
|
|
|
|
|
|
default_handlers = make_handlers()
|
|
|
|
|
|
def encode_based_on_extension1(data: Any, tname: str, handlers: dict):
|
|
"""Encode data based on its extension and a dict of handlers.
|
|
|
|
:param data: data
|
|
:param tname: file extension
|
|
:param handlers: handlers
|
|
"""
|
|
if tname[0] == "_":
|
|
if not isinstance(data, str):
|
|
raise ValueError("the values of metadata must be of string type")
|
|
return data
|
|
extension = re.sub(r".*\.", "", tname).lower()
|
|
if isinstance(data, bytes):
|
|
return data
|
|
if isinstance(data, str):
|
|
return data.encode("utf-8")
|
|
handler = handlers.get(extension)
|
|
if handler is None:
|
|
raise ValueError(f"no handler found for {extension}")
|
|
return handler(data)
|
|
|
|
|
|
def encode_based_on_extension(sample: dict, handlers: dict):
|
|
"""Encode an entire sample with a collection of handlers.
|
|
|
|
:param sample: data sample (a dict)
|
|
:param handlers: handlers for encoding
|
|
"""
|
|
return {
|
|
k: encode_based_on_extension1(v, k, handlers)
|
|
for k, v in list(sample.items())
|
|
}
|
|
|
|
|
|
def make_encoder(spec: Union[bool, str, dict, Callable]):
|
|
"""Make an encoder function from a specification.
|
|
|
|
:param spec: specification
|
|
"""
|
|
if spec is False or spec is None:
|
|
|
|
def encoder(x):
|
|
"""Do not encode at all."""
|
|
return x
|
|
|
|
elif callable(spec):
|
|
encoder = spec
|
|
elif isinstance(spec, dict):
|
|
|
|
def f(sample):
|
|
"""Encode based on extension."""
|
|
return encode_based_on_extension(sample, spec)
|
|
|
|
encoder = f
|
|
|
|
elif spec is True:
|
|
handlers = default_handlers
|
|
|
|
def g(sample):
|
|
"""Encode based on extension."""
|
|
return encode_based_on_extension(sample, handlers)
|
|
|
|
encoder = g
|
|
|
|
else:
|
|
raise ValueError(f"{spec}: unknown decoder spec")
|
|
if not callable(encoder):
|
|
raise ValueError(f"{spec} did not yield a callable encoder")
|
|
return encoder
|
|
|
|
|
|
class TarWriter:
|
|
"""A class for writing dictionaries to tar files.
|
|
|
|
:param fileobj: fileobj: file name for tar file (.tgz/.tar) or open file descriptor
|
|
:param encoder: sample encoding (Default value = True)
|
|
:param compress: (Default value = None)
|
|
|
|
`True` will use an encoder that behaves similar to the automatic
|
|
decoder for `Dataset`. `False` disables encoding and expects byte strings
|
|
(except for metadata, which must be strings). The `encoder` argument can
|
|
also be a `callable`, or a dictionary mapping extensions to encoders.
|
|
|
|
The following code will add two file to the tar archive: `a/b.png` and
|
|
`a/b.output.png`.
|
|
|
|
```Python
|
|
tarwriter = TarWriter(stream)
|
|
image = imread("b.jpg")
|
|
image2 = imread("b.out.jpg")
|
|
sample = {"__key__": "a/b", "png": image, "output.png": image2}
|
|
tarwriter.write(sample)
|
|
```
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
fileobj,
|
|
user: str="bigdata",
|
|
group: str="bigdata",
|
|
mode: int=0o0444,
|
|
compress: Optional[bool]=None,
|
|
encoder: Union[None, bool, Callable]=True,
|
|
keep_meta: bool=False, ):
|
|
"""Create a tar writer.
|
|
|
|
:param fileobj: stream to write data to
|
|
:param user: user for tar files
|
|
:param group: group for tar files
|
|
:param mode: mode for tar files
|
|
:param compress: desired compression
|
|
:param encoder: encoder function
|
|
:param keep_meta: keep metadata (entries starting with "_")
|
|
"""
|
|
if isinstance(fileobj, str):
|
|
if compress is False:
|
|
tarmode = "w|"
|
|
elif compress is True:
|
|
tarmode = "w|gz"
|
|
else:
|
|
tarmode = "w|gz" if fileobj.endswith("gz") else "w|"
|
|
fileobj = gopen.gopen(fileobj, "wb")
|
|
self.own_fileobj = fileobj
|
|
else:
|
|
tarmode = "w|gz" if compress is True else "w|"
|
|
self.own_fileobj = None
|
|
self.encoder = make_encoder(encoder)
|
|
self.keep_meta = keep_meta
|
|
self.stream = fileobj
|
|
self.tarstream = tarfile.open(fileobj=fileobj, mode=tarmode)
|
|
|
|
self.user = user
|
|
self.group = group
|
|
self.mode = mode
|
|
self.compress = compress
|
|
|
|
def __enter__(self):
|
|
"""Enter context."""
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
"""Exit context."""
|
|
self.close()
|
|
|
|
def close(self):
|
|
"""Close the tar file."""
|
|
self.tarstream.close()
|
|
if self.own_fileobj is not None:
|
|
self.own_fileobj.close()
|
|
self.own_fileobj = None
|
|
|
|
def write(self, obj):
|
|
"""Write a dictionary to the tar file.
|
|
|
|
:param obj: dictionary of objects to be stored
|
|
:returns: size of the entry
|
|
|
|
"""
|
|
total = 0
|
|
obj = self.encoder(obj)
|
|
if "__key__" not in obj:
|
|
raise ValueError("object must contain a __key__")
|
|
for k, v in list(obj.items()):
|
|
if k[0] == "_":
|
|
continue
|
|
if not isinstance(v, (bytes, bytearray, memoryview)):
|
|
raise ValueError(
|
|
f"{k} doesn't map to a bytes after encoding ({type(v)})")
|
|
key = obj["__key__"]
|
|
for k in sorted(obj.keys()):
|
|
if k == "__key__":
|
|
continue
|
|
if not self.keep_meta and k[0] == "_":
|
|
continue
|
|
v = obj[k]
|
|
if isinstance(v, str):
|
|
v = v.encode("utf-8")
|
|
now = time.time()
|
|
ti = tarfile.TarInfo(key + "." + k)
|
|
ti.size = len(v)
|
|
ti.mtime = now
|
|
ti.mode = self.mode
|
|
ti.uname = self.user
|
|
ti.gname = self.group
|
|
if not isinstance(v, (bytes, bytearray, memoryview)):
|
|
raise ValueError(
|
|
f"converter didn't yield bytes: {k}, {type(v)}")
|
|
stream = io.BytesIO(v)
|
|
self.tarstream.addfile(ti, stream)
|
|
total += ti.size
|
|
return total
|
|
|
|
|
|
class ShardWriter:
|
|
"""Like TarWriter but splits into multiple shards."""
|
|
|
|
def __init__(
|
|
self,
|
|
pattern: str,
|
|
maxcount: int=100000,
|
|
maxsize: float=3e9,
|
|
post: Optional[Callable]=None,
|
|
start_shard: int=0,
|
|
**kw, ):
|
|
"""Create a ShardWriter.
|
|
|
|
:param pattern: output file pattern
|
|
:param maxcount: maximum number of records per shard (Default value = 100000)
|
|
:param maxsize: maximum size of each shard (Default value = 3e9)
|
|
:param kw: other options passed to TarWriter
|
|
"""
|
|
self.verbose = 1
|
|
self.kw = kw
|
|
self.maxcount = maxcount
|
|
self.maxsize = maxsize
|
|
self.post = post
|
|
|
|
self.tarstream = None
|
|
self.shard = start_shard
|
|
self.pattern = pattern
|
|
self.total = 0
|
|
self.count = 0
|
|
self.size = 0
|
|
self.fname = None
|
|
self.next_stream()
|
|
|
|
def next_stream(self):
|
|
"""Close the current stream and move to the next."""
|
|
self.finish()
|
|
self.fname = self.pattern % self.shard
|
|
if self.verbose:
|
|
print(
|
|
"# writing",
|
|
self.fname,
|
|
self.count,
|
|
"%.1f GB" % (self.size / 1e9),
|
|
self.total, )
|
|
self.shard += 1
|
|
stream = open(self.fname, "wb")
|
|
self.tarstream = TarWriter(stream, **self.kw)
|
|
self.count = 0
|
|
self.size = 0
|
|
|
|
def write(self, obj):
|
|
"""Write a sample.
|
|
|
|
:param obj: sample to be written
|
|
"""
|
|
if (self.tarstream is None or self.count >= self.maxcount or
|
|
self.size >= self.maxsize):
|
|
self.next_stream()
|
|
size = self.tarstream.write(obj)
|
|
self.count += 1
|
|
self.total += 1
|
|
self.size += size
|
|
|
|
def finish(self):
|
|
"""Finish all writing (use close instead)."""
|
|
if self.tarstream is not None:
|
|
self.tarstream.close()
|
|
assert self.fname is not None
|
|
if callable(self.post):
|
|
self.post(self.fname)
|
|
self.tarstream = None
|
|
|
|
def close(self):
|
|
"""Close the stream."""
|
|
self.finish()
|
|
del self.tarstream
|
|
del self.shard
|
|
del self.count
|
|
del self.size
|
|
|
|
def __enter__(self):
|
|
"""Enter context."""
|
|
return self
|
|
|
|
def __exit__(self, *args, **kw):
|
|
"""Exit context."""
|
|
self.close()
|