You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/paddlespeech/audio/streamdata/writer.py

454 lines
13 KiB

#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
"""Classes and functions for writing tar files and WebDataset files."""
import io
import json
import pickle
import re
import tarfile
import time
from typing import Any
from typing import Callable
from typing import Optional
from typing import Union
import numpy as np
from . import gopen
def imageencoder(image: Any, format: str="PNG"): # skipcq: PYL-W0622
"""Compress an image using PIL and return it as a string.
Can handle float or uint8 images.
:param image: ndarray representing an image
:param format: compression format (PNG, JPEG, PPM)
"""
import PIL
assert isinstance(image, (PIL.Image.Image, np.ndarray)), type(image)
if isinstance(image, np.ndarray):
if image.dtype in [np.dtype("f"), np.dtype("d")]:
if not (np.amin(image) > -0.001 and np.amax(image) < 1.001):
raise ValueError(
f"image values out of range {np.amin(image)} {np.amax(image)}"
)
image = np.clip(image, 0.0, 1.0)
image = np.array(image * 255.0, "uint8")
assert image.ndim in [2, 3]
if image.ndim == 3:
assert image.shape[2] in [1, 3]
image = PIL.Image.fromarray(image)
if format.upper() == "JPG":
format = "JPEG"
elif format.upper() in ["IMG", "IMAGE"]:
format = "PPM"
if format == "JPEG":
opts = dict(quality=100)
else:
opts = {}
with io.BytesIO() as result:
image.save(result, format=format, **opts)
return result.getvalue()
def bytestr(data: Any):
"""Convert data into a bytestring.
Uses str and ASCII encoding for data that isn't already in string format.
:param data: data
"""
if isinstance(data, bytes):
return data
if isinstance(data, str):
return data.encode("ascii")
return str(data).encode("ascii")
def paddle_dumps(data: Any):
"""Dump data into a bytestring using paddle.dumps.
This delays importing paddle until needed.
:param data: data to be dumped
"""
import io
import paddle
stream = io.BytesIO()
paddle.save(data, stream)
return stream.getvalue()
def numpy_dumps(data: np.ndarray):
"""Dump data into a bytestring using numpy npy format.
:param data: data to be dumped
"""
import io
import numpy.lib.format
stream = io.BytesIO()
numpy.lib.format.write_array(stream, data)
return stream.getvalue()
def numpy_npz_dumps(data: np.ndarray):
"""Dump data into a bytestring using numpy npz format.
:param data: data to be dumped
"""
import io
stream = io.BytesIO()
np.savez_compressed(stream, **data)
return stream.getvalue()
def tenbin_dumps(x):
from . import tenbin
if isinstance(x, list):
return memoryview(tenbin.encode_buffer(x))
else:
return memoryview(tenbin.encode_buffer([x]))
def cbor_dumps(x):
import cbor
return cbor.dumps(x)
def mp_dumps(x):
import msgpack
return msgpack.packb(x)
def add_handlers(d, keys, value):
if isinstance(keys, str):
keys = keys.split()
for k in keys:
d[k] = value
def make_handlers():
"""Create a list of handlers for encoding data."""
handlers = {}
add_handlers(handlers, "cls cls2 class count index inx id",
lambda x: str(x).encode("ascii"))
add_handlers(handlers, "txt text transcript", lambda x: x.encode("utf-8"))
add_handlers(handlers, "html htm", lambda x: x.encode("utf-8"))
add_handlers(handlers, "pyd pickle", pickle.dumps)
add_handlers(handlers, "pdparams", paddle_dumps)
add_handlers(handlers, "npy", numpy_dumps)
add_handlers(handlers, "npz", numpy_npz_dumps)
add_handlers(handlers, "ten tenbin tb", tenbin_dumps)
add_handlers(handlers, "json jsn", lambda x: json.dumps(x).encode("utf-8"))
add_handlers(handlers, "mp msgpack msg", mp_dumps)
add_handlers(handlers, "cbor", cbor_dumps)
add_handlers(handlers, "jpg jpeg img image",
lambda data: imageencoder(data, "jpg"))
add_handlers(handlers, "png", lambda data: imageencoder(data, "png"))
add_handlers(handlers, "pbm", lambda data: imageencoder(data, "pbm"))
add_handlers(handlers, "pgm", lambda data: imageencoder(data, "pgm"))
add_handlers(handlers, "ppm", lambda data: imageencoder(data, "ppm"))
return handlers
default_handlers = make_handlers()
def encode_based_on_extension1(data: Any, tname: str, handlers: dict):
"""Encode data based on its extension and a dict of handlers.
:param data: data
:param tname: file extension
:param handlers: handlers
"""
if tname[0] == "_":
if not isinstance(data, str):
raise ValueError("the values of metadata must be of string type")
return data
extension = re.sub(r".*\.", "", tname).lower()
if isinstance(data, bytes):
return data
if isinstance(data, str):
return data.encode("utf-8")
handler = handlers.get(extension)
if handler is None:
raise ValueError(f"no handler found for {extension}")
return handler(data)
def encode_based_on_extension(sample: dict, handlers: dict):
"""Encode an entire sample with a collection of handlers.
:param sample: data sample (a dict)
:param handlers: handlers for encoding
"""
return {
k: encode_based_on_extension1(v, k, handlers)
for k, v in list(sample.items())
}
def make_encoder(spec: Union[bool, str, dict, Callable]):
"""Make an encoder function from a specification.
:param spec: specification
"""
if spec is False or spec is None:
def encoder(x):
"""Do not encode at all."""
return x
elif callable(spec):
encoder = spec
elif isinstance(spec, dict):
def f(sample):
"""Encode based on extension."""
return encode_based_on_extension(sample, spec)
encoder = f
elif spec is True:
handlers = default_handlers
def g(sample):
"""Encode based on extension."""
return encode_based_on_extension(sample, handlers)
encoder = g
else:
raise ValueError(f"{spec}: unknown decoder spec")
if not callable(encoder):
raise ValueError(f"{spec} did not yield a callable encoder")
return encoder
class TarWriter:
"""A class for writing dictionaries to tar files.
:param fileobj: fileobj: file name for tar file (.tgz/.tar) or open file descriptor
:param encoder: sample encoding (Default value = True)
:param compress: (Default value = None)
`True` will use an encoder that behaves similar to the automatic
decoder for `Dataset`. `False` disables encoding and expects byte strings
(except for metadata, which must be strings). The `encoder` argument can
also be a `callable`, or a dictionary mapping extensions to encoders.
The following code will add two file to the tar archive: `a/b.png` and
`a/b.output.png`.
```Python
tarwriter = TarWriter(stream)
image = imread("b.jpg")
image2 = imread("b.out.jpg")
sample = {"__key__": "a/b", "png": image, "output.png": image2}
tarwriter.write(sample)
```
"""
def __init__(
self,
fileobj,
user: str="bigdata",
group: str="bigdata",
mode: int=0o0444,
compress: Optional[bool]=None,
encoder: Union[None, bool, Callable]=True,
keep_meta: bool=False, ):
"""Create a tar writer.
:param fileobj: stream to write data to
:param user: user for tar files
:param group: group for tar files
:param mode: mode for tar files
:param compress: desired compression
:param encoder: encoder function
:param keep_meta: keep metadata (entries starting with "_")
"""
if isinstance(fileobj, str):
if compress is False:
tarmode = "w|"
elif compress is True:
tarmode = "w|gz"
else:
tarmode = "w|gz" if fileobj.endswith("gz") else "w|"
fileobj = gopen.gopen(fileobj, "wb")
self.own_fileobj = fileobj
else:
tarmode = "w|gz" if compress is True else "w|"
self.own_fileobj = None
self.encoder = make_encoder(encoder)
self.keep_meta = keep_meta
self.stream = fileobj
self.tarstream = tarfile.open(fileobj=fileobj, mode=tarmode)
self.user = user
self.group = group
self.mode = mode
self.compress = compress
def __enter__(self):
"""Enter context."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Exit context."""
self.close()
def close(self):
"""Close the tar file."""
self.tarstream.close()
if self.own_fileobj is not None:
self.own_fileobj.close()
self.own_fileobj = None
def write(self, obj):
"""Write a dictionary to the tar file.
:param obj: dictionary of objects to be stored
:returns: size of the entry
"""
total = 0
obj = self.encoder(obj)
if "__key__" not in obj:
raise ValueError("object must contain a __key__")
for k, v in list(obj.items()):
if k[0] == "_":
continue
if not isinstance(v, (bytes, bytearray, memoryview)):
raise ValueError(
f"{k} doesn't map to a bytes after encoding ({type(v)})")
key = obj["__key__"]
for k in sorted(obj.keys()):
if k == "__key__":
continue
if not self.keep_meta and k[0] == "_":
continue
v = obj[k]
if isinstance(v, str):
v = v.encode("utf-8")
now = time.time()
ti = tarfile.TarInfo(key + "." + k)
ti.size = len(v)
ti.mtime = now
ti.mode = self.mode
ti.uname = self.user
ti.gname = self.group
if not isinstance(v, (bytes, bytearray, memoryview)):
raise ValueError(
f"converter didn't yield bytes: {k}, {type(v)}")
stream = io.BytesIO(v)
self.tarstream.addfile(ti, stream)
total += ti.size
return total
class ShardWriter:
"""Like TarWriter but splits into multiple shards."""
def __init__(
self,
pattern: str,
maxcount: int=100000,
maxsize: float=3e9,
post: Optional[Callable]=None,
start_shard: int=0,
**kw, ):
"""Create a ShardWriter.
:param pattern: output file pattern
:param maxcount: maximum number of records per shard (Default value = 100000)
:param maxsize: maximum size of each shard (Default value = 3e9)
:param kw: other options passed to TarWriter
"""
self.verbose = 1
self.kw = kw
self.maxcount = maxcount
self.maxsize = maxsize
self.post = post
self.tarstream = None
self.shard = start_shard
self.pattern = pattern
self.total = 0
self.count = 0
self.size = 0
self.fname = None
self.next_stream()
def next_stream(self):
"""Close the current stream and move to the next."""
self.finish()
self.fname = self.pattern % self.shard
if self.verbose:
print(
"# writing",
self.fname,
self.count,
"%.1f GB" % (self.size / 1e9),
self.total, )
self.shard += 1
stream = open(self.fname, "wb")
self.tarstream = TarWriter(stream, **self.kw)
self.count = 0
self.size = 0
def write(self, obj):
"""Write a sample.
:param obj: sample to be written
"""
if (self.tarstream is None or self.count >= self.maxcount or
self.size >= self.maxsize):
self.next_stream()
size = self.tarstream.write(obj)
self.count += 1
self.total += 1
self.size += size
def finish(self):
"""Finish all writing (use close instead)."""
if self.tarstream is not None:
self.tarstream.close()
assert self.fname is not None
if callable(self.post):
self.post(self.fname)
self.tarstream = None
def close(self):
"""Close the stream."""
self.finish()
del self.tarstream
del self.shard
del self.count
del self.size
def __enter__(self):
"""Enter context."""
return self
def __exit__(self, *args, **kw):
"""Exit context."""
self.close()