From 0c7abc1f1753d4af1b104a3aae30b5f55661b8c1 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Mon, 27 Jun 2022 11:06:30 +0000 Subject: [PATCH] add training scripts --- examples/wenetspeech/asr1/conf/conformer.yaml | 28 +- examples/wenetspeech/asr1/local/data.sh | 125 ++--- .../asr1/local/wenetspeech_data_prep.sh | 4 +- .../{stream_data => streamdata}/__init__.py | 6 +- paddlespeech/audio/streamdata/autodecode.py | 445 +++++++++++++++++ .../{stream_data => streamdata}/cache.py | 4 +- .../{stream_data => streamdata}/compat.py | 2 +- .../audio/streamdata/extradatasets.py | 141 ++++++ .../{stream_data => streamdata}/filters.py | 4 +- paddlespeech/audio/streamdata/gopen.py | 340 +++++++++++++ paddlespeech/audio/streamdata/handlers.py | 47 ++ paddlespeech/audio/streamdata/mix.py | 85 ++++ .../paddle_utils.py | 0 .../{stream_data => streamdata}/pipeline.py | 3 +- .../{stream_data => streamdata}/shardlists.py | 0 .../tariterators.py | 4 +- .../{stream_data => streamdata}/utils.py | 0 paddlespeech/audio/streamdata/writer.py | 450 ++++++++++++++++++ paddlespeech/s2t/io/dataloader.py | 61 ++- setup.py | 3 +- 20 files changed, 1620 insertions(+), 132 deletions(-) rename paddlespeech/audio/{stream_data => streamdata}/__init__.py (87%) create mode 100644 paddlespeech/audio/streamdata/autodecode.py rename paddlespeech/audio/{stream_data => streamdata}/cache.py (98%) rename paddlespeech/audio/{stream_data => streamdata}/compat.py (99%) create mode 100644 paddlespeech/audio/streamdata/extradatasets.py rename paddlespeech/audio/{stream_data => streamdata}/filters.py (99%) create mode 100644 paddlespeech/audio/streamdata/gopen.py create mode 100644 paddlespeech/audio/streamdata/handlers.py create mode 100644 paddlespeech/audio/streamdata/mix.py rename paddlespeech/audio/{stream_data => streamdata}/paddle_utils.py (100%) rename paddlespeech/audio/{stream_data => streamdata}/pipeline.py (96%) rename paddlespeech/audio/{stream_data => streamdata}/shardlists.py (100%) rename paddlespeech/audio/{stream_data => streamdata}/tariterators.py (99%) rename paddlespeech/audio/{stream_data => streamdata}/utils.py (100%) create mode 100644 paddlespeech/audio/streamdata/writer.py diff --git a/examples/wenetspeech/asr1/conf/conformer.yaml b/examples/wenetspeech/asr1/conf/conformer.yaml index dd4ff0e2..f46d4bd9 100644 --- a/examples/wenetspeech/asr1/conf/conformer.yaml +++ b/examples/wenetspeech/asr1/conf/conformer.yaml @@ -1,7 +1,6 @@ ############################################ # Network Architecture # ############################################ -cmvn_file: cmvn_file_type: "json" # encoder related encoder: conformer @@ -43,9 +42,9 @@ model_conf: ########################################### # Data # ########################################### -train_manifest: data/manifest.train -dev_manifest: data/manifest.dev -test_manifest: data/manifest.test +train_manifest: data/train_l/data.list +dev_manifest: data/dev/data.list +test_manifest: data/test_meeting/data.list ########################################### # Dataloader # @@ -54,23 +53,22 @@ use_stream_data: True unit_type: 'char' vocab_filepath: data/lang_char/vocab.txt cmvn_file: data/mean_std.json -preprocess_config: conf/preprocess.yaml spm_model_prefix: '' feat_dim: 80 stride_ms: 10.0 window_ms: 25.0 dither: 0.1 sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs -batch_size: 64 +batch_size: 32 minlen_in: 10 -maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced +maxlen_in: 1200 # if input length(number of frames) > maxlen-in, data is automatically removed minlen_out: 0 -maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced +maxlen_out: 150 # if output length(number of tokens) > maxlen-out, data is automatically removed resample_rate: 16000 -shuffle_size: 10000 -sort_size: 500 -num_workers: 4 -prefetch_factor: 100 +shuffle_size: 1500 +sort_size: 1000 +num_workers: 0 +prefetch_factor: 10 dist_sampler: True num_encs: 1 augment_conf: @@ -90,10 +88,10 @@ augment_conf: ########################################### # Training # ########################################### -n_epoch: 240 -accum_grad: 16 +n_epoch: 30 +accum_grad: 32 global_grad_clip: 5.0 -log_interval: 1 +log_interval: 100 checkpoint: kbest_n: 50 latest_n: 5 diff --git a/examples/wenetspeech/asr1/local/data.sh b/examples/wenetspeech/asr1/local/data.sh index d216dd84..b3472a8f 100755 --- a/examples/wenetspeech/asr1/local/data.sh +++ b/examples/wenetspeech/asr1/local/data.sh @@ -2,6 +2,8 @@ # Copyright 2021 Mobvoi Inc(Author: Di Wu, Binbin Zhang) # NPU, ASLP Group (Author: Qijie Shao) +# +# Modified from wenet(https://github.com/wenet-e2e/wenet) stage=-1 stop_stage=100 @@ -30,7 +32,7 @@ mkdir -p data TARGET_DIR=${MAIN_ROOT}/dataset mkdir -p ${TARGET_DIR} -if [ ${stage} -le -2 ] && [ ${stop_stage} -ge -2 ]; then +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then # download data echo "Please follow https://github.com/wenet-e2e/WenetSpeech to download the data." exit 0; @@ -44,86 +46,57 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then data || exit 1; fi -if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then - # generate manifests - python3 ${TARGET_DIR}/aishell/aishell.py \ - --manifest_prefix="data/manifest" \ - --target_dir="${TARGET_DIR}/aishell" - - if [ $? -ne 0 ]; then - echo "Prepare Aishell failed. Terminated." - exit 1 - fi - - for dataset in train dev test; do - mv data/manifest.${dataset} data/manifest.${dataset}.raw - done -fi - -if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then - # compute mean and stddev for normalizer - if $cmvn; then - full_size=`cat data/${train_set}/wav.scp | wc -l` - sampling_size=$((full_size / cmvn_sampling_divisor)) - shuf -n $sampling_size data/$train_set/wav.scp \ - > data/$train_set/wav.scp.sampled - num_workers=$(nproc) - - python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ - --manifest_path="data/manifest.train.raw" \ - --spectrum_type="fbank" \ - --feat_dim=80 \ - --delta_delta=false \ - --stride_ms=10 \ - --window_ms=25 \ - --sample_rate=16000 \ - --use_dB_normalization=False \ - --num_samples=-1 \ - --num_workers=${num_workers} \ - --output_path="data/mean_std.json" - - if [ $? -ne 0 ]; then - echo "Compute mean and stddev failed. Terminated." - exit 1 - fi - fi -fi - -dict=data/dict/lang_char.txt +dict=data/lang_char/vocab.txt if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then - # download data, generate manifests - # build vocabulary - python3 ${MAIN_ROOT}/utils/build_vocab.py \ - --unit_type="char" \ - --count_threshold=0 \ - --vocab_path="data/lang_char/vocab.txt" \ - --manifest_paths "data/manifest.train.raw" - - if [ $? -ne 0 ]; then - echo "Build vocabulary failed. Terminated." - exit 1 - fi + echo "Make a dictionary" + echo "dictionary: ${dict}" + mkdir -p $(dirname $dict) + echo "" > ${dict} # 0 will be used for "blank" in CTC + echo "" >> ${dict} # must be 1 + echo "▁" >> ${dict} # ▁ is for space + utils/text2token.py -s 1 -n 1 --space "▁" data/${train_set}/text \ + | cut -f 2- -d" " | tr " " "\n" \ + | sort | uniq | grep -a -v -e '^\s*$' \ + | grep -v "▁" \ + | awk '{print $0}' >> ${dict} \ + || exit 1; + echo "" >> $dict fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then - # format manifest with tokenids, vocab size - for dataset in train dev test; do - { - python3 ${MAIN_ROOT}/utils/format_data.py \ - --cmvn_path "data/mean_std.json" \ - --unit_type "char" \ - --vocab_path="data/vocab.txt" \ - --manifest_path="data/manifest.${dataset}.raw" \ - --output_path="data/manifest.${dataset}" + echo "Compute cmvn" + # Here we use all the training data, you can sample some some data to save time + # BUG!!! We should use the segmented data for CMVN + if $cmvn; then + full_size=`cat data/${train_set}/wav.scp | wc -l` + sampling_size=$((full_size / cmvn_sampling_divisor)) + shuf -n $sampling_size data/$train_set/wav.scp \ + > data/$train_set/wav.scp.sampled + python3 utils/compute_cmvn_stats.py \ + --num_workers 16 \ + --train_config $train_config \ + --in_scp data/$train_set/wav.scp.sampled \ + --out_cmvn data/$train_set/mean_std.json \ + || exit 1; + fi +fi - if [ $? -ne 0 ]; then - echo "Formt mnaifest failed. Terminated." - exit 1 - fi - } & - done - wait +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + echo "Making shards, please wait..." + RED='\033[0;31m' + NOCOLOR='\033[0m' + echo -e "It requires ${RED}1.2T ${NOCOLOR}space for $shards_dir, please make sure you have enough space" + echo -e "It takes about ${RED}12 ${NOCOLOR}hours with 32 threads" + for x in $dev_set $test_sets ${train_set}; do + dst=$shards_dir/$x + mkdir -p $dst + utils/make_filted_shard_list.py --resample 16000 --num_utts_per_shard 1000 \ + --do_filter --num_node 1 --num_gpus_per_node 8 \ + --num_threads 32 --segments data/$x/segments \ + data/$x/wav.scp data/$x/text \ + $(realpath $dst) data/$x/data.list + done fi -echo "Aishell data preparation done." +echo "Wenetspeech data preparation done." exit 0 diff --git a/examples/wenetspeech/asr1/local/wenetspeech_data_prep.sh b/examples/wenetspeech/asr1/local/wenetspeech_data_prep.sh index 85853053..baa2b32d 100755 --- a/examples/wenetspeech/asr1/local/wenetspeech_data_prep.sh +++ b/examples/wenetspeech/asr1/local/wenetspeech_data_prep.sh @@ -24,7 +24,7 @@ stage=1 prefix= train_subset=L -. ./tools/parse_options.sh || exit 1; +. ./utils/parse_options.sh || exit 1; filter_by_id () { idlist=$1 @@ -132,4 +132,4 @@ if [ $stage -le 2 ]; then done fi -echo "$0: Done" \ No newline at end of file +echo "$0: Done" diff --git a/paddlespeech/audio/stream_data/__init__.py b/paddlespeech/audio/streamdata/__init__.py similarity index 87% rename from paddlespeech/audio/stream_data/__init__.py rename to paddlespeech/audio/streamdata/__init__.py index e9706d4e..d84fbb52 100644 --- a/paddlespeech/audio/stream_data/__init__.py +++ b/paddlespeech/audio/streamdata/__init__.py @@ -11,7 +11,7 @@ from .cache import ( pipe_cleaner, ) from .compat import WebDataset, WebLoader, FluidWrapper -from webdataset.extradatasets import MockDataset, with_epoch, with_length +from .extradatasets import MockDataset, with_epoch, with_length from .filters import ( associate, batched, @@ -65,5 +65,5 @@ from .shardlists import ( ) from .tariterators import tarfile_samples, tarfile_to_samples from .utils import PipelineStage, repeatedly -from webdataset.writer import ShardWriter, TarWriter, numpy_dumps -from webdataset.mix import RandomMix, RoundRobin +from .writer import ShardWriter, TarWriter, numpy_dumps +from .mix import RandomMix, RoundRobin diff --git a/paddlespeech/audio/streamdata/autodecode.py b/paddlespeech/audio/streamdata/autodecode.py new file mode 100644 index 00000000..8c74b685 --- /dev/null +++ b/paddlespeech/audio/streamdata/autodecode.py @@ -0,0 +1,445 @@ +# +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). +# Modified from https://github.com/webdataset/webdataset +# + +"""Automatically decode webdataset samples.""" + +import io, json, os, pickle, re, tempfile +from functools import partial + +import numpy as np + +"""Extensions passed on to the image decoder.""" +image_extensions = "jpg jpeg png ppm pgm pbm pnm".split() + + +################################################################ +# handle basic datatypes +################################################################ + + +def paddle_loads(data): + """Load data using paddle.loads, importing paddle only if needed. + + :param data: data to be decoded + """ + import io + + import paddle + + stream = io.BytesIO(data) + return paddle.load(stream) + + +def tenbin_loads(data): + from . import tenbin + + return tenbin.decode_buffer(data) + + +def msgpack_loads(data): + import msgpack + + return msgpack.unpackb(data) + + +def npy_loads(data): + import numpy.lib.format + + stream = io.BytesIO(data) + return numpy.lib.format.read_array(stream) + + +def cbor_loads(data): + import cbor + + return cbor.loads(data) + + +decoders = { + "txt": lambda data: data.decode("utf-8"), + "text": lambda data: data.decode("utf-8"), + "transcript": lambda data: data.decode("utf-8"), + "cls": lambda data: int(data), + "cls2": lambda data: int(data), + "index": lambda data: int(data), + "inx": lambda data: int(data), + "id": lambda data: int(data), + "json": lambda data: json.loads(data), + "jsn": lambda data: json.loads(data), + "pyd": lambda data: pickle.loads(data), + "pickle": lambda data: pickle.loads(data), + "pdparams": lambda data: paddle_loads(data), + "ten": tenbin_loads, + "tb": tenbin_loads, + "mp": msgpack_loads, + "msg": msgpack_loads, + "npy": npy_loads, + "npz": lambda data: np.load(io.BytesIO(data)), + "cbor": cbor_loads, +} + + +def basichandlers(key, data): + """Handle basic file decoding. + + This function is usually part of the post= decoders. + This handles the following forms of decoding: + + - txt -> unicode string + - cls cls2 class count index inx id -> int + - json jsn -> JSON decoding + - pyd pickle -> pickle decoding + - pdparams -> paddle.loads + - ten tenbin -> fast tensor loading + - mp messagepack msg -> messagepack decoding + - npy -> Python NPY decoding + + :param key: file name extension + :param data: binary data to be decoded + """ + extension = re.sub(r".*[.]", "", key) + + if extension in decoders: + return decoders[extension](data) + + return None + + +################################################################ +# Generic extension handler. +################################################################ + + +def call_extension_handler(key, data, f, extensions): + """Call the function f with the given data if the key matches the extensions. + + :param key: actual key found in the sample + :param data: binary data + :param f: decoder function + :param extensions: list of matching extensions + """ + extension = key.lower().split(".") + for target in extensions: + target = target.split(".") + if len(target) > len(extension): + continue + if extension[-len(target) :] == target: + return f(data) + return None + + +def handle_extension(extensions, f): + """Return a decoder function for the list of extensions. + + Extensions can be a space separated list of extensions. + Extensions can contain dots, in which case the corresponding number + of extension components must be present in the key given to f. + Comparisons are case insensitive. + + Examples: + handle_extension("jpg jpeg", my_decode_jpg) # invoked for any file.jpg + handle_extension("seg.jpg", special_case_jpg) # invoked only for file.seg.jpg + """ + extensions = extensions.lower().split() + return partial(call_extension_handler, f=f, extensions=extensions) + + +################################################################ +# handle images +################################################################ + +imagespecs = { + "l8": ("numpy", "uint8", "l"), + "rgb8": ("numpy", "uint8", "rgb"), + "rgba8": ("numpy", "uint8", "rgba"), + "l": ("numpy", "float", "l"), + "rgb": ("numpy", "float", "rgb"), + "rgba": ("numpy", "float", "rgba"), + "paddlel8": ("paddle", "uint8", "l"), + "paddlergb8": ("paddle", "uint8", "rgb"), + "paddlergba8": ("paddle", "uint8", "rgba"), + "paddlel": ("paddle", "float", "l"), + "paddlergb": ("paddle", "float", "rgb"), + "paddle": ("paddle", "float", "rgb"), + "paddlergba": ("paddle", "float", "rgba"), + "pill": ("pil", None, "l"), + "pil": ("pil", None, "rgb"), + "pilrgb": ("pil", None, "rgb"), + "pilrgba": ("pil", None, "rgba"), +} + + +class ImageHandler: + """Decode image data using the given `imagespec`. + + The `imagespec` specifies whether the image is decoded + to numpy/paddle/pi, decoded to uint8/float, and decoded + to l/rgb/rgba: + + - l8: numpy uint8 l + - rgb8: numpy uint8 rgb + - rgba8: numpy uint8 rgba + - l: numpy float l + - rgb: numpy float rgb + - rgba: numpy float rgba + - paddlel8: paddle uint8 l + - paddlergb8: paddle uint8 rgb + - paddlergba8: paddle uint8 rgba + - paddlel: paddle float l + - paddlergb: paddle float rgb + - paddle: paddle float rgb + - paddlergba: paddle float rgba + - pill: pil None l + - pil: pil None rgb + - pilrgb: pil None rgb + - pilrgba: pil None rgba + + """ + + def __init__(self, imagespec, extensions=image_extensions): + """Create an image handler. + + :param imagespec: short string indicating the type of decoding + :param extensions: list of extensions the image handler is invoked for + """ + if imagespec not in list(imagespecs.keys()): + raise ValueError("Unknown imagespec: %s" % imagespec) + self.imagespec = imagespec.lower() + self.extensions = extensions + + def __call__(self, key, data): + """Perform image decoding. + + :param key: file name extension + :param data: binary data + """ + import PIL.Image + + extension = re.sub(r".*[.]", "", key) + if extension.lower() not in self.extensions: + return None + imagespec = self.imagespec + atype, etype, mode = imagespecs[imagespec] + with io.BytesIO(data) as stream: + img = PIL.Image.open(stream) + img.load() + img = img.convert(mode.upper()) + if atype == "pil": + return img + elif atype == "numpy": + result = np.asarray(img) + if result.dtype != np.uint8: + raise ValueError("ImageHandler: numpy image must be uint8") + if etype == "uint8": + return result + else: + return result.astype("f") / 255.0 + elif atype == "paddle": + import paddle + + result = np.asarray(img) + if result.dtype != np.uint8: + raise ValueError("ImageHandler: paddle image must be uint8") + if etype == "uint8": + result = np.array(result.transpose(2, 0, 1)) + return paddle.tensor(result) + else: + result = np.array(result.transpose(2, 0, 1)) + return paddle.tensor(result) / 255.0 + return None + + +def imagehandler(imagespec, extensions=image_extensions): + """Create an image handler. + + This is just a lower case alias for ImageHander. + + :param imagespec: textual image spec + :param extensions: list of extensions the handler should be applied for + """ + return ImageHandler(imagespec, extensions) + + +################################################################ +# torch video +################################################################ + +''' +def torch_video(key, data): + """Decode video using the torchvideo library. + + :param key: file name extension + :param data: data to be decoded + """ + extension = re.sub(r".*[.]", "", key) + if extension not in "mp4 ogv mjpeg avi mov h264 mpg webm wmv".split(): + return None + + import torchvision.io + + with tempfile.TemporaryDirectory() as dirname: + fname = os.path.join(dirname, f"file.{extension}") + with open(fname, "wb") as stream: + stream.write(data) + return torchvision.io.read_video(fname, pts_unit="sec") +''' + + +################################################################ +# paddleaudio +################################################################ + + +def paddle_audio(key, data): + """Decode audio using the paddleaudio library. + + :param key: file name extension + :param data: data to be decoded + """ + extension = re.sub(r".*[.]", "", key) + if extension not in ["flac", "mp3", "sox", "wav", "m4a", "ogg", "wma"]: + return None + + import paddleaudio + + with tempfile.TemporaryDirectory() as dirname: + fname = os.path.join(dirname, f"file.{extension}") + with open(fname, "wb") as stream: + stream.write(data) + return paddleaudio.load(fname) + + +################################################################ +# special class for continuing decoding +################################################################ + + +class Continue: + """Special class for continuing decoding. + + This is mostly used for decompression, as in: + + def decompressor(key, data): + if key.endswith(".gz"): + return Continue(key[:-3], decompress(data)) + return None + """ + + def __init__(self, key, data): + """__init__. + + :param key: + :param data: + """ + self.key, self.data = key, data + + +def gzfilter(key, data): + """Decode .gz files. + + This decodes compressed files and the continues decoding. + + :param key: file name extension + :param data: binary data + """ + import gzip + + if not key.endswith(".gz"): + return None + decompressed = gzip.open(io.BytesIO(data)).read() + return Continue(key[:-3], decompressed) + + +################################################################ +# decode entire training amples +################################################################ + + +default_pre_handlers = [gzfilter] +default_post_handlers = [basichandlers] + + +class Decoder: + """Decode samples using a list of handlers. + + For each key/data item, this iterates through the list of + handlers until some handler returns something other than None. + """ + + def __init__(self, handlers, pre=None, post=None, only=None, partial=False): + """Create a Decoder. + + :param handlers: main list of handlers + :param pre: handlers called before the main list (.gz handler by default) + :param post: handlers called after the main list (default handlers by default) + :param only: a list of extensions; when give, only ignores files with those extensions + :param partial: allow partial decoding (i.e., don't decode fields that aren't of type bytes) + """ + if isinstance(only, str): + only = only.split() + self.only = only if only is None else set(only) + if pre is None: + pre = default_pre_handlers + if post is None: + post = default_post_handlers + assert all(callable(h) for h in handlers), f"one of {handlers} not callable" + assert all(callable(h) for h in pre), f"one of {pre} not callable" + assert all(callable(h) for h in post), f"one of {post} not callable" + self.handlers = pre + handlers + post + self.partial = partial + + def decode1(self, key, data): + """Decode a single field of a sample. + + :param key: file name extension + :param data: binary data + """ + key = "." + key + for f in self.handlers: + result = f(key, data) + if isinstance(result, Continue): + key, data = result.key, result.data + continue + if result is not None: + return result + return data + + def decode(self, sample): + """Decode an entire sample. + + :param sample: the sample, a dictionary of key value pairs + """ + result = {} + assert isinstance(sample, dict), sample + for k, v in list(sample.items()): + if k[0] == "_": + if isinstance(v, bytes): + v = v.decode("utf-8") + result[k] = v + continue + if self.only is not None and k not in self.only: + result[k] = v + continue + assert v is not None + if self.partial: + if isinstance(v, bytes): + result[k] = self.decode1(k, v) + else: + result[k] = v + else: + assert isinstance(v, bytes) + result[k] = self.decode1(k, v) + return result + + def __call__(self, sample): + """Decode an entire sample. + + :param sample: the sample + """ + assert isinstance(sample, dict), (len(sample), sample) + return self.decode(sample) diff --git a/paddlespeech/audio/stream_data/cache.py b/paddlespeech/audio/streamdata/cache.py similarity index 98% rename from paddlespeech/audio/stream_data/cache.py rename to paddlespeech/audio/streamdata/cache.py index 724f6911..e7bbffa1 100644 --- a/paddlespeech/audio/stream_data/cache.py +++ b/paddlespeech/audio/streamdata/cache.py @@ -6,8 +6,8 @@ import itertools, os, random, re, sys from urllib.parse import urlparse from . import filters -from webdataset import gopen -from webdataset.handlers import reraise_exception +from . import gopen +from .handlers import reraise_exception from .tariterators import tar_file_and_group_expander default_cache_dir = os.environ.get("WDS_CACHE", "./_cache") diff --git a/paddlespeech/audio/stream_data/compat.py b/paddlespeech/audio/streamdata/compat.py similarity index 99% rename from paddlespeech/audio/stream_data/compat.py rename to paddlespeech/audio/streamdata/compat.py index ee564431..11308d03 100644 --- a/paddlespeech/audio/stream_data/compat.py +++ b/paddlespeech/audio/streamdata/compat.py @@ -8,7 +8,7 @@ from typing import List import braceexpand, yaml -from webdataset import autodecode +from . import autodecode from . import cache, filters, shardlists, tariterators from .filters import reraise_exception from .pipeline import DataPipeline diff --git a/paddlespeech/audio/streamdata/extradatasets.py b/paddlespeech/audio/streamdata/extradatasets.py new file mode 100644 index 00000000..e6d61772 --- /dev/null +++ b/paddlespeech/audio/streamdata/extradatasets.py @@ -0,0 +1,141 @@ +# +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). +# Modified from https://github.com/webdataset/webdataset +# + + +"""Train PyTorch models directly from POSIX tar archive. + +Code works locally or over HTTP connections. +""" + +import itertools as itt +import os +import random +import sys + +import braceexpand + +from . import utils +from .paddle_utils import IterableDataset +from .utils import PipelineStage + + +class MockDataset(IterableDataset): + """MockDataset. + + A mock dataset for performance testing and unit testing. + """ + + def __init__(self, sample, length): + """Create a mock dataset instance. + + :param sample: the sample to be returned repeatedly + :param length: the length of the mock dataset + """ + self.sample = sample + self.length = length + + def __iter__(self): + """Return an iterator over this mock dataset.""" + for i in range(self.length): + yield self.sample + + +class repeatedly(IterableDataset, PipelineStage): + """Repeatedly yield samples from a dataset.""" + + def __init__(self, source, nepochs=None, nbatches=None, length=None): + """Create an instance of Repeatedly. + + :param nepochs: repeat for a maximum of nepochs + :param nbatches: repeat for a maximum of nbatches + """ + self.source = source + self.length = length + self.nbatches = nbatches + + def invoke(self, source): + """Return an iterator that iterates repeatedly over a source.""" + return utils.repeatedly( + source, + nepochs=self.nepochs, + nbatches=self.nbatches, + ) + + +class with_epoch(IterableDataset): + """Change the actual and nominal length of an IterableDataset. + + This will continuously iterate through the original dataset, but + impose new epoch boundaries at the given length/nominal. + This exists mainly as a workaround for the odd logic in DataLoader. + It is also useful for choosing smaller nominal epoch sizes with + very large datasets. + + """ + + def __init__(self, dataset, length): + """Chop the dataset to the given length. + + :param dataset: IterableDataset + :param length: declared length of the dataset + :param nominal: nominal length of dataset (if different from declared) + """ + super().__init__() + self.length = length + self.source = None + + def __getstate__(self): + """Return the pickled state of the dataset. + + This resets the dataset iterator, since that can't be pickled. + """ + result = dict(self.__dict__) + result["source"] = None + return result + + def invoke(self, dataset): + """Return an iterator over the dataset. + + This iterator returns as many samples as given by the `length` + parameter. + """ + if self.source is None: + self.source = iter(dataset) + for i in range(self.length): + try: + sample = next(self.source) + except StopIteration: + self.source = iter(dataset) + try: + sample = next(self.source) + except StopIteration: + return + yield sample + self.source = None + + +class with_length(IterableDataset, PipelineStage): + """Repeatedly yield samples from a dataset.""" + + def __init__(self, dataset, length): + """Create an instance of Repeatedly. + + :param dataset: source dataset + :param length: stated length + """ + super().__init__() + self.dataset = dataset + self.length = length + + def invoke(self, dataset): + """Return an iterator that iterates repeatedly over a source.""" + return iter(dataset) + + def __len__(self): + """Return the user specified length.""" + return self.length diff --git a/paddlespeech/audio/stream_data/filters.py b/paddlespeech/audio/streamdata/filters.py similarity index 99% rename from paddlespeech/audio/stream_data/filters.py rename to paddlespeech/audio/streamdata/filters.py index db3e037a..0ade66f9 100644 --- a/paddlespeech/audio/stream_data/filters.py +++ b/paddlespeech/audio/streamdata/filters.py @@ -21,7 +21,7 @@ from functools import reduce, wraps import numpy as np -from webdataset import autodecode +from . import autodecode from . import utils from .paddle_utils import PaddleTensor from .utils import PipelineStage @@ -932,4 +932,4 @@ def _placeholder(source): for data in source: yield data -placeholder = pipelinefilter(_placeholder) \ No newline at end of file +placeholder = pipelinefilter(_placeholder) diff --git a/paddlespeech/audio/streamdata/gopen.py b/paddlespeech/audio/streamdata/gopen.py new file mode 100644 index 00000000..457d048a --- /dev/null +++ b/paddlespeech/audio/streamdata/gopen.py @@ -0,0 +1,340 @@ +# +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). +# + + +"""Open URLs by calling subcommands.""" + +import os, sys, re +from subprocess import PIPE, Popen +from urllib.parse import urlparse + +# global used for printing additional node information during verbose output +info = {} + + +class Pipe: + """Wrapper class for subprocess.Pipe. + + This class looks like a stream from the outside, but it checks + subprocess status and handles timeouts with exceptions. + This way, clients of the class do not need to know that they are + dealing with subprocesses. + + :param *args: passed to `subprocess.Pipe` + :param **kw: passed to `subprocess.Pipe` + :param timeout: timeout for closing/waiting + :param ignore_errors: don't raise exceptions on subprocess errors + :param ignore_status: list of status codes to ignore + """ + + def __init__( + self, + *args, + mode=None, + timeout=7200.0, + ignore_errors=False, + ignore_status=[], + **kw, + ): + """Create an IO Pipe.""" + self.ignore_errors = ignore_errors + self.ignore_status = [0] + ignore_status + self.timeout = timeout + self.args = (args, kw) + if mode[0] == "r": + self.proc = Popen(*args, stdout=PIPE, **kw) + self.stream = self.proc.stdout + if self.stream is None: + raise ValueError(f"{args}: couldn't open") + elif mode[0] == "w": + self.proc = Popen(*args, stdin=PIPE, **kw) + self.stream = self.proc.stdin + if self.stream is None: + raise ValueError(f"{args}: couldn't open") + self.status = None + + def __str__(self): + return f"" + + def check_status(self): + """Poll the process and handle any errors.""" + status = self.proc.poll() + if status is not None: + self.wait_for_child() + + def wait_for_child(self): + """Check the status variable and raise an exception if necessary.""" + verbose = int(os.environ.get("GOPEN_VERBOSE", 0)) + if self.status is not None and verbose: + # print(f"(waiting again [{self.status} {os.getpid()}:{self.proc.pid}])", file=sys.stderr) + return + self.status = self.proc.wait() + if verbose: + print( + f"pipe exit [{self.status} {os.getpid()}:{self.proc.pid}] {self.args} {info}", + file=sys.stderr, + ) + if self.status not in self.ignore_status and not self.ignore_errors: + raise Exception(f"{self.args}: exit {self.status} (read) {info}") + + def read(self, *args, **kw): + """Wrap stream.read and checks status.""" + result = self.stream.read(*args, **kw) + self.check_status() + return result + + def write(self, *args, **kw): + """Wrap stream.write and checks status.""" + result = self.stream.write(*args, **kw) + self.check_status() + return result + + def readLine(self, *args, **kw): + """Wrap stream.readLine and checks status.""" + result = self.stream.readLine(*args, **kw) + self.status = self.proc.poll() + self.check_status() + return result + + def close(self): + """Wrap stream.close, wait for the subprocess, and handle errors.""" + self.stream.close() + self.status = self.proc.wait(self.timeout) + self.wait_for_child() + + def __enter__(self): + """Context handler.""" + return self + + def __exit__(self, etype, value, traceback): + """Context handler.""" + self.close() + + +def set_options( + obj, timeout=None, ignore_errors=None, ignore_status=None, handler=None +): + """Set options for Pipes. + + This function can be called on any stream. It will set pipe options only + when its argument is a pipe. + + :param obj: any kind of stream + :param timeout: desired timeout + :param ignore_errors: desired ignore_errors setting + :param ignore_status: desired ignore_status setting + :param handler: desired error handler + """ + if not isinstance(obj, Pipe): + return False + if timeout is not None: + obj.timeout = timeout + if ignore_errors is not None: + obj.ignore_errors = ignore_errors + if ignore_status is not None: + obj.ignore_status = ignore_status + if handler is not None: + obj.handler = handler + return True + + +def gopen_file(url, mode="rb", bufsize=8192): + """Open a file. + + This works for local files, files over HTTP, and pipe: files. + + :param url: URL to be opened + :param mode: mode to open it with + :param bufsize: requested buffer size + """ + return open(url, mode) + + +def gopen_pipe(url, mode="rb", bufsize=8192): + """Use gopen to open a pipe. + + :param url: a pipe: URL + :param mode: desired mode + :param bufsize: desired buffer size + """ + assert url.startswith("pipe:") + cmd = url[5:] + if mode[0] == "r": + return Pipe( + cmd, + mode=mode, + shell=True, + bufsize=bufsize, + ignore_status=[141], + ) # skipcq: BAN-B604 + elif mode[0] == "w": + return Pipe( + cmd, + mode=mode, + shell=True, + bufsize=bufsize, + ignore_status=[141], + ) # skipcq: BAN-B604 + else: + raise ValueError(f"{mode}: unknown mode") + + +def gopen_curl(url, mode="rb", bufsize=8192): + """Open a URL with `curl`. + + :param url: url (usually, http:// etc.) + :param mode: file mode + :param bufsize: buffer size + """ + if mode[0] == "r": + cmd = f"curl -s -L '{url}'" + return Pipe( + cmd, + mode=mode, + shell=True, + bufsize=bufsize, + ignore_status=[141, 23], + ) # skipcq: BAN-B604 + elif mode[0] == "w": + cmd = f"curl -s -L -T - '{url}'" + return Pipe( + cmd, + mode=mode, + shell=True, + bufsize=bufsize, + ignore_status=[141, 26], + ) # skipcq: BAN-B604 + else: + raise ValueError(f"{mode}: unknown mode") + + +def gopen_htgs(url, mode="rb", bufsize=8192): + """Open a URL with `curl`. + + :param url: url (usually, http:// etc.) + :param mode: file mode + :param bufsize: buffer size + """ + if mode[0] == "r": + url = re.sub(r"(?i)^htgs://", "gs://", url) + cmd = f"curl -s -L '{url}'" + return Pipe( + cmd, + mode=mode, + shell=True, + bufsize=bufsize, + ignore_status=[141, 23], + ) # skipcq: BAN-B604 + elif mode[0] == "w": + raise ValueError(f"{mode}: cannot write") + else: + raise ValueError(f"{mode}: unknown mode") + + + +def gopen_gsutil(url, mode="rb", bufsize=8192): + """Open a URL with `curl`. + + :param url: url (usually, http:// etc.) + :param mode: file mode + :param bufsize: buffer size + """ + if mode[0] == "r": + cmd = f"gsutil cat '{url}'" + return Pipe( + cmd, + mode=mode, + shell=True, + bufsize=bufsize, + ignore_status=[141, 23], + ) # skipcq: BAN-B604 + elif mode[0] == "w": + cmd = f"gsutil cp - '{url}'" + return Pipe( + cmd, + mode=mode, + shell=True, + bufsize=bufsize, + ignore_status=[141, 26], + ) # skipcq: BAN-B604 + else: + raise ValueError(f"{mode}: unknown mode") + + + +def gopen_error(url, *args, **kw): + """Raise a value error. + + :param url: url + :param args: other arguments + :param kw: other keywords + """ + raise ValueError(f"{url}: no gopen handler defined") + + +"""A dispatch table mapping URL schemes to handlers.""" +gopen_schemes = dict( + __default__=gopen_error, + pipe=gopen_pipe, + http=gopen_curl, + https=gopen_curl, + sftp=gopen_curl, + ftps=gopen_curl, + scp=gopen_curl, + gs=gopen_gsutil, + htgs=gopen_htgs, +) + + +def gopen(url, mode="rb", bufsize=8192, **kw): + """Open the URL. + + This uses the `gopen_schemes` dispatch table to dispatch based + on scheme. + + Support for the following schemes is built-in: pipe, file, + http, https, sftp, ftps, scp. + + When no scheme is given the url is treated as a file. + + You can use the OPEN_VERBOSE argument to get info about + files being opened. + + :param url: the source URL + :param mode: the mode ("rb", "r") + :param bufsize: the buffer size + """ + global fallback_gopen + verbose = int(os.environ.get("GOPEN_VERBOSE", 0)) + if verbose: + print("GOPEN", url, info, file=sys.stderr) + assert mode in ["rb", "wb"], mode + if url == "-": + if mode == "rb": + return sys.stdin.buffer + elif mode == "wb": + return sys.stdout.buffer + else: + raise ValueError(f"unknown mode {mode}") + pr = urlparse(url) + if pr.scheme == "": + bufsize = int(os.environ.get("GOPEN_BUFFER", -1)) + return open(url, mode, buffering=bufsize) + if pr.scheme == "file": + bufsize = int(os.environ.get("GOPEN_BUFFER", -1)) + return open(pr.path, mode, buffering=bufsize) + handler = gopen_schemes["__default__"] + handler = gopen_schemes.get(pr.scheme, handler) + return handler(url, mode, bufsize, **kw) + + +def reader(url, **kw): + """Open url with gopen and mode "rb". + + :param url: source URL + :param kw: other keywords forwarded to gopen + """ + return gopen(url, "rb", **kw) diff --git a/paddlespeech/audio/streamdata/handlers.py b/paddlespeech/audio/streamdata/handlers.py new file mode 100644 index 00000000..7f3d28b6 --- /dev/null +++ b/paddlespeech/audio/streamdata/handlers.py @@ -0,0 +1,47 @@ +# +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). +# + +"""Pluggable exception handlers. + +These are functions that take an exception as an argument and then return... + +- the exception (in order to re-raise it) +- True (in order to continue and ignore the exception) +- False (in order to ignore the exception and stop processing) + +They are used as handler= arguments in much of the library. +""" + +import time, warnings + + +def reraise_exception(exn): + """Call in an exception handler to re-raise the exception.""" + raise exn + + +def ignore_and_continue(exn): + """Call in an exception handler to ignore any exception and continue.""" + return True + + +def warn_and_continue(exn): + """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" + warnings.warn(repr(exn)) + time.sleep(0.5) + return True + + +def ignore_and_stop(exn): + """Call in an exception handler to ignore any exception and stop further processing.""" + return False + + +def warn_and_stop(exn): + """Call in an exception handler to ignore any exception and stop further processing.""" + warnings.warn(repr(exn)) + time.sleep(0.5) + return False diff --git a/paddlespeech/audio/streamdata/mix.py b/paddlespeech/audio/streamdata/mix.py new file mode 100644 index 00000000..7d790f00 --- /dev/null +++ b/paddlespeech/audio/streamdata/mix.py @@ -0,0 +1,85 @@ +# +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). +# Modified from https://github.com/webdataset/webdataset +# + +"""Classes for mixing samples from multiple sources.""" + +import itertools, os, random, time, sys +from functools import reduce, wraps + +import numpy as np + +from . import autodecode, utils +from .paddle_utils import PaddleTensor, IterableDataset +from .utils import PipelineStage + + +def round_robin_shortest(*sources): + i = 0 + while True: + try: + sample = next(sources[i % len(sources)]) + yield sample + except StopIteration: + break + i += 1 + + +def round_robin_longest(*sources): + i = 0 + while len(sources) > 0: + try: + sample = next(sources[i]) + i += 1 + yield sample + except StopIteration: + del sources[i] + + +class RoundRobin(IterableDataset): + def __init__(self, datasets, longest=False): + self.datasets = datasets + self.longest = longest + + def __iter__(self): + """Return an iterator over the sources.""" + sources = [iter(d) for d in self.datasets] + if self.longest: + return round_robin_longest(*sources) + else: + return round_robin_shortest(*sources) + + +def random_samples(sources, probs=None, longest=False): + if probs is None: + probs = [1] * len(sources) + else: + probs = list(probs) + while len(sources) > 0: + cum = (np.array(probs) / np.sum(probs)).cumsum() + r = random.random() + i = np.searchsorted(cum, r) + try: + yield next(sources[i]) + except StopIteration: + if longest: + del sources[i] + del probs[i] + else: + break + + +class RandomMix(IterableDataset): + def __init__(self, datasets, probs=None, longest=False): + self.datasets = datasets + self.probs = probs + self.longest = longest + + def __iter__(self): + """Return an iterator over the sources.""" + sources = [iter(d) for d in self.datasets] + return random_samples(sources, self.probs, longest=self.longest) diff --git a/paddlespeech/audio/stream_data/paddle_utils.py b/paddlespeech/audio/streamdata/paddle_utils.py similarity index 100% rename from paddlespeech/audio/stream_data/paddle_utils.py rename to paddlespeech/audio/streamdata/paddle_utils.py diff --git a/paddlespeech/audio/stream_data/pipeline.py b/paddlespeech/audio/streamdata/pipeline.py similarity index 96% rename from paddlespeech/audio/stream_data/pipeline.py rename to paddlespeech/audio/streamdata/pipeline.py index e738083f..7339a762 100644 --- a/paddlespeech/audio/stream_data/pipeline.py +++ b/paddlespeech/audio/streamdata/pipeline.py @@ -10,8 +10,7 @@ from typing import List import braceexpand, yaml -from webdataset import autodecode, extradatasets as eds, filters, shardlists, tariterators -from webdataset.handlers import reraise_exception +from .handlers import reraise_exception from .paddle_utils import DataLoader, IterableDataset from .utils import PipelineStage diff --git a/paddlespeech/audio/stream_data/shardlists.py b/paddlespeech/audio/streamdata/shardlists.py similarity index 100% rename from paddlespeech/audio/stream_data/shardlists.py rename to paddlespeech/audio/streamdata/shardlists.py diff --git a/paddlespeech/audio/stream_data/tariterators.py b/paddlespeech/audio/streamdata/tariterators.py similarity index 99% rename from paddlespeech/audio/stream_data/tariterators.py rename to paddlespeech/audio/streamdata/tariterators.py index d9469797..2c1daae1 100644 --- a/paddlespeech/audio/stream_data/tariterators.py +++ b/paddlespeech/audio/streamdata/tariterators.py @@ -14,8 +14,8 @@ import random, re, tarfile import braceexpand from . import filters -from webdataset import gopen -from webdataset.handlers import reraise_exception +from . import gopen +from .handlers import reraise_exception trace = False meta_prefix = "__" diff --git a/paddlespeech/audio/stream_data/utils.py b/paddlespeech/audio/streamdata/utils.py similarity index 100% rename from paddlespeech/audio/stream_data/utils.py rename to paddlespeech/audio/streamdata/utils.py diff --git a/paddlespeech/audio/streamdata/writer.py b/paddlespeech/audio/streamdata/writer.py new file mode 100644 index 00000000..7d4f7703 --- /dev/null +++ b/paddlespeech/audio/streamdata/writer.py @@ -0,0 +1,450 @@ +# +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). +# Modified from https://github.com/webdataset/webdataset +# + +"""Classes and functions for writing tar files and WebDataset files.""" + +import io, json, pickle, re, tarfile, time +from typing import Any, Callable, Optional, Union + +import numpy as np + +from . import gopen + + +def imageencoder(image: Any, format: str = "PNG"): # skipcq: PYL-W0622 + """Compress an image using PIL and return it as a string. + + Can handle float or uint8 images. + + :param image: ndarray representing an image + :param format: compression format (PNG, JPEG, PPM) + + """ + import PIL + + assert isinstance(image, (PIL.Image.Image, np.ndarray)), type(image) + + if isinstance(image, np.ndarray): + if image.dtype in [np.dtype("f"), np.dtype("d")]: + if not (np.amin(image) > -0.001 and np.amax(image) < 1.001): + raise ValueError( + f"image values out of range {np.amin(image)} {np.amax(image)}" + ) + image = np.clip(image, 0.0, 1.0) + image = np.array(image * 255.0, "uint8") + assert image.ndim in [2, 3] + if image.ndim == 3: + assert image.shape[2] in [1, 3] + image = PIL.Image.fromarray(image) + if format.upper() == "JPG": + format = "JPEG" + elif format.upper() in ["IMG", "IMAGE"]: + format = "PPM" + if format == "JPEG": + opts = dict(quality=100) + else: + opts = {} + with io.BytesIO() as result: + image.save(result, format=format, **opts) + return result.getvalue() + + +def bytestr(data: Any): + """Convert data into a bytestring. + + Uses str and ASCII encoding for data that isn't already in string format. + + :param data: data + """ + if isinstance(data, bytes): + return data + if isinstance(data, str): + return data.encode("ascii") + return str(data).encode("ascii") + +def paddle_dumps(data: Any): + """Dump data into a bytestring using paddle.dumps. + + This delays importing paddle until needed. + + :param data: data to be dumped + """ + import io + + import paddle + + stream = io.BytesIO() + paddle.save(data, stream) + return stream.getvalue() + +def numpy_dumps(data: np.ndarray): + """Dump data into a bytestring using numpy npy format. + + :param data: data to be dumped + """ + import io + + import numpy.lib.format + + stream = io.BytesIO() + numpy.lib.format.write_array(stream, data) + return stream.getvalue() + + +def numpy_npz_dumps(data: np.ndarray): + """Dump data into a bytestring using numpy npz format. + + :param data: data to be dumped + """ + import io + + stream = io.BytesIO() + np.savez_compressed(stream, **data) + return stream.getvalue() + + +def tenbin_dumps(x): + from . import tenbin + + if isinstance(x, list): + return memoryview(tenbin.encode_buffer(x)) + else: + return memoryview(tenbin.encode_buffer([x])) + + +def cbor_dumps(x): + import cbor + + return cbor.dumps(x) + + +def mp_dumps(x): + import msgpack + + return msgpack.packb(x) + + +def add_handlers(d, keys, value): + if isinstance(keys, str): + keys = keys.split() + for k in keys: + d[k] = value + + +def make_handlers(): + """Create a list of handlers for encoding data.""" + handlers = {} + add_handlers( + handlers, "cls cls2 class count index inx id", lambda x: str(x).encode("ascii") + ) + add_handlers(handlers, "txt text transcript", lambda x: x.encode("utf-8")) + add_handlers(handlers, "html htm", lambda x: x.encode("utf-8")) + add_handlers(handlers, "pyd pickle", pickle.dumps) + add_handlers(handlers, "pdparams", paddle_dumps) + add_handlers(handlers, "npy", numpy_dumps) + add_handlers(handlers, "npz", numpy_npz_dumps) + add_handlers(handlers, "ten tenbin tb", tenbin_dumps) + add_handlers(handlers, "json jsn", lambda x: json.dumps(x).encode("utf-8")) + add_handlers(handlers, "mp msgpack msg", mp_dumps) + add_handlers(handlers, "cbor", cbor_dumps) + add_handlers(handlers, "jpg jpeg img image", lambda data: imageencoder(data, "jpg")) + add_handlers(handlers, "png", lambda data: imageencoder(data, "png")) + add_handlers(handlers, "pbm", lambda data: imageencoder(data, "pbm")) + add_handlers(handlers, "pgm", lambda data: imageencoder(data, "pgm")) + add_handlers(handlers, "ppm", lambda data: imageencoder(data, "ppm")) + return handlers + + +default_handlers = make_handlers() + + +def encode_based_on_extension1(data: Any, tname: str, handlers: dict): + """Encode data based on its extension and a dict of handlers. + + :param data: data + :param tname: file extension + :param handlers: handlers + """ + if tname[0] == "_": + if not isinstance(data, str): + raise ValueError("the values of metadata must be of string type") + return data + extension = re.sub(r".*\.", "", tname).lower() + if isinstance(data, bytes): + return data + if isinstance(data, str): + return data.encode("utf-8") + handler = handlers.get(extension) + if handler is None: + raise ValueError(f"no handler found for {extension}") + return handler(data) + + +def encode_based_on_extension(sample: dict, handlers: dict): + """Encode an entire sample with a collection of handlers. + + :param sample: data sample (a dict) + :param handlers: handlers for encoding + """ + return { + k: encode_based_on_extension1(v, k, handlers) for k, v in list(sample.items()) + } + + +def make_encoder(spec: Union[bool, str, dict, Callable]): + """Make an encoder function from a specification. + + :param spec: specification + """ + if spec is False or spec is None: + + def encoder(x): + """Do not encode at all.""" + return x + + elif callable(spec): + encoder = spec + elif isinstance(spec, dict): + + def f(sample): + """Encode based on extension.""" + return encode_based_on_extension(sample, spec) + + encoder = f + + elif spec is True: + handlers = default_handlers + + def g(sample): + """Encode based on extension.""" + return encode_based_on_extension(sample, handlers) + + encoder = g + + else: + raise ValueError(f"{spec}: unknown decoder spec") + if not callable(encoder): + raise ValueError(f"{spec} did not yield a callable encoder") + return encoder + + +class TarWriter: + """A class for writing dictionaries to tar files. + + :param fileobj: fileobj: file name for tar file (.tgz/.tar) or open file descriptor + :param encoder: sample encoding (Default value = True) + :param compress: (Default value = None) + + `True` will use an encoder that behaves similar to the automatic + decoder for `Dataset`. `False` disables encoding and expects byte strings + (except for metadata, which must be strings). The `encoder` argument can + also be a `callable`, or a dictionary mapping extensions to encoders. + + The following code will add two file to the tar archive: `a/b.png` and + `a/b.output.png`. + + ```Python + tarwriter = TarWriter(stream) + image = imread("b.jpg") + image2 = imread("b.out.jpg") + sample = {"__key__": "a/b", "png": image, "output.png": image2} + tarwriter.write(sample) + ``` + """ + + def __init__( + self, + fileobj, + user: str = "bigdata", + group: str = "bigdata", + mode: int = 0o0444, + compress: Optional[bool] = None, + encoder: Union[None, bool, Callable] = True, + keep_meta: bool = False, + ): + """Create a tar writer. + + :param fileobj: stream to write data to + :param user: user for tar files + :param group: group for tar files + :param mode: mode for tar files + :param compress: desired compression + :param encoder: encoder function + :param keep_meta: keep metadata (entries starting with "_") + """ + if isinstance(fileobj, str): + if compress is False: + tarmode = "w|" + elif compress is True: + tarmode = "w|gz" + else: + tarmode = "w|gz" if fileobj.endswith("gz") else "w|" + fileobj = gopen.gopen(fileobj, "wb") + self.own_fileobj = fileobj + else: + tarmode = "w|gz" if compress is True else "w|" + self.own_fileobj = None + self.encoder = make_encoder(encoder) + self.keep_meta = keep_meta + self.stream = fileobj + self.tarstream = tarfile.open(fileobj=fileobj, mode=tarmode) + + self.user = user + self.group = group + self.mode = mode + self.compress = compress + + def __enter__(self): + """Enter context.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Exit context.""" + self.close() + + def close(self): + """Close the tar file.""" + self.tarstream.close() + if self.own_fileobj is not None: + self.own_fileobj.close() + self.own_fileobj = None + + def write(self, obj): + """Write a dictionary to the tar file. + + :param obj: dictionary of objects to be stored + :returns: size of the entry + + """ + total = 0 + obj = self.encoder(obj) + if "__key__" not in obj: + raise ValueError("object must contain a __key__") + for k, v in list(obj.items()): + if k[0] == "_": + continue + if not isinstance(v, (bytes, bytearray, memoryview)): + raise ValueError( + f"{k} doesn't map to a bytes after encoding ({type(v)})" + ) + key = obj["__key__"] + for k in sorted(obj.keys()): + if k == "__key__": + continue + if not self.keep_meta and k[0] == "_": + continue + v = obj[k] + if isinstance(v, str): + v = v.encode("utf-8") + now = time.time() + ti = tarfile.TarInfo(key + "." + k) + ti.size = len(v) + ti.mtime = now + ti.mode = self.mode + ti.uname = self.user + ti.gname = self.group + if not isinstance(v, (bytes, bytearray, memoryview)): + raise ValueError(f"converter didn't yield bytes: {k}, {type(v)}") + stream = io.BytesIO(v) + self.tarstream.addfile(ti, stream) + total += ti.size + return total + + +class ShardWriter: + """Like TarWriter but splits into multiple shards.""" + + def __init__( + self, + pattern: str, + maxcount: int = 100000, + maxsize: float = 3e9, + post: Optional[Callable] = None, + start_shard: int = 0, + **kw, + ): + """Create a ShardWriter. + + :param pattern: output file pattern + :param maxcount: maximum number of records per shard (Default value = 100000) + :param maxsize: maximum size of each shard (Default value = 3e9) + :param kw: other options passed to TarWriter + """ + self.verbose = 1 + self.kw = kw + self.maxcount = maxcount + self.maxsize = maxsize + self.post = post + + self.tarstream = None + self.shard = start_shard + self.pattern = pattern + self.total = 0 + self.count = 0 + self.size = 0 + self.fname = None + self.next_stream() + + def next_stream(self): + """Close the current stream and move to the next.""" + self.finish() + self.fname = self.pattern % self.shard + if self.verbose: + print( + "# writing", + self.fname, + self.count, + "%.1f GB" % (self.size / 1e9), + self.total, + ) + self.shard += 1 + stream = open(self.fname, "wb") + self.tarstream = TarWriter(stream, **self.kw) + self.count = 0 + self.size = 0 + + def write(self, obj): + """Write a sample. + + :param obj: sample to be written + """ + if ( + self.tarstream is None + or self.count >= self.maxcount + or self.size >= self.maxsize + ): + self.next_stream() + size = self.tarstream.write(obj) + self.count += 1 + self.total += 1 + self.size += size + + def finish(self): + """Finish all writing (use close instead).""" + if self.tarstream is not None: + self.tarstream.close() + assert self.fname is not None + if callable(self.post): + self.post(self.fname) + self.tarstream = None + + def close(self): + """Close the stream.""" + self.finish() + del self.tarstream + del self.shard + del self.count + del self.size + + def __enter__(self): + """Enter context.""" + return self + + def __exit__(self, *args, **kw): + """Exit context.""" + self.close() diff --git a/paddlespeech/s2t/io/dataloader.py b/paddlespeech/s2t/io/dataloader.py index c27969f0..2f3803fa 100644 --- a/paddlespeech/s2t/io/dataloader.py +++ b/paddlespeech/s2t/io/dataloader.py @@ -18,6 +18,7 @@ from typing import Text import jsonlines import numpy as np +import paddle from paddle.io import BatchSampler from paddle.io import DataLoader from paddle.io import DistributedBatchSampler @@ -28,7 +29,7 @@ from paddlespeech.s2t.io.dataset import TransformDataset from paddlespeech.s2t.io.reader import LoadInputsAndTargets from paddlespeech.s2t.utils.log import Log -import paddlespeech.audio.stream_data as stream_data +import paddlespeech.audio.streamdata as streamdata from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer __all__ = ["BatchDataLoader"] @@ -101,38 +102,46 @@ class StreamDataLoader(): shardlist.append(line.strip()) if self.dist_sampler: - base_dataset = stream_data.DataPipeline( - stream_data.SimpleShardList(shardlist), - stream_data.split_by_node, - stream_data.split_by_worker, - stream_data.tarfile_to_samples(stream_data.reraise_exception) + base_dataset = streamdata.DataPipeline( + streamdata.SimpleShardList(shardlist), + streamdata.split_by_node, + streamdata.split_by_worker, + streamdata.tarfile_to_samples(streamdata.reraise_exception) ) else: - base_dataset = stream_data.DataPipeline( - stream_data.SimpleShardList(shardlist), - stream_data.split_by_worker, - stream_data.tarfile_to_samples(stream_data.reraise_exception) + base_dataset = streamdata.DataPipeline( + streamdata.SimpleShardList(shardlist), + streamdata.split_by_worker, + streamdata.tarfile_to_samples(streamdata.reraise_exception) ) self.dataset = base_dataset.append_list( - stream_data.tokenize(symbol_table), - stream_data.data_filter(frame_shift=frame_shift, max_length=maxlen_in, min_length=minlen_in, token_max_length=maxlen_out, token_min_length=minlen_in), - stream_data.resample(resample_rate=resample_rate), - stream_data.compute_fbank(num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither), - stream_data.spec_aug(**augment_conf) if train_mode else stream_data.placeholder(), # num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80) - stream_data.shuffle(shuffle_size), - stream_data.sort(sort_size=sort_size), - stream_data.batched(batch_size), - stream_data.padding(), - stream_data.cmvn(cmvn_file) - ) - self.loader = stream_data.WebLoader( - self.dataset, - num_workers=self.n_iter_processes, - prefetch_factor = self.prefetch_factor, - batch_size=None + streamdata.tokenize(symbol_table), + streamdata.data_filter(frame_shift=frame_shift, max_length=maxlen_in, min_length=minlen_in, token_max_length=maxlen_out, token_min_length=minlen_in), + streamdata.resample(resample_rate=resample_rate), + streamdata.compute_fbank(num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither), + streamdata.spec_aug(**augment_conf) if train_mode else streamdata.placeholder(), # num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80) + streamdata.shuffle(shuffle_size), + streamdata.sort(sort_size=sort_size), + streamdata.batched(batch_size), + streamdata.padding(), + streamdata.cmvn(cmvn_file) ) + if paddle.__version__ >= '2.3.2': + self.loader = streamdata.WebLoader( + self.dataset, + num_workers=self.n_iter_processes, + prefetch_factor = self.prefetch_factor, + batch_size=None + ) + else: + self.loader = streamdata.WebLoader( + self.dataset, + num_workers=self.n_iter_processes, + batch_size=None + ) + def __iter__(self): return self.loader.__iter__() diff --git a/setup.py b/setup.py index b94a4cb2..035d0b2d 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,8 @@ base = [ "pypinyin", "pypinyin-dict", "python-dateutil", "pyworld", "resampy==0.2.2", "sacrebleu", "scipy", "sentencepiece~=0.1.96", "soundfile~=0.10", "textgrid", "timer", "tqdm", "typeguard", "visualdl", "webrtcvad", - "yacs~=0.1.8", "prettytable", "zhon", 'colorlog', 'pathos == 0.2.8', 'webdataset' + "yacs~=0.1.8", "prettytable", "zhon", "colorlog", "pathos == 0.2.8", + "braceexpand", "pyyaml" ] server = [