Merge pull request #2062 from Jackwaterveg/webdataset
[Audio] Add webdataset in paddlespeech.audiopull/2120/head
commit
d1a25f6cb1
@ -0,0 +1,68 @@
|
||||
#!/bin/bash
|
||||
|
||||
profiler_options=
|
||||
benchmark_batch_size=0
|
||||
benchmark_max_step=0
|
||||
|
||||
# seed may break model convergence
|
||||
seed=0
|
||||
|
||||
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
|
||||
|
||||
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
echo "using $ngpu gpus..."
|
||||
|
||||
if [ ${seed} != 0 ]; then
|
||||
export FLAGS_cudnn_deterministic=True
|
||||
echo "using seed $seed & FLAGS_cudnn_deterministic=True ..."
|
||||
fi
|
||||
|
||||
if [ $# -lt 2 ] && [ $# -gt 3 ];then
|
||||
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
config_path=$1
|
||||
ckpt_name=$2
|
||||
ips=$3
|
||||
|
||||
if [ ! $ips ];then
|
||||
ips_config=
|
||||
else
|
||||
ips_config="--ips="${ips}
|
||||
fi
|
||||
echo ${ips_config}
|
||||
|
||||
mkdir -p exp
|
||||
|
||||
if [ ${ngpu} == 0 ]; then
|
||||
python3 -u ${BIN_DIR}/train.py \
|
||||
--ngpu ${ngpu} \
|
||||
--seed ${seed} \
|
||||
--config ${config_path} \
|
||||
--output exp/${ckpt_name} \
|
||||
--profiler-options "${profiler_options}" \
|
||||
--benchmark-batch-size ${benchmark_batch_size} \
|
||||
--benchmark-max-step ${benchmark_max_step}
|
||||
else
|
||||
NCCL_SOCKET_IFNAME=eth0 python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
|
||||
--ngpu ${ngpu} \
|
||||
--seed ${seed} \
|
||||
--config ${config_path} \
|
||||
--output exp/${ckpt_name} \
|
||||
--profiler-options "${profiler_options}" \
|
||||
--benchmark-batch-size ${benchmark_batch_size} \
|
||||
--benchmark-max-step ${benchmark_max_step}
|
||||
fi
|
||||
|
||||
|
||||
if [ ${seed} != 0 ]; then
|
||||
unset FLAGS_cudnn_deterministic
|
||||
fi
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in training!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
exit 0
|
@ -0,0 +1,70 @@
|
||||
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
# See the LICENSE file for licensing terms (BSD-style).
|
||||
# Modified from https://github.com/webdataset/webdataset
|
||||
#
|
||||
# flake8: noqa
|
||||
|
||||
from .cache import (
|
||||
cached_tarfile_samples,
|
||||
cached_tarfile_to_samples,
|
||||
lru_cleanup,
|
||||
pipe_cleaner,
|
||||
)
|
||||
from .compat import WebDataset, WebLoader, FluidWrapper
|
||||
from .extradatasets import MockDataset, with_epoch, with_length
|
||||
from .filters import (
|
||||
associate,
|
||||
batched,
|
||||
decode,
|
||||
detshuffle,
|
||||
extract_keys,
|
||||
getfirst,
|
||||
info,
|
||||
map,
|
||||
map_dict,
|
||||
map_tuple,
|
||||
pipelinefilter,
|
||||
rename,
|
||||
rename_keys,
|
||||
audio_resample,
|
||||
select,
|
||||
shuffle,
|
||||
slice,
|
||||
to_tuple,
|
||||
transform_with,
|
||||
unbatched,
|
||||
xdecode,
|
||||
audio_data_filter,
|
||||
audio_tokenize,
|
||||
audio_resample,
|
||||
audio_compute_fbank,
|
||||
audio_spec_aug,
|
||||
sort,
|
||||
audio_padding,
|
||||
audio_cmvn,
|
||||
placeholder,
|
||||
)
|
||||
from .handlers import (
|
||||
ignore_and_continue,
|
||||
ignore_and_stop,
|
||||
reraise_exception,
|
||||
warn_and_continue,
|
||||
warn_and_stop,
|
||||
)
|
||||
from .pipeline import DataPipeline
|
||||
from .shardlists import (
|
||||
MultiShardSample,
|
||||
ResampledShards,
|
||||
SimpleShardList,
|
||||
non_empty,
|
||||
resampled,
|
||||
shardspec,
|
||||
single_node_only,
|
||||
split_by_node,
|
||||
split_by_worker,
|
||||
)
|
||||
from .tariterators import tarfile_samples, tarfile_to_samples
|
||||
from .utils import PipelineStage, repeatedly
|
||||
from .writer import ShardWriter, TarWriter, numpy_dumps
|
||||
from .mix import RandomMix, RoundRobin
|
@ -0,0 +1,445 @@
|
||||
#
|
||||
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
# This file is part of the WebDataset library.
|
||||
# See the LICENSE file for licensing terms (BSD-style).
|
||||
# Modified from https://github.com/webdataset/webdataset
|
||||
#
|
||||
|
||||
"""Automatically decode webdataset samples."""
|
||||
|
||||
import io, json, os, pickle, re, tempfile
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
|
||||
"""Extensions passed on to the image decoder."""
|
||||
image_extensions = "jpg jpeg png ppm pgm pbm pnm".split()
|
||||
|
||||
|
||||
################################################################
|
||||
# handle basic datatypes
|
||||
################################################################
|
||||
|
||||
|
||||
def paddle_loads(data):
|
||||
"""Load data using paddle.loads, importing paddle only if needed.
|
||||
|
||||
:param data: data to be decoded
|
||||
"""
|
||||
import io
|
||||
|
||||
import paddle
|
||||
|
||||
stream = io.BytesIO(data)
|
||||
return paddle.load(stream)
|
||||
|
||||
|
||||
def tenbin_loads(data):
|
||||
from . import tenbin
|
||||
|
||||
return tenbin.decode_buffer(data)
|
||||
|
||||
|
||||
def msgpack_loads(data):
|
||||
import msgpack
|
||||
|
||||
return msgpack.unpackb(data)
|
||||
|
||||
|
||||
def npy_loads(data):
|
||||
import numpy.lib.format
|
||||
|
||||
stream = io.BytesIO(data)
|
||||
return numpy.lib.format.read_array(stream)
|
||||
|
||||
|
||||
def cbor_loads(data):
|
||||
import cbor
|
||||
|
||||
return cbor.loads(data)
|
||||
|
||||
|
||||
decoders = {
|
||||
"txt": lambda data: data.decode("utf-8"),
|
||||
"text": lambda data: data.decode("utf-8"),
|
||||
"transcript": lambda data: data.decode("utf-8"),
|
||||
"cls": lambda data: int(data),
|
||||
"cls2": lambda data: int(data),
|
||||
"index": lambda data: int(data),
|
||||
"inx": lambda data: int(data),
|
||||
"id": lambda data: int(data),
|
||||
"json": lambda data: json.loads(data),
|
||||
"jsn": lambda data: json.loads(data),
|
||||
"pyd": lambda data: pickle.loads(data),
|
||||
"pickle": lambda data: pickle.loads(data),
|
||||
"pdparams": lambda data: paddle_loads(data),
|
||||
"ten": tenbin_loads,
|
||||
"tb": tenbin_loads,
|
||||
"mp": msgpack_loads,
|
||||
"msg": msgpack_loads,
|
||||
"npy": npy_loads,
|
||||
"npz": lambda data: np.load(io.BytesIO(data)),
|
||||
"cbor": cbor_loads,
|
||||
}
|
||||
|
||||
|
||||
def basichandlers(key, data):
|
||||
"""Handle basic file decoding.
|
||||
|
||||
This function is usually part of the post= decoders.
|
||||
This handles the following forms of decoding:
|
||||
|
||||
- txt -> unicode string
|
||||
- cls cls2 class count index inx id -> int
|
||||
- json jsn -> JSON decoding
|
||||
- pyd pickle -> pickle decoding
|
||||
- pdparams -> paddle.loads
|
||||
- ten tenbin -> fast tensor loading
|
||||
- mp messagepack msg -> messagepack decoding
|
||||
- npy -> Python NPY decoding
|
||||
|
||||
:param key: file name extension
|
||||
:param data: binary data to be decoded
|
||||
"""
|
||||
extension = re.sub(r".*[.]", "", key)
|
||||
|
||||
if extension in decoders:
|
||||
return decoders[extension](data)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
################################################################
|
||||
# Generic extension handler.
|
||||
################################################################
|
||||
|
||||
|
||||
def call_extension_handler(key, data, f, extensions):
|
||||
"""Call the function f with the given data if the key matches the extensions.
|
||||
|
||||
:param key: actual key found in the sample
|
||||
:param data: binary data
|
||||
:param f: decoder function
|
||||
:param extensions: list of matching extensions
|
||||
"""
|
||||
extension = key.lower().split(".")
|
||||
for target in extensions:
|
||||
target = target.split(".")
|
||||
if len(target) > len(extension):
|
||||
continue
|
||||
if extension[-len(target) :] == target:
|
||||
return f(data)
|
||||
return None
|
||||
|
||||
|
||||
def handle_extension(extensions, f):
|
||||
"""Return a decoder function for the list of extensions.
|
||||
|
||||
Extensions can be a space separated list of extensions.
|
||||
Extensions can contain dots, in which case the corresponding number
|
||||
of extension components must be present in the key given to f.
|
||||
Comparisons are case insensitive.
|
||||
|
||||
Examples:
|
||||
handle_extension("jpg jpeg", my_decode_jpg) # invoked for any file.jpg
|
||||
handle_extension("seg.jpg", special_case_jpg) # invoked only for file.seg.jpg
|
||||
"""
|
||||
extensions = extensions.lower().split()
|
||||
return partial(call_extension_handler, f=f, extensions=extensions)
|
||||
|
||||
|
||||
################################################################
|
||||
# handle images
|
||||
################################################################
|
||||
|
||||
imagespecs = {
|
||||
"l8": ("numpy", "uint8", "l"),
|
||||
"rgb8": ("numpy", "uint8", "rgb"),
|
||||
"rgba8": ("numpy", "uint8", "rgba"),
|
||||
"l": ("numpy", "float", "l"),
|
||||
"rgb": ("numpy", "float", "rgb"),
|
||||
"rgba": ("numpy", "float", "rgba"),
|
||||
"paddlel8": ("paddle", "uint8", "l"),
|
||||
"paddlergb8": ("paddle", "uint8", "rgb"),
|
||||
"paddlergba8": ("paddle", "uint8", "rgba"),
|
||||
"paddlel": ("paddle", "float", "l"),
|
||||
"paddlergb": ("paddle", "float", "rgb"),
|
||||
"paddle": ("paddle", "float", "rgb"),
|
||||
"paddlergba": ("paddle", "float", "rgba"),
|
||||
"pill": ("pil", None, "l"),
|
||||
"pil": ("pil", None, "rgb"),
|
||||
"pilrgb": ("pil", None, "rgb"),
|
||||
"pilrgba": ("pil", None, "rgba"),
|
||||
}
|
||||
|
||||
|
||||
class ImageHandler:
|
||||
"""Decode image data using the given `imagespec`.
|
||||
|
||||
The `imagespec` specifies whether the image is decoded
|
||||
to numpy/paddle/pi, decoded to uint8/float, and decoded
|
||||
to l/rgb/rgba:
|
||||
|
||||
- l8: numpy uint8 l
|
||||
- rgb8: numpy uint8 rgb
|
||||
- rgba8: numpy uint8 rgba
|
||||
- l: numpy float l
|
||||
- rgb: numpy float rgb
|
||||
- rgba: numpy float rgba
|
||||
- paddlel8: paddle uint8 l
|
||||
- paddlergb8: paddle uint8 rgb
|
||||
- paddlergba8: paddle uint8 rgba
|
||||
- paddlel: paddle float l
|
||||
- paddlergb: paddle float rgb
|
||||
- paddle: paddle float rgb
|
||||
- paddlergba: paddle float rgba
|
||||
- pill: pil None l
|
||||
- pil: pil None rgb
|
||||
- pilrgb: pil None rgb
|
||||
- pilrgba: pil None rgba
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, imagespec, extensions=image_extensions):
|
||||
"""Create an image handler.
|
||||
|
||||
:param imagespec: short string indicating the type of decoding
|
||||
:param extensions: list of extensions the image handler is invoked for
|
||||
"""
|
||||
if imagespec not in list(imagespecs.keys()):
|
||||
raise ValueError("Unknown imagespec: %s" % imagespec)
|
||||
self.imagespec = imagespec.lower()
|
||||
self.extensions = extensions
|
||||
|
||||
def __call__(self, key, data):
|
||||
"""Perform image decoding.
|
||||
|
||||
:param key: file name extension
|
||||
:param data: binary data
|
||||
"""
|
||||
import PIL.Image
|
||||
|
||||
extension = re.sub(r".*[.]", "", key)
|
||||
if extension.lower() not in self.extensions:
|
||||
return None
|
||||
imagespec = self.imagespec
|
||||
atype, etype, mode = imagespecs[imagespec]
|
||||
with io.BytesIO(data) as stream:
|
||||
img = PIL.Image.open(stream)
|
||||
img.load()
|
||||
img = img.convert(mode.upper())
|
||||
if atype == "pil":
|
||||
return img
|
||||
elif atype == "numpy":
|
||||
result = np.asarray(img)
|
||||
if result.dtype != np.uint8:
|
||||
raise ValueError("ImageHandler: numpy image must be uint8")
|
||||
if etype == "uint8":
|
||||
return result
|
||||
else:
|
||||
return result.astype("f") / 255.0
|
||||
elif atype == "paddle":
|
||||
import paddle
|
||||
|
||||
result = np.asarray(img)
|
||||
if result.dtype != np.uint8:
|
||||
raise ValueError("ImageHandler: paddle image must be uint8")
|
||||
if etype == "uint8":
|
||||
result = np.array(result.transpose(2, 0, 1))
|
||||
return paddle.tensor(result)
|
||||
else:
|
||||
result = np.array(result.transpose(2, 0, 1))
|
||||
return paddle.tensor(result) / 255.0
|
||||
return None
|
||||
|
||||
|
||||
def imagehandler(imagespec, extensions=image_extensions):
|
||||
"""Create an image handler.
|
||||
|
||||
This is just a lower case alias for ImageHander.
|
||||
|
||||
:param imagespec: textual image spec
|
||||
:param extensions: list of extensions the handler should be applied for
|
||||
"""
|
||||
return ImageHandler(imagespec, extensions)
|
||||
|
||||
|
||||
################################################################
|
||||
# torch video
|
||||
################################################################
|
||||
|
||||
'''
|
||||
def torch_video(key, data):
|
||||
"""Decode video using the torchvideo library.
|
||||
|
||||
:param key: file name extension
|
||||
:param data: data to be decoded
|
||||
"""
|
||||
extension = re.sub(r".*[.]", "", key)
|
||||
if extension not in "mp4 ogv mjpeg avi mov h264 mpg webm wmv".split():
|
||||
return None
|
||||
|
||||
import torchvision.io
|
||||
|
||||
with tempfile.TemporaryDirectory() as dirname:
|
||||
fname = os.path.join(dirname, f"file.{extension}")
|
||||
with open(fname, "wb") as stream:
|
||||
stream.write(data)
|
||||
return torchvision.io.read_video(fname, pts_unit="sec")
|
||||
'''
|
||||
|
||||
|
||||
################################################################
|
||||
# paddlespeech.audio
|
||||
################################################################
|
||||
|
||||
|
||||
def paddle_audio(key, data):
|
||||
"""Decode audio using the paddlespeech.audio library.
|
||||
|
||||
:param key: file name extension
|
||||
:param data: data to be decoded
|
||||
"""
|
||||
extension = re.sub(r".*[.]", "", key)
|
||||
if extension not in ["flac", "mp3", "sox", "wav", "m4a", "ogg", "wma"]:
|
||||
return None
|
||||
|
||||
import paddlespeech.audio
|
||||
|
||||
with tempfile.TemporaryDirectory() as dirname:
|
||||
fname = os.path.join(dirname, f"file.{extension}")
|
||||
with open(fname, "wb") as stream:
|
||||
stream.write(data)
|
||||
return paddlespeech.audio.load(fname)
|
||||
|
||||
|
||||
################################################################
|
||||
# special class for continuing decoding
|
||||
################################################################
|
||||
|
||||
|
||||
class Continue:
|
||||
"""Special class for continuing decoding.
|
||||
|
||||
This is mostly used for decompression, as in:
|
||||
|
||||
def decompressor(key, data):
|
||||
if key.endswith(".gz"):
|
||||
return Continue(key[:-3], decompress(data))
|
||||
return None
|
||||
"""
|
||||
|
||||
def __init__(self, key, data):
|
||||
"""__init__.
|
||||
|
||||
:param key:
|
||||
:param data:
|
||||
"""
|
||||
self.key, self.data = key, data
|
||||
|
||||
|
||||
def gzfilter(key, data):
|
||||
"""Decode .gz files.
|
||||
|
||||
This decodes compressed files and the continues decoding.
|
||||
|
||||
:param key: file name extension
|
||||
:param data: binary data
|
||||
"""
|
||||
import gzip
|
||||
|
||||
if not key.endswith(".gz"):
|
||||
return None
|
||||
decompressed = gzip.open(io.BytesIO(data)).read()
|
||||
return Continue(key[:-3], decompressed)
|
||||
|
||||
|
||||
################################################################
|
||||
# decode entire training amples
|
||||
################################################################
|
||||
|
||||
|
||||
default_pre_handlers = [gzfilter]
|
||||
default_post_handlers = [basichandlers]
|
||||
|
||||
|
||||
class Decoder:
|
||||
"""Decode samples using a list of handlers.
|
||||
|
||||
For each key/data item, this iterates through the list of
|
||||
handlers until some handler returns something other than None.
|
||||
"""
|
||||
|
||||
def __init__(self, handlers, pre=None, post=None, only=None, partial=False):
|
||||
"""Create a Decoder.
|
||||
|
||||
:param handlers: main list of handlers
|
||||
:param pre: handlers called before the main list (.gz handler by default)
|
||||
:param post: handlers called after the main list (default handlers by default)
|
||||
:param only: a list of extensions; when give, only ignores files with those extensions
|
||||
:param partial: allow partial decoding (i.e., don't decode fields that aren't of type bytes)
|
||||
"""
|
||||
if isinstance(only, str):
|
||||
only = only.split()
|
||||
self.only = only if only is None else set(only)
|
||||
if pre is None:
|
||||
pre = default_pre_handlers
|
||||
if post is None:
|
||||
post = default_post_handlers
|
||||
assert all(callable(h) for h in handlers), f"one of {handlers} not callable"
|
||||
assert all(callable(h) for h in pre), f"one of {pre} not callable"
|
||||
assert all(callable(h) for h in post), f"one of {post} not callable"
|
||||
self.handlers = pre + handlers + post
|
||||
self.partial = partial
|
||||
|
||||
def decode1(self, key, data):
|
||||
"""Decode a single field of a sample.
|
||||
|
||||
:param key: file name extension
|
||||
:param data: binary data
|
||||
"""
|
||||
key = "." + key
|
||||
for f in self.handlers:
|
||||
result = f(key, data)
|
||||
if isinstance(result, Continue):
|
||||
key, data = result.key, result.data
|
||||
continue
|
||||
if result is not None:
|
||||
return result
|
||||
return data
|
||||
|
||||
def decode(self, sample):
|
||||
"""Decode an entire sample.
|
||||
|
||||
:param sample: the sample, a dictionary of key value pairs
|
||||
"""
|
||||
result = {}
|
||||
assert isinstance(sample, dict), sample
|
||||
for k, v in list(sample.items()):
|
||||
if k[0] == "_":
|
||||
if isinstance(v, bytes):
|
||||
v = v.decode("utf-8")
|
||||
result[k] = v
|
||||
continue
|
||||
if self.only is not None and k not in self.only:
|
||||
result[k] = v
|
||||
continue
|
||||
assert v is not None
|
||||
if self.partial:
|
||||
if isinstance(v, bytes):
|
||||
result[k] = self.decode1(k, v)
|
||||
else:
|
||||
result[k] = v
|
||||
else:
|
||||
assert isinstance(v, bytes)
|
||||
result[k] = self.decode1(k, v)
|
||||
return result
|
||||
|
||||
def __call__(self, sample):
|
||||
"""Decode an entire sample.
|
||||
|
||||
:param sample: the sample
|
||||
"""
|
||||
assert isinstance(sample, dict), (len(sample), sample)
|
||||
return self.decode(sample)
|
@ -0,0 +1,190 @@
|
||||
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
# See the LICENSE file for licensing terms (BSD-style).
|
||||
# Modified from https://github.com/webdataset/webdataset
|
||||
import itertools, os, random, re, sys
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from . import filters
|
||||
from . import gopen
|
||||
from .handlers import reraise_exception
|
||||
from .tariterators import tar_file_and_group_expander
|
||||
|
||||
default_cache_dir = os.environ.get("WDS_CACHE", "./_cache")
|
||||
default_cache_size = float(os.environ.get("WDS_CACHE_SIZE", "1e18"))
|
||||
|
||||
|
||||
def lru_cleanup(cache_dir, cache_size, keyfn=os.path.getctime, verbose=False):
|
||||
"""Performs cleanup of the file cache in cache_dir using an LRU strategy,
|
||||
keeping the total size of all remaining files below cache_size."""
|
||||
if not os.path.exists(cache_dir):
|
||||
return
|
||||
total_size = 0
|
||||
for dirpath, dirnames, filenames in os.walk(cache_dir):
|
||||
for filename in filenames:
|
||||
total_size += os.path.getsize(os.path.join(dirpath, filename))
|
||||
if total_size <= cache_size:
|
||||
return
|
||||
# sort files by last access time
|
||||
files = []
|
||||
for dirpath, dirnames, filenames in os.walk(cache_dir):
|
||||
for filename in filenames:
|
||||
files.append(os.path.join(dirpath, filename))
|
||||
files.sort(key=keyfn, reverse=True)
|
||||
# delete files until we're under the cache size
|
||||
while len(files) > 0 and total_size > cache_size:
|
||||
fname = files.pop()
|
||||
total_size -= os.path.getsize(fname)
|
||||
if verbose:
|
||||
print("# deleting %s" % fname, file=sys.stderr)
|
||||
os.remove(fname)
|
||||
|
||||
|
||||
def download(url, dest, chunk_size=1024 ** 2, verbose=False):
|
||||
"""Download a file from `url` to `dest`."""
|
||||
temp = dest + f".temp{os.getpid()}"
|
||||
with gopen.gopen(url) as stream:
|
||||
with open(temp, "wb") as f:
|
||||
while True:
|
||||
data = stream.read(chunk_size)
|
||||
if not data:
|
||||
break
|
||||
f.write(data)
|
||||
os.rename(temp, dest)
|
||||
|
||||
|
||||
def pipe_cleaner(spec):
|
||||
"""Guess the actual URL from a "pipe:" specification."""
|
||||
if spec.startswith("pipe:"):
|
||||
spec = spec[5:]
|
||||
words = spec.split(" ")
|
||||
for word in words:
|
||||
if re.match(r"^(https?|gs|ais|s3)", word):
|
||||
return word
|
||||
return spec
|
||||
|
||||
|
||||
def get_file_cached(
|
||||
spec,
|
||||
cache_size=-1,
|
||||
cache_dir=None,
|
||||
url_to_name=pipe_cleaner,
|
||||
verbose=False,
|
||||
):
|
||||
if cache_size == -1:
|
||||
cache_size = default_cache_size
|
||||
if cache_dir is None:
|
||||
cache_dir = default_cache_dir
|
||||
url = url_to_name(spec)
|
||||
parsed = urlparse(url)
|
||||
dirname, filename = os.path.split(parsed.path)
|
||||
dirname = dirname.lstrip("/")
|
||||
dirname = re.sub(r"[:/|;]", "_", dirname)
|
||||
destdir = os.path.join(cache_dir, dirname)
|
||||
os.makedirs(destdir, exist_ok=True)
|
||||
dest = os.path.join(cache_dir, dirname, filename)
|
||||
if not os.path.exists(dest):
|
||||
if verbose:
|
||||
print("# downloading %s to %s" % (url, dest), file=sys.stderr)
|
||||
lru_cleanup(cache_dir, cache_size, verbose=verbose)
|
||||
download(spec, dest, verbose=verbose)
|
||||
return dest
|
||||
|
||||
|
||||
def get_filetype(fname):
|
||||
with os.popen("file '%s'" % fname) as f:
|
||||
ftype = f.read()
|
||||
return ftype
|
||||
|
||||
|
||||
def check_tar_format(fname):
|
||||
"""Check whether a file is a tar archive."""
|
||||
ftype = get_filetype(fname)
|
||||
return "tar archive" in ftype or "gzip compressed" in ftype
|
||||
|
||||
|
||||
verbose_cache = int(os.environ.get("WDS_VERBOSE_CACHE", "0"))
|
||||
|
||||
|
||||
def cached_url_opener(
|
||||
data,
|
||||
handler=reraise_exception,
|
||||
cache_size=-1,
|
||||
cache_dir=None,
|
||||
url_to_name=pipe_cleaner,
|
||||
validator=check_tar_format,
|
||||
verbose=False,
|
||||
always=False,
|
||||
):
|
||||
"""Given a stream of url names (packaged in `dict(url=url)`), yield opened streams."""
|
||||
verbose = verbose or verbose_cache
|
||||
for sample in data:
|
||||
assert isinstance(sample, dict), sample
|
||||
assert "url" in sample
|
||||
url = sample["url"]
|
||||
attempts = 5
|
||||
try:
|
||||
if not always and os.path.exists(url):
|
||||
dest = url
|
||||
else:
|
||||
dest = get_file_cached(
|
||||
url,
|
||||
cache_size=cache_size,
|
||||
cache_dir=cache_dir,
|
||||
url_to_name=url_to_name,
|
||||
verbose=verbose,
|
||||
)
|
||||
if verbose:
|
||||
print("# opening %s" % dest, file=sys.stderr)
|
||||
assert os.path.exists(dest)
|
||||
if not validator(dest):
|
||||
ftype = get_filetype(dest)
|
||||
with open(dest, "rb") as f:
|
||||
data = f.read(200)
|
||||
os.remove(dest)
|
||||
raise ValueError(
|
||||
"%s (%s) is not a tar archive, but a %s, contains %s"
|
||||
% (dest, url, ftype, repr(data))
|
||||
)
|
||||
try:
|
||||
stream = open(dest, "rb")
|
||||
sample.update(stream=stream)
|
||||
yield sample
|
||||
except FileNotFoundError as exn:
|
||||
# dealing with race conditions in lru_cleanup
|
||||
attempts -= 1
|
||||
if attempts > 0:
|
||||
time.sleep(random.random() * 10)
|
||||
continue
|
||||
raise exn
|
||||
except Exception as exn:
|
||||
exn.args = exn.args + (url,)
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
def cached_tarfile_samples(
|
||||
src,
|
||||
handler=reraise_exception,
|
||||
cache_size=-1,
|
||||
cache_dir=None,
|
||||
verbose=False,
|
||||
url_to_name=pipe_cleaner,
|
||||
always=False,
|
||||
):
|
||||
streams = cached_url_opener(
|
||||
src,
|
||||
handler=handler,
|
||||
cache_size=cache_size,
|
||||
cache_dir=cache_dir,
|
||||
verbose=verbose,
|
||||
url_to_name=url_to_name,
|
||||
always=always,
|
||||
)
|
||||
samples = tar_file_and_group_expander(streams, handler=handler)
|
||||
return samples
|
||||
|
||||
|
||||
cached_tarfile_to_samples = filters.pipelinefilter(cached_tarfile_samples)
|
@ -0,0 +1,170 @@
|
||||
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
# See the LICENSE file for licensing terms (BSD-style).
|
||||
# Modified from https://github.com/webdataset/webdataset
|
||||
from dataclasses import dataclass
|
||||
from itertools import islice
|
||||
from typing import List
|
||||
|
||||
import braceexpand, yaml
|
||||
|
||||
from . import autodecode
|
||||
from . import cache, filters, shardlists, tariterators
|
||||
from .filters import reraise_exception
|
||||
from .pipeline import DataPipeline
|
||||
from .paddle_utils import DataLoader, IterableDataset
|
||||
|
||||
|
||||
class FluidInterface:
|
||||
def batched(self, batchsize):
|
||||
return self.compose(filters.batched(batchsize))
|
||||
|
||||
def dynamic_batched(self, max_frames_in_batch):
|
||||
return self.compose(filter.dynamic_batched(max_frames_in_batch))
|
||||
|
||||
def unbatched(self):
|
||||
return self.compose(filters.unbatched())
|
||||
|
||||
def listed(self, batchsize, partial=True):
|
||||
return self.compose(filters.batched(), batchsize=batchsize, collation_fn=None)
|
||||
|
||||
def unlisted(self):
|
||||
return self.compose(filters.unlisted())
|
||||
|
||||
def log_keys(self, logfile=None):
|
||||
return self.compose(filters.log_keys(logfile))
|
||||
|
||||
def shuffle(self, size, **kw):
|
||||
if size < 1:
|
||||
return self
|
||||
else:
|
||||
return self.compose(filters.shuffle(size, **kw))
|
||||
|
||||
def map(self, f, handler=reraise_exception):
|
||||
return self.compose(filters.map(f, handler=handler))
|
||||
|
||||
def decode(self, *args, pre=None, post=None, only=None, partial=False, handler=reraise_exception):
|
||||
handlers = [autodecode.ImageHandler(x) if isinstance(x, str) else x for x in args]
|
||||
decoder = autodecode.Decoder(handlers, pre=pre, post=post, only=only, partial=partial)
|
||||
return self.map(decoder, handler=handler)
|
||||
|
||||
def map_dict(self, handler=reraise_exception, **kw):
|
||||
return self.compose(filters.map_dict(handler=handler, **kw))
|
||||
|
||||
def select(self, predicate, **kw):
|
||||
return self.compose(filters.select(predicate, **kw))
|
||||
|
||||
def to_tuple(self, *args, handler=reraise_exception):
|
||||
return self.compose(filters.to_tuple(*args, handler=handler))
|
||||
|
||||
def map_tuple(self, *args, handler=reraise_exception):
|
||||
return self.compose(filters.map_tuple(*args, handler=handler))
|
||||
|
||||
def slice(self, *args):
|
||||
return self.compose(filters.slice(*args))
|
||||
|
||||
def rename(self, **kw):
|
||||
return self.compose(filters.rename(**kw))
|
||||
|
||||
def rsample(self, p=0.5):
|
||||
return self.compose(filters.rsample(p))
|
||||
|
||||
def rename_keys(self, *args, **kw):
|
||||
return self.compose(filters.rename_keys(*args, **kw))
|
||||
|
||||
def extract_keys(self, *args, **kw):
|
||||
return self.compose(filters.extract_keys(*args, **kw))
|
||||
|
||||
def xdecode(self, *args, **kw):
|
||||
return self.compose(filters.xdecode(*args, **kw))
|
||||
|
||||
def audio_data_filter(self, *args, **kw):
|
||||
return self.compose(filters.audio_data_filter(*args, **kw))
|
||||
|
||||
def audio_tokenize(self, *args, **kw):
|
||||
return self.compose(filters.audio_tokenize(*args, **kw))
|
||||
|
||||
def resample(self, *args, **kw):
|
||||
return self.compose(filters.resample(*args, **kw))
|
||||
|
||||
def audio_compute_fbank(self, *args, **kw):
|
||||
return self.compose(filters.audio_compute_fbank(*args, **kw))
|
||||
|
||||
def audio_spec_aug(self, *args, **kw):
|
||||
return self.compose(filters.audio_spec_aug(*args, **kw))
|
||||
|
||||
def sort(self, size=500):
|
||||
return self.compose(filters.sort(size))
|
||||
|
||||
def audio_padding(self):
|
||||
return self.compose(filters.audio_padding())
|
||||
|
||||
def audio_cmvn(self, cmvn_file):
|
||||
return self.compose(filters.audio_cmvn(cmvn_file))
|
||||
|
||||
class WebDataset(DataPipeline, FluidInterface):
|
||||
"""Small fluid-interface wrapper for DataPipeline."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
urls,
|
||||
handler=reraise_exception,
|
||||
resampled=False,
|
||||
repeat=False,
|
||||
shardshuffle=None,
|
||||
cache_size=0,
|
||||
cache_dir=None,
|
||||
detshuffle=False,
|
||||
nodesplitter=shardlists.single_node_only,
|
||||
verbose=False,
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(urls, IterableDataset):
|
||||
assert not resampled
|
||||
self.append(urls)
|
||||
elif isinstance(urls, str) and (urls.endswith(".yaml") or urls.endswith(".yml")):
|
||||
with (open(urls)) as stream:
|
||||
spec = yaml.safe_load(stream)
|
||||
assert "datasets" in spec
|
||||
self.append(shardlists.MultiShardSample(spec))
|
||||
elif isinstance(urls, dict):
|
||||
assert "datasets" in urls
|
||||
self.append(shardlists.MultiShardSample(urls))
|
||||
elif resampled:
|
||||
self.append(shardlists.ResampledShards(urls))
|
||||
else:
|
||||
self.append(shardlists.SimpleShardList(urls))
|
||||
self.append(nodesplitter)
|
||||
self.append(shardlists.split_by_worker)
|
||||
if shardshuffle is True:
|
||||
shardshuffle = 100
|
||||
if shardshuffle is not None:
|
||||
if detshuffle:
|
||||
self.append(filters.detshuffle(shardshuffle))
|
||||
else:
|
||||
self.append(filters.shuffle(shardshuffle))
|
||||
if cache_size == 0:
|
||||
self.append(tariterators.tarfile_to_samples(handler=handler))
|
||||
else:
|
||||
assert cache_size == -1 or cache_size > 0
|
||||
self.append(
|
||||
cache.cached_tarfile_to_samples(
|
||||
handler=handler,
|
||||
verbose=verbose,
|
||||
cache_size=cache_size,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class FluidWrapper(DataPipeline, FluidInterface):
|
||||
"""Small fluid-interface wrapper for DataPipeline."""
|
||||
|
||||
def __init__(self, initial):
|
||||
super().__init__()
|
||||
self.append(initial)
|
||||
|
||||
|
||||
class WebLoader(DataPipeline, FluidInterface):
|
||||
def __init__(self, *args, **kw):
|
||||
super().__init__(DataLoader(*args, **kw))
|
@ -0,0 +1,141 @@
|
||||
#
|
||||
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
# This file is part of the WebDataset library.
|
||||
# See the LICENSE file for licensing terms (BSD-style).
|
||||
# Modified from https://github.com/webdataset/webdataset
|
||||
#
|
||||
|
||||
|
||||
"""Train PyTorch models directly from POSIX tar archive.
|
||||
|
||||
Code works locally or over HTTP connections.
|
||||
"""
|
||||
|
||||
import itertools as itt
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
import braceexpand
|
||||
|
||||
from . import utils
|
||||
from .paddle_utils import IterableDataset
|
||||
from .utils import PipelineStage
|
||||
|
||||
|
||||
class MockDataset(IterableDataset):
|
||||
"""MockDataset.
|
||||
|
||||
A mock dataset for performance testing and unit testing.
|
||||
"""
|
||||
|
||||
def __init__(self, sample, length):
|
||||
"""Create a mock dataset instance.
|
||||
|
||||
:param sample: the sample to be returned repeatedly
|
||||
:param length: the length of the mock dataset
|
||||
"""
|
||||
self.sample = sample
|
||||
self.length = length
|
||||
|
||||
def __iter__(self):
|
||||
"""Return an iterator over this mock dataset."""
|
||||
for i in range(self.length):
|
||||
yield self.sample
|
||||
|
||||
|
||||
class repeatedly(IterableDataset, PipelineStage):
|
||||
"""Repeatedly yield samples from a dataset."""
|
||||
|
||||
def __init__(self, source, nepochs=None, nbatches=None, length=None):
|
||||
"""Create an instance of Repeatedly.
|
||||
|
||||
:param nepochs: repeat for a maximum of nepochs
|
||||
:param nbatches: repeat for a maximum of nbatches
|
||||
"""
|
||||
self.source = source
|
||||
self.length = length
|
||||
self.nbatches = nbatches
|
||||
|
||||
def invoke(self, source):
|
||||
"""Return an iterator that iterates repeatedly over a source."""
|
||||
return utils.repeatedly(
|
||||
source,
|
||||
nepochs=self.nepochs,
|
||||
nbatches=self.nbatches,
|
||||
)
|
||||
|
||||
|
||||
class with_epoch(IterableDataset):
|
||||
"""Change the actual and nominal length of an IterableDataset.
|
||||
|
||||
This will continuously iterate through the original dataset, but
|
||||
impose new epoch boundaries at the given length/nominal.
|
||||
This exists mainly as a workaround for the odd logic in DataLoader.
|
||||
It is also useful for choosing smaller nominal epoch sizes with
|
||||
very large datasets.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, length):
|
||||
"""Chop the dataset to the given length.
|
||||
|
||||
:param dataset: IterableDataset
|
||||
:param length: declared length of the dataset
|
||||
:param nominal: nominal length of dataset (if different from declared)
|
||||
"""
|
||||
super().__init__()
|
||||
self.length = length
|
||||
self.source = None
|
||||
|
||||
def __getstate__(self):
|
||||
"""Return the pickled state of the dataset.
|
||||
|
||||
This resets the dataset iterator, since that can't be pickled.
|
||||
"""
|
||||
result = dict(self.__dict__)
|
||||
result["source"] = None
|
||||
return result
|
||||
|
||||
def invoke(self, dataset):
|
||||
"""Return an iterator over the dataset.
|
||||
|
||||
This iterator returns as many samples as given by the `length`
|
||||
parameter.
|
||||
"""
|
||||
if self.source is None:
|
||||
self.source = iter(dataset)
|
||||
for i in range(self.length):
|
||||
try:
|
||||
sample = next(self.source)
|
||||
except StopIteration:
|
||||
self.source = iter(dataset)
|
||||
try:
|
||||
sample = next(self.source)
|
||||
except StopIteration:
|
||||
return
|
||||
yield sample
|
||||
self.source = None
|
||||
|
||||
|
||||
class with_length(IterableDataset, PipelineStage):
|
||||
"""Repeatedly yield samples from a dataset."""
|
||||
|
||||
def __init__(self, dataset, length):
|
||||
"""Create an instance of Repeatedly.
|
||||
|
||||
:param dataset: source dataset
|
||||
:param length: stated length
|
||||
"""
|
||||
super().__init__()
|
||||
self.dataset = dataset
|
||||
self.length = length
|
||||
|
||||
def invoke(self, dataset):
|
||||
"""Return an iterator that iterates repeatedly over a source."""
|
||||
return iter(dataset)
|
||||
|
||||
def __len__(self):
|
||||
"""Return the user specified length."""
|
||||
return self.length
|
@ -0,0 +1,935 @@
|
||||
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
# This file is part of the WebDataset library.
|
||||
# See the LICENSE file for licensing terms (BSD-style).
|
||||
#
|
||||
|
||||
# Modified from https://github.com/webdataset/webdataset
|
||||
# Modified from wenet(https://github.com/wenet-e2e/wenet)
|
||||
"""A collection of iterators for data transformations.
|
||||
|
||||
These functions are plain iterator functions. You can find curried versions
|
||||
in webdataset.filters, and you can find IterableDataset wrappers in
|
||||
webdataset.processing.
|
||||
"""
|
||||
|
||||
import io
|
||||
from fnmatch import fnmatch
|
||||
import re
|
||||
import itertools, os, random, sys, time
|
||||
from functools import reduce, wraps
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import autodecode
|
||||
from . import utils
|
||||
from .paddle_utils import PaddleTensor
|
||||
from .utils import PipelineStage
|
||||
|
||||
from .. import backends
|
||||
from ..compliance import kaldi
|
||||
import paddle
|
||||
from ..transform.cmvn import GlobalCMVN
|
||||
from ..utils.tensor_utils import pad_sequence
|
||||
from ..transform.spec_augment import time_warp
|
||||
from ..transform.spec_augment import time_mask
|
||||
from ..transform.spec_augment import freq_mask
|
||||
|
||||
class FilterFunction(object):
|
||||
"""Helper class for currying pipeline stages.
|
||||
|
||||
We use this roundabout construct becauce it can be pickled.
|
||||
"""
|
||||
|
||||
def __init__(self, f, *args, **kw):
|
||||
"""Create a curried function."""
|
||||
self.f = f
|
||||
self.args = args
|
||||
self.kw = kw
|
||||
|
||||
def __call__(self, data):
|
||||
"""Call the curried function with the given argument."""
|
||||
return self.f(data, *self.args, **self.kw)
|
||||
|
||||
def __str__(self):
|
||||
"""Compute a string representation."""
|
||||
return f"<{self.f.__name__} {self.args} {self.kw}>"
|
||||
|
||||
def __repr__(self):
|
||||
"""Compute a string representation."""
|
||||
return f"<{self.f.__name__} {self.args} {self.kw}>"
|
||||
|
||||
|
||||
class RestCurried(object):
|
||||
"""Helper class for currying pipeline stages.
|
||||
|
||||
We use this roundabout construct because it can be pickled.
|
||||
"""
|
||||
|
||||
def __init__(self, f):
|
||||
"""Store the function for future currying."""
|
||||
self.f = f
|
||||
|
||||
def __call__(self, *args, **kw):
|
||||
"""Curry with the given arguments."""
|
||||
return FilterFunction(self.f, *args, **kw)
|
||||
|
||||
|
||||
def pipelinefilter(f):
|
||||
"""Turn the decorated function into one that is partially applied for
|
||||
all arguments other than the first."""
|
||||
result = RestCurried(f)
|
||||
return result
|
||||
|
||||
|
||||
def reraise_exception(exn):
|
||||
"""Reraises the given exception; used as a handler.
|
||||
|
||||
:param exn: exception
|
||||
"""
|
||||
raise exn
|
||||
|
||||
|
||||
def identity(x):
|
||||
"""Return the argument."""
|
||||
return x
|
||||
|
||||
|
||||
def compose2(f, g):
|
||||
"""Compose two functions, g(f(x))."""
|
||||
return lambda x: g(f(x))
|
||||
|
||||
|
||||
def compose(*args):
|
||||
"""Compose a sequence of functions (left-to-right)."""
|
||||
return reduce(compose2, args)
|
||||
|
||||
|
||||
def pipeline(source, *args):
|
||||
"""Write an input pipeline; first argument is source, rest are filters."""
|
||||
if len(args) == 0:
|
||||
return source
|
||||
return compose(*args)(source)
|
||||
|
||||
|
||||
def getfirst(a, keys, default=None, missing_is_error=True):
|
||||
"""Get the first matching key from a dictionary.
|
||||
|
||||
Keys can be specified as a list, or as a string of keys separated by ';'.
|
||||
"""
|
||||
if isinstance(keys, str):
|
||||
assert " " not in keys
|
||||
keys = keys.split(";")
|
||||
for k in keys:
|
||||
if k in a:
|
||||
return a[k]
|
||||
if missing_is_error:
|
||||
raise ValueError(f"didn't find {keys} in {list(a.keys())}")
|
||||
return default
|
||||
|
||||
|
||||
def parse_field_spec(fields):
|
||||
"""Parse a specification for a list of fields to be extracted.
|
||||
|
||||
Keys are separated by spaces in the spec. Each key can itself
|
||||
be composed of key alternatives separated by ';'.
|
||||
"""
|
||||
if isinstance(fields, str):
|
||||
fields = fields.split()
|
||||
return [field.split(";") for field in fields]
|
||||
|
||||
|
||||
def transform_with(sample, transformers):
|
||||
"""Transform a list of values using a list of functions.
|
||||
|
||||
sample: list of values
|
||||
transformers: list of functions
|
||||
|
||||
If there are fewer transformers than inputs, or if a transformer
|
||||
function is None, then the identity function is used for the
|
||||
corresponding sample fields.
|
||||
"""
|
||||
if transformers is None or len(transformers) == 0:
|
||||
return sample
|
||||
result = list(sample)
|
||||
assert len(transformers) <= len(sample)
|
||||
for i in range(len(transformers)): # skipcq: PYL-C0200
|
||||
f = transformers[i]
|
||||
if f is not None:
|
||||
result[i] = f(sample[i])
|
||||
return result
|
||||
|
||||
###
|
||||
# Iterators
|
||||
###
|
||||
|
||||
def _info(data, fmt=None, n=3, every=-1, width=50, stream=sys.stderr, name=""):
|
||||
"""Print information about the samples that are passing through.
|
||||
|
||||
:param data: source iterator
|
||||
:param fmt: format statement (using sample dict as keyword)
|
||||
:param n: when to stop
|
||||
:param every: how often to print
|
||||
:param width: maximum width
|
||||
:param stream: output stream
|
||||
:param name: identifier printed before any output
|
||||
"""
|
||||
for i, sample in enumerate(data):
|
||||
if i < n or (every > 0 and (i + 1) % every == 0):
|
||||
if fmt is None:
|
||||
print("---", name, file=stream)
|
||||
for k, v in sample.items():
|
||||
print(k, repr(v)[:width], file=stream)
|
||||
else:
|
||||
print(fmt.format(**sample), file=stream)
|
||||
yield sample
|
||||
|
||||
|
||||
info = pipelinefilter(_info)
|
||||
|
||||
|
||||
def pick(buf, rng):
|
||||
k = rng.randint(0, len(buf) - 1)
|
||||
sample = buf[k]
|
||||
buf[k] = buf[-1]
|
||||
buf.pop()
|
||||
return sample
|
||||
|
||||
|
||||
def _shuffle(data, bufsize=1000, initial=100, rng=None, handler=None):
|
||||
"""Shuffle the data in the stream.
|
||||
|
||||
This uses a buffer of size `bufsize`. Shuffling at
|
||||
startup is less random; this is traded off against
|
||||
yielding samples quickly.
|
||||
|
||||
data: iterator
|
||||
bufsize: buffer size for shuffling
|
||||
returns: iterator
|
||||
rng: either random module or random.Random instance
|
||||
|
||||
"""
|
||||
if rng is None:
|
||||
rng = random.Random(int((os.getpid() + time.time()) * 1e9))
|
||||
initial = min(initial, bufsize)
|
||||
buf = []
|
||||
for sample in data:
|
||||
buf.append(sample)
|
||||
if len(buf) < bufsize:
|
||||
try:
|
||||
buf.append(next(data)) # skipcq: PYL-R1708
|
||||
except StopIteration:
|
||||
pass
|
||||
if len(buf) >= initial:
|
||||
yield pick(buf, rng)
|
||||
while len(buf) > 0:
|
||||
yield pick(buf, rng)
|
||||
|
||||
|
||||
shuffle = pipelinefilter(_shuffle)
|
||||
|
||||
|
||||
class detshuffle(PipelineStage):
|
||||
def __init__(self, bufsize=1000, initial=100, seed=0, epoch=-1):
|
||||
self.bufsize = bufsize
|
||||
self.initial = initial
|
||||
self.seed = seed
|
||||
self.epoch = epoch
|
||||
|
||||
def run(self, src):
|
||||
self.epoch += 1
|
||||
rng = random.Random()
|
||||
rng.seed((self.seed, self.epoch))
|
||||
return _shuffle(src, self.bufsize, self.initial, rng)
|
||||
|
||||
|
||||
def _select(data, predicate):
|
||||
"""Select samples based on a predicate.
|
||||
|
||||
:param data: source iterator
|
||||
:param predicate: predicate (function)
|
||||
"""
|
||||
for sample in data:
|
||||
if predicate(sample):
|
||||
yield sample
|
||||
|
||||
|
||||
select = pipelinefilter(_select)
|
||||
|
||||
|
||||
def _log_keys(data, logfile=None):
|
||||
import fcntl
|
||||
|
||||
if logfile is None or logfile == "":
|
||||
for sample in data:
|
||||
yield sample
|
||||
else:
|
||||
with open(logfile, "a") as stream:
|
||||
for i, sample in enumerate(data):
|
||||
buf = f"{i}\t{sample.get('__worker__')}\t{sample.get('__rank__')}\t{sample.get('__key__')}\n"
|
||||
try:
|
||||
fcntl.flock(stream.fileno(), fcntl.LOCK_EX)
|
||||
stream.write(buf)
|
||||
finally:
|
||||
fcntl.flock(stream.fileno(), fcntl.LOCK_UN)
|
||||
yield sample
|
||||
|
||||
|
||||
log_keys = pipelinefilter(_log_keys)
|
||||
|
||||
|
||||
def _decode(data, *args, handler=reraise_exception, **kw):
|
||||
"""Decode data based on the decoding functions given as arguments."""
|
||||
|
||||
decoder = lambda x: autodecode.imagehandler(x) if isinstance(x, str) else x
|
||||
handlers = [decoder(x) for x in args]
|
||||
f = autodecode.Decoder(handlers, **kw)
|
||||
|
||||
for sample in data:
|
||||
assert isinstance(sample, dict), sample
|
||||
try:
|
||||
decoded = f(sample)
|
||||
except Exception as exn: # skipcq: PYL-W0703
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
yield decoded
|
||||
|
||||
|
||||
decode = pipelinefilter(_decode)
|
||||
|
||||
|
||||
def _map(data, f, handler=reraise_exception):
|
||||
"""Map samples."""
|
||||
for sample in data:
|
||||
try:
|
||||
result = f(sample)
|
||||
except Exception as exn:
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
if result is None:
|
||||
continue
|
||||
if isinstance(sample, dict) and isinstance(result, dict):
|
||||
result["__key__"] = sample.get("__key__")
|
||||
yield result
|
||||
|
||||
|
||||
map = pipelinefilter(_map)
|
||||
|
||||
|
||||
def _rename(data, handler=reraise_exception, keep=True, **kw):
|
||||
"""Rename samples based on keyword arguments."""
|
||||
for sample in data:
|
||||
try:
|
||||
if not keep:
|
||||
yield {k: getfirst(sample, v, missing_is_error=True) for k, v in kw.items()}
|
||||
else:
|
||||
|
||||
def listify(v):
|
||||
return v.split(";") if isinstance(v, str) else v
|
||||
|
||||
to_be_replaced = {x for v in kw.values() for x in listify(v)}
|
||||
result = {k: v for k, v in sample.items() if k not in to_be_replaced}
|
||||
result.update({k: getfirst(sample, v, missing_is_error=True) for k, v in kw.items()})
|
||||
yield result
|
||||
except Exception as exn:
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
rename = pipelinefilter(_rename)
|
||||
|
||||
|
||||
def _associate(data, associator, **kw):
|
||||
"""Associate additional data with samples."""
|
||||
for sample in data:
|
||||
if callable(associator):
|
||||
extra = associator(sample["__key__"])
|
||||
else:
|
||||
extra = associator.get(sample["__key__"], {})
|
||||
sample.update(extra) # destructive
|
||||
yield sample
|
||||
|
||||
|
||||
associate = pipelinefilter(_associate)
|
||||
|
||||
|
||||
def _map_dict(data, handler=reraise_exception, **kw):
|
||||
"""Map the entries in a dict sample with individual functions."""
|
||||
assert len(list(kw.keys())) > 0
|
||||
for key, f in kw.items():
|
||||
assert callable(f), (key, f)
|
||||
|
||||
for sample in data:
|
||||
assert isinstance(sample, dict)
|
||||
try:
|
||||
for k, f in kw.items():
|
||||
sample[k] = f(sample[k])
|
||||
except Exception as exn:
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
yield sample
|
||||
|
||||
|
||||
map_dict = pipelinefilter(_map_dict)
|
||||
|
||||
|
||||
def _to_tuple(data, *args, handler=reraise_exception, missing_is_error=True, none_is_error=None):
|
||||
"""Convert dict samples to tuples."""
|
||||
if none_is_error is None:
|
||||
none_is_error = missing_is_error
|
||||
if len(args) == 1 and isinstance(args[0], str) and " " in args[0]:
|
||||
args = args[0].split()
|
||||
|
||||
for sample in data:
|
||||
try:
|
||||
result = tuple([getfirst(sample, f, missing_is_error=missing_is_error) for f in args])
|
||||
if none_is_error and any(x is None for x in result):
|
||||
raise ValueError(f"to_tuple {args} got {sample.keys()}")
|
||||
yield result
|
||||
except Exception as exn:
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
to_tuple = pipelinefilter(_to_tuple)
|
||||
|
||||
|
||||
def _map_tuple(data, *args, handler=reraise_exception):
|
||||
"""Map the entries of a tuple with individual functions."""
|
||||
args = [f if f is not None else utils.identity for f in args]
|
||||
for f in args:
|
||||
assert callable(f), f
|
||||
for sample in data:
|
||||
assert isinstance(sample, (list, tuple))
|
||||
sample = list(sample)
|
||||
n = min(len(args), len(sample))
|
||||
try:
|
||||
for i in range(n):
|
||||
sample[i] = args[i](sample[i])
|
||||
except Exception as exn:
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
yield tuple(sample)
|
||||
|
||||
|
||||
map_tuple = pipelinefilter(_map_tuple)
|
||||
|
||||
|
||||
def _unlisted(data):
|
||||
"""Turn batched data back into unbatched data."""
|
||||
for batch in data:
|
||||
assert isinstance(batch, list), sample
|
||||
for sample in batch:
|
||||
yield sample
|
||||
|
||||
|
||||
unlisted = pipelinefilter(_unlisted)
|
||||
|
||||
|
||||
def _unbatched(data):
|
||||
"""Turn batched data back into unbatched data."""
|
||||
for sample in data:
|
||||
assert isinstance(sample, (tuple, list)), sample
|
||||
assert len(sample) > 0
|
||||
for i in range(len(sample[0])):
|
||||
yield tuple(x[i] for x in sample)
|
||||
|
||||
|
||||
unbatched = pipelinefilter(_unbatched)
|
||||
|
||||
|
||||
def _rsample(data, p=0.5):
|
||||
"""Randomly subsample a stream of data."""
|
||||
assert p >= 0.0 and p <= 1.0
|
||||
for sample in data:
|
||||
if random.uniform(0.0, 1.0) < p:
|
||||
yield sample
|
||||
|
||||
|
||||
rsample = pipelinefilter(_rsample)
|
||||
|
||||
slice = pipelinefilter(itertools.islice)
|
||||
|
||||
|
||||
def _extract_keys(source, *patterns, duplicate_is_error=True, ignore_missing=False):
|
||||
for sample in source:
|
||||
result = []
|
||||
for pattern in patterns:
|
||||
pattern = pattern.split(";") if isinstance(pattern, str) else pattern
|
||||
matches = [x for x in sample.keys() if any(fnmatch("." + x, p) for p in pattern)]
|
||||
if len(matches) == 0:
|
||||
if ignore_missing:
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Cannot find {pattern} in sample keys {sample.keys()}.")
|
||||
if len(matches) > 1 and duplicate_is_error:
|
||||
raise ValueError(f"Multiple sample keys {sample.keys()} match {pattern}.")
|
||||
value = sample[matches[0]]
|
||||
result.append(value)
|
||||
yield tuple(result)
|
||||
|
||||
|
||||
extract_keys = pipelinefilter(_extract_keys)
|
||||
|
||||
|
||||
def _rename_keys(source, *args, keep_unselected=False, must_match=True, duplicate_is_error=True, **kw):
|
||||
renamings = [(pattern, output) for output, pattern in args]
|
||||
renamings += [(pattern, output) for output, pattern in kw.items()]
|
||||
for sample in source:
|
||||
new_sample = {}
|
||||
matched = {k: False for k, _ in renamings}
|
||||
for path, value in sample.items():
|
||||
fname = re.sub(r".*/", "", path)
|
||||
new_name = None
|
||||
for pattern, name in renamings[::-1]:
|
||||
if fnmatch(fname.lower(), pattern):
|
||||
matched[pattern] = True
|
||||
new_name = name
|
||||
break
|
||||
if new_name is None:
|
||||
if keep_unselected:
|
||||
new_sample[path] = value
|
||||
continue
|
||||
if new_name in new_sample:
|
||||
if duplicate_is_error:
|
||||
raise ValueError(f"Duplicate value in sample {sample.keys()} after rename.")
|
||||
continue
|
||||
new_sample[new_name] = value
|
||||
if must_match and not all(matched.values()):
|
||||
raise ValueError(f"Not all patterns ({matched}) matched sample keys ({sample.keys()}).")
|
||||
|
||||
yield new_sample
|
||||
|
||||
|
||||
rename_keys = pipelinefilter(_rename_keys)
|
||||
|
||||
|
||||
def decode_bin(stream):
|
||||
return stream.read()
|
||||
|
||||
|
||||
def decode_text(stream):
|
||||
binary = stream.read()
|
||||
return binary.decode("utf-8")
|
||||
|
||||
|
||||
def decode_pickle(stream):
|
||||
return pickle.load(stream)
|
||||
|
||||
|
||||
default_decoders = [
|
||||
("*.bin", decode_bin),
|
||||
("*.txt", decode_text),
|
||||
("*.pyd", decode_pickle),
|
||||
]
|
||||
|
||||
|
||||
def find_decoder(decoders, path):
|
||||
fname = re.sub(r".*/", "", path)
|
||||
if fname.startswith("__"):
|
||||
return lambda x: x
|
||||
for pattern, fun in decoders[::-1]:
|
||||
if fnmatch(fname.lower(), pattern) or fnmatch("." + fname.lower(), pattern):
|
||||
return fun
|
||||
return None
|
||||
|
||||
|
||||
def _xdecode(
|
||||
source,
|
||||
*args,
|
||||
must_decode=True,
|
||||
defaults=default_decoders,
|
||||
**kw,
|
||||
):
|
||||
decoders = list(defaults) + list(args)
|
||||
decoders += [("*." + k, v) for k, v in kw.items()]
|
||||
for sample in source:
|
||||
new_sample = {}
|
||||
for path, data in sample.items():
|
||||
if path.startswith("__"):
|
||||
new_sample[path] = data
|
||||
continue
|
||||
decoder = find_decoder(decoders, path)
|
||||
if decoder is False:
|
||||
value = data
|
||||
elif decoder is None:
|
||||
if must_decode:
|
||||
raise ValueError(f"No decoder found for {path}.")
|
||||
value = data
|
||||
else:
|
||||
if isinstance(data, bytes):
|
||||
data = io.BytesIO(data)
|
||||
value = decoder(data)
|
||||
new_sample[path] = value
|
||||
yield new_sample
|
||||
|
||||
xdecode = pipelinefilter(_xdecode)
|
||||
|
||||
|
||||
|
||||
def _audio_data_filter(source,
|
||||
frame_shift=10,
|
||||
max_length=10240,
|
||||
min_length=10,
|
||||
token_max_length=200,
|
||||
token_min_length=1,
|
||||
min_output_input_ratio=0.0005,
|
||||
max_output_input_ratio=1):
|
||||
""" Filter sample according to feature and label length
|
||||
Inplace operation.
|
||||
|
||||
Args::
|
||||
source: Iterable[{fname, wav, label, sample_rate}]
|
||||
frame_shift: length of frame shift (ms)
|
||||
max_length: drop utterance which is greater than max_length(10ms)
|
||||
min_length: drop utterance which is less than min_length(10ms)
|
||||
token_max_length: drop utterance which is greater than
|
||||
token_max_length, especially when use char unit for
|
||||
english modeling
|
||||
token_min_length: drop utterance which is
|
||||
less than token_max_length
|
||||
min_output_input_ratio: minimal ration of
|
||||
token_length / feats_length(10ms)
|
||||
max_output_input_ratio: maximum ration of
|
||||
token_length / feats_length(10ms)
|
||||
|
||||
Returns:
|
||||
Iterable[{fname, wav, label, sample_rate}]
|
||||
"""
|
||||
for sample in source:
|
||||
assert 'sample_rate' in sample
|
||||
assert 'wav' in sample
|
||||
assert 'label' in sample
|
||||
# sample['wav'] is paddle.Tensor, we have 100 frames every second (default)
|
||||
num_frames = sample['wav'].shape[1] / sample['sample_rate'] * (1000 / frame_shift)
|
||||
if num_frames < min_length:
|
||||
continue
|
||||
if num_frames > max_length:
|
||||
continue
|
||||
if len(sample['label']) < token_min_length:
|
||||
continue
|
||||
if len(sample['label']) > token_max_length:
|
||||
continue
|
||||
if num_frames != 0:
|
||||
if len(sample['label']) / num_frames < min_output_input_ratio:
|
||||
continue
|
||||
if len(sample['label']) / num_frames > max_output_input_ratio:
|
||||
continue
|
||||
yield sample
|
||||
|
||||
audio_data_filter = pipelinefilter(_audio_data_filter)
|
||||
|
||||
def _audio_tokenize(source,
|
||||
symbol_table,
|
||||
bpe_model=None,
|
||||
non_lang_syms=None,
|
||||
split_with_space=False):
|
||||
""" Decode text to chars or BPE
|
||||
Inplace operation
|
||||
|
||||
Args:
|
||||
source: Iterable[{fname, wav, txt, sample_rate}]
|
||||
|
||||
Returns:
|
||||
Iterable[{fname, wav, txt, tokens, label, sample_rate}]
|
||||
"""
|
||||
if non_lang_syms is not None:
|
||||
non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})")
|
||||
else:
|
||||
non_lang_syms = {}
|
||||
non_lang_syms_pattern = None
|
||||
|
||||
if bpe_model is not None:
|
||||
import sentencepiece as spm
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(bpe_model)
|
||||
else:
|
||||
sp = None
|
||||
|
||||
for sample in source:
|
||||
assert 'txt' in sample
|
||||
txt = sample['txt'].strip()
|
||||
if non_lang_syms_pattern is not None:
|
||||
parts = non_lang_syms_pattern.split(txt.upper())
|
||||
parts = [w for w in parts if len(w.strip()) > 0]
|
||||
else:
|
||||
parts = [txt]
|
||||
|
||||
label = []
|
||||
tokens = []
|
||||
for part in parts:
|
||||
if part in non_lang_syms:
|
||||
tokens.append(part)
|
||||
else:
|
||||
if bpe_model is not None:
|
||||
tokens.extend(__tokenize_by_bpe_model(sp, part))
|
||||
else:
|
||||
if split_with_space:
|
||||
part = part.split(" ")
|
||||
for ch in part:
|
||||
if ch == ' ':
|
||||
ch = "<space>"
|
||||
tokens.append(ch)
|
||||
|
||||
for ch in tokens:
|
||||
if ch in symbol_table:
|
||||
label.append(symbol_table[ch])
|
||||
elif '<unk>' in symbol_table:
|
||||
label.append(symbol_table['<unk>'])
|
||||
|
||||
sample['tokens'] = tokens
|
||||
sample['label'] = label
|
||||
yield sample
|
||||
|
||||
audio_tokenize = pipelinefilter(_audio_tokenize)
|
||||
|
||||
def _audio_resample(source, resample_rate=16000):
|
||||
""" Resample data.
|
||||
Inplace operation.
|
||||
|
||||
Args:
|
||||
data: Iterable[{fname, wav, label, sample_rate}]
|
||||
resample_rate: target resample rate
|
||||
|
||||
Returns:
|
||||
Iterable[{fname, wav, label, sample_rate}]
|
||||
"""
|
||||
for sample in source:
|
||||
assert 'sample_rate' in sample
|
||||
assert 'wav' in sample
|
||||
sample_rate = sample['sample_rate']
|
||||
waveform = sample['wav']
|
||||
if sample_rate != resample_rate:
|
||||
sample['sample_rate'] = resample_rate
|
||||
sample['wav'] = paddle.to_tensor(backends.soundfile_backend.resample(
|
||||
waveform.numpy(), src_sr = sample_rate, target_sr = resample_rate
|
||||
))
|
||||
yield sample
|
||||
|
||||
audio_resample = pipelinefilter(_audio_resample)
|
||||
|
||||
def _audio_compute_fbank(source,
|
||||
num_mel_bins=80,
|
||||
frame_length=25,
|
||||
frame_shift=10,
|
||||
dither=0.0):
|
||||
""" Extract fbank
|
||||
|
||||
Args:
|
||||
source: Iterable[{fname, wav, label, sample_rate}]
|
||||
num_mel_bins: number of mel filter bank
|
||||
frame_length: length of one frame (ms)
|
||||
frame_shift: length of frame shift (ms)
|
||||
dither: value of dither
|
||||
|
||||
Returns:
|
||||
Iterable[{fname, feat, label}]
|
||||
"""
|
||||
for sample in source:
|
||||
assert 'sample_rate' in sample
|
||||
assert 'wav' in sample
|
||||
assert 'fname' in sample
|
||||
assert 'label' in sample
|
||||
sample_rate = sample['sample_rate']
|
||||
waveform = sample['wav']
|
||||
waveform = waveform * (1 << 15)
|
||||
# Only keep fname, feat, label
|
||||
mat = kaldi.fbank(waveform,
|
||||
n_mels=num_mel_bins,
|
||||
frame_length=frame_length,
|
||||
frame_shift=frame_shift,
|
||||
dither=dither,
|
||||
energy_floor=0.0,
|
||||
sr=sample_rate)
|
||||
yield dict(fname=sample['fname'], label=sample['label'], feat=mat)
|
||||
|
||||
|
||||
audio_compute_fbank = pipelinefilter(_audio_compute_fbank)
|
||||
|
||||
def _audio_spec_aug(source,
|
||||
max_w=5,
|
||||
w_inplace=True,
|
||||
w_mode="PIL",
|
||||
max_f=30,
|
||||
num_f_mask=2,
|
||||
f_inplace=True,
|
||||
f_replace_with_zero=False,
|
||||
max_t=40,
|
||||
num_t_mask=2,
|
||||
t_inplace=True,
|
||||
t_replace_with_zero=False,):
|
||||
""" Do spec augmentation
|
||||
Inplace operation
|
||||
|
||||
Args:
|
||||
source: Iterable[{fname, feat, label}]
|
||||
max_w: max width of time warp
|
||||
w_inplace: whether to inplace the original data while time warping
|
||||
w_mode: time warp mode
|
||||
max_f: max width of freq mask
|
||||
num_f_mask: number of freq mask to apply
|
||||
f_inplace: whether to inplace the original data while frequency masking
|
||||
f_replace_with_zero: use zero to mask
|
||||
max_t: max width of time mask
|
||||
num_t_mask: number of time mask to apply
|
||||
t_inplace: whether to inplace the original data while time masking
|
||||
t_replace_with_zero: use zero to mask
|
||||
|
||||
Returns
|
||||
Iterable[{fname, feat, label}]
|
||||
"""
|
||||
for sample in source:
|
||||
x = sample['feat']
|
||||
x = x.numpy()
|
||||
x = time_warp(x, max_time_warp=max_w, inplace = w_inplace, mode= w_mode)
|
||||
x = freq_mask(x, F = max_f, n_mask = num_f_mask, inplace = f_inplace, replace_with_zero = f_replace_with_zero)
|
||||
x = time_mask(x, T = max_t, n_mask = num_t_mask, inplace = t_inplace, replace_with_zero = t_replace_with_zero)
|
||||
sample['feat'] = paddle.to_tensor(x, dtype=paddle.float32)
|
||||
yield sample
|
||||
|
||||
audio_spec_aug = pipelinefilter(_audio_spec_aug)
|
||||
|
||||
|
||||
def _sort(source, sort_size=500):
|
||||
""" Sort the data by feature length.
|
||||
Sort is used after shuffle and before batch, so we can group
|
||||
utts with similar lengths into a batch, and `sort_size` should
|
||||
be less than `shuffle_size`
|
||||
|
||||
Args:
|
||||
source: Iterable[{fname, feat, label}]
|
||||
sort_size: buffer size for sort
|
||||
|
||||
Returns:
|
||||
Iterable[{fname, feat, label}]
|
||||
"""
|
||||
|
||||
buf = []
|
||||
for sample in source:
|
||||
buf.append(sample)
|
||||
if len(buf) >= sort_size:
|
||||
buf.sort(key=lambda x: x['feat'].shape[0])
|
||||
for x in buf:
|
||||
yield x
|
||||
buf = []
|
||||
# The sample left over
|
||||
buf.sort(key=lambda x: x['feat'].shape[0])
|
||||
for x in buf:
|
||||
yield x
|
||||
|
||||
sort = pipelinefilter(_sort)
|
||||
|
||||
def _batched(source, batch_size=16):
|
||||
""" Static batch the data by `batch_size`
|
||||
|
||||
Args:
|
||||
data: Iterable[{fname, feat, label}]
|
||||
batch_size: batch size
|
||||
|
||||
Returns:
|
||||
Iterable[List[{fname, feat, label}]]
|
||||
"""
|
||||
buf = []
|
||||
for sample in source:
|
||||
buf.append(sample)
|
||||
if len(buf) >= batch_size:
|
||||
yield buf
|
||||
buf = []
|
||||
if len(buf) > 0:
|
||||
yield buf
|
||||
|
||||
batched = pipelinefilter(_batched)
|
||||
|
||||
def dynamic_batched(source, max_frames_in_batch=12000):
|
||||
""" Dynamic batch the data until the total frames in batch
|
||||
reach `max_frames_in_batch`
|
||||
|
||||
Args:
|
||||
source: Iterable[{fname, feat, label}]
|
||||
max_frames_in_batch: max_frames in one batch
|
||||
|
||||
Returns:
|
||||
Iterable[List[{fname, feat, label}]]
|
||||
"""
|
||||
buf = []
|
||||
longest_frames = 0
|
||||
for sample in source:
|
||||
assert 'feat' in sample
|
||||
assert isinstance(sample['feat'], paddle.Tensor)
|
||||
new_sample_frames = sample['feat'].size(0)
|
||||
longest_frames = max(longest_frames, new_sample_frames)
|
||||
frames_after_padding = longest_frames * (len(buf) + 1)
|
||||
if frames_after_padding > max_frames_in_batch:
|
||||
yield buf
|
||||
buf = [sample]
|
||||
longest_frames = new_sample_frames
|
||||
else:
|
||||
buf.append(sample)
|
||||
if len(buf) > 0:
|
||||
yield buf
|
||||
|
||||
|
||||
def _audio_padding(source):
|
||||
""" Padding the data into training data
|
||||
|
||||
Args:
|
||||
source: Iterable[List[{fname, feat, label}]]
|
||||
|
||||
Returns:
|
||||
Iterable[Tuple(fname, feats, labels, feats lengths, label lengths)]
|
||||
"""
|
||||
for sample in source:
|
||||
assert isinstance(sample, list)
|
||||
feats_length = paddle.to_tensor([x['feat'].shape[0] for x in sample],
|
||||
dtype="int64")
|
||||
order = paddle.argsort(feats_length, descending=True)
|
||||
feats_lengths = paddle.to_tensor(
|
||||
[sample[i]['feat'].shape[0] for i in order], dtype="int64")
|
||||
sorted_feats = [sample[i]['feat'] for i in order]
|
||||
sorted_keys = [sample[i]['fname'] for i in order]
|
||||
sorted_labels = [
|
||||
paddle.to_tensor(sample[i]['label'], dtype="int32") for i in order
|
||||
]
|
||||
label_lengths = paddle.to_tensor([x.shape[0] for x in sorted_labels],
|
||||
dtype="int64")
|
||||
padded_feats = pad_sequence(sorted_feats,
|
||||
batch_first=True,
|
||||
padding_value=0)
|
||||
padding_labels = pad_sequence(sorted_labels,
|
||||
batch_first=True,
|
||||
padding_value=-1)
|
||||
|
||||
yield (sorted_keys, padded_feats, feats_lengths, padding_labels,
|
||||
label_lengths)
|
||||
|
||||
audio_padding = pipelinefilter(_audio_padding)
|
||||
|
||||
def _audio_cmvn(source, cmvn_file):
|
||||
global_cmvn = GlobalCMVN(cmvn_file)
|
||||
for batch in source:
|
||||
sorted_keys, padded_feats, feats_lengths, padding_labels, label_lengths = batch
|
||||
padded_feats = padded_feats.numpy()
|
||||
padded_feats = global_cmvn(padded_feats)
|
||||
padded_feats = paddle.to_tensor(padded_feats, dtype=paddle.float32)
|
||||
yield (sorted_keys, padded_feats, feats_lengths, padding_labels,
|
||||
label_lengths)
|
||||
|
||||
audio_cmvn = pipelinefilter(_audio_cmvn)
|
||||
|
||||
def _placeholder(source):
|
||||
for data in source:
|
||||
yield data
|
||||
|
||||
placeholder = pipelinefilter(_placeholder)
|
@ -0,0 +1,340 @@
|
||||
#
|
||||
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||
# This file is part of the WebDataset library.
|
||||
# See the LICENSE file for licensing terms (BSD-style).
|
||||
#
|
||||
|
||||
|
||||
"""Open URLs by calling subcommands."""
|
||||
|
||||
import os, sys, re
|
||||
from subprocess import PIPE, Popen
|
||||
from urllib.parse import urlparse
|
||||
|
||||
# global used for printing additional node information during verbose output
|
||||
info = {}
|
||||
|
||||
|
||||
class Pipe:
|
||||
"""Wrapper class for subprocess.Pipe.
|
||||
|
||||
This class looks like a stream from the outside, but it checks
|
||||
subprocess status and handles timeouts with exceptions.
|
||||
This way, clients of the class do not need to know that they are
|
||||
dealing with subprocesses.
|
||||
|
||||
:param *args: passed to `subprocess.Pipe`
|
||||
:param **kw: passed to `subprocess.Pipe`
|
||||
:param timeout: timeout for closing/waiting
|
||||
:param ignore_errors: don't raise exceptions on subprocess errors
|
||||
:param ignore_status: list of status codes to ignore
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
mode=None,
|
||||
timeout=7200.0,
|
||||
ignore_errors=False,
|
||||
ignore_status=[],
|
||||
**kw,
|
||||
):
|
||||
"""Create an IO Pipe."""
|
||||
self.ignore_errors = ignore_errors
|
||||
self.ignore_status = [0] + ignore_status
|
||||
self.timeout = timeout
|
||||
self.args = (args, kw)
|
||||
if mode[0] == "r":
|
||||
self.proc = Popen(*args, stdout=PIPE, **kw)
|
||||
self.stream = self.proc.stdout
|
||||
if self.stream is None:
|
||||
raise ValueError(f"{args}: couldn't open")
|
||||
elif mode[0] == "w":
|
||||
self.proc = Popen(*args, stdin=PIPE, **kw)
|
||||
self.stream = self.proc.stdin
|
||||
if self.stream is None:
|
||||
raise ValueError(f"{args}: couldn't open")
|
||||
self.status = None
|
||||
|
||||
def __str__(self):
|
||||
return f"<Pipe {self.args}>"
|
||||
|
||||
def check_status(self):
|
||||
"""Poll the process and handle any errors."""
|
||||
status = self.proc.poll()
|
||||
if status is not None:
|
||||
self.wait_for_child()
|
||||
|
||||
def wait_for_child(self):
|
||||
"""Check the status variable and raise an exception if necessary."""
|
||||
verbose = int(os.environ.get("GOPEN_VERBOSE", 0))
|
||||
if self.status is not None and verbose:
|
||||
# print(f"(waiting again [{self.status} {os.getpid()}:{self.proc.pid}])", file=sys.stderr)
|
||||
return
|
||||
self.status = self.proc.wait()
|
||||
if verbose:
|
||||
print(
|
||||
f"pipe exit [{self.status} {os.getpid()}:{self.proc.pid}] {self.args} {info}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
if self.status not in self.ignore_status and not self.ignore_errors:
|
||||
raise Exception(f"{self.args}: exit {self.status} (read) {info}")
|
||||
|
||||
def read(self, *args, **kw):
|
||||
"""Wrap stream.read and checks status."""
|
||||
result = self.stream.read(*args, **kw)
|
||||
self.check_status()
|
||||
return result
|
||||
|
||||
def write(self, *args, **kw):
|
||||
"""Wrap stream.write and checks status."""
|
||||
result = self.stream.write(*args, **kw)
|
||||
self.check_status()
|
||||
return result
|
||||
|
||||
def readLine(self, *args, **kw):
|
||||
"""Wrap stream.readLine and checks status."""
|
||||
result = self.stream.readLine(*args, **kw)
|
||||
self.status = self.proc.poll()
|
||||
self.check_status()
|
||||
return result
|
||||
|
||||
def close(self):
|
||||
"""Wrap stream.close, wait for the subprocess, and handle errors."""
|
||||
self.stream.close()
|
||||
self.status = self.proc.wait(self.timeout)
|
||||
self.wait_for_child()
|
||||
|
||||
def __enter__(self):
|
||||
"""Context handler."""
|
||||
return self
|
||||
|
||||
def __exit__(self, etype, value, traceback):
|
||||
"""Context handler."""
|
||||
self.close()
|
||||
|
||||
|
||||
def set_options(
|
||||
obj, timeout=None, ignore_errors=None, ignore_status=None, handler=None
|
||||
):
|
||||
"""Set options for Pipes.
|
||||
|
||||
This function can be called on any stream. It will set pipe options only
|
||||
when its argument is a pipe.
|
||||
|
||||
:param obj: any kind of stream
|
||||
:param timeout: desired timeout
|
||||
:param ignore_errors: desired ignore_errors setting
|
||||
:param ignore_status: desired ignore_status setting
|
||||
:param handler: desired error handler
|
||||
"""
|
||||
if not isinstance(obj, Pipe):
|
||||
return False
|
||||
if timeout is not None:
|
||||
obj.timeout = timeout
|
||||
if ignore_errors is not None:
|
||||
obj.ignore_errors = ignore_errors
|
||||
if ignore_status is not None:
|
||||
obj.ignore_status = ignore_status
|
||||
if handler is not None:
|
||||
obj.handler = handler
|
||||
return True
|
||||
|
||||
|
||||
def gopen_file(url, mode="rb", bufsize=8192):
|
||||
"""Open a file.
|
||||
|
||||
This works for local files, files over HTTP, and pipe: files.
|
||||
|
||||
:param url: URL to be opened
|
||||
:param mode: mode to open it with
|
||||
:param bufsize: requested buffer size
|
||||
"""
|
||||
return open(url, mode)
|
||||
|
||||
|
||||
def gopen_pipe(url, mode="rb", bufsize=8192):
|
||||
"""Use gopen to open a pipe.
|
||||
|
||||
:param url: a pipe: URL
|
||||
:param mode: desired mode
|
||||
:param bufsize: desired buffer size
|
||||
"""
|
||||
assert url.startswith("pipe:")
|
||||
cmd = url[5:]
|
||||
if mode[0] == "r":
|
||||
return Pipe(
|
||||
cmd,
|
||||
mode=mode,
|
||||
shell=True,
|
||||
bufsize=bufsize,
|
||||
ignore_status=[141],
|
||||
) # skipcq: BAN-B604
|
||||
elif mode[0] == "w":
|
||||
return Pipe(
|
||||
cmd,
|
||||
mode=mode,
|
||||
shell=True,
|
||||
bufsize=bufsize,
|
||||
ignore_status=[141],
|
||||
) # skipcq: BAN-B604
|
||||
else:
|
||||
raise ValueError(f"{mode}: unknown mode")
|
||||
|
||||
|
||||
def gopen_curl(url, mode="rb", bufsize=8192):
|
||||
"""Open a URL with `curl`.
|
||||
|
||||
:param url: url (usually, http:// etc.)
|
||||
:param mode: file mode
|
||||
:param bufsize: buffer size
|
||||
"""
|
||||
if mode[0] == "r":
|
||||
cmd = f"curl -s -L '{url}'"
|
||||
return Pipe(
|
||||
cmd,
|
||||
mode=mode,
|
||||
shell=True,
|
||||
bufsize=bufsize,
|
||||
ignore_status=[141, 23],
|
||||
) # skipcq: BAN-B604
|
||||
elif mode[0] == "w":
|
||||
cmd = f"curl -s -L -T - '{url}'"
|
||||
return Pipe(
|
||||
cmd,
|
||||
mode=mode,
|
||||
shell=True,
|
||||
bufsize=bufsize,
|
||||
ignore_status=[141, 26],
|
||||
) # skipcq: BAN-B604
|
||||
else:
|
||||
raise ValueError(f"{mode}: unknown mode")
|
||||
|
||||
|
||||
def gopen_htgs(url, mode="rb", bufsize=8192):
|
||||
"""Open a URL with `curl`.
|
||||
|
||||
:param url: url (usually, http:// etc.)
|
||||
:param mode: file mode
|
||||
:param bufsize: buffer size
|
||||
"""
|
||||
if mode[0] == "r":
|
||||
url = re.sub(r"(?i)^htgs://", "gs://", url)
|
||||
cmd = f"curl -s -L '{url}'"
|
||||
return Pipe(
|
||||
cmd,
|
||||
mode=mode,
|
||||
shell=True,
|
||||
bufsize=bufsize,
|
||||
ignore_status=[141, 23],
|
||||
) # skipcq: BAN-B604
|
||||
elif mode[0] == "w":
|
||||
raise ValueError(f"{mode}: cannot write")
|
||||
else:
|
||||
raise ValueError(f"{mode}: unknown mode")
|
||||
|
||||
|
||||
|
||||
def gopen_gsutil(url, mode="rb", bufsize=8192):
|
||||
"""Open a URL with `curl`.
|
||||
|
||||
:param url: url (usually, http:// etc.)
|
||||
:param mode: file mode
|
||||
:param bufsize: buffer size
|
||||
"""
|
||||
if mode[0] == "r":
|
||||
cmd = f"gsutil cat '{url}'"
|
||||
return Pipe(
|
||||
cmd,
|
||||
mode=mode,
|
||||
shell=True,
|
||||
bufsize=bufsize,
|
||||
ignore_status=[141, 23],
|
||||
) # skipcq: BAN-B604
|
||||
elif mode[0] == "w":
|
||||
cmd = f"gsutil cp - '{url}'"
|
||||
return Pipe(
|
||||
cmd,
|
||||
mode=mode,
|
||||
shell=True,
|
||||
bufsize=bufsize,
|
||||
ignore_status=[141, 26],
|
||||
) # skipcq: BAN-B604
|
||||
else:
|
||||
raise ValueError(f"{mode}: unknown mode")
|
||||
|
||||
|
||||
|
||||
def gopen_error(url, *args, **kw):
|
||||
"""Raise a value error.
|
||||
|
||||
:param url: url
|
||||
:param args: other arguments
|
||||
:param kw: other keywords
|
||||
"""
|
||||
raise ValueError(f"{url}: no gopen handler defined")
|
||||
|
||||
|
||||
"""A dispatch table mapping URL schemes to handlers."""
|
||||
gopen_schemes = dict(
|
||||
__default__=gopen_error,
|
||||
pipe=gopen_pipe,
|
||||
http=gopen_curl,
|
||||
https=gopen_curl,
|
||||
sftp=gopen_curl,
|
||||
ftps=gopen_curl,
|
||||
scp=gopen_curl,
|
||||
gs=gopen_gsutil,
|
||||
htgs=gopen_htgs,
|
||||
)
|
||||
|
||||
|
||||
def gopen(url, mode="rb", bufsize=8192, **kw):
|
||||
"""Open the URL.
|
||||
|
||||
This uses the `gopen_schemes` dispatch table to dispatch based
|
||||
on scheme.
|
||||
|
||||
Support for the following schemes is built-in: pipe, file,
|
||||
http, https, sftp, ftps, scp.
|
||||
|
||||
When no scheme is given the url is treated as a file.
|
||||
|
||||
You can use the OPEN_VERBOSE argument to get info about
|
||||
files being opened.
|
||||
|
||||
:param url: the source URL
|
||||
:param mode: the mode ("rb", "r")
|
||||
:param bufsize: the buffer size
|
||||
"""
|
||||
global fallback_gopen
|
||||
verbose = int(os.environ.get("GOPEN_VERBOSE", 0))
|
||||
if verbose:
|
||||
print("GOPEN", url, info, file=sys.stderr)
|
||||
assert mode in ["rb", "wb"], mode
|
||||
if url == "-":
|
||||
if mode == "rb":
|
||||
return sys.stdin.buffer
|
||||
elif mode == "wb":
|
||||
return sys.stdout.buffer
|
||||
else:
|
||||
raise ValueError(f"unknown mode {mode}")
|
||||
pr = urlparse(url)
|
||||
if pr.scheme == "":
|
||||
bufsize = int(os.environ.get("GOPEN_BUFFER", -1))
|
||||
return open(url, mode, buffering=bufsize)
|
||||
if pr.scheme == "file":
|
||||
bufsize = int(os.environ.get("GOPEN_BUFFER", -1))
|
||||
return open(pr.path, mode, buffering=bufsize)
|
||||
handler = gopen_schemes["__default__"]
|
||||
handler = gopen_schemes.get(pr.scheme, handler)
|
||||
return handler(url, mode, bufsize, **kw)
|
||||
|
||||
|
||||
def reader(url, **kw):
|
||||
"""Open url with gopen and mode "rb".
|
||||
|
||||
:param url: source URL
|
||||
:param kw: other keywords forwarded to gopen
|
||||
"""
|
||||
return gopen(url, "rb", **kw)
|
@ -0,0 +1,47 @@
|
||||
#
|
||||
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||
# This file is part of the WebDataset library.
|
||||
# See the LICENSE file for licensing terms (BSD-style).
|
||||
#
|
||||
|
||||
"""Pluggable exception handlers.
|
||||
|
||||
These are functions that take an exception as an argument and then return...
|
||||
|
||||
- the exception (in order to re-raise it)
|
||||
- True (in order to continue and ignore the exception)
|
||||
- False (in order to ignore the exception and stop processing)
|
||||
|
||||
They are used as handler= arguments in much of the library.
|
||||
"""
|
||||
|
||||
import time, warnings
|
||||
|
||||
|
||||
def reraise_exception(exn):
|
||||
"""Call in an exception handler to re-raise the exception."""
|
||||
raise exn
|
||||
|
||||
|
||||
def ignore_and_continue(exn):
|
||||
"""Call in an exception handler to ignore any exception and continue."""
|
||||
return True
|
||||
|
||||
|
||||
def warn_and_continue(exn):
|
||||
"""Call in an exception handler to ignore any exception, isssue a warning, and continue."""
|
||||
warnings.warn(repr(exn))
|
||||
time.sleep(0.5)
|
||||
return True
|
||||
|
||||
|
||||
def ignore_and_stop(exn):
|
||||
"""Call in an exception handler to ignore any exception and stop further processing."""
|
||||
return False
|
||||
|
||||
|
||||
def warn_and_stop(exn):
|
||||
"""Call in an exception handler to ignore any exception and stop further processing."""
|
||||
warnings.warn(repr(exn))
|
||||
time.sleep(0.5)
|
||||
return False
|
@ -0,0 +1,85 @@
|
||||
#
|
||||
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
# This file is part of the WebDataset library.
|
||||
# See the LICENSE file for licensing terms (BSD-style).
|
||||
# Modified from https://github.com/webdataset/webdataset
|
||||
#
|
||||
|
||||
"""Classes for mixing samples from multiple sources."""
|
||||
|
||||
import itertools, os, random, time, sys
|
||||
from functools import reduce, wraps
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import autodecode, utils
|
||||
from .paddle_utils import PaddleTensor, IterableDataset
|
||||
from .utils import PipelineStage
|
||||
|
||||
|
||||
def round_robin_shortest(*sources):
|
||||
i = 0
|
||||
while True:
|
||||
try:
|
||||
sample = next(sources[i % len(sources)])
|
||||
yield sample
|
||||
except StopIteration:
|
||||
break
|
||||
i += 1
|
||||
|
||||
|
||||
def round_robin_longest(*sources):
|
||||
i = 0
|
||||
while len(sources) > 0:
|
||||
try:
|
||||
sample = next(sources[i])
|
||||
i += 1
|
||||
yield sample
|
||||
except StopIteration:
|
||||
del sources[i]
|
||||
|
||||
|
||||
class RoundRobin(IterableDataset):
|
||||
def __init__(self, datasets, longest=False):
|
||||
self.datasets = datasets
|
||||
self.longest = longest
|
||||
|
||||
def __iter__(self):
|
||||
"""Return an iterator over the sources."""
|
||||
sources = [iter(d) for d in self.datasets]
|
||||
if self.longest:
|
||||
return round_robin_longest(*sources)
|
||||
else:
|
||||
return round_robin_shortest(*sources)
|
||||
|
||||
|
||||
def random_samples(sources, probs=None, longest=False):
|
||||
if probs is None:
|
||||
probs = [1] * len(sources)
|
||||
else:
|
||||
probs = list(probs)
|
||||
while len(sources) > 0:
|
||||
cum = (np.array(probs) / np.sum(probs)).cumsum()
|
||||
r = random.random()
|
||||
i = np.searchsorted(cum, r)
|
||||
try:
|
||||
yield next(sources[i])
|
||||
except StopIteration:
|
||||
if longest:
|
||||
del sources[i]
|
||||
del probs[i]
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
class RandomMix(IterableDataset):
|
||||
def __init__(self, datasets, probs=None, longest=False):
|
||||
self.datasets = datasets
|
||||
self.probs = probs
|
||||
self.longest = longest
|
||||
|
||||
def __iter__(self):
|
||||
"""Return an iterator over the sources."""
|
||||
sources = [iter(d) for d in self.datasets]
|
||||
return random_samples(sources, self.probs, longest=self.longest)
|
@ -0,0 +1,33 @@
|
||||
#
|
||||
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
# This file is part of the WebDataset library.
|
||||
# See the LICENSE file for licensing terms (BSD-style).
|
||||
# Modified from https://github.com/webdataset/webdataset
|
||||
#
|
||||
|
||||
"""Mock implementations of paddle interfaces when paddle is not available."""
|
||||
|
||||
|
||||
try:
|
||||
from paddle.io import DataLoader, IterableDataset
|
||||
except ModuleNotFoundError:
|
||||
|
||||
class IterableDataset:
|
||||
"""Empty implementation of IterableDataset when paddle is not available."""
|
||||
|
||||
pass
|
||||
|
||||
class DataLoader:
|
||||
"""Empty implementation of DataLoader when paddle is not available."""
|
||||
|
||||
pass
|
||||
|
||||
try:
|
||||
from paddle import Tensor as PaddleTensor
|
||||
except ModuleNotFoundError:
|
||||
|
||||
class TorchTensor:
|
||||
"""Empty implementation of PaddleTensor when paddle is not available."""
|
||||
|
||||
pass
|
@ -0,0 +1,132 @@
|
||||
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
# See the LICENSE file for licensing terms (BSD-style).
|
||||
# Modified from https://github.com/webdataset/webdataset
|
||||
#%%
|
||||
import copy, os, random, sys, time
|
||||
from dataclasses import dataclass
|
||||
from itertools import islice
|
||||
from typing import List
|
||||
|
||||
import braceexpand, yaml
|
||||
|
||||
from .handlers import reraise_exception
|
||||
from .paddle_utils import DataLoader, IterableDataset
|
||||
from .utils import PipelineStage
|
||||
|
||||
|
||||
def add_length_method(obj):
|
||||
def length(self):
|
||||
return self.size
|
||||
|
||||
Combined = type(
|
||||
obj.__class__.__name__ + "_Length",
|
||||
(obj.__class__, IterableDataset),
|
||||
{"__len__": length},
|
||||
)
|
||||
obj.__class__ = Combined
|
||||
return obj
|
||||
|
||||
|
||||
class DataPipeline(IterableDataset, PipelineStage):
|
||||
"""A pipeline starting with an IterableDataset and a series of filters."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.pipeline = []
|
||||
self.length = -1
|
||||
self.repetitions = 1
|
||||
self.nsamples = -1
|
||||
for arg in args:
|
||||
if arg is None:
|
||||
continue
|
||||
if isinstance(arg, list):
|
||||
self.pipeline.extend(arg)
|
||||
else:
|
||||
self.pipeline.append(arg)
|
||||
|
||||
def invoke(self, f, *args, **kwargs):
|
||||
"""Apply a pipeline stage, possibly to the output of a previous stage."""
|
||||
if isinstance(f, PipelineStage):
|
||||
return f.run(*args, **kwargs)
|
||||
if isinstance(f, (IterableDataset, DataLoader)) and len(args) == 0:
|
||||
return iter(f)
|
||||
if isinstance(f, list):
|
||||
return iter(f)
|
||||
if callable(f):
|
||||
result = f(*args, **kwargs)
|
||||
return result
|
||||
raise ValueError(f"{f}: not a valid pipeline stage")
|
||||
|
||||
def iterator1(self):
|
||||
"""Create an iterator through one epoch in the pipeline."""
|
||||
source = self.invoke(self.pipeline[0])
|
||||
for step in self.pipeline[1:]:
|
||||
source = self.invoke(step, source)
|
||||
return source
|
||||
|
||||
def iterator(self):
|
||||
"""Create an iterator through the entire dataset, using the given number of repetitions."""
|
||||
for i in range(self.repetitions):
|
||||
for sample in self.iterator1():
|
||||
yield sample
|
||||
|
||||
def __iter__(self):
|
||||
"""Create an iterator through the pipeline, repeating and slicing as requested."""
|
||||
if self.repetitions != 1:
|
||||
if self.nsamples > 0:
|
||||
return islice(self.iterator(), self.nsamples)
|
||||
else:
|
||||
return self.iterator()
|
||||
else:
|
||||
return self.iterator()
|
||||
|
||||
def stage(self, i):
|
||||
"""Return pipeline stage i."""
|
||||
return self.pipeline[i]
|
||||
|
||||
def append(self, f):
|
||||
"""Append a pipeline stage (modifies the object)."""
|
||||
self.pipeline.append(f)
|
||||
return self
|
||||
|
||||
def append_list(self, *args):
|
||||
for arg in args:
|
||||
self.pipeline.append(arg)
|
||||
return self
|
||||
|
||||
def compose(self, *args):
|
||||
"""Append a pipeline stage to a copy of the pipeline and returns the copy."""
|
||||
result = copy.copy(self)
|
||||
for arg in args:
|
||||
result.append(arg)
|
||||
return result
|
||||
|
||||
def with_length(self, n):
|
||||
"""Add a __len__ method returning the desired value.
|
||||
|
||||
This does not change the actual number of samples in an epoch.
|
||||
PyTorch IterableDataset should not have a __len__ method.
|
||||
This is provided only as a workaround for some broken training environments
|
||||
that require a __len__ method.
|
||||
"""
|
||||
self.size = n
|
||||
return add_length_method(self)
|
||||
|
||||
def with_epoch(self, nsamples=-1, nbatches=-1):
|
||||
"""Change the epoch to return the given number of samples/batches.
|
||||
|
||||
The two arguments mean the same thing."""
|
||||
self.repetitions = sys.maxsize
|
||||
self.nsamples = max(nsamples, nbatches)
|
||||
return self
|
||||
|
||||
def repeat(self, nepochs=-1, nbatches=-1):
|
||||
"""Repeat iterating through the dataset for the given #epochs up to the given #samples."""
|
||||
if nepochs > 0:
|
||||
self.repetitions = nepochs
|
||||
self.nsamples = nbatches
|
||||
else:
|
||||
self.repetitions = sys.maxsize
|
||||
self.nsamples = nbatches
|
||||
return self
|
@ -0,0 +1,261 @@
|
||||
#
|
||||
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
# This file is part of the WebDataset library.
|
||||
# See the LICENSE file for licensing terms (BSD-style).
|
||||
#
|
||||
|
||||
# Modified from https://github.com/webdataset/webdataset
|
||||
|
||||
"""Train PyTorch models directly from POSIX tar archive.
|
||||
|
||||
Code works locally or over HTTP connections.
|
||||
"""
|
||||
|
||||
import os, random, sys, time
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import islice
|
||||
from typing import List
|
||||
|
||||
import braceexpand, yaml
|
||||
|
||||
from . import utils
|
||||
from .filters import pipelinefilter
|
||||
from .paddle_utils import IterableDataset
|
||||
|
||||
|
||||
from ..utils.log import Logger
|
||||
logger = Logger(__name__)
|
||||
def expand_urls(urls):
|
||||
if isinstance(urls, str):
|
||||
urllist = urls.split("::")
|
||||
result = []
|
||||
for url in urllist:
|
||||
result.extend(braceexpand.braceexpand(url))
|
||||
return result
|
||||
else:
|
||||
return list(urls)
|
||||
|
||||
|
||||
class SimpleShardList(IterableDataset):
|
||||
"""An iterable dataset yielding a list of urls."""
|
||||
|
||||
def __init__(self, urls, seed=None):
|
||||
"""Iterate through the list of shards.
|
||||
|
||||
:param urls: a list of URLs as a Python list or brace notation string
|
||||
"""
|
||||
super().__init__()
|
||||
urls = expand_urls(urls)
|
||||
self.urls = urls
|
||||
assert isinstance(self.urls[0], str)
|
||||
self.seed = seed
|
||||
|
||||
def __len__(self):
|
||||
return len(self.urls)
|
||||
|
||||
def __iter__(self):
|
||||
"""Return an iterator over the shards."""
|
||||
urls = self.urls.copy()
|
||||
if self.seed is not None:
|
||||
random.Random(self.seed).shuffle(urls)
|
||||
for url in urls:
|
||||
yield dict(url=url)
|
||||
|
||||
|
||||
def split_by_node(src, group=None):
|
||||
rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group)
|
||||
logger.info(f"world_size:{world_size}, rank:{rank}")
|
||||
if world_size > 1:
|
||||
for s in islice(src, rank, None, world_size):
|
||||
yield s
|
||||
else:
|
||||
for s in src:
|
||||
yield s
|
||||
|
||||
|
||||
def single_node_only(src, group=None):
|
||||
rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group)
|
||||
if world_size > 1:
|
||||
raise ValueError("input pipeline needs to be reconfigured for multinode training")
|
||||
for s in src:
|
||||
yield s
|
||||
|
||||
|
||||
def split_by_worker(src):
|
||||
rank, world_size, worker, num_workers = utils.paddle_worker_info()
|
||||
logger.info(f"num_workers:{num_workers}, worker:{worker}")
|
||||
if num_workers > 1:
|
||||
for s in islice(src, worker, None, num_workers):
|
||||
yield s
|
||||
else:
|
||||
for s in src:
|
||||
yield s
|
||||
|
||||
|
||||
def resampled_(src, n=sys.maxsize):
|
||||
import random
|
||||
|
||||
seed = time.time()
|
||||
try:
|
||||
seed = open("/dev/random", "rb").read(20)
|
||||
except Exception as exn:
|
||||
print(repr(exn)[:50], file=sys.stderr)
|
||||
rng = random.Random(seed)
|
||||
print("# resampled loading", file=sys.stderr)
|
||||
items = list(src)
|
||||
print(f"# resampled got {len(items)} samples, yielding {n}", file=sys.stderr)
|
||||
for i in range(n):
|
||||
yield rng.choice(items)
|
||||
|
||||
|
||||
resampled = pipelinefilter(resampled_)
|
||||
|
||||
|
||||
def non_empty(src):
|
||||
count = 0
|
||||
for s in src:
|
||||
yield s
|
||||
count += 1
|
||||
if count == 0:
|
||||
raise ValueError("pipeline stage received no data at all and this was declared as an error")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MSSource:
|
||||
"""Class representing a data source."""
|
||||
|
||||
name: str = ""
|
||||
perepoch: int = -1
|
||||
resample: bool = False
|
||||
urls: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
default_rng = random.Random()
|
||||
|
||||
|
||||
def expand(s):
|
||||
return os.path.expanduser(os.path.expandvars(s))
|
||||
|
||||
|
||||
class MultiShardSample(IterableDataset):
|
||||
def __init__(self, fname):
|
||||
"""Construct a shardlist from multiple sources using a YAML spec."""
|
||||
self.epoch = -1
|
||||
class MultiShardSample(IterableDataset):
|
||||
def __init__(self, fname):
|
||||
"""Construct a shardlist from multiple sources using a YAML spec."""
|
||||
self.epoch = -1
|
||||
self.parse_spec(fname)
|
||||
|
||||
def parse_spec(self, fname):
|
||||
self.rng = default_rng # capture default_rng if we fork
|
||||
if isinstance(fname, dict):
|
||||
spec = fname
|
||||
fname = "{dict}"
|
||||
else:
|
||||
with open(fname) as stream:
|
||||
spec = yaml.safe_load(stream)
|
||||
assert set(spec.keys()).issubset(set("prefix datasets buckets".split())), list(spec.keys())
|
||||
prefix = expand(spec.get("prefix", ""))
|
||||
self.sources = []
|
||||
for ds in spec["datasets"]:
|
||||
assert set(ds.keys()).issubset(set("buckets name shards resample choose".split())), list(
|
||||
ds.keys()
|
||||
)
|
||||
buckets = ds.get("buckets", spec.get("buckets", []))
|
||||
if isinstance(buckets, str):
|
||||
buckets = [buckets]
|
||||
buckets = [expand(s) for s in buckets]
|
||||
if buckets == []:
|
||||
buckets = [""]
|
||||
assert len(buckets) == 1, f"{buckets}: FIXME support for multiple buckets unimplemented"
|
||||
bucket = buckets[0]
|
||||
name = ds.get("name", "@" + bucket)
|
||||
urls = ds["shards"]
|
||||
if isinstance(urls, str):
|
||||
urls = [urls]
|
||||
# urls = [u for url in urls for u in braceexpand.braceexpand(url)]
|
||||
urls = [
|
||||
prefix + os.path.join(bucket, u) for url in urls for u in braceexpand.braceexpand(expand(url))
|
||||
]
|
||||
resample = ds.get("resample", -1)
|
||||
nsample = ds.get("choose", -1)
|
||||
if nsample > len(urls):
|
||||
raise ValueError(f"perepoch {nsample} must be no greater than the number of shards")
|
||||
if (nsample > 0) and (resample > 0):
|
||||
raise ValueError("specify only one of perepoch or choose")
|
||||
entry = MSSource(name=name, urls=urls, perepoch=nsample, resample=resample)
|
||||
self.sources.append(entry)
|
||||
print(f"# {name} {len(urls)} {nsample}", file=sys.stderr)
|
||||
|
||||
def set_epoch(self, seed):
|
||||
"""Set the current epoch (for consistent shard selection among nodes)."""
|
||||
self.rng = random.Random(seed)
|
||||
|
||||
def get_shards_for_epoch(self):
|
||||
result = []
|
||||
for source in self.sources:
|
||||
if source.resample > 0:
|
||||
# sample with replacement
|
||||
l = self.rng.choices(source.urls, k=source.resample)
|
||||
elif source.perepoch > 0:
|
||||
# sample without replacement
|
||||
l = list(source.urls)
|
||||
self.rng.shuffle(l)
|
||||
l = l[: source.perepoch]
|
||||
else:
|
||||
l = list(source.urls)
|
||||
result += l
|
||||
self.rng.shuffle(result)
|
||||
return result
|
||||
|
||||
def __iter__(self):
|
||||
shards = self.get_shards_for_epoch()
|
||||
for shard in shards:
|
||||
yield dict(url=shard)
|
||||
|
||||
|
||||
def shardspec(spec):
|
||||
if spec.endswith(".yaml"):
|
||||
return MultiShardSample(spec)
|
||||
else:
|
||||
return SimpleShardList(spec)
|
||||
|
||||
|
||||
class ResampledShards(IterableDataset):
|
||||
"""An iterable dataset yielding a list of urls."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
urls,
|
||||
nshards=sys.maxsize,
|
||||
worker_seed=None,
|
||||
deterministic=False,
|
||||
):
|
||||
"""Sample shards from the shard list with replacement.
|
||||
|
||||
:param urls: a list of URLs as a Python list or brace notation string
|
||||
"""
|
||||
super().__init__()
|
||||
urls = expand_urls(urls)
|
||||
self.urls = urls
|
||||
assert isinstance(self.urls[0], str)
|
||||
self.nshards = nshards
|
||||
self.worker_seed = utils.paddle_worker_seed if worker_seed is None else worker_seed
|
||||
self.deterministic = deterministic
|
||||
self.epoch = -1
|
||||
|
||||
def __iter__(self):
|
||||
"""Return an iterator over the shards."""
|
||||
self.epoch += 1
|
||||
if self.deterministic:
|
||||
seed = utils.make_seed(self.worker_seed(), self.epoch)
|
||||
else:
|
||||
seed = utils.make_seed(self.worker_seed(), self.epoch, os.getpid(), time.time_ns(), os.urandom(4))
|
||||
if os.environ.get("WDS_SHOW_SEED", "0") == "1":
|
||||
print(f"# ResampledShards seed {seed}")
|
||||
self.rng = random.Random(seed)
|
||||
for _ in range(self.nshards):
|
||||
index = self.rng.randint(0, len(self.urls) - 1)
|
||||
yield dict(url=self.urls[index])
|
@ -0,0 +1,283 @@
|
||||
#
|
||||
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
# This file is part of the WebDataset library.
|
||||
# See the LICENSE file for licensing terms (BSD-style).
|
||||
|
||||
# Modified from https://github.com/webdataset/webdataset
|
||||
# Modified from wenet(https://github.com/wenet-e2e/wenet)
|
||||
|
||||
"""Low level iteration functions for tar archives."""
|
||||
|
||||
import random, re, tarfile
|
||||
|
||||
import braceexpand
|
||||
|
||||
from . import filters
|
||||
from . import gopen
|
||||
from .handlers import reraise_exception
|
||||
|
||||
trace = False
|
||||
meta_prefix = "__"
|
||||
meta_suffix = "__"
|
||||
|
||||
import paddlespeech
|
||||
import paddle
|
||||
import numpy as np
|
||||
|
||||
AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
|
||||
|
||||
def base_plus_ext(path):
|
||||
"""Split off all file extensions.
|
||||
|
||||
Returns base, allext.
|
||||
|
||||
:param path: path with extensions
|
||||
:param returns: path with all extensions removed
|
||||
|
||||
"""
|
||||
match = re.match(r"^((?:.*/|)[^.]+)[.]([^/]*)$", path)
|
||||
if not match:
|
||||
return None, None
|
||||
return match.group(1), match.group(2)
|
||||
|
||||
|
||||
def valid_sample(sample):
|
||||
"""Check whether a sample is valid.
|
||||
|
||||
:param sample: sample to be checked
|
||||
"""
|
||||
return (
|
||||
sample is not None
|
||||
and isinstance(sample, dict)
|
||||
and len(list(sample.keys())) > 0
|
||||
and not sample.get("__bad__", False)
|
||||
)
|
||||
|
||||
|
||||
# FIXME: UNUSED
|
||||
def shardlist(urls, *, shuffle=False):
|
||||
"""Given a list of URLs, yields that list, possibly shuffled."""
|
||||
if isinstance(urls, str):
|
||||
urls = braceexpand.braceexpand(urls)
|
||||
else:
|
||||
urls = list(urls)
|
||||
if shuffle:
|
||||
random.shuffle(urls)
|
||||
for url in urls:
|
||||
yield dict(url=url)
|
||||
|
||||
|
||||
def url_opener(data, handler=reraise_exception, **kw):
|
||||
"""Given a stream of url names (packaged in `dict(url=url)`), yield opened streams."""
|
||||
for sample in data:
|
||||
assert isinstance(sample, dict), sample
|
||||
assert "url" in sample
|
||||
url = sample["url"]
|
||||
try:
|
||||
stream = gopen.gopen(url, **kw)
|
||||
sample.update(stream=stream)
|
||||
yield sample
|
||||
except Exception as exn:
|
||||
exn.args = exn.args + (url,)
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
def tar_file_iterator(
|
||||
fileobj, skip_meta=r"__[^/]*__($|/)", handler=reraise_exception
|
||||
):
|
||||
"""Iterate over tar file, yielding filename, content pairs for the given tar stream.
|
||||
|
||||
:param fileobj: byte stream suitable for tarfile
|
||||
:param skip_meta: regexp for keys that are skipped entirely (Default value = r"__[^/]*__($|/)")
|
||||
|
||||
"""
|
||||
stream = tarfile.open(fileobj=fileobj, mode="r:*")
|
||||
for tarinfo in stream:
|
||||
fname = tarinfo.name
|
||||
try:
|
||||
if not tarinfo.isreg():
|
||||
continue
|
||||
if fname is None:
|
||||
continue
|
||||
if (
|
||||
"/" not in fname
|
||||
and fname.startswith(meta_prefix)
|
||||
and fname.endswith(meta_suffix)
|
||||
):
|
||||
# skipping metadata for now
|
||||
continue
|
||||
if skip_meta is not None and re.match(skip_meta, fname):
|
||||
continue
|
||||
|
||||
name = tarinfo.name
|
||||
pos = name.rfind('.')
|
||||
assert pos > 0
|
||||
prefix, postfix = name[:pos], name[pos + 1:]
|
||||
if postfix == 'wav':
|
||||
waveform, sample_rate = paddlespeech.audio.load(stream.extractfile(tarinfo), normal=False)
|
||||
result = dict(fname=prefix, wav=waveform, sample_rate = sample_rate)
|
||||
else:
|
||||
txt = stream.extractfile(tarinfo).read().decode('utf8').strip()
|
||||
result = dict(fname=prefix, txt=txt)
|
||||
#result = dict(fname=fname, data=data)
|
||||
yield result
|
||||
stream.members = []
|
||||
except Exception as exn:
|
||||
if hasattr(exn, "args") and len(exn.args) > 0:
|
||||
exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:]
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
del stream
|
||||
|
||||
def tar_file_and_group_iterator(
|
||||
fileobj, skip_meta=r"__[^/]*__($|/)", handler=reraise_exception
|
||||
):
|
||||
""" Expand a stream of open tar files into a stream of tar file contents.
|
||||
And groups the file with same prefix
|
||||
|
||||
Args:
|
||||
data: Iterable[{src, stream}]
|
||||
|
||||
Returns:
|
||||
Iterable[{key, wav, txt, sample_rate}]
|
||||
"""
|
||||
stream = tarfile.open(fileobj=fileobj, mode="r:*")
|
||||
prev_prefix = None
|
||||
example = {}
|
||||
valid = True
|
||||
for tarinfo in stream:
|
||||
name = tarinfo.name
|
||||
pos = name.rfind('.')
|
||||
assert pos > 0
|
||||
prefix, postfix = name[:pos], name[pos + 1:]
|
||||
if prev_prefix is not None and prefix != prev_prefix:
|
||||
example['fname'] = prev_prefix
|
||||
if valid:
|
||||
yield example
|
||||
example = {}
|
||||
valid = True
|
||||
with stream.extractfile(tarinfo) as file_obj:
|
||||
try:
|
||||
if postfix == 'txt':
|
||||
example['txt'] = file_obj.read().decode('utf8').strip()
|
||||
elif postfix in AUDIO_FORMAT_SETS:
|
||||
waveform, sample_rate = paddlespeech.audio.load(file_obj, normal=False)
|
||||
waveform = paddle.to_tensor(np.expand_dims(np.array(waveform),0), dtype=paddle.float32)
|
||||
|
||||
example['wav'] = waveform
|
||||
example['sample_rate'] = sample_rate
|
||||
else:
|
||||
example[postfix] = file_obj.read()
|
||||
except Exception as exn:
|
||||
if hasattr(exn, "args") and len(exn.args) > 0:
|
||||
exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:]
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
valid = False
|
||||
# logging.warning('error to parse {}'.format(name))
|
||||
prev_prefix = prefix
|
||||
if prev_prefix is not None:
|
||||
example['fname'] = prev_prefix
|
||||
yield example
|
||||
stream.close()
|
||||
|
||||
def tar_file_expander(data, handler=reraise_exception):
|
||||
"""Expand a stream of open tar files into a stream of tar file contents.
|
||||
|
||||
This returns an iterator over (filename, file_contents).
|
||||
"""
|
||||
for source in data:
|
||||
url = source["url"]
|
||||
try:
|
||||
assert isinstance(source, dict)
|
||||
assert "stream" in source
|
||||
for sample in tar_file_iterator(source["stream"]):
|
||||
assert (
|
||||
isinstance(sample, dict) and "data" in sample and "fname" in sample
|
||||
)
|
||||
sample["__url__"] = url
|
||||
yield sample
|
||||
except Exception as exn:
|
||||
exn.args = exn.args + (source.get("stream"), source.get("url"))
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
|
||||
|
||||
def tar_file_and_group_expander(data, handler=reraise_exception):
|
||||
"""Expand a stream of open tar files into a stream of tar file contents.
|
||||
|
||||
This returns an iterator over (filename, file_contents).
|
||||
"""
|
||||
for source in data:
|
||||
url = source["url"]
|
||||
try:
|
||||
assert isinstance(source, dict)
|
||||
assert "stream" in source
|
||||
for sample in tar_file_and_group_iterator(source["stream"]):
|
||||
assert (
|
||||
isinstance(sample, dict) and "wav" in sample and "txt" in sample and "fname" in sample
|
||||
)
|
||||
sample["__url__"] = url
|
||||
yield sample
|
||||
except Exception as exn:
|
||||
exn.args = exn.args + (source.get("stream"), source.get("url"))
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
def group_by_keys(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
|
||||
"""Return function over iterator that groups key, value pairs into samples.
|
||||
|
||||
:param keys: function that splits the key into key and extension (base_plus_ext)
|
||||
:param lcase: convert suffixes to lower case (Default value = True)
|
||||
"""
|
||||
current_sample = None
|
||||
for filesample in data:
|
||||
assert isinstance(filesample, dict)
|
||||
fname, value = filesample["fname"], filesample["data"]
|
||||
prefix, suffix = keys(fname)
|
||||
if trace:
|
||||
print(
|
||||
prefix,
|
||||
suffix,
|
||||
current_sample.keys() if isinstance(current_sample, dict) else None,
|
||||
)
|
||||
if prefix is None:
|
||||
continue
|
||||
if lcase:
|
||||
suffix = suffix.lower()
|
||||
if current_sample is None or prefix != current_sample["__key__"]:
|
||||
if valid_sample(current_sample):
|
||||
yield current_sample
|
||||
current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
|
||||
if suffix in current_sample:
|
||||
raise ValueError(
|
||||
f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}"
|
||||
)
|
||||
if suffixes is None or suffix in suffixes:
|
||||
current_sample[suffix] = value
|
||||
if valid_sample(current_sample):
|
||||
yield current_sample
|
||||
|
||||
|
||||
def tarfile_samples(src, handler=reraise_exception):
|
||||
streams = url_opener(src, handler=handler)
|
||||
samples = tar_file_and_group_expander(streams, handler=handler)
|
||||
return samples
|
||||
|
||||
|
||||
tarfile_to_samples = filters.pipelinefilter(tarfile_samples)
|
@ -0,0 +1,132 @@
|
||||
#
|
||||
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
# This file is part of the WebDataset library.
|
||||
# See the LICENSE file for licensing terms (BSD-style).
|
||||
#
|
||||
|
||||
# Modified from https://github.com/webdataset/webdataset
|
||||
|
||||
"""Miscellaneous utility functions."""
|
||||
|
||||
import importlib
|
||||
import itertools as itt
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from typing import Any, Callable, Iterator, Optional, Union
|
||||
|
||||
from ..utils.log import Logger
|
||||
|
||||
logger = Logger(__name__)
|
||||
|
||||
def make_seed(*args):
|
||||
seed = 0
|
||||
for arg in args:
|
||||
seed = (seed * 31 + hash(arg)) & 0x7FFFFFFF
|
||||
return seed
|
||||
|
||||
|
||||
class PipelineStage:
|
||||
def invoke(self, *args, **kw):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def identity(x: Any) -> Any:
|
||||
"""Return the argument as is."""
|
||||
return x
|
||||
|
||||
|
||||
def safe_eval(s: str, expr: str = "{}"):
|
||||
"""Evaluate the given expression more safely."""
|
||||
if re.sub("[^A-Za-z0-9_]", "", s) != s:
|
||||
raise ValueError(f"safe_eval: illegal characters in: '{s}'")
|
||||
return eval(expr.format(s))
|
||||
|
||||
|
||||
def lookup_sym(sym: str, modules: list):
|
||||
"""Look up a symbol in a list of modules."""
|
||||
for mname in modules:
|
||||
module = importlib.import_module(mname, package="webdataset")
|
||||
result = getattr(module, sym, None)
|
||||
if result is not None:
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
def repeatedly0(
|
||||
loader: Iterator, nepochs: int = sys.maxsize, nbatches: int = sys.maxsize
|
||||
):
|
||||
"""Repeatedly returns batches from a DataLoader."""
|
||||
for epoch in range(nepochs):
|
||||
for sample in itt.islice(loader, nbatches):
|
||||
yield sample
|
||||
|
||||
|
||||
def guess_batchsize(batch: Union[tuple, list]):
|
||||
"""Guess the batch size by looking at the length of the first element in a tuple."""
|
||||
return len(batch[0])
|
||||
|
||||
|
||||
def repeatedly(
|
||||
source: Iterator,
|
||||
nepochs: int = None,
|
||||
nbatches: int = None,
|
||||
nsamples: int = None,
|
||||
batchsize: Callable[..., int] = guess_batchsize,
|
||||
):
|
||||
"""Repeatedly yield samples from an iterator."""
|
||||
epoch = 0
|
||||
batch = 0
|
||||
total = 0
|
||||
while True:
|
||||
for sample in source:
|
||||
yield sample
|
||||
batch += 1
|
||||
if nbatches is not None and batch >= nbatches:
|
||||
return
|
||||
if nsamples is not None:
|
||||
total += guess_batchsize(sample)
|
||||
if total >= nsamples:
|
||||
return
|
||||
epoch += 1
|
||||
if nepochs is not None and epoch >= nepochs:
|
||||
return
|
||||
|
||||
def paddle_worker_info(group=None):
|
||||
"""Return node and worker info for PyTorch and some distributed environments."""
|
||||
rank = 0
|
||||
world_size = 1
|
||||
worker = 0
|
||||
num_workers = 1
|
||||
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||
rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
else:
|
||||
try:
|
||||
import paddle.distributed
|
||||
group = group or paddle.distributed.get_group()
|
||||
rank = paddle.distributed.get_rank()
|
||||
world_size = paddle.distributed.get_world_size()
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
if "WORKER" in os.environ and "NUM_WORKERS" in os.environ:
|
||||
worker = int(os.environ["WORKER"])
|
||||
num_workers = int(os.environ["NUM_WORKERS"])
|
||||
else:
|
||||
try:
|
||||
from paddle.io import get_worker_info
|
||||
worker_info = paddle.io.get_worker_info()
|
||||
if worker_info is not None:
|
||||
worker = worker_info.id
|
||||
num_workers = worker_info.num_workers
|
||||
except ModuleNotFoundError as E:
|
||||
logger.info(f"not found {E}")
|
||||
exit(-1)
|
||||
|
||||
return rank, world_size, worker, num_workers
|
||||
|
||||
def paddle_worker_seed(group=None):
|
||||
"""Compute a distinct, deterministic RNG seed for each worker and node."""
|
||||
rank, world_size, worker, num_workers = paddle_worker_info(group=group)
|
||||
return rank * 1000 + worker
|
@ -0,0 +1,450 @@
|
||||
#
|
||||
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
# This file is part of the WebDataset library.
|
||||
# See the LICENSE file for licensing terms (BSD-style).
|
||||
# Modified from https://github.com/webdataset/webdataset
|
||||
#
|
||||
|
||||
"""Classes and functions for writing tar files and WebDataset files."""
|
||||
|
||||
import io, json, pickle, re, tarfile, time
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import gopen
|
||||
|
||||
|
||||
def imageencoder(image: Any, format: str = "PNG"): # skipcq: PYL-W0622
|
||||
"""Compress an image using PIL and return it as a string.
|
||||
|
||||
Can handle float or uint8 images.
|
||||
|
||||
:param image: ndarray representing an image
|
||||
:param format: compression format (PNG, JPEG, PPM)
|
||||
|
||||
"""
|
||||
import PIL
|
||||
|
||||
assert isinstance(image, (PIL.Image.Image, np.ndarray)), type(image)
|
||||
|
||||
if isinstance(image, np.ndarray):
|
||||
if image.dtype in [np.dtype("f"), np.dtype("d")]:
|
||||
if not (np.amin(image) > -0.001 and np.amax(image) < 1.001):
|
||||
raise ValueError(
|
||||
f"image values out of range {np.amin(image)} {np.amax(image)}"
|
||||
)
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
image = np.array(image * 255.0, "uint8")
|
||||
assert image.ndim in [2, 3]
|
||||
if image.ndim == 3:
|
||||
assert image.shape[2] in [1, 3]
|
||||
image = PIL.Image.fromarray(image)
|
||||
if format.upper() == "JPG":
|
||||
format = "JPEG"
|
||||
elif format.upper() in ["IMG", "IMAGE"]:
|
||||
format = "PPM"
|
||||
if format == "JPEG":
|
||||
opts = dict(quality=100)
|
||||
else:
|
||||
opts = {}
|
||||
with io.BytesIO() as result:
|
||||
image.save(result, format=format, **opts)
|
||||
return result.getvalue()
|
||||
|
||||
|
||||
def bytestr(data: Any):
|
||||
"""Convert data into a bytestring.
|
||||
|
||||
Uses str and ASCII encoding for data that isn't already in string format.
|
||||
|
||||
:param data: data
|
||||
"""
|
||||
if isinstance(data, bytes):
|
||||
return data
|
||||
if isinstance(data, str):
|
||||
return data.encode("ascii")
|
||||
return str(data).encode("ascii")
|
||||
|
||||
def paddle_dumps(data: Any):
|
||||
"""Dump data into a bytestring using paddle.dumps.
|
||||
|
||||
This delays importing paddle until needed.
|
||||
|
||||
:param data: data to be dumped
|
||||
"""
|
||||
import io
|
||||
|
||||
import paddle
|
||||
|
||||
stream = io.BytesIO()
|
||||
paddle.save(data, stream)
|
||||
return stream.getvalue()
|
||||
|
||||
def numpy_dumps(data: np.ndarray):
|
||||
"""Dump data into a bytestring using numpy npy format.
|
||||
|
||||
:param data: data to be dumped
|
||||
"""
|
||||
import io
|
||||
|
||||
import numpy.lib.format
|
||||
|
||||
stream = io.BytesIO()
|
||||
numpy.lib.format.write_array(stream, data)
|
||||
return stream.getvalue()
|
||||
|
||||
|
||||
def numpy_npz_dumps(data: np.ndarray):
|
||||
"""Dump data into a bytestring using numpy npz format.
|
||||
|
||||
:param data: data to be dumped
|
||||
"""
|
||||
import io
|
||||
|
||||
stream = io.BytesIO()
|
||||
np.savez_compressed(stream, **data)
|
||||
return stream.getvalue()
|
||||
|
||||
|
||||
def tenbin_dumps(x):
|
||||
from . import tenbin
|
||||
|
||||
if isinstance(x, list):
|
||||
return memoryview(tenbin.encode_buffer(x))
|
||||
else:
|
||||
return memoryview(tenbin.encode_buffer([x]))
|
||||
|
||||
|
||||
def cbor_dumps(x):
|
||||
import cbor
|
||||
|
||||
return cbor.dumps(x)
|
||||
|
||||
|
||||
def mp_dumps(x):
|
||||
import msgpack
|
||||
|
||||
return msgpack.packb(x)
|
||||
|
||||
|
||||
def add_handlers(d, keys, value):
|
||||
if isinstance(keys, str):
|
||||
keys = keys.split()
|
||||
for k in keys:
|
||||
d[k] = value
|
||||
|
||||
|
||||
def make_handlers():
|
||||
"""Create a list of handlers for encoding data."""
|
||||
handlers = {}
|
||||
add_handlers(
|
||||
handlers, "cls cls2 class count index inx id", lambda x: str(x).encode("ascii")
|
||||
)
|
||||
add_handlers(handlers, "txt text transcript", lambda x: x.encode("utf-8"))
|
||||
add_handlers(handlers, "html htm", lambda x: x.encode("utf-8"))
|
||||
add_handlers(handlers, "pyd pickle", pickle.dumps)
|
||||
add_handlers(handlers, "pdparams", paddle_dumps)
|
||||
add_handlers(handlers, "npy", numpy_dumps)
|
||||
add_handlers(handlers, "npz", numpy_npz_dumps)
|
||||
add_handlers(handlers, "ten tenbin tb", tenbin_dumps)
|
||||
add_handlers(handlers, "json jsn", lambda x: json.dumps(x).encode("utf-8"))
|
||||
add_handlers(handlers, "mp msgpack msg", mp_dumps)
|
||||
add_handlers(handlers, "cbor", cbor_dumps)
|
||||
add_handlers(handlers, "jpg jpeg img image", lambda data: imageencoder(data, "jpg"))
|
||||
add_handlers(handlers, "png", lambda data: imageencoder(data, "png"))
|
||||
add_handlers(handlers, "pbm", lambda data: imageencoder(data, "pbm"))
|
||||
add_handlers(handlers, "pgm", lambda data: imageencoder(data, "pgm"))
|
||||
add_handlers(handlers, "ppm", lambda data: imageencoder(data, "ppm"))
|
||||
return handlers
|
||||
|
||||
|
||||
default_handlers = make_handlers()
|
||||
|
||||
|
||||
def encode_based_on_extension1(data: Any, tname: str, handlers: dict):
|
||||
"""Encode data based on its extension and a dict of handlers.
|
||||
|
||||
:param data: data
|
||||
:param tname: file extension
|
||||
:param handlers: handlers
|
||||
"""
|
||||
if tname[0] == "_":
|
||||
if not isinstance(data, str):
|
||||
raise ValueError("the values of metadata must be of string type")
|
||||
return data
|
||||
extension = re.sub(r".*\.", "", tname).lower()
|
||||
if isinstance(data, bytes):
|
||||
return data
|
||||
if isinstance(data, str):
|
||||
return data.encode("utf-8")
|
||||
handler = handlers.get(extension)
|
||||
if handler is None:
|
||||
raise ValueError(f"no handler found for {extension}")
|
||||
return handler(data)
|
||||
|
||||
|
||||
def encode_based_on_extension(sample: dict, handlers: dict):
|
||||
"""Encode an entire sample with a collection of handlers.
|
||||
|
||||
:param sample: data sample (a dict)
|
||||
:param handlers: handlers for encoding
|
||||
"""
|
||||
return {
|
||||
k: encode_based_on_extension1(v, k, handlers) for k, v in list(sample.items())
|
||||
}
|
||||
|
||||
|
||||
def make_encoder(spec: Union[bool, str, dict, Callable]):
|
||||
"""Make an encoder function from a specification.
|
||||
|
||||
:param spec: specification
|
||||
"""
|
||||
if spec is False or spec is None:
|
||||
|
||||
def encoder(x):
|
||||
"""Do not encode at all."""
|
||||
return x
|
||||
|
||||
elif callable(spec):
|
||||
encoder = spec
|
||||
elif isinstance(spec, dict):
|
||||
|
||||
def f(sample):
|
||||
"""Encode based on extension."""
|
||||
return encode_based_on_extension(sample, spec)
|
||||
|
||||
encoder = f
|
||||
|
||||
elif spec is True:
|
||||
handlers = default_handlers
|
||||
|
||||
def g(sample):
|
||||
"""Encode based on extension."""
|
||||
return encode_based_on_extension(sample, handlers)
|
||||
|
||||
encoder = g
|
||||
|
||||
else:
|
||||
raise ValueError(f"{spec}: unknown decoder spec")
|
||||
if not callable(encoder):
|
||||
raise ValueError(f"{spec} did not yield a callable encoder")
|
||||
return encoder
|
||||
|
||||
|
||||
class TarWriter:
|
||||
"""A class for writing dictionaries to tar files.
|
||||
|
||||
:param fileobj: fileobj: file name for tar file (.tgz/.tar) or open file descriptor
|
||||
:param encoder: sample encoding (Default value = True)
|
||||
:param compress: (Default value = None)
|
||||
|
||||
`True` will use an encoder that behaves similar to the automatic
|
||||
decoder for `Dataset`. `False` disables encoding and expects byte strings
|
||||
(except for metadata, which must be strings). The `encoder` argument can
|
||||
also be a `callable`, or a dictionary mapping extensions to encoders.
|
||||
|
||||
The following code will add two file to the tar archive: `a/b.png` and
|
||||
`a/b.output.png`.
|
||||
|
||||
```Python
|
||||
tarwriter = TarWriter(stream)
|
||||
image = imread("b.jpg")
|
||||
image2 = imread("b.out.jpg")
|
||||
sample = {"__key__": "a/b", "png": image, "output.png": image2}
|
||||
tarwriter.write(sample)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fileobj,
|
||||
user: str = "bigdata",
|
||||
group: str = "bigdata",
|
||||
mode: int = 0o0444,
|
||||
compress: Optional[bool] = None,
|
||||
encoder: Union[None, bool, Callable] = True,
|
||||
keep_meta: bool = False,
|
||||
):
|
||||
"""Create a tar writer.
|
||||
|
||||
:param fileobj: stream to write data to
|
||||
:param user: user for tar files
|
||||
:param group: group for tar files
|
||||
:param mode: mode for tar files
|
||||
:param compress: desired compression
|
||||
:param encoder: encoder function
|
||||
:param keep_meta: keep metadata (entries starting with "_")
|
||||
"""
|
||||
if isinstance(fileobj, str):
|
||||
if compress is False:
|
||||
tarmode = "w|"
|
||||
elif compress is True:
|
||||
tarmode = "w|gz"
|
||||
else:
|
||||
tarmode = "w|gz" if fileobj.endswith("gz") else "w|"
|
||||
fileobj = gopen.gopen(fileobj, "wb")
|
||||
self.own_fileobj = fileobj
|
||||
else:
|
||||
tarmode = "w|gz" if compress is True else "w|"
|
||||
self.own_fileobj = None
|
||||
self.encoder = make_encoder(encoder)
|
||||
self.keep_meta = keep_meta
|
||||
self.stream = fileobj
|
||||
self.tarstream = tarfile.open(fileobj=fileobj, mode=tarmode)
|
||||
|
||||
self.user = user
|
||||
self.group = group
|
||||
self.mode = mode
|
||||
self.compress = compress
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter context."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Exit context."""
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
"""Close the tar file."""
|
||||
self.tarstream.close()
|
||||
if self.own_fileobj is not None:
|
||||
self.own_fileobj.close()
|
||||
self.own_fileobj = None
|
||||
|
||||
def write(self, obj):
|
||||
"""Write a dictionary to the tar file.
|
||||
|
||||
:param obj: dictionary of objects to be stored
|
||||
:returns: size of the entry
|
||||
|
||||
"""
|
||||
total = 0
|
||||
obj = self.encoder(obj)
|
||||
if "__key__" not in obj:
|
||||
raise ValueError("object must contain a __key__")
|
||||
for k, v in list(obj.items()):
|
||||
if k[0] == "_":
|
||||
continue
|
||||
if not isinstance(v, (bytes, bytearray, memoryview)):
|
||||
raise ValueError(
|
||||
f"{k} doesn't map to a bytes after encoding ({type(v)})"
|
||||
)
|
||||
key = obj["__key__"]
|
||||
for k in sorted(obj.keys()):
|
||||
if k == "__key__":
|
||||
continue
|
||||
if not self.keep_meta and k[0] == "_":
|
||||
continue
|
||||
v = obj[k]
|
||||
if isinstance(v, str):
|
||||
v = v.encode("utf-8")
|
||||
now = time.time()
|
||||
ti = tarfile.TarInfo(key + "." + k)
|
||||
ti.size = len(v)
|
||||
ti.mtime = now
|
||||
ti.mode = self.mode
|
||||
ti.uname = self.user
|
||||
ti.gname = self.group
|
||||
if not isinstance(v, (bytes, bytearray, memoryview)):
|
||||
raise ValueError(f"converter didn't yield bytes: {k}, {type(v)}")
|
||||
stream = io.BytesIO(v)
|
||||
self.tarstream.addfile(ti, stream)
|
||||
total += ti.size
|
||||
return total
|
||||
|
||||
|
||||
class ShardWriter:
|
||||
"""Like TarWriter but splits into multiple shards."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pattern: str,
|
||||
maxcount: int = 100000,
|
||||
maxsize: float = 3e9,
|
||||
post: Optional[Callable] = None,
|
||||
start_shard: int = 0,
|
||||
**kw,
|
||||
):
|
||||
"""Create a ShardWriter.
|
||||
|
||||
:param pattern: output file pattern
|
||||
:param maxcount: maximum number of records per shard (Default value = 100000)
|
||||
:param maxsize: maximum size of each shard (Default value = 3e9)
|
||||
:param kw: other options passed to TarWriter
|
||||
"""
|
||||
self.verbose = 1
|
||||
self.kw = kw
|
||||
self.maxcount = maxcount
|
||||
self.maxsize = maxsize
|
||||
self.post = post
|
||||
|
||||
self.tarstream = None
|
||||
self.shard = start_shard
|
||||
self.pattern = pattern
|
||||
self.total = 0
|
||||
self.count = 0
|
||||
self.size = 0
|
||||
self.fname = None
|
||||
self.next_stream()
|
||||
|
||||
def next_stream(self):
|
||||
"""Close the current stream and move to the next."""
|
||||
self.finish()
|
||||
self.fname = self.pattern % self.shard
|
||||
if self.verbose:
|
||||
print(
|
||||
"# writing",
|
||||
self.fname,
|
||||
self.count,
|
||||
"%.1f GB" % (self.size / 1e9),
|
||||
self.total,
|
||||
)
|
||||
self.shard += 1
|
||||
stream = open(self.fname, "wb")
|
||||
self.tarstream = TarWriter(stream, **self.kw)
|
||||
self.count = 0
|
||||
self.size = 0
|
||||
|
||||
def write(self, obj):
|
||||
"""Write a sample.
|
||||
|
||||
:param obj: sample to be written
|
||||
"""
|
||||
if (
|
||||
self.tarstream is None
|
||||
or self.count >= self.maxcount
|
||||
or self.size >= self.maxsize
|
||||
):
|
||||
self.next_stream()
|
||||
size = self.tarstream.write(obj)
|
||||
self.count += 1
|
||||
self.total += 1
|
||||
self.size += size
|
||||
|
||||
def finish(self):
|
||||
"""Finish all writing (use close instead)."""
|
||||
if self.tarstream is not None:
|
||||
self.tarstream.close()
|
||||
assert self.fname is not None
|
||||
if callable(self.post):
|
||||
self.post(self.fname)
|
||||
self.tarstream = None
|
||||
|
||||
def close(self):
|
||||
"""Close the stream."""
|
||||
self.finish()
|
||||
del self.tarstream
|
||||
del self.shard
|
||||
del self.count
|
||||
del self.size
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter context."""
|
||||
return self
|
||||
|
||||
def __exit__(self, *args, **kw):
|
||||
"""Exit context."""
|
||||
self.close()
|
@ -0,0 +1,235 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains the text featurizer class."""
|
||||
from pprint import pformat
|
||||
from typing import Union
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
from .utility import BLANK
|
||||
from .utility import EOS
|
||||
from .utility import load_dict
|
||||
from .utility import MASKCTC
|
||||
from .utility import SOS
|
||||
from .utility import SPACE
|
||||
from .utility import UNK
|
||||
from ..utils.log import Logger
|
||||
|
||||
logger = Logger(__name__)
|
||||
|
||||
__all__ = ["TextFeaturizer"]
|
||||
|
||||
|
||||
class TextFeaturizer():
|
||||
def __init__(self, unit_type, vocab, spm_model_prefix=None, maskctc=False):
|
||||
"""Text featurizer, for processing or extracting features from text.
|
||||
|
||||
Currently, it supports char/word/sentence-piece level tokenizing and conversion into
|
||||
a list of token indices. Note that the token indexing order follows the
|
||||
given vocabulary file.
|
||||
|
||||
Args:
|
||||
unit_type (str): unit type, e.g. char, word, spm
|
||||
vocab Option[str, list]: Filepath to load vocabulary for token indices conversion, or vocab list.
|
||||
spm_model_prefix (str, optional): spm model prefix. Defaults to None.
|
||||
"""
|
||||
assert unit_type in ('char', 'spm', 'word')
|
||||
self.unit_type = unit_type
|
||||
self.unk = UNK
|
||||
self.maskctc = maskctc
|
||||
|
||||
if vocab:
|
||||
self.vocab_dict, self._id2token, self.vocab_list, self.unk_id, self.eos_id, self.blank_id = self._load_vocabulary_from_file(
|
||||
vocab, maskctc)
|
||||
self.vocab_size = len(self.vocab_list)
|
||||
else:
|
||||
logger.warning("TextFeaturizer: not have vocab file or vocab list.")
|
||||
|
||||
if unit_type == 'spm':
|
||||
spm_model = spm_model_prefix + '.model'
|
||||
self.sp = spm.SentencePieceProcessor()
|
||||
self.sp.Load(spm_model)
|
||||
|
||||
def tokenize(self, text, replace_space=True):
|
||||
if self.unit_type == 'char':
|
||||
tokens = self.char_tokenize(text, replace_space)
|
||||
elif self.unit_type == 'word':
|
||||
tokens = self.word_tokenize(text)
|
||||
else: # spm
|
||||
tokens = self.spm_tokenize(text)
|
||||
return tokens
|
||||
|
||||
def detokenize(self, tokens):
|
||||
if self.unit_type == 'char':
|
||||
text = self.char_detokenize(tokens)
|
||||
elif self.unit_type == 'word':
|
||||
text = self.word_detokenize(tokens)
|
||||
else: # spm
|
||||
text = self.spm_detokenize(tokens)
|
||||
return text
|
||||
|
||||
def featurize(self, text):
|
||||
"""Convert text string to a list of token indices.
|
||||
|
||||
Args:
|
||||
text (str): Text to process.
|
||||
|
||||
Returns:
|
||||
List[int]: List of token indices.
|
||||
"""
|
||||
tokens = self.tokenize(text)
|
||||
ids = []
|
||||
for token in tokens:
|
||||
if token not in self.vocab_dict:
|
||||
logger.debug(f"Text Token: {token} -> {self.unk}")
|
||||
token = self.unk
|
||||
ids.append(self.vocab_dict[token])
|
||||
return ids
|
||||
|
||||
def defeaturize(self, idxs):
|
||||
"""Convert a list of token indices to text string,
|
||||
ignore index after eos_id.
|
||||
|
||||
Args:
|
||||
idxs (List[int]): List of token indices.
|
||||
|
||||
Returns:
|
||||
str: Text.
|
||||
"""
|
||||
tokens = []
|
||||
for idx in idxs:
|
||||
if idx == self.eos_id:
|
||||
break
|
||||
tokens.append(self._id2token[idx])
|
||||
text = self.detokenize(tokens)
|
||||
return text
|
||||
|
||||
def char_tokenize(self, text, replace_space=True):
|
||||
"""Character tokenizer.
|
||||
|
||||
Args:
|
||||
text (str): text string.
|
||||
replace_space (bool): False only used by build_vocab.py.
|
||||
|
||||
Returns:
|
||||
List[str]: tokens.
|
||||
"""
|
||||
text = text.strip()
|
||||
if replace_space:
|
||||
text_list = [SPACE if item == " " else item for item in list(text)]
|
||||
else:
|
||||
text_list = list(text)
|
||||
return text_list
|
||||
|
||||
def char_detokenize(self, tokens):
|
||||
"""Character detokenizer.
|
||||
|
||||
Args:
|
||||
tokens (List[str]): tokens.
|
||||
|
||||
Returns:
|
||||
str: text string.
|
||||
"""
|
||||
tokens = [t.replace(SPACE, " ") for t in tokens]
|
||||
return "".join(tokens)
|
||||
|
||||
def word_tokenize(self, text):
|
||||
"""Word tokenizer, separate by <space>."""
|
||||
return text.strip().split()
|
||||
|
||||
def word_detokenize(self, tokens):
|
||||
"""Word detokenizer, separate by <space>."""
|
||||
return " ".join(tokens)
|
||||
|
||||
def spm_tokenize(self, text):
|
||||
"""spm tokenize.
|
||||
|
||||
Args:
|
||||
text (str): text string.
|
||||
|
||||
Returns:
|
||||
List[str]: sentence pieces str code
|
||||
"""
|
||||
stats = {"num_empty": 0, "num_filtered": 0}
|
||||
|
||||
def valid(line):
|
||||
return True
|
||||
|
||||
def encode(l):
|
||||
return self.sp.EncodeAsPieces(l)
|
||||
|
||||
def encode_line(line):
|
||||
line = line.strip()
|
||||
if len(line) > 0:
|
||||
line = encode(line)
|
||||
if valid(line):
|
||||
return line
|
||||
else:
|
||||
stats["num_filtered"] += 1
|
||||
else:
|
||||
stats["num_empty"] += 1
|
||||
return None
|
||||
|
||||
enc_line = encode_line(text)
|
||||
return enc_line
|
||||
|
||||
def spm_detokenize(self, tokens, input_format='piece'):
|
||||
"""spm detokenize.
|
||||
|
||||
Args:
|
||||
ids (List[str]): tokens.
|
||||
|
||||
Returns:
|
||||
str: text
|
||||
"""
|
||||
if input_format == "piece":
|
||||
|
||||
def decode(l):
|
||||
return "".join(self.sp.DecodePieces(l))
|
||||
elif input_format == "id":
|
||||
|
||||
def decode(l):
|
||||
return "".join(self.sp.DecodeIds(l))
|
||||
|
||||
return decode(tokens)
|
||||
|
||||
def _load_vocabulary_from_file(self, vocab: Union[str, list],
|
||||
maskctc: bool):
|
||||
"""Load vocabulary from file."""
|
||||
if isinstance(vocab, list):
|
||||
vocab_list = vocab
|
||||
else:
|
||||
vocab_list = load_dict(vocab, maskctc)
|
||||
assert vocab_list is not None
|
||||
logger.debug(f"Vocab: {pformat(vocab_list)}")
|
||||
|
||||
id2token = dict(
|
||||
[(idx, token) for (idx, token) in enumerate(vocab_list)])
|
||||
token2id = dict(
|
||||
[(token, idx) for (idx, token) in enumerate(vocab_list)])
|
||||
|
||||
blank_id = vocab_list.index(BLANK) if BLANK in vocab_list else -1
|
||||
maskctc_id = vocab_list.index(MASKCTC) if MASKCTC in vocab_list else -1
|
||||
unk_id = vocab_list.index(UNK) if UNK in vocab_list else -1
|
||||
eos_id = vocab_list.index(EOS) if EOS in vocab_list else -1
|
||||
sos_id = vocab_list.index(SOS) if SOS in vocab_list else -1
|
||||
space_id = vocab_list.index(SPACE) if SPACE in vocab_list else -1
|
||||
|
||||
logger.info(f"BLANK id: {blank_id}")
|
||||
logger.info(f"UNK id: {unk_id}")
|
||||
logger.info(f"EOS id: {eos_id}")
|
||||
logger.info(f"SOS id: {sos_id}")
|
||||
logger.info(f"SPACE id: {space_id}")
|
||||
logger.info(f"MASKCTC id: {maskctc_id}")
|
||||
return token2id, id2token, vocab_list, unk_id, eos_id, blank_id
|
@ -0,0 +1,393 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains data helper functions."""
|
||||
import json
|
||||
import math
|
||||
import tarfile
|
||||
from collections import namedtuple
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Text
|
||||
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
|
||||
from paddlespeech.s2t.utils.log import Log
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
__all__ = [
|
||||
"load_dict", "load_cmvn", "read_manifest", "rms_to_db", "rms_to_dbfs",
|
||||
"max_dbfs", "mean_dbfs", "gain_db_to_ratio", "normalize_audio", "SOS",
|
||||
"EOS", "UNK", "BLANK", "MASKCTC", "SPACE", "convert_samples_to_float32",
|
||||
"convert_samples_from_float32"
|
||||
]
|
||||
|
||||
IGNORE_ID = -1
|
||||
# `sos` and `eos` using same token
|
||||
SOS = "<eos>"
|
||||
EOS = SOS
|
||||
UNK = "<unk>"
|
||||
BLANK = "<blank>"
|
||||
MASKCTC = "<mask>"
|
||||
SPACE = "<space>"
|
||||
|
||||
|
||||
def load_dict(dict_path: Optional[Text], maskctc=False) -> Optional[List[Text]]:
|
||||
if dict_path is None:
|
||||
return None
|
||||
|
||||
with open(dict_path, "r") as f:
|
||||
dictionary = f.readlines()
|
||||
# first token is `<blank>`
|
||||
# multi line: `<blank> 0\n`
|
||||
# one line: `<blank>`
|
||||
# space is relpace with <space>
|
||||
char_list = [entry[:-1].split(" ")[0] for entry in dictionary]
|
||||
if BLANK not in char_list:
|
||||
char_list.insert(0, BLANK)
|
||||
if EOS not in char_list:
|
||||
char_list.append(EOS)
|
||||
# for non-autoregressive maskctc model
|
||||
if maskctc and MASKCTC not in char_list:
|
||||
char_list.append(MASKCTC)
|
||||
return char_list
|
||||
|
||||
|
||||
def read_manifest(
|
||||
manifest_path,
|
||||
max_input_len=float('inf'),
|
||||
min_input_len=0.0,
|
||||
max_output_len=float('inf'),
|
||||
min_output_len=0.0,
|
||||
max_output_input_ratio=float('inf'),
|
||||
min_output_input_ratio=0.0, ):
|
||||
"""Load and parse manifest file.
|
||||
|
||||
Args:
|
||||
manifest_path ([type]): Manifest file to load and parse.
|
||||
max_input_len ([type], optional): maximum output seq length,
|
||||
in seconds for raw wav, in frame numbers for feature data.
|
||||
Defaults to float('inf').
|
||||
min_input_len (float, optional): minimum input seq length,
|
||||
in seconds for raw wav, in frame numbers for feature data.
|
||||
Defaults to 0.0.
|
||||
max_output_len (float, optional): maximum input seq length,
|
||||
in modeling units. Defaults to 500.0.
|
||||
min_output_len (float, optional): minimum input seq length,
|
||||
in modeling units. Defaults to 0.0.
|
||||
max_output_input_ratio (float, optional):
|
||||
maximum output seq length/output seq length ratio. Defaults to 10.0.
|
||||
min_output_input_ratio (float, optional):
|
||||
minimum output seq length/output seq length ratio. Defaults to 0.05.
|
||||
|
||||
Raises:
|
||||
IOError: If failed to parse the manifest.
|
||||
|
||||
Returns:
|
||||
List[dict]: Manifest parsing results.
|
||||
"""
|
||||
manifest = []
|
||||
with jsonlines.open(manifest_path, 'r') as reader:
|
||||
for json_data in reader:
|
||||
feat_len = json_data["input"][0]["shape"][
|
||||
0] if "input" in json_data and "shape" in json_data["input"][
|
||||
0] else 1.0
|
||||
token_len = json_data["output"][0]["shape"][
|
||||
0] if "output" in json_data and "shape" in json_data["output"][
|
||||
0] else 1.0
|
||||
conditions = [
|
||||
feat_len >= min_input_len,
|
||||
feat_len <= max_input_len,
|
||||
token_len >= min_output_len,
|
||||
token_len <= max_output_len,
|
||||
token_len / feat_len >= min_output_input_ratio,
|
||||
token_len / feat_len <= max_output_input_ratio,
|
||||
]
|
||||
if all(conditions):
|
||||
manifest.append(json_data)
|
||||
return manifest
|
||||
|
||||
|
||||
# Tar File read
|
||||
TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object'])
|
||||
|
||||
|
||||
def parse_tar(file):
|
||||
"""Parse a tar file to get a tarfile object
|
||||
and a map containing tarinfoes
|
||||
"""
|
||||
result = {}
|
||||
f = tarfile.open(file)
|
||||
for tarinfo in f.getmembers():
|
||||
result[tarinfo.name] = tarinfo
|
||||
return f, result
|
||||
|
||||
|
||||
def subfile_from_tar(file, local_data=None):
|
||||
"""Get subfile object from tar.
|
||||
|
||||
tar:tarpath#filename
|
||||
|
||||
It will return a subfile object from tar file
|
||||
and cached tar file info for next reading request.
|
||||
"""
|
||||
tarpath, filename = file.split(':', 1)[1].split('#', 1)
|
||||
|
||||
if local_data is None:
|
||||
local_data = TarLocalData(tar2info={}, tar2object={})
|
||||
|
||||
assert isinstance(local_data, TarLocalData)
|
||||
|
||||
if 'tar2info' not in local_data.__dict__:
|
||||
local_data.tar2info = {}
|
||||
if 'tar2object' not in local_data.__dict__:
|
||||
local_data.tar2object = {}
|
||||
|
||||
if tarpath not in local_data.tar2info:
|
||||
fobj, infos = parse_tar(tarpath)
|
||||
local_data.tar2info[tarpath] = infos
|
||||
local_data.tar2object[tarpath] = fobj
|
||||
else:
|
||||
fobj = local_data.tar2object[tarpath]
|
||||
infos = local_data.tar2info[tarpath]
|
||||
return fobj.extractfile(infos[filename])
|
||||
|
||||
|
||||
def rms_to_db(rms: float):
|
||||
"""Root Mean Square to dB.
|
||||
|
||||
Args:
|
||||
rms ([float]): root mean square
|
||||
|
||||
Returns:
|
||||
float: dB
|
||||
"""
|
||||
return 20.0 * math.log10(max(1e-16, rms))
|
||||
|
||||
|
||||
def rms_to_dbfs(rms: float):
|
||||
"""Root Mean Square to dBFS.
|
||||
https://fireattack.wordpress.com/2017/02/06/replaygain-loudness-normalization-and-applications/
|
||||
Audio is mix of sine wave, so 1 amp sine wave's Full scale is 0.7071, equal to -3.0103dB.
|
||||
|
||||
dB = dBFS + 3.0103
|
||||
dBFS = db - 3.0103
|
||||
e.g. 0 dB = -3.0103 dBFS
|
||||
|
||||
Args:
|
||||
rms ([float]): root mean square
|
||||
|
||||
Returns:
|
||||
float: dBFS
|
||||
"""
|
||||
return rms_to_db(rms) - 3.0103
|
||||
|
||||
|
||||
def max_dbfs(sample_data: np.ndarray):
|
||||
"""Peak dBFS based on the maximum energy sample.
|
||||
|
||||
Args:
|
||||
sample_data ([np.ndarray]): float array, [-1, 1].
|
||||
|
||||
Returns:
|
||||
float: dBFS
|
||||
"""
|
||||
# Peak dBFS based on the maximum energy sample. Will prevent overdrive if used for normalization.
|
||||
return rms_to_dbfs(max(abs(np.min(sample_data)), abs(np.max(sample_data))))
|
||||
|
||||
|
||||
def mean_dbfs(sample_data):
|
||||
"""Peak dBFS based on the RMS energy.
|
||||
|
||||
Args:
|
||||
sample_data ([np.ndarray]): float array, [-1, 1].
|
||||
|
||||
Returns:
|
||||
float: dBFS
|
||||
"""
|
||||
return rms_to_dbfs(
|
||||
math.sqrt(np.mean(np.square(sample_data, dtype=np.float64))))
|
||||
|
||||
|
||||
def gain_db_to_ratio(gain_db: float):
|
||||
"""dB to ratio
|
||||
|
||||
Args:
|
||||
gain_db (float): gain in dB
|
||||
|
||||
Returns:
|
||||
float: scale in amp
|
||||
"""
|
||||
return math.pow(10.0, gain_db / 20.0)
|
||||
|
||||
|
||||
def normalize_audio(sample_data: np.ndarray, dbfs: float=-3.0103):
|
||||
"""Nomalize audio to dBFS.
|
||||
|
||||
Args:
|
||||
sample_data (np.ndarray): input wave samples, [-1, 1].
|
||||
dbfs (float, optional): target dBFS. Defaults to -3.0103.
|
||||
|
||||
Returns:
|
||||
np.ndarray: normalized wave
|
||||
"""
|
||||
return np.maximum(
|
||||
np.minimum(sample_data * gain_db_to_ratio(dbfs - max_dbfs(sample_data)),
|
||||
1.0), -1.0)
|
||||
|
||||
|
||||
def _load_json_cmvn(json_cmvn_file):
|
||||
""" Load the json format cmvn stats file and calculate cmvn
|
||||
|
||||
Args:
|
||||
json_cmvn_file: cmvn stats file in json format
|
||||
|
||||
Returns:
|
||||
a numpy array of [means, vars]
|
||||
"""
|
||||
with open(json_cmvn_file) as f:
|
||||
cmvn_stats = json.load(f)
|
||||
|
||||
means = cmvn_stats['mean_stat']
|
||||
variance = cmvn_stats['var_stat']
|
||||
count = cmvn_stats['frame_num']
|
||||
for i in range(len(means)):
|
||||
means[i] /= count
|
||||
variance[i] = variance[i] / count - means[i] * means[i]
|
||||
if variance[i] < 1.0e-20:
|
||||
variance[i] = 1.0e-20
|
||||
variance[i] = 1.0 / math.sqrt(variance[i])
|
||||
cmvn = np.array([means, variance])
|
||||
return cmvn
|
||||
|
||||
|
||||
def _load_kaldi_cmvn(kaldi_cmvn_file):
|
||||
""" Load the kaldi format cmvn stats file and calculate cmvn
|
||||
|
||||
Args:
|
||||
kaldi_cmvn_file: kaldi text style global cmvn file, which
|
||||
is generated by:
|
||||
compute-cmvn-stats --binary=false scp:feats.scp global_cmvn
|
||||
|
||||
Returns:
|
||||
a numpy array of [means, vars]
|
||||
"""
|
||||
means = []
|
||||
variance = []
|
||||
with open(kaldi_cmvn_file, 'r') as fid:
|
||||
# kaldi binary file start with '\0B'
|
||||
if fid.read(2) == '\0B':
|
||||
logger.error('kaldi cmvn binary file is not supported, please '
|
||||
'recompute it by: compute-cmvn-stats --binary=false '
|
||||
' scp:feats.scp global_cmvn')
|
||||
sys.exit(1)
|
||||
fid.seek(0)
|
||||
arr = fid.read().split()
|
||||
assert (arr[0] == '[')
|
||||
assert (arr[-2] == '0')
|
||||
assert (arr[-1] == ']')
|
||||
feat_dim = int((len(arr) - 2 - 2) / 2)
|
||||
for i in range(1, feat_dim + 1):
|
||||
means.append(float(arr[i]))
|
||||
count = float(arr[feat_dim + 1])
|
||||
for i in range(feat_dim + 2, 2 * feat_dim + 2):
|
||||
variance.append(float(arr[i]))
|
||||
|
||||
for i in range(len(means)):
|
||||
means[i] /= count
|
||||
variance[i] = variance[i] / count - means[i] * means[i]
|
||||
if variance[i] < 1.0e-20:
|
||||
variance[i] = 1.0e-20
|
||||
variance[i] = 1.0 / math.sqrt(variance[i])
|
||||
cmvn = np.array([means, variance])
|
||||
return cmvn
|
||||
|
||||
|
||||
def load_cmvn(cmvn_file: str, filetype: str):
|
||||
"""load cmvn from file.
|
||||
|
||||
Args:
|
||||
cmvn_file (str): cmvn path.
|
||||
filetype (str): file type, optional[npz, json, kaldi].
|
||||
|
||||
Raises:
|
||||
ValueError: file type not support.
|
||||
|
||||
Returns:
|
||||
Tuple[np.ndarray, np.ndarray]: mean, istd
|
||||
"""
|
||||
assert filetype in ['npz', 'json', 'kaldi'], filetype
|
||||
filetype = filetype.lower()
|
||||
if filetype == "json":
|
||||
cmvn = _load_json_cmvn(cmvn_file)
|
||||
elif filetype == "kaldi":
|
||||
cmvn = _load_kaldi_cmvn(cmvn_file)
|
||||
elif filetype == "npz":
|
||||
eps = 1e-14
|
||||
npzfile = np.load(cmvn_file)
|
||||
mean = np.squeeze(npzfile["mean"])
|
||||
std = np.squeeze(npzfile["std"])
|
||||
istd = 1 / (std + eps)
|
||||
cmvn = [mean, istd]
|
||||
else:
|
||||
raise ValueError(f"cmvn file type no support: {filetype}")
|
||||
return cmvn[0], cmvn[1]
|
||||
|
||||
|
||||
def convert_samples_to_float32(samples):
|
||||
"""Convert sample type to float32.
|
||||
|
||||
Audio sample type is usually integer or float-point.
|
||||
Integers will be scaled to [-1, 1] in float32.
|
||||
|
||||
PCM16 -> PCM32
|
||||
"""
|
||||
float32_samples = samples.astype('float32')
|
||||
if samples.dtype in np.sctypes['int']:
|
||||
bits = np.iinfo(samples.dtype).bits
|
||||
float32_samples *= (1. / 2**(bits - 1))
|
||||
elif samples.dtype in np.sctypes['float']:
|
||||
pass
|
||||
else:
|
||||
raise TypeError("Unsupported sample type: %s." % samples.dtype)
|
||||
return float32_samples
|
||||
|
||||
|
||||
def convert_samples_from_float32(samples, dtype):
|
||||
"""Convert sample type from float32 to dtype.
|
||||
|
||||
Audio sample type is usually integer or float-point. For integer
|
||||
type, float32 will be rescaled from [-1, 1] to the maximum range
|
||||
supported by the integer type.
|
||||
|
||||
PCM32 -> PCM16
|
||||
"""
|
||||
dtype = np.dtype(dtype)
|
||||
output_samples = samples.copy()
|
||||
if dtype in np.sctypes['int']:
|
||||
bits = np.iinfo(dtype).bits
|
||||
output_samples *= (2**(bits - 1) / 1.)
|
||||
min_val = np.iinfo(dtype).min
|
||||
max_val = np.iinfo(dtype).max
|
||||
output_samples[output_samples > max_val] = max_val
|
||||
output_samples[output_samples < min_val] = min_val
|
||||
elif samples.dtype in np.sctypes['float']:
|
||||
min_val = np.finfo(dtype).min
|
||||
max_val = np.finfo(dtype).max
|
||||
output_samples[output_samples > max_val] = max_val
|
||||
output_samples[output_samples < min_val] = min_val
|
||||
else:
|
||||
raise TypeError("Unsupported sample type: %s." % samples.dtype)
|
||||
return output_samples.astype(dtype)
|
@ -0,0 +1,35 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from espnet(https://github.com/espnet/espnet)
|
||||
import inspect
|
||||
|
||||
|
||||
def check_kwargs(func, kwargs, name=None):
|
||||
"""check kwargs are valid for func
|
||||
|
||||
If kwargs are invalid, raise TypeError as same as python default
|
||||
:param function func: function to be validated
|
||||
:param dict kwargs: keyword arguments for func
|
||||
:param str name: name used in TypeError (default is func name)
|
||||
"""
|
||||
try:
|
||||
params = inspect.signature(func).parameters
|
||||
except ValueError:
|
||||
return
|
||||
if name is None:
|
||||
name = func.__name__
|
||||
for k in kwargs.keys():
|
||||
if k not in params:
|
||||
raise TypeError(
|
||||
f"{name}() got an unexpected keyword argument '{k}'")
|
@ -0,0 +1,38 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from espnet(https://github.com/espnet/espnet)
|
||||
import importlib
|
||||
|
||||
__all__ = ["dynamic_import"]
|
||||
|
||||
|
||||
def dynamic_import(import_path, alias=dict()):
|
||||
"""dynamic import module and class
|
||||
|
||||
:param str import_path: syntax 'module_name:class_name'
|
||||
e.g., 'paddlespeech.s2t.models.u2:U2Model'
|
||||
:param dict alias: shortcut for registered class
|
||||
:return: imported class
|
||||
"""
|
||||
if import_path not in alias and ":" not in import_path:
|
||||
raise ValueError(
|
||||
"import_path should be one of {} or "
|
||||
'include ":", e.g. "paddlespeech.s2t.models.u2:U2Model" : '
|
||||
"{}".format(set(alias), import_path))
|
||||
if ":" not in import_path:
|
||||
import_path = alias[import_path]
|
||||
|
||||
module_name, objname = import_path.split(":")
|
||||
m = importlib.import_module(module_name)
|
||||
return getattr(m, objname)
|
@ -0,0 +1,192 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unility functions for Transformer."""
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import paddle
|
||||
|
||||
from .log import Logger
|
||||
|
||||
__all__ = ["pad_sequence", "add_sos_eos", "th_accuracy", "has_tensor"]
|
||||
|
||||
logger = Logger(__name__)
|
||||
|
||||
|
||||
def has_tensor(val):
|
||||
if isinstance(val, (list, tuple)):
|
||||
for item in val:
|
||||
if has_tensor(item):
|
||||
return True
|
||||
elif isinstance(val, dict):
|
||||
for k, v in val.items():
|
||||
print(k)
|
||||
if has_tensor(v):
|
||||
return True
|
||||
else:
|
||||
return paddle.is_tensor(val)
|
||||
|
||||
|
||||
def pad_sequence(sequences: List[paddle.Tensor],
|
||||
batch_first: bool=False,
|
||||
padding_value: float=0.0) -> paddle.Tensor:
|
||||
r"""Pad a list of variable length Tensors with ``padding_value``
|
||||
|
||||
``pad_sequence`` stacks a list of Tensors along a new dimension,
|
||||
and pads them to equal length. For example, if the input is list of
|
||||
sequences with size ``L x *`` and if batch_first is False, and ``T x B x *``
|
||||
otherwise.
|
||||
|
||||
`B` is batch size. It is equal to the number of elements in ``sequences``.
|
||||
`T` is length of the longest sequence.
|
||||
`L` is length of the sequence.
|
||||
`*` is any number of trailing dimensions, including none.
|
||||
|
||||
Example:
|
||||
>>> from paddle.nn.utils.rnn import pad_sequence
|
||||
>>> a = paddle.ones(25, 300)
|
||||
>>> b = paddle.ones(22, 300)
|
||||
>>> c = paddle.ones(15, 300)
|
||||
>>> pad_sequence([a, b, c]).shape
|
||||
paddle.Tensor([25, 3, 300])
|
||||
|
||||
Note:
|
||||
This function returns a Tensor of size ``T x B x *`` or ``B x T x *``
|
||||
where `T` is the length of the longest sequence. This function assumes
|
||||
trailing dimensions and type of all the Tensors in sequences are same.
|
||||
|
||||
Args:
|
||||
sequences (list[Tensor]): list of variable length sequences.
|
||||
batch_first (bool, optional): output will be in ``B x T x *`` if True, or in
|
||||
``T x B x *`` otherwise
|
||||
padding_value (float, optional): value for padded elements. Default: 0.
|
||||
|
||||
Returns:
|
||||
Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``.
|
||||
Tensor of size ``B x T x *`` otherwise
|
||||
"""
|
||||
|
||||
# assuming trailing dimensions and type of all the Tensors
|
||||
# in sequences are same and fetching those from sequences[0]
|
||||
max_size = paddle.shape(sequences[0])
|
||||
# (TODO Hui Zhang): slice not supprot `end==start`
|
||||
# trailing_dims = max_size[1:]
|
||||
trailing_dims = tuple(
|
||||
max_size[1:].numpy().tolist()) if sequences[0].ndim >= 2 else ()
|
||||
max_len = max([s.shape[0] for s in sequences])
|
||||
if batch_first:
|
||||
out_dims = (len(sequences), max_len) + trailing_dims
|
||||
else:
|
||||
out_dims = (max_len, len(sequences)) + trailing_dims
|
||||
out_tensor = paddle.full(out_dims, padding_value, sequences[0].dtype)
|
||||
for i, tensor in enumerate(sequences):
|
||||
length = tensor.shape[0]
|
||||
# use index notation to prevent duplicate references to the tensor
|
||||
if batch_first:
|
||||
# TODO (Hui Zhang): set_value op not supprot `end==start`
|
||||
# TODO (Hui Zhang): set_value op not support int16
|
||||
# TODO (Hui Zhang): set_varbase 2 rank not support [0,0,...]
|
||||
# out_tensor[i, :length, ...] = tensor
|
||||
if length != 0:
|
||||
out_tensor[i, :length] = tensor
|
||||
else:
|
||||
out_tensor[i, length] = tensor
|
||||
else:
|
||||
# TODO (Hui Zhang): set_value op not supprot `end==start`
|
||||
# out_tensor[:length, i, ...] = tensor
|
||||
if length != 0:
|
||||
out_tensor[:length, i] = tensor
|
||||
else:
|
||||
out_tensor[length, i] = tensor
|
||||
|
||||
return out_tensor
|
||||
|
||||
|
||||
def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
|
||||
ignore_id: int) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""Add <sos> and <eos> labels.
|
||||
Args:
|
||||
ys_pad (paddle.Tensor): batch of padded target sequences (B, Lmax)
|
||||
sos (int): index of <sos>
|
||||
eos (int): index of <eeos>
|
||||
ignore_id (int): index of padding
|
||||
Returns:
|
||||
ys_in (paddle.Tensor) : (B, Lmax + 1)
|
||||
ys_out (paddle.Tensor) : (B, Lmax + 1)
|
||||
Examples:
|
||||
>>> sos_id = 10
|
||||
>>> eos_id = 11
|
||||
>>> ignore_id = -1
|
||||
>>> ys_pad
|
||||
tensor([[ 1, 2, 3, 4, 5],
|
||||
[ 4, 5, 6, -1, -1],
|
||||
[ 7, 8, 9, -1, -1]], dtype=paddle.int32)
|
||||
>>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id)
|
||||
>>> ys_in
|
||||
tensor([[10, 1, 2, 3, 4, 5],
|
||||
[10, 4, 5, 6, 11, 11],
|
||||
[10, 7, 8, 9, 11, 11]])
|
||||
>>> ys_out
|
||||
tensor([[ 1, 2, 3, 4, 5, 11],
|
||||
[ 4, 5, 6, 11, -1, -1],
|
||||
[ 7, 8, 9, 11, -1, -1]])
|
||||
"""
|
||||
# TODO(Hui Zhang): using comment code,
|
||||
#_sos = paddle.to_tensor(
|
||||
# [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
|
||||
#_eos = paddle.to_tensor(
|
||||
# [eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
|
||||
#ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
|
||||
#ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]
|
||||
#ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]
|
||||
#return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)
|
||||
B = ys_pad.shape[0]
|
||||
_sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos
|
||||
_eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos
|
||||
ys_in = paddle.cat([_sos, ys_pad], dim=1)
|
||||
mask_pad = (ys_in == ignore_id)
|
||||
ys_in = ys_in.masked_fill(mask_pad, eos)
|
||||
|
||||
ys_out = paddle.cat([ys_pad, _eos], dim=1)
|
||||
ys_out = ys_out.masked_fill(mask_pad, eos)
|
||||
mask_eos = (ys_out == ignore_id)
|
||||
ys_out = ys_out.masked_fill(mask_eos, eos)
|
||||
ys_out = ys_out.masked_fill(mask_pad, ignore_id)
|
||||
return ys_in, ys_out
|
||||
|
||||
|
||||
def th_accuracy(pad_outputs: paddle.Tensor,
|
||||
pad_targets: paddle.Tensor,
|
||||
ignore_label: int) -> float:
|
||||
"""Calculate accuracy.
|
||||
Args:
|
||||
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
|
||||
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
|
||||
ignore_label (int): Ignore label id.
|
||||
Returns:
|
||||
float: Accuracy value (0.0 - 1.0).
|
||||
"""
|
||||
pad_pred = pad_outputs.view(pad_targets.shape[0], pad_targets.shape[1],
|
||||
pad_outputs.shape[1]).argmax(2)
|
||||
mask = pad_targets != ignore_label
|
||||
#TODO(Hui Zhang): sum not support bool type
|
||||
# numerator = paddle.sum(
|
||||
# pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
|
||||
numerator = (
|
||||
pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
|
||||
numerator = paddle.sum(numerator.type_as(pad_targets))
|
||||
#TODO(Hui Zhang): sum not support bool type
|
||||
# denominator = paddle.sum(mask)
|
||||
denominator = paddle.sum(mask.type_as(pad_targets))
|
||||
return float(numerator) / float(denominator)
|
Loading…
Reference in new issue