Merge pull request #2062 from Jackwaterveg/webdataset

[Audio] Add webdataset in paddlespeech.audio
pull/2120/head
Jackwaterveg 3 years ago committed by GitHub
commit d1a25f6cb1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,7 +1,6 @@
############################################ ############################################
# Network Architecture # # Network Architecture #
############################################ ############################################
cmvn_file:
cmvn_file_type: "json" cmvn_file_type: "json"
# encoder related # encoder related
encoder: conformer encoder: conformer
@ -43,40 +42,42 @@ model_conf:
########################################### ###########################################
# Data # # Data #
########################################### ###########################################
train_manifest: data/manifest.train train_manifest: data/train_l/data.list
dev_manifest: data/manifest.dev dev_manifest: data/dev/data.list
test_manifest: data/manifest.test test_manifest: data/test_meeting/data.list
########################################### ###########################################
# Dataloader # # Dataloader #
########################################### ###########################################
vocab_filepath: data/lang_char/vocab.txt use_stream_data: True
unit_type: 'char' unit_type: 'char'
vocab_filepath: data/lang_char/vocab.txt
preprocess_config: conf/preprocess.yaml preprocess_config: conf/preprocess.yaml
cmvn_file: data/mean_std.json
spm_model_prefix: '' spm_model_prefix: ''
feat_dim: 80 feat_dim: 80
stride_ms: 10.0 stride_ms: 10.0
window_ms: 25.0 window_ms: 25.0
dither: 0.1
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 64 batch_size: 32
maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced minlen_in: 10
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced maxlen_in: 1200 # if input length(number of frames) > maxlen-in, data is automatically removed
minibatches: 0 # for debug minlen_out: 0
batch_count: auto maxlen_out: 150 # if output length(number of tokens) > maxlen-out, data is automatically removed
batch_bins: 0 resample_rate: 16000
batch_frames_in: 0 shuffle_size: 1500 # read number of 'shuffle_size' data as a chunk, shuffle the data in the chunk
batch_frames_out: 0 sort_size: 1000 # read number of 'sort_size' data as a chunk, sort the data in the chunk
batch_frames_inout: 0 num_workers: 8
num_workers: 0 prefetch_factor: 10
subsampling_factor: 1 dist_sampler: True
num_encs: 1 num_encs: 1
########################################### ###########################################
# Training # # Training #
########################################### ###########################################
n_epoch: 240 n_epoch: 32
accum_grad: 16 accum_grad: 32
global_grad_clip: 5.0 global_grad_clip: 5.0
log_interval: 100 log_interval: 100
checkpoint: checkpoint:

@ -2,6 +2,8 @@
# Copyright 2021 Mobvoi Inc(Author: Di Wu, Binbin Zhang) # Copyright 2021 Mobvoi Inc(Author: Di Wu, Binbin Zhang)
# NPU, ASLP Group (Author: Qijie Shao) # NPU, ASLP Group (Author: Qijie Shao)
#
# Modified from wenet(https://github.com/wenet-e2e/wenet)
stage=-1 stage=-1
stop_stage=100 stop_stage=100
@ -30,7 +32,7 @@ mkdir -p data
TARGET_DIR=${MAIN_ROOT}/dataset TARGET_DIR=${MAIN_ROOT}/dataset
mkdir -p ${TARGET_DIR} mkdir -p ${TARGET_DIR}
if [ ${stage} -le -2 ] && [ ${stop_stage} -ge -2 ]; then if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
# download data # download data
echo "Please follow https://github.com/wenet-e2e/WenetSpeech to download the data." echo "Please follow https://github.com/wenet-e2e/WenetSpeech to download the data."
exit 0; exit 0;
@ -44,86 +46,57 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
data || exit 1; data || exit 1;
fi fi
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then dict=data/lang_char/vocab.txt
# generate manifests if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
python3 ${TARGET_DIR}/aishell/aishell.py \ echo "Make a dictionary"
--manifest_prefix="data/manifest" \ echo "dictionary: ${dict}"
--target_dir="${TARGET_DIR}/aishell" mkdir -p $(dirname $dict)
echo "<blank>" > ${dict} # 0 will be used for "blank" in CTC
if [ $? -ne 0 ]; then echo "<unk>" >> ${dict} # <unk> must be 1
echo "Prepare Aishell failed. Terminated." echo "▁" >> ${dict} # ▁ is for space
exit 1 utils/text2token.py -s 1 -n 1 --space "▁" data/${train_set}/text \
fi | cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' \
for dataset in train dev test; do | grep -v "▁" \
mv data/manifest.${dataset} data/manifest.${dataset}.raw | awk '{print $0}' >> ${dict} \
done || exit 1;
echo "<eos>" >> $dict
fi fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# compute mean and stddev for normalizer echo "Compute cmvn"
# Here we use all the training data, you can sample some some data to save time
# BUG!!! We should use the segmented data for CMVN
if $cmvn; then if $cmvn; then
full_size=`cat data/${train_set}/wav.scp | wc -l` full_size=`cat data/${train_set}/wav.scp | wc -l`
sampling_size=$((full_size / cmvn_sampling_divisor)) sampling_size=$((full_size / cmvn_sampling_divisor))
shuf -n $sampling_size data/$train_set/wav.scp \ shuf -n $sampling_size data/$train_set/wav.scp \
> data/$train_set/wav.scp.sampled > data/$train_set/wav.scp.sampled
num_workers=$(nproc) python3 utils/compute_cmvn_stats.py \
--num_workers 16 \
python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ --train_config $train_config \
--manifest_path="data/manifest.train.raw" \ --in_scp data/$train_set/wav.scp.sampled \
--spectrum_type="fbank" \ --out_cmvn data/$train_set/mean_std.json \
--feat_dim=80 \ || exit 1;
--delta_delta=false \
--stride_ms=10 \
--window_ms=25 \
--sample_rate=16000 \
--use_dB_normalization=False \
--num_samples=-1 \
--num_workers=${num_workers} \
--output_path="data/mean_std.json"
if [ $? -ne 0 ]; then
echo "Compute mean and stddev failed. Terminated."
exit 1
fi
fi
fi
dict=data/dict/lang_char.txt
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# download data, generate manifests
# build vocabulary
python3 ${MAIN_ROOT}/utils/build_vocab.py \
--unit_type="char" \
--count_threshold=0 \
--vocab_path="data/lang_char/vocab.txt" \
--manifest_paths "data/manifest.train.raw"
if [ $? -ne 0 ]; then
echo "Build vocabulary failed. Terminated."
exit 1
fi fi
fi fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# format manifest with tokenids, vocab size echo "Making shards, please wait..."
for dataset in train dev test; do RED='\033[0;31m'
{ NOCOLOR='\033[0m'
python3 ${MAIN_ROOT}/utils/format_data.py \ echo -e "It requires ${RED}1.2T ${NOCOLOR}space for $shards_dir, please make sure you have enough space"
--cmvn_path "data/mean_std.json" \ echo -e "It takes about ${RED}12 ${NOCOLOR}hours with 32 threads"
--unit_type "char" \ for x in $dev_set $test_sets ${train_set}; do
--vocab_path="data/vocab.txt" \ dst=$shards_dir/$x
--manifest_path="data/manifest.${dataset}.raw" \ mkdir -p $dst
--output_path="data/manifest.${dataset}" utils/make_filted_shard_list.py --num_node 1 --num_gpus_per_node 8 --num_utts_per_shard 1000 \
--do_filter --resample 16000 \
if [ $? -ne 0 ]; then --num_threads 32 --segments data/$x/segments \
echo "Formt mnaifest failed. Terminated." data/$x/wav.scp data/$x/text \
exit 1 $(realpath $dst) data/$x/data.list
fi
} &
done done
wait
fi fi
echo "Aishell data preparation done." echo "Wenetspeech data preparation done."
exit 0 exit 0

