# # Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # This file is part of the WebDataset library. # See the LICENSE file for licensing terms (BSD-style). # Modified from https://github.com/webdataset/webdataset # """Classes and functions for writing tar files and WebDataset files.""" import io, json, pickle, re, tarfile, time from typing import Any, Callable, Optional, Union import numpy as np from . import gopen def imageencoder(image: Any, format: str = "PNG"): # skipcq: PYL-W0622 """Compress an image using PIL and return it as a string. Can handle float or uint8 images. :param image: ndarray representing an image :param format: compression format (PNG, JPEG, PPM) """ import PIL assert isinstance(image, (PIL.Image.Image, np.ndarray)), type(image) if isinstance(image, np.ndarray): if image.dtype in [np.dtype("f"), np.dtype("d")]: if not (np.amin(image) > -0.001 and np.amax(image) < 1.001): raise ValueError( f"image values out of range {np.amin(image)} {np.amax(image)}" ) image = np.clip(image, 0.0, 1.0) image = np.array(image * 255.0, "uint8") assert image.ndim in [2, 3] if image.ndim == 3: assert image.shape[2] in [1, 3] image = PIL.Image.fromarray(image) if format.upper() == "JPG": format = "JPEG" elif format.upper() in ["IMG", "IMAGE"]: format = "PPM" if format == "JPEG": opts = dict(quality=100) else: opts = {} with io.BytesIO() as result: image.save(result, format=format, **opts) return result.getvalue() def bytestr(data: Any): """Convert data into a bytestring. Uses str and ASCII encoding for data that isn't already in string format. :param data: data """ if isinstance(data, bytes): return data if isinstance(data, str): return data.encode("ascii") return str(data).encode("ascii") def paddle_dumps(data: Any): """Dump data into a bytestring using paddle.dumps. This delays importing paddle until needed. :param data: data to be dumped """ import io import paddle stream = io.BytesIO() paddle.save(data, stream) return stream.getvalue() def numpy_dumps(data: np.ndarray): """Dump data into a bytestring using numpy npy format. :param data: data to be dumped """ import io import numpy.lib.format stream = io.BytesIO() numpy.lib.format.write_array(stream, data) return stream.getvalue() def numpy_npz_dumps(data: np.ndarray): """Dump data into a bytestring using numpy npz format. :param data: data to be dumped """ import io stream = io.BytesIO() np.savez_compressed(stream, **data) return stream.getvalue() def tenbin_dumps(x): from . import tenbin if isinstance(x, list): return memoryview(tenbin.encode_buffer(x)) else: return memoryview(tenbin.encode_buffer([x])) def cbor_dumps(x): import cbor return cbor.dumps(x) def mp_dumps(x): import msgpack return msgpack.packb(x) def add_handlers(d, keys, value): if isinstance(keys, str): keys = keys.split() for k in keys: d[k] = value def make_handlers(): """Create a list of handlers for encoding data.""" handlers = {} add_handlers( handlers, "cls cls2 class count index inx id", lambda x: str(x).encode("ascii") ) add_handlers(handlers, "txt text transcript", lambda x: x.encode("utf-8")) add_handlers(handlers, "html htm", lambda x: x.encode("utf-8")) add_handlers(handlers, "pyd pickle", pickle.dumps) add_handlers(handlers, "pdparams", paddle_dumps) add_handlers(handlers, "npy", numpy_dumps) add_handlers(handlers, "npz", numpy_npz_dumps) add_handlers(handlers, "ten tenbin tb", tenbin_dumps) add_handlers(handlers, "json jsn", lambda x: json.dumps(x).encode("utf-8")) add_handlers(handlers, "mp msgpack msg", mp_dumps) add_handlers(handlers, "cbor", cbor_dumps) add_handlers(handlers, "jpg jpeg img image", lambda data: imageencoder(data, "jpg")) add_handlers(handlers, "png", lambda data: imageencoder(data, "png")) add_handlers(handlers, "pbm", lambda data: imageencoder(data, "pbm")) add_handlers(handlers, "pgm", lambda data: imageencoder(data, "pgm")) add_handlers(handlers, "ppm", lambda data: imageencoder(data, "ppm")) return handlers default_handlers = make_handlers() def encode_based_on_extension1(data: Any, tname: str, handlers: dict): """Encode data based on its extension and a dict of handlers. :param data: data :param tname: file extension :param handlers: handlers """ if tname[0] == "_": if not isinstance(data, str): raise ValueError("the values of metadata must be of string type") return data extension = re.sub(r".*\.", "", tname).lower() if isinstance(data, bytes): return data if isinstance(data, str): return data.encode("utf-8") handler = handlers.get(extension) if handler is None: raise ValueError(f"no handler found for {extension}") return handler(data) def encode_based_on_extension(sample: dict, handlers: dict): """Encode an entire sample with a collection of handlers. :param sample: data sample (a dict) :param handlers: handlers for encoding """ return { k: encode_based_on_extension1(v, k, handlers) for k, v in list(sample.items()) } def make_encoder(spec: Union[bool, str, dict, Callable]): """Make an encoder function from a specification. :param spec: specification """ if spec is False or spec is None: def encoder(x): """Do not encode at all.""" return x elif callable(spec): encoder = spec elif isinstance(spec, dict): def f(sample): """Encode based on extension.""" return encode_based_on_extension(sample, spec) encoder = f elif spec is True: handlers = default_handlers def g(sample): """Encode based on extension.""" return encode_based_on_extension(sample, handlers) encoder = g else: raise ValueError(f"{spec}: unknown decoder spec") if not callable(encoder): raise ValueError(f"{spec} did not yield a callable encoder") return encoder class TarWriter: """A class for writing dictionaries to tar files. :param fileobj: fileobj: file name for tar file (.tgz/.tar) or open file descriptor :param encoder: sample encoding (Default value = True) :param compress: (Default value = None) `True` will use an encoder that behaves similar to the automatic decoder for `Dataset`. `False` disables encoding and expects byte strings (except for metadata, which must be strings). The `encoder` argument can also be a `callable`, or a dictionary mapping extensions to encoders. The following code will add two file to the tar archive: `a/b.png` and `a/b.output.png`. ```Python tarwriter = TarWriter(stream) image = imread("b.jpg") image2 = imread("b.out.jpg") sample = {"__key__": "a/b", "png": image, "output.png": image2} tarwriter.write(sample) ``` """ def __init__( self, fileobj, user: str = "bigdata", group: str = "bigdata", mode: int = 0o0444, compress: Optional[bool] = None, encoder: Union[None, bool, Callable] = True, keep_meta: bool = False, ): """Create a tar writer. :param fileobj: stream to write data to :param user: user for tar files :param group: group for tar files :param mode: mode for tar files :param compress: desired compression :param encoder: encoder function :param keep_meta: keep metadata (entries starting with "_") """ if isinstance(fileobj, str): if compress is False: tarmode = "w|" elif compress is True: tarmode = "w|gz" else: tarmode = "w|gz" if fileobj.endswith("gz") else "w|" fileobj = gopen.gopen(fileobj, "wb") self.own_fileobj = fileobj else: tarmode = "w|gz" if compress is True else "w|" self.own_fileobj = None self.encoder = make_encoder(encoder) self.keep_meta = keep_meta self.stream = fileobj self.tarstream = tarfile.open(fileobj=fileobj, mode=tarmode) self.user = user self.group = group self.mode = mode self.compress = compress def __enter__(self): """Enter context.""" return self def __exit__(self, exc_type, exc_val, exc_tb): """Exit context.""" self.close() def close(self): """Close the tar file.""" self.tarstream.close() if self.own_fileobj is not None: self.own_fileobj.close() self.own_fileobj = None def write(self, obj): """Write a dictionary to the tar file. :param obj: dictionary of objects to be stored :returns: size of the entry """ total = 0 obj = self.encoder(obj) if "__key__" not in obj: raise ValueError("object must contain a __key__") for k, v in list(obj.items()): if k[0] == "_": continue if not isinstance(v, (bytes, bytearray, memoryview)): raise ValueError( f"{k} doesn't map to a bytes after encoding ({type(v)})" ) key = obj["__key__"] for k in sorted(obj.keys()): if k == "__key__": continue if not self.keep_meta and k[0] == "_": continue v = obj[k] if isinstance(v, str): v = v.encode("utf-8") now = time.time() ti = tarfile.TarInfo(key + "." + k) ti.size = len(v) ti.mtime = now ti.mode = self.mode ti.uname = self.user ti.gname = self.group if not isinstance(v, (bytes, bytearray, memoryview)): raise ValueError(f"converter didn't yield bytes: {k}, {type(v)}") stream = io.BytesIO(v) self.tarstream.addfile(ti, stream) total += ti.size return total class ShardWriter: """Like TarWriter but splits into multiple shards.""" def __init__( self, pattern: str, maxcount: int = 100000, maxsize: float = 3e9, post: Optional[Callable] = None, start_shard: int = 0, **kw, ): """Create a ShardWriter. :param pattern: output file pattern :param maxcount: maximum number of records per shard (Default value = 100000) :param maxsize: maximum size of each shard (Default value = 3e9) :param kw: other options passed to TarWriter """ self.verbose = 1 self.kw = kw self.maxcount = maxcount self.maxsize = maxsize self.post = post self.tarstream = None self.shard = start_shard self.pattern = pattern self.total = 0 self.count = 0 self.size = 0 self.fname = None self.next_stream() def next_stream(self): """Close the current stream and move to the next.""" self.finish() self.fname = self.pattern % self.shard if self.verbose: print( "# writing", self.fname, self.count, "%.1f GB" % (self.size / 1e9), self.total, ) self.shard += 1 stream = open(self.fname, "wb") self.tarstream = TarWriter(stream, **self.kw) self.count = 0 self.size = 0 def write(self, obj): """Write a sample. :param obj: sample to be written """ if ( self.tarstream is None or self.count >= self.maxcount or self.size >= self.maxsize ): self.next_stream() size = self.tarstream.write(obj) self.count += 1 self.total += 1 self.size += size def finish(self): """Finish all writing (use close instead).""" if self.tarstream is not None: self.tarstream.close() assert self.fname is not None if callable(self.post): self.post(self.fname) self.tarstream = None def close(self): """Close the stream.""" self.finish() del self.tarstream del self.shard del self.count del self.size def __enter__(self): """Enter context.""" return self def __exit__(self, *args, **kw): """Exit context.""" self.close()