You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/paddlespeech/audio/streamdata/writer.py

451 lines
13 KiB

#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
"""Classes and functions for writing tar files and WebDataset files."""
import io, json, pickle, re, tarfile, time
from typing import Any, Callable, Optional, Union
import numpy as np
from . import gopen
def imageencoder(image: Any, format: str = "PNG"): # skipcq: PYL-W0622
"""Compress an image using PIL and return it as a string.
Can handle float or uint8 images.
:param image: ndarray representing an image
:param format: compression format (PNG, JPEG, PPM)
"""
import PIL
assert isinstance(image, (PIL.Image.Image, np.ndarray)), type(image)
if isinstance(image, np.ndarray):
if image.dtype in [np.dtype("f"), np.dtype("d")]:
if not (np.amin(image) > -0.001 and np.amax(image) < 1.001):
raise ValueError(
f"image values out of range {np.amin(image)} {np.amax(image)}"
)
image = np.clip(image, 0.0, 1.0)
image = np.array(image * 255.0, "uint8")
assert image.ndim in [2, 3]
if image.ndim == 3:
assert image.shape[2] in [1, 3]
image = PIL.Image.fromarray(image)
if format.upper() == "JPG":
format = "JPEG"
elif format.upper() in ["IMG", "IMAGE"]:
format = "PPM"
if format == "JPEG":
opts = dict(quality=100)
else:
opts = {}
with io.BytesIO() as result:
image.save(result, format=format, **opts)
return result.getvalue()
def bytestr(data: Any):
"""Convert data into a bytestring.
Uses str and ASCII encoding for data that isn't already in string format.
:param data: data
"""
if isinstance(data, bytes):
return data
if isinstance(data, str):
return data.encode("ascii")
return str(data).encode("ascii")
def paddle_dumps(data: Any):
"""Dump data into a bytestring using paddle.dumps.
This delays importing paddle until needed.
:param data: data to be dumped
"""
import io
import paddle
stream = io.BytesIO()
paddle.save(data, stream)
return stream.getvalue()
def numpy_dumps(data: np.ndarray):
"""Dump data into a bytestring using numpy npy format.
:param data: data to be dumped
"""
import io
import numpy.lib.format
stream = io.BytesIO()
numpy.lib.format.write_array(stream, data)
return stream.getvalue()
def numpy_npz_dumps(data: np.ndarray):
"""Dump data into a bytestring using numpy npz format.
:param data: data to be dumped
"""
import io
stream = io.BytesIO()
np.savez_compressed(stream, **data)
return stream.getvalue()
def tenbin_dumps(x):
from . import tenbin
if isinstance(x, list):
return memoryview(tenbin.encode_buffer(x))
else:
return memoryview(tenbin.encode_buffer([x]))
def cbor_dumps(x):
import cbor
return cbor.dumps(x)
def mp_dumps(x):
import msgpack
return msgpack.packb(x)
def add_handlers(d, keys, value):
if isinstance(keys, str):
keys = keys.split()
for k in keys:
d[k] = value
def make_handlers():
"""Create a list of handlers for encoding data."""
handlers = {}
add_handlers(
handlers, "cls cls2 class count index inx id", lambda x: str(x).encode("ascii")
)
add_handlers(handlers, "txt text transcript", lambda x: x.encode("utf-8"))
add_handlers(handlers, "html htm", lambda x: x.encode("utf-8"))
add_handlers(handlers, "pyd pickle", pickle.dumps)
add_handlers(handlers, "pdparams", paddle_dumps)
add_handlers(handlers, "npy", numpy_dumps)
add_handlers(handlers, "npz", numpy_npz_dumps)
add_handlers(handlers, "ten tenbin tb", tenbin_dumps)
add_handlers(handlers, "json jsn", lambda x: json.dumps(x).encode("utf-8"))
add_handlers(handlers, "mp msgpack msg", mp_dumps)
add_handlers(handlers, "cbor", cbor_dumps)
add_handlers(handlers, "jpg jpeg img image", lambda data: imageencoder(data, "jpg"))
add_handlers(handlers, "png", lambda data: imageencoder(data, "png"))
add_handlers(handlers, "pbm", lambda data: imageencoder(data, "pbm"))
add_handlers(handlers, "pgm", lambda data: imageencoder(data, "pgm"))
add_handlers(handlers, "ppm", lambda data: imageencoder(data, "ppm"))
return handlers
default_handlers = make_handlers()
def encode_based_on_extension1(data: Any, tname: str, handlers: dict):
"""Encode data based on its extension and a dict of handlers.
:param data: data
:param tname: file extension
:param handlers: handlers
"""
if tname[0] == "_":
if not isinstance(data, str):
raise ValueError("the values of metadata must be of string type")
return data
extension = re.sub(r".*\.", "", tname).lower()
if isinstance(data, bytes):
return data
if isinstance(data, str):
return data.encode("utf-8")
handler = handlers.get(extension)
if handler is None:
raise ValueError(f"no handler found for {extension}")
return handler(data)
def encode_based_on_extension(sample: dict, handlers: dict):
"""Encode an entire sample with a collection of handlers.
:param sample: data sample (a dict)
:param handlers: handlers for encoding
"""
return {
k: encode_based_on_extension1(v, k, handlers) for k, v in list(sample.items())
}
def make_encoder(spec: Union[bool, str, dict, Callable]):
"""Make an encoder function from a specification.
:param spec: specification
"""
if spec is False or spec is None:
def encoder(x):
"""Do not encode at all."""
return x
elif callable(spec):
encoder = spec
elif isinstance(spec, dict):
def f(sample):
"""Encode based on extension."""
return encode_based_on_extension(sample, spec)
encoder = f
elif spec is True:
handlers = default_handlers
def g(sample):
"""Encode based on extension."""
return encode_based_on_extension(sample, handlers)
encoder = g
else:
raise ValueError(f"{spec}: unknown decoder spec")
if not callable(encoder):
raise ValueError(f"{spec} did not yield a callable encoder")
return encoder
class TarWriter:
"""A class for writing dictionaries to tar files.
:param fileobj: fileobj: file name for tar file (.tgz/.tar) or open file descriptor
:param encoder: sample encoding (Default value = True)
:param compress: (Default value = None)
`True` will use an encoder that behaves similar to the automatic
decoder for `Dataset`. `False` disables encoding and expects byte strings
(except for metadata, which must be strings). The `encoder` argument can
also be a `callable`, or a dictionary mapping extensions to encoders.
The following code will add two file to the tar archive: `a/b.png` and
`a/b.output.png`.
```Python
tarwriter = TarWriter(stream)
image = imread("b.jpg")
image2 = imread("b.out.jpg")
sample = {"__key__": "a/b", "png": image, "output.png": image2}
tarwriter.write(sample)
```
"""
def __init__(
self,
fileobj,
user: str = "bigdata",
group: str = "bigdata",
mode: int = 0o0444,
compress: Optional[bool] = None,
encoder: Union[None, bool, Callable] = True,
keep_meta: bool = False,
):
"""Create a tar writer.
:param fileobj: stream to write data to
:param user: user for tar files
:param group: group for tar files
:param mode: mode for tar files
:param compress: desired compression
:param encoder: encoder function
:param keep_meta: keep metadata (entries starting with "_")
"""
if isinstance(fileobj, str):
if compress is False:
tarmode = "w|"
elif compress is True:
tarmode = "w|gz"
else:
tarmode = "w|gz" if fileobj.endswith("gz") else "w|"
fileobj = gopen.gopen(fileobj, "wb")
self.own_fileobj = fileobj
else:
tarmode = "w|gz" if compress is True else "w|"
self.own_fileobj = None
self.encoder = make_encoder(encoder)
self.keep_meta = keep_meta
self.stream = fileobj
self.tarstream = tarfile.open(fileobj=fileobj, mode=tarmode)
self.user = user
self.group = group
self.mode = mode
self.compress = compress
def __enter__(self):
"""Enter context."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Exit context."""
self.close()
def close(self):
"""Close the tar file."""
self.tarstream.close()
if self.own_fileobj is not None:
self.own_fileobj.close()
self.own_fileobj = None
def write(self, obj):
"""Write a dictionary to the tar file.
:param obj: dictionary of objects to be stored
:returns: size of the entry
"""
total = 0
obj = self.encoder(obj)
if "__key__" not in obj:
raise ValueError("object must contain a __key__")
for k, v in list(obj.items()):
if k[0] == "_":
continue
if not isinstance(v, (bytes, bytearray, memoryview)):
raise ValueError(
f"{k} doesn't map to a bytes after encoding ({type(v)})"
)
key = obj["__key__"]
for k in sorted(obj.keys()):
if k == "__key__":
continue
if not self.keep_meta and k[0] == "_":
continue
v = obj[k]
if isinstance(v, str):
v = v.encode("utf-8")
now = time.time()
ti = tarfile.TarInfo(key + "." + k)
ti.size = len(v)
ti.mtime = now
ti.mode = self.mode
ti.uname = self.user
ti.gname = self.group
if not isinstance(v, (bytes, bytearray, memoryview)):
raise ValueError(f"converter didn't yield bytes: {k}, {type(v)}")
stream = io.BytesIO(v)
self.tarstream.addfile(ti, stream)
total += ti.size
return total
class ShardWriter:
"""Like TarWriter but splits into multiple shards."""
def __init__(
self,
pattern: str,
maxcount: int = 100000,
maxsize: float = 3e9,
post: Optional[Callable] = None,
start_shard: int = 0,
**kw,
):
"""Create a ShardWriter.
:param pattern: output file pattern
:param maxcount: maximum number of records per shard (Default value = 100000)
:param maxsize: maximum size of each shard (Default value = 3e9)
:param kw: other options passed to TarWriter
"""
self.verbose = 1
self.kw = kw
self.maxcount = maxcount
self.maxsize = maxsize
self.post = post
self.tarstream = None
self.shard = start_shard
self.pattern = pattern
self.total = 0
self.count = 0
self.size = 0
self.fname = None
self.next_stream()
def next_stream(self):
"""Close the current stream and move to the next."""
self.finish()
self.fname = self.pattern % self.shard
if self.verbose:
print(
"# writing",
self.fname,
self.count,
"%.1f GB" % (self.size / 1e9),
self.total,
)
self.shard += 1
stream = open(self.fname, "wb")
self.tarstream = TarWriter(stream, **self.kw)
self.count = 0
self.size = 0
def write(self, obj):
"""Write a sample.
:param obj: sample to be written
"""
if (
self.tarstream is None
or self.count >= self.maxcount
or self.size >= self.maxsize
):
self.next_stream()
size = self.tarstream.write(obj)
self.count += 1
self.total += 1
self.size += size
def finish(self):
"""Finish all writing (use close instead)."""
if self.tarstream is not None:
self.tarstream.close()
assert self.fname is not None
if callable(self.post):
self.post(self.fname)
self.tarstream = None
def close(self):
"""Close the stream."""
self.finish()
del self.tarstream
del self.shard
del self.count
del self.size
def __enter__(self):
"""Enter context."""
return self
def __exit__(self, *args, **kw):
"""Exit context."""
self.close()