@ -0,0 +1,68 @@
#!/bin/bash
profiler_options=
benchmark_batch_size=0
benchmark_max_step=0
# seed may break model convergence
seed=0
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True
echo "using seed $seed & FLAGS_cudnn_deterministic=True ..."
fi
if [ $# -lt 2 ] && [ $# -gt 3 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)"
exit -1
fi
config_path=$1
ckpt_name=$2
ips=$3
if [ ! $ips ];then
ips_config=
else
ips_config="--ips="${ips}
fi
echo ${ips_config}
mkdir -p exp
if [ ${ngpu} == 0 ]; then
python3 -u ${BIN_DIR}/train.py \
--ngpu ${ngpu} \
--seed ${seed} \
--config ${config_path} \
--output exp/${ckpt_name} \
--profiler-options "${profiler_options}" \
--benchmark-batch-size ${benchmark_batch_size} \
--benchmark-max-step ${benchmark_max_step}
else
NCCL_SOCKET_IFNAME=eth0 python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
--ngpu ${ngpu} \
--seed ${seed} \
--config ${config_path} \
--output exp/${ckpt_name} \
--profiler-options "${profiler_options}" \
--benchmark-batch-size ${benchmark_batch_size} \
--benchmark-max-step ${benchmark_max_step}
fi
if [ ${seed} != 0 ]; then
unset FLAGS_cudnn_deterministic
fi
if [ $? -ne 0 ]; then
echo "Failed in training!"
exit 1
fi
exit 0

@ -24,7 +24,7 @@ stage=1
prefix= prefix=
train_subset=L train_subset=L
. ./tools/parse_options.sh || exit 1; . ./utils/parse_options.sh || exit 1;
filter_by_id () { filter_by_id () {
idlist=$1 idlist=$1

@ -7,6 +7,7 @@ gpus=0,1,2,3,4,5,6,7
stage=0 stage=0
stop_stage=100 stop_stage=100
conf_path=conf/conformer.yaml conf_path=conf/conformer.yaml
ips= #xxx.xxx.xxx.xxx,xxx.xxx.xxx.xxx
decode_conf_path=conf/tuning/decode.yaml decode_conf_path=conf/tuning/decode.yaml
average_checkpoint=true average_checkpoint=true
avg_num=10 avg_num=10
@ -26,7 +27,7 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir # train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${ips}
fi fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then

@ -0,0 +1,70 @@
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
# flake8: noqa
from .cache import (
cached_tarfile_samples,
cached_tarfile_to_samples,
lru_cleanup,
pipe_cleaner,
)
from .compat import WebDataset, WebLoader, FluidWrapper
from .extradatasets import MockDataset, with_epoch, with_length
from .filters import (
associate,
batched,
decode,
detshuffle,
extract_keys,
getfirst,
info,
map,
map_dict,
map_tuple,
pipelinefilter,
rename,
rename_keys,
audio_resample,
select,
shuffle,
slice,
to_tuple,
transform_with,
unbatched,
xdecode,
audio_data_filter,
audio_tokenize,
audio_resample,
audio_compute_fbank,
audio_spec_aug,
sort,
audio_padding,
audio_cmvn,
placeholder,
)
from .handlers import (
ignore_and_continue,
ignore_and_stop,
reraise_exception,
warn_and_continue,
warn_and_stop,
)
from .pipeline import DataPipeline
from .shardlists import (
MultiShardSample,
ResampledShards,
SimpleShardList,
non_empty,
resampled,
shardspec,
single_node_only,
split_by_node,
split_by_worker,
)
from .tariterators import tarfile_samples, tarfile_to_samples
from .utils import PipelineStage, repeatedly
from .writer import ShardWriter, TarWriter, numpy_dumps
from .mix import RandomMix, RoundRobin

@ -0,0 +1,445 @@
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
"""Automatically decode webdataset samples."""
import io, json, os, pickle, re, tempfile
from functools import partial
import numpy as np
"""Extensions passed on to the image decoder."""
image_extensions = "jpg jpeg png ppm pgm pbm pnm".split()
################################################################
# handle basic datatypes
################################################################
def paddle_loads(data):
"""Load data using paddle.loads, importing paddle only if needed.
:param data: data to be decoded
"""
import io
import paddle
stream = io.BytesIO(data)
return paddle.load(stream)
def tenbin_loads(data):
from . import tenbin
return tenbin.decode_buffer(data)
def msgpack_loads(data):
import msgpack
return msgpack.unpackb(data)
def npy_loads(data):
import numpy.lib.format
stream = io.BytesIO(data)
return numpy.lib.format.read_array(stream)
def cbor_loads(data):
import cbor
return cbor.loads(data)
decoders = {
"txt": lambda data: data.decode("utf-8"),
"text": lambda data: data.decode("utf-8"),
"transcript": lambda data: data.decode("utf-8"),
"cls": lambda data: int(data),
"cls2": lambda data: int(data),
"index": lambda data: int(data),
"inx": lambda data: int(data),
"id": lambda data: int(data),
"json": lambda data: json.loads(data),
"jsn": lambda data: json.loads(data),
"pyd": lambda data: pickle.loads(data),
"pickle": lambda data: pickle.loads(data),
"pdparams": lambda data: paddle_loads(data),
"ten": tenbin_loads,
"tb": tenbin_loads,
"mp": msgpack_loads,
"msg": msgpack_loads,
"npy": npy_loads,
"npz": lambda data: np.load(io.BytesIO(data)),
"cbor": cbor_loads,
}
def basichandlers(key, data):
"""Handle basic file decoding.
This function is usually part of the post= decoders.
This handles the following forms of decoding:
- txt -> unicode string
- cls cls2 class count index inx id -> int
- json jsn -> JSON decoding
- pyd pickle -> pickle decoding
- pdparams -> paddle.loads
- ten tenbin -> fast tensor loading
- mp messagepack msg -> messagepack decoding
- npy -> Python NPY decoding
:param key: file name extension
:param data: binary data to be decoded
"""
extension = re.sub(r".*[.]", "", key)
if extension in decoders:
return decoders[extension](data)
return None
################################################################
# Generic extension handler.
################################################################
def call_extension_handler(key, data, f, extensions):
"""Call the function f with the given data if the key matches the extensions.
:param key: actual key found in the sample
:param data: binary data
:param f: decoder function
:param extensions: list of matching extensions
"""
extension = key.lower().split(".")
for target in extensions:
target = target.split(".")
if len(target) > len(extension):
continue
if extension[-len(target) :] == target:
return f(data)
return None
def handle_extension(extensions, f):
"""Return a decoder function for the list of extensions.
Extensions can be a space separated list of extensions.
Extensions can contain dots, in which case the corresponding number
of extension components must be present in the key given to f.
Comparisons are case insensitive.
Examples:
handle_extension("jpg jpeg", my_decode_jpg) # invoked for any file.jpg
handle_extension("seg.jpg", special_case_jpg) # invoked only for file.seg.jpg
"""
extensions = extensions.lower().split()
return partial(call_extension_handler, f=f, extensions=extensions)
################################################################
# handle images
################################################################
imagespecs = {
"l8": ("numpy", "uint8", "l"),
"rgb8": ("numpy", "uint8", "rgb"),
"rgba8": ("numpy", "uint8", "rgba"),
"l": ("numpy", "float", "l"),
"rgb": ("numpy", "float", "rgb"),
"rgba": ("numpy", "float", "rgba"),
"paddlel8": ("paddle", "uint8", "l"),
"paddlergb8": ("paddle", "uint8", "rgb"),
"paddlergba8": ("paddle", "uint8", "rgba"),
"paddlel": ("paddle", "float", "l"),
"paddlergb": ("paddle", "float", "rgb"),
"paddle": ("paddle", "float", "rgb"),
"paddlergba": ("paddle", "float", "rgba"),
"pill": ("pil", None, "l"),
"pil": ("pil", None, "rgb"),
"pilrgb": ("pil", None, "rgb"),
"pilrgba": ("pil", None, "rgba"),
}
class ImageHandler:
"""Decode image data using the given `imagespec`.
The `imagespec` specifies whether the image is decoded
to numpy/paddle/pi, decoded to uint8/float, and decoded
to l/rgb/rgba:
- l8: numpy uint8 l
- rgb8: numpy uint8 rgb
- rgba8: numpy uint8 rgba
- l: numpy float l
- rgb: numpy float rgb
- rgba: numpy float rgba
- paddlel8: paddle uint8 l
- paddlergb8: paddle uint8 rgb
- paddlergba8: paddle uint8 rgba
- paddlel: paddle float l
- paddlergb: paddle float rgb
- paddle: paddle float rgb
- paddlergba: paddle float rgba
- pill: pil None l
- pil: pil None rgb
- pilrgb: pil None rgb
- pilrgba: pil None rgba
"""
def __init__(self, imagespec, extensions=image_extensions):
"""Create an image handler.
:param imagespec: short string indicating the type of decoding
:param extensions: list of extensions the image handler is invoked for
"""
if imagespec not in list(imagespecs.keys()):
raise ValueError("Unknown imagespec: %s" % imagespec)
self.imagespec = imagespec.lower()
self.extensions = extensions
def __call__(self, key, data):
"""Perform image decoding.
:param key: file name extension
:param data: binary data
"""
import PIL.Image
extension = re.sub(r".*[.]", "", key)
if extension.lower() not in self.extensions:
return None
imagespec = self.imagespec
atype, etype, mode = imagespecs[imagespec]
with io.BytesIO(data) as stream:
img = PIL.Image.open(stream)
img.load()
img = img.convert(mode.upper())
if atype == "pil":
return img
elif atype == "numpy":
result = np.asarray(img)
if result.dtype != np.uint8:
raise ValueError("ImageHandler: numpy image must be uint8")
if etype == "uint8":
return result
else:
return result.astype("f") / 255.0
elif atype == "paddle":
import paddle
result = np.asarray(img)
if result.dtype != np.uint8:
raise ValueError("ImageHandler: paddle image must be uint8")
if etype == "uint8":
result = np.array(result.transpose(2, 0, 1))
return paddle.tensor(result)
else:
result = np.array(result.transpose(2, 0, 1))
return paddle.tensor(result) / 255.0
return None
def imagehandler(imagespec, extensions=image_extensions):
"""Create an image handler.
This is just a lower case alias for ImageHander.
:param imagespec: textual image spec
:param extensions: list of extensions the handler should be applied for
"""
return ImageHandler(imagespec, extensions)
################################################################
# torch video
################################################################
'''
def torch_video(key, data):
"""Decode video using the torchvideo library.
:param key: file name extension
:param data: data to be decoded
"""
extension = re.sub(r".*[.]", "", key)
if extension not in "mp4 ogv mjpeg avi mov h264 mpg webm wmv".split():
return None
import torchvision.io
with tempfile.TemporaryDirectory() as dirname:
fname = os.path.join(dirname, f"file.{extension}")
with open(fname, "wb") as stream:
stream.write(data)
return torchvision.io.read_video(fname, pts_unit="sec")
'''
################################################################
# paddlespeech.audio
################################################################
def paddle_audio(key, data):
"""Decode audio using the paddlespeech.audio library.
:param key: file name extension
:param data: data to be decoded
"""
extension = re.sub(r".*[.]", "", key)
if extension not in ["flac", "mp3", "sox", "wav", "m4a", "ogg", "wma"]:
return None
import paddlespeech.audio
with tempfile.TemporaryDirectory() as dirname:
fname = os.path.join(dirname, f"file.{extension}")
with open(fname, "wb") as stream:
stream.write(data)
return paddlespeech.audio.load(fname)
################################################################
# special class for continuing decoding
################################################################
class Continue:
"""Special class for continuing decoding.
This is mostly used for decompression, as in:
def decompressor(key, data):
if key.endswith(".gz"):
return Continue(key[:-3], decompress(data))
return None
"""
def __init__(self, key, data):
"""__init__.
:param key:
:param data:
"""
self.key, self.data = key, data
def gzfilter(key, data):
"""Decode .gz files.
This decodes compressed files and the continues decoding.
:param key: file name extension
:param data: binary data
"""
import gzip
if not key.endswith(".gz"):
return None
decompressed = gzip.open(io.BytesIO(data)).read()
return Continue(key[:-3], decompressed)
################################################################
# decode entire training amples
################################################################
default_pre_handlers = [gzfilter]
default_post_handlers = [basichandlers]
class Decoder:
"""Decode samples using a list of handlers.
For each key/data item, this iterates through the list of
handlers until some handler returns something other than None.
"""
def __init__(self, handlers, pre=None, post=None, only=None, partial=False):
"""Create a Decoder.
:param handlers: main list of handlers
:param pre: handlers called before the main list (.gz handler by default)
:param post: handlers called after the main list (default handlers by default)
:param only: a list of extensions; when give, only ignores files with those extensions
:param partial: allow partial decoding (i.e., don't decode fields that aren't of type bytes)
"""
if isinstance(only, str):
only = only.split()
self.only = only if only is None else set(only)
if pre is None:
pre = default_pre_handlers
if post is None:
post = default_post_handlers
assert all(callable(h) for h in handlers), f"one of {handlers} not callable"
assert all(callable(h) for h in pre), f"one of {pre} not callable"
assert all(callable(h) for h in post), f"one of {post} not callable"
self.handlers = pre + handlers + post
self.partial = partial
def decode1(self, key, data):
"""Decode a single field of a sample.
:param key: file name extension
:param data: binary data
"""
key = "." + key
for f in self.handlers:
result = f(key, data)
if isinstance(result, Continue):
key, data = result.key, result.data
continue
if result is not None:
return result
return data
def decode(self, sample):
"""Decode an entire sample.
:param sample: the sample, a dictionary of key value pairs
"""
result = {}
assert isinstance(sample, dict), sample
for k, v in list(sample.items()):
if k[0] == "_":
if isinstance(v, bytes):
v = v.decode("utf-8")
result[k] = v
continue
if self.only is not None and k not in self.only:
result[k] = v
continue
assert v is not None
if self.partial:
if isinstance(v, bytes):
result[k] = self.decode1(k, v)
else:
result[k] = v
else:
assert isinstance(v, bytes)
result[k] = self.decode1(k, v)
return result
def __call__(self, sample):
"""Decode an entire sample.
:param sample: the sample
"""
assert isinstance(sample, dict), (len(sample), sample)
return self.decode(sample)

@ -0,0 +1,190 @@
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
import itertools, os, random, re, sys
from urllib.parse import urlparse
from . import filters
from . import gopen
from .handlers import reraise_exception
from .tariterators import tar_file_and_group_expander
default_cache_dir = os.environ.get("WDS_CACHE", "./_cache")
default_cache_size = float(os.environ.get("WDS_CACHE_SIZE", "1e18"))
def lru_cleanup(cache_dir, cache_size, keyfn=os.path.getctime, verbose=False):
"""Performs cleanup of the file cache in cache_dir using an LRU strategy,
keeping the total size of all remaining files below cache_size."""
if not os.path.exists(cache_dir):
return
total_size = 0
for dirpath, dirnames, filenames in os.walk(cache_dir):
for filename in filenames:
total_size += os.path.getsize(os.path.join(dirpath, filename))
if total_size <= cache_size:
return
# sort files by last access time
files = []
for dirpath, dirnames, filenames in os.walk(cache_dir):
for filename in filenames:
files.append(os.path.join(dirpath, filename))
files.sort(key=keyfn, reverse=True)
# delete files until we're under the cache size
while len(files) > 0 and total_size > cache_size:
fname = files.pop()
total_size -= os.path.getsize(fname)
if verbose:
print("# deleting %s" % fname, file=sys.stderr)
os.remove(fname)
def download(url, dest, chunk_size=1024 ** 2, verbose=False):
"""Download a file from `url` to `dest`."""
temp = dest + f".temp{os.getpid()}"
with gopen.gopen(url) as stream:
with open(temp, "wb") as f:
while True:
data = stream.read(chunk_size)
if not data:
break
f.write(data)
os.rename(temp, dest)
def pipe_cleaner(spec):
"""Guess the actual URL from a "pipe:" specification."""
if spec.startswith("pipe:"):
spec = spec[5:]
words = spec.split(" ")
for word in words:
if re.match(r"^(https?|gs|ais|s3)", word):
return word
return spec
def get_file_cached(
spec,
cache_size=-1,
cache_dir=None,
url_to_name=pipe_cleaner,
verbose=False,
):
if cache_size == -1:
cache_size = default_cache_size
if cache_dir is None:
cache_dir = default_cache_dir
url = url_to_name(spec)
parsed = urlparse(url)
dirname, filename = os.path.split(parsed.path)
dirname = dirname.lstrip("/")
dirname = re.sub(r"[:/|;]", "_", dirname)
destdir = os.path.join(cache_dir, dirname)
os.makedirs(destdir, exist_ok=True)
dest = os.path.join(cache_dir, dirname, filename)
if not os.path.exists(dest):
if verbose:
print("# downloading %s to %s" % (url, dest), file=sys.stderr)
lru_cleanup(cache_dir, cache_size, verbose=verbose)
download(spec, dest, verbose=verbose)
return dest
def get_filetype(fname):
with os.popen("file '%s'" % fname) as f:
ftype = f.read()
return ftype
def check_tar_format(fname):
"""Check whether a file is a tar archive."""
ftype = get_filetype(fname)
return "tar archive" in ftype or "gzip compressed" in ftype
verbose_cache = int(os.environ.get("WDS_VERBOSE_CACHE", "0"))
def cached_url_opener(
data,
handler=reraise_exception,
cache_size=-1,
cache_dir=None,
url_to_name=pipe_cleaner,
validator=check_tar_format,
verbose=False,
always=False,
):
"""Given a stream of url names (packaged in `dict(url=url)`), yield opened streams."""
verbose = verbose or verbose_cache
for sample in data:
assert isinstance(sample, dict), sample
assert "url" in sample
url = sample["url"]
attempts = 5
try:
if not always and os.path.exists(url):
dest = url
else:
dest = get_file_cached(
url,
cache_size=cache_size,
cache_dir=cache_dir,
url_to_name=url_to_name,
verbose=verbose,
)
if verbose:
print("# opening %s" % dest, file=sys.stderr)
assert os.path.exists(dest)
if not validator(dest):
ftype = get_filetype(dest)
with open(dest, "rb") as f:
data = f.read(200)
os.remove(dest)
raise ValueError(
"%s (%s) is not a tar archive, but a %s, contains %s"
% (dest, url, ftype, repr(data))
)
try:
stream = open(dest, "rb")
sample.update(stream=stream)
yield sample
except FileNotFoundError as exn:
# dealing with race conditions in lru_cleanup
attempts -= 1
if attempts > 0:
time.sleep(random.random() * 10)
continue
raise exn
except Exception as exn:
exn.args = exn.args + (url,)
if handler(exn):
continue
else:
break
def cached_tarfile_samples(
src,
handler=reraise_exception,
cache_size=-1,
cache_dir=None,
verbose=False,
url_to_name=pipe_cleaner,
always=False,
):
streams = cached_url_opener(
src,
handler=handler,
cache_size=cache_size,
cache_dir=cache_dir,
verbose=verbose,
url_to_name=url_to_name,
always=always,
)
samples = tar_file_and_group_expander(streams, handler=handler)
return samples
cached_tarfile_to_samples = filters.pipelinefilter(cached_tarfile_samples)

@ -0,0 +1,170 @@
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
from dataclasses import dataclass
from itertools import islice
from typing import List
import braceexpand, yaml
from . import autodecode
from . import cache, filters, shardlists, tariterators
from .filters import reraise_exception
from .pipeline import DataPipeline
from .paddle_utils import DataLoader, IterableDataset
class FluidInterface:
def batched(self, batchsize):
return self.compose(filters.batched(batchsize))
def dynamic_batched(self, max_frames_in_batch):
return self.compose(filter.dynamic_batched(max_frames_in_batch))
def unbatched(self):
return self.compose(filters.unbatched())
def listed(self, batchsize, partial=True):
return self.compose(filters.batched(), batchsize=batchsize, collation_fn=None)
def unlisted(self):
return self.compose(filters.unlisted())
def log_keys(self, logfile=None):
return self.compose(filters.log_keys(logfile))
def shuffle(self, size, **kw):
if size < 1:
return self
else:
return self.compose(filters.shuffle(size, **kw))
def map(self, f, handler=reraise_exception):
return self.compose(filters.map(f, handler=handler))
def decode(self, *args, pre=None, post=None, only=None, partial=False, handler=reraise_exception):
handlers = [autodecode.ImageHandler(x) if isinstance(x, str) else x for x in args]
decoder = autodecode.Decoder(handlers, pre=pre, post=post, only=only, partial=partial)
return self.map(decoder, handler=handler)
def map_dict(self, handler=reraise_exception, **kw):
return self.compose(filters.map_dict(handler=handler, **kw))
def select(self, predicate, **kw):
return self.compose(filters.select(predicate, **kw))
def to_tuple(self, *args, handler=reraise_exception):
return self.compose(filters.to_tuple(*args, handler=handler))
def map_tuple(self, *args, handler=reraise_exception):
return self.compose(filters.map_tuple(*args, handler=handler))
def slice(self, *args):
return self.compose(filters.slice(*args))
def rename(self, **kw):
return self.compose(filters.rename(**kw))
def rsample(self, p=0.5):
return self.compose(filters.rsample(p))
def rename_keys(self, *args, **kw):
return self.compose(filters.rename_keys(*args, **kw))
def extract_keys(self, *args, **kw):
return self.compose(filters.extract_keys(*args, **kw))
def xdecode(self, *args, **kw):
return self.compose(filters.xdecode(*args, **kw))
def audio_data_filter(self, *args, **kw):
return self.compose(filters.audio_data_filter(*args, **kw))
def audio_tokenize(self, *args, **kw):
return self.compose(filters.audio_tokenize(*args, **kw))
def resample(self, *args, **kw):
return self.compose(filters.resample(*args, **kw))
def audio_compute_fbank(self, *args, **kw):
return self.compose(filters.audio_compute_fbank(*args, **kw))
def audio_spec_aug(self, *args, **kw):
return self.compose(filters.audio_spec_aug(*args, **kw))
def sort(self, size=500):
return self.compose(filters.sort(size))
def audio_padding(self):
return self.compose(filters.audio_padding())
def audio_cmvn(self, cmvn_file):
return self.compose(filters.audio_cmvn(cmvn_file))
class WebDataset(DataPipeline, FluidInterface):
"""Small fluid-interface wrapper for DataPipeline."""
def __init__(
self,
urls,
handler=reraise_exception,
resampled=False,
repeat=False,
shardshuffle=None,
cache_size=0,
cache_dir=None,
detshuffle=False,
nodesplitter=shardlists.single_node_only,
verbose=False,
):
super().__init__()
if isinstance(urls, IterableDataset):
assert not resampled
self.append(urls)
elif isinstance(urls, str) and (urls.endswith(".yaml") or urls.endswith(".yml")):
with (open(urls)) as stream:
spec = yaml.safe_load(stream)
assert "datasets" in spec
self.append(shardlists.MultiShardSample(spec))
elif isinstance(urls, dict):
assert "datasets" in urls
self.append(shardlists.MultiShardSample(urls))
elif resampled:
self.append(shardlists.ResampledShards(urls))
else:
self.append(shardlists.SimpleShardList(urls))
self.append(nodesplitter)
self.append(shardlists.split_by_worker)
if shardshuffle is True:
shardshuffle = 100
if shardshuffle is not None:
if detshuffle:
self.append(filters.detshuffle(shardshuffle))
else:
self.append(filters.shuffle(shardshuffle))
if cache_size == 0:
self.append(tariterators.tarfile_to_samples(handler=handler))
else:
assert cache_size == -1 or cache_size > 0
self.append(
cache.cached_tarfile_to_samples(
handler=handler,
verbose=verbose,
cache_size=cache_size,
cache_dir=cache_dir,
)
)
class FluidWrapper(DataPipeline, FluidInterface):
"""Small fluid-interface wrapper for DataPipeline."""
def __init__(self, initial):
super().__init__()
self.append(initial)
class WebLoader(DataPipeline, FluidInterface):
def __init__(self, *args, **kw):
super().__init__(DataLoader(*args, **kw))

@ -0,0 +1,141 @@
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
"""Train PyTorch models directly from POSIX tar archive.
Code works locally or over HTTP connections.
"""
import itertools as itt
import os
import random
import sys
import braceexpand
from . import utils
from .paddle_utils import IterableDataset
from .utils import PipelineStage
class MockDataset(IterableDataset):
"""MockDataset.
A mock dataset for performance testing and unit testing.
"""
def __init__(self, sample, length):
"""Create a mock dataset instance.
:param sample: the sample to be returned repeatedly
:param length: the length of the mock dataset
"""
self.sample = sample
self.length = length
def __iter__(self):
"""Return an iterator over this mock dataset."""
for i in range(self.length):
yield self.sample
class repeatedly(IterableDataset, PipelineStage):
"""Repeatedly yield samples from a dataset."""
def __init__(self, source, nepochs=None, nbatches=None, length=None):
"""Create an instance of Repeatedly.
:param nepochs: repeat for a maximum of nepochs
:param nbatches: repeat for a maximum of nbatches
"""
self.source = source
self.length = length
self.nbatches = nbatches
def invoke(self, source):
"""Return an iterator that iterates repeatedly over a source."""
return utils.repeatedly(
source,
nepochs=self.nepochs,
nbatches=self.nbatches,
)
class with_epoch(IterableDataset):
"""Change the actual and nominal length of an IterableDataset.
This will continuously iterate through the original dataset, but
impose new epoch boundaries at the given length/nominal.
This exists mainly as a workaround for the odd logic in DataLoader.
It is also useful for choosing smaller nominal epoch sizes with
very large datasets.
"""
def __init__(self, dataset, length):
"""Chop the dataset to the given length.
:param dataset: IterableDataset
:param length: declared length of the dataset
:param nominal: nominal length of dataset (if different from declared)
"""
super().__init__()
self.length = length
self.source = None
def __getstate__(self):
"""Return the pickled state of the dataset.
This resets the dataset iterator, since that can't be pickled.
"""
result = dict(self.__dict__)
result["source"] = None
return result
def invoke(self, dataset):
"""Return an iterator over the dataset.
This iterator returns as many samples as given by the `length`
parameter.
"""
if self.source is None:
self.source = iter(dataset)
for i in range(self.length):
try:
sample = next(self.source)
except StopIteration:
self.source = iter(dataset)
try:
sample = next(self.source)
except StopIteration:
return
yield sample
self.source = None
class with_length(IterableDataset, PipelineStage):
"""Repeatedly yield samples from a dataset."""
def __init__(self, dataset, length):
"""Create an instance of Repeatedly.
:param dataset: source dataset
:param length: stated length
"""
super().__init__()
self.dataset = dataset
self.length = length
def invoke(self, dataset):
"""Return an iterator that iterates repeatedly over a source."""
return iter(dataset)
def __len__(self):
"""Return the user specified length."""
return self.length

@ -0,0 +1,935 @@
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
# Modified from https://github.com/webdataset/webdataset
# Modified from wenet(https://github.com/wenet-e2e/wenet)
"""A collection of iterators for data transformations.
These functions are plain iterator functions. You can find curried versions
in webdataset.filters, and you can find IterableDataset wrappers in
webdataset.processing.
"""
import io
from fnmatch import fnmatch
import re
import itertools, os, random, sys, time
from functools import reduce, wraps
import numpy as np
from . import autodecode
from . import utils
from .paddle_utils import PaddleTensor
from .utils import PipelineStage
from .. import backends
from ..compliance import kaldi
import paddle
from ..transform.cmvn import GlobalCMVN
from ..utils.tensor_utils import pad_sequence
from ..transform.spec_augment import time_warp
from ..transform.spec_augment import time_mask
from ..transform.spec_augment import freq_mask
class FilterFunction(object):
"""Helper class for currying pipeline stages.
We use this roundabout construct becauce it can be pickled.
"""
def __init__(self, f, *args, **kw):
"""Create a curried function."""
self.f = f
self.args = args
self.kw = kw
def __call__(self, data):
"""Call the curried function with the given argument."""
return self.f(data, *self.args, **self.kw)
def __str__(self):
"""Compute a string representation."""
return f"<{self.f.__name__} {self.args} {self.kw}>"
def __repr__(self):
"""Compute a string representation."""
return f"<{self.f.__name__} {self.args} {self.kw}>"
class RestCurried(object):
"""Helper class for currying pipeline stages.
We use this roundabout construct because it can be pickled.
"""
def __init__(self, f):
"""Store the function for future currying."""
self.f = f
def __call__(self, *args, **kw):
"""Curry with the given arguments."""
return FilterFunction(self.f, *args, **kw)
def pipelinefilter(f):
"""Turn the decorated function into one that is partially applied for
all arguments other than the first."""
result = RestCurried(f)
return result
def reraise_exception(exn):
"""Reraises the given exception; used as a handler.
:param exn: exception
"""
raise exn
def identity(x):
"""Return the argument."""
return x
def compose2(f, g):
"""Compose two functions, g(f(x))."""
return lambda x: g(f(x))
def compose(*args):
"""Compose a sequence of functions (left-to-right)."""
return reduce(compose2, args)
def pipeline(source, *args):
"""Write an input pipeline; first argument is source, rest are filters."""
if len(args) == 0:
return source
return compose(*args)(source)
def getfirst(a, keys, default=None, missing_is_error=True):
"""Get the first matching key from a dictionary.
Keys can be specified as a list, or as a string of keys separated by ';'.
"""
if isinstance(keys, str):
assert " " not in keys
keys = keys.split(";")
for k in keys:
if k in a:
return a[k]
if missing_is_error:
raise ValueError(f"didn't find {keys} in {list(a.keys())}")
return default
def parse_field_spec(fields):
"""Parse a specification for a list of fields to be extracted.
Keys are separated by spaces in the spec. Each key can itself
be composed of key alternatives separated by ';'.
"""
if isinstance(fields, str):
fields = fields.split()
return [field.split(";") for field in fields]
def transform_with(sample, transformers):
"""Transform a list of values using a list of functions.
sample: list of values
transformers: list of functions
If there are fewer transformers than inputs, or if a transformer
function is None, then the identity function is used for the
corresponding sample fields.
"""
if transformers is None or len(transformers) == 0:
return sample
result = list(sample)
assert len(transformers) <= len(sample)
for i in range(len(transformers)): # skipcq: PYL-C0200
f = transformers[i]
if f is not None:
result[i] = f(sample[i])
return result
###
# Iterators
###
def _info(data, fmt=None, n=3, every=-1, width=50, stream=sys.stderr, name=""):
"""Print information about the samples that are passing through.
:param data: source iterator
:param fmt: format statement (using sample dict as keyword)
:param n: when to stop
:param every: how often to print
:param width: maximum width
:param stream: output stream
:param name: identifier printed before any output
"""
for i, sample in enumerate(data):
if i < n or (every > 0 and (i + 1) % every == 0):
if fmt is None:
print("---", name, file=stream)
for k, v in sample.items():
print(k, repr(v)[:width], file=stream)
else:
print(fmt.format(**sample), file=stream)
yield sample
info = pipelinefilter(_info)
def pick(buf, rng):
k = rng.randint(0, len(buf) - 1)
sample = buf[k]
buf[k] = buf[-1]
buf.pop()
return sample
def _shuffle(data, bufsize=1000, initial=100, rng=None, handler=None):
"""Shuffle the data in the stream.
This uses a buffer of size `bufsize`. Shuffling at
startup is less random; this is traded off against
yielding samples quickly.
data: iterator
bufsize: buffer size for shuffling
returns: iterator
rng: either random module or random.Random instance
"""
if rng is None:
rng = random.Random(int((os.getpid() + time.time()) * 1e9))
initial = min(initial, bufsize)
buf = []
for sample in data:
buf.append(sample)
if len(buf) < bufsize:
try:
buf.append(next(data)) # skipcq: PYL-R1708
except StopIteration:
pass
if len(buf) >= initial:
yield pick(buf, rng)
while len(buf) > 0:
yield pick(buf, rng)
shuffle = pipelinefilter(_shuffle)
class detshuffle(PipelineStage):
def __init__(self, bufsize=1000, initial=100, seed=0, epoch=-1):
self.bufsize = bufsize
self.initial = initial
self.seed = seed
self.epoch = epoch
def run(self, src):
self.epoch += 1
rng = random.Random()
rng.seed((self.seed, self.epoch))
return _shuffle(src, self.bufsize, self.initial, rng)
def _select(data, predicate):
"""Select samples based on a predicate.
:param data: source iterator
:param predicate: predicate (function)
"""
for sample in data:
if predicate(sample):
yield sample
select = pipelinefilter(_select)
def _log_keys(data, logfile=None):
import fcntl
if logfile is None or logfile == "":
for sample in data:
yield sample
else:
with open(logfile, "a") as stream:
for i, sample in enumerate(data):
buf = f"{i}\t{sample.get('__worker__')}\t{sample.get('__rank__')}\t{sample.get('__key__')}\n"
try:
fcntl.flock(stream.fileno(), fcntl.LOCK_EX)
stream.write(buf)
finally:
fcntl.flock(stream.fileno(), fcntl.LOCK_UN)
yield sample
log_keys = pipelinefilter(_log_keys)
def _decode(data, *args, handler=reraise_exception, **kw):
"""Decode data based on the decoding functions given as arguments."""
decoder = lambda x: autodecode.imagehandler(x) if isinstance(x, str) else x
handlers = [decoder(x) for x in args]
f = autodecode.Decoder(handlers, **kw)
for sample in data:
assert isinstance(sample, dict), sample
try:
decoded = f(sample)
except Exception as exn: # skipcq: PYL-W0703
if handler(exn):
continue
else:
break
yield decoded
decode = pipelinefilter(_decode)
def _map(data, f, handler=reraise_exception):
"""Map samples."""
for sample in data:
try:
result = f(sample)
except Exception as exn:
if handler(exn):
continue
else:
break
if result is None:
continue
if isinstance(sample, dict) and isinstance(result, dict):
result["__key__"] = sample.get("__key__")
yield result
map = pipelinefilter(_map)
def _rename(data, handler=reraise_exception, keep=True, **kw):
"""Rename samples based on keyword arguments."""
for sample in data:
try:
if not keep:
yield {k: getfirst(sample, v, missing_is_error=True) for k, v in kw.items()}
else:
def listify(v):
return v.split(";") if isinstance(v, str) else v
to_be_replaced = {x for v in kw.values() for x in listify(v)}
result = {k: v for k, v in sample.items() if k not in to_be_replaced}
result.update({k: getfirst(sample, v, missing_is_error=True) for k, v in kw.items()})
yield result
except Exception as exn:
if handler(exn):
continue
else:
break
rename = pipelinefilter(_rename)
def _associate(data, associator, **kw):
"""Associate additional data with samples."""
for sample in data:
if callable(associator):
extra = associator(sample["__key__"])
else:
extra = associator.get(sample["__key__"], {})
sample.update(extra) # destructive
yield sample
associate = pipelinefilter(_associate)
def _map_dict(data, handler=reraise_exception, **kw):
"""Map the entries in a dict sample with individual functions."""
assert len(list(kw.keys())) > 0
for key, f in kw.items():
assert callable(f), (key, f)
for sample in data:
assert isinstance(sample, dict)
try:
for k, f in kw.items():
sample[k] = f(sample[k])
except Exception as exn:
if handler(exn):
continue
else:
break
yield sample
map_dict = pipelinefilter(_map_dict)
def _to_tuple(data, *args, handler=reraise_exception, missing_is_error=True, none_is_error=None):
"""Convert dict samples to tuples."""
if none_is_error is None:
none_is_error = missing_is_error
if len(args) == 1 and isinstance(args[0], str) and " " in args[0]:
args = args[0].split()
for sample in data:
try:
result = tuple([getfirst(sample, f, missing_is_error=missing_is_error) for f in args])
if none_is_error and any(x is None for x in result):
raise ValueError(f"to_tuple {args} got {sample.keys()}")
yield result
except Exception as exn:
if handler(exn):
continue
else:
break
to_tuple = pipelinefilter(_to_tuple)
def _map_tuple(data, *args, handler=reraise_exception):
"""Map the entries of a tuple with individual functions."""
args = [f if f is not None else utils.identity for f in args]
for f in args:
assert callable(f), f
for sample in data:
assert isinstance(sample, (list, tuple))
sample = list(sample)
n = min(len(args), len(sample))
try:
for i in range(n):
sample[i] = args[i](sample[i])
except Exception as exn:
if handler(exn):
continue
else:
break
yield tuple(sample)
map_tuple = pipelinefilter(_map_tuple)
def _unlisted(data):
"""Turn batched data back into unbatched data."""
for batch in data:
assert isinstance(batch, list), sample
for sample in batch:
yield sample
unlisted = pipelinefilter(_unlisted)
def _unbatched(data):
"""Turn batched data back into unbatched data."""
for sample in data:
assert isinstance(sample, (tuple, list)), sample
assert len(sample) > 0
for i in range(len(sample[0])):
yield tuple(x[i] for x in sample)
unbatched = pipelinefilter(_unbatched)
def _rsample(data, p=0.5):
"""Randomly subsample a stream of data."""
assert p >= 0.0 and p <= 1.0
for sample in data:
if random.uniform(0.0, 1.0) < p:
yield sample
rsample = pipelinefilter(_rsample)
slice = pipelinefilter(itertools.islice)
def _extract_keys(source, *patterns, duplicate_is_error=True, ignore_missing=False):
for sample in source:
result = []
for pattern in patterns:
pattern = pattern.split(";") if isinstance(pattern, str) else pattern
matches = [x for x in sample.keys() if any(fnmatch("." + x, p) for p in pattern)]
if len(matches) == 0:
if ignore_missing:
continue
else:
raise ValueError(f"Cannot find {pattern} in sample keys {sample.keys()}.")
if len(matches) > 1 and duplicate_is_error:
raise ValueError(f"Multiple sample keys {sample.keys()} match {pattern}.")
value = sample[matches[0]]
result.append(value)
yield tuple(result)
extract_keys = pipelinefilter(_extract_keys)
def _rename_keys(source, *args, keep_unselected=False, must_match=True, duplicate_is_error=True, **kw):
renamings = [(pattern, output) for output, pattern in args]
renamings += [(pattern, output) for output, pattern in kw.items()]
for sample in source:
new_sample = {}
matched = {k: False for k, _ in renamings}
for path, value in sample.items():
fname = re.sub(r".*/", "", path)
new_name = None
for pattern, name in renamings[::-1]:
if fnmatch(fname.lower(), pattern):
matched[pattern] = True
new_name = name
break
if new_name is None:
if keep_unselected:
new_sample[path] = value
continue
if new_name in new_sample:
if duplicate_is_error:
raise ValueError(f"Duplicate value in sample {sample.keys()} after rename.")
continue
new_sample[new_name] = value
if must_match and not all(matched.values()):
raise ValueError(f"Not all patterns ({matched}) matched sample keys ({sample.keys()}).")
yield new_sample
rename_keys = pipelinefilter(_rename_keys)
def decode_bin(stream):
return stream.read()
def decode_text(stream):
binary = stream.read()
return binary.decode("utf-8")
def decode_pickle(stream):
return pickle.load(stream)
default_decoders = [
("*.bin", decode_bin),
("*.txt", decode_text),
("*.pyd", decode_pickle),
]
def find_decoder(decoders, path):
fname = re.sub(r".*/", "", path)
if fname.startswith("__"):
return lambda x: x
for pattern, fun in decoders[::-1]:
if fnmatch(fname.lower(), pattern) or fnmatch("." + fname.lower(), pattern):
return fun
return None
def _xdecode(
source,
*args,
must_decode=True,
defaults=default_decoders,
**kw,
):
decoders = list(defaults) + list(args)
decoders += [("*." + k, v) for k, v in kw.items()]
for sample in source:
new_sample = {}
for path, data in sample.items():
if path.startswith("__"):
new_sample[path] = data
continue
decoder = find_decoder(decoders, path)
if decoder is False:
value = data
elif decoder is None:
if must_decode:
raise ValueError(f"No decoder found for {path}.")
value = data
else:
if isinstance(data, bytes):
data = io.BytesIO(data)
value = decoder(data)
new_sample[path] = value
yield new_sample
xdecode = pipelinefilter(_xdecode)
def _audio_data_filter(source,
frame_shift=10,
max_length=10240,
min_length=10,
token_max_length=200,
token_min_length=1,
min_output_input_ratio=0.0005,
max_output_input_ratio=1):
""" Filter sample according to feature and label length
Inplace operation.
Args::
source: Iterable[{fname, wav, label, sample_rate}]
frame_shift: length of frame shift (ms)
max_length: drop utterance which is greater than max_length(10ms)
min_length: drop utterance which is less than min_length(10ms)
token_max_length: drop utterance which is greater than
token_max_length, especially when use char unit for
english modeling
token_min_length: drop utterance which is
less than token_max_length
min_output_input_ratio: minimal ration of
token_length / feats_length(10ms)
max_output_input_ratio: maximum ration of
token_length / feats_length(10ms)
Returns:
Iterable[{fname, wav, label, sample_rate}]
"""
for sample in source:
assert 'sample_rate' in sample
assert 'wav' in sample
assert 'label' in sample
# sample['wav'] is paddle.Tensor, we have 100 frames every second (default)
num_frames = sample['wav'].shape[1] / sample['sample_rate'] * (1000 / frame_shift)
if num_frames < min_length:
continue
if num_frames > max_length:
continue
if len(sample['label']) < token_min_length:
continue
if len(sample['label']) > token_max_length:
continue
if num_frames != 0:
if len(sample['label']) / num_frames < min_output_input_ratio:
continue
if len(sample['label']) / num_frames > max_output_input_ratio:
continue
yield sample
audio_data_filter = pipelinefilter(_audio_data_filter)
def _audio_tokenize(source,
symbol_table,
bpe_model=None,
non_lang_syms=None,
split_with_space=False):
""" Decode text to chars or BPE
Inplace operation
Args:
source: Iterable[{fname, wav, txt, sample_rate}]
Returns:
Iterable[{fname, wav, txt, tokens, label, sample_rate}]
"""
if non_lang_syms is not None:
non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})")
else:
non_lang_syms = {}
non_lang_syms_pattern = None
if bpe_model is not None:
import sentencepiece as spm
sp = spm.SentencePieceProcessor()
sp.load(bpe_model)
else:
sp = None
for sample in source:
assert 'txt' in sample
txt = sample['txt'].strip()
if non_lang_syms_pattern is not None:
parts = non_lang_syms_pattern.split(txt.upper())
parts = [w for w in parts if len(w.strip()) > 0]
else:
parts = [txt]
label = []
tokens = []
for part in parts:
if part in non_lang_syms:
tokens.append(part)
else:
if bpe_model is not None:
tokens.extend(__tokenize_by_bpe_model(sp, part))
else:
if split_with_space:
part = part.split(" ")
for ch in part:
if ch == ' ':
ch = "<space>"
tokens.append(ch)
for ch in tokens:
if ch in symbol_table:
label.append(symbol_table[ch])
elif '<unk>' in symbol_table:
label.append(symbol_table['<unk>'])
sample['tokens'] = tokens
sample['label'] = label
yield sample
audio_tokenize = pipelinefilter(_audio_tokenize)
def _audio_resample(source, resample_rate=16000):
""" Resample data.
Inplace operation.
Args:
data: Iterable[{fname, wav, label, sample_rate}]
resample_rate: target resample rate
Returns:
Iterable[{fname, wav, label, sample_rate}]
"""
for sample in source:
assert 'sample_rate' in sample
assert 'wav' in sample
sample_rate = sample['sample_rate']
waveform = sample['wav']
if sample_rate != resample_rate:
sample['sample_rate'] = resample_rate
sample['wav'] = paddle.to_tensor(backends.soundfile_backend.resample(
waveform.numpy(), src_sr = sample_rate, target_sr = resample_rate
))
yield sample
audio_resample = pipelinefilter(_audio_resample)
def _audio_compute_fbank(source,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
dither=0.0):
""" Extract fbank
Args:
source: Iterable[{fname, wav, label, sample_rate}]
num_mel_bins: number of mel filter bank
frame_length: length of one frame (ms)
frame_shift: length of frame shift (ms)
dither: value of dither
Returns:
Iterable[{fname, feat, label}]
"""
for sample in source:
assert 'sample_rate' in sample
assert 'wav' in sample
assert 'fname' in sample
assert 'label' in sample
sample_rate = sample['sample_rate']
waveform = sample['wav']
waveform = waveform * (1 << 15)
# Only keep fname, feat, label
mat = kaldi.fbank(waveform,
n_mels=num_mel_bins,
frame_length=frame_length,
frame_shift=frame_shift,
dither=dither,
energy_floor=0.0,
sr=sample_rate)
yield dict(fname=sample['fname'], label=sample['label'], feat=mat)
audio_compute_fbank = pipelinefilter(_audio_compute_fbank)
def _audio_spec_aug(source,
max_w=5,
w_inplace=True,
w_mode="PIL",
max_f=30,
num_f_mask=2,
f_inplace=True,
f_replace_with_zero=False,
max_t=40,
num_t_mask=2,
t_inplace=True,
t_replace_with_zero=False,):
""" Do spec augmentation
Inplace operation
Args:
source: Iterable[{fname, feat, label}]
max_w: max width of time warp
w_inplace: whether to inplace the original data while time warping
w_mode: time warp mode
max_f: max width of freq mask
num_f_mask: number of freq mask to apply
f_inplace: whether to inplace the original data while frequency masking
f_replace_with_zero: use zero to mask
max_t: max width of time mask
num_t_mask: number of time mask to apply
t_inplace: whether to inplace the original data while time masking
t_replace_with_zero: use zero to mask
Returns
Iterable[{fname, feat, label}]
"""
for sample in source:
x = sample['feat']
x = x.numpy()
x = time_warp(x, max_time_warp=max_w, inplace = w_inplace, mode= w_mode)
x = freq_mask(x, F = max_f, n_mask = num_f_mask, inplace = f_inplace, replace_with_zero = f_replace_with_zero)
x = time_mask(x, T = max_t, n_mask = num_t_mask, inplace = t_inplace, replace_with_zero = t_replace_with_zero)
sample['feat'] = paddle.to_tensor(x, dtype=paddle.float32)
yield sample
audio_spec_aug = pipelinefilter(_audio_spec_aug)
def _sort(source, sort_size=500):
""" Sort the data by feature length.
Sort is used after shuffle and before batch, so we can group
utts with similar lengths into a batch, and `sort_size` should
be less than `shuffle_size`
Args:
source: Iterable[{fname, feat, label}]
sort_size: buffer size for sort
Returns:
Iterable[{fname, feat, label}]
"""
buf = []
for sample in source:
buf.append(sample)
if len(buf) >= sort_size:
buf.sort(key=lambda x: x['feat'].shape[0])
for x in buf:
yield x
buf = []
# The sample left over
buf.sort(key=lambda x: x['feat'].shape[0])
for x in buf:
yield x
sort = pipelinefilter(_sort)
def _batched(source, batch_size=16):
""" Static batch the data by `batch_size`
Args:
data: Iterable[{fname, feat, label}]
batch_size: batch size
Returns:
Iterable[List[{fname, feat, label}]]
"""
buf = []
for sample in source:
buf.append(sample)
if len(buf) >= batch_size:
yield buf
buf = []
if len(buf) > 0:
yield buf
batched = pipelinefilter(_batched)
def dynamic_batched(source, max_frames_in_batch=12000):
""" Dynamic batch the data until the total frames in batch
reach `max_frames_in_batch`
Args:
source: Iterable[{fname, feat, label}]
max_frames_in_batch: max_frames in one batch
Returns:
Iterable[List[{fname, feat, label}]]
"""
buf = []
longest_frames = 0
for sample in source:
assert 'feat' in sample
assert isinstance(sample['feat'], paddle.Tensor)
new_sample_frames = sample['feat'].size(0)
longest_frames = max(longest_frames, new_sample_frames)
frames_after_padding = longest_frames * (len(buf) + 1)
if frames_after_padding > max_frames_in_batch:
yield buf
buf = [sample]
longest_frames = new_sample_frames
else:
buf.append(sample)
if len(buf) > 0:
yield buf
def _audio_padding(source):
""" Padding the data into training data
Args:
source: Iterable[List[{fname, feat, label}]]
Returns:
Iterable[Tuple(fname, feats, labels, feats lengths, label lengths)]
"""
for sample in source:
assert isinstance(sample, list)
feats_length = paddle.to_tensor([x['feat'].shape[0] for x in sample],
dtype="int64")
order = paddle.argsort(feats_length, descending=True)
feats_lengths = paddle.to_tensor(
[sample[i]['feat'].shape[0] for i in order], dtype="int64")
sorted_feats = [sample[i]['feat'] for i in order]
sorted_keys = [sample[i]['fname'] for i in order]
sorted_labels = [
paddle.to_tensor(sample[i]['label'], dtype="int32") for i in order
]
label_lengths = paddle.to_tensor([x.shape[0] for x in sorted_labels],
dtype="int64")
padded_feats = pad_sequence(sorted_feats,
batch_first=True,
padding_value=0)
padding_labels = pad_sequence(sorted_labels,
batch_first=True,
padding_value=-1)
yield (sorted_keys, padded_feats, feats_lengths, padding_labels,
label_lengths)
audio_padding = pipelinefilter(_audio_padding)
def _audio_cmvn(source, cmvn_file):
global_cmvn = GlobalCMVN(cmvn_file)
for batch in source:
sorted_keys, padded_feats, feats_lengths, padding_labels, label_lengths = batch
padded_feats = padded_feats.numpy()
padded_feats = global_cmvn(padded_feats)
padded_feats = paddle.to_tensor(padded_feats, dtype=paddle.float32)
yield (sorted_keys, padded_feats, feats_lengths, padding_labels,
label_lengths)
audio_cmvn = pipelinefilter(_audio_cmvn)
def _placeholder(source):
for data in source:
yield data
placeholder = pipelinefilter(_placeholder)

@ -0,0 +1,340 @@
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
"""Open URLs by calling subcommands."""
import os, sys, re
from subprocess import PIPE, Popen
from urllib.parse import urlparse
# global used for printing additional node information during verbose output
info = {}
class Pipe:
"""Wrapper class for subprocess.Pipe.
This class looks like a stream from the outside, but it checks
subprocess status and handles timeouts with exceptions.
This way, clients of the class do not need to know that they are
dealing with subprocesses.
:param *args: passed to `subprocess.Pipe`
:param **kw: passed to `subprocess.Pipe`
:param timeout: timeout for closing/waiting
:param ignore_errors: don't raise exceptions on subprocess errors
:param ignore_status: list of status codes to ignore
"""
def __init__(
self,
*args,
mode=None,
timeout=7200.0,
ignore_errors=False,
ignore_status=[],
**kw,
):
"""Create an IO Pipe."""
self.ignore_errors = ignore_errors
self.ignore_status = [0] + ignore_status
self.timeout = timeout
self.args = (args, kw)
if mode[0] == "r":
self.proc = Popen(*args, stdout=PIPE, **kw)
self.stream = self.proc.stdout
if self.stream is None:
raise ValueError(f"{args}: couldn't open")
elif mode[0] == "w":
self.proc = Popen(*args, stdin=PIPE, **kw)
self.stream = self.proc.stdin
if self.stream is None:
raise ValueError(f"{args}: couldn't open")
self.status = None
def __str__(self):
return f"<Pipe {self.args}>"
def check_status(self):
"""Poll the process and handle any errors."""
status = self.proc.poll()
if status is not None:
self.wait_for_child()
def wait_for_child(self):
"""Check the status variable and raise an exception if necessary."""
verbose = int(os.environ.get("GOPEN_VERBOSE", 0))
if self.status is not None and verbose:
# print(f"(waiting again [{self.status} {os.getpid()}:{self.proc.pid}])", file=sys.stderr)
return
self.status = self.proc.wait()
if verbose:
print(
f"pipe exit [{self.status} {os.getpid()}:{self.proc.pid}] {self.args} {info}",
file=sys.stderr,
)
if self.status not in self.ignore_status and not self.ignore_errors:
raise Exception(f"{self.args}: exit {self.status} (read) {info}")
def read(self, *args, **kw):
"""Wrap stream.read and checks status."""
result = self.stream.read(*args, **kw)
self.check_status()
return result
def write(self, *args, **kw):
"""Wrap stream.write and checks status."""
result = self.stream.write(*args, **kw)
self.check_status()
return result
def readLine(self, *args, **kw):
"""Wrap stream.readLine and checks status."""
result = self.stream.readLine(*args, **kw)
self.status = self.proc.poll()
self.check_status()
return result
def close(self):
"""Wrap stream.close, wait for the subprocess, and handle errors."""
self.stream.close()
self.status = self.proc.wait(self.timeout)
self.wait_for_child()
def __enter__(self):
"""Context handler."""
return self
def __exit__(self, etype, value, traceback):
"""Context handler."""
self.close()
def set_options(
obj, timeout=None, ignore_errors=None, ignore_status=None, handler=None
):
"""Set options for Pipes.
This function can be called on any stream. It will set pipe options only
when its argument is a pipe.
:param obj: any kind of stream
:param timeout: desired timeout
:param ignore_errors: desired ignore_errors setting
:param ignore_status: desired ignore_status setting
:param handler: desired error handler
"""
if not isinstance(obj, Pipe):
return False
if timeout is not None:
obj.timeout = timeout
if ignore_errors is not None:
obj.ignore_errors = ignore_errors
if ignore_status is not None:
obj.ignore_status = ignore_status
if handler is not None:
obj.handler = handler
return True
def gopen_file(url, mode="rb", bufsize=8192):
"""Open a file.
This works for local files, files over HTTP, and pipe: files.
:param url: URL to be opened
:param mode: mode to open it with
:param bufsize: requested buffer size
"""
return open(url, mode)
def gopen_pipe(url, mode="rb", bufsize=8192):
"""Use gopen to open a pipe.
:param url: a pipe: URL
:param mode: desired mode
:param bufsize: desired buffer size
"""
assert url.startswith("pipe:")
cmd = url[5:]
if mode[0] == "r":
return Pipe(
cmd,
mode=mode,
shell=True,
bufsize=bufsize,
ignore_status=[141],
) # skipcq: BAN-B604
elif mode[0] == "w":
return Pipe(
cmd,
mode=mode,
shell=True,
bufsize=bufsize,
ignore_status=[141],
) # skipcq: BAN-B604
else:
raise ValueError(f"{mode}: unknown mode")
def gopen_curl(url, mode="rb", bufsize=8192):
"""Open a URL with `curl`.
:param url: url (usually, http:// etc.)
:param mode: file mode
:param bufsize: buffer size
"""
if mode[0] == "r":
cmd = f"curl -s -L '{url}'"
return Pipe(
cmd,
mode=mode,
shell=True,
bufsize=bufsize,
ignore_status=[141, 23],
) # skipcq: BAN-B604
elif mode[0] == "w":
cmd = f"curl -s -L -T - '{url}'"
return Pipe(
cmd,
mode=mode,
shell=True,
bufsize=bufsize,
ignore_status=[141, 26],
) # skipcq: BAN-B604
else:
raise ValueError(f"{mode}: unknown mode")
def gopen_htgs(url, mode="rb", bufsize=8192):
"""Open a URL with `curl`.
:param url: url (usually, http:// etc.)
:param mode: file mode
:param bufsize: buffer size
"""
if mode[0] == "r":
url = re.sub(r"(?i)^htgs://", "gs://", url)
cmd = f"curl -s -L '{url}'"
return Pipe(
cmd,
mode=mode,
shell=True,
bufsize=bufsize,
ignore_status=[141, 23],
) # skipcq: BAN-B604
elif mode[0] == "w":
raise ValueError(f"{mode}: cannot write")
else:
raise ValueError(f"{mode}: unknown mode")
def gopen_gsutil(url, mode="rb", bufsize=8192):
"""Open a URL with `curl`.
:param url: url (usually, http:// etc.)
:param mode: file mode
:param bufsize: buffer size
"""
if mode[0] == "r":
cmd = f"gsutil cat '{url}'"
return Pipe(
cmd,
mode=mode,
shell=True,
bufsize=bufsize,
ignore_status=[141, 23],
) # skipcq: BAN-B604
elif mode[0] == "w":
cmd = f"gsutil cp - '{url}'"
return Pipe(
cmd,
mode=mode,
shell=True,
bufsize=bufsize,
ignore_status=[141, 26],
) # skipcq: BAN-B604
else:
raise ValueError(f"{mode}: unknown mode")
def gopen_error(url, *args, **kw):
"""Raise a value error.
:param url: url
:param args: other arguments
:param kw: other keywords
"""
raise ValueError(f"{url}: no gopen handler defined")
"""A dispatch table mapping URL schemes to handlers."""
gopen_schemes = dict(
__default__=gopen_error,
pipe=gopen_pipe,
http=gopen_curl,
https=gopen_curl,
sftp=gopen_curl,
ftps=gopen_curl,
scp=gopen_curl,
gs=gopen_gsutil,
htgs=gopen_htgs,
)
def gopen(url, mode="rb", bufsize=8192, **kw):
"""Open the URL.
This uses the `gopen_schemes` dispatch table to dispatch based
on scheme.
Support for the following schemes is built-in: pipe, file,
http, https, sftp, ftps, scp.
When no scheme is given the url is treated as a file.
You can use the OPEN_VERBOSE argument to get info about
files being opened.
:param url: the source URL
:param mode: the mode ("rb", "r")
:param bufsize: the buffer size
"""
global fallback_gopen
verbose = int(os.environ.get("GOPEN_VERBOSE", 0))
if verbose:
print("GOPEN", url, info, file=sys.stderr)
assert mode in ["rb", "wb"], mode
if url == "-":
if mode == "rb":
return sys.stdin.buffer
elif mode == "wb":
return sys.stdout.buffer
else:
raise ValueError(f"unknown mode {mode}")
pr = urlparse(url)
if pr.scheme == "":
bufsize = int(os.environ.get("GOPEN_BUFFER", -1))
return open(url, mode, buffering=bufsize)
if pr.scheme == "file":
bufsize = int(os.environ.get("GOPEN_BUFFER", -1))
return open(pr.path, mode, buffering=bufsize)
handler = gopen_schemes["__default__"]
handler = gopen_schemes.get(pr.scheme, handler)
return handler(url, mode, bufsize, **kw)
def reader(url, **kw):
"""Open url with gopen and mode "rb".
:param url: source URL
:param kw: other keywords forwarded to gopen
"""
return gopen(url, "rb", **kw)

@ -0,0 +1,47 @@
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
"""Pluggable exception handlers.
These are functions that take an exception as an argument and then return...
- the exception (in order to re-raise it)
- True (in order to continue and ignore the exception)
- False (in order to ignore the exception and stop processing)
They are used as handler= arguments in much of the library.
"""
import time, warnings
def reraise_exception(exn):
"""Call in an exception handler to re-raise the exception."""
raise exn
def ignore_and_continue(exn):
"""Call in an exception handler to ignore any exception and continue."""
return True
def warn_and_continue(exn):
"""Call in an exception handler to ignore any exception, isssue a warning, and continue."""
warnings.warn(repr(exn))
time.sleep(0.5)
return True
def ignore_and_stop(exn):
"""Call in an exception handler to ignore any exception and stop further processing."""
return False
def warn_and_stop(exn):
"""Call in an exception handler to ignore any exception and stop further processing."""
warnings.warn(repr(exn))
time.sleep(0.5)
return False

@ -0,0 +1,85 @@
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
"""Classes for mixing samples from multiple sources."""
import itertools, os, random, time, sys
from functools import reduce, wraps
import numpy as np
from . import autodecode, utils
from .paddle_utils import PaddleTensor, IterableDataset
from .utils import PipelineStage
def round_robin_shortest(*sources):
i = 0
while True:
try:
sample = next(sources[i % len(sources)])
yield sample
except StopIteration:
break
i += 1
def round_robin_longest(*sources):
i = 0
while len(sources) > 0:
try:
sample = next(sources[i])
i += 1
yield sample
except StopIteration:
del sources[i]
class RoundRobin(IterableDataset):
def __init__(self, datasets, longest=False):
self.datasets = datasets
self.longest = longest
def __iter__(self):
"""Return an iterator over the sources."""
sources = [iter(d) for d in self.datasets]
if self.longest:
return round_robin_longest(*sources)
else:
return round_robin_shortest(*sources)
def random_samples(sources, probs=None, longest=False):
if probs is None:
probs = [1] * len(sources)
else:
probs = list(probs)
while len(sources) > 0:
cum = (np.array(probs) / np.sum(probs)).cumsum()
r = random.random()
i = np.searchsorted(cum, r)
try:
yield next(sources[i])
except StopIteration:
if longest:
del sources[i]
del probs[i]
else:
break
class RandomMix(IterableDataset):
def __init__(self, datasets, probs=None, longest=False):
self.datasets = datasets
self.probs = probs
self.longest = longest
def __iter__(self):
"""Return an iterator over the sources."""
sources = [iter(d) for d in self.datasets]
return random_samples(sources, self.probs, longest=self.longest)

@ -0,0 +1,33 @@
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
"""Mock implementations of paddle interfaces when paddle is not available."""
try:
from paddle.io import DataLoader, IterableDataset
except ModuleNotFoundError:
class IterableDataset:
"""Empty implementation of IterableDataset when paddle is not available."""
pass
class DataLoader:
"""Empty implementation of DataLoader when paddle is not available."""
pass
try:
from paddle import Tensor as PaddleTensor
except ModuleNotFoundError:
class TorchTensor:
"""Empty implementation of PaddleTensor when paddle is not available."""
pass

@ -0,0 +1,132 @@
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#%%
import copy, os, random, sys, time
from dataclasses import dataclass
from itertools import islice
from typing import List
import braceexpand, yaml
from .handlers import reraise_exception
from .paddle_utils import DataLoader, IterableDataset
from .utils import PipelineStage
def add_length_method(obj):
def length(self):
return self.size
Combined = type(
obj.__class__.__name__ + "_Length",
(obj.__class__, IterableDataset),
{"__len__": length},
)
obj.__class__ = Combined
return obj
class DataPipeline(IterableDataset, PipelineStage):
"""A pipeline starting with an IterableDataset and a series of filters."""
def __init__(self, *args, **kwargs):
super().__init__()
self.pipeline = []
self.length = -1
self.repetitions = 1
self.nsamples = -1
for arg in args:
if arg is None:
continue
if isinstance(arg, list):
self.pipeline.extend(arg)
else:
self.pipeline.append(arg)
def invoke(self, f, *args, **kwargs):
"""Apply a pipeline stage, possibly to the output of a previous stage."""
if isinstance(f, PipelineStage):
return f.run(*args, **kwargs)
if isinstance(f, (IterableDataset, DataLoader)) and len(args) == 0:
return iter(f)
if isinstance(f, list):
return iter(f)
if callable(f):
result = f(*args, **kwargs)
return result
raise ValueError(f"{f}: not a valid pipeline stage")
def iterator1(self):
"""Create an iterator through one epoch in the pipeline."""
source = self.invoke(self.pipeline[0])
for step in self.pipeline[1:]:
source = self.invoke(step, source)
return source
def iterator(self):
"""Create an iterator through the entire dataset, using the given number of repetitions."""
for i in range(self.repetitions):
for sample in self.iterator1():
yield sample
def __iter__(self):
"""Create an iterator through the pipeline, repeating and slicing as requested."""
if self.repetitions != 1:
if self.nsamples > 0:
return islice(self.iterator(), self.nsamples)
else:
return self.iterator()
else:
return self.iterator()
def stage(self, i):
"""Return pipeline stage i."""
return self.pipeline[i]
def append(self, f):
"""Append a pipeline stage (modifies the object)."""
self.pipeline.append(f)
return self
def append_list(self, *args):
for arg in args:
self.pipeline.append(arg)
return self
def compose(self, *args):
"""Append a pipeline stage to a copy of the pipeline and returns the copy."""
result = copy.copy(self)
for arg in args:
result.append(arg)
return result
def with_length(self, n):
"""Add a __len__ method returning the desired value.
This does not change the actual number of samples in an epoch.
PyTorch IterableDataset should not have a __len__ method.
This is provided only as a workaround for some broken training environments
that require a __len__ method.
"""
self.size = n
return add_length_method(self)
def with_epoch(self, nsamples=-1, nbatches=-1):
"""Change the epoch to return the given number of samples/batches.
The two arguments mean the same thing."""
self.repetitions = sys.maxsize
self.nsamples = max(nsamples, nbatches)
return self
def repeat(self, nepochs=-1, nbatches=-1):
"""Repeat iterating through the dataset for the given #epochs up to the given #samples."""
if nepochs > 0:
self.repetitions = nepochs
self.nsamples = nbatches
else:
self.repetitions = sys.maxsize
self.nsamples = nbatches
return self

@ -0,0 +1,261 @@
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
# Modified from https://github.com/webdataset/webdataset
"""Train PyTorch models directly from POSIX tar archive.
Code works locally or over HTTP connections.
"""
import os, random, sys, time
from dataclasses import dataclass, field
from itertools import islice
from typing import List
import braceexpand, yaml
from . import utils
from .filters import pipelinefilter
from .paddle_utils import IterableDataset
from ..utils.log import Logger
logger = Logger(__name__)
def expand_urls(urls):
if isinstance(urls, str):
urllist = urls.split("::")
result = []
for url in urllist:
result.extend(braceexpand.braceexpand(url))
return result
else:
return list(urls)
class SimpleShardList(IterableDataset):
"""An iterable dataset yielding a list of urls."""
def __init__(self, urls, seed=None):
"""Iterate through the list of shards.
:param urls: a list of URLs as a Python list or brace notation string
"""
super().__init__()
urls = expand_urls(urls)
self.urls = urls
assert isinstance(self.urls[0], str)
self.seed = seed
def __len__(self):
return len(self.urls)
def __iter__(self):
"""Return an iterator over the shards."""
urls = self.urls.copy()
if self.seed is not None:
random.Random(self.seed).shuffle(urls)
for url in urls:
yield dict(url=url)
def split_by_node(src, group=None):
rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group)
logger.info(f"world_size:{world_size}, rank:{rank}")
if world_size > 1:
for s in islice(src, rank, None, world_size):
yield s
else:
for s in src:
yield s
def single_node_only(src, group=None):
rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group)
if world_size > 1:
raise ValueError("input pipeline needs to be reconfigured for multinode training")
for s in src:
yield s
def split_by_worker(src):
rank, world_size, worker, num_workers = utils.paddle_worker_info()
logger.info(f"num_workers:{num_workers}, worker:{worker}")
if num_workers > 1:
for s in islice(src, worker, None, num_workers):
yield s
else:
for s in src:
yield s
def resampled_(src, n=sys.maxsize):
import random
seed = time.time()
try:
seed = open("/dev/random", "rb").read(20)
except Exception as exn:
print(repr(exn)[:50], file=sys.stderr)
rng = random.Random(seed)
print("# resampled loading", file=sys.stderr)
items = list(src)
print(f"# resampled got {len(items)} samples, yielding {n}", file=sys.stderr)
for i in range(n):
yield rng.choice(items)
resampled = pipelinefilter(resampled_)
def non_empty(src):
count = 0
for s in src:
yield s
count += 1
if count == 0:
raise ValueError("pipeline stage received no data at all and this was declared as an error")
@dataclass
class MSSource:
"""Class representing a data source."""
name: str = ""
perepoch: int = -1
resample: bool = False
urls: List[str] = field(default_factory=list)
default_rng = random.Random()
def expand(s):
return os.path.expanduser(os.path.expandvars(s))
class MultiShardSample(IterableDataset):
def __init__(self, fname):
"""Construct a shardlist from multiple sources using a YAML spec."""
self.epoch = -1
class MultiShardSample(IterableDataset):
def __init__(self, fname):
"""Construct a shardlist from multiple sources using a YAML spec."""
self.epoch = -1
self.parse_spec(fname)
def parse_spec(self, fname):
self.rng = default_rng # capture default_rng if we fork
if isinstance(fname, dict):
spec = fname
fname = "{dict}"
else:
with open(fname) as stream:
spec = yaml.safe_load(stream)
assert set(spec.keys()).issubset(set("prefix datasets buckets".split())), list(spec.keys())
prefix = expand(spec.get("prefix", ""))
self.sources = []
for ds in spec["datasets"]:
assert set(ds.keys()).issubset(set("buckets name shards resample choose".split())), list(
ds.keys()
)
buckets = ds.get("buckets", spec.get("buckets", []))
if isinstance(buckets, str):
buckets = [buckets]
buckets = [expand(s) for s in buckets]
if buckets == []:
buckets = [""]
assert len(buckets) == 1, f"{buckets}: FIXME support for multiple buckets unimplemented"
bucket = buckets[0]
name = ds.get("name", "@" + bucket)
urls = ds["shards"]
if isinstance(urls, str):
urls = [urls]
# urls = [u for url in urls for u in braceexpand.braceexpand(url)]
urls = [
prefix + os.path.join(bucket, u) for url in urls for u in braceexpand.braceexpand(expand(url))
]
resample = ds.get("resample", -1)
nsample = ds.get("choose", -1)
if nsample > len(urls):
raise ValueError(f"perepoch {nsample} must be no greater than the number of shards")
if (nsample > 0) and (resample > 0):
raise ValueError("specify only one of perepoch or choose")
entry = MSSource(name=name, urls=urls, perepoch=nsample, resample=resample)
self.sources.append(entry)
print(f"# {name} {len(urls)} {nsample}", file=sys.stderr)
def set_epoch(self, seed):
"""Set the current epoch (for consistent shard selection among nodes)."""
self.rng = random.Random(seed)
def get_shards_for_epoch(self):
result = []
for source in self.sources:
if source.resample > 0:
# sample with replacement
l = self.rng.choices(source.urls, k=source.resample)
elif source.perepoch > 0:
# sample without replacement
l = list(source.urls)
self.rng.shuffle(l)
l = l[: source.perepoch]
else:
l = list(source.urls)
result += l
self.rng.shuffle(result)
return result
def __iter__(self):
shards = self.get_shards_for_epoch()
for shard in shards:
yield dict(url=shard)
def shardspec(spec):
if spec.endswith(".yaml"):
return MultiShardSample(spec)
else:
return SimpleShardList(spec)
class ResampledShards(IterableDataset):
"""An iterable dataset yielding a list of urls."""
def __init__(
self,
urls,
nshards=sys.maxsize,
worker_seed=None,
deterministic=False,
):
"""Sample shards from the shard list with replacement.
:param urls: a list of URLs as a Python list or brace notation string
"""
super().__init__()
urls = expand_urls(urls)
self.urls = urls
assert isinstance(self.urls[0], str)
self.nshards = nshards
self.worker_seed = utils.paddle_worker_seed if worker_seed is None else worker_seed
self.deterministic = deterministic
self.epoch = -1
def __iter__(self):
"""Return an iterator over the shards."""
self.epoch += 1
if self.deterministic:
seed = utils.make_seed(self.worker_seed(), self.epoch)
else:
seed = utils.make_seed(self.worker_seed(), self.epoch, os.getpid(), time.time_ns(), os.urandom(4))
if os.environ.get("WDS_SHOW_SEED", "0") == "1":
print(f"# ResampledShards seed {seed}")
self.rng = random.Random(seed)
for _ in range(self.nshards):
index = self.rng.randint(0, len(self.urls) - 1)
yield dict(url=self.urls[index])

