You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
284 lines
9.2 KiB
284 lines
9.2 KiB
#
|
|
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
# This file is part of the WebDataset library.
|
|
# See the LICENSE file for licensing terms (BSD-style).
|
|
|
|
# Modified from https://github.com/webdataset/webdataset
|
|
# Modified from wenet(https://github.com/wenet-e2e/wenet)
|
|
|
|
"""Low level iteration functions for tar archives."""
|
|
|
|
import random, re, tarfile
|
|
|
|
import braceexpand
|
|
|
|
from . import filters
|
|
from webdataset import gopen
|
|
from webdataset.handlers import reraise_exception
|
|
|
|
trace = False
|
|
meta_prefix = "__"
|
|
meta_suffix = "__"
|
|
|
|
from ... import audio as paddleaudio
|
|
import paddle
|
|
import numpy as np
|
|
|
|
AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
|
|
|
|
def base_plus_ext(path):
|
|
"""Split off all file extensions.
|
|
|
|
Returns base, allext.
|
|
|
|
:param path: path with extensions
|
|
:param returns: path with all extensions removed
|
|
|
|
"""
|
|
match = re.match(r"^((?:.*/|)[^.]+)[.]([^/]*)$", path)
|
|
if not match:
|
|
return None, None
|
|
return match.group(1), match.group(2)
|
|
|
|
|
|
def valid_sample(sample):
|
|
"""Check whether a sample is valid.
|
|
|
|
:param sample: sample to be checked
|
|
"""
|
|
return (
|
|
sample is not None
|
|
and isinstance(sample, dict)
|
|
and len(list(sample.keys())) > 0
|
|
and not sample.get("__bad__", False)
|
|
)
|
|
|
|
|
|
# FIXME: UNUSED
|
|
def shardlist(urls, *, shuffle=False):
|
|
"""Given a list of URLs, yields that list, possibly shuffled."""
|
|
if isinstance(urls, str):
|
|
urls = braceexpand.braceexpand(urls)
|
|
else:
|
|
urls = list(urls)
|
|
if shuffle:
|
|
random.shuffle(urls)
|
|
for url in urls:
|
|
yield dict(url=url)
|
|
|
|
|
|
def url_opener(data, handler=reraise_exception, **kw):
|
|
"""Given a stream of url names (packaged in `dict(url=url)`), yield opened streams."""
|
|
for sample in data:
|
|
assert isinstance(sample, dict), sample
|
|
assert "url" in sample
|
|
url = sample["url"]
|
|
try:
|
|
stream = gopen.gopen(url, **kw)
|
|
sample.update(stream=stream)
|
|
yield sample
|
|
except Exception as exn:
|
|
exn.args = exn.args + (url,)
|
|
if handler(exn):
|
|
continue
|
|
else:
|
|
break
|
|
|
|
|
|
def tar_file_iterator(
|
|
fileobj, skip_meta=r"__[^/]*__($|/)", handler=reraise_exception
|
|
):
|
|
"""Iterate over tar file, yielding filename, content pairs for the given tar stream.
|
|
|
|
:param fileobj: byte stream suitable for tarfile
|
|
:param skip_meta: regexp for keys that are skipped entirely (Default value = r"__[^/]*__($|/)")
|
|
|
|
"""
|
|
stream = tarfile.open(fileobj=fileobj, mode="r:*")
|
|
for tarinfo in stream:
|
|
fname = tarinfo.name
|
|
try:
|
|
if not tarinfo.isreg():
|
|
continue
|
|
if fname is None:
|
|
continue
|
|
if (
|
|
"/" not in fname
|
|
and fname.startswith(meta_prefix)
|
|
and fname.endswith(meta_suffix)
|
|
):
|
|
# skipping metadata for now
|
|
continue
|
|
if skip_meta is not None and re.match(skip_meta, fname):
|
|
continue
|
|
|
|
name = tarinfo.name
|
|
pos = name.rfind('.')
|
|
assert pos > 0
|
|
prefix, postfix = name[:pos], name[pos + 1:]
|
|
if postfix == 'wav':
|
|
waveform, sample_rate = paddleaudio.load(stream.extractfile(tarinfo), normal=False)
|
|
result = dict(fname=prefix, wav=waveform, sample_rate = sample_rate)
|
|
else:
|
|
txt = stream.extractfile(tarinfo).read().decode('utf8').strip()
|
|
result = dict(fname=prefix, txt=txt)
|
|
#result = dict(fname=fname, data=data)
|
|
yield result
|
|
stream.members = []
|
|
except Exception as exn:
|
|
if hasattr(exn, "args") and len(exn.args) > 0:
|
|
exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:]
|
|
if handler(exn):
|
|
continue
|
|
else:
|
|
break
|
|
del stream
|
|
|
|
def tar_file_and_group_iterator(
|
|
fileobj, skip_meta=r"__[^/]*__($|/)", handler=reraise_exception
|
|
):
|
|
""" Expand a stream of open tar files into a stream of tar file contents.
|
|
And groups the file with same prefix
|
|
|
|
Args:
|
|
data: Iterable[{src, stream}]
|
|
|
|
Returns:
|
|
Iterable[{key, wav, txt, sample_rate}]
|
|
"""
|
|
stream = tarfile.open(fileobj=fileobj, mode="r:*")
|
|
prev_prefix = None
|
|
example = {}
|
|
valid = True
|
|
for tarinfo in stream:
|
|
name = tarinfo.name
|
|
pos = name.rfind('.')
|
|
assert pos > 0
|
|
prefix, postfix = name[:pos], name[pos + 1:]
|
|
if prev_prefix is not None and prefix != prev_prefix:
|
|
example['fname'] = prev_prefix
|
|
if valid:
|
|
yield example
|
|
example = {}
|
|
valid = True
|
|
with stream.extractfile(tarinfo) as file_obj:
|
|
try:
|
|
if postfix == 'txt':
|
|
example['txt'] = file_obj.read().decode('utf8').strip()
|
|
elif postfix in AUDIO_FORMAT_SETS:
|
|
waveform, sample_rate = paddleaudio.load(file_obj, normal=False)
|
|
waveform = paddle.to_tensor(np.expand_dims(np.array(waveform),0), dtype=paddle.float32)
|
|
|
|
example['wav'] = waveform
|
|
example['sample_rate'] = sample_rate
|
|
else:
|
|
example[postfix] = file_obj.read()
|
|
except Exception as exn:
|
|
if hasattr(exn, "args") and len(exn.args) > 0:
|
|
exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:]
|
|
if handler(exn):
|
|
continue
|
|
else:
|
|
break
|
|
valid = False
|
|
# logging.warning('error to parse {}'.format(name))
|
|
prev_prefix = prefix
|
|
if prev_prefix is not None:
|
|
example['fname'] = prev_prefix
|
|
yield example
|
|
stream.close()
|
|
|
|
def tar_file_expander(data, handler=reraise_exception):
|
|
"""Expand a stream of open tar files into a stream of tar file contents.
|
|
|
|
This returns an iterator over (filename, file_contents).
|
|
"""
|
|
for source in data:
|
|
url = source["url"]
|
|
try:
|
|
assert isinstance(source, dict)
|
|
assert "stream" in source
|
|
for sample in tar_file_iterator(source["stream"]):
|
|
assert (
|
|
isinstance(sample, dict) and "data" in sample and "fname" in sample
|
|
)
|
|
sample["__url__"] = url
|
|
yield sample
|
|
except Exception as exn:
|
|
exn.args = exn.args + (source.get("stream"), source.get("url"))
|
|
if handler(exn):
|
|
continue
|
|
else:
|
|
break
|
|
|
|
|
|
|
|
|
|
def tar_file_and_group_expander(data, handler=reraise_exception):
|
|
"""Expand a stream of open tar files into a stream of tar file contents.
|
|
|
|
This returns an iterator over (filename, file_contents).
|
|
"""
|
|
for source in data:
|
|
url = source["url"]
|
|
try:
|
|
assert isinstance(source, dict)
|
|
assert "stream" in source
|
|
for sample in tar_file_and_group_iterator(source["stream"]):
|
|
assert (
|
|
isinstance(sample, dict) and "wav" in sample and "txt" in sample and "fname" in sample
|
|
)
|
|
sample["__url__"] = url
|
|
yield sample
|
|
except Exception as exn:
|
|
exn.args = exn.args + (source.get("stream"), source.get("url"))
|
|
if handler(exn):
|
|
continue
|
|
else:
|
|
break
|
|
|
|
|
|
def group_by_keys(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
|
|
"""Return function over iterator that groups key, value pairs into samples.
|
|
|
|
:param keys: function that splits the key into key and extension (base_plus_ext)
|
|
:param lcase: convert suffixes to lower case (Default value = True)
|
|
"""
|
|
current_sample = None
|
|
for filesample in data:
|
|
assert isinstance(filesample, dict)
|
|
fname, value = filesample["fname"], filesample["data"]
|
|
prefix, suffix = keys(fname)
|
|
if trace:
|
|
print(
|
|
prefix,
|
|
suffix,
|
|
current_sample.keys() if isinstance(current_sample, dict) else None,
|
|
)
|
|
if prefix is None:
|
|
continue
|
|
if lcase:
|
|
suffix = suffix.lower()
|
|
if current_sample is None or prefix != current_sample["__key__"]:
|
|
if valid_sample(current_sample):
|
|
yield current_sample
|
|
current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
|
|
if suffix in current_sample:
|
|
raise ValueError(
|
|
f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}"
|
|
)
|
|
if suffixes is None or suffix in suffixes:
|
|
current_sample[suffix] = value
|
|
if valid_sample(current_sample):
|
|
yield current_sample
|
|
|
|
|
|
def tarfile_samples(src, handler=reraise_exception):
|
|
streams = url_opener(src, handler=handler)
|
|
samples = tar_file_and_group_expander(streams, handler=handler)
|
|
return samples
|
|
|
|
|
|
tarfile_to_samples = filters.pipelinefilter(tarfile_samples)
|