diff --git a/examples/wenetspeech/asr1/conf/conformer.yaml b/examples/wenetspeech/asr1/conf/conformer.yaml index 6c2bbca4..d1ac20b9 100644 --- a/examples/wenetspeech/asr1/conf/conformer.yaml +++ b/examples/wenetspeech/asr1/conf/conformer.yaml @@ -1,7 +1,6 @@ ############################################ # Network Architecture # ############################################ -cmvn_file: cmvn_file_type: "json" # encoder related encoder: conformer @@ -43,40 +42,42 @@ model_conf: ########################################### # Data # ########################################### -train_manifest: data/manifest.train -dev_manifest: data/manifest.dev -test_manifest: data/manifest.test +train_manifest: data/train_l/data.list +dev_manifest: data/dev/data.list +test_manifest: data/test_meeting/data.list ########################################### # Dataloader # ########################################### -vocab_filepath: data/lang_char/vocab.txt +use_stream_data: True unit_type: 'char' +vocab_filepath: data/lang_char/vocab.txt preprocess_config: conf/preprocess.yaml +cmvn_file: data/mean_std.json spm_model_prefix: '' feat_dim: 80 stride_ms: 10.0 window_ms: 25.0 +dither: 0.1 sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs -batch_size: 64 -maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced -maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced -minibatches: 0 # for debug -batch_count: auto -batch_bins: 0 -batch_frames_in: 0 -batch_frames_out: 0 -batch_frames_inout: 0 -num_workers: 0 -subsampling_factor: 1 +batch_size: 32 +minlen_in: 10 +maxlen_in: 1200 # if input length(number of frames) > maxlen-in, data is automatically removed +minlen_out: 0 +maxlen_out: 150 # if output length(number of tokens) > maxlen-out, data is automatically removed +resample_rate: 16000 +shuffle_size: 1500 # read number of 'shuffle_size' data as a chunk, shuffle the data in the chunk +sort_size: 1000 # read number of 'sort_size' data as a chunk, sort the data in the chunk +num_workers: 8 +prefetch_factor: 10 +dist_sampler: True num_encs: 1 - ########################################### # Training # ########################################### -n_epoch: 240 -accum_grad: 16 +n_epoch: 32 +accum_grad: 32 global_grad_clip: 5.0 log_interval: 100 checkpoint: diff --git a/examples/wenetspeech/asr1/local/data.sh b/examples/wenetspeech/asr1/local/data.sh index d216dd84..62579ba3 100755 --- a/examples/wenetspeech/asr1/local/data.sh +++ b/examples/wenetspeech/asr1/local/data.sh @@ -2,6 +2,8 @@ # Copyright 2021 Mobvoi Inc(Author: Di Wu, Binbin Zhang) # NPU, ASLP Group (Author: Qijie Shao) +# +# Modified from wenet(https://github.com/wenet-e2e/wenet) stage=-1 stop_stage=100 @@ -30,7 +32,7 @@ mkdir -p data TARGET_DIR=${MAIN_ROOT}/dataset mkdir -p ${TARGET_DIR} -if [ ${stage} -le -2 ] && [ ${stop_stage} -ge -2 ]; then +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then # download data echo "Please follow https://github.com/wenet-e2e/WenetSpeech to download the data." exit 0; @@ -44,86 +46,57 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then data || exit 1; fi -if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then - # generate manifests - python3 ${TARGET_DIR}/aishell/aishell.py \ - --manifest_prefix="data/manifest" \ - --target_dir="${TARGET_DIR}/aishell" - - if [ $? -ne 0 ]; then - echo "Prepare Aishell failed. Terminated." - exit 1 - fi - - for dataset in train dev test; do - mv data/manifest.${dataset} data/manifest.${dataset}.raw - done -fi - -if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then - # compute mean and stddev for normalizer - if $cmvn; then - full_size=`cat data/${train_set}/wav.scp | wc -l` - sampling_size=$((full_size / cmvn_sampling_divisor)) - shuf -n $sampling_size data/$train_set/wav.scp \ - > data/$train_set/wav.scp.sampled - num_workers=$(nproc) - - python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ - --manifest_path="data/manifest.train.raw" \ - --spectrum_type="fbank" \ - --feat_dim=80 \ - --delta_delta=false \ - --stride_ms=10 \ - --window_ms=25 \ - --sample_rate=16000 \ - --use_dB_normalization=False \ - --num_samples=-1 \ - --num_workers=${num_workers} \ - --output_path="data/mean_std.json" - - if [ $? -ne 0 ]; then - echo "Compute mean and stddev failed. Terminated." - exit 1 - fi - fi -fi - -dict=data/dict/lang_char.txt +dict=data/lang_char/vocab.txt if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then - # download data, generate manifests - # build vocabulary - python3 ${MAIN_ROOT}/utils/build_vocab.py \ - --unit_type="char" \ - --count_threshold=0 \ - --vocab_path="data/lang_char/vocab.txt" \ - --manifest_paths "data/manifest.train.raw" - - if [ $? -ne 0 ]; then - echo "Build vocabulary failed. Terminated." - exit 1 - fi + echo "Make a dictionary" + echo "dictionary: ${dict}" + mkdir -p $(dirname $dict) + echo "" > ${dict} # 0 will be used for "blank" in CTC + echo "" >> ${dict} # must be 1 + echo "▁" >> ${dict} # ▁ is for space + utils/text2token.py -s 1 -n 1 --space "▁" data/${train_set}/text \ + | cut -f 2- -d" " | tr " " "\n" \ + | sort | uniq | grep -a -v -e '^\s*$' \ + | grep -v "▁" \ + | awk '{print $0}' >> ${dict} \ + || exit 1; + echo "" >> $dict fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then - # format manifest with tokenids, vocab size - for dataset in train dev test; do - { - python3 ${MAIN_ROOT}/utils/format_data.py \ - --cmvn_path "data/mean_std.json" \ - --unit_type "char" \ - --vocab_path="data/vocab.txt" \ - --manifest_path="data/manifest.${dataset}.raw" \ - --output_path="data/manifest.${dataset}" + echo "Compute cmvn" + # Here we use all the training data, you can sample some some data to save time + # BUG!!! We should use the segmented data for CMVN + if $cmvn; then + full_size=`cat data/${train_set}/wav.scp | wc -l` + sampling_size=$((full_size / cmvn_sampling_divisor)) + shuf -n $sampling_size data/$train_set/wav.scp \ + > data/$train_set/wav.scp.sampled + python3 utils/compute_cmvn_stats.py \ + --num_workers 16 \ + --train_config $train_config \ + --in_scp data/$train_set/wav.scp.sampled \ + --out_cmvn data/$train_set/mean_std.json \ + || exit 1; + fi +fi - if [ $? -ne 0 ]; then - echo "Formt mnaifest failed. Terminated." - exit 1 - fi - } & - done - wait +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + echo "Making shards, please wait..." + RED='\033[0;31m' + NOCOLOR='\033[0m' + echo -e "It requires ${RED}1.2T ${NOCOLOR}space for $shards_dir, please make sure you have enough space" + echo -e "It takes about ${RED}12 ${NOCOLOR}hours with 32 threads" + for x in $dev_set $test_sets ${train_set}; do + dst=$shards_dir/$x + mkdir -p $dst + utils/make_filted_shard_list.py --num_node 1 --num_gpus_per_node 8 --num_utts_per_shard 1000 \ + --do_filter --resample 16000 \ + --num_threads 32 --segments data/$x/segments \ + data/$x/wav.scp data/$x/text \ + $(realpath $dst) data/$x/data.list + done fi -echo "Aishell data preparation done." +echo "Wenetspeech data preparation done." exit 0 diff --git a/examples/wenetspeech/asr1/local/train.sh b/examples/wenetspeech/asr1/local/train.sh new file mode 100755 index 00000000..01af00b6 --- /dev/null +++ b/examples/wenetspeech/asr1/local/train.sh @@ -0,0 +1,68 @@ +#!/bin/bash + +profiler_options= +benchmark_batch_size=0 +benchmark_max_step=0 + +# seed may break model convergence +seed=0 + +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +if [ ${seed} != 0 ]; then + export FLAGS_cudnn_deterministic=True + echo "using seed $seed & FLAGS_cudnn_deterministic=True ..." +fi + +if [ $# -lt 2 ] && [ $# -gt 3 ];then + echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)" + exit -1 +fi + +config_path=$1 +ckpt_name=$2 +ips=$3 + +if [ ! $ips ];then + ips_config= +else + ips_config="--ips="${ips} +fi +echo ${ips_config} + +mkdir -p exp + +if [ ${ngpu} == 0 ]; then +python3 -u ${BIN_DIR}/train.py \ +--ngpu ${ngpu} \ +--seed ${seed} \ +--config ${config_path} \ +--output exp/${ckpt_name} \ +--profiler-options "${profiler_options}" \ +--benchmark-batch-size ${benchmark_batch_size} \ +--benchmark-max-step ${benchmark_max_step} +else +NCCL_SOCKET_IFNAME=eth0 python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \ +--ngpu ${ngpu} \ +--seed ${seed} \ +--config ${config_path} \ +--output exp/${ckpt_name} \ +--profiler-options "${profiler_options}" \ +--benchmark-batch-size ${benchmark_batch_size} \ +--benchmark-max-step ${benchmark_max_step} +fi + + +if [ ${seed} != 0 ]; then + unset FLAGS_cudnn_deterministic +fi + +if [ $? -ne 0 ]; then + echo "Failed in training!" + exit 1 +fi + +exit 0 diff --git a/examples/wenetspeech/asr1/local/wenetspeech_data_prep.sh b/examples/wenetspeech/asr1/local/wenetspeech_data_prep.sh index 85853053..baa2b32d 100755 --- a/examples/wenetspeech/asr1/local/wenetspeech_data_prep.sh +++ b/examples/wenetspeech/asr1/local/wenetspeech_data_prep.sh @@ -24,7 +24,7 @@ stage=1 prefix= train_subset=L -. ./tools/parse_options.sh || exit 1; +. ./utils/parse_options.sh || exit 1; filter_by_id () { idlist=$1 @@ -132,4 +132,4 @@ if [ $stage -le 2 ]; then done fi -echo "$0: Done" \ No newline at end of file +echo "$0: Done" diff --git a/examples/wenetspeech/asr1/run.sh b/examples/wenetspeech/asr1/run.sh index 9995bc63..ddce0a9c 100644 --- a/examples/wenetspeech/asr1/run.sh +++ b/examples/wenetspeech/asr1/run.sh @@ -7,6 +7,7 @@ gpus=0,1,2,3,4,5,6,7 stage=0 stop_stage=100 conf_path=conf/conformer.yaml +ips= #xxx.xxx.xxx.xxx,xxx.xxx.xxx.xxx decode_conf_path=conf/tuning/decode.yaml average_checkpoint=true avg_num=10 @@ -26,7 +27,7 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # train model, all `ckpt` under `exp` dir - CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${ips} fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then diff --git a/paddlespeech/audio/streamdata/__init__.py b/paddlespeech/audio/streamdata/__init__.py new file mode 100644 index 00000000..753fcc11 --- /dev/null +++ b/paddlespeech/audio/streamdata/__init__.py @@ -0,0 +1,70 @@ +# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# See the LICENSE file for licensing terms (BSD-style). +# Modified from https://github.com/webdataset/webdataset +# +# flake8: noqa + +from .cache import ( + cached_tarfile_samples, + cached_tarfile_to_samples, + lru_cleanup, + pipe_cleaner, +) +from .compat import WebDataset, WebLoader, FluidWrapper +from .extradatasets import MockDataset, with_epoch, with_length +from .filters import ( + associate, + batched, + decode, + detshuffle, + extract_keys, + getfirst, + info, + map, + map_dict, + map_tuple, + pipelinefilter, + rename, + rename_keys, + audio_resample, + select, + shuffle, + slice, + to_tuple, + transform_with, + unbatched, + xdecode, + audio_data_filter, + audio_tokenize, + audio_resample, + audio_compute_fbank, + audio_spec_aug, + sort, + audio_padding, + audio_cmvn, + placeholder, +) +from .handlers import ( + ignore_and_continue, + ignore_and_stop, + reraise_exception, + warn_and_continue, + warn_and_stop, +) +from .pipeline import DataPipeline +from .shardlists import ( + MultiShardSample, + ResampledShards, + SimpleShardList, + non_empty, + resampled, + shardspec, + single_node_only, + split_by_node, + split_by_worker, +) +from .tariterators import tarfile_samples, tarfile_to_samples +from .utils import PipelineStage, repeatedly +from .writer import ShardWriter, TarWriter, numpy_dumps +from .mix import RandomMix, RoundRobin diff --git a/paddlespeech/audio/streamdata/autodecode.py b/paddlespeech/audio/streamdata/autodecode.py new file mode 100644 index 00000000..ca0e2ea2 --- /dev/null +++ b/paddlespeech/audio/streamdata/autodecode.py @@ -0,0 +1,445 @@ +# +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). +# Modified from https://github.com/webdataset/webdataset +# + +"""Automatically decode webdataset samples.""" + +import io, json, os, pickle, re, tempfile +from functools import partial + +import numpy as np + +"""Extensions passed on to the image decoder.""" +image_extensions = "jpg jpeg png ppm pgm pbm pnm".split() + + +################################################################ +# handle basic datatypes +################################################################ + + +def paddle_loads(data): + """Load data using paddle.loads, importing paddle only if needed. + + :param data: data to be decoded + """ + import io + + import paddle + + stream = io.BytesIO(data) + return paddle.load(stream) + + +def tenbin_loads(data): + from . import tenbin + + return tenbin.decode_buffer(data) + + +def msgpack_loads(data): + import msgpack + + return msgpack.unpackb(data) + + +def npy_loads(data): + import numpy.lib.format + + stream = io.BytesIO(data) + return numpy.lib.format.read_array(stream) + + +def cbor_loads(data): + import cbor + + return cbor.loads(data) + + +decoders = { + "txt": lambda data: data.decode("utf-8"), + "text": lambda data: data.decode("utf-8"), + "transcript": lambda data: data.decode("utf-8"), + "cls": lambda data: int(data), + "cls2": lambda data: int(data), + "index": lambda data: int(data), + "inx": lambda data: int(data), + "id": lambda data: int(data), + "json": lambda data: json.loads(data), + "jsn": lambda data: json.loads(data), + "pyd": lambda data: pickle.loads(data), + "pickle": lambda data: pickle.loads(data), + "pdparams": lambda data: paddle_loads(data), + "ten": tenbin_loads, + "tb": tenbin_loads, + "mp": msgpack_loads, + "msg": msgpack_loads, + "npy": npy_loads, + "npz": lambda data: np.load(io.BytesIO(data)), + "cbor": cbor_loads, +} + + +def basichandlers(key, data): + """Handle basic file decoding. + + This function is usually part of the post= decoders. + This handles the following forms of decoding: + + - txt -> unicode string + - cls cls2 class count index inx id -> int + - json jsn -> JSON decoding + - pyd pickle -> pickle decoding + - pdparams -> paddle.loads + - ten tenbin -> fast tensor loading + - mp messagepack msg -> messagepack decoding + - npy -> Python NPY decoding + + :param key: file name extension + :param data: binary data to be decoded + """ + extension = re.sub(r".*[.]", "", key) + + if extension in decoders: + return decoders[extension](data) + + return None + + +################################################################ +# Generic extension handler. +################################################################ + + +def call_extension_handler(key, data, f, extensions): + """Call the function f with the given data if the key matches the extensions. + + :param key: actual key found in the sample + :param data: binary data + :param f: decoder function + :param extensions: list of matching extensions + """ + extension = key.lower().split(".") + for target in extensions: + target = target.split(".") + if len(target) > len(extension): + continue + if extension[-len(target) :] == target: + return f(data) + return None + + +def handle_extension(extensions, f): + """Return a decoder function for the list of extensions. + + Extensions can be a space separated list of extensions. + Extensions can contain dots, in which case the corresponding number + of extension components must be present in the key given to f. + Comparisons are case insensitive. + + Examples: + handle_extension("jpg jpeg", my_decode_jpg) # invoked for any file.jpg + handle_extension("seg.jpg", special_case_jpg) # invoked only for file.seg.jpg + """ + extensions = extensions.lower().split() + return partial(call_extension_handler, f=f, extensions=extensions) + + +################################################################ +# handle images +################################################################ + +imagespecs = { + "l8": ("numpy", "uint8", "l"), + "rgb8": ("numpy", "uint8", "rgb"), + "rgba8": ("numpy", "uint8", "rgba"), + "l": ("numpy", "float", "l"), + "rgb": ("numpy", "float", "rgb"), + "rgba": ("numpy", "float", "rgba"), + "paddlel8": ("paddle", "uint8", "l"), + "paddlergb8": ("paddle", "uint8", "rgb"), + "paddlergba8": ("paddle", "uint8", "rgba"), + "paddlel": ("paddle", "float", "l"), + "paddlergb": ("paddle", "float", "rgb"), + "paddle": ("paddle", "float", "rgb"), + "paddlergba": ("paddle", "float", "rgba"), + "pill": ("pil", None, "l"), + "pil": ("pil", None, "rgb"), + "pilrgb": ("pil", None, "rgb"), + "pilrgba": ("pil", None, "rgba"), +} + + +class ImageHandler: + """Decode image data using the given `imagespec`. + + The `imagespec` specifies whether the image is decoded + to numpy/paddle/pi, decoded to uint8/float, and decoded + to l/rgb/rgba: + + - l8: numpy uint8 l + - rgb8: numpy uint8 rgb + - rgba8: numpy uint8 rgba + - l: numpy float l + - rgb: numpy float rgb + - rgba: numpy float rgba + - paddlel8: paddle uint8 l + - paddlergb8: paddle uint8 rgb + - paddlergba8: paddle uint8 rgba + - paddlel: paddle float l + - paddlergb: paddle float rgb + - paddle: paddle float rgb + - paddlergba: paddle float rgba + - pill: pil None l + - pil: pil None rgb + - pilrgb: pil None rgb + - pilrgba: pil None rgba + + """ + + def __init__(self, imagespec, extensions=image_extensions): + """Create an image handler. + + :param imagespec: short string indicating the type of decoding + :param extensions: list of extensions the image handler is invoked for + """ + if imagespec not in list(imagespecs.keys()): + raise ValueError("Unknown imagespec: %s" % imagespec) + self.imagespec = imagespec.lower() + self.extensions = extensions + + def __call__(self, key, data): + """Perform image decoding. + + :param key: file name extension + :param data: binary data + """ + import PIL.Image + + extension = re.sub(r".*[.]", "", key) + if extension.lower() not in self.extensions: + return None + imagespec = self.imagespec + atype, etype, mode = imagespecs[imagespec] + with io.BytesIO(data) as stream: + img = PIL.Image.open(stream) + img.load() + img = img.convert(mode.upper()) + if atype == "pil": + return img + elif atype == "numpy": + result = np.asarray(img) + if result.dtype != np.uint8: + raise ValueError("ImageHandler: numpy image must be uint8") + if etype == "uint8": + return result + else: + return result.astype("f") / 255.0 + elif atype == "paddle": + import paddle + + result = np.asarray(img) + if result.dtype != np.uint8: + raise ValueError("ImageHandler: paddle image must be uint8") + if etype == "uint8": + result = np.array(result.transpose(2, 0, 1)) + return paddle.tensor(result) + else: + result = np.array(result.transpose(2, 0, 1)) + return paddle.tensor(result) / 255.0 + return None + + +def imagehandler(imagespec, extensions=image_extensions): + """Create an image handler. + + This is just a lower case alias for ImageHander. + + :param imagespec: textual image spec + :param extensions: list of extensions the handler should be applied for + """ + return ImageHandler(imagespec, extensions) + + +################################################################ +# torch video +################################################################ + +''' +def torch_video(key, data): + """Decode video using the torchvideo library. + + :param key: file name extension + :param data: data to be decoded + """ + extension = re.sub(r".*[.]", "", key) + if extension not in "mp4 ogv mjpeg avi mov h264 mpg webm wmv".split(): + return None + + import torchvision.io + + with tempfile.TemporaryDirectory() as dirname: + fname = os.path.join(dirname, f"file.{extension}") + with open(fname, "wb") as stream: + stream.write(data) + return torchvision.io.read_video(fname, pts_unit="sec") +''' + + +################################################################ +# paddlespeech.audio +################################################################ + + +def paddle_audio(key, data): + """Decode audio using the paddlespeech.audio library. + + :param key: file name extension + :param data: data to be decoded + """ + extension = re.sub(r".*[.]", "", key) + if extension not in ["flac", "mp3", "sox", "wav", "m4a", "ogg", "wma"]: + return None + + import paddlespeech.audio + + with tempfile.TemporaryDirectory() as dirname: + fname = os.path.join(dirname, f"file.{extension}") + with open(fname, "wb") as stream: + stream.write(data) + return paddlespeech.audio.load(fname) + + +################################################################ +# special class for continuing decoding +################################################################ + + +class Continue: + """Special class for continuing decoding. + + This is mostly used for decompression, as in: + + def decompressor(key, data): + if key.endswith(".gz"): + return Continue(key[:-3], decompress(data)) + return None + """ + + def __init__(self, key, data): + """__init__. + + :param key: + :param data: + """ + self.key, self.data = key, data + + +def gzfilter(key, data): + """Decode .gz files. + + This decodes compressed files and the continues decoding. + + :param key: file name extension + :param data: binary data + """ + import gzip + + if not key.endswith(".gz"): + return None + decompressed = gzip.open(io.BytesIO(data)).read() + return Continue(key[:-3], decompressed) + + +################################################################ +# decode entire training amples +################################################################ + + +default_pre_handlers = [gzfilter] +default_post_handlers = [basichandlers] + + +class Decoder: + """Decode samples using a list of handlers. + + For each key/data item, this iterates through the list of + handlers until some handler returns something other than None. + """ + + def __init__(self, handlers, pre=None, post=None, only=None, partial=False): + """Create a Decoder. + + :param handlers: main list of handlers + :param pre: handlers called before the main list (.gz handler by default) + :param post: handlers called after the main list (default handlers by default) + :param only: a list of extensions; when give, only ignores files with those extensions + :param partial: allow partial decoding (i.e., don't decode fields that aren't of type bytes) + """ + if isinstance(only, str): + only = only.split() + self.only = only if only is None else set(only) + if pre is None: + pre = default_pre_handlers + if post is None: + post = default_post_handlers + assert all(callable(h) for h in handlers), f"one of {handlers} not callable" + assert all(callable(h) for h in pre), f"one of {pre} not callable" + assert all(callable(h) for h in post), f"one of {post} not callable" + self.handlers = pre + handlers + post + self.partial = partial + + def decode1(self, key, data): + """Decode a single field of a sample. + + :param key: file name extension + :param data: binary data + """ + key = "." + key + for f in self.handlers: + result = f(key, data) + if isinstance(result, Continue): + key, data = result.key, result.data + continue + if result is not None: + return result + return data + + def decode(self, sample): + """Decode an entire sample. + + :param sample: the sample, a dictionary of key value pairs + """ + result = {} + assert isinstance(sample, dict), sample + for k, v in list(sample.items()): + if k[0] == "_": + if isinstance(v, bytes): + v = v.decode("utf-8") + result[k] = v + continue + if self.only is not None and k not in self.only: + result[k] = v + continue + assert v is not None + if self.partial: + if isinstance(v, bytes): + result[k] = self.decode1(k, v) + else: + result[k] = v + else: + assert isinstance(v, bytes) + result[k] = self.decode1(k, v) + return result + + def __call__(self, sample): + """Decode an entire sample. + + :param sample: the sample + """ + assert isinstance(sample, dict), (len(sample), sample) + return self.decode(sample) diff --git a/paddlespeech/audio/streamdata/cache.py b/paddlespeech/audio/streamdata/cache.py new file mode 100644 index 00000000..e7bbffa1 --- /dev/null +++ b/paddlespeech/audio/streamdata/cache.py @@ -0,0 +1,190 @@ +# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# See the LICENSE file for licensing terms (BSD-style). +# Modified from https://github.com/webdataset/webdataset +import itertools, os, random, re, sys +from urllib.parse import urlparse + +from . import filters +from . import gopen +from .handlers import reraise_exception +from .tariterators import tar_file_and_group_expander + +default_cache_dir = os.environ.get("WDS_CACHE", "./_cache") +default_cache_size = float(os.environ.get("WDS_CACHE_SIZE", "1e18")) + + +def lru_cleanup(cache_dir, cache_size, keyfn=os.path.getctime, verbose=False): + """Performs cleanup of the file cache in cache_dir using an LRU strategy, + keeping the total size of all remaining files below cache_size.""" + if not os.path.exists(cache_dir): + return + total_size = 0 + for dirpath, dirnames, filenames in os.walk(cache_dir): + for filename in filenames: + total_size += os.path.getsize(os.path.join(dirpath, filename)) + if total_size <= cache_size: + return + # sort files by last access time + files = [] + for dirpath, dirnames, filenames in os.walk(cache_dir): + for filename in filenames: + files.append(os.path.join(dirpath, filename)) + files.sort(key=keyfn, reverse=True) + # delete files until we're under the cache size + while len(files) > 0 and total_size > cache_size: + fname = files.pop() + total_size -= os.path.getsize(fname) + if verbose: + print("# deleting %s" % fname, file=sys.stderr) + os.remove(fname) + + +def download(url, dest, chunk_size=1024 ** 2, verbose=False): + """Download a file from `url` to `dest`.""" + temp = dest + f".temp{os.getpid()}" + with gopen.gopen(url) as stream: + with open(temp, "wb") as f: + while True: + data = stream.read(chunk_size) + if not data: + break + f.write(data) + os.rename(temp, dest) + + +def pipe_cleaner(spec): + """Guess the actual URL from a "pipe:" specification.""" + if spec.startswith("pipe:"): + spec = spec[5:] + words = spec.split(" ") + for word in words: + if re.match(r"^(https?|gs|ais|s3)", word): + return word + return spec + + +def get_file_cached( + spec, + cache_size=-1, + cache_dir=None, + url_to_name=pipe_cleaner, + verbose=False, +): + if cache_size == -1: + cache_size = default_cache_size + if cache_dir is None: + cache_dir = default_cache_dir + url = url_to_name(spec) + parsed = urlparse(url) + dirname, filename = os.path.split(parsed.path) + dirname = dirname.lstrip("/") + dirname = re.sub(r"[:/|;]", "_", dirname) + destdir = os.path.join(cache_dir, dirname) + os.makedirs(destdir, exist_ok=True) + dest = os.path.join(cache_dir, dirname, filename) + if not os.path.exists(dest): + if verbose: + print("# downloading %s to %s" % (url, dest), file=sys.stderr) + lru_cleanup(cache_dir, cache_size, verbose=verbose) + download(spec, dest, verbose=verbose) + return dest + + +def get_filetype(fname): + with os.popen("file '%s'" % fname) as f: + ftype = f.read() + return ftype + + +def check_tar_format(fname): + """Check whether a file is a tar archive.""" + ftype = get_filetype(fname) + return "tar archive" in ftype or "gzip compressed" in ftype + + +verbose_cache = int(os.environ.get("WDS_VERBOSE_CACHE", "0")) + + +def cached_url_opener( + data, + handler=reraise_exception, + cache_size=-1, + cache_dir=None, + url_to_name=pipe_cleaner, + validator=check_tar_format, + verbose=False, + always=False, +): + """Given a stream of url names (packaged in `dict(url=url)`), yield opened streams.""" + verbose = verbose or verbose_cache + for sample in data: + assert isinstance(sample, dict), sample + assert "url" in sample + url = sample["url"] + attempts = 5 + try: + if not always and os.path.exists(url): + dest = url + else: + dest = get_file_cached( + url, + cache_size=cache_size, + cache_dir=cache_dir, + url_to_name=url_to_name, + verbose=verbose, + ) + if verbose: + print("# opening %s" % dest, file=sys.stderr) + assert os.path.exists(dest) + if not validator(dest): + ftype = get_filetype(dest) + with open(dest, "rb") as f: + data = f.read(200) + os.remove(dest) + raise ValueError( + "%s (%s) is not a tar archive, but a %s, contains %s" + % (dest, url, ftype, repr(data)) + ) + try: + stream = open(dest, "rb") + sample.update(stream=stream) + yield sample + except FileNotFoundError as exn: + # dealing with race conditions in lru_cleanup + attempts -= 1 + if attempts > 0: + time.sleep(random.random() * 10) + continue + raise exn + except Exception as exn: + exn.args = exn.args + (url,) + if handler(exn): + continue + else: + break + + +def cached_tarfile_samples( + src, + handler=reraise_exception, + cache_size=-1, + cache_dir=None, + verbose=False, + url_to_name=pipe_cleaner, + always=False, +): + streams = cached_url_opener( + src, + handler=handler, + cache_size=cache_size, + cache_dir=cache_dir, + verbose=verbose, + url_to_name=url_to_name, + always=always, + ) + samples = tar_file_and_group_expander(streams, handler=handler) + return samples + + +cached_tarfile_to_samples = filters.pipelinefilter(cached_tarfile_samples) diff --git a/paddlespeech/audio/streamdata/compat.py b/paddlespeech/audio/streamdata/compat.py new file mode 100644 index 00000000..deda5338 --- /dev/null +++ b/paddlespeech/audio/streamdata/compat.py @@ -0,0 +1,170 @@ +# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# See the LICENSE file for licensing terms (BSD-style). +# Modified from https://github.com/webdataset/webdataset +from dataclasses import dataclass +from itertools import islice +from typing import List + +import braceexpand, yaml + +from . import autodecode +from . import cache, filters, shardlists, tariterators +from .filters import reraise_exception +from .pipeline import DataPipeline +from .paddle_utils import DataLoader, IterableDataset + + +class FluidInterface: + def batched(self, batchsize): + return self.compose(filters.batched(batchsize)) + + def dynamic_batched(self, max_frames_in_batch): + return self.compose(filter.dynamic_batched(max_frames_in_batch)) + + def unbatched(self): + return self.compose(filters.unbatched()) + + def listed(self, batchsize, partial=True): + return self.compose(filters.batched(), batchsize=batchsize, collation_fn=None) + + def unlisted(self): + return self.compose(filters.unlisted()) + + def log_keys(self, logfile=None): + return self.compose(filters.log_keys(logfile)) + + def shuffle(self, size, **kw): + if size < 1: + return self + else: + return self.compose(filters.shuffle(size, **kw)) + + def map(self, f, handler=reraise_exception): + return self.compose(filters.map(f, handler=handler)) + + def decode(self, *args, pre=None, post=None, only=None, partial=False, handler=reraise_exception): + handlers = [autodecode.ImageHandler(x) if isinstance(x, str) else x for x in args] + decoder = autodecode.Decoder(handlers, pre=pre, post=post, only=only, partial=partial) + return self.map(decoder, handler=handler) + + def map_dict(self, handler=reraise_exception, **kw): + return self.compose(filters.map_dict(handler=handler, **kw)) + + def select(self, predicate, **kw): + return self.compose(filters.select(predicate, **kw)) + + def to_tuple(self, *args, handler=reraise_exception): + return self.compose(filters.to_tuple(*args, handler=handler)) + + def map_tuple(self, *args, handler=reraise_exception): + return self.compose(filters.map_tuple(*args, handler=handler)) + + def slice(self, *args): + return self.compose(filters.slice(*args)) + + def rename(self, **kw): + return self.compose(filters.rename(**kw)) + + def rsample(self, p=0.5): + return self.compose(filters.rsample(p)) + + def rename_keys(self, *args, **kw): + return self.compose(filters.rename_keys(*args, **kw)) + + def extract_keys(self, *args, **kw): + return self.compose(filters.extract_keys(*args, **kw)) + + def xdecode(self, *args, **kw): + return self.compose(filters.xdecode(*args, **kw)) + + def audio_data_filter(self, *args, **kw): + return self.compose(filters.audio_data_filter(*args, **kw)) + + def audio_tokenize(self, *args, **kw): + return self.compose(filters.audio_tokenize(*args, **kw)) + + def resample(self, *args, **kw): + return self.compose(filters.resample(*args, **kw)) + + def audio_compute_fbank(self, *args, **kw): + return self.compose(filters.audio_compute_fbank(*args, **kw)) + + def audio_spec_aug(self, *args, **kw): + return self.compose(filters.audio_spec_aug(*args, **kw)) + + def sort(self, size=500): + return self.compose(filters.sort(size)) + + def audio_padding(self): + return self.compose(filters.audio_padding()) + + def audio_cmvn(self, cmvn_file): + return self.compose(filters.audio_cmvn(cmvn_file)) + +class WebDataset(DataPipeline, FluidInterface): + """Small fluid-interface wrapper for DataPipeline.""" + + def __init__( + self, + urls, + handler=reraise_exception, + resampled=False, + repeat=False, + shardshuffle=None, + cache_size=0, + cache_dir=None, + detshuffle=False, + nodesplitter=shardlists.single_node_only, + verbose=False, + ): + super().__init__() + if isinstance(urls, IterableDataset): + assert not resampled + self.append(urls) + elif isinstance(urls, str) and (urls.endswith(".yaml") or urls.endswith(".yml")): + with (open(urls)) as stream: + spec = yaml.safe_load(stream) + assert "datasets" in spec + self.append(shardlists.MultiShardSample(spec)) + elif isinstance(urls, dict): + assert "datasets" in urls + self.append(shardlists.MultiShardSample(urls)) + elif resampled: + self.append(shardlists.ResampledShards(urls)) + else: + self.append(shardlists.SimpleShardList(urls)) + self.append(nodesplitter) + self.append(shardlists.split_by_worker) + if shardshuffle is True: + shardshuffle = 100 + if shardshuffle is not None: + if detshuffle: + self.append(filters.detshuffle(shardshuffle)) + else: + self.append(filters.shuffle(shardshuffle)) + if cache_size == 0: + self.append(tariterators.tarfile_to_samples(handler=handler)) + else: + assert cache_size == -1 or cache_size > 0 + self.append( + cache.cached_tarfile_to_samples( + handler=handler, + verbose=verbose, + cache_size=cache_size, + cache_dir=cache_dir, + ) + ) + + +class FluidWrapper(DataPipeline, FluidInterface): + """Small fluid-interface wrapper for DataPipeline.""" + + def __init__(self, initial): + super().__init__() + self.append(initial) + + +class WebLoader(DataPipeline, FluidInterface): + def __init__(self, *args, **kw): + super().__init__(DataLoader(*args, **kw)) diff --git a/paddlespeech/audio/streamdata/extradatasets.py b/paddlespeech/audio/streamdata/extradatasets.py new file mode 100644 index 00000000..e6d61772 --- /dev/null +++ b/paddlespeech/audio/streamdata/extradatasets.py @@ -0,0 +1,141 @@ +# +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). +# Modified from https://github.com/webdataset/webdataset +# + + +"""Train PyTorch models directly from POSIX tar archive. + +Code works locally or over HTTP connections. +""" + +import itertools as itt +import os +import random +import sys + +import braceexpand + +from . import utils +from .paddle_utils import IterableDataset +from .utils import PipelineStage + + +class MockDataset(IterableDataset): + """MockDataset. + + A mock dataset for performance testing and unit testing. + """ + + def __init__(self, sample, length): + """Create a mock dataset instance. + + :param sample: the sample to be returned repeatedly + :param length: the length of the mock dataset + """ + self.sample = sample + self.length = length + + def __iter__(self): + """Return an iterator over this mock dataset.""" + for i in range(self.length): + yield self.sample + + +class repeatedly(IterableDataset, PipelineStage): + """Repeatedly yield samples from a dataset.""" + + def __init__(self, source, nepochs=None, nbatches=None, length=None): + """Create an instance of Repeatedly. + + :param nepochs: repeat for a maximum of nepochs + :param nbatches: repeat for a maximum of nbatches + """ + self.source = source + self.length = length + self.nbatches = nbatches + + def invoke(self, source): + """Return an iterator that iterates repeatedly over a source.""" + return utils.repeatedly( + source, + nepochs=self.nepochs, + nbatches=self.nbatches, + ) + + +class with_epoch(IterableDataset): + """Change the actual and nominal length of an IterableDataset. + + This will continuously iterate through the original dataset, but + impose new epoch boundaries at the given length/nominal. + This exists mainly as a workaround for the odd logic in DataLoader. + It is also useful for choosing smaller nominal epoch sizes with + very large datasets. + + """ + + def __init__(self, dataset, length): + """Chop the dataset to the given length. + + :param dataset: IterableDataset + :param length: declared length of the dataset + :param nominal: nominal length of dataset (if different from declared) + """ + super().__init__() + self.length = length + self.source = None + + def __getstate__(self): + """Return the pickled state of the dataset. + + This resets the dataset iterator, since that can't be pickled. + """ + result = dict(self.__dict__) + result["source"] = None + return result + + def invoke(self, dataset): + """Return an iterator over the dataset. + + This iterator returns as many samples as given by the `length` + parameter. + """ + if self.source is None: + self.source = iter(dataset) + for i in range(self.length): + try: + sample = next(self.source) + except StopIteration: + self.source = iter(dataset) + try: + sample = next(self.source) + except StopIteration: + return + yield sample + self.source = None + + +class with_length(IterableDataset, PipelineStage): + """Repeatedly yield samples from a dataset.""" + + def __init__(self, dataset, length): + """Create an instance of Repeatedly. + + :param dataset: source dataset + :param length: stated length + """ + super().__init__() + self.dataset = dataset + self.length = length + + def invoke(self, dataset): + """Return an iterator that iterates repeatedly over a source.""" + return iter(dataset) + + def __len__(self): + """Return the user specified length.""" + return self.length diff --git a/paddlespeech/audio/streamdata/filters.py b/paddlespeech/audio/streamdata/filters.py new file mode 100644 index 00000000..82b9c6ba --- /dev/null +++ b/paddlespeech/audio/streamdata/filters.py @@ -0,0 +1,935 @@ +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). +# + +# Modified from https://github.com/webdataset/webdataset +# Modified from wenet(https://github.com/wenet-e2e/wenet) +"""A collection of iterators for data transformations. + +These functions are plain iterator functions. You can find curried versions +in webdataset.filters, and you can find IterableDataset wrappers in +webdataset.processing. +""" + +import io +from fnmatch import fnmatch +import re +import itertools, os, random, sys, time +from functools import reduce, wraps + +import numpy as np + +from . import autodecode +from . import utils +from .paddle_utils import PaddleTensor +from .utils import PipelineStage + +from .. import backends +from ..compliance import kaldi +import paddle +from ..transform.cmvn import GlobalCMVN +from ..utils.tensor_utils import pad_sequence +from ..transform.spec_augment import time_warp +from ..transform.spec_augment import time_mask +from ..transform.spec_augment import freq_mask + +class FilterFunction(object): + """Helper class for currying pipeline stages. + + We use this roundabout construct becauce it can be pickled. + """ + + def __init__(self, f, *args, **kw): + """Create a curried function.""" + self.f = f + self.args = args + self.kw = kw + + def __call__(self, data): + """Call the curried function with the given argument.""" + return self.f(data, *self.args, **self.kw) + + def __str__(self): + """Compute a string representation.""" + return f"<{self.f.__name__} {self.args} {self.kw}>" + + def __repr__(self): + """Compute a string representation.""" + return f"<{self.f.__name__} {self.args} {self.kw}>" + + +class RestCurried(object): + """Helper class for currying pipeline stages. + + We use this roundabout construct because it can be pickled. + """ + + def __init__(self, f): + """Store the function for future currying.""" + self.f = f + + def __call__(self, *args, **kw): + """Curry with the given arguments.""" + return FilterFunction(self.f, *args, **kw) + + +def pipelinefilter(f): + """Turn the decorated function into one that is partially applied for + all arguments other than the first.""" + result = RestCurried(f) + return result + + +def reraise_exception(exn): + """Reraises the given exception; used as a handler. + + :param exn: exception + """ + raise exn + + +def identity(x): + """Return the argument.""" + return x + + +def compose2(f, g): + """Compose two functions, g(f(x)).""" + return lambda x: g(f(x)) + + +def compose(*args): + """Compose a sequence of functions (left-to-right).""" + return reduce(compose2, args) + + +def pipeline(source, *args): + """Write an input pipeline; first argument is source, rest are filters.""" + if len(args) == 0: + return source + return compose(*args)(source) + + +def getfirst(a, keys, default=None, missing_is_error=True): + """Get the first matching key from a dictionary. + + Keys can be specified as a list, or as a string of keys separated by ';'. + """ + if isinstance(keys, str): + assert " " not in keys + keys = keys.split(";") + for k in keys: + if k in a: + return a[k] + if missing_is_error: + raise ValueError(f"didn't find {keys} in {list(a.keys())}") + return default + + +def parse_field_spec(fields): + """Parse a specification for a list of fields to be extracted. + + Keys are separated by spaces in the spec. Each key can itself + be composed of key alternatives separated by ';'. + """ + if isinstance(fields, str): + fields = fields.split() + return [field.split(";") for field in fields] + + +def transform_with(sample, transformers): + """Transform a list of values using a list of functions. + + sample: list of values + transformers: list of functions + + If there are fewer transformers than inputs, or if a transformer + function is None, then the identity function is used for the + corresponding sample fields. + """ + if transformers is None or len(transformers) == 0: + return sample + result = list(sample) + assert len(transformers) <= len(sample) + for i in range(len(transformers)): # skipcq: PYL-C0200 + f = transformers[i] + if f is not None: + result[i] = f(sample[i]) + return result + +### +# Iterators +### + +def _info(data, fmt=None, n=3, every=-1, width=50, stream=sys.stderr, name=""): + """Print information about the samples that are passing through. + + :param data: source iterator + :param fmt: format statement (using sample dict as keyword) + :param n: when to stop + :param every: how often to print + :param width: maximum width + :param stream: output stream + :param name: identifier printed before any output + """ + for i, sample in enumerate(data): + if i < n or (every > 0 and (i + 1) % every == 0): + if fmt is None: + print("---", name, file=stream) + for k, v in sample.items(): + print(k, repr(v)[:width], file=stream) + else: + print(fmt.format(**sample), file=stream) + yield sample + + +info = pipelinefilter(_info) + + +def pick(buf, rng): + k = rng.randint(0, len(buf) - 1) + sample = buf[k] + buf[k] = buf[-1] + buf.pop() + return sample + + +def _shuffle(data, bufsize=1000, initial=100, rng=None, handler=None): + """Shuffle the data in the stream. + + This uses a buffer of size `bufsize`. Shuffling at + startup is less random; this is traded off against + yielding samples quickly. + + data: iterator + bufsize: buffer size for shuffling + returns: iterator + rng: either random module or random.Random instance + + """ + if rng is None: + rng = random.Random(int((os.getpid() + time.time()) * 1e9)) + initial = min(initial, bufsize) + buf = [] + for sample in data: + buf.append(sample) + if len(buf) < bufsize: + try: + buf.append(next(data)) # skipcq: PYL-R1708 + except StopIteration: + pass + if len(buf) >= initial: + yield pick(buf, rng) + while len(buf) > 0: + yield pick(buf, rng) + + +shuffle = pipelinefilter(_shuffle) + + +class detshuffle(PipelineStage): + def __init__(self, bufsize=1000, initial=100, seed=0, epoch=-1): + self.bufsize = bufsize + self.initial = initial + self.seed = seed + self.epoch = epoch + + def run(self, src): + self.epoch += 1 + rng = random.Random() + rng.seed((self.seed, self.epoch)) + return _shuffle(src, self.bufsize, self.initial, rng) + + +def _select(data, predicate): + """Select samples based on a predicate. + + :param data: source iterator + :param predicate: predicate (function) + """ + for sample in data: + if predicate(sample): + yield sample + + +select = pipelinefilter(_select) + + +def _log_keys(data, logfile=None): + import fcntl + + if logfile is None or logfile == "": + for sample in data: + yield sample + else: + with open(logfile, "a") as stream: + for i, sample in enumerate(data): + buf = f"{i}\t{sample.get('__worker__')}\t{sample.get('__rank__')}\t{sample.get('__key__')}\n" + try: + fcntl.flock(stream.fileno(), fcntl.LOCK_EX) + stream.write(buf) + finally: + fcntl.flock(stream.fileno(), fcntl.LOCK_UN) + yield sample + + +log_keys = pipelinefilter(_log_keys) + + +def _decode(data, *args, handler=reraise_exception, **kw): + """Decode data based on the decoding functions given as arguments.""" + + decoder = lambda x: autodecode.imagehandler(x) if isinstance(x, str) else x + handlers = [decoder(x) for x in args] + f = autodecode.Decoder(handlers, **kw) + + for sample in data: + assert isinstance(sample, dict), sample + try: + decoded = f(sample) + except Exception as exn: # skipcq: PYL-W0703 + if handler(exn): + continue + else: + break + yield decoded + + +decode = pipelinefilter(_decode) + + +def _map(data, f, handler=reraise_exception): + """Map samples.""" + for sample in data: + try: + result = f(sample) + except Exception as exn: + if handler(exn): + continue + else: + break + if result is None: + continue + if isinstance(sample, dict) and isinstance(result, dict): + result["__key__"] = sample.get("__key__") + yield result + + +map = pipelinefilter(_map) + + +def _rename(data, handler=reraise_exception, keep=True, **kw): + """Rename samples based on keyword arguments.""" + for sample in data: + try: + if not keep: + yield {k: getfirst(sample, v, missing_is_error=True) for k, v in kw.items()} + else: + + def listify(v): + return v.split(";") if isinstance(v, str) else v + + to_be_replaced = {x for v in kw.values() for x in listify(v)} + result = {k: v for k, v in sample.items() if k not in to_be_replaced} + result.update({k: getfirst(sample, v, missing_is_error=True) for k, v in kw.items()}) + yield result + except Exception as exn: + if handler(exn): + continue + else: + break + + +rename = pipelinefilter(_rename) + + +def _associate(data, associator, **kw): + """Associate additional data with samples.""" + for sample in data: + if callable(associator): + extra = associator(sample["__key__"]) + else: + extra = associator.get(sample["__key__"], {}) + sample.update(extra) # destructive + yield sample + + +associate = pipelinefilter(_associate) + + +def _map_dict(data, handler=reraise_exception, **kw): + """Map the entries in a dict sample with individual functions.""" + assert len(list(kw.keys())) > 0 + for key, f in kw.items(): + assert callable(f), (key, f) + + for sample in data: + assert isinstance(sample, dict) + try: + for k, f in kw.items(): + sample[k] = f(sample[k]) + except Exception as exn: + if handler(exn): + continue + else: + break + yield sample + + +map_dict = pipelinefilter(_map_dict) + + +def _to_tuple(data, *args, handler=reraise_exception, missing_is_error=True, none_is_error=None): + """Convert dict samples to tuples.""" + if none_is_error is None: + none_is_error = missing_is_error + if len(args) == 1 and isinstance(args[0], str) and " " in args[0]: + args = args[0].split() + + for sample in data: + try: + result = tuple([getfirst(sample, f, missing_is_error=missing_is_error) for f in args]) + if none_is_error and any(x is None for x in result): + raise ValueError(f"to_tuple {args} got {sample.keys()}") + yield result + except Exception as exn: + if handler(exn): + continue + else: + break + + +to_tuple = pipelinefilter(_to_tuple) + + +def _map_tuple(data, *args, handler=reraise_exception): + """Map the entries of a tuple with individual functions.""" + args = [f if f is not None else utils.identity for f in args] + for f in args: + assert callable(f), f + for sample in data: + assert isinstance(sample, (list, tuple)) + sample = list(sample) + n = min(len(args), len(sample)) + try: + for i in range(n): + sample[i] = args[i](sample[i]) + except Exception as exn: + if handler(exn): + continue + else: + break + yield tuple(sample) + + +map_tuple = pipelinefilter(_map_tuple) + + +def _unlisted(data): + """Turn batched data back into unbatched data.""" + for batch in data: + assert isinstance(batch, list), sample + for sample in batch: + yield sample + + +unlisted = pipelinefilter(_unlisted) + + +def _unbatched(data): + """Turn batched data back into unbatched data.""" + for sample in data: + assert isinstance(sample, (tuple, list)), sample + assert len(sample) > 0 + for i in range(len(sample[0])): + yield tuple(x[i] for x in sample) + + +unbatched = pipelinefilter(_unbatched) + + +def _rsample(data, p=0.5): + """Randomly subsample a stream of data.""" + assert p >= 0.0 and p <= 1.0 + for sample in data: + if random.uniform(0.0, 1.0) < p: + yield sample + + +rsample = pipelinefilter(_rsample) + +slice = pipelinefilter(itertools.islice) + + +def _extract_keys(source, *patterns, duplicate_is_error=True, ignore_missing=False): + for sample in source: + result = [] + for pattern in patterns: + pattern = pattern.split(";") if isinstance(pattern, str) else pattern + matches = [x for x in sample.keys() if any(fnmatch("." + x, p) for p in pattern)] + if len(matches) == 0: + if ignore_missing: + continue + else: + raise ValueError(f"Cannot find {pattern} in sample keys {sample.keys()}.") + if len(matches) > 1 and duplicate_is_error: + raise ValueError(f"Multiple sample keys {sample.keys()} match {pattern}.") + value = sample[matches[0]] + result.append(value) + yield tuple(result) + + +extract_keys = pipelinefilter(_extract_keys) + + +def _rename_keys(source, *args, keep_unselected=False, must_match=True, duplicate_is_error=True, **kw): + renamings = [(pattern, output) for output, pattern in args] + renamings += [(pattern, output) for output, pattern in kw.items()] + for sample in source: + new_sample = {} + matched = {k: False for k, _ in renamings} + for path, value in sample.items(): + fname = re.sub(r".*/", "", path) + new_name = None + for pattern, name in renamings[::-1]: + if fnmatch(fname.lower(), pattern): + matched[pattern] = True + new_name = name + break + if new_name is None: + if keep_unselected: + new_sample[path] = value + continue + if new_name in new_sample: + if duplicate_is_error: + raise ValueError(f"Duplicate value in sample {sample.keys()} after rename.") + continue + new_sample[new_name] = value + if must_match and not all(matched.values()): + raise ValueError(f"Not all patterns ({matched}) matched sample keys ({sample.keys()}).") + + yield new_sample + + +rename_keys = pipelinefilter(_rename_keys) + + +def decode_bin(stream): + return stream.read() + + +def decode_text(stream): + binary = stream.read() + return binary.decode("utf-8") + + +def decode_pickle(stream): + return pickle.load(stream) + + +default_decoders = [ + ("*.bin", decode_bin), + ("*.txt", decode_text), + ("*.pyd", decode_pickle), +] + + +def find_decoder(decoders, path): + fname = re.sub(r".*/", "", path) + if fname.startswith("__"): + return lambda x: x + for pattern, fun in decoders[::-1]: + if fnmatch(fname.lower(), pattern) or fnmatch("." + fname.lower(), pattern): + return fun + return None + + +def _xdecode( + source, + *args, + must_decode=True, + defaults=default_decoders, + **kw, +): + decoders = list(defaults) + list(args) + decoders += [("*." + k, v) for k, v in kw.items()] + for sample in source: + new_sample = {} + for path, data in sample.items(): + if path.startswith("__"): + new_sample[path] = data + continue + decoder = find_decoder(decoders, path) + if decoder is False: + value = data + elif decoder is None: + if must_decode: + raise ValueError(f"No decoder found for {path}.") + value = data + else: + if isinstance(data, bytes): + data = io.BytesIO(data) + value = decoder(data) + new_sample[path] = value + yield new_sample + +xdecode = pipelinefilter(_xdecode) + + + +def _audio_data_filter(source, + frame_shift=10, + max_length=10240, + min_length=10, + token_max_length=200, + token_min_length=1, + min_output_input_ratio=0.0005, + max_output_input_ratio=1): + """ Filter sample according to feature and label length + Inplace operation. + + Args:: + source: Iterable[{fname, wav, label, sample_rate}] + frame_shift: length of frame shift (ms) + max_length: drop utterance which is greater than max_length(10ms) + min_length: drop utterance which is less than min_length(10ms) + token_max_length: drop utterance which is greater than + token_max_length, especially when use char unit for + english modeling + token_min_length: drop utterance which is + less than token_max_length + min_output_input_ratio: minimal ration of + token_length / feats_length(10ms) + max_output_input_ratio: maximum ration of + token_length / feats_length(10ms) + + Returns: + Iterable[{fname, wav, label, sample_rate}] + """ + for sample in source: + assert 'sample_rate' in sample + assert 'wav' in sample + assert 'label' in sample + # sample['wav'] is paddle.Tensor, we have 100 frames every second (default) + num_frames = sample['wav'].shape[1] / sample['sample_rate'] * (1000 / frame_shift) + if num_frames < min_length: + continue + if num_frames > max_length: + continue + if len(sample['label']) < token_min_length: + continue + if len(sample['label']) > token_max_length: + continue + if num_frames != 0: + if len(sample['label']) / num_frames < min_output_input_ratio: + continue + if len(sample['label']) / num_frames > max_output_input_ratio: + continue + yield sample + +audio_data_filter = pipelinefilter(_audio_data_filter) + +def _audio_tokenize(source, + symbol_table, + bpe_model=None, + non_lang_syms=None, + split_with_space=False): + """ Decode text to chars or BPE + Inplace operation + + Args: + source: Iterable[{fname, wav, txt, sample_rate}] + + Returns: + Iterable[{fname, wav, txt, tokens, label, sample_rate}] + """ + if non_lang_syms is not None: + non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})") + else: + non_lang_syms = {} + non_lang_syms_pattern = None + + if bpe_model is not None: + import sentencepiece as spm + sp = spm.SentencePieceProcessor() + sp.load(bpe_model) + else: + sp = None + + for sample in source: + assert 'txt' in sample + txt = sample['txt'].strip() + if non_lang_syms_pattern is not None: + parts = non_lang_syms_pattern.split(txt.upper()) + parts = [w for w in parts if len(w.strip()) > 0] + else: + parts = [txt] + + label = [] + tokens = [] + for part in parts: + if part in non_lang_syms: + tokens.append(part) + else: + if bpe_model is not None: + tokens.extend(__tokenize_by_bpe_model(sp, part)) + else: + if split_with_space: + part = part.split(" ") + for ch in part: + if ch == ' ': + ch = "" + tokens.append(ch) + + for ch in tokens: + if ch in symbol_table: + label.append(symbol_table[ch]) + elif '' in symbol_table: + label.append(symbol_table['']) + + sample['tokens'] = tokens + sample['label'] = label + yield sample + +audio_tokenize = pipelinefilter(_audio_tokenize) + +def _audio_resample(source, resample_rate=16000): + """ Resample data. + Inplace operation. + + Args: + data: Iterable[{fname, wav, label, sample_rate}] + resample_rate: target resample rate + + Returns: + Iterable[{fname, wav, label, sample_rate}] + """ + for sample in source: + assert 'sample_rate' in sample + assert 'wav' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + if sample_rate != resample_rate: + sample['sample_rate'] = resample_rate + sample['wav'] = paddle.to_tensor(backends.soundfile_backend.resample( + waveform.numpy(), src_sr = sample_rate, target_sr = resample_rate + )) + yield sample + +audio_resample = pipelinefilter(_audio_resample) + +def _audio_compute_fbank(source, + num_mel_bins=80, + frame_length=25, + frame_shift=10, + dither=0.0): + """ Extract fbank + + Args: + source: Iterable[{fname, wav, label, sample_rate}] + num_mel_bins: number of mel filter bank + frame_length: length of one frame (ms) + frame_shift: length of frame shift (ms) + dither: value of dither + + Returns: + Iterable[{fname, feat, label}] + """ + for sample in source: + assert 'sample_rate' in sample + assert 'wav' in sample + assert 'fname' in sample + assert 'label' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + waveform = waveform * (1 << 15) + # Only keep fname, feat, label + mat = kaldi.fbank(waveform, + n_mels=num_mel_bins, + frame_length=frame_length, + frame_shift=frame_shift, + dither=dither, + energy_floor=0.0, + sr=sample_rate) + yield dict(fname=sample['fname'], label=sample['label'], feat=mat) + + +audio_compute_fbank = pipelinefilter(_audio_compute_fbank) + +def _audio_spec_aug(source, + max_w=5, + w_inplace=True, + w_mode="PIL", + max_f=30, + num_f_mask=2, + f_inplace=True, + f_replace_with_zero=False, + max_t=40, + num_t_mask=2, + t_inplace=True, + t_replace_with_zero=False,): + """ Do spec augmentation + Inplace operation + + Args: + source: Iterable[{fname, feat, label}] + max_w: max width of time warp + w_inplace: whether to inplace the original data while time warping + w_mode: time warp mode + max_f: max width of freq mask + num_f_mask: number of freq mask to apply + f_inplace: whether to inplace the original data while frequency masking + f_replace_with_zero: use zero to mask + max_t: max width of time mask + num_t_mask: number of time mask to apply + t_inplace: whether to inplace the original data while time masking + t_replace_with_zero: use zero to mask + + Returns + Iterable[{fname, feat, label}] + """ + for sample in source: + x = sample['feat'] + x = x.numpy() + x = time_warp(x, max_time_warp=max_w, inplace = w_inplace, mode= w_mode) + x = freq_mask(x, F = max_f, n_mask = num_f_mask, inplace = f_inplace, replace_with_zero = f_replace_with_zero) + x = time_mask(x, T = max_t, n_mask = num_t_mask, inplace = t_inplace, replace_with_zero = t_replace_with_zero) + sample['feat'] = paddle.to_tensor(x, dtype=paddle.float32) + yield sample + +audio_spec_aug = pipelinefilter(_audio_spec_aug) + + +def _sort(source, sort_size=500): + """ Sort the data by feature length. + Sort is used after shuffle and before batch, so we can group + utts with similar lengths into a batch, and `sort_size` should + be less than `shuffle_size` + + Args: + source: Iterable[{fname, feat, label}] + sort_size: buffer size for sort + + Returns: + Iterable[{fname, feat, label}] + """ + + buf = [] + for sample in source: + buf.append(sample) + if len(buf) >= sort_size: + buf.sort(key=lambda x: x['feat'].shape[0]) + for x in buf: + yield x + buf = [] + # The sample left over + buf.sort(key=lambda x: x['feat'].shape[0]) + for x in buf: + yield x + +sort = pipelinefilter(_sort) + +def _batched(source, batch_size=16): + """ Static batch the data by `batch_size` + + Args: + data: Iterable[{fname, feat, label}] + batch_size: batch size + + Returns: + Iterable[List[{fname, feat, label}]] + """ + buf = [] + for sample in source: + buf.append(sample) + if len(buf) >= batch_size: + yield buf + buf = [] + if len(buf) > 0: + yield buf + +batched = pipelinefilter(_batched) + +def dynamic_batched(source, max_frames_in_batch=12000): + """ Dynamic batch the data until the total frames in batch + reach `max_frames_in_batch` + + Args: + source: Iterable[{fname, feat, label}] + max_frames_in_batch: max_frames in one batch + + Returns: + Iterable[List[{fname, feat, label}]] + """ + buf = [] + longest_frames = 0 + for sample in source: + assert 'feat' in sample + assert isinstance(sample['feat'], paddle.Tensor) + new_sample_frames = sample['feat'].size(0) + longest_frames = max(longest_frames, new_sample_frames) + frames_after_padding = longest_frames * (len(buf) + 1) + if frames_after_padding > max_frames_in_batch: + yield buf + buf = [sample] + longest_frames = new_sample_frames + else: + buf.append(sample) + if len(buf) > 0: + yield buf + + +def _audio_padding(source): + """ Padding the data into training data + + Args: + source: Iterable[List[{fname, feat, label}]] + + Returns: + Iterable[Tuple(fname, feats, labels, feats lengths, label lengths)] + """ + for sample in source: + assert isinstance(sample, list) + feats_length = paddle.to_tensor([x['feat'].shape[0] for x in sample], + dtype="int64") + order = paddle.argsort(feats_length, descending=True) + feats_lengths = paddle.to_tensor( + [sample[i]['feat'].shape[0] for i in order], dtype="int64") + sorted_feats = [sample[i]['feat'] for i in order] + sorted_keys = [sample[i]['fname'] for i in order] + sorted_labels = [ + paddle.to_tensor(sample[i]['label'], dtype="int32") for i in order + ] + label_lengths = paddle.to_tensor([x.shape[0] for x in sorted_labels], + dtype="int64") + padded_feats = pad_sequence(sorted_feats, + batch_first=True, + padding_value=0) + padding_labels = pad_sequence(sorted_labels, + batch_first=True, + padding_value=-1) + + yield (sorted_keys, padded_feats, feats_lengths, padding_labels, + label_lengths) + +audio_padding = pipelinefilter(_audio_padding) + +def _audio_cmvn(source, cmvn_file): + global_cmvn = GlobalCMVN(cmvn_file) + for batch in source: + sorted_keys, padded_feats, feats_lengths, padding_labels, label_lengths = batch + padded_feats = padded_feats.numpy() + padded_feats = global_cmvn(padded_feats) + padded_feats = paddle.to_tensor(padded_feats, dtype=paddle.float32) + yield (sorted_keys, padded_feats, feats_lengths, padding_labels, + label_lengths) + +audio_cmvn = pipelinefilter(_audio_cmvn) + +def _placeholder(source): + for data in source: + yield data + +placeholder = pipelinefilter(_placeholder) diff --git a/paddlespeech/audio/streamdata/gopen.py b/paddlespeech/audio/streamdata/gopen.py new file mode 100644 index 00000000..457d048a --- /dev/null +++ b/paddlespeech/audio/streamdata/gopen.py @@ -0,0 +1,340 @@ +# +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). +# + + +"""Open URLs by calling subcommands.""" + +import os, sys, re +from subprocess import PIPE, Popen +from urllib.parse import urlparse + +# global used for printing additional node information during verbose output +info = {} + + +class Pipe: + """Wrapper class for subprocess.Pipe. + + This class looks like a stream from the outside, but it checks + subprocess status and handles timeouts with exceptions. + This way, clients of the class do not need to know that they are + dealing with subprocesses. + + :param *args: passed to `subprocess.Pipe` + :param **kw: passed to `subprocess.Pipe` + :param timeout: timeout for closing/waiting + :param ignore_errors: don't raise exceptions on subprocess errors + :param ignore_status: list of status codes to ignore + """ + + def __init__( + self, + *args, + mode=None, + timeout=7200.0, + ignore_errors=False, + ignore_status=[], + **kw, + ): + """Create an IO Pipe.""" + self.ignore_errors = ignore_errors + self.ignore_status = [0] + ignore_status + self.timeout = timeout + self.args = (args, kw) + if mode[0] == "r": + self.proc = Popen(*args, stdout=PIPE, **kw) + self.stream = self.proc.stdout + if self.stream is None: + raise ValueError(f"{args}: couldn't open") + elif mode[0] == "w": + self.proc = Popen(*args, stdin=PIPE, **kw) + self.stream = self.proc.stdin + if self.stream is None: + raise ValueError(f"{args}: couldn't open") + self.status = None + + def __str__(self): + return f"" + + def check_status(self): + """Poll the process and handle any errors.""" + status = self.proc.poll() + if status is not None: + self.wait_for_child() + + def wait_for_child(self): + """Check the status variable and raise an exception if necessary.""" + verbose = int(os.environ.get("GOPEN_VERBOSE", 0)) + if self.status is not None and verbose: + # print(f"(waiting again [{self.status} {os.getpid()}:{self.proc.pid}])", file=sys.stderr) + return + self.status = self.proc.wait() + if verbose: + print( + f"pipe exit [{self.status} {os.getpid()}:{self.proc.pid}] {self.args} {info}", + file=sys.stderr, + ) + if self.status not in self.ignore_status and not self.ignore_errors: + raise Exception(f"{self.args}: exit {self.status} (read) {info}") + + def read(self, *args, **kw): + """Wrap stream.read and checks status.""" + result = self.stream.read(*args, **kw) + self.check_status() + return result + + def write(self, *args, **kw): + """Wrap stream.write and checks status.""" + result = self.stream.write(*args, **kw) + self.check_status() + return result + + def readLine(self, *args, **kw): + """Wrap stream.readLine and checks status.""" + result = self.stream.readLine(*args, **kw) + self.status = self.proc.poll() + self.check_status() + return result + + def close(self): + """Wrap stream.close, wait for the subprocess, and handle errors.""" + self.stream.close() + self.status = self.proc.wait(self.timeout) + self.wait_for_child() + + def __enter__(self): + """Context handler.""" + return self + + def __exit__(self, etype, value, traceback): + """Context handler.""" + self.close() + + +def set_options( + obj, timeout=None, ignore_errors=None, ignore_status=None, handler=None +): + """Set options for Pipes. + + This function can be called on any stream. It will set pipe options only + when its argument is a pipe. + + :param obj: any kind of stream + :param timeout: desired timeout + :param ignore_errors: desired ignore_errors setting + :param ignore_status: desired ignore_status setting + :param handler: desired error handler + """ + if not isinstance(obj, Pipe): + return False + if timeout is not None: + obj.timeout = timeout + if ignore_errors is not None: + obj.ignore_errors = ignore_errors + if ignore_status is not None: + obj.ignore_status = ignore_status + if handler is not None: + obj.handler = handler + return True + + +def gopen_file(url, mode="rb", bufsize=8192): + """Open a file. + + This works for local files, files over HTTP, and pipe: files. + + :param url: URL to be opened + :param mode: mode to open it with + :param bufsize: requested buffer size + """ + return open(url, mode) + + +def gopen_pipe(url, mode="rb", bufsize=8192): + """Use gopen to open a pipe. + + :param url: a pipe: URL + :param mode: desired mode + :param bufsize: desired buffer size + """ + assert url.startswith("pipe:") + cmd = url[5:] + if mode[0] == "r": + return Pipe( + cmd, + mode=mode, + shell=True, + bufsize=bufsize, + ignore_status=[141], + ) # skipcq: BAN-B604 + elif mode[0] == "w": + return Pipe( + cmd, + mode=mode, + shell=True, + bufsize=bufsize, + ignore_status=[141], + ) # skipcq: BAN-B604 + else: + raise ValueError(f"{mode}: unknown mode") + + +def gopen_curl(url, mode="rb", bufsize=8192): + """Open a URL with `curl`. + + :param url: url (usually, http:// etc.) + :param mode: file mode + :param bufsize: buffer size + """ + if mode[0] == "r": + cmd = f"curl -s -L '{url}'" + return Pipe( + cmd, + mode=mode, + shell=True, + bufsize=bufsize, + ignore_status=[141, 23], + ) # skipcq: BAN-B604 + elif mode[0] == "w": + cmd = f"curl -s -L -T - '{url}'" + return Pipe( + cmd, + mode=mode, + shell=True, + bufsize=bufsize, + ignore_status=[141, 26], + ) # skipcq: BAN-B604 + else: + raise ValueError(f"{mode}: unknown mode") + + +def gopen_htgs(url, mode="rb", bufsize=8192): + """Open a URL with `curl`. + + :param url: url (usually, http:// etc.) + :param mode: file mode + :param bufsize: buffer size + """ + if mode[0] == "r": + url = re.sub(r"(?i)^htgs://", "gs://", url) + cmd = f"curl -s -L '{url}'" + return Pipe( + cmd, + mode=mode, + shell=True, + bufsize=bufsize, + ignore_status=[141, 23], + ) # skipcq: BAN-B604 + elif mode[0] == "w": + raise ValueError(f"{mode}: cannot write") + else: + raise ValueError(f"{mode}: unknown mode") + + + +def gopen_gsutil(url, mode="rb", bufsize=8192): + """Open a URL with `curl`. + + :param url: url (usually, http:// etc.) + :param mode: file mode + :param bufsize: buffer size + """ + if mode[0] == "r": + cmd = f"gsutil cat '{url}'" + return Pipe( + cmd, + mode=mode, + shell=True, + bufsize=bufsize, + ignore_status=[141, 23], + ) # skipcq: BAN-B604 + elif mode[0] == "w": + cmd = f"gsutil cp - '{url}'" + return Pipe( + cmd, + mode=mode, + shell=True, + bufsize=bufsize, + ignore_status=[141, 26], + ) # skipcq: BAN-B604 + else: + raise ValueError(f"{mode}: unknown mode") + + + +def gopen_error(url, *args, **kw): + """Raise a value error. + + :param url: url + :param args: other arguments + :param kw: other keywords + """ + raise ValueError(f"{url}: no gopen handler defined") + + +"""A dispatch table mapping URL schemes to handlers.""" +gopen_schemes = dict( + __default__=gopen_error, + pipe=gopen_pipe, + http=gopen_curl, + https=gopen_curl, + sftp=gopen_curl, + ftps=gopen_curl, + scp=gopen_curl, + gs=gopen_gsutil, + htgs=gopen_htgs, +) + + +def gopen(url, mode="rb", bufsize=8192, **kw): + """Open the URL. + + This uses the `gopen_schemes` dispatch table to dispatch based + on scheme. + + Support for the following schemes is built-in: pipe, file, + http, https, sftp, ftps, scp. + + When no scheme is given the url is treated as a file. + + You can use the OPEN_VERBOSE argument to get info about + files being opened. + + :param url: the source URL + :param mode: the mode ("rb", "r") + :param bufsize: the buffer size + """ + global fallback_gopen + verbose = int(os.environ.get("GOPEN_VERBOSE", 0)) + if verbose: + print("GOPEN", url, info, file=sys.stderr) + assert mode in ["rb", "wb"], mode + if url == "-": + if mode == "rb": + return sys.stdin.buffer + elif mode == "wb": + return sys.stdout.buffer + else: + raise ValueError(f"unknown mode {mode}") + pr = urlparse(url) + if pr.scheme == "": + bufsize = int(os.environ.get("GOPEN_BUFFER", -1)) + return open(url, mode, buffering=bufsize) + if pr.scheme == "file": + bufsize = int(os.environ.get("GOPEN_BUFFER", -1)) + return open(pr.path, mode, buffering=bufsize) + handler = gopen_schemes["__default__"] + handler = gopen_schemes.get(pr.scheme, handler) + return handler(url, mode, bufsize, **kw) + + +def reader(url, **kw): + """Open url with gopen and mode "rb". + + :param url: source URL + :param kw: other keywords forwarded to gopen + """ + return gopen(url, "rb", **kw) diff --git a/paddlespeech/audio/streamdata/handlers.py b/paddlespeech/audio/streamdata/handlers.py new file mode 100644 index 00000000..7f3d28b6 --- /dev/null +++ b/paddlespeech/audio/streamdata/handlers.py @@ -0,0 +1,47 @@ +# +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). +# + +"""Pluggable exception handlers. + +These are functions that take an exception as an argument and then return... + +- the exception (in order to re-raise it) +- True (in order to continue and ignore the exception) +- False (in order to ignore the exception and stop processing) + +They are used as handler= arguments in much of the library. +""" + +import time, warnings + + +def reraise_exception(exn): + """Call in an exception handler to re-raise the exception.""" + raise exn + + +def ignore_and_continue(exn): + """Call in an exception handler to ignore any exception and continue.""" + return True + + +def warn_and_continue(exn): + """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" + warnings.warn(repr(exn)) + time.sleep(0.5) + return True + + +def ignore_and_stop(exn): + """Call in an exception handler to ignore any exception and stop further processing.""" + return False + + +def warn_and_stop(exn): + """Call in an exception handler to ignore any exception and stop further processing.""" + warnings.warn(repr(exn)) + time.sleep(0.5) + return False diff --git a/paddlespeech/audio/streamdata/mix.py b/paddlespeech/audio/streamdata/mix.py new file mode 100644 index 00000000..7d790f00 --- /dev/null +++ b/paddlespeech/audio/streamdata/mix.py @@ -0,0 +1,85 @@ +# +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). +# Modified from https://github.com/webdataset/webdataset +# + +"""Classes for mixing samples from multiple sources.""" + +import itertools, os, random, time, sys +from functools import reduce, wraps + +import numpy as np + +from . import autodecode, utils +from .paddle_utils import PaddleTensor, IterableDataset +from .utils import PipelineStage + + +def round_robin_shortest(*sources): + i = 0 + while True: + try: + sample = next(sources[i % len(sources)]) + yield sample + except StopIteration: + break + i += 1 + + +def round_robin_longest(*sources): + i = 0 + while len(sources) > 0: + try: + sample = next(sources[i]) + i += 1 + yield sample + except StopIteration: + del sources[i] + + +class RoundRobin(IterableDataset): + def __init__(self, datasets, longest=False): + self.datasets = datasets + self.longest = longest + + def __iter__(self): + """Return an iterator over the sources.""" + sources = [iter(d) for d in self.datasets] + if self.longest: + return round_robin_longest(*sources) + else: + return round_robin_shortest(*sources) + + +def random_samples(sources, probs=None, longest=False): + if probs is None: + probs = [1] * len(sources) + else: + probs = list(probs) + while len(sources) > 0: + cum = (np.array(probs) / np.sum(probs)).cumsum() + r = random.random() + i = np.searchsorted(cum, r) + try: + yield next(sources[i]) + except StopIteration: + if longest: + del sources[i] + del probs[i] + else: + break + + +class RandomMix(IterableDataset): + def __init__(self, datasets, probs=None, longest=False): + self.datasets = datasets + self.probs = probs + self.longest = longest + + def __iter__(self): + """Return an iterator over the sources.""" + sources = [iter(d) for d in self.datasets] + return random_samples(sources, self.probs, longest=self.longest) diff --git a/paddlespeech/audio/streamdata/paddle_utils.py b/paddlespeech/audio/streamdata/paddle_utils.py new file mode 100644 index 00000000..02bc4c84 --- /dev/null +++ b/paddlespeech/audio/streamdata/paddle_utils.py @@ -0,0 +1,33 @@ +# +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). +# Modified from https://github.com/webdataset/webdataset +# + +"""Mock implementations of paddle interfaces when paddle is not available.""" + + +try: + from paddle.io import DataLoader, IterableDataset +except ModuleNotFoundError: + + class IterableDataset: + """Empty implementation of IterableDataset when paddle is not available.""" + + pass + + class DataLoader: + """Empty implementation of DataLoader when paddle is not available.""" + + pass + +try: + from paddle import Tensor as PaddleTensor +except ModuleNotFoundError: + + class TorchTensor: + """Empty implementation of PaddleTensor when paddle is not available.""" + + pass diff --git a/paddlespeech/audio/streamdata/pipeline.py b/paddlespeech/audio/streamdata/pipeline.py new file mode 100644 index 00000000..7339a762 --- /dev/null +++ b/paddlespeech/audio/streamdata/pipeline.py @@ -0,0 +1,132 @@ +# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# See the LICENSE file for licensing terms (BSD-style). +# Modified from https://github.com/webdataset/webdataset +#%% +import copy, os, random, sys, time +from dataclasses import dataclass +from itertools import islice +from typing import List + +import braceexpand, yaml + +from .handlers import reraise_exception +from .paddle_utils import DataLoader, IterableDataset +from .utils import PipelineStage + + +def add_length_method(obj): + def length(self): + return self.size + + Combined = type( + obj.__class__.__name__ + "_Length", + (obj.__class__, IterableDataset), + {"__len__": length}, + ) + obj.__class__ = Combined + return obj + + +class DataPipeline(IterableDataset, PipelineStage): + """A pipeline starting with an IterableDataset and a series of filters.""" + + def __init__(self, *args, **kwargs): + super().__init__() + self.pipeline = [] + self.length = -1 + self.repetitions = 1 + self.nsamples = -1 + for arg in args: + if arg is None: + continue + if isinstance(arg, list): + self.pipeline.extend(arg) + else: + self.pipeline.append(arg) + + def invoke(self, f, *args, **kwargs): + """Apply a pipeline stage, possibly to the output of a previous stage.""" + if isinstance(f, PipelineStage): + return f.run(*args, **kwargs) + if isinstance(f, (IterableDataset, DataLoader)) and len(args) == 0: + return iter(f) + if isinstance(f, list): + return iter(f) + if callable(f): + result = f(*args, **kwargs) + return result + raise ValueError(f"{f}: not a valid pipeline stage") + + def iterator1(self): + """Create an iterator through one epoch in the pipeline.""" + source = self.invoke(self.pipeline[0]) + for step in self.pipeline[1:]: + source = self.invoke(step, source) + return source + + def iterator(self): + """Create an iterator through the entire dataset, using the given number of repetitions.""" + for i in range(self.repetitions): + for sample in self.iterator1(): + yield sample + + def __iter__(self): + """Create an iterator through the pipeline, repeating and slicing as requested.""" + if self.repetitions != 1: + if self.nsamples > 0: + return islice(self.iterator(), self.nsamples) + else: + return self.iterator() + else: + return self.iterator() + + def stage(self, i): + """Return pipeline stage i.""" + return self.pipeline[i] + + def append(self, f): + """Append a pipeline stage (modifies the object).""" + self.pipeline.append(f) + return self + + def append_list(self, *args): + for arg in args: + self.pipeline.append(arg) + return self + + def compose(self, *args): + """Append a pipeline stage to a copy of the pipeline and returns the copy.""" + result = copy.copy(self) + for arg in args: + result.append(arg) + return result + + def with_length(self, n): + """Add a __len__ method returning the desired value. + + This does not change the actual number of samples in an epoch. + PyTorch IterableDataset should not have a __len__ method. + This is provided only as a workaround for some broken training environments + that require a __len__ method. + """ + self.size = n + return add_length_method(self) + + def with_epoch(self, nsamples=-1, nbatches=-1): + """Change the epoch to return the given number of samples/batches. + + The two arguments mean the same thing.""" + self.repetitions = sys.maxsize + self.nsamples = max(nsamples, nbatches) + return self + + def repeat(self, nepochs=-1, nbatches=-1): + """Repeat iterating through the dataset for the given #epochs up to the given #samples.""" + if nepochs > 0: + self.repetitions = nepochs + self.nsamples = nbatches + else: + self.repetitions = sys.maxsize + self.nsamples = nbatches + return self diff --git a/paddlespeech/audio/streamdata/shardlists.py b/paddlespeech/audio/streamdata/shardlists.py new file mode 100644 index 00000000..cfaf9a64 --- /dev/null +++ b/paddlespeech/audio/streamdata/shardlists.py @@ -0,0 +1,261 @@ +# +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). +# + +# Modified from https://github.com/webdataset/webdataset + +"""Train PyTorch models directly from POSIX tar archive. + +Code works locally or over HTTP connections. +""" + +import os, random, sys, time +from dataclasses import dataclass, field +from itertools import islice +from typing import List + +import braceexpand, yaml + +from . import utils +from .filters import pipelinefilter +from .paddle_utils import IterableDataset + + +from ..utils.log import Logger +logger = Logger(__name__) +def expand_urls(urls): + if isinstance(urls, str): + urllist = urls.split("::") + result = [] + for url in urllist: + result.extend(braceexpand.braceexpand(url)) + return result + else: + return list(urls) + + +class SimpleShardList(IterableDataset): + """An iterable dataset yielding a list of urls.""" + + def __init__(self, urls, seed=None): + """Iterate through the list of shards. + + :param urls: a list of URLs as a Python list or brace notation string + """ + super().__init__() + urls = expand_urls(urls) + self.urls = urls + assert isinstance(self.urls[0], str) + self.seed = seed + + def __len__(self): + return len(self.urls) + + def __iter__(self): + """Return an iterator over the shards.""" + urls = self.urls.copy() + if self.seed is not None: + random.Random(self.seed).shuffle(urls) + for url in urls: + yield dict(url=url) + + +def split_by_node(src, group=None): + rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group) + logger.info(f"world_size:{world_size}, rank:{rank}") + if world_size > 1: + for s in islice(src, rank, None, world_size): + yield s + else: + for s in src: + yield s + + +def single_node_only(src, group=None): + rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group) + if world_size > 1: + raise ValueError("input pipeline needs to be reconfigured for multinode training") + for s in src: + yield s + + +def split_by_worker(src): + rank, world_size, worker, num_workers = utils.paddle_worker_info() + logger.info(f"num_workers:{num_workers}, worker:{worker}") + if num_workers > 1: + for s in islice(src, worker, None, num_workers): + yield s + else: + for s in src: + yield s + + +def resampled_(src, n=sys.maxsize): + import random + + seed = time.time() + try: + seed = open("/dev/random", "rb").read(20) + except Exception as exn: + print(repr(exn)[:50], file=sys.stderr) + rng = random.Random(seed) + print("# resampled loading", file=sys.stderr) + items = list(src) + print(f"# resampled got {len(items)} samples, yielding {n}", file=sys.stderr) + for i in range(n): + yield rng.choice(items) + + +resampled = pipelinefilter(resampled_) + + +def non_empty(src): + count = 0 + for s in src: + yield s + count += 1 + if count == 0: + raise ValueError("pipeline stage received no data at all and this was declared as an error") + + +@dataclass +class MSSource: + """Class representing a data source.""" + + name: str = "" + perepoch: int = -1 + resample: bool = False + urls: List[str] = field(default_factory=list) + + +default_rng = random.Random() + + +def expand(s): + return os.path.expanduser(os.path.expandvars(s)) + + +class MultiShardSample(IterableDataset): + def __init__(self, fname): + """Construct a shardlist from multiple sources using a YAML spec.""" + self.epoch = -1 +class MultiShardSample(IterableDataset): + def __init__(self, fname): + """Construct a shardlist from multiple sources using a YAML spec.""" + self.epoch = -1 + self.parse_spec(fname) + + def parse_spec(self, fname): + self.rng = default_rng # capture default_rng if we fork + if isinstance(fname, dict): + spec = fname + fname = "{dict}" + else: + with open(fname) as stream: + spec = yaml.safe_load(stream) + assert set(spec.keys()).issubset(set("prefix datasets buckets".split())), list(spec.keys()) + prefix = expand(spec.get("prefix", "")) + self.sources = [] + for ds in spec["datasets"]: + assert set(ds.keys()).issubset(set("buckets name shards resample choose".split())), list( + ds.keys() + ) + buckets = ds.get("buckets", spec.get("buckets", [])) + if isinstance(buckets, str): + buckets = [buckets] + buckets = [expand(s) for s in buckets] + if buckets == []: + buckets = [""] + assert len(buckets) == 1, f"{buckets}: FIXME support for multiple buckets unimplemented" + bucket = buckets[0] + name = ds.get("name", "@" + bucket) + urls = ds["shards"] + if isinstance(urls, str): + urls = [urls] + # urls = [u for url in urls for u in braceexpand.braceexpand(url)] + urls = [ + prefix + os.path.join(bucket, u) for url in urls for u in braceexpand.braceexpand(expand(url)) + ] + resample = ds.get("resample", -1) + nsample = ds.get("choose", -1) + if nsample > len(urls): + raise ValueError(f"perepoch {nsample} must be no greater than the number of shards") + if (nsample > 0) and (resample > 0): + raise ValueError("specify only one of perepoch or choose") + entry = MSSource(name=name, urls=urls, perepoch=nsample, resample=resample) + self.sources.append(entry) + print(f"# {name} {len(urls)} {nsample}", file=sys.stderr) + + def set_epoch(self, seed): + """Set the current epoch (for consistent shard selection among nodes).""" + self.rng = random.Random(seed) + + def get_shards_for_epoch(self): + result = [] + for source in self.sources: + if source.resample > 0: + # sample with replacement + l = self.rng.choices(source.urls, k=source.resample) + elif source.perepoch > 0: + # sample without replacement + l = list(source.urls) + self.rng.shuffle(l) + l = l[: source.perepoch] + else: + l = list(source.urls) + result += l + self.rng.shuffle(result) + return result + + def __iter__(self): + shards = self.get_shards_for_epoch() + for shard in shards: + yield dict(url=shard) + + +def shardspec(spec): + if spec.endswith(".yaml"): + return MultiShardSample(spec) + else: + return SimpleShardList(spec) + + +class ResampledShards(IterableDataset): + """An iterable dataset yielding a list of urls.""" + + def __init__( + self, + urls, + nshards=sys.maxsize, + worker_seed=None, + deterministic=False, + ): + """Sample shards from the shard list with replacement. + + :param urls: a list of URLs as a Python list or brace notation string + """ + super().__init__() + urls = expand_urls(urls) + self.urls = urls + assert isinstance(self.urls[0], str) + self.nshards = nshards + self.worker_seed = utils.paddle_worker_seed if worker_seed is None else worker_seed + self.deterministic = deterministic + self.epoch = -1 + + def __iter__(self): + """Return an iterator over the shards.""" + self.epoch += 1 + if self.deterministic: + seed = utils.make_seed(self.worker_seed(), self.epoch) + else: + seed = utils.make_seed(self.worker_seed(), self.epoch, os.getpid(), time.time_ns(), os.urandom(4)) + if os.environ.get("WDS_SHOW_SEED", "0") == "1": + print(f"# ResampledShards seed {seed}") + self.rng = random.Random(seed) + for _ in range(self.nshards): + index = self.rng.randint(0, len(self.urls) - 1) + yield dict(url=self.urls[index]) diff --git a/paddlespeech/audio/streamdata/tariterators.py b/paddlespeech/audio/streamdata/tariterators.py new file mode 100644 index 00000000..b1616918 --- /dev/null +++ b/paddlespeech/audio/streamdata/tariterators.py @@ -0,0 +1,283 @@ +# +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). + +# Modified from https://github.com/webdataset/webdataset +# Modified from wenet(https://github.com/wenet-e2e/wenet) + +"""Low level iteration functions for tar archives.""" + +import random, re, tarfile + +import braceexpand + +from . import filters +from . import gopen +from .handlers import reraise_exception + +trace = False +meta_prefix = "__" +meta_suffix = "__" + +import paddlespeech +import paddle +import numpy as np + +AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma']) + +def base_plus_ext(path): + """Split off all file extensions. + + Returns base, allext. + + :param path: path with extensions + :param returns: path with all extensions removed + + """ + match = re.match(r"^((?:.*/|)[^.]+)[.]([^/]*)$", path) + if not match: + return None, None + return match.group(1), match.group(2) + + +def valid_sample(sample): + """Check whether a sample is valid. + + :param sample: sample to be checked + """ + return ( + sample is not None + and isinstance(sample, dict) + and len(list(sample.keys())) > 0 + and not sample.get("__bad__", False) + ) + + +# FIXME: UNUSED +def shardlist(urls, *, shuffle=False): + """Given a list of URLs, yields that list, possibly shuffled.""" + if isinstance(urls, str): + urls = braceexpand.braceexpand(urls) + else: + urls = list(urls) + if shuffle: + random.shuffle(urls) + for url in urls: + yield dict(url=url) + + +def url_opener(data, handler=reraise_exception, **kw): + """Given a stream of url names (packaged in `dict(url=url)`), yield opened streams.""" + for sample in data: + assert isinstance(sample, dict), sample + assert "url" in sample + url = sample["url"] + try: + stream = gopen.gopen(url, **kw) + sample.update(stream=stream) + yield sample + except Exception as exn: + exn.args = exn.args + (url,) + if handler(exn): + continue + else: + break + + +def tar_file_iterator( + fileobj, skip_meta=r"__[^/]*__($|/)", handler=reraise_exception +): + """Iterate over tar file, yielding filename, content pairs for the given tar stream. + + :param fileobj: byte stream suitable for tarfile + :param skip_meta: regexp for keys that are skipped entirely (Default value = r"__[^/]*__($|/)") + + """ + stream = tarfile.open(fileobj=fileobj, mode="r:*") + for tarinfo in stream: + fname = tarinfo.name + try: + if not tarinfo.isreg(): + continue + if fname is None: + continue + if ( + "/" not in fname + and fname.startswith(meta_prefix) + and fname.endswith(meta_suffix) + ): + # skipping metadata for now + continue + if skip_meta is not None and re.match(skip_meta, fname): + continue + + name = tarinfo.name + pos = name.rfind('.') + assert pos > 0 + prefix, postfix = name[:pos], name[pos + 1:] + if postfix == 'wav': + waveform, sample_rate = paddlespeech.audio.load(stream.extractfile(tarinfo), normal=False) + result = dict(fname=prefix, wav=waveform, sample_rate = sample_rate) + else: + txt = stream.extractfile(tarinfo).read().decode('utf8').strip() + result = dict(fname=prefix, txt=txt) + #result = dict(fname=fname, data=data) + yield result + stream.members = [] + except Exception as exn: + if hasattr(exn, "args") and len(exn.args) > 0: + exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:] + if handler(exn): + continue + else: + break + del stream + +def tar_file_and_group_iterator( + fileobj, skip_meta=r"__[^/]*__($|/)", handler=reraise_exception +): + """ Expand a stream of open tar files into a stream of tar file contents. + And groups the file with same prefix + + Args: + data: Iterable[{src, stream}] + + Returns: + Iterable[{key, wav, txt, sample_rate}] + """ + stream = tarfile.open(fileobj=fileobj, mode="r:*") + prev_prefix = None + example = {} + valid = True + for tarinfo in stream: + name = tarinfo.name + pos = name.rfind('.') + assert pos > 0 + prefix, postfix = name[:pos], name[pos + 1:] + if prev_prefix is not None and prefix != prev_prefix: + example['fname'] = prev_prefix + if valid: + yield example + example = {} + valid = True + with stream.extractfile(tarinfo) as file_obj: + try: + if postfix == 'txt': + example['txt'] = file_obj.read().decode('utf8').strip() + elif postfix in AUDIO_FORMAT_SETS: + waveform, sample_rate = paddlespeech.audio.load(file_obj, normal=False) + waveform = paddle.to_tensor(np.expand_dims(np.array(waveform),0), dtype=paddle.float32) + + example['wav'] = waveform + example['sample_rate'] = sample_rate + else: + example[postfix] = file_obj.read() + except Exception as exn: + if hasattr(exn, "args") and len(exn.args) > 0: + exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:] + if handler(exn): + continue + else: + break + valid = False + # logging.warning('error to parse {}'.format(name)) + prev_prefix = prefix + if prev_prefix is not None: + example['fname'] = prev_prefix + yield example + stream.close() + +def tar_file_expander(data, handler=reraise_exception): + """Expand a stream of open tar files into a stream of tar file contents. + + This returns an iterator over (filename, file_contents). + """ + for source in data: + url = source["url"] + try: + assert isinstance(source, dict) + assert "stream" in source + for sample in tar_file_iterator(source["stream"]): + assert ( + isinstance(sample, dict) and "data" in sample and "fname" in sample + ) + sample["__url__"] = url + yield sample + except Exception as exn: + exn.args = exn.args + (source.get("stream"), source.get("url")) + if handler(exn): + continue + else: + break + + + + +def tar_file_and_group_expander(data, handler=reraise_exception): + """Expand a stream of open tar files into a stream of tar file contents. + + This returns an iterator over (filename, file_contents). + """ + for source in data: + url = source["url"] + try: + assert isinstance(source, dict) + assert "stream" in source + for sample in tar_file_and_group_iterator(source["stream"]): + assert ( + isinstance(sample, dict) and "wav" in sample and "txt" in sample and "fname" in sample + ) + sample["__url__"] = url + yield sample + except Exception as exn: + exn.args = exn.args + (source.get("stream"), source.get("url")) + if handler(exn): + continue + else: + break + + +def group_by_keys(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): + """Return function over iterator that groups key, value pairs into samples. + + :param keys: function that splits the key into key and extension (base_plus_ext) + :param lcase: convert suffixes to lower case (Default value = True) + """ + current_sample = None + for filesample in data: + assert isinstance(filesample, dict) + fname, value = filesample["fname"], filesample["data"] + prefix, suffix = keys(fname) + if trace: + print( + prefix, + suffix, + current_sample.keys() if isinstance(current_sample, dict) else None, + ) + if prefix is None: + continue + if lcase: + suffix = suffix.lower() + if current_sample is None or prefix != current_sample["__key__"]: + if valid_sample(current_sample): + yield current_sample + current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) + if suffix in current_sample: + raise ValueError( + f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}" + ) + if suffixes is None or suffix in suffixes: + current_sample[suffix] = value + if valid_sample(current_sample): + yield current_sample + + +def tarfile_samples(src, handler=reraise_exception): + streams = url_opener(src, handler=handler) + samples = tar_file_and_group_expander(streams, handler=handler) + return samples + + +tarfile_to_samples = filters.pipelinefilter(tarfile_samples) diff --git a/paddlespeech/audio/streamdata/utils.py b/paddlespeech/audio/streamdata/utils.py new file mode 100644 index 00000000..c7294f2b --- /dev/null +++ b/paddlespeech/audio/streamdata/utils.py @@ -0,0 +1,132 @@ +# +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). +# + +# Modified from https://github.com/webdataset/webdataset + +"""Miscellaneous utility functions.""" + +import importlib +import itertools as itt +import os +import re +import sys +from typing import Any, Callable, Iterator, Optional, Union + +from ..utils.log import Logger + +logger = Logger(__name__) + +def make_seed(*args): + seed = 0 + for arg in args: + seed = (seed * 31 + hash(arg)) & 0x7FFFFFFF + return seed + + +class PipelineStage: + def invoke(self, *args, **kw): + raise NotImplementedError + + +def identity(x: Any) -> Any: + """Return the argument as is.""" + return x + + +def safe_eval(s: str, expr: str = "{}"): + """Evaluate the given expression more safely.""" + if re.sub("[^A-Za-z0-9_]", "", s) != s: + raise ValueError(f"safe_eval: illegal characters in: '{s}'") + return eval(expr.format(s)) + + +def lookup_sym(sym: str, modules: list): + """Look up a symbol in a list of modules.""" + for mname in modules: + module = importlib.import_module(mname, package="webdataset") + result = getattr(module, sym, None) + if result is not None: + return result + return None + + +def repeatedly0( + loader: Iterator, nepochs: int = sys.maxsize, nbatches: int = sys.maxsize +): + """Repeatedly returns batches from a DataLoader.""" + for epoch in range(nepochs): + for sample in itt.islice(loader, nbatches): + yield sample + + +def guess_batchsize(batch: Union[tuple, list]): + """Guess the batch size by looking at the length of the first element in a tuple.""" + return len(batch[0]) + + +def repeatedly( + source: Iterator, + nepochs: int = None, + nbatches: int = None, + nsamples: int = None, + batchsize: Callable[..., int] = guess_batchsize, +): + """Repeatedly yield samples from an iterator.""" + epoch = 0 + batch = 0 + total = 0 + while True: + for sample in source: + yield sample + batch += 1 + if nbatches is not None and batch >= nbatches: + return + if nsamples is not None: + total += guess_batchsize(sample) + if total >= nsamples: + return + epoch += 1 + if nepochs is not None and epoch >= nepochs: + return + +def paddle_worker_info(group=None): + """Return node and worker info for PyTorch and some distributed environments.""" + rank = 0 + world_size = 1 + worker = 0 + num_workers = 1 + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + else: + try: + import paddle.distributed + group = group or paddle.distributed.get_group() + rank = paddle.distributed.get_rank() + world_size = paddle.distributed.get_world_size() + except ModuleNotFoundError: + pass + if "WORKER" in os.environ and "NUM_WORKERS" in os.environ: + worker = int(os.environ["WORKER"]) + num_workers = int(os.environ["NUM_WORKERS"]) + else: + try: + from paddle.io import get_worker_info + worker_info = paddle.io.get_worker_info() + if worker_info is not None: + worker = worker_info.id + num_workers = worker_info.num_workers + except ModuleNotFoundError as E: + logger.info(f"not found {E}") + exit(-1) + + return rank, world_size, worker, num_workers + +def paddle_worker_seed(group=None): + """Compute a distinct, deterministic RNG seed for each worker and node.""" + rank, world_size, worker, num_workers = paddle_worker_info(group=group) + return rank * 1000 + worker diff --git a/paddlespeech/audio/streamdata/writer.py b/paddlespeech/audio/streamdata/writer.py new file mode 100644 index 00000000..7d4f7703 --- /dev/null +++ b/paddlespeech/audio/streamdata/writer.py @@ -0,0 +1,450 @@ +# +# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). +# Modified from https://github.com/webdataset/webdataset +# + +"""Classes and functions for writing tar files and WebDataset files.""" + +import io, json, pickle, re, tarfile, time +from typing import Any, Callable, Optional, Union + +import numpy as np + +from . import gopen + + +def imageencoder(image: Any, format: str = "PNG"): # skipcq: PYL-W0622 + """Compress an image using PIL and return it as a string. + + Can handle float or uint8 images. + + :param image: ndarray representing an image + :param format: compression format (PNG, JPEG, PPM) + + """ + import PIL + + assert isinstance(image, (PIL.Image.Image, np.ndarray)), type(image) + + if isinstance(image, np.ndarray): + if image.dtype in [np.dtype("f"), np.dtype("d")]: + if not (np.amin(image) > -0.001 and np.amax(image) < 1.001): + raise ValueError( + f"image values out of range {np.amin(image)} {np.amax(image)}" + ) + image = np.clip(image, 0.0, 1.0) + image = np.array(image * 255.0, "uint8") + assert image.ndim in [2, 3] + if image.ndim == 3: + assert image.shape[2] in [1, 3] + image = PIL.Image.fromarray(image) + if format.upper() == "JPG": + format = "JPEG" + elif format.upper() in ["IMG", "IMAGE"]: + format = "PPM" + if format == "JPEG": + opts = dict(quality=100) + else: + opts = {} + with io.BytesIO() as result: + image.save(result, format=format, **opts) + return result.getvalue() + + +def bytestr(data: Any): + """Convert data into a bytestring. + + Uses str and ASCII encoding for data that isn't already in string format. + + :param data: data + """ + if isinstance(data, bytes): + return data + if isinstance(data, str): + return data.encode("ascii") + return str(data).encode("ascii") + +def paddle_dumps(data: Any): + """Dump data into a bytestring using paddle.dumps. + + This delays importing paddle until needed. + + :param data: data to be dumped + """ + import io + + import paddle + + stream = io.BytesIO() + paddle.save(data, stream) + return stream.getvalue() + +def numpy_dumps(data: np.ndarray): + """Dump data into a bytestring using numpy npy format. + + :param data: data to be dumped + """ + import io + + import numpy.lib.format + + stream = io.BytesIO() + numpy.lib.format.write_array(stream, data) + return stream.getvalue() + + +def numpy_npz_dumps(data: np.ndarray): + """Dump data into a bytestring using numpy npz format. + + :param data: data to be dumped + """ + import io + + stream = io.BytesIO() + np.savez_compressed(stream, **data) + return stream.getvalue() + + +def tenbin_dumps(x): + from . import tenbin + + if isinstance(x, list): + return memoryview(tenbin.encode_buffer(x)) + else: + return memoryview(tenbin.encode_buffer([x])) + + +def cbor_dumps(x): + import cbor + + return cbor.dumps(x) + + +def mp_dumps(x): + import msgpack + + return msgpack.packb(x) + + +def add_handlers(d, keys, value): + if isinstance(keys, str): + keys = keys.split() + for k in keys: + d[k] = value + + +def make_handlers(): + """Create a list of handlers for encoding data.""" + handlers = {} + add_handlers( + handlers, "cls cls2 class count index inx id", lambda x: str(x).encode("ascii") + ) + add_handlers(handlers, "txt text transcript", lambda x: x.encode("utf-8")) + add_handlers(handlers, "html htm", lambda x: x.encode("utf-8")) + add_handlers(handlers, "pyd pickle", pickle.dumps) + add_handlers(handlers, "pdparams", paddle_dumps) + add_handlers(handlers, "npy", numpy_dumps) + add_handlers(handlers, "npz", numpy_npz_dumps) + add_handlers(handlers, "ten tenbin tb", tenbin_dumps) + add_handlers(handlers, "json jsn", lambda x: json.dumps(x).encode("utf-8")) + add_handlers(handlers, "mp msgpack msg", mp_dumps) + add_handlers(handlers, "cbor", cbor_dumps) + add_handlers(handlers, "jpg jpeg img image", lambda data: imageencoder(data, "jpg")) + add_handlers(handlers, "png", lambda data: imageencoder(data, "png")) + add_handlers(handlers, "pbm", lambda data: imageencoder(data, "pbm")) + add_handlers(handlers, "pgm", lambda data: imageencoder(data, "pgm")) + add_handlers(handlers, "ppm", lambda data: imageencoder(data, "ppm")) + return handlers + + +default_handlers = make_handlers() + + +def encode_based_on_extension1(data: Any, tname: str, handlers: dict): + """Encode data based on its extension and a dict of handlers. + + :param data: data + :param tname: file extension + :param handlers: handlers + """ + if tname[0] == "_": + if not isinstance(data, str): + raise ValueError("the values of metadata must be of string type") + return data + extension = re.sub(r".*\.", "", tname).lower() + if isinstance(data, bytes): + return data + if isinstance(data, str): + return data.encode("utf-8") + handler = handlers.get(extension) + if handler is None: + raise ValueError(f"no handler found for {extension}") + return handler(data) + + +def encode_based_on_extension(sample: dict, handlers: dict): + """Encode an entire sample with a collection of handlers. + + :param sample: data sample (a dict) + :param handlers: handlers for encoding + """ + return { + k: encode_based_on_extension1(v, k, handlers) for k, v in list(sample.items()) + } + + +def make_encoder(spec: Union[bool, str, dict, Callable]): + """Make an encoder function from a specification. + + :param spec: specification + """ + if spec is False or spec is None: + + def encoder(x): + """Do not encode at all.""" + return x + + elif callable(spec): + encoder = spec + elif isinstance(spec, dict): + + def f(sample): + """Encode based on extension.""" + return encode_based_on_extension(sample, spec) + + encoder = f + + elif spec is True: + handlers = default_handlers + + def g(sample): + """Encode based on extension.""" + return encode_based_on_extension(sample, handlers) + + encoder = g + + else: + raise ValueError(f"{spec}: unknown decoder spec") + if not callable(encoder): + raise ValueError(f"{spec} did not yield a callable encoder") + return encoder + + +class TarWriter: + """A class for writing dictionaries to tar files. + + :param fileobj: fileobj: file name for tar file (.tgz/.tar) or open file descriptor + :param encoder: sample encoding (Default value = True) + :param compress: (Default value = None) + + `True` will use an encoder that behaves similar to the automatic + decoder for `Dataset`. `False` disables encoding and expects byte strings + (except for metadata, which must be strings). The `encoder` argument can + also be a `callable`, or a dictionary mapping extensions to encoders. + + The following code will add two file to the tar archive: `a/b.png` and + `a/b.output.png`. + + ```Python + tarwriter = TarWriter(stream) + image = imread("b.jpg") + image2 = imread("b.out.jpg") + sample = {"__key__": "a/b", "png": image, "output.png": image2} + tarwriter.write(sample) + ``` + """ + + def __init__( + self, + fileobj, + user: str = "bigdata", + group: str = "bigdata", + mode: int = 0o0444, + compress: Optional[bool] = None, + encoder: Union[None, bool, Callable] = True, + keep_meta: bool = False, + ): + """Create a tar writer. + + :param fileobj: stream to write data to + :param user: user for tar files + :param group: group for tar files + :param mode: mode for tar files + :param compress: desired compression + :param encoder: encoder function + :param keep_meta: keep metadata (entries starting with "_") + """ + if isinstance(fileobj, str): + if compress is False: + tarmode = "w|" + elif compress is True: + tarmode = "w|gz" + else: + tarmode = "w|gz" if fileobj.endswith("gz") else "w|" + fileobj = gopen.gopen(fileobj, "wb") + self.own_fileobj = fileobj + else: + tarmode = "w|gz" if compress is True else "w|" + self.own_fileobj = None + self.encoder = make_encoder(encoder) + self.keep_meta = keep_meta + self.stream = fileobj + self.tarstream = tarfile.open(fileobj=fileobj, mode=tarmode) + + self.user = user + self.group = group + self.mode = mode + self.compress = compress + + def __enter__(self): + """Enter context.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Exit context.""" + self.close() + + def close(self): + """Close the tar file.""" + self.tarstream.close() + if self.own_fileobj is not None: + self.own_fileobj.close() + self.own_fileobj = None + + def write(self, obj): + """Write a dictionary to the tar file. + + :param obj: dictionary of objects to be stored + :returns: size of the entry + + """ + total = 0 + obj = self.encoder(obj) + if "__key__" not in obj: + raise ValueError("object must contain a __key__") + for k, v in list(obj.items()): + if k[0] == "_": + continue + if not isinstance(v, (bytes, bytearray, memoryview)): + raise ValueError( + f"{k} doesn't map to a bytes after encoding ({type(v)})" + ) + key = obj["__key__"] + for k in sorted(obj.keys()): + if k == "__key__": + continue + if not self.keep_meta and k[0] == "_": + continue + v = obj[k] + if isinstance(v, str): + v = v.encode("utf-8") + now = time.time() + ti = tarfile.TarInfo(key + "." + k) + ti.size = len(v) + ti.mtime = now + ti.mode = self.mode + ti.uname = self.user + ti.gname = self.group + if not isinstance(v, (bytes, bytearray, memoryview)): + raise ValueError(f"converter didn't yield bytes: {k}, {type(v)}") + stream = io.BytesIO(v) + self.tarstream.addfile(ti, stream) + total += ti.size + return total + + +class ShardWriter: + """Like TarWriter but splits into multiple shards.""" + + def __init__( + self, + pattern: str, + maxcount: int = 100000, + maxsize: float = 3e9, + post: Optional[Callable] = None, + start_shard: int = 0, + **kw, + ): + """Create a ShardWriter. + + :param pattern: output file pattern + :param maxcount: maximum number of records per shard (Default value = 100000) + :param maxsize: maximum size of each shard (Default value = 3e9) + :param kw: other options passed to TarWriter + """ + self.verbose = 1 + self.kw = kw + self.maxcount = maxcount + self.maxsize = maxsize + self.post = post + + self.tarstream = None + self.shard = start_shard + self.pattern = pattern + self.total = 0 + self.count = 0 + self.size = 0 + self.fname = None + self.next_stream() + + def next_stream(self): + """Close the current stream and move to the next.""" + self.finish() + self.fname = self.pattern % self.shard + if self.verbose: + print( + "# writing", + self.fname, + self.count, + "%.1f GB" % (self.size / 1e9), + self.total, + ) + self.shard += 1 + stream = open(self.fname, "wb") + self.tarstream = TarWriter(stream, **self.kw) + self.count = 0 + self.size = 0 + + def write(self, obj): + """Write a sample. + + :param obj: sample to be written + """ + if ( + self.tarstream is None + or self.count >= self.maxcount + or self.size >= self.maxsize + ): + self.next_stream() + size = self.tarstream.write(obj) + self.count += 1 + self.total += 1 + self.size += size + + def finish(self): + """Finish all writing (use close instead).""" + if self.tarstream is not None: + self.tarstream.close() + assert self.fname is not None + if callable(self.post): + self.post(self.fname) + self.tarstream = None + + def close(self): + """Close the stream.""" + self.finish() + del self.tarstream + del self.shard + del self.count + del self.size + + def __enter__(self): + """Enter context.""" + return self + + def __exit__(self, *args, **kw): + """Exit context.""" + self.close() diff --git a/paddlespeech/audio/text/text_featurizer.py b/paddlespeech/audio/text/text_featurizer.py new file mode 100644 index 00000000..91c4d75c --- /dev/null +++ b/paddlespeech/audio/text/text_featurizer.py @@ -0,0 +1,235 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains the text featurizer class.""" +from pprint import pformat +from typing import Union + +import sentencepiece as spm + +from .utility import BLANK +from .utility import EOS +from .utility import load_dict +from .utility import MASKCTC +from .utility import SOS +from .utility import SPACE +from .utility import UNK +from ..utils.log import Logger + +logger = Logger(__name__) + +__all__ = ["TextFeaturizer"] + + +class TextFeaturizer(): + def __init__(self, unit_type, vocab, spm_model_prefix=None, maskctc=False): + """Text featurizer, for processing or extracting features from text. + + Currently, it supports char/word/sentence-piece level tokenizing and conversion into + a list of token indices. Note that the token indexing order follows the + given vocabulary file. + + Args: + unit_type (str): unit type, e.g. char, word, spm + vocab Option[str, list]: Filepath to load vocabulary for token indices conversion, or vocab list. + spm_model_prefix (str, optional): spm model prefix. Defaults to None. + """ + assert unit_type in ('char', 'spm', 'word') + self.unit_type = unit_type + self.unk = UNK + self.maskctc = maskctc + + if vocab: + self.vocab_dict, self._id2token, self.vocab_list, self.unk_id, self.eos_id, self.blank_id = self._load_vocabulary_from_file( + vocab, maskctc) + self.vocab_size = len(self.vocab_list) + else: + logger.warning("TextFeaturizer: not have vocab file or vocab list.") + + if unit_type == 'spm': + spm_model = spm_model_prefix + '.model' + self.sp = spm.SentencePieceProcessor() + self.sp.Load(spm_model) + + def tokenize(self, text, replace_space=True): + if self.unit_type == 'char': + tokens = self.char_tokenize(text, replace_space) + elif self.unit_type == 'word': + tokens = self.word_tokenize(text) + else: # spm + tokens = self.spm_tokenize(text) + return tokens + + def detokenize(self, tokens): + if self.unit_type == 'char': + text = self.char_detokenize(tokens) + elif self.unit_type == 'word': + text = self.word_detokenize(tokens) + else: # spm + text = self.spm_detokenize(tokens) + return text + + def featurize(self, text): + """Convert text string to a list of token indices. + + Args: + text (str): Text to process. + + Returns: + List[int]: List of token indices. + """ + tokens = self.tokenize(text) + ids = [] + for token in tokens: + if token not in self.vocab_dict: + logger.debug(f"Text Token: {token} -> {self.unk}") + token = self.unk + ids.append(self.vocab_dict[token]) + return ids + + def defeaturize(self, idxs): + """Convert a list of token indices to text string, + ignore index after eos_id. + + Args: + idxs (List[int]): List of token indices. + + Returns: + str: Text. + """ + tokens = [] + for idx in idxs: + if idx == self.eos_id: + break + tokens.append(self._id2token[idx]) + text = self.detokenize(tokens) + return text + + def char_tokenize(self, text, replace_space=True): + """Character tokenizer. + + Args: + text (str): text string. + replace_space (bool): False only used by build_vocab.py. + + Returns: + List[str]: tokens. + """ + text = text.strip() + if replace_space: + text_list = [SPACE if item == " " else item for item in list(text)] + else: + text_list = list(text) + return text_list + + def char_detokenize(self, tokens): + """Character detokenizer. + + Args: + tokens (List[str]): tokens. + + Returns: + str: text string. + """ + tokens = [t.replace(SPACE, " ") for t in tokens] + return "".join(tokens) + + def word_tokenize(self, text): + """Word tokenizer, separate by .""" + return text.strip().split() + + def word_detokenize(self, tokens): + """Word detokenizer, separate by .""" + return " ".join(tokens) + + def spm_tokenize(self, text): + """spm tokenize. + + Args: + text (str): text string. + + Returns: + List[str]: sentence pieces str code + """ + stats = {"num_empty": 0, "num_filtered": 0} + + def valid(line): + return True + + def encode(l): + return self.sp.EncodeAsPieces(l) + + def encode_line(line): + line = line.strip() + if len(line) > 0: + line = encode(line) + if valid(line): + return line + else: + stats["num_filtered"] += 1 + else: + stats["num_empty"] += 1 + return None + + enc_line = encode_line(text) + return enc_line + + def spm_detokenize(self, tokens, input_format='piece'): + """spm detokenize. + + Args: + ids (List[str]): tokens. + + Returns: + str: text + """ + if input_format == "piece": + + def decode(l): + return "".join(self.sp.DecodePieces(l)) + elif input_format == "id": + + def decode(l): + return "".join(self.sp.DecodeIds(l)) + + return decode(tokens) + + def _load_vocabulary_from_file(self, vocab: Union[str, list], + maskctc: bool): + """Load vocabulary from file.""" + if isinstance(vocab, list): + vocab_list = vocab + else: + vocab_list = load_dict(vocab, maskctc) + assert vocab_list is not None + logger.debug(f"Vocab: {pformat(vocab_list)}") + + id2token = dict( + [(idx, token) for (idx, token) in enumerate(vocab_list)]) + token2id = dict( + [(token, idx) for (idx, token) in enumerate(vocab_list)]) + + blank_id = vocab_list.index(BLANK) if BLANK in vocab_list else -1 + maskctc_id = vocab_list.index(MASKCTC) if MASKCTC in vocab_list else -1 + unk_id = vocab_list.index(UNK) if UNK in vocab_list else -1 + eos_id = vocab_list.index(EOS) if EOS in vocab_list else -1 + sos_id = vocab_list.index(SOS) if SOS in vocab_list else -1 + space_id = vocab_list.index(SPACE) if SPACE in vocab_list else -1 + + logger.info(f"BLANK id: {blank_id}") + logger.info(f"UNK id: {unk_id}") + logger.info(f"EOS id: {eos_id}") + logger.info(f"SOS id: {sos_id}") + logger.info(f"SPACE id: {space_id}") + logger.info(f"MASKCTC id: {maskctc_id}") + return token2id, id2token, vocab_list, unk_id, eos_id, blank_id diff --git a/paddlespeech/audio/text/utility.py b/paddlespeech/audio/text/utility.py new file mode 100644 index 00000000..d35785db --- /dev/null +++ b/paddlespeech/audio/text/utility.py @@ -0,0 +1,393 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains data helper functions.""" +import json +import math +import tarfile +from collections import namedtuple +from typing import List +from typing import Optional +from typing import Text + +import jsonlines +import numpy as np + +from paddlespeech.s2t.utils.log import Log + +logger = Log(__name__).getlog() + +__all__ = [ + "load_dict", "load_cmvn", "read_manifest", "rms_to_db", "rms_to_dbfs", + "max_dbfs", "mean_dbfs", "gain_db_to_ratio", "normalize_audio", "SOS", + "EOS", "UNK", "BLANK", "MASKCTC", "SPACE", "convert_samples_to_float32", + "convert_samples_from_float32" +] + +IGNORE_ID = -1 +# `sos` and `eos` using same token +SOS = "" +EOS = SOS +UNK = "" +BLANK = "" +MASKCTC = "" +SPACE = "" + + +def load_dict(dict_path: Optional[Text], maskctc=False) -> Optional[List[Text]]: + if dict_path is None: + return None + + with open(dict_path, "r") as f: + dictionary = f.readlines() + # first token is `` + # multi line: ` 0\n` + # one line: `` + # space is relpace with + char_list = [entry[:-1].split(" ")[0] for entry in dictionary] + if BLANK not in char_list: + char_list.insert(0, BLANK) + if EOS not in char_list: + char_list.append(EOS) + # for non-autoregressive maskctc model + if maskctc and MASKCTC not in char_list: + char_list.append(MASKCTC) + return char_list + + +def read_manifest( + manifest_path, + max_input_len=float('inf'), + min_input_len=0.0, + max_output_len=float('inf'), + min_output_len=0.0, + max_output_input_ratio=float('inf'), + min_output_input_ratio=0.0, ): + """Load and parse manifest file. + + Args: + manifest_path ([type]): Manifest file to load and parse. + max_input_len ([type], optional): maximum output seq length, + in seconds for raw wav, in frame numbers for feature data. + Defaults to float('inf'). + min_input_len (float, optional): minimum input seq length, + in seconds for raw wav, in frame numbers for feature data. + Defaults to 0.0. + max_output_len (float, optional): maximum input seq length, + in modeling units. Defaults to 500.0. + min_output_len (float, optional): minimum input seq length, + in modeling units. Defaults to 0.0. + max_output_input_ratio (float, optional): + maximum output seq length/output seq length ratio. Defaults to 10.0. + min_output_input_ratio (float, optional): + minimum output seq length/output seq length ratio. Defaults to 0.05. + + Raises: + IOError: If failed to parse the manifest. + + Returns: + List[dict]: Manifest parsing results. + """ + manifest = [] + with jsonlines.open(manifest_path, 'r') as reader: + for json_data in reader: + feat_len = json_data["input"][0]["shape"][ + 0] if "input" in json_data and "shape" in json_data["input"][ + 0] else 1.0 + token_len = json_data["output"][0]["shape"][ + 0] if "output" in json_data and "shape" in json_data["output"][ + 0] else 1.0 + conditions = [ + feat_len >= min_input_len, + feat_len <= max_input_len, + token_len >= min_output_len, + token_len <= max_output_len, + token_len / feat_len >= min_output_input_ratio, + token_len / feat_len <= max_output_input_ratio, + ] + if all(conditions): + manifest.append(json_data) + return manifest + + +# Tar File read +TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object']) + + +def parse_tar(file): + """Parse a tar file to get a tarfile object + and a map containing tarinfoes + """ + result = {} + f = tarfile.open(file) + for tarinfo in f.getmembers(): + result[tarinfo.name] = tarinfo + return f, result + + +def subfile_from_tar(file, local_data=None): + """Get subfile object from tar. + + tar:tarpath#filename + + It will return a subfile object from tar file + and cached tar file info for next reading request. + """ + tarpath, filename = file.split(':', 1)[1].split('#', 1) + + if local_data is None: + local_data = TarLocalData(tar2info={}, tar2object={}) + + assert isinstance(local_data, TarLocalData) + + if 'tar2info' not in local_data.__dict__: + local_data.tar2info = {} + if 'tar2object' not in local_data.__dict__: + local_data.tar2object = {} + + if tarpath not in local_data.tar2info: + fobj, infos = parse_tar(tarpath) + local_data.tar2info[tarpath] = infos + local_data.tar2object[tarpath] = fobj + else: + fobj = local_data.tar2object[tarpath] + infos = local_data.tar2info[tarpath] + return fobj.extractfile(infos[filename]) + + +def rms_to_db(rms: float): + """Root Mean Square to dB. + + Args: + rms ([float]): root mean square + + Returns: + float: dB + """ + return 20.0 * math.log10(max(1e-16, rms)) + + +def rms_to_dbfs(rms: float): + """Root Mean Square to dBFS. + https://fireattack.wordpress.com/2017/02/06/replaygain-loudness-normalization-and-applications/ + Audio is mix of sine wave, so 1 amp sine wave's Full scale is 0.7071, equal to -3.0103dB. + + dB = dBFS + 3.0103 + dBFS = db - 3.0103 + e.g. 0 dB = -3.0103 dBFS + + Args: + rms ([float]): root mean square + + Returns: + float: dBFS + """ + return rms_to_db(rms) - 3.0103 + + +def max_dbfs(sample_data: np.ndarray): + """Peak dBFS based on the maximum energy sample. + + Args: + sample_data ([np.ndarray]): float array, [-1, 1]. + + Returns: + float: dBFS + """ + # Peak dBFS based on the maximum energy sample. Will prevent overdrive if used for normalization. + return rms_to_dbfs(max(abs(np.min(sample_data)), abs(np.max(sample_data)))) + + +def mean_dbfs(sample_data): + """Peak dBFS based on the RMS energy. + + Args: + sample_data ([np.ndarray]): float array, [-1, 1]. + + Returns: + float: dBFS + """ + return rms_to_dbfs( + math.sqrt(np.mean(np.square(sample_data, dtype=np.float64)))) + + +def gain_db_to_ratio(gain_db: float): + """dB to ratio + + Args: + gain_db (float): gain in dB + + Returns: + float: scale in amp + """ + return math.pow(10.0, gain_db / 20.0) + + +def normalize_audio(sample_data: np.ndarray, dbfs: float=-3.0103): + """Nomalize audio to dBFS. + + Args: + sample_data (np.ndarray): input wave samples, [-1, 1]. + dbfs (float, optional): target dBFS. Defaults to -3.0103. + + Returns: + np.ndarray: normalized wave + """ + return np.maximum( + np.minimum(sample_data * gain_db_to_ratio(dbfs - max_dbfs(sample_data)), + 1.0), -1.0) + + +def _load_json_cmvn(json_cmvn_file): + """ Load the json format cmvn stats file and calculate cmvn + + Args: + json_cmvn_file: cmvn stats file in json format + + Returns: + a numpy array of [means, vars] + """ + with open(json_cmvn_file) as f: + cmvn_stats = json.load(f) + + means = cmvn_stats['mean_stat'] + variance = cmvn_stats['var_stat'] + count = cmvn_stats['frame_num'] + for i in range(len(means)): + means[i] /= count + variance[i] = variance[i] / count - means[i] * means[i] + if variance[i] < 1.0e-20: + variance[i] = 1.0e-20 + variance[i] = 1.0 / math.sqrt(variance[i]) + cmvn = np.array([means, variance]) + return cmvn + + +def _load_kaldi_cmvn(kaldi_cmvn_file): + """ Load the kaldi format cmvn stats file and calculate cmvn + + Args: + kaldi_cmvn_file: kaldi text style global cmvn file, which + is generated by: + compute-cmvn-stats --binary=false scp:feats.scp global_cmvn + + Returns: + a numpy array of [means, vars] + """ + means = [] + variance = [] + with open(kaldi_cmvn_file, 'r') as fid: + # kaldi binary file start with '\0B' + if fid.read(2) == '\0B': + logger.error('kaldi cmvn binary file is not supported, please ' + 'recompute it by: compute-cmvn-stats --binary=false ' + ' scp:feats.scp global_cmvn') + sys.exit(1) + fid.seek(0) + arr = fid.read().split() + assert (arr[0] == '[') + assert (arr[-2] == '0') + assert (arr[-1] == ']') + feat_dim = int((len(arr) - 2 - 2) / 2) + for i in range(1, feat_dim + 1): + means.append(float(arr[i])) + count = float(arr[feat_dim + 1]) + for i in range(feat_dim + 2, 2 * feat_dim + 2): + variance.append(float(arr[i])) + + for i in range(len(means)): + means[i] /= count + variance[i] = variance[i] / count - means[i] * means[i] + if variance[i] < 1.0e-20: + variance[i] = 1.0e-20 + variance[i] = 1.0 / math.sqrt(variance[i]) + cmvn = np.array([means, variance]) + return cmvn + + +def load_cmvn(cmvn_file: str, filetype: str): + """load cmvn from file. + + Args: + cmvn_file (str): cmvn path. + filetype (str): file type, optional[npz, json, kaldi]. + + Raises: + ValueError: file type not support. + + Returns: + Tuple[np.ndarray, np.ndarray]: mean, istd + """ + assert filetype in ['npz', 'json', 'kaldi'], filetype + filetype = filetype.lower() + if filetype == "json": + cmvn = _load_json_cmvn(cmvn_file) + elif filetype == "kaldi": + cmvn = _load_kaldi_cmvn(cmvn_file) + elif filetype == "npz": + eps = 1e-14 + npzfile = np.load(cmvn_file) + mean = np.squeeze(npzfile["mean"]) + std = np.squeeze(npzfile["std"]) + istd = 1 / (std + eps) + cmvn = [mean, istd] + else: + raise ValueError(f"cmvn file type no support: {filetype}") + return cmvn[0], cmvn[1] + + +def convert_samples_to_float32(samples): + """Convert sample type to float32. + + Audio sample type is usually integer or float-point. + Integers will be scaled to [-1, 1] in float32. + + PCM16 -> PCM32 + """ + float32_samples = samples.astype('float32') + if samples.dtype in np.sctypes['int']: + bits = np.iinfo(samples.dtype).bits + float32_samples *= (1. / 2**(bits - 1)) + elif samples.dtype in np.sctypes['float']: + pass + else: + raise TypeError("Unsupported sample type: %s." % samples.dtype) + return float32_samples + + +def convert_samples_from_float32(samples, dtype): + """Convert sample type from float32 to dtype. + + Audio sample type is usually integer or float-point. For integer + type, float32 will be rescaled from [-1, 1] to the maximum range + supported by the integer type. + + PCM32 -> PCM16 + """ + dtype = np.dtype(dtype) + output_samples = samples.copy() + if dtype in np.sctypes['int']: + bits = np.iinfo(dtype).bits + output_samples *= (2**(bits - 1) / 1.) + min_val = np.iinfo(dtype).min + max_val = np.iinfo(dtype).max + output_samples[output_samples > max_val] = max_val + output_samples[output_samples < min_val] = min_val + elif samples.dtype in np.sctypes['float']: + min_val = np.finfo(dtype).min + max_val = np.finfo(dtype).max + output_samples[output_samples > max_val] = max_val + output_samples[output_samples < min_val] = min_val + else: + raise TypeError("Unsupported sample type: %s." % samples.dtype) + return output_samples.astype(dtype) diff --git a/paddlespeech/s2t/transform/__init__.py b/paddlespeech/audio/transform/__init__.py similarity index 100% rename from paddlespeech/s2t/transform/__init__.py rename to paddlespeech/audio/transform/__init__.py diff --git a/paddlespeech/s2t/transform/add_deltas.py b/paddlespeech/audio/transform/add_deltas.py similarity index 100% rename from paddlespeech/s2t/transform/add_deltas.py rename to paddlespeech/audio/transform/add_deltas.py diff --git a/paddlespeech/s2t/transform/channel_selector.py b/paddlespeech/audio/transform/channel_selector.py similarity index 100% rename from paddlespeech/s2t/transform/channel_selector.py rename to paddlespeech/audio/transform/channel_selector.py diff --git a/paddlespeech/s2t/transform/cmvn.py b/paddlespeech/audio/transform/cmvn.py similarity index 100% rename from paddlespeech/s2t/transform/cmvn.py rename to paddlespeech/audio/transform/cmvn.py diff --git a/paddlespeech/s2t/transform/functional.py b/paddlespeech/audio/transform/functional.py similarity index 94% rename from paddlespeech/s2t/transform/functional.py rename to paddlespeech/audio/transform/functional.py index ccb50081..271819ad 100644 --- a/paddlespeech/s2t/transform/functional.py +++ b/paddlespeech/audio/transform/functional.py @@ -14,8 +14,8 @@ # Modified from espnet(https://github.com/espnet/espnet) import inspect -from paddlespeech.s2t.transform.transform_interface import TransformInterface -from paddlespeech.s2t.utils.check_kwargs import check_kwargs +from paddlespeech.audio.transform.transform_interface import TransformInterface +from paddlespeech.audio.utils.check_kwargs import check_kwargs class FuncTrans(TransformInterface): diff --git a/paddlespeech/s2t/transform/perturb.py b/paddlespeech/audio/transform/perturb.py similarity index 86% rename from paddlespeech/s2t/transform/perturb.py rename to paddlespeech/audio/transform/perturb.py index b18caefb..8044dc36 100644 --- a/paddlespeech/s2t/transform/perturb.py +++ b/paddlespeech/audio/transform/perturb.py @@ -17,8 +17,97 @@ import numpy import scipy import soundfile -from paddlespeech.s2t.io.reader import SoundHDF5File +import io +import os +import h5py +import numpy as np +class SoundHDF5File(): + """Collecting sound files to a HDF5 file + + >>> f = SoundHDF5File('a.flac.h5', mode='a') + >>> array = np.random.randint(0, 100, 100, dtype=np.int16) + >>> f['id'] = (array, 16000) + >>> array, rate = f['id'] + + + :param: str filepath: + :param: str mode: + :param: str format: The type used when saving wav. flac, nist, htk, etc. + :param: str dtype: + + """ + + def __init__(self, + filepath, + mode="r+", + format=None, + dtype="int16", + **kwargs): + self.filepath = filepath + self.mode = mode + self.dtype = dtype + + self.file = h5py.File(filepath, mode, **kwargs) + if format is None: + # filepath = a.flac.h5 -> format = flac + second_ext = os.path.splitext(os.path.splitext(filepath)[0])[1] + format = second_ext[1:] + if format.upper() not in soundfile.available_formats(): + # If not found, flac is selected + format = "flac" + + # This format affects only saving + self.format = format + + def __repr__(self): + return ''.format( + self.filepath, self.mode, self.format, self.dtype) + + def create_dataset(self, name, shape=None, data=None, **kwds): + f = io.BytesIO() + array, rate = data + soundfile.write(f, array, rate, format=self.format) + self.file.create_dataset( + name, shape=shape, data=np.void(f.getvalue()), **kwds) + + def __setitem__(self, name, data): + self.create_dataset(name, data=data) + + def __getitem__(self, key): + data = self.file[key][()] + f = io.BytesIO(data.tobytes()) + array, rate = soundfile.read(f, dtype=self.dtype) + return array, rate + + def keys(self): + return self.file.keys() + + def values(self): + for k in self.file: + yield self[k] + + def items(self): + for k in self.file: + yield k, self[k] + + def __iter__(self): + return iter(self.file) + + def __contains__(self, item): + return item in self.file + + def __len__(self, item): + return len(self.file) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.file.close() + + def close(self): + self.file.close() class SpeedPerturbation(): """SpeedPerturbation @@ -469,3 +558,4 @@ class RIRConvolve(): [scipy.convolve(x, r, mode="same") for r in rir], axis=-1) else: return scipy.convolve(x, rir, mode="same") + diff --git a/paddlespeech/s2t/transform/spec_augment.py b/paddlespeech/audio/transform/spec_augment.py similarity index 99% rename from paddlespeech/s2t/transform/spec_augment.py rename to paddlespeech/audio/transform/spec_augment.py index 7b3485b1..029e7b8f 100644 --- a/paddlespeech/s2t/transform/spec_augment.py +++ b/paddlespeech/audio/transform/spec_augment.py @@ -17,7 +17,7 @@ import random import numpy from PIL import Image -from paddlespeech.s2t.transform.functional import FuncTrans +from .functional import FuncTrans def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"): diff --git a/paddlespeech/s2t/transform/spectrogram.py b/paddlespeech/audio/transform/spectrogram.py similarity index 99% rename from paddlespeech/s2t/transform/spectrogram.py rename to paddlespeech/audio/transform/spectrogram.py index 19f0237b..864f3f99 100644 --- a/paddlespeech/s2t/transform/spectrogram.py +++ b/paddlespeech/audio/transform/spectrogram.py @@ -17,7 +17,7 @@ import numpy as np import paddle from python_speech_features import logfbank -import paddlespeech.audio.compliance.kaldi as kaldi +from ..compliance import kaldi def stft(x, diff --git a/paddlespeech/s2t/transform/transform_interface.py b/paddlespeech/audio/transform/transform_interface.py similarity index 100% rename from paddlespeech/s2t/transform/transform_interface.py rename to paddlespeech/audio/transform/transform_interface.py diff --git a/paddlespeech/s2t/transform/transformation.py b/paddlespeech/audio/transform/transformation.py similarity index 75% rename from paddlespeech/s2t/transform/transformation.py rename to paddlespeech/audio/transform/transformation.py index 3b433cb0..d24d6437 100644 --- a/paddlespeech/s2t/transform/transformation.py +++ b/paddlespeech/audio/transform/transformation.py @@ -22,32 +22,32 @@ from inspect import signature import yaml -from paddlespeech.s2t.utils.dynamic_import import dynamic_import +from ..utils.dynamic_import import dynamic_import import_alias = dict( - identity="paddlespeech.s2t.transform.transform_interface:Identity", - time_warp="paddlespeech.s2t.transform.spec_augment:TimeWarp", - time_mask="paddlespeech.s2t.transform.spec_augment:TimeMask", - freq_mask="paddlespeech.s2t.transform.spec_augment:FreqMask", - spec_augment="paddlespeech.s2t.transform.spec_augment:SpecAugment", - speed_perturbation="paddlespeech.s2t.transform.perturb:SpeedPerturbation", - speed_perturbation_sox="paddlespeech.s2t.transform.perturb:SpeedPerturbationSox", - volume_perturbation="paddlespeech.s2t.transform.perturb:VolumePerturbation", - noise_injection="paddlespeech.s2t.transform.perturb:NoiseInjection", - bandpass_perturbation="paddlespeech.s2t.transform.perturb:BandpassPerturbation", - rir_convolve="paddlespeech.s2t.transform.perturb:RIRConvolve", - delta="paddlespeech.s2t.transform.add_deltas:AddDeltas", - cmvn="paddlespeech.s2t.transform.cmvn:CMVN", - utterance_cmvn="paddlespeech.s2t.transform.cmvn:UtteranceCMVN", - fbank="paddlespeech.s2t.transform.spectrogram:LogMelSpectrogram", - spectrogram="paddlespeech.s2t.transform.spectrogram:Spectrogram", - stft="paddlespeech.s2t.transform.spectrogram:Stft", - istft="paddlespeech.s2t.transform.spectrogram:IStft", - stft2fbank="paddlespeech.s2t.transform.spectrogram:Stft2LogMelSpectrogram", - wpe="paddlespeech.s2t.transform.wpe:WPE", - channel_selector="paddlespeech.s2t.transform.channel_selector:ChannelSelector", - fbank_kaldi="paddlespeech.s2t.transform.spectrogram:LogMelSpectrogramKaldi", - cmvn_json="paddlespeech.s2t.transform.cmvn:GlobalCMVN") + identity="paddlespeech.audio.transform.transform_interface:Identity", + time_warp="paddlespeech.audio.transform.spec_augment:TimeWarp", + time_mask="paddlespeech.audio.transform.spec_augment:TimeMask", + freq_mask="paddlespeech.audio.transform.spec_augment:FreqMask", + spec_augment="paddlespeech.audio.transform.spec_augment:SpecAugment", + speed_perturbation="paddlespeech.audio.transform.perturb:SpeedPerturbation", + speed_perturbation_sox="paddlespeech.audio.transform.perturb:SpeedPerturbationSox", + volume_perturbation="paddlespeech.audio.transform.perturb:VolumePerturbation", + noise_injection="paddlespeech.audio.transform.perturb:NoiseInjection", + bandpass_perturbation="paddlespeech.audio.transform.perturb:BandpassPerturbation", + rir_convolve="paddlespeech.audio.transform.perturb:RIRConvolve", + delta="paddlespeech.audio.transform.add_deltas:AddDeltas", + cmvn="paddlespeech.audio.transform.cmvn:CMVN", + utterance_cmvn="paddlespeech.audio.transform.cmvn:UtteranceCMVN", + fbank="paddlespeech.audio.transform.spectrogram:LogMelSpectrogram", + spectrogram="paddlespeech.audio.transform.spectrogram:Spectrogram", + stft="paddlespeech.audio.transform.spectrogram:Stft", + istft="paddlespeech.audio.transform.spectrogram:IStft", + stft2fbank="paddlespeech.audio.transform.spectrogram:Stft2LogMelSpectrogram", + wpe="paddlespeech.audio.transform.wpe:WPE", + channel_selector="paddlespeech.audio.transform.channel_selector:ChannelSelector", + fbank_kaldi="paddlespeech.audio.transform.spectrogram:LogMelSpectrogramKaldi", + cmvn_json="paddlespeech.audio.transform.cmvn:GlobalCMVN") class Transformation(): diff --git a/paddlespeech/s2t/transform/wpe.py b/paddlespeech/audio/transform/wpe.py similarity index 100% rename from paddlespeech/s2t/transform/wpe.py rename to paddlespeech/audio/transform/wpe.py diff --git a/paddlespeech/audio/utils/check_kwargs.py b/paddlespeech/audio/utils/check_kwargs.py new file mode 100644 index 00000000..0aa839ac --- /dev/null +++ b/paddlespeech/audio/utils/check_kwargs.py @@ -0,0 +1,35 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +import inspect + + +def check_kwargs(func, kwargs, name=None): + """check kwargs are valid for func + + If kwargs are invalid, raise TypeError as same as python default + :param function func: function to be validated + :param dict kwargs: keyword arguments for func + :param str name: name used in TypeError (default is func name) + """ + try: + params = inspect.signature(func).parameters + except ValueError: + return + if name is None: + name = func.__name__ + for k in kwargs.keys(): + if k not in params: + raise TypeError( + f"{name}() got an unexpected keyword argument '{k}'") diff --git a/paddlespeech/audio/utils/dynamic_import.py b/paddlespeech/audio/utils/dynamic_import.py new file mode 100644 index 00000000..99f93356 --- /dev/null +++ b/paddlespeech/audio/utils/dynamic_import.py @@ -0,0 +1,38 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +import importlib + +__all__ = ["dynamic_import"] + + +def dynamic_import(import_path, alias=dict()): + """dynamic import module and class + + :param str import_path: syntax 'module_name:class_name' + e.g., 'paddlespeech.s2t.models.u2:U2Model' + :param dict alias: shortcut for registered class + :return: imported class + """ + if import_path not in alias and ":" not in import_path: + raise ValueError( + "import_path should be one of {} or " + 'include ":", e.g. "paddlespeech.s2t.models.u2:U2Model" : ' + "{}".format(set(alias), import_path)) + if ":" not in import_path: + import_path = alias[import_path] + + module_name, objname = import_path.split(":") + m = importlib.import_module(module_name) + return getattr(m, objname) diff --git a/paddlespeech/audio/utils/log.py b/paddlespeech/audio/utils/log.py index 5656b286..0a25bbd5 100644 --- a/paddlespeech/audio/utils/log.py +++ b/paddlespeech/audio/utils/log.py @@ -65,6 +65,7 @@ class Logger(object): def __init__(self, name: str=None): name = 'PaddleAudio' if not name else name + self.name = name self.logger = logging.getLogger(name) for key, conf in log_config.items(): @@ -101,7 +102,7 @@ class Logger(object): if not self.is_enable: return - self.logger.log(log_level, msg) + self.logger.log(log_level, self.name + " | " + msg) @contextlib.contextmanager def use_terminator(self, terminator: str): diff --git a/paddlespeech/audio/utils/tensor_utils.py b/paddlespeech/audio/utils/tensor_utils.py new file mode 100644 index 00000000..16f60810 --- /dev/null +++ b/paddlespeech/audio/utils/tensor_utils.py @@ -0,0 +1,192 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unility functions for Transformer.""" +from typing import List +from typing import Tuple + +import paddle + +from .log import Logger + +__all__ = ["pad_sequence", "add_sos_eos", "th_accuracy", "has_tensor"] + +logger = Logger(__name__) + + +def has_tensor(val): + if isinstance(val, (list, tuple)): + for item in val: + if has_tensor(item): + return True + elif isinstance(val, dict): + for k, v in val.items(): + print(k) + if has_tensor(v): + return True + else: + return paddle.is_tensor(val) + + +def pad_sequence(sequences: List[paddle.Tensor], + batch_first: bool=False, + padding_value: float=0.0) -> paddle.Tensor: + r"""Pad a list of variable length Tensors with ``padding_value`` + + ``pad_sequence`` stacks a list of Tensors along a new dimension, + and pads them to equal length. For example, if the input is list of + sequences with size ``L x *`` and if batch_first is False, and ``T x B x *`` + otherwise. + + `B` is batch size. It is equal to the number of elements in ``sequences``. + `T` is length of the longest sequence. + `L` is length of the sequence. + `*` is any number of trailing dimensions, including none. + + Example: + >>> from paddle.nn.utils.rnn import pad_sequence + >>> a = paddle.ones(25, 300) + >>> b = paddle.ones(22, 300) + >>> c = paddle.ones(15, 300) + >>> pad_sequence([a, b, c]).shape + paddle.Tensor([25, 3, 300]) + + Note: + This function returns a Tensor of size ``T x B x *`` or ``B x T x *`` + where `T` is the length of the longest sequence. This function assumes + trailing dimensions and type of all the Tensors in sequences are same. + + Args: + sequences (list[Tensor]): list of variable length sequences. + batch_first (bool, optional): output will be in ``B x T x *`` if True, or in + ``T x B x *`` otherwise + padding_value (float, optional): value for padded elements. Default: 0. + + Returns: + Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``. + Tensor of size ``B x T x *`` otherwise + """ + + # assuming trailing dimensions and type of all the Tensors + # in sequences are same and fetching those from sequences[0] + max_size = paddle.shape(sequences[0]) + # (TODO Hui Zhang): slice not supprot `end==start` + # trailing_dims = max_size[1:] + trailing_dims = tuple( + max_size[1:].numpy().tolist()) if sequences[0].ndim >= 2 else () + max_len = max([s.shape[0] for s in sequences]) + if batch_first: + out_dims = (len(sequences), max_len) + trailing_dims + else: + out_dims = (max_len, len(sequences)) + trailing_dims + out_tensor = paddle.full(out_dims, padding_value, sequences[0].dtype) + for i, tensor in enumerate(sequences): + length = tensor.shape[0] + # use index notation to prevent duplicate references to the tensor + if batch_first: + # TODO (Hui Zhang): set_value op not supprot `end==start` + # TODO (Hui Zhang): set_value op not support int16 + # TODO (Hui Zhang): set_varbase 2 rank not support [0,0,...] + # out_tensor[i, :length, ...] = tensor + if length != 0: + out_tensor[i, :length] = tensor + else: + out_tensor[i, length] = tensor + else: + # TODO (Hui Zhang): set_value op not supprot `end==start` + # out_tensor[:length, i, ...] = tensor + if length != 0: + out_tensor[:length, i] = tensor + else: + out_tensor[length, i] = tensor + + return out_tensor + + +def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int, + ignore_id: int) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Add and labels. + Args: + ys_pad (paddle.Tensor): batch of padded target sequences (B, Lmax) + sos (int): index of + eos (int): index of + ignore_id (int): index of padding + Returns: + ys_in (paddle.Tensor) : (B, Lmax + 1) + ys_out (paddle.Tensor) : (B, Lmax + 1) + Examples: + >>> sos_id = 10 + >>> eos_id = 11 + >>> ignore_id = -1 + >>> ys_pad + tensor([[ 1, 2, 3, 4, 5], + [ 4, 5, 6, -1, -1], + [ 7, 8, 9, -1, -1]], dtype=paddle.int32) + >>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id) + >>> ys_in + tensor([[10, 1, 2, 3, 4, 5], + [10, 4, 5, 6, 11, 11], + [10, 7, 8, 9, 11, 11]]) + >>> ys_out + tensor([[ 1, 2, 3, 4, 5, 11], + [ 4, 5, 6, 11, -1, -1], + [ 7, 8, 9, 11, -1, -1]]) + """ + # TODO(Hui Zhang): using comment code, + #_sos = paddle.to_tensor( + # [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place) + #_eos = paddle.to_tensor( + # [eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place) + #ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys + #ys_in = [paddle.cat([_sos, y], dim=0) for y in ys] + #ys_out = [paddle.cat([y, _eos], dim=0) for y in ys] + #return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id) + B = ys_pad.shape[0] + _sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos + _eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos + ys_in = paddle.cat([_sos, ys_pad], dim=1) + mask_pad = (ys_in == ignore_id) + ys_in = ys_in.masked_fill(mask_pad, eos) + + ys_out = paddle.cat([ys_pad, _eos], dim=1) + ys_out = ys_out.masked_fill(mask_pad, eos) + mask_eos = (ys_out == ignore_id) + ys_out = ys_out.masked_fill(mask_eos, eos) + ys_out = ys_out.masked_fill(mask_pad, ignore_id) + return ys_in, ys_out + + +def th_accuracy(pad_outputs: paddle.Tensor, + pad_targets: paddle.Tensor, + ignore_label: int) -> float: + """Calculate accuracy. + Args: + pad_outputs (Tensor): Prediction tensors (B * Lmax, D). + pad_targets (LongTensor): Target label tensors (B, Lmax, D). + ignore_label (int): Ignore label id. + Returns: + float: Accuracy value (0.0 - 1.0). + """ + pad_pred = pad_outputs.view(pad_targets.shape[0], pad_targets.shape[1], + pad_outputs.shape[1]).argmax(2) + mask = pad_targets != ignore_label + #TODO(Hui Zhang): sum not support bool type + # numerator = paddle.sum( + # pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) + numerator = ( + pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) + numerator = paddle.sum(numerator.type_as(pad_targets)) + #TODO(Hui Zhang): sum not support bool type + # denominator = paddle.sum(mask) + denominator = paddle.sum(mask.type_as(pad_targets)) + return float(numerator) / float(denominator) diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 52cb7cc8..76dfafb9 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -33,8 +33,8 @@ from ..log import logger from ..utils import CLI_TIMER from ..utils import stats_wrapper from ..utils import timer_register +from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer -from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.utils.utility import UpdateConfig __all__ = ['ASRExecutor'] diff --git a/paddlespeech/s2t/exps/deepspeech2/model.py b/paddlespeech/s2t/exps/deepspeech2/model.py index 511997a7..7ab8cf85 100644 --- a/paddlespeech/s2t/exps/deepspeech2/model.py +++ b/paddlespeech/s2t/exps/deepspeech2/model.py @@ -23,7 +23,7 @@ import paddle from paddle import distributed as dist from paddle import inference -from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer +from paddlespeech.audio.text.text_featurizer import TextFeaturizer from paddlespeech.s2t.io.dataloader import BatchDataLoader from paddlespeech.s2t.models.ds2 import DeepSpeech2InferModel from paddlespeech.s2t.models.ds2 import DeepSpeech2Model diff --git a/paddlespeech/s2t/exps/u2/bin/test_wav.py b/paddlespeech/s2t/exps/u2/bin/test_wav.py index 86c3db89..887ec7a6 100644 --- a/paddlespeech/s2t/exps/u2/bin/test_wav.py +++ b/paddlespeech/s2t/exps/u2/bin/test_wav.py @@ -20,10 +20,10 @@ import paddle import soundfile from yacs.config import CfgNode +from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.models.u2 import U2Model from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.utility import UpdateConfig logger = Log(__name__).getlog() diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index efcc9629..cdad3b8f 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -26,6 +26,8 @@ from paddle import distributed as dist from paddlespeech.s2t.frontend.featurizer import TextFeaturizer from paddlespeech.s2t.io.dataloader import BatchDataLoader +from paddlespeech.s2t.io.dataloader import StreamDataLoader +from paddlespeech.s2t.io.dataloader import DataLoaderFactory from paddlespeech.s2t.models.u2 import U2Model from paddlespeech.s2t.training.optimizer import OptimizerFactory from paddlespeech.s2t.training.reporter import ObsScope @@ -106,7 +108,8 @@ class U2Trainer(Trainer): @paddle.no_grad() def valid(self): self.model.eval() - logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") + if not self.use_streamdata: + logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") valid_losses = defaultdict(list) num_seen_utts = 1 total_loss = 0.0 @@ -132,7 +135,8 @@ class U2Trainer(Trainer): msg = f"Valid: Rank: {dist.get_rank()}, " msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) - msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) + if not self.use_streamdata: + msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in valid_dump.items()) logger.info(msg) @@ -152,7 +156,8 @@ class U2Trainer(Trainer): self.before_train() - logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") + if not self.use_streamdata: + logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.n_epoch: with Timer("Epoch-Train Time Cost: {}"): self.model.train() @@ -170,7 +175,8 @@ class U2Trainer(Trainer): self.train_batch(batch_index, batch, msg) self.after_train_batch() report('iter', batch_index + 1) - report('total', len(self.train_loader)) + if not self.use_streamdata: + report('total', len(self.train_loader)) report('reader_cost', dataload_time) observation['batch_cost'] = observation[ 'reader_cost'] + observation['step_cost'] @@ -191,7 +197,6 @@ class U2Trainer(Trainer): except Exception as e: logger.error(e) raise e - with Timer("Eval Time Cost: {}"): total_loss, num_seen_utts = self.valid() if dist.get_world_size() > 1: @@ -218,92 +223,16 @@ class U2Trainer(Trainer): def setup_dataloader(self): config = self.config.clone() - + self.use_streamdata = config.get("use_stream_data", False) if self.train: - # train/valid dataset, return token ids - self.train_loader = BatchDataLoader( - json_file=config.train_manifest, - train_mode=True, - sortagrad=config.sortagrad, - batch_size=config.batch_size, - maxlen_in=config.maxlen_in, - maxlen_out=config.maxlen_out, - minibatches=config.minibatches, - mini_batch_size=self.args.ngpu, - batch_count=config.batch_count, - batch_bins=config.batch_bins, - batch_frames_in=config.batch_frames_in, - batch_frames_out=config.batch_frames_out, - batch_frames_inout=config.batch_frames_inout, - preprocess_conf=config.preprocess_config, - n_iter_processes=config.num_workers, - subsampling_factor=1, - num_encs=1, - dist_sampler=config.get('dist_sampler', False), - shortest_first=False) - - self.valid_loader = BatchDataLoader( - json_file=config.dev_manifest, - train_mode=False, - sortagrad=False, - batch_size=config.batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=self.args.ngpu, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=config.preprocess_config, - n_iter_processes=config.num_workers, - subsampling_factor=1, - num_encs=1, - dist_sampler=config.get('dist_sampler', False), - shortest_first=False) + self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args) + self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args) logger.info("Setup train/valid Dataloader!") else: decode_batch_size = config.get('decode', dict()).get( 'decode_batch_size', 1) - # test dataset, return raw text - self.test_loader = BatchDataLoader( - json_file=config.test_manifest, - train_mode=False, - sortagrad=False, - batch_size=decode_batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=1, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=config.preprocess_config, - n_iter_processes=1, - subsampling_factor=1, - num_encs=1) - - self.align_loader = BatchDataLoader( - json_file=config.test_manifest, - train_mode=False, - sortagrad=False, - batch_size=decode_batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=1, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=config.preprocess_config, - n_iter_processes=1, - subsampling_factor=1, - num_encs=1) + self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args) + self.align_loader = DataLoaderFactory.get_dataloader('align', config, self.args) logger.info("Setup test/align Dataloader!") def setup_model(self): @@ -452,7 +381,8 @@ class U2Tester(U2Trainer): def test(self): assert self.args.result_file self.model.eval() - logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") + if not self.use_streamdata: + logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") stride_ms = self.config.stride_ms error_rate_type = None diff --git a/paddlespeech/s2t/exps/u2_kaldi/model.py b/paddlespeech/s2t/exps/u2_kaldi/model.py index bc995977..cb015c11 100644 --- a/paddlespeech/s2t/exps/u2_kaldi/model.py +++ b/paddlespeech/s2t/exps/u2_kaldi/model.py @@ -25,7 +25,7 @@ from paddle import distributed as dist from paddlespeech.s2t.frontend.featurizer import TextFeaturizer from paddlespeech.s2t.frontend.utility import load_dict -from paddlespeech.s2t.io.dataloader import BatchDataLoader +from paddlespeech.s2t.io.dataloader import DataLoaderFactory from paddlespeech.s2t.models.u2 import U2Model from paddlespeech.s2t.training.optimizer import OptimizerFactory from paddlespeech.s2t.training.scheduler import LRSchedulerFactory @@ -104,7 +104,8 @@ class U2Trainer(Trainer): @paddle.no_grad() def valid(self): self.model.eval() - logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") + if not self.use_streamdata: + logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") valid_losses = defaultdict(list) num_seen_utts = 1 total_loss = 0.0 @@ -131,7 +132,8 @@ class U2Trainer(Trainer): msg = f"Valid: Rank: {dist.get_rank()}, " msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) - msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) + if not self.use_streamdata: + msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in valid_dump.items()) logger.info(msg) @@ -150,8 +152,8 @@ class U2Trainer(Trainer): # paddle.jit.save(script_model, script_model_path) self.before_train() - - logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") + if not self.use_streamdata: + logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.n_epoch: with Timer("Epoch-Train Time Cost: {}"): self.model.train() @@ -162,7 +164,8 @@ class U2Trainer(Trainer): msg = "Train: Rank: {}, ".format(dist.get_rank()) msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) - msg += "batch : {}/{}, ".format(batch_index + 1, + if not self.use_streamdata: + msg += "batch : {}/{}, ".format(batch_index + 1, len(self.train_loader)) msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) msg += "data time: {:>.3f}s, ".format(dataload_time) @@ -198,87 +201,23 @@ class U2Trainer(Trainer): self.new_epoch() def setup_dataloader(self): - config = self.config.clone() - # train/valid dataset, return token ids - self.train_loader = BatchDataLoader( - json_file=config.train_manifest, - train_mode=True, - sortagrad=False, - batch_size=config.batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=self.args.ngpu, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=config.preprocess_config, - n_iter_processes=config.num_workers, - subsampling_factor=1, - num_encs=1) - - self.valid_loader = BatchDataLoader( - json_file=config.dev_manifest, - train_mode=False, - sortagrad=False, - batch_size=config.batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=self.args.ngpu, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=None, - n_iter_processes=config.num_workers, - subsampling_factor=1, - num_encs=1) - - decode_batch_size = config.get('decode', dict()).get( - 'decode_batch_size', 1) - # test dataset, return raw text - self.test_loader = BatchDataLoader( - json_file=config.test_manifest, - train_mode=False, - sortagrad=False, - batch_size=decode_batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=1, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=None, - n_iter_processes=1, - subsampling_factor=1, - num_encs=1) - - self.align_loader = BatchDataLoader( - json_file=config.test_manifest, - train_mode=False, - sortagrad=False, - batch_size=decode_batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=1, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=None, - n_iter_processes=1, - subsampling_factor=1, - num_encs=1) - logger.info("Setup train/valid/test/align Dataloader!") + self.use_streamdata = config.get("use_stream_data", False) + if self.train: + config = self.config.clone() + self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args) + config = self.config.clone() + config['preprocess_config'] = None + self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args) + logger.info("Setup train/valid Dataloader!") + else: + config = self.config.clone() + config['preprocess_config'] = None + self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args) + config = self.config.clone() + config['preprocess_config'] = None + self.align_loader = DataLoaderFactory.get_dataloader('align', config, self.args) + logger.info("Setup test/align Dataloader!") + def setup_model(self): config = self.config @@ -406,7 +345,8 @@ class U2Tester(U2Trainer): def test(self): assert self.args.result_file self.model.eval() - logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") + if not self.use_streamdata: + logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") stride_ms = self.config.stride_ms error_rate_type = None diff --git a/paddlespeech/s2t/exps/u2_st/model.py b/paddlespeech/s2t/exps/u2_st/model.py index 6a32eda7..60382543 100644 --- a/paddlespeech/s2t/exps/u2_st/model.py +++ b/paddlespeech/s2t/exps/u2_st/model.py @@ -25,7 +25,7 @@ import paddle from paddle import distributed as dist from paddlespeech.s2t.frontend.featurizer import TextFeaturizer -from paddlespeech.s2t.io.dataloader import BatchDataLoader +from paddlespeech.s2t.io.dataloader import DataLoaderFactory from paddlespeech.s2t.models.u2_st import U2STModel from paddlespeech.s2t.training.optimizer import OptimizerFactory from paddlespeech.s2t.training.reporter import ObsScope @@ -120,7 +120,8 @@ class U2STTrainer(Trainer): @paddle.no_grad() def valid(self): self.model.eval() - logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") + if not self.use_streamdata: + logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") valid_losses = defaultdict(list) num_seen_utts = 1 total_loss = 0.0 @@ -153,7 +154,8 @@ class U2STTrainer(Trainer): msg = f"Valid: Rank: {dist.get_rank()}, " msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) - msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) + if not self.use_streamdata: + msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in valid_dump.items()) logger.info(msg) @@ -172,8 +174,8 @@ class U2STTrainer(Trainer): # paddle.jit.save(script_model, script_model_path) self.before_train() - - logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") + if not self.use_streamdata: + logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.n_epoch: with Timer("Epoch-Train Time Cost: {}"): self.model.train() @@ -191,7 +193,8 @@ class U2STTrainer(Trainer): self.train_batch(batch_index, batch, msg) self.after_train_batch() report('iter', batch_index + 1) - report('total', len(self.train_loader)) + if not self.use_streamdata: + report('total', len(self.train_loader)) report('reader_cost', dataload_time) observation['batch_cost'] = observation[ 'reader_cost'] + observation['step_cost'] @@ -241,79 +244,18 @@ class U2STTrainer(Trainer): load_transcript = True if config.model_conf.asr_weight > 0 else False + config = self.config.clone() + config['load_transcript'] = load_transcript + self.use_streamdata = config.get("use_stream_data", False) if self.train: - # train/valid dataset, return token ids - self.train_loader = BatchDataLoader( - json_file=config.train_manifest, - train_mode=True, - sortagrad=False, - batch_size=config.batch_size, - maxlen_in=config.maxlen_in, - maxlen_out=config.maxlen_out, - minibatches=0, - mini_batch_size=1, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=config. - preprocess_config, # aug will be off when train_mode=False - n_iter_processes=config.num_workers, - subsampling_factor=1, - load_aux_output=load_transcript, - num_encs=1, - dist_sampler=True) - - self.valid_loader = BatchDataLoader( - json_file=config.dev_manifest, - train_mode=False, - sortagrad=False, - batch_size=config.batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=1, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=config. - preprocess_config, # aug will be off when train_mode=False - n_iter_processes=config.num_workers, - subsampling_factor=1, - load_aux_output=load_transcript, - num_encs=1, - dist_sampler=False) + self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args) + self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args) logger.info("Setup train/valid Dataloader!") else: - # test dataset, return raw text - decode_batch_size = config.get('decode', dict()).get( - 'decode_batch_size', 1) - self.test_loader = BatchDataLoader( - json_file=config.test_manifest, - train_mode=False, - sortagrad=False, - batch_size=decode_batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=1, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=config. - preprocess_config, # aug will be off when train_mode=False - n_iter_processes=config.num_workers, - subsampling_factor=1, - num_encs=1, - dist_sampler=False) - + self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args) logger.info("Setup test Dataloader!") + def setup_model(self): config = self.config model_conf = config @@ -468,7 +410,8 @@ class U2STTester(U2STTrainer): def test(self): assert self.args.result_file self.model.eval() - logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") + if not self.use_streamdata: + logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") decode_cfg = self.config.decode bleu_func = bleu_score.char_bleu if decode_cfg.error_rate_type == 'char-bleu' else bleu_score.bleu diff --git a/paddlespeech/s2t/io/dataloader.py b/paddlespeech/s2t/io/dataloader.py index 55aa13ff..83183024 100644 --- a/paddlespeech/s2t/io/dataloader.py +++ b/paddlespeech/s2t/io/dataloader.py @@ -18,6 +18,7 @@ from typing import Text import jsonlines import numpy as np +import paddle from paddle.io import BatchSampler from paddle.io import DataLoader from paddle.io import DistributedBatchSampler @@ -28,7 +29,11 @@ from paddlespeech.s2t.io.dataset import TransformDataset from paddlespeech.s2t.io.reader import LoadInputsAndTargets from paddlespeech.s2t.utils.log import Log -__all__ = ["BatchDataLoader"] +import paddlespeech.audio.streamdata as streamdata +from paddlespeech.audio.text.text_featurizer import TextFeaturizer +from yacs.config import CfgNode + +__all__ = ["BatchDataLoader", "StreamDataLoader"] logger = Log(__name__).getlog() @@ -56,6 +61,136 @@ def batch_collate(x): """ return x[0] +def read_preprocess_cfg(preprocess_conf_file): + augment_conf = dict() + preprocess_cfg = CfgNode(new_allowed=True) + preprocess_cfg.merge_from_file(preprocess_conf_file) + for idx, process in enumerate(preprocess_cfg["process"]): + opts = dict(process) + process_type = opts.pop("type") + if process_type == 'time_warp': + augment_conf['max_w'] = process['max_time_warp'] + augment_conf['w_inplace'] = process['inplace'] + augment_conf['w_mode'] = process['mode'] + if process_type == 'freq_mask': + augment_conf['max_f'] = process['F'] + augment_conf['num_f_mask'] = process['n_mask'] + augment_conf['f_inplace'] = process['inplace'] + augment_conf['f_replace_with_zero'] = process['replace_with_zero'] + if process_type == 'time_mask': + augment_conf['max_t'] = process['T'] + augment_conf['num_t_mask'] = process['n_mask'] + augment_conf['t_inplace'] = process['inplace'] + augment_conf['t_replace_with_zero'] = process['replace_with_zero'] + return augment_conf + +class StreamDataLoader(): + def __init__(self, + manifest_file: str, + train_mode: bool, + unit_type: str='char', + batch_size: int=0, + preprocess_conf=None, + num_mel_bins=80, + frame_length=25, + frame_shift=10, + dither=0.0, + minlen_in: float=0.0, + maxlen_in: float=float('inf'), + minlen_out: float=0.0, + maxlen_out: float=float('inf'), + resample_rate: int=16000, + shuffle_size: int=10000, + sort_size: int=1000, + n_iter_processes: int=1, + prefetch_factor: int=2, + dist_sampler: bool=False, + cmvn_file="data/mean_std.json", + vocab_filepath='data/lang_char/vocab.txt'): + self.manifest_file = manifest_file + self.train_model = train_mode + self.batch_size = batch_size + self.prefetch_factor = prefetch_factor + self.dist_sampler = dist_sampler + self.n_iter_processes = n_iter_processes + + text_featurizer = TextFeaturizer(unit_type, vocab_filepath) + symbol_table = text_featurizer.vocab_dict + self.feat_dim = num_mel_bins + self.vocab_size = text_featurizer.vocab_size + + augment_conf = read_preprocess_cfg(preprocess_conf) + + # The list of shard + shardlist = [] + with open(manifest_file, "r") as f: + for line in f.readlines(): + shardlist.append(line.strip()) + world_size = 1 + try: + world_size = paddle.distributed.get_world_size() + except Exception as e: + logger.warninig(e) + logger.warninig("can not get world_size using paddle.distributed.get_world_size(), use world_size=1") + assert(len(shardlist) >= world_size, "the length of shard list should >= number of gpus/xpus/...") + + update_n_iter_processes = int(max(min(len(shardlist)/world_size - 1, self.n_iter_processes), 0)) + logger.info(f"update_n_iter_processes {update_n_iter_processes}") + if update_n_iter_processes != self.n_iter_processes: + self.n_iter_processes = update_n_iter_processes + logger.info(f"change nun_workers to {self.n_iter_processes}") + + if self.dist_sampler: + base_dataset = streamdata.DataPipeline( + streamdata.SimpleShardList(shardlist), + streamdata.split_by_node if train_mode else streamdata.placeholder(), + streamdata.split_by_worker, + streamdata.tarfile_to_samples(streamdata.reraise_exception) + ) + else: + base_dataset = streamdata.DataPipeline( + streamdata.SimpleShardList(shardlist), + streamdata.split_by_worker, + streamdata.tarfile_to_samples(streamdata.reraise_exception) + ) + + self.dataset = base_dataset.append_list( + streamdata.audio_tokenize(symbol_table), + streamdata.audio_data_filter(frame_shift=frame_shift, max_length=maxlen_in, min_length=minlen_in, token_max_length=maxlen_out, token_min_length=minlen_out), + streamdata.audio_resample(resample_rate=resample_rate), + streamdata.audio_compute_fbank(num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither), + streamdata.audio_spec_aug(**augment_conf) if train_mode else streamdata.placeholder(), # num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80) + streamdata.shuffle(shuffle_size), + streamdata.sort(sort_size=sort_size), + streamdata.batched(batch_size), + streamdata.audio_padding(), + streamdata.audio_cmvn(cmvn_file) + ) + + if paddle.__version__ >= '2.3.2': + self.loader = streamdata.WebLoader( + self.dataset, + num_workers=self.n_iter_processes, + prefetch_factor = self.prefetch_factor, + batch_size=None + ) + else: + self.loader = streamdata.WebLoader( + self.dataset, + num_workers=self.n_iter_processes, + batch_size=None + ) + + def __iter__(self): + return self.loader.__iter__() + + def __call__(self): + return self.__iter__() + + def __len__(self): + logger.info("Stream dataloader does not support calculate the length of the dataset") + return -1 + class BatchDataLoader(): def __init__(self, @@ -199,3 +334,119 @@ class BatchDataLoader(): echo += f"shortest_first: {self.shortest_first}, " echo += f"file: {self.json_file}" return echo + + +class DataLoaderFactory(): + @staticmethod + def get_dataloader(mode: str, config, args): + config = config.clone() + use_streamdata = config.get("use_stream_data", False) + if use_streamdata: + if mode == 'train': + config['manifest'] = config.train_manifest + config['train_mode'] = True + elif mode == 'valid': + config['manifest'] = config.dev_manifest + config['train_mode'] = False + elif model == 'test' or mode == 'align': + config['manifest'] = config.test_manifest + config['train_mode'] = False + config['dither'] = 0.0 + config['minlen_in'] = 0.0 + config['maxlen_in'] = float('inf') + config['minlen_out'] = 0 + config['maxlen_out'] = float('inf') + config['dist_sampler'] = False + else: + raise KeyError("not valid mode type!!, please input one of 'train, valid, test, align'") + return StreamDataLoader( + manifest_file=config.manifest, + train_mode=config.train_mode, + unit_type=config.unit_type, + preprocess_conf=config.preprocess_config, + batch_size=config.batch_size, + num_mel_bins=config.feat_dim, + frame_length=config.window_ms, + frame_shift=config.stride_ms, + dither=config.dither, + minlen_in=config.minlen_in, + maxlen_in=config.maxlen_in, + minlen_out=config.minlen_out, + maxlen_out=config.maxlen_out, + resample_rate=config.resample_rate, + shuffle_size=config.shuffle_size, + sort_size=config.sort_size, + n_iter_processes=config.num_workers, + prefetch_factor=config.prefetch_factor, + dist_sampler=config.dist_sampler, + cmvn_file=config.cmvn_file, + vocab_filepath=config.vocab_filepath, + ) + else: + if mode == 'train': + config['manifest'] = config.train_manifest + config['train_mode'] = True + config['mini_batch_size'] = args.ngpu + config['subsampling_factor'] = 1 + config['num_encs'] = 1 + elif mode == 'valid': + config['manifest'] = config.dev_manifest + config['train_mode'] = False + config['sortagrad'] = False + config['maxlen_in'] = float('inf') + config['maxlen_out'] = float('inf') + config['minibatches'] = 0 + config['mini_batch_size'] = args.ngpu + config['batch_count'] = 'auto' + config['batch_bins'] = 0 + config['batch_frames_in'] = 0 + config['batch_frames_out'] = 0 + config['batch_frames_inout'] = 0 + config['subsampling_factor'] = 1 + config['num_encs'] = 1 + config['shortest_first'] = False + elif mode == 'test' or mode == 'align': + config['manifest'] = config.test_manifest + config['train_mode'] = False + config['sortagrad'] = False + config['batch_size'] = config.get('decode', dict()).get( + 'decode_batch_size', 1) + config['maxlen_in'] = float('inf') + config['maxlen_out'] = float('inf') + config['minibatches'] = 0 + config['mini_batch_size'] = 1 + config['batch_count'] = 'auto' + config['batch_bins'] = 0 + config['batch_frames_in'] = 0 + config['batch_frames_out'] = 0 + config['batch_frames_inout'] = 0 + config['num_workers'] = 1 + config['subsampling_factor'] = 1 + config['num_encs'] = 1 + config['dist_sampler'] = False + config['shortest_first'] = False + else: + raise KeyError("not valid mode type!!, please input one of 'train, valid, test, align'") + + return BatchDataLoader( + json_file=config.manifest, + train_mode=config.train_mode, + sortagrad=config.sortagrad, + batch_size=config.batch_size, + maxlen_in=config.maxlen_in, + maxlen_out=config.maxlen_out, + minibatches=config.minibatches, + mini_batch_size=config.mini_batch_size, + batch_count=config.batch_count, + batch_bins=config.batch_bins, + batch_frames_in=config.batch_frames_in, + batch_frames_out=config.batch_frames_out, + batch_frames_inout=config.batch_frames_inout, + preprocess_conf=config.preprocess_config, + n_iter_processes=config.num_workers, + subsampling_factor=config.subsampling_factor, + load_aux_output=config.get('load_transcript', None), + num_encs=config.num_encs, + dist_sampler=config.dist_sampler, + shortest_first=config.shortest_first) + diff --git a/paddlespeech/s2t/io/reader.py b/paddlespeech/s2t/io/reader.py index 4e136bdc..5e018bef 100644 --- a/paddlespeech/s2t/io/reader.py +++ b/paddlespeech/s2t/io/reader.py @@ -19,7 +19,7 @@ import numpy as np import soundfile from .utility import feat_type -from paddlespeech.s2t.transform.transformation import Transformation +from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.s2t.utils.log import Log # from paddlespeech.s2t.frontend.augmentor.augmentation import AugmentationPipeline as Transformation diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index c9b6e2fb..100aca18 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -48,9 +48,9 @@ from paddlespeech.s2t.utils import checkpoint from paddlespeech.s2t.utils import layer_tools from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank from paddlespeech.s2t.utils.log import Log -from paddlespeech.s2t.utils.tensor_utils import add_sos_eos -from paddlespeech.s2t.utils.tensor_utils import pad_sequence -from paddlespeech.s2t.utils.tensor_utils import th_accuracy +from paddlespeech.audio.utils.tensor_utils import add_sos_eos +from paddlespeech.audio.utils.tensor_utils import pad_sequence +from paddlespeech.audio.utils.tensor_utils import th_accuracy from paddlespeech.s2t.utils.utility import log_add from paddlespeech.s2t.utils.utility import UpdateConfig diff --git a/paddlespeech/s2t/models/u2_st/u2_st.py b/paddlespeech/s2t/models/u2_st/u2_st.py index 6447753c..00ded912 100644 --- a/paddlespeech/s2t/models/u2_st/u2_st.py +++ b/paddlespeech/s2t/models/u2_st/u2_st.py @@ -38,8 +38,8 @@ from paddlespeech.s2t.modules.mask import subsequent_mask from paddlespeech.s2t.utils import checkpoint from paddlespeech.s2t.utils import layer_tools from paddlespeech.s2t.utils.log import Log -from paddlespeech.s2t.utils.tensor_utils import add_sos_eos -from paddlespeech.s2t.utils.tensor_utils import th_accuracy +from paddlespeech.audio.utils.tensor_utils import add_sos_eos +from paddlespeech.audio.utils.tensor_utils import th_accuracy from paddlespeech.s2t.utils.utility import UpdateConfig __all__ = ["U2STModel", "U2STInferModel"] diff --git a/paddlespeech/server/engine/asr/online/onnx/asr_engine.py b/paddlespeech/server/engine/asr/online/onnx/asr_engine.py index 9ec7bb5a..ab4f1130 100644 --- a/paddlespeech/server/engine/asr/online/onnx/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/onnx/asr_engine.py @@ -26,7 +26,7 @@ from paddlespeech.cli.log import logger from paddlespeech.resource import CommonTaskResource from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.modules.ctc import CTCDecoder -from paddlespeech.s2t.transform.transformation import Transformation +from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils import onnx_infer diff --git a/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py b/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py index 1f0a8b80..182e6418 100644 --- a/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py @@ -24,9 +24,9 @@ from yacs.config import CfgNode from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.log import logger from paddlespeech.resource import CommonTaskResource +from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.modules.ctc import CTCDecoder -from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.paddle_predictor import init_predictor diff --git a/paddlespeech/server/engine/asr/online/python/asr_engine.py b/paddlespeech/server/engine/asr/online/python/asr_engine.py index b3472a83..2bacfecd 100644 --- a/paddlespeech/server/engine/asr/online/python/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/python/asr_engine.py @@ -24,9 +24,9 @@ from yacs.config import CfgNode from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.log import logger from paddlespeech.resource import CommonTaskResource +from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.modules.ctc import CTCDecoder -from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.utils.tensor_utils import add_sos_eos from paddlespeech.s2t.utils.tensor_utils import pad_sequence from paddlespeech.s2t.utils.utility import UpdateConfig diff --git a/setup.py b/setup.py index b46c24e8..a3ef753a 100644 --- a/setup.py +++ b/setup.py @@ -69,7 +69,9 @@ base = [ "prettytable", "zhon", "colorlog", - "pathos == 0.2.8" + "pathos == 0.2.8", + "braceexpand", + "pyyaml" ] server = [