@ -0,0 +1,283 @@
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
# Modified from wenet(https://github.com/wenet-e2e/wenet)
"""Low level iteration functions for tar archives."""
import random, re, tarfile
import braceexpand
from . import filters
from . import gopen
from .handlers import reraise_exception
trace = False
meta_prefix = "__"
meta_suffix = "__"
import paddlespeech
import paddle
import numpy as np
AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
def base_plus_ext(path):
"""Split off all file extensions.
Returns base, allext.
:param path: path with extensions
:param returns: path with all extensions removed
"""
match = re.match(r"^((?:.*/|)[^.]+)[.]([^/]*)$", path)
if not match:
return None, None
return match.group(1), match.group(2)
def valid_sample(sample):
"""Check whether a sample is valid.
:param sample: sample to be checked
"""
return (
sample is not None
and isinstance(sample, dict)
and len(list(sample.keys())) > 0
and not sample.get("__bad__", False)
)
# FIXME: UNUSED
def shardlist(urls, *, shuffle=False):
"""Given a list of URLs, yields that list, possibly shuffled."""
if isinstance(urls, str):
urls = braceexpand.braceexpand(urls)
else:
urls = list(urls)
if shuffle:
random.shuffle(urls)
for url in urls:
yield dict(url=url)
def url_opener(data, handler=reraise_exception, **kw):
"""Given a stream of url names (packaged in `dict(url=url)`), yield opened streams."""
for sample in data:
assert isinstance(sample, dict), sample
assert "url" in sample
url = sample["url"]
try:
stream = gopen.gopen(url, **kw)
sample.update(stream=stream)
yield sample
except Exception as exn:
exn.args = exn.args + (url,)
if handler(exn):
continue
else:
break
def tar_file_iterator(
fileobj, skip_meta=r"__[^/]*__($|/)", handler=reraise_exception
):
"""Iterate over tar file, yielding filename, content pairs for the given tar stream.
:param fileobj: byte stream suitable for tarfile
:param skip_meta: regexp for keys that are skipped entirely (Default value = r"__[^/]*__($|/)")
"""
stream = tarfile.open(fileobj=fileobj, mode="r:*")
for tarinfo in stream:
fname = tarinfo.name
try:
if not tarinfo.isreg():
continue
if fname is None:
continue
if (
"/" not in fname
and fname.startswith(meta_prefix)
and fname.endswith(meta_suffix)
):
# skipping metadata for now
continue
if skip_meta is not None and re.match(skip_meta, fname):
continue
name = tarinfo.name
pos = name.rfind('.')
assert pos > 0
prefix, postfix = name[:pos], name[pos + 1:]
if postfix == 'wav':
waveform, sample_rate = paddlespeech.audio.load(stream.extractfile(tarinfo), normal=False)
result = dict(fname=prefix, wav=waveform, sample_rate = sample_rate)
else:
txt = stream.extractfile(tarinfo).read().decode('utf8').strip()
result = dict(fname=prefix, txt=txt)
#result = dict(fname=fname, data=data)
yield result
stream.members = []
except Exception as exn:
if hasattr(exn, "args") and len(exn.args) > 0:
exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:]
if handler(exn):
continue
else:
break
del stream
def tar_file_and_group_iterator(
fileobj, skip_meta=r"__[^/]*__($|/)", handler=reraise_exception
):
""" Expand a stream of open tar files into a stream of tar file contents.
And groups the file with same prefix
Args:
data: Iterable[{src, stream}]
Returns:
Iterable[{key, wav, txt, sample_rate}]
"""
stream = tarfile.open(fileobj=fileobj, mode="r:*")
prev_prefix = None
example = {}
valid = True
for tarinfo in stream:
name = tarinfo.name
pos = name.rfind('.')
assert pos > 0
prefix, postfix = name[:pos], name[pos + 1:]
if prev_prefix is not None and prefix != prev_prefix:
example['fname'] = prev_prefix
if valid:
yield example
example = {}
valid = True
with stream.extractfile(tarinfo) as file_obj:
try:
if postfix == 'txt':
example['txt'] = file_obj.read().decode('utf8').strip()
elif postfix in AUDIO_FORMAT_SETS:
waveform, sample_rate = paddlespeech.audio.load(file_obj, normal=False)
waveform = paddle.to_tensor(np.expand_dims(np.array(waveform),0), dtype=paddle.float32)
example['wav'] = waveform
example['sample_rate'] = sample_rate
else:
example[postfix] = file_obj.read()
except Exception as exn:
if hasattr(exn, "args") and len(exn.args) > 0:
exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:]
if handler(exn):
continue
else:
break
valid = False
# logging.warning('error to parse {}'.format(name))
prev_prefix = prefix
if prev_prefix is not None:
example['fname'] = prev_prefix
yield example
stream.close()
def tar_file_expander(data, handler=reraise_exception):
"""Expand a stream of open tar files into a stream of tar file contents.
This returns an iterator over (filename, file_contents).
"""
for source in data:
url = source["url"]
try:
assert isinstance(source, dict)
assert "stream" in source
for sample in tar_file_iterator(source["stream"]):
assert (
isinstance(sample, dict) and "data" in sample and "fname" in sample
)
sample["__url__"] = url
yield sample
except Exception as exn:
exn.args = exn.args + (source.get("stream"), source.get("url"))
if handler(exn):
continue
else:
break
def tar_file_and_group_expander(data, handler=reraise_exception):
"""Expand a stream of open tar files into a stream of tar file contents.
This returns an iterator over (filename, file_contents).
"""
for source in data:
url = source["url"]
try:
assert isinstance(source, dict)
assert "stream" in source
for sample in tar_file_and_group_iterator(source["stream"]):
assert (
isinstance(sample, dict) and "wav" in sample and "txt" in sample and "fname" in sample
)
sample["__url__"] = url
yield sample
except Exception as exn:
exn.args = exn.args + (source.get("stream"), source.get("url"))
if handler(exn):
continue
else:
break
def group_by_keys(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
"""Return function over iterator that groups key, value pairs into samples.
:param keys: function that splits the key into key and extension (base_plus_ext)
:param lcase: convert suffixes to lower case (Default value = True)
"""
current_sample = None
for filesample in data:
assert isinstance(filesample, dict)
fname, value = filesample["fname"], filesample["data"]
prefix, suffix = keys(fname)
if trace:
print(
prefix,
suffix,
current_sample.keys() if isinstance(current_sample, dict) else None,
)
if prefix is None:
continue
if lcase:
suffix = suffix.lower()
if current_sample is None or prefix != current_sample["__key__"]:
if valid_sample(current_sample):
yield current_sample
current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
if suffix in current_sample:
raise ValueError(
f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}"
)
if suffixes is None or suffix in suffixes:
current_sample[suffix] = value
if valid_sample(current_sample):
yield current_sample
def tarfile_samples(src, handler=reraise_exception):
streams = url_opener(src, handler=handler)
samples = tar_file_and_group_expander(streams, handler=handler)
return samples
tarfile_to_samples = filters.pipelinefilter(tarfile_samples)

@ -0,0 +1,132 @@
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
# Modified from https://github.com/webdataset/webdataset
"""Miscellaneous utility functions."""
import importlib
import itertools as itt
import os
import re
import sys
from typing import Any, Callable, Iterator, Optional, Union
from ..utils.log import Logger
logger = Logger(__name__)
def make_seed(*args):
seed = 0
for arg in args:
seed = (seed * 31 + hash(arg)) & 0x7FFFFFFF
return seed
class PipelineStage:
def invoke(self, *args, **kw):
raise NotImplementedError
def identity(x: Any) -> Any:
"""Return the argument as is."""
return x
def safe_eval(s: str, expr: str = "{}"):
"""Evaluate the given expression more safely."""
if re.sub("[^A-Za-z0-9_]", "", s) != s:
raise ValueError(f"safe_eval: illegal characters in: '{s}'")
return eval(expr.format(s))
def lookup_sym(sym: str, modules: list):
"""Look up a symbol in a list of modules."""
for mname in modules:
module = importlib.import_module(mname, package="webdataset")
result = getattr(module, sym, None)
if result is not None:
return result
return None
def repeatedly0(
loader: Iterator, nepochs: int = sys.maxsize, nbatches: int = sys.maxsize
):
"""Repeatedly returns batches from a DataLoader."""
for epoch in range(nepochs):
for sample in itt.islice(loader, nbatches):
yield sample
def guess_batchsize(batch: Union[tuple, list]):
"""Guess the batch size by looking at the length of the first element in a tuple."""
return len(batch[0])
def repeatedly(
source: Iterator,
nepochs: int = None,
nbatches: int = None,
nsamples: int = None,
batchsize: Callable[..., int] = guess_batchsize,
):
"""Repeatedly yield samples from an iterator."""
epoch = 0
batch = 0
total = 0
while True:
for sample in source:
yield sample
batch += 1
if nbatches is not None and batch >= nbatches:
return
if nsamples is not None:
total += guess_batchsize(sample)
if total >= nsamples:
return
epoch += 1
if nepochs is not None and epoch >= nepochs:
return
def paddle_worker_info(group=None):
"""Return node and worker info for PyTorch and some distributed environments."""
rank = 0
world_size = 1
worker = 0
num_workers = 1
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
else:
try:
import paddle.distributed
group = group or paddle.distributed.get_group()
rank = paddle.distributed.get_rank()
world_size = paddle.distributed.get_world_size()
except ModuleNotFoundError:
pass
if "WORKER" in os.environ and "NUM_WORKERS" in os.environ:
worker = int(os.environ["WORKER"])
num_workers = int(os.environ["NUM_WORKERS"])
else:
try:
from paddle.io import get_worker_info
worker_info = paddle.io.get_worker_info()
if worker_info is not None:
worker = worker_info.id
num_workers = worker_info.num_workers
except ModuleNotFoundError as E:
logger.info(f"not found {E}")
exit(-1)
return rank, world_size, worker, num_workers
def paddle_worker_seed(group=None):
"""Compute a distinct, deterministic RNG seed for each worker and node."""
rank, world_size, worker, num_workers = paddle_worker_info(group=group)
return rank * 1000 + worker

@ -0,0 +1,450 @@
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
"""Classes and functions for writing tar files and WebDataset files."""
import io, json, pickle, re, tarfile, time
from typing import Any, Callable, Optional, Union
import numpy as np
from . import gopen
def imageencoder(image: Any, format: str = "PNG"): # skipcq: PYL-W0622
"""Compress an image using PIL and return it as a string.
Can handle float or uint8 images.
:param image: ndarray representing an image
:param format: compression format (PNG, JPEG, PPM)
"""
import PIL
assert isinstance(image, (PIL.Image.Image, np.ndarray)), type(image)
if isinstance(image, np.ndarray):
if image.dtype in [np.dtype("f"), np.dtype("d")]:
if not (np.amin(image) > -0.001 and np.amax(image) < 1.001):
raise ValueError(
f"image values out of range {np.amin(image)} {np.amax(image)}"
)
image = np.clip(image, 0.0, 1.0)
image = np.array(image * 255.0, "uint8")
assert image.ndim in [2, 3]
if image.ndim == 3:
assert image.shape[2] in [1, 3]
image = PIL.Image.fromarray(image)
if format.upper() == "JPG":
format = "JPEG"
elif format.upper() in ["IMG", "IMAGE"]:
format = "PPM"
if format == "JPEG":
opts = dict(quality=100)
else:
opts = {}
with io.BytesIO() as result:
image.save(result, format=format, **opts)
return result.getvalue()
def bytestr(data: Any):
"""Convert data into a bytestring.
Uses str and ASCII encoding for data that isn't already in string format.
:param data: data
"""
if isinstance(data, bytes):
return data
if isinstance(data, str):
return data.encode("ascii")
return str(data).encode("ascii")
def paddle_dumps(data: Any):
"""Dump data into a bytestring using paddle.dumps.
This delays importing paddle until needed.
:param data: data to be dumped
"""
import io
import paddle
stream = io.BytesIO()
paddle.save(data, stream)
return stream.getvalue()
def numpy_dumps(data: np.ndarray):
"""Dump data into a bytestring using numpy npy format.
:param data: data to be dumped
"""
import io
import numpy.lib.format
stream = io.BytesIO()
numpy.lib.format.write_array(stream, data)
return stream.getvalue()
def numpy_npz_dumps(data: np.ndarray):
"""Dump data into a bytestring using numpy npz format.
:param data: data to be dumped
"""
import io
stream = io.BytesIO()
np.savez_compressed(stream, **data)
return stream.getvalue()
def tenbin_dumps(x):
from . import tenbin
if isinstance(x, list):
return memoryview(tenbin.encode_buffer(x))
else:
return memoryview(tenbin.encode_buffer([x]))
def cbor_dumps(x):
import cbor
return cbor.dumps(x)
def mp_dumps(x):
import msgpack
return msgpack.packb(x)
def add_handlers(d, keys, value):
if isinstance(keys, str):
keys = keys.split()
for k in keys:
d[k] = value
def make_handlers():
"""Create a list of handlers for encoding data."""
handlers = {}
add_handlers(
handlers, "cls cls2 class count index inx id", lambda x: str(x).encode("ascii")
)
add_handlers(handlers, "txt text transcript", lambda x: x.encode("utf-8"))
add_handlers(handlers, "html htm", lambda x: x.encode("utf-8"))
add_handlers(handlers, "pyd pickle", pickle.dumps)
add_handlers(handlers, "pdparams", paddle_dumps)
add_handlers(handlers, "npy", numpy_dumps)
add_handlers(handlers, "npz", numpy_npz_dumps)
add_handlers(handlers, "ten tenbin tb", tenbin_dumps)
add_handlers(handlers, "json jsn", lambda x: json.dumps(x).encode("utf-8"))
add_handlers(handlers, "mp msgpack msg", mp_dumps)
add_handlers(handlers, "cbor", cbor_dumps)
add_handlers(handlers, "jpg jpeg img image", lambda data: imageencoder(data, "jpg"))
add_handlers(handlers, "png", lambda data: imageencoder(data, "png"))
add_handlers(handlers, "pbm", lambda data: imageencoder(data, "pbm"))
add_handlers(handlers, "pgm", lambda data: imageencoder(data, "pgm"))
add_handlers(handlers, "ppm", lambda data: imageencoder(data, "ppm"))
return handlers
default_handlers = make_handlers()
def encode_based_on_extension1(data: Any, tname: str, handlers: dict):
"""Encode data based on its extension and a dict of handlers.
:param data: data
:param tname: file extension
:param handlers: handlers
"""
if tname[0] == "_":
if not isinstance(data, str):
raise ValueError("the values of metadata must be of string type")
return data
extension = re.sub(r".*\.", "", tname).lower()
if isinstance(data, bytes):
return data
if isinstance(data, str):
return data.encode("utf-8")
handler = handlers.get(extension)
if handler is None:
raise ValueError(f"no handler found for {extension}")
return handler(data)
def encode_based_on_extension(sample: dict, handlers: dict):
"""Encode an entire sample with a collection of handlers.
:param sample: data sample (a dict)
:param handlers: handlers for encoding
"""
return {
k: encode_based_on_extension1(v, k, handlers) for k, v in list(sample.items())
}
def make_encoder(spec: Union[bool, str, dict, Callable]):
"""Make an encoder function from a specification.
:param spec: specification
"""
if spec is False or spec is None:
def encoder(x):
"""Do not encode at all."""
return x
elif callable(spec):
encoder = spec
elif isinstance(spec, dict):
def f(sample):
"""Encode based on extension."""
return encode_based_on_extension(sample, spec)
encoder = f
elif spec is True:
handlers = default_handlers
def g(sample):
"""Encode based on extension."""
return encode_based_on_extension(sample, handlers)
encoder = g
else:
raise ValueError(f"{spec}: unknown decoder spec")
if not callable(encoder):
raise ValueError(f"{spec} did not yield a callable encoder")
return encoder
class TarWriter:
"""A class for writing dictionaries to tar files.
:param fileobj: fileobj: file name for tar file (.tgz/.tar) or open file descriptor
:param encoder: sample encoding (Default value = True)
:param compress: (Default value = None)
`True` will use an encoder that behaves similar to the automatic
decoder for `Dataset`. `False` disables encoding and expects byte strings
(except for metadata, which must be strings). The `encoder` argument can
also be a `callable`, or a dictionary mapping extensions to encoders.
The following code will add two file to the tar archive: `a/b.png` and
`a/b.output.png`.
```Python
tarwriter = TarWriter(stream)
image = imread("b.jpg")
image2 = imread("b.out.jpg")
sample = {"__key__": "a/b", "png": image, "output.png": image2}
tarwriter.write(sample)
```
"""
def __init__(
self,
fileobj,
user: str = "bigdata",
group: str = "bigdata",
mode: int = 0o0444,
compress: Optional[bool] = None,
encoder: Union[None, bool, Callable] = True,
keep_meta: bool = False,
):
"""Create a tar writer.
:param fileobj: stream to write data to
:param user: user for tar files
:param group: group for tar files
:param mode: mode for tar files
:param compress: desired compression
:param encoder: encoder function
:param keep_meta: keep metadata (entries starting with "_")
"""
if isinstance(fileobj, str):
if compress is False:
tarmode = "w|"
elif compress is True:
tarmode = "w|gz"
else:
tarmode = "w|gz" if fileobj.endswith("gz") else "w|"
fileobj = gopen.gopen(fileobj, "wb")
self.own_fileobj = fileobj
else:
tarmode = "w|gz" if compress is True else "w|"
self.own_fileobj = None
self.encoder = make_encoder(encoder)
self.keep_meta = keep_meta
self.stream = fileobj
self.tarstream = tarfile.open(fileobj=fileobj, mode=tarmode)
self.user = user
self.group = group
self.mode = mode
self.compress = compress
def __enter__(self):
"""Enter context."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Exit context."""
self.close()
def close(self):
"""Close the tar file."""
self.tarstream.close()
if self.own_fileobj is not None:
self.own_fileobj.close()
self.own_fileobj = None
def write(self, obj):
"""Write a dictionary to the tar file.
:param obj: dictionary of objects to be stored
:returns: size of the entry
"""
total = 0
obj = self.encoder(obj)
if "__key__" not in obj:
raise ValueError("object must contain a __key__")
for k, v in list(obj.items()):
if k[0] == "_":
continue
if not isinstance(v, (bytes, bytearray, memoryview)):
raise ValueError(
f"{k} doesn't map to a bytes after encoding ({type(v)})"
)
key = obj["__key__"]
for k in sorted(obj.keys()):
if k == "__key__":
continue
if not self.keep_meta and k[0] == "_":
continue
v = obj[k]
if isinstance(v, str):
v = v.encode("utf-8")
now = time.time()
ti = tarfile.TarInfo(key + "." + k)
ti.size = len(v)
ti.mtime = now
ti.mode = self.mode
ti.uname = self.user
ti.gname = self.group
if not isinstance(v, (bytes, bytearray, memoryview)):
raise ValueError(f"converter didn't yield bytes: {k}, {type(v)}")
stream = io.BytesIO(v)
self.tarstream.addfile(ti, stream)
total += ti.size
return total
class ShardWriter:
"""Like TarWriter but splits into multiple shards."""
def __init__(
self,
pattern: str,
maxcount: int = 100000,
maxsize: float = 3e9,
post: Optional[Callable] = None,
start_shard: int = 0,
**kw,
):
"""Create a ShardWriter.
:param pattern: output file pattern
:param maxcount: maximum number of records per shard (Default value = 100000)
:param maxsize: maximum size of each shard (Default value = 3e9)
:param kw: other options passed to TarWriter
"""
self.verbose = 1
self.kw = kw
self.maxcount = maxcount
self.maxsize = maxsize
self.post = post
self.tarstream = None
self.shard = start_shard
self.pattern = pattern
self.total = 0
self.count = 0
self.size = 0
self.fname = None
self.next_stream()
def next_stream(self):
"""Close the current stream and move to the next."""
self.finish()
self.fname = self.pattern % self.shard
if self.verbose:
print(
"# writing",
self.fname,
self.count,
"%.1f GB" % (self.size / 1e9),
self.total,
)
self.shard += 1
stream = open(self.fname, "wb")
self.tarstream = TarWriter(stream, **self.kw)
self.count = 0
self.size = 0
def write(self, obj):
"""Write a sample.
:param obj: sample to be written
"""
if (
self.tarstream is None
or self.count >= self.maxcount
or self.size >= self.maxsize
):
self.next_stream()
size = self.tarstream.write(obj)
self.count += 1
self.total += 1
self.size += size
def finish(self):
"""Finish all writing (use close instead)."""
if self.tarstream is not None:
self.tarstream.close()
assert self.fname is not None
if callable(self.post):
self.post(self.fname)
self.tarstream = None
def close(self):
"""Close the stream."""
self.finish()
del self.tarstream
del self.shard
del self.count
del self.size
def __enter__(self):
"""Enter context."""
return self
def __exit__(self, *args, **kw):
"""Exit context."""
self.close()

@ -0,0 +1,235 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains the text featurizer class."""
from pprint import pformat
from typing import Union
import sentencepiece as spm
from .utility import BLANK
from .utility import EOS
from .utility import load_dict
from .utility import MASKCTC
from .utility import SOS
from .utility import SPACE
from .utility import UNK
from ..utils.log import Logger
logger = Logger(__name__)
__all__ = ["TextFeaturizer"]
class TextFeaturizer():
def __init__(self, unit_type, vocab, spm_model_prefix=None, maskctc=False):
"""Text featurizer, for processing or extracting features from text.
Currently, it supports char/word/sentence-piece level tokenizing and conversion into
a list of token indices. Note that the token indexing order follows the
given vocabulary file.
Args:
unit_type (str): unit type, e.g. char, word, spm
vocab Option[str, list]: Filepath to load vocabulary for token indices conversion, or vocab list.
spm_model_prefix (str, optional): spm model prefix. Defaults to None.
"""
assert unit_type in ('char', 'spm', 'word')
self.unit_type = unit_type
self.unk = UNK
self.maskctc = maskctc
if vocab:
self.vocab_dict, self._id2token, self.vocab_list, self.unk_id, self.eos_id, self.blank_id = self._load_vocabulary_from_file(
vocab, maskctc)
self.vocab_size = len(self.vocab_list)
else:
logger.warning("TextFeaturizer: not have vocab file or vocab list.")
if unit_type == 'spm':
spm_model = spm_model_prefix + '.model'
self.sp = spm.SentencePieceProcessor()
self.sp.Load(spm_model)
def tokenize(self, text, replace_space=True):
if self.unit_type == 'char':
tokens = self.char_tokenize(text, replace_space)
elif self.unit_type == 'word':
tokens = self.word_tokenize(text)
else: # spm
tokens = self.spm_tokenize(text)
return tokens
def detokenize(self, tokens):
if self.unit_type == 'char':
text = self.char_detokenize(tokens)
elif self.unit_type == 'word':
text = self.word_detokenize(tokens)
else: # spm
text = self.spm_detokenize(tokens)
return text
def featurize(self, text):
"""Convert text string to a list of token indices.
Args:
text (str): Text to process.
Returns:
List[int]: List of token indices.
"""
tokens = self.tokenize(text)
ids = []
for token in tokens:
if token not in self.vocab_dict:
logger.debug(f"Text Token: {token} -> {self.unk}")
token = self.unk
ids.append(self.vocab_dict[token])
return ids
def defeaturize(self, idxs):
"""Convert a list of token indices to text string,
ignore index after eos_id.
Args:
idxs (List[int]): List of token indices.
Returns:
str: Text.
"""
tokens = []
for idx in idxs:
if idx == self.eos_id:
break
tokens.append(self._id2token[idx])
text = self.detokenize(tokens)
return text
def char_tokenize(self, text, replace_space=True):
"""Character tokenizer.
Args:
text (str): text string.
replace_space (bool): False only used by build_vocab.py.
Returns:
List[str]: tokens.
"""
text = text.strip()
if replace_space:
text_list = [SPACE if item == " " else item for item in list(text)]
else:
text_list = list(text)
return text_list
def char_detokenize(self, tokens):
"""Character detokenizer.
Args:
tokens (List[str]): tokens.
Returns:
str: text string.
"""
tokens = [t.replace(SPACE, " ") for t in tokens]
return "".join(tokens)
def word_tokenize(self, text):
"""Word tokenizer, separate by <space>."""
return text.strip().split()
def word_detokenize(self, tokens):
"""Word detokenizer, separate by <space>."""
return " ".join(tokens)
def spm_tokenize(self, text):
"""spm tokenize.
Args:
text (str): text string.
Returns:
List[str]: sentence pieces str code
"""
stats = {"num_empty": 0, "num_filtered": 0}
def valid(line):
return True
def encode(l):
return self.sp.EncodeAsPieces(l)
def encode_line(line):
line = line.strip()
if len(line) > 0:
line = encode(line)
if valid(line):
return line
else:
stats["num_filtered"] += 1
else:
stats["num_empty"] += 1
return None
enc_line = encode_line(text)
return enc_line
def spm_detokenize(self, tokens, input_format='piece'):
"""spm detokenize.
Args:
ids (List[str]): tokens.
Returns:
str: text
"""
if input_format == "piece":
def decode(l):
return "".join(self.sp.DecodePieces(l))
elif input_format == "id":
def decode(l):
return "".join(self.sp.DecodeIds(l))
return decode(tokens)
def _load_vocabulary_from_file(self, vocab: Union[str, list],
maskctc: bool):
"""Load vocabulary from file."""
if isinstance(vocab, list):
vocab_list = vocab
else:
vocab_list = load_dict(vocab, maskctc)
assert vocab_list is not None
logger.debug(f"Vocab: {pformat(vocab_list)}")
id2token = dict(
[(idx, token) for (idx, token) in enumerate(vocab_list)])
token2id = dict(
[(token, idx) for (idx, token) in enumerate(vocab_list)])
blank_id = vocab_list.index(BLANK) if BLANK in vocab_list else -1
maskctc_id = vocab_list.index(MASKCTC) if MASKCTC in vocab_list else -1
unk_id = vocab_list.index(UNK) if UNK in vocab_list else -1
eos_id = vocab_list.index(EOS) if EOS in vocab_list else -1
sos_id = vocab_list.index(SOS) if SOS in vocab_list else -1
space_id = vocab_list.index(SPACE) if SPACE in vocab_list else -1
logger.info(f"BLANK id: {blank_id}")
logger.info(f"UNK id: {unk_id}")
logger.info(f"EOS id: {eos_id}")
logger.info(f"SOS id: {sos_id}")
logger.info(f"SPACE id: {space_id}")
logger.info(f"MASKCTC id: {maskctc_id}")
return token2id, id2token, vocab_list, unk_id, eos_id, blank_id

@ -0,0 +1,393 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains data helper functions."""
import json
import math
import tarfile
from collections import namedtuple
from typing import List
from typing import Optional
from typing import Text
import jsonlines
import numpy as np
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
__all__ = [
"load_dict", "load_cmvn", "read_manifest", "rms_to_db", "rms_to_dbfs",
"max_dbfs", "mean_dbfs", "gain_db_to_ratio", "normalize_audio", "SOS",
"EOS", "UNK", "BLANK", "MASKCTC", "SPACE", "convert_samples_to_float32",
"convert_samples_from_float32"
]
IGNORE_ID = -1
# `sos` and `eos` using same token
SOS = "<eos>"
EOS = SOS
UNK = "<unk>"
BLANK = "<blank>"
MASKCTC = "<mask>"
SPACE = "<space>"
def load_dict(dict_path: Optional[Text], maskctc=False) -> Optional[List[Text]]:
if dict_path is None:
return None
with open(dict_path, "r") as f:
dictionary = f.readlines()
# first token is `<blank>`
# multi line: `<blank> 0\n`
# one line: `<blank>`
# space is relpace with <space>
char_list = [entry[:-1].split(" ")[0] for entry in dictionary]
if BLANK not in char_list:
char_list.insert(0, BLANK)
if EOS not in char_list:
char_list.append(EOS)
# for non-autoregressive maskctc model
if maskctc and MASKCTC not in char_list:
char_list.append(MASKCTC)
return char_list
def read_manifest(
manifest_path,
max_input_len=float('inf'),
min_input_len=0.0,
max_output_len=float('inf'),
min_output_len=0.0,
max_output_input_ratio=float('inf'),
min_output_input_ratio=0.0, ):
"""Load and parse manifest file.
Args:
manifest_path ([type]): Manifest file to load and parse.
max_input_len ([type], optional): maximum output seq length,
in seconds for raw wav, in frame numbers for feature data.
Defaults to float('inf').
min_input_len (float, optional): minimum input seq length,
in seconds for raw wav, in frame numbers for feature data.
Defaults to 0.0.
max_output_len (float, optional): maximum input seq length,
in modeling units. Defaults to 500.0.
min_output_len (float, optional): minimum input seq length,
in modeling units. Defaults to 0.0.
max_output_input_ratio (float, optional):
maximum output seq length/output seq length ratio. Defaults to 10.0.
min_output_input_ratio (float, optional):
minimum output seq length/output seq length ratio. Defaults to 0.05.
Raises:
IOError: If failed to parse the manifest.
Returns:
List[dict]: Manifest parsing results.
"""
manifest = []
with jsonlines.open(manifest_path, 'r') as reader:
for json_data in reader:
feat_len = json_data["input"][0]["shape"][
0] if "input" in json_data and "shape" in json_data["input"][
0] else 1.0
token_len = json_data["output"][0]["shape"][
0] if "output" in json_data and "shape" in json_data["output"][
0] else 1.0
conditions = [
feat_len >= min_input_len,
feat_len <= max_input_len,
token_len >= min_output_len,
token_len <= max_output_len,
token_len / feat_len >= min_output_input_ratio,
token_len / feat_len <= max_output_input_ratio,
]
if all(conditions):
manifest.append(json_data)
return manifest
# Tar File read
TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object'])
def parse_tar(file):
"""Parse a tar file to get a tarfile object
and a map containing tarinfoes
"""
result = {}
f = tarfile.open(file)
for tarinfo in f.getmembers():
result[tarinfo.name] = tarinfo
return f, result
def subfile_from_tar(file, local_data=None):
"""Get subfile object from tar.
tar:tarpath#filename
It will return a subfile object from tar file
and cached tar file info for next reading request.
"""
tarpath, filename = file.split(':', 1)[1].split('#', 1)
if local_data is None:
local_data = TarLocalData(tar2info={}, tar2object={})
assert isinstance(local_data, TarLocalData)
if 'tar2info' not in local_data.__dict__:
local_data.tar2info = {}
if 'tar2object' not in local_data.__dict__:
local_data.tar2object = {}
if tarpath not in local_data.tar2info:
fobj, infos = parse_tar(tarpath)
local_data.tar2info[tarpath] = infos
local_data.tar2object[tarpath] = fobj
else:
fobj = local_data.tar2object[tarpath]
infos = local_data.tar2info[tarpath]
return fobj.extractfile(infos[filename])
def rms_to_db(rms: float):
"""Root Mean Square to dB.
Args:
rms ([float]): root mean square
Returns:
float: dB
"""
return 20.0 * math.log10(max(1e-16, rms))
def rms_to_dbfs(rms: float):
"""Root Mean Square to dBFS.
https://fireattack.wordpress.com/2017/02/06/replaygain-loudness-normalization-and-applications/
Audio is mix of sine wave, so 1 amp sine wave's Full scale is 0.7071, equal to -3.0103dB.
dB = dBFS + 3.0103
dBFS = db - 3.0103
e.g. 0 dB = -3.0103 dBFS
Args:
rms ([float]): root mean square
Returns:
float: dBFS
"""
return rms_to_db(rms) - 3.0103
def max_dbfs(sample_data: np.ndarray):
"""Peak dBFS based on the maximum energy sample.
Args:
sample_data ([np.ndarray]): float array, [-1, 1].
Returns:
float: dBFS
"""
# Peak dBFS based on the maximum energy sample. Will prevent overdrive if used for normalization.
return rms_to_dbfs(max(abs(np.min(sample_data)), abs(np.max(sample_data))))
def mean_dbfs(sample_data):
"""Peak dBFS based on the RMS energy.
Args:
sample_data ([np.ndarray]): float array, [-1, 1].
Returns:
float: dBFS
"""
return rms_to_dbfs(
math.sqrt(np.mean(np.square(sample_data, dtype=np.float64))))
def gain_db_to_ratio(gain_db: float):
"""dB to ratio
Args:
gain_db (float): gain in dB
Returns:
float: scale in amp
"""
return math.pow(10.0, gain_db / 20.0)
def normalize_audio(sample_data: np.ndarray, dbfs: float=-3.0103):
"""Nomalize audio to dBFS.
Args:
sample_data (np.ndarray): input wave samples, [-1, 1].
dbfs (float, optional): target dBFS. Defaults to -3.0103.
Returns:
np.ndarray: normalized wave
"""
return np.maximum(
np.minimum(sample_data * gain_db_to_ratio(dbfs - max_dbfs(sample_data)),
1.0), -1.0)
def _load_json_cmvn(json_cmvn_file):
""" Load the json format cmvn stats file and calculate cmvn
Args:
json_cmvn_file: cmvn stats file in json format
Returns:
a numpy array of [means, vars]
"""
with open(json_cmvn_file) as f:
cmvn_stats = json.load(f)
means = cmvn_stats['mean_stat']
variance = cmvn_stats['var_stat']
count = cmvn_stats['frame_num']
for i in range(len(means)):
means[i] /= count
variance[i] = variance[i] / count - means[i] * means[i]
if variance[i] < 1.0e-20:
variance[i] = 1.0e-20
variance[i] = 1.0 / math.sqrt(variance[i])
cmvn = np.array([means, variance])
return cmvn
def _load_kaldi_cmvn(kaldi_cmvn_file):
""" Load the kaldi format cmvn stats file and calculate cmvn
Args:
kaldi_cmvn_file: kaldi text style global cmvn file, which
is generated by:
compute-cmvn-stats --binary=false scp:feats.scp global_cmvn
Returns:
a numpy array of [means, vars]
"""
means = []
variance = []
with open(kaldi_cmvn_file, 'r') as fid:
# kaldi binary file start with '\0B'
if fid.read(2) == '\0B':
logger.error('kaldi cmvn binary file is not supported, please '
'recompute it by: compute-cmvn-stats --binary=false '
' scp:feats.scp global_cmvn')
sys.exit(1)
fid.seek(0)
arr = fid.read().split()
assert (arr[0] == '[')
assert (arr[-2] == '0')
assert (arr[-1] == ']')
feat_dim = int((len(arr) - 2 - 2) / 2)
for i in range(1, feat_dim + 1):
means.append(float(arr[i]))
count = float(arr[feat_dim + 1])
for i in range(feat_dim + 2, 2 * feat_dim + 2):
variance.append(float(arr[i]))
for i in range(len(means)):
means[i] /= count
variance[i] = variance[i] / count - means[i] * means[i]
if variance[i] < 1.0e-20:
variance[i] = 1.0e-20
variance[i] = 1.0 / math.sqrt(variance[i])
cmvn = np.array([means, variance])
return cmvn
def load_cmvn(cmvn_file: str, filetype: str):
"""load cmvn from file.
Args:
cmvn_file (str): cmvn path.
filetype (str): file type, optional[npz, json, kaldi].
Raises:
ValueError: file type not support.
Returns:
Tuple[np.ndarray, np.ndarray]: mean, istd
"""
assert filetype in ['npz', 'json', 'kaldi'], filetype
filetype = filetype.lower()
if filetype == "json":
cmvn = _load_json_cmvn(cmvn_file)
elif filetype == "kaldi":
cmvn = _load_kaldi_cmvn(cmvn_file)
elif filetype == "npz":
eps = 1e-14
npzfile = np.load(cmvn_file)
mean = np.squeeze(npzfile["mean"])
std = np.squeeze(npzfile["std"])
istd = 1 / (std + eps)
cmvn = [mean, istd]
else:
raise ValueError(f"cmvn file type no support: {filetype}")
return cmvn[0], cmvn[1]
def convert_samples_to_float32(samples):
"""Convert sample type to float32.
Audio sample type is usually integer or float-point.
Integers will be scaled to [-1, 1] in float32.
PCM16 -> PCM32
"""
float32_samples = samples.astype('float32')
if samples.dtype in np.sctypes['int']:
bits = np.iinfo(samples.dtype).bits
float32_samples *= (1. / 2**(bits - 1))
elif samples.dtype in np.sctypes['float']:
pass
else:
raise TypeError("Unsupported sample type: %s." % samples.dtype)
return float32_samples
def convert_samples_from_float32(samples, dtype):
"""Convert sample type from float32 to dtype.
Audio sample type is usually integer or float-point. For integer
type, float32 will be rescaled from [-1, 1] to the maximum range
supported by the integer type.
PCM32 -> PCM16
"""
dtype = np.dtype(dtype)
output_samples = samples.copy()
if dtype in np.sctypes['int']:
bits = np.iinfo(dtype).bits
output_samples *= (2**(bits - 1) / 1.)
min_val = np.iinfo(dtype).min
max_val = np.iinfo(dtype).max
output_samples[output_samples > max_val] = max_val
output_samples[output_samples < min_val] = min_val
elif samples.dtype in np.sctypes['float']:
min_val = np.finfo(dtype).min
max_val = np.finfo(dtype).max
output_samples[output_samples > max_val] = max_val
output_samples[output_samples < min_val] = min_val
else:
raise TypeError("Unsupported sample type: %s." % samples.dtype)
return output_samples.astype(dtype)

@ -14,8 +14,8 @@
# Modified from espnet(https://github.com/espnet/espnet) # Modified from espnet(https://github.com/espnet/espnet)
import inspect import inspect
from paddlespeech.s2t.transform.transform_interface import TransformInterface from paddlespeech.audio.transform.transform_interface import TransformInterface
from paddlespeech.s2t.utils.check_kwargs import check_kwargs from paddlespeech.audio.utils.check_kwargs import check_kwargs
class FuncTrans(TransformInterface): class FuncTrans(TransformInterface):

@ -17,8 +17,97 @@ import numpy
import scipy import scipy
import soundfile import soundfile
from paddlespeech.s2t.io.reader import SoundHDF5File import io
import os
import h5py
import numpy as np
class SoundHDF5File():
"""Collecting sound files to a HDF5 file
>>> f = SoundHDF5File('a.flac.h5', mode='a')
>>> array = np.random.randint(0, 100, 100, dtype=np.int16)
>>> f['id'] = (array, 16000)
>>> array, rate = f['id']
:param: str filepath:
:param: str mode:
:param: str format: The type used when saving wav. flac, nist, htk, etc.
:param: str dtype:
"""
def __init__(self,
filepath,
mode="r+",
format=None,
dtype="int16",
**kwargs):
self.filepath = filepath
self.mode = mode
self.dtype = dtype
self.file = h5py.File(filepath, mode, **kwargs)
if format is None:
# filepath = a.flac.h5 -> format = flac
second_ext = os.path.splitext(os.path.splitext(filepath)[0])[1]
format = second_ext[1:]
if format.upper() not in soundfile.available_formats():
# If not found, flac is selected
format = "flac"
# This format affects only saving
self.format = format
def __repr__(self):
return '<SoundHDF5 file "{}" (mode {}, format {}, type {})>'.format(
self.filepath, self.mode, self.format, self.dtype)
def create_dataset(self, name, shape=None, data=None, **kwds):
f = io.BytesIO()
array, rate = data
soundfile.write(f, array, rate, format=self.format)
self.file.create_dataset(
name, shape=shape, data=np.void(f.getvalue()), **kwds)
def __setitem__(self, name, data):
self.create_dataset(name, data=data)
def __getitem__(self, key):
data = self.file[key][()]
f = io.BytesIO(data.tobytes())
array, rate = soundfile.read(f, dtype=self.dtype)
return array, rate
def keys(self):
return self.file.keys()
def values(self):
for k in self.file:
yield self[k]
def items(self):
for k in self.file:
yield k, self[k]
def __iter__(self):
return iter(self.file)
def __contains__(self, item):
return item in self.file
def __len__(self, item):
return len(self.file)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.file.close()
def close(self):
self.file.close()
class SpeedPerturbation(): class SpeedPerturbation():
"""SpeedPerturbation """SpeedPerturbation
@ -469,3 +558,4 @@ class RIRConvolve():
[scipy.convolve(x, r, mode="same") for r in rir], axis=-1) [scipy.convolve(x, r, mode="same") for r in rir], axis=-1)
else: else:
return scipy.convolve(x, rir, mode="same") return scipy.convolve(x, rir, mode="same")

@ -17,7 +17,7 @@ import random
import numpy import numpy
from PIL import Image from PIL import Image
from paddlespeech.s2t.transform.functional import FuncTrans from .functional import FuncTrans
def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"): def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"):

@ -17,7 +17,7 @@ import numpy as np
import paddle import paddle
from python_speech_features import logfbank from python_speech_features import logfbank
import paddlespeech.audio.compliance.kaldi as kaldi from ..compliance import kaldi
def stft(x, def stft(x,

@ -22,32 +22,32 @@ from inspect import signature
import yaml import yaml
from paddlespeech.s2t.utils.dynamic_import import dynamic_import from ..utils.dynamic_import import dynamic_import
import_alias = dict( import_alias = dict(
identity="paddlespeech.s2t.transform.transform_interface:Identity", identity="paddlespeech.audio.transform.transform_interface:Identity",
time_warp="paddlespeech.s2t.transform.spec_augment:TimeWarp", time_warp="paddlespeech.audio.transform.spec_augment:TimeWarp",
time_mask="paddlespeech.s2t.transform.spec_augment:TimeMask", time_mask="paddlespeech.audio.transform.spec_augment:TimeMask",
freq_mask="paddlespeech.s2t.transform.spec_augment:FreqMask", freq_mask="paddlespeech.audio.transform.spec_augment:FreqMask",
spec_augment="paddlespeech.s2t.transform.spec_augment:SpecAugment", spec_augment="paddlespeech.audio.transform.spec_augment:SpecAugment",
speed_perturbation="paddlespeech.s2t.transform.perturb:SpeedPerturbation", speed_perturbation="paddlespeech.audio.transform.perturb:SpeedPerturbation",
speed_perturbation_sox="paddlespeech.s2t.transform.perturb:SpeedPerturbationSox", speed_perturbation_sox="paddlespeech.audio.transform.perturb:SpeedPerturbationSox",
volume_perturbation="paddlespeech.s2t.transform.perturb:VolumePerturbation", volume_perturbation="paddlespeech.audio.transform.perturb:VolumePerturbation",
noise_injection="paddlespeech.s2t.transform.perturb:NoiseInjection", noise_injection="paddlespeech.audio.transform.perturb:NoiseInjection",
bandpass_perturbation="paddlespeech.s2t.transform.perturb:BandpassPerturbation", bandpass_perturbation="paddlespeech.audio.transform.perturb:BandpassPerturbation",
rir_convolve="paddlespeech.s2t.transform.perturb:RIRConvolve", rir_convolve="paddlespeech.audio.transform.perturb:RIRConvolve",
delta="paddlespeech.s2t.transform.add_deltas:AddDeltas", delta="paddlespeech.audio.transform.add_deltas:AddDeltas",
cmvn="paddlespeech.s2t.transform.cmvn:CMVN", cmvn="paddlespeech.audio.transform.cmvn:CMVN",
utterance_cmvn="paddlespeech.s2t.transform.cmvn:UtteranceCMVN", utterance_cmvn="paddlespeech.audio.transform.cmvn:UtteranceCMVN",
fbank="paddlespeech.s2t.transform.spectrogram:LogMelSpectrogram", fbank="paddlespeech.audio.transform.spectrogram:LogMelSpectrogram",
spectrogram="paddlespeech.s2t.transform.spectrogram:Spectrogram", spectrogram="paddlespeech.audio.transform.spectrogram:Spectrogram",
stft="paddlespeech.s2t.transform.spectrogram:Stft", stft="paddlespeech.audio.transform.spectrogram:Stft",
istft="paddlespeech.s2t.transform.spectrogram:IStft", istft="paddlespeech.audio.transform.spectrogram:IStft",
stft2fbank="paddlespeech.s2t.transform.spectrogram:Stft2LogMelSpectrogram", stft2fbank="paddlespeech.audio.transform.spectrogram:Stft2LogMelSpectrogram",
wpe="paddlespeech.s2t.transform.wpe:WPE", wpe="paddlespeech.audio.transform.wpe:WPE",
channel_selector="paddlespeech.s2t.transform.channel_selector:ChannelSelector", channel_selector="paddlespeech.audio.transform.channel_selector:ChannelSelector",
fbank_kaldi="paddlespeech.s2t.transform.spectrogram:LogMelSpectrogramKaldi", fbank_kaldi="paddlespeech.audio.transform.spectrogram:LogMelSpectrogramKaldi",
cmvn_json="paddlespeech.s2t.transform.cmvn:GlobalCMVN") cmvn_json="paddlespeech.audio.transform.cmvn:GlobalCMVN")
class Transformation(): class Transformation():

@ -0,0 +1,35 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import inspect
def check_kwargs(func, kwargs, name=None):
"""check kwargs are valid for func
If kwargs are invalid, raise TypeError as same as python default
:param function func: function to be validated
:param dict kwargs: keyword arguments for func
:param str name: name used in TypeError (default is func name)
"""
try:
params = inspect.signature(func).parameters
except ValueError:
return
if name is None:
name = func.__name__
for k in kwargs.keys():
if k not in params:
raise TypeError(
f"{name}() got an unexpected keyword argument '{k}'")

@ -0,0 +1,38 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import importlib
__all__ = ["dynamic_import"]
def dynamic_import(import_path, alias=dict()):
"""dynamic import module and class
:param str import_path: syntax 'module_name:class_name'
e.g., 'paddlespeech.s2t.models.u2:U2Model'
:param dict alias: shortcut for registered class
:return: imported class
"""
if import_path not in alias and ":" not in import_path:
raise ValueError(
"import_path should be one of {} or "
'include ":", e.g. "paddlespeech.s2t.models.u2:U2Model" : '
"{}".format(set(alias), import_path))
if ":" not in import_path:
import_path = alias[import_path]
module_name, objname = import_path.split(":")
m = importlib.import_module(module_name)
return getattr(m, objname)

@ -65,6 +65,7 @@ class Logger(object):
def __init__(self, name: str=None): def __init__(self, name: str=None):
name = 'PaddleAudio' if not name else name name = 'PaddleAudio' if not name else name
self.name = name
self.logger = logging.getLogger(name) self.logger = logging.getLogger(name)
for key, conf in log_config.items(): for key, conf in log_config.items():
@ -101,7 +102,7 @@ class Logger(object):
if not self.is_enable: if not self.is_enable:
return return
self.logger.log(log_level, msg) self.logger.log(log_level, self.name + " | " + msg)
@contextlib.contextmanager @contextlib.contextmanager
def use_terminator(self, terminator: str): def use_terminator(self, terminator: str):

@ -0,0 +1,192 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unility functions for Transformer."""
from typing import List
from typing import Tuple
import paddle
from .log import Logger
__all__ = ["pad_sequence", "add_sos_eos", "th_accuracy", "has_tensor"]
logger = Logger(__name__)
def has_tensor(val):
if isinstance(val, (list, tuple)):
for item in val:
if has_tensor(item):
return True
elif isinstance(val, dict):
for k, v in val.items():
print(k)
if has_tensor(v):
return True
else:
return paddle.is_tensor(val)
def pad_sequence(sequences: List[paddle.Tensor],
batch_first: bool=False,
padding_value: float=0.0) -> paddle.Tensor:
r"""Pad a list of variable length Tensors with ``padding_value``
``pad_sequence`` stacks a list of Tensors along a new dimension,
and pads them to equal length. For example, if the input is list of
sequences with size ``L x *`` and if batch_first is False, and ``T x B x *``
otherwise.
`B` is batch size. It is equal to the number of elements in ``sequences``.
`T` is length of the longest sequence.
`L` is length of the sequence.
`*` is any number of trailing dimensions, including none.
Example:
>>> from paddle.nn.utils.rnn import pad_sequence
>>> a = paddle.ones(25, 300)
>>> b = paddle.ones(22, 300)
>>> c = paddle.ones(15, 300)
>>> pad_sequence([a, b, c]).shape
paddle.Tensor([25, 3, 300])
Note:
This function returns a Tensor of size ``T x B x *`` or ``B x T x *``
where `T` is the length of the longest sequence. This function assumes
trailing dimensions and type of all the Tensors in sequences are same.
Args:
sequences (list[Tensor]): list of variable length sequences.
batch_first (bool, optional): output will be in ``B x T x *`` if True, or in
``T x B x *`` otherwise
padding_value (float, optional): value for padded elements. Default: 0.
Returns:
Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``.
Tensor of size ``B x T x *`` otherwise
"""
# assuming trailing dimensions and type of all the Tensors
# in sequences are same and fetching those from sequences[0]
max_size = paddle.shape(sequences[0])
# (TODO Hui Zhang): slice not supprot `end==start`
# trailing_dims = max_size[1:]
trailing_dims = tuple(
max_size[1:].numpy().tolist()) if sequences[0].ndim >= 2 else ()
max_len = max([s.shape[0] for s in sequences])
if batch_first:
out_dims = (len(sequences), max_len) + trailing_dims
else:
out_dims = (max_len, len(sequences)) + trailing_dims
out_tensor = paddle.full(out_dims, padding_value, sequences[0].dtype)
for i, tensor in enumerate(sequences):
length = tensor.shape[0]
# use index notation to prevent duplicate references to the tensor
if batch_first:
# TODO (Hui Zhang): set_value op not supprot `end==start`
# TODO (Hui Zhang): set_value op not support int16
# TODO (Hui Zhang): set_varbase 2 rank not support [0,0,...]
# out_tensor[i, :length, ...] = tensor
if length != 0:
out_tensor[i, :length] = tensor
else:
out_tensor[i, length] = tensor
else:
# TODO (Hui Zhang): set_value op not supprot `end==start`
# out_tensor[:length, i, ...] = tensor
if length != 0:
out_tensor[:length, i] = tensor
else:
out_tensor[length, i] = tensor
return out_tensor
def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
ignore_id: int) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Add <sos> and <eos> labels.
Args:
ys_pad (paddle.Tensor): batch of padded target sequences (B, Lmax)
sos (int): index of <sos>
eos (int): index of <eeos>
ignore_id (int): index of padding
Returns:
ys_in (paddle.Tensor) : (B, Lmax + 1)
ys_out (paddle.Tensor) : (B, Lmax + 1)
Examples:
>>> sos_id = 10
>>> eos_id = 11
>>> ignore_id = -1
>>> ys_pad
tensor([[ 1, 2, 3, 4, 5],
[ 4, 5, 6, -1, -1],
[ 7, 8, 9, -1, -1]], dtype=paddle.int32)
>>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id)
>>> ys_in
tensor([[10, 1, 2, 3, 4, 5],
[10, 4, 5, 6, 11, 11],
[10, 7, 8, 9, 11, 11]])
>>> ys_out
tensor([[ 1, 2, 3, 4, 5, 11],
[ 4, 5, 6, 11, -1, -1],
[ 7, 8, 9, 11, -1, -1]])
"""
# TODO(Hui Zhang): using comment code,
#_sos = paddle.to_tensor(
# [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
#_eos = paddle.to_tensor(
# [eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
#ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
#ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]
#ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]
#return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)
B = ys_pad.shape[0]
_sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos
_eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos
ys_in = paddle.cat([_sos, ys_pad], dim=1)
mask_pad = (ys_in == ignore_id)
ys_in = ys_in.masked_fill(mask_pad, eos)
ys_out = paddle.cat([ys_pad, _eos], dim=1)
ys_out = ys_out.masked_fill(mask_pad, eos)
mask_eos = (ys_out == ignore_id)
ys_out = ys_out.masked_fill(mask_eos, eos)
ys_out = ys_out.masked_fill(mask_pad, ignore_id)
return ys_in, ys_out
def th_accuracy(pad_outputs: paddle.Tensor,
pad_targets: paddle.Tensor,
ignore_label: int) -> float:
"""Calculate accuracy.
Args:
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
ignore_label (int): Ignore label id.
Returns:
float: Accuracy value (0.0 - 1.0).
"""
pad_pred = pad_outputs.view(pad_targets.shape[0], pad_targets.shape[1],
pad_outputs.shape[1]).argmax(2)
mask = pad_targets != ignore_label
#TODO(Hui Zhang): sum not support bool type
# numerator = paddle.sum(
# pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
numerator = (
pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
numerator = paddle.sum(numerator.type_as(pad_targets))
#TODO(Hui Zhang): sum not support bool type
# denominator = paddle.sum(mask)
denominator = paddle.sum(mask.type_as(pad_targets))
return float(numerator) / float(denominator)

@ -33,8 +33,8 @@ from ..log import logger
from ..utils import CLI_TIMER from ..utils import CLI_TIMER
from ..utils import stats_wrapper from ..utils import stats_wrapper
from ..utils import timer_register from ..utils import timer_register
from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig
__all__ = ['ASRExecutor'] __all__ = ['ASRExecutor']

@ -23,7 +23,7 @@ import paddle
from paddle import distributed as dist from paddle import distributed as dist
from paddle import inference from paddle import inference
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.audio.text.text_featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import BatchDataLoader from paddlespeech.s2t.io.dataloader import BatchDataLoader
from paddlespeech.s2t.models.ds2 import DeepSpeech2InferModel from paddlespeech.s2t.models.ds2 import DeepSpeech2InferModel
from paddlespeech.s2t.models.ds2 import DeepSpeech2Model from paddlespeech.s2t.models.ds2 import DeepSpeech2Model

@ -20,10 +20,10 @@ import paddle
import soundfile import soundfile
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.models.u2 import U2Model from paddlespeech.s2t.models.u2 import U2Model
from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig
logger = Log(__name__).getlog() logger = Log(__name__).getlog()

@ -26,6 +26,8 @@ from paddle import distributed as dist
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import BatchDataLoader from paddlespeech.s2t.io.dataloader import BatchDataLoader
from paddlespeech.s2t.io.dataloader import StreamDataLoader
from paddlespeech.s2t.io.dataloader import DataLoaderFactory
from paddlespeech.s2t.models.u2 import U2Model from paddlespeech.s2t.models.u2 import U2Model
from paddlespeech.s2t.training.optimizer import OptimizerFactory from paddlespeech.s2t.training.optimizer import OptimizerFactory
from paddlespeech.s2t.training.reporter import ObsScope from paddlespeech.s2t.training.reporter import ObsScope
@ -106,6 +108,7 @@ class U2Trainer(Trainer):
@paddle.no_grad() @paddle.no_grad()
def valid(self): def valid(self):
self.model.eval() self.model.eval()
if not self.use_streamdata:
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
valid_losses = defaultdict(list) valid_losses = defaultdict(list)
num_seen_utts = 1 num_seen_utts = 1
@ -132,6 +135,7 @@ class U2Trainer(Trainer):
msg = f"Valid: Rank: {dist.get_rank()}, " msg = f"Valid: Rank: {dist.get_rank()}, "
msg += "epoch: {}, ".format(self.epoch) msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration) msg += "step: {}, ".format(self.iteration)
if not self.use_streamdata:
msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v) msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_dump.items()) for k, v in valid_dump.items())
@ -152,6 +156,7 @@ class U2Trainer(Trainer):
self.before_train() self.before_train()
if not self.use_streamdata:
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.n_epoch: while self.epoch < self.config.n_epoch:
with Timer("Epoch-Train Time Cost: {}"): with Timer("Epoch-Train Time Cost: {}"):
@ -170,6 +175,7 @@ class U2Trainer(Trainer):
self.train_batch(batch_index, batch, msg) self.train_batch(batch_index, batch, msg)
self.after_train_batch() self.after_train_batch()
report('iter', batch_index + 1) report('iter', batch_index + 1)
if not self.use_streamdata:
report('total', len(self.train_loader)) report('total', len(self.train_loader))
report('reader_cost', dataload_time) report('reader_cost', dataload_time)
observation['batch_cost'] = observation[ observation['batch_cost'] = observation[
@ -191,7 +197,6 @@ class U2Trainer(Trainer):
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
raise e raise e
with Timer("Eval Time Cost: {}"): with Timer("Eval Time Cost: {}"):
total_loss, num_seen_utts = self.valid() total_loss, num_seen_utts = self.valid()
if dist.get_world_size() > 1: if dist.get_world_size() > 1:
@ -218,92 +223,16 @@ class U2Trainer(Trainer):
def setup_dataloader(self): def setup_dataloader(self):
config = self.config.clone() config = self.config.clone()
self.use_streamdata = config.get("use_stream_data", False)
if self.train: if self.train:
# train/valid dataset, return token ids self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args)
self.train_loader = BatchDataLoader( self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args)
json_file=config.train_manifest,
train_mode=True,
sortagrad=config.sortagrad,
batch_size=config.batch_size,
maxlen_in=config.maxlen_in,
maxlen_out=config.maxlen_out,
minibatches=config.minibatches,
mini_batch_size=self.args.ngpu,
batch_count=config.batch_count,
batch_bins=config.batch_bins,
batch_frames_in=config.batch_frames_in,
batch_frames_out=config.batch_frames_out,
batch_frames_inout=config.batch_frames_inout,
preprocess_conf=config.preprocess_config,
n_iter_processes=config.num_workers,
subsampling_factor=1,
num_encs=1,
dist_sampler=config.get('dist_sampler', False),
shortest_first=False)
self.valid_loader = BatchDataLoader(
json_file=config.dev_manifest,
train_mode=False,
sortagrad=False,
batch_size=config.batch_size,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
mini_batch_size=self.args.ngpu,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=config.preprocess_config,
n_iter_processes=config.num_workers,
subsampling_factor=1,
num_encs=1,
dist_sampler=config.get('dist_sampler', False),
shortest_first=False)
logger.info("Setup train/valid Dataloader!") logger.info("Setup train/valid Dataloader!")
else: else:
decode_batch_size = config.get('decode', dict()).get( decode_batch_size = config.get('decode', dict()).get(
'decode_batch_size', 1) 'decode_batch_size', 1)
# test dataset, return raw text self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args)
self.test_loader = BatchDataLoader( self.align_loader = DataLoaderFactory.get_dataloader('align', config, self.args)
json_file=config.test_manifest,
train_mode=False,
sortagrad=False,
batch_size=decode_batch_size,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
mini_batch_size=1,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=config.preprocess_config,
n_iter_processes=1,
subsampling_factor=1,
num_encs=1)
self.align_loader = BatchDataLoader(
json_file=config.test_manifest,
train_mode=False,
sortagrad=False,
batch_size=decode_batch_size,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
mini_batch_size=1,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=config.preprocess_config,
n_iter_processes=1,
subsampling_factor=1,
num_encs=1)
logger.info("Setup test/align Dataloader!") logger.info("Setup test/align Dataloader!")
def setup_model(self): def setup_model(self):
@ -452,6 +381,7 @@ class U2Tester(U2Trainer):
def test(self): def test(self):
assert self.args.result_file assert self.args.result_file
self.model.eval() self.model.eval()
if not self.use_streamdata:
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
stride_ms = self.config.stride_ms stride_ms = self.config.stride_ms

@ -25,7 +25,7 @@ from paddle import distributed as dist
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.utility import load_dict from paddlespeech.s2t.frontend.utility import load_dict
from paddlespeech.s2t.io.dataloader import BatchDataLoader from paddlespeech.s2t.io.dataloader import DataLoaderFactory
from paddlespeech.s2t.models.u2 import U2Model from paddlespeech.s2t.models.u2 import U2Model
from paddlespeech.s2t.training.optimizer import OptimizerFactory from paddlespeech.s2t.training.optimizer import OptimizerFactory
from paddlespeech.s2t.training.scheduler import LRSchedulerFactory from paddlespeech.s2t.training.scheduler import LRSchedulerFactory
@ -104,6 +104,7 @@ class U2Trainer(Trainer):
@paddle.no_grad() @paddle.no_grad()
def valid(self): def valid(self):
self.model.eval() self.model.eval()
if not self.use_streamdata:
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
valid_losses = defaultdict(list) valid_losses = defaultdict(list)
num_seen_utts = 1 num_seen_utts = 1
@ -131,6 +132,7 @@ class U2Trainer(Trainer):
msg = f"Valid: Rank: {dist.get_rank()}, " msg = f"Valid: Rank: {dist.get_rank()}, "
msg += "epoch: {}, ".format(self.epoch) msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration) msg += "step: {}, ".format(self.iteration)
if not self.use_streamdata:
msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v) msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_dump.items()) for k, v in valid_dump.items())
@ -150,7 +152,7 @@ class U2Trainer(Trainer):
# paddle.jit.save(script_model, script_model_path) # paddle.jit.save(script_model, script_model_path)
self.before_train() self.before_train()
if not self.use_streamdata:
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.n_epoch: while self.epoch < self.config.n_epoch:
with Timer("Epoch-Train Time Cost: {}"): with Timer("Epoch-Train Time Cost: {}"):
@ -162,6 +164,7 @@ class U2Trainer(Trainer):
msg = "Train: Rank: {}, ".format(dist.get_rank()) msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch) msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration) msg += "step: {}, ".format(self.iteration)
if not self.use_streamdata:
msg += "batch : {}/{}, ".format(batch_index + 1, msg += "batch : {}/{}, ".format(batch_index + 1,
len(self.train_loader)) len(self.train_loader))
msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
@ -198,87 +201,23 @@ class U2Trainer(Trainer):
self.new_epoch() self.new_epoch()
def setup_dataloader(self): def setup_dataloader(self):
self.use_streamdata = config.get("use_stream_data", False)
if self.train:
config = self.config.clone()
self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args)
config = self.config.clone()
config['preprocess_config'] = None
self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args)
logger.info("Setup train/valid Dataloader!")
else:
config = self.config.clone()
config['preprocess_config'] = None
self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args)
config = self.config.clone() config = self.config.clone()
# train/valid dataset, return token ids config['preprocess_config'] = None
self.train_loader = BatchDataLoader( self.align_loader = DataLoaderFactory.get_dataloader('align', config, self.args)
json_file=config.train_manifest, logger.info("Setup test/align Dataloader!")
train_mode=True,
sortagrad=False,
batch_size=config.batch_size,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
mini_batch_size=self.args.ngpu,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=config.preprocess_config,
n_iter_processes=config.num_workers,
subsampling_factor=1,
num_encs=1)
self.valid_loader = BatchDataLoader(
json_file=config.dev_manifest,
train_mode=False,
sortagrad=False,
batch_size=config.batch_size,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
mini_batch_size=self.args.ngpu,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=None,
n_iter_processes=config.num_workers,
subsampling_factor=1,
num_encs=1)
decode_batch_size = config.get('decode', dict()).get(
'decode_batch_size', 1)
# test dataset, return raw text
self.test_loader = BatchDataLoader(
json_file=config.test_manifest,
train_mode=False,
sortagrad=False,
batch_size=decode_batch_size,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
mini_batch_size=1,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=None,
n_iter_processes=1,
subsampling_factor=1,
num_encs=1)
self.align_loader = BatchDataLoader(
json_file=config.test_manifest,
train_mode=False,
sortagrad=False,
batch_size=decode_batch_size,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
mini_batch_size=1,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=None,
n_iter_processes=1,
subsampling_factor=1,
num_encs=1)
logger.info("Setup train/valid/test/align Dataloader!")
def setup_model(self): def setup_model(self):
config = self.config config = self.config
@ -406,6 +345,7 @@ class U2Tester(U2Trainer):
def test(self): def test(self):
assert self.args.result_file assert self.args.result_file
self.model.eval() self.model.eval()
if not self.use_streamdata:
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
stride_ms = self.config.stride_ms stride_ms = self.config.stride_ms

@ -25,7 +25,7 @@ import paddle
from paddle import distributed as dist from paddle import distributed as dist
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import BatchDataLoader from paddlespeech.s2t.io.dataloader import DataLoaderFactory
from paddlespeech.s2t.models.u2_st import U2STModel from paddlespeech.s2t.models.u2_st import U2STModel
from paddlespeech.s2t.training.optimizer import OptimizerFactory from paddlespeech.s2t.training.optimizer import OptimizerFactory
from paddlespeech.s2t.training.reporter import ObsScope from paddlespeech.s2t.training.reporter import ObsScope
@ -120,6 +120,7 @@ class U2STTrainer(Trainer):
@paddle.no_grad() @paddle.no_grad()
def valid(self): def valid(self):
self.model.eval() self.model.eval()
if not self.use_streamdata:
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
valid_losses = defaultdict(list) valid_losses = defaultdict(list)
num_seen_utts = 1 num_seen_utts = 1
@ -153,6 +154,7 @@ class U2STTrainer(Trainer):
msg = f"Valid: Rank: {dist.get_rank()}, " msg = f"Valid: Rank: {dist.get_rank()}, "
msg += "epoch: {}, ".format(self.epoch) msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration) msg += "step: {}, ".format(self.iteration)
if not self.use_streamdata:
msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v) msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_dump.items()) for k, v in valid_dump.items())
@ -172,7 +174,7 @@ class U2STTrainer(Trainer):
# paddle.jit.save(script_model, script_model_path) # paddle.jit.save(script_model, script_model_path)
self.before_train() self.before_train()
if not self.use_streamdata:
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.n_epoch: while self.epoch < self.config.n_epoch:
with Timer("Epoch-Train Time Cost: {}"): with Timer("Epoch-Train Time Cost: {}"):
@ -191,6 +193,7 @@ class U2STTrainer(Trainer):
self.train_batch(batch_index, batch, msg) self.train_batch(batch_index, batch, msg)
self.after_train_batch() self.after_train_batch()
report('iter', batch_index + 1) report('iter', batch_index + 1)
if not self.use_streamdata:
report('total', len(self.train_loader)) report('total', len(self.train_loader))
report('reader_cost', dataload_time) report('reader_cost', dataload_time)
observation['batch_cost'] = observation[ observation['batch_cost'] = observation[
@ -241,79 +244,18 @@ class U2STTrainer(Trainer):
load_transcript = True if config.model_conf.asr_weight > 0 else False load_transcript = True if config.model_conf.asr_weight > 0 else False
config = self.config.clone()
config['load_transcript'] = load_transcript
self.use_streamdata = config.get("use_stream_data", False)
if self.train: if self.train:
# train/valid dataset, return token ids self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args)
self.train_loader = BatchDataLoader( self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args)
json_file=config.train_manifest,
train_mode=True,
sortagrad=False,
batch_size=config.batch_size,
maxlen_in=config.maxlen_in,
maxlen_out=config.maxlen_out,
minibatches=0,
mini_batch_size=1,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=config.
preprocess_config, # aug will be off when train_mode=False
n_iter_processes=config.num_workers,
subsampling_factor=1,
load_aux_output=load_transcript,
num_encs=1,
dist_sampler=True)
self.valid_loader = BatchDataLoader(
json_file=config.dev_manifest,
train_mode=False,
sortagrad=False,
batch_size=config.batch_size,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
mini_batch_size=1,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=config.
preprocess_config, # aug will be off when train_mode=False
n_iter_processes=config.num_workers,
subsampling_factor=1,
load_aux_output=load_transcript,
num_encs=1,
dist_sampler=False)
logger.info("Setup train/valid Dataloader!") logger.info("Setup train/valid Dataloader!")
else: else:
# test dataset, return raw text self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args)
decode_batch_size = config.get('decode', dict()).get(
'decode_batch_size', 1)
self.test_loader = BatchDataLoader(
json_file=config.test_manifest,
train_mode=False,
sortagrad=False,
batch_size=decode_batch_size,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
mini_batch_size=1,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=config.
preprocess_config, # aug will be off when train_mode=False
n_iter_processes=config.num_workers,
subsampling_factor=1,
num_encs=1,
dist_sampler=False)
logger.info("Setup test Dataloader!") logger.info("Setup test Dataloader!")
def setup_model(self): def setup_model(self):
config = self.config config = self.config
model_conf = config model_conf = config
@ -468,6 +410,7 @@ class U2STTester(U2STTrainer):
def test(self): def test(self):
assert self.args.result_file assert self.args.result_file
self.model.eval() self.model.eval()
if not self.use_streamdata:
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
decode_cfg = self.config.decode decode_cfg = self.config.decode

@ -18,6 +18,7 @@ from typing import Text
import jsonlines import jsonlines
import numpy as np import numpy as np
import paddle
from paddle.io import BatchSampler from paddle.io import BatchSampler
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler from paddle.io import DistributedBatchSampler
@ -28,7 +29,11 @@ from paddlespeech.s2t.io.dataset import TransformDataset
from paddlespeech.s2t.io.reader import LoadInputsAndTargets from paddlespeech.s2t.io.reader import LoadInputsAndTargets
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
__all__ = ["BatchDataLoader"] import paddlespeech.audio.streamdata as streamdata
from paddlespeech.audio.text.text_featurizer import TextFeaturizer
from yacs.config import CfgNode
__all__ = ["BatchDataLoader", "StreamDataLoader"]
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
@ -56,6 +61,136 @@ def batch_collate(x):
""" """
return x[0] return x[0]
def read_preprocess_cfg(preprocess_conf_file):
augment_conf = dict()
preprocess_cfg = CfgNode(new_allowed=True)
preprocess_cfg.merge_from_file(preprocess_conf_file)
for idx, process in enumerate(preprocess_cfg["process"]):
opts = dict(process)
process_type = opts.pop("type")
if process_type == 'time_warp':
augment_conf['max_w'] = process['max_time_warp']
augment_conf['w_inplace'] = process['inplace']
augment_conf['w_mode'] = process['mode']
if process_type == 'freq_mask':
augment_conf['max_f'] = process['F']
augment_conf['num_f_mask'] = process['n_mask']
augment_conf['f_inplace'] = process['inplace']
augment_conf['f_replace_with_zero'] = process['replace_with_zero']
if process_type == 'time_mask':
augment_conf['max_t'] = process['T']
augment_conf['num_t_mask'] = process['n_mask']
augment_conf['t_inplace'] = process['inplace']
augment_conf['t_replace_with_zero'] = process['replace_with_zero']
return augment_conf
class StreamDataLoader():
def __init__(self,
manifest_file: str,
train_mode: bool,
unit_type: str='char',
batch_size: int=0,
preprocess_conf=None,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
dither=0.0,
minlen_in: float=0.0,
maxlen_in: float=float('inf'),
minlen_out: float=0.0,
maxlen_out: float=float('inf'),
resample_rate: int=16000,
shuffle_size: int=10000,
sort_size: int=1000,
n_iter_processes: int=1,
prefetch_factor: int=2,
dist_sampler: bool=False,
cmvn_file="data/mean_std.json",
vocab_filepath='data/lang_char/vocab.txt'):
self.manifest_file = manifest_file
self.train_model = train_mode
self.batch_size = batch_size
self.prefetch_factor = prefetch_factor
self.dist_sampler = dist_sampler
self.n_iter_processes = n_iter_processes
text_featurizer = TextFeaturizer(unit_type, vocab_filepath)
symbol_table = text_featurizer.vocab_dict
self.feat_dim = num_mel_bins
self.vocab_size = text_featurizer.vocab_size
augment_conf = read_preprocess_cfg(preprocess_conf)
# The list of shard
shardlist = []
with open(manifest_file, "r") as f:
for line in f.readlines():
shardlist.append(line.strip())
world_size = 1
try:
world_size = paddle.distributed.get_world_size()
except Exception as e:
logger.warninig(e)
logger.warninig("can not get world_size using paddle.distributed.get_world_size(), use world_size=1")
assert(len(shardlist) >= world_size, "the length of shard list should >= number of gpus/xpus/...")
update_n_iter_processes = int(max(min(len(shardlist)/world_size - 1, self.n_iter_processes), 0))
logger.info(f"update_n_iter_processes {update_n_iter_processes}")
if update_n_iter_processes != self.n_iter_processes:
self.n_iter_processes = update_n_iter_processes
logger.info(f"change nun_workers to {self.n_iter_processes}")
if self.dist_sampler:
base_dataset = streamdata.DataPipeline(
streamdata.SimpleShardList(shardlist),
streamdata.split_by_node if train_mode else streamdata.placeholder(),
streamdata.split_by_worker,
streamdata.tarfile_to_samples(streamdata.reraise_exception)
)
else:
base_dataset = streamdata.DataPipeline(
streamdata.SimpleShardList(shardlist),
streamdata.split_by_worker,
streamdata.tarfile_to_samples(streamdata.reraise_exception)
)
self.dataset = base_dataset.append_list(
streamdata.audio_tokenize(symbol_table),
streamdata.audio_data_filter(frame_shift=frame_shift, max_length=maxlen_in, min_length=minlen_in, token_max_length=maxlen_out, token_min_length=minlen_out),
streamdata.audio_resample(resample_rate=resample_rate),
streamdata.audio_compute_fbank(num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither),
streamdata.audio_spec_aug(**augment_conf) if train_mode else streamdata.placeholder(), # num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80)
streamdata.shuffle(shuffle_size),
streamdata.sort(sort_size=sort_size),
streamdata.batched(batch_size),
streamdata.audio_padding(),
streamdata.audio_cmvn(cmvn_file)
)
if paddle.__version__ >= '2.3.2':
self.loader = streamdata.WebLoader(
self.dataset,
num_workers=self.n_iter_processes,
prefetch_factor = self.prefetch_factor,
batch_size=None
)
else:
self.loader = streamdata.WebLoader(
self.dataset,
num_workers=self.n_iter_processes,
batch_size=None
)
def __iter__(self):
return self.loader.__iter__()
def __call__(self):
return self.__iter__()
def __len__(self):
logger.info("Stream dataloader does not support calculate the length of the dataset")
return -1
class BatchDataLoader(): class BatchDataLoader():
def __init__(self, def __init__(self,
@ -199,3 +334,119 @@ class BatchDataLoader():
echo += f"shortest_first: {self.shortest_first}, " echo += f"shortest_first: {self.shortest_first}, "
echo += f"file: {self.json_file}" echo += f"file: {self.json_file}"
return echo return echo
class DataLoaderFactory():
@staticmethod
def get_dataloader(mode: str, config, args):
config = config.clone()
use_streamdata = config.get("use_stream_data", False)
if use_streamdata:
if mode == 'train':
config['manifest'] = config.train_manifest
config['train_mode'] = True
elif mode == 'valid':
config['manifest'] = config.dev_manifest
config['train_mode'] = False
elif model == 'test' or mode == 'align':
config['manifest'] = config.test_manifest
config['train_mode'] = False
config['dither'] = 0.0
config['minlen_in'] = 0.0
config['maxlen_in'] = float('inf')
config['minlen_out'] = 0
config['maxlen_out'] = float('inf')
config['dist_sampler'] = False
else:
raise KeyError("not valid mode type!!, please input one of 'train, valid, test, align'")
return StreamDataLoader(
manifest_file=config.manifest,
train_mode=config.train_mode,
unit_type=config.unit_type,
preprocess_conf=config.preprocess_config,
batch_size=config.batch_size,
num_mel_bins=config.feat_dim,
frame_length=config.window_ms,
frame_shift=config.stride_ms,
dither=config.dither,
minlen_in=config.minlen_in,
maxlen_in=config.maxlen_in,
minlen_out=config.minlen_out,
maxlen_out=config.maxlen_out,
resample_rate=config.resample_rate,
shuffle_size=config.shuffle_size,
sort_size=config.sort_size,
n_iter_processes=config.num_workers,
prefetch_factor=config.prefetch_factor,
dist_sampler=config.dist_sampler,
cmvn_file=config.cmvn_file,
vocab_filepath=config.vocab_filepath,
)
else:
if mode == 'train':
config['manifest'] = config.train_manifest
config['train_mode'] = True
config['mini_batch_size'] = args.ngpu
config['subsampling_factor'] = 1
config['num_encs'] = 1
elif mode == 'valid':
config['manifest'] = config.dev_manifest
config['train_mode'] = False
config['sortagrad'] = False
config['maxlen_in'] = float('inf')
config['maxlen_out'] = float('inf')
config['minibatches'] = 0
config['mini_batch_size'] = args.ngpu
config['batch_count'] = 'auto'
config['batch_bins'] = 0
config['batch_frames_in'] = 0
config['batch_frames_out'] = 0
config['batch_frames_inout'] = 0
config['subsampling_factor'] = 1
config['num_encs'] = 1
config['shortest_first'] = False
elif mode == 'test' or mode == 'align':
config['manifest'] = config.test_manifest
config['train_mode'] = False
config['sortagrad'] = False
config['batch_size'] = config.get('decode', dict()).get(
'decode_batch_size', 1)
config['maxlen_in'] = float('inf')
config['maxlen_out'] = float('inf')
config['minibatches'] = 0
config['mini_batch_size'] = 1
config['batch_count'] = 'auto'
config['batch_bins'] = 0
config['batch_frames_in'] = 0
config['batch_frames_out'] = 0
config['batch_frames_inout'] = 0
config['num_workers'] = 1
config['subsampling_factor'] = 1
config['num_encs'] = 1
config['dist_sampler'] = False
config['shortest_first'] = False
else:
raise KeyError("not valid mode type!!, please input one of 'train, valid, test, align'")
return BatchDataLoader(
json_file=config.manifest,
train_mode=config.train_mode,
sortagrad=config.sortagrad,
batch_size=config.batch_size,
maxlen_in=config.maxlen_in,
maxlen_out=config.maxlen_out,
minibatches=config.minibatches,
mini_batch_size=config.mini_batch_size,
batch_count=config.batch_count,
batch_bins=config.batch_bins,
batch_frames_in=config.batch_frames_in,
batch_frames_out=config.batch_frames_out,
batch_frames_inout=config.batch_frames_inout,
preprocess_conf=config.preprocess_config,
n_iter_processes=config.num_workers,
subsampling_factor=config.subsampling_factor,
load_aux_output=config.get('load_transcript', None),
num_encs=config.num_encs,
dist_sampler=config.dist_sampler,
shortest_first=config.shortest_first)

@ -19,7 +19,7 @@ import numpy as np
import soundfile import soundfile
from .utility import feat_type from .utility import feat_type
from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
# from paddlespeech.s2t.frontend.augmentor.augmentation import AugmentationPipeline as Transformation # from paddlespeech.s2t.frontend.augmentor.augmentation import AugmentationPipeline as Transformation

@ -48,9 +48,9 @@ from paddlespeech.s2t.utils import checkpoint
from paddlespeech.s2t.utils import layer_tools from paddlespeech.s2t.utils import layer_tools
from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.tensor_utils import add_sos_eos from paddlespeech.audio.utils.tensor_utils import add_sos_eos
from paddlespeech.s2t.utils.tensor_utils import pad_sequence from paddlespeech.audio.utils.tensor_utils import pad_sequence
from paddlespeech.s2t.utils.tensor_utils import th_accuracy from paddlespeech.audio.utils.tensor_utils import th_accuracy
from paddlespeech.s2t.utils.utility import log_add from paddlespeech.s2t.utils.utility import log_add
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig

@ -38,8 +38,8 @@ from paddlespeech.s2t.modules.mask import subsequent_mask
from paddlespeech.s2t.utils import checkpoint from paddlespeech.s2t.utils import checkpoint
from paddlespeech.s2t.utils import layer_tools from paddlespeech.s2t.utils import layer_tools
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.tensor_utils import add_sos_eos from paddlespeech.audio.utils.tensor_utils import add_sos_eos
from paddlespeech.s2t.utils.tensor_utils import th_accuracy from paddlespeech.audio.utils.tensor_utils import th_accuracy
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig
__all__ = ["U2STModel", "U2STInferModel"] __all__ = ["U2STModel", "U2STInferModel"]

@ -26,7 +26,7 @@ from paddlespeech.cli.log import logger
from paddlespeech.resource import CommonTaskResource from paddlespeech.resource import CommonTaskResource
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils import onnx_infer from paddlespeech.server.utils import onnx_infer

@ -24,9 +24,9 @@ from yacs.config import CfgNode
from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.resource import CommonTaskResource from paddlespeech.resource import CommonTaskResource
from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.paddle_predictor import init_predictor from paddlespeech.server.utils.paddle_predictor import init_predictor

@ -24,9 +24,9 @@ from yacs.config import CfgNode
from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.resource import CommonTaskResource from paddlespeech.resource import CommonTaskResource
from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.tensor_utils import add_sos_eos from paddlespeech.s2t.utils.tensor_utils import add_sos_eos
from paddlespeech.s2t.utils.tensor_utils import pad_sequence from paddlespeech.s2t.utils.tensor_utils import pad_sequence
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig

@ -69,7 +69,9 @@ base = [
"prettytable", "prettytable",
"zhon", "zhon",
"colorlog", "colorlog",
"pathos == 0.2.8" "pathos == 0.2.8",
"braceexpand",
"pyyaml"
] ]
server = [ server = [

Loading…
Cancel
Save