add training scripts

pull/2062/head
huangyuxin 2 years ago
parent c7a7b113c8
commit 0c7abc1f17

@ -1,7 +1,6 @@
############################################
# Network Architecture #
############################################
cmvn_file:
cmvn_file_type: "json"
# encoder related
encoder: conformer
@ -43,9 +42,9 @@ model_conf:
###########################################
# Data #
###########################################
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
train_manifest: data/train_l/data.list
dev_manifest: data/dev/data.list
test_manifest: data/test_meeting/data.list
###########################################
# Dataloader #
@ -54,23 +53,22 @@ use_stream_data: True
unit_type: 'char'
vocab_filepath: data/lang_char/vocab.txt
cmvn_file: data/mean_std.json
preprocess_config: conf/preprocess.yaml
spm_model_prefix: ''
feat_dim: 80
stride_ms: 10.0
window_ms: 25.0
dither: 0.1
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 64
batch_size: 32
minlen_in: 10
maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
maxlen_in: 1200 # if input length(number of frames) > maxlen-in, data is automatically removed
minlen_out: 0
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
maxlen_out: 150 # if output length(number of tokens) > maxlen-out, data is automatically removed
resample_rate: 16000
shuffle_size: 10000
sort_size: 500
num_workers: 4
prefetch_factor: 100
shuffle_size: 1500
sort_size: 1000
num_workers: 0
prefetch_factor: 10
dist_sampler: True
num_encs: 1
augment_conf:
@ -90,10 +88,10 @@ augment_conf:
###########################################
# Training #
###########################################
n_epoch: 240
accum_grad: 16
n_epoch: 30
accum_grad: 32
global_grad_clip: 5.0
log_interval: 1
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5

@ -2,6 +2,8 @@
# Copyright 2021 Mobvoi Inc(Author: Di Wu, Binbin Zhang)
# NPU, ASLP Group (Author: Qijie Shao)
#
# Modified from wenet(https://github.com/wenet-e2e/wenet)
stage=-1
stop_stage=100
@ -30,7 +32,7 @@ mkdir -p data
TARGET_DIR=${MAIN_ROOT}/dataset
mkdir -p ${TARGET_DIR}
if [ ${stage} -le -2 ] && [ ${stop_stage} -ge -2 ]; then
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
# download data
echo "Please follow https://github.com/wenet-e2e/WenetSpeech to download the data."
exit 0;
@ -44,86 +46,57 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
data || exit 1;
fi
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
# generate manifests
python3 ${TARGET_DIR}/aishell/aishell.py \
--manifest_prefix="data/manifest" \
--target_dir="${TARGET_DIR}/aishell"
if [ $? -ne 0 ]; then
echo "Prepare Aishell failed. Terminated."
exit 1
fi
for dataset in train dev test; do
mv data/manifest.${dataset} data/manifest.${dataset}.raw
done
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# compute mean and stddev for normalizer
if $cmvn; then
full_size=`cat data/${train_set}/wav.scp | wc -l`
sampling_size=$((full_size / cmvn_sampling_divisor))
shuf -n $sampling_size data/$train_set/wav.scp \
> data/$train_set/wav.scp.sampled
num_workers=$(nproc)
python3 ${MAIN_ROOT}/utils/compute_mean_std.py \
--manifest_path="data/manifest.train.raw" \
--spectrum_type="fbank" \
--feat_dim=80 \
--delta_delta=false \
--stride_ms=10 \
--window_ms=25 \
--sample_rate=16000 \
--use_dB_normalization=False \
--num_samples=-1 \
--num_workers=${num_workers} \
--output_path="data/mean_std.json"
if [ $? -ne 0 ]; then
echo "Compute mean and stddev failed. Terminated."
exit 1
fi
fi
fi
dict=data/dict/lang_char.txt
dict=data/lang_char/vocab.txt
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# download data, generate manifests
# build vocabulary
python3 ${MAIN_ROOT}/utils/build_vocab.py \
--unit_type="char" \
--count_threshold=0 \
--vocab_path="data/lang_char/vocab.txt" \
--manifest_paths "data/manifest.train.raw"
if [ $? -ne 0 ]; then
echo "Build vocabulary failed. Terminated."
exit 1
fi
echo "Make a dictionary"
echo "dictionary: ${dict}"
mkdir -p $(dirname $dict)
echo "<blank>" > ${dict} # 0 will be used for "blank" in CTC
echo "<unk>" >> ${dict} # <unk> must be 1
echo "▁" >> ${dict} # ▁ is for space
utils/text2token.py -s 1 -n 1 --space "▁" data/${train_set}/text \
| cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' \
| grep -v "▁" \
| awk '{print $0}' >> ${dict} \
|| exit 1;
echo "<eos>" >> $dict
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# format manifest with tokenids, vocab size
for dataset in train dev test; do
{
python3 ${MAIN_ROOT}/utils/format_data.py \
--cmvn_path "data/mean_std.json" \
--unit_type "char" \
--vocab_path="data/vocab.txt" \
--manifest_path="data/manifest.${dataset}.raw" \
--output_path="data/manifest.${dataset}"
echo "Compute cmvn"
# Here we use all the training data, you can sample some some data to save time
# BUG!!! We should use the segmented data for CMVN
if $cmvn; then
full_size=`cat data/${train_set}/wav.scp | wc -l`
sampling_size=$((full_size / cmvn_sampling_divisor))
shuf -n $sampling_size data/$train_set/wav.scp \
> data/$train_set/wav.scp.sampled
python3 utils/compute_cmvn_stats.py \
--num_workers 16 \
--train_config $train_config \
--in_scp data/$train_set/wav.scp.sampled \
--out_cmvn data/$train_set/mean_std.json \
|| exit 1;
fi
fi
if [ $? -ne 0 ]; then
echo "Formt mnaifest failed. Terminated."
exit 1
fi
} &
done
wait
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Making shards, please wait..."
RED='\033[0;31m'
NOCOLOR='\033[0m'
echo -e "It requires ${RED}1.2T ${NOCOLOR}space for $shards_dir, please make sure you have enough space"
echo -e "It takes about ${RED}12 ${NOCOLOR}hours with 32 threads"
for x in $dev_set $test_sets ${train_set}; do
dst=$shards_dir/$x
mkdir -p $dst
utils/make_filted_shard_list.py --resample 16000 --num_utts_per_shard 1000 \
--do_filter --num_node 1 --num_gpus_per_node 8 \
--num_threads 32 --segments data/$x/segments \
data/$x/wav.scp data/$x/text \
$(realpath $dst) data/$x/data.list
done
fi
echo "Aishell data preparation done."
echo "Wenetspeech data preparation done."
exit 0

@ -24,7 +24,7 @@ stage=1
prefix=
train_subset=L
. ./tools/parse_options.sh || exit 1;
. ./utils/parse_options.sh || exit 1;
filter_by_id () {
idlist=$1
@ -132,4 +132,4 @@ if [ $stage -le 2 ]; then
done
fi
echo "$0: Done"
echo "$0: Done"

@ -11,7 +11,7 @@ from .cache import (
pipe_cleaner,
)
from .compat import WebDataset, WebLoader, FluidWrapper
from webdataset.extradatasets import MockDataset, with_epoch, with_length
from .extradatasets import MockDataset, with_epoch, with_length
from .filters import (
associate,
batched,
@ -65,5 +65,5 @@ from .shardlists import (
)
from .tariterators import tarfile_samples, tarfile_to_samples
from .utils import PipelineStage, repeatedly
from webdataset.writer import ShardWriter, TarWriter, numpy_dumps
from webdataset.mix import RandomMix, RoundRobin
from .writer import ShardWriter, TarWriter, numpy_dumps
from .mix import RandomMix, RoundRobin

@ -0,0 +1,445 @@
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
"""Automatically decode webdataset samples."""
import io, json, os, pickle, re, tempfile
from functools import partial
import numpy as np
"""Extensions passed on to the image decoder."""
image_extensions = "jpg jpeg png ppm pgm pbm pnm".split()
################################################################
# handle basic datatypes
################################################################
def paddle_loads(data):
"""Load data using paddle.loads, importing paddle only if needed.
:param data: data to be decoded
"""
import io
import paddle
stream = io.BytesIO(data)
return paddle.load(stream)
def tenbin_loads(data):
from . import tenbin
return tenbin.decode_buffer(data)
def msgpack_loads(data):
import msgpack
return msgpack.unpackb(data)
def npy_loads(data):
import numpy.lib.format
stream = io.BytesIO(data)
return numpy.lib.format.read_array(stream)
def cbor_loads(data):
import cbor
return cbor.loads(data)
decoders = {
"txt": lambda data: data.decode("utf-8"),
"text": lambda data: data.decode("utf-8"),
"transcript": lambda data: data.decode("utf-8"),
"cls": lambda data: int(data),
"cls2": lambda data: int(data),
"index": lambda data: int(data),
"inx": lambda data: int(data),
"id": lambda data: int(data),
"json": lambda data: json.loads(data),
"jsn": lambda data: json.loads(data),
"pyd": lambda data: pickle.loads(data),
"pickle": lambda data: pickle.loads(data),
"pdparams": lambda data: paddle_loads(data),
"ten": tenbin_loads,
"tb": tenbin_loads,
"mp": msgpack_loads,
"msg": msgpack_loads,
"npy": npy_loads,
"npz": lambda data: np.load(io.BytesIO(data)),
"cbor": cbor_loads,
}
def basichandlers(key, data):
"""Handle basic file decoding.
This function is usually part of the post= decoders.
This handles the following forms of decoding:
- txt -> unicode string
- cls cls2 class count index inx id -> int
- json jsn -> JSON decoding
- pyd pickle -> pickle decoding
- pdparams -> paddle.loads
- ten tenbin -> fast tensor loading
- mp messagepack msg -> messagepack decoding
- npy -> Python NPY decoding
:param key: file name extension
:param data: binary data to be decoded
"""
extension = re.sub(r".*[.]", "", key)
if extension in decoders:
return decoders[extension](data)
return None
################################################################
# Generic extension handler.
################################################################
def call_extension_handler(key, data, f, extensions):
"""Call the function f with the given data if the key matches the extensions.
:param key: actual key found in the sample
:param data: binary data
:param f: decoder function
:param extensions: list of matching extensions
"""
extension = key.lower().split(".")
for target in extensions:
target = target.split(".")
if len(target) > len(extension):
continue
if extension[-len(target) :] == target:
return f(data)
return None
def handle_extension(extensions, f):
"""Return a decoder function for the list of extensions.
Extensions can be a space separated list of extensions.
Extensions can contain dots, in which case the corresponding number
of extension components must be present in the key given to f.
Comparisons are case insensitive.
Examples:
handle_extension("jpg jpeg", my_decode_jpg) # invoked for any file.jpg
handle_extension("seg.jpg", special_case_jpg) # invoked only for file.seg.jpg
"""
extensions = extensions.lower().split()
return partial(call_extension_handler, f=f, extensions=extensions)
################################################################
# handle images
################################################################
imagespecs = {
"l8": ("numpy", "uint8", "l"),
"rgb8": ("numpy", "uint8", "rgb"),
"rgba8": ("numpy", "uint8", "rgba"),
"l": ("numpy", "float", "l"),
"rgb": ("numpy", "float", "rgb"),
"rgba": ("numpy", "float", "rgba"),
"paddlel8": ("paddle", "uint8", "l"),
"paddlergb8": ("paddle", "uint8", "rgb"),
"paddlergba8": ("paddle", "uint8", "rgba"),
"paddlel": ("paddle", "float", "l"),
"paddlergb": ("paddle", "float", "rgb"),
"paddle": ("paddle", "float", "rgb"),
"paddlergba": ("paddle", "float", "rgba"),
"pill": ("pil", None, "l"),
"pil": ("pil", None, "rgb"),
"pilrgb": ("pil", None, "rgb"),
"pilrgba": ("pil", None, "rgba"),
}
class ImageHandler:
"""Decode image data using the given `imagespec`.
The `imagespec` specifies whether the image is decoded
to numpy/paddle/pi, decoded to uint8/float, and decoded
to l/rgb/rgba:
- l8: numpy uint8 l
- rgb8: numpy uint8 rgb
- rgba8: numpy uint8 rgba
- l: numpy float l
- rgb: numpy float rgb
- rgba: numpy float rgba
- paddlel8: paddle uint8 l
- paddlergb8: paddle uint8 rgb
- paddlergba8: paddle uint8 rgba
- paddlel: paddle float l
- paddlergb: paddle float rgb
- paddle: paddle float rgb
- paddlergba: paddle float rgba
- pill: pil None l
- pil: pil None rgb
- pilrgb: pil None rgb
- pilrgba: pil None rgba
"""
def __init__(self, imagespec, extensions=image_extensions):
"""Create an image handler.
:param imagespec: short string indicating the type of decoding
:param extensions: list of extensions the image handler is invoked for
"""
if imagespec not in list(imagespecs.keys()):
raise ValueError("Unknown imagespec: %s" % imagespec)
self.imagespec = imagespec.lower()
self.extensions = extensions
def __call__(self, key, data):
"""Perform image decoding.
:param key: file name extension
:param data: binary data
"""
import PIL.Image
extension = re.sub(r".*[.]", "", key)
if extension.lower() not in self.extensions:
return None
imagespec = self.imagespec
atype, etype, mode = imagespecs[imagespec]
with io.BytesIO(data) as stream:
img = PIL.Image.open(stream)
img.load()
img = img.convert(mode.upper())
if atype == "pil":
return img
elif atype == "numpy":
result = np.asarray(img)
if result.dtype != np.uint8:
raise ValueError("ImageHandler: numpy image must be uint8")
if etype == "uint8":
return result
else:
return result.astype("f") / 255.0
elif atype == "paddle":
import paddle
result = np.asarray(img)
if result.dtype != np.uint8:
raise ValueError("ImageHandler: paddle image must be uint8")
if etype == "uint8":
result = np.array(result.transpose(2, 0, 1))
return paddle.tensor(result)
else:
result = np.array(result.transpose(2, 0, 1))
return paddle.tensor(result) / 255.0
return None
def imagehandler(imagespec, extensions=image_extensions):
"""Create an image handler.
This is just a lower case alias for ImageHander.
:param imagespec: textual image spec
:param extensions: list of extensions the handler should be applied for
"""
return ImageHandler(imagespec, extensions)
################################################################
# torch video
################################################################
'''
def torch_video(key, data):
"""Decode video using the torchvideo library.
:param key: file name extension
:param data: data to be decoded
"""
extension = re.sub(r".*[.]", "", key)
if extension not in "mp4 ogv mjpeg avi mov h264 mpg webm wmv".split():
return None
import torchvision.io
with tempfile.TemporaryDirectory() as dirname:
fname = os.path.join(dirname, f"file.{extension}")
with open(fname, "wb") as stream:
stream.write(data)
return torchvision.io.read_video(fname, pts_unit="sec")
'''
################################################################
# paddleaudio
################################################################
def paddle_audio(key, data):
"""Decode audio using the paddleaudio library.
:param key: file name extension
:param data: data to be decoded
"""
extension = re.sub(r".*[.]", "", key)
if extension not in ["flac", "mp3", "sox", "wav", "m4a", "ogg", "wma"]:
return None
import paddleaudio
with tempfile.TemporaryDirectory() as dirname:
fname = os.path.join(dirname, f"file.{extension}")
with open(fname, "wb") as stream:
stream.write(data)
return paddleaudio.load(fname)
################################################################
# special class for continuing decoding
################################################################
class Continue:
"""Special class for continuing decoding.
This is mostly used for decompression, as in:
def decompressor(key, data):
if key.endswith(".gz"):
return Continue(key[:-3], decompress(data))
return None
"""
def __init__(self, key, data):
"""__init__.
:param key:
:param data:
"""
self.key, self.data = key, data
def gzfilter(key, data):
"""Decode .gz files.
This decodes compressed files and the continues decoding.
:param key: file name extension
:param data: binary data
"""
import gzip
if not key.endswith(".gz"):
return None
decompressed = gzip.open(io.BytesIO(data)).read()
return Continue(key[:-3], decompressed)
################################################################
# decode entire training amples
################################################################
default_pre_handlers = [gzfilter]
default_post_handlers = [basichandlers]
class Decoder:
"""Decode samples using a list of handlers.
For each key/data item, this iterates through the list of
handlers until some handler returns something other than None.
"""
def __init__(self, handlers, pre=None, post=None, only=None, partial=False):
"""Create a Decoder.
:param handlers: main list of handlers
:param pre: handlers called before the main list (.gz handler by default)
:param post: handlers called after the main list (default handlers by default)
:param only: a list of extensions; when give, only ignores files with those extensions
:param partial: allow partial decoding (i.e., don't decode fields that aren't of type bytes)
"""
if isinstance(only, str):
only = only.split()
self.only = only if only is None else set(only)
if pre is None:
pre = default_pre_handlers
if post is None:
post = default_post_handlers
assert all(callable(h) for h in handlers), f"one of {handlers} not callable"
assert all(callable(h) for h in pre), f"one of {pre} not callable"
assert all(callable(h) for h in post), f"one of {post} not callable"
self.handlers = pre + handlers + post
self.partial = partial
def decode1(self, key, data):
"""Decode a single field of a sample.
:param key: file name extension
:param data: binary data
"""
key = "." + key
for f in self.handlers:
result = f(key, data)
if isinstance(result, Continue):
key, data = result.key, result.data
continue
if result is not None:
return result
return data
def decode(self, sample):
"""Decode an entire sample.
:param sample: the sample, a dictionary of key value pairs
"""
result = {}
assert isinstance(sample, dict), sample
for k, v in list(sample.items()):
if k[0] == "_":
if isinstance(v, bytes):
v = v.decode("utf-8")
result[k] = v
continue
if self.only is not None and k not in self.only:
result[k] = v
continue
assert v is not None
if self.partial:
if isinstance(v, bytes):
result[k] = self.decode1(k, v)
else:
result[k] = v
else:
assert isinstance(v, bytes)
result[k] = self.decode1(k, v)
return result
def __call__(self, sample):
"""Decode an entire sample.
:param sample: the sample
"""
assert isinstance(sample, dict), (len(sample), sample)
return self.decode(sample)

@ -6,8 +6,8 @@ import itertools, os, random, re, sys
from urllib.parse import urlparse
from . import filters
from webdataset import gopen
from webdataset.handlers import reraise_exception
from . import gopen
from .handlers import reraise_exception
from .tariterators import tar_file_and_group_expander
default_cache_dir = os.environ.get("WDS_CACHE", "./_cache")

@ -8,7 +8,7 @@ from typing import List
import braceexpand, yaml
from webdataset import autodecode
from . import autodecode
from . import cache, filters, shardlists, tariterators
from .filters import reraise_exception
from .pipeline import DataPipeline

@ -0,0 +1,141 @@
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
"""Train PyTorch models directly from POSIX tar archive.
Code works locally or over HTTP connections.
"""
import itertools as itt
import os
import random
import sys
import braceexpand
from . import utils
from .paddle_utils import IterableDataset
from .utils import PipelineStage
class MockDataset(IterableDataset):
"""MockDataset.
A mock dataset for performance testing and unit testing.
"""
def __init__(self, sample, length):
"""Create a mock dataset instance.
:param sample: the sample to be returned repeatedly
:param length: the length of the mock dataset
"""
self.sample = sample
self.length = length
def __iter__(self):
"""Return an iterator over this mock dataset."""
for i in range(self.length):
yield self.sample
class repeatedly(IterableDataset, PipelineStage):
"""Repeatedly yield samples from a dataset."""
def __init__(self, source, nepochs=None, nbatches=None, length=None):
"""Create an instance of Repeatedly.
:param nepochs: repeat for a maximum of nepochs
:param nbatches: repeat for a maximum of nbatches
"""
self.source = source
self.length = length
self.nbatches = nbatches
def invoke(self, source):
"""Return an iterator that iterates repeatedly over a source."""
return utils.repeatedly(
source,
nepochs=self.nepochs,
nbatches=self.nbatches,
)
class with_epoch(IterableDataset):
"""Change the actual and nominal length of an IterableDataset.
This will continuously iterate through the original dataset, but
impose new epoch boundaries at the given length/nominal.
This exists mainly as a workaround for the odd logic in DataLoader.
It is also useful for choosing smaller nominal epoch sizes with
very large datasets.
"""
def __init__(self, dataset, length):
"""Chop the dataset to the given length.
:param dataset: IterableDataset
:param length: declared length of the dataset
:param nominal: nominal length of dataset (if different from declared)
"""
super().__init__()
self.length = length
self.source = None
def __getstate__(self):
"""Return the pickled state of the dataset.
This resets the dataset iterator, since that can't be pickled.
"""
result = dict(self.__dict__)
result["source"] = None
return result
def invoke(self, dataset):
"""Return an iterator over the dataset.
This iterator returns as many samples as given by the `length`
parameter.
"""
if self.source is None:
self.source = iter(dataset)
for i in range(self.length):
try:
sample = next(self.source)
except StopIteration:
self.source = iter(dataset)
try:
sample = next(self.source)
except StopIteration:
return
yield sample
self.source = None
class with_length(IterableDataset, PipelineStage):
"""Repeatedly yield samples from a dataset."""
def __init__(self, dataset, length):
"""Create an instance of Repeatedly.
:param dataset: source dataset
:param length: stated length
"""
super().__init__()
self.dataset = dataset
self.length = length
def invoke(self, dataset):
"""Return an iterator that iterates repeatedly over a source."""
return iter(dataset)
def __len__(self):
"""Return the user specified length."""
return self.length

@ -21,7 +21,7 @@ from functools import reduce, wraps
import numpy as np
from webdataset import autodecode
from . import autodecode
from . import utils
from .paddle_utils import PaddleTensor
from .utils import PipelineStage
@ -932,4 +932,4 @@ def _placeholder(source):
for data in source:
yield data
placeholder = pipelinefilter(_placeholder)
placeholder = pipelinefilter(_placeholder)

@ -0,0 +1,340 @@
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
"""Open URLs by calling subcommands."""
import os, sys, re
from subprocess import PIPE, Popen
from urllib.parse import urlparse
# global used for printing additional node information during verbose output
info = {}
class Pipe:
"""Wrapper class for subprocess.Pipe.
This class looks like a stream from the outside, but it checks
subprocess status and handles timeouts with exceptions.
This way, clients of the class do not need to know that they are
dealing with subprocesses.
:param *args: passed to `subprocess.Pipe`
:param **kw: passed to `subprocess.Pipe`
:param timeout: timeout for closing/waiting
:param ignore_errors: don't raise exceptions on subprocess errors
:param ignore_status: list of status codes to ignore
"""
def __init__(
self,
*args,
mode=None,
timeout=7200.0,
ignore_errors=False,
ignore_status=[],
**kw,
):
"""Create an IO Pipe."""
self.ignore_errors = ignore_errors
self.ignore_status = [0] + ignore_status
self.timeout = timeout
self.args = (args, kw)
if mode[0] == "r":
self.proc = Popen(*args, stdout=PIPE, **kw)
self.stream = self.proc.stdout
if self.stream is None:
raise ValueError(f"{args}: couldn't open")
elif mode[0] == "w":
self.proc = Popen(*args, stdin=PIPE, **kw)
self.stream = self.proc.stdin
if self.stream is None:
raise ValueError(f"{args}: couldn't open")
self.status = None
def __str__(self):
return f"<Pipe {self.args}>"
def check_status(self):
"""Poll the process and handle any errors."""
status = self.proc.poll()
if status is not None:
self.wait_for_child()
def wait_for_child(self):
"""Check the status variable and raise an exception if necessary."""
verbose = int(os.environ.get("GOPEN_VERBOSE", 0))
if self.status is not None and verbose:
# print(f"(waiting again [{self.status} {os.getpid()}:{self.proc.pid}])", file=sys.stderr)
return
self.status = self.proc.wait()
if verbose:
print(
f"pipe exit [{self.status} {os.getpid()}:{self.proc.pid}] {self.args} {info}",
file=sys.stderr,
)
if self.status not in self.ignore_status and not self.ignore_errors:
raise Exception(f"{self.args}: exit {self.status} (read) {info}")
def read(self, *args, **kw):
"""Wrap stream.read and checks status."""
result = self.stream.read(*args, **kw)
self.check_status()
return result
def write(self, *args, **kw):
"""Wrap stream.write and checks status."""
result = self.stream.write(*args, **kw)
self.check_status()
return result
def readLine(self, *args, **kw):
"""Wrap stream.readLine and checks status."""
result = self.stream.readLine(*args, **kw)
self.status = self.proc.poll()
self.check_status()
return result
def close(self):
"""Wrap stream.close, wait for the subprocess, and handle errors."""
self.stream.close()
self.status = self.proc.wait(self.timeout)
self.wait_for_child()
def __enter__(self):
"""Context handler."""
return self
def __exit__(self, etype, value, traceback):
"""Context handler."""
self.close()
def set_options(
obj, timeout=None, ignore_errors=None, ignore_status=None, handler=None
):
"""Set options for Pipes.
This function can be called on any stream. It will set pipe options only
when its argument is a pipe.
:param obj: any kind of stream
:param timeout: desired timeout
:param ignore_errors: desired ignore_errors setting
:param ignore_status: desired ignore_status setting
:param handler: desired error handler
"""
if not isinstance(obj, Pipe):
return False
if timeout is not None:
obj.timeout = timeout
if ignore_errors is not None:
obj.ignore_errors = ignore_errors
if ignore_status is not None:
obj.ignore_status = ignore_status
if handler is not None:
obj.handler = handler
return True
def gopen_file(url, mode="rb", bufsize=8192):
"""Open a file.
This works for local files, files over HTTP, and pipe: files.
:param url: URL to be opened
:param mode: mode to open it with
:param bufsize: requested buffer size
"""
return open(url, mode)
def gopen_pipe(url, mode="rb", bufsize=8192):
"""Use gopen to open a pipe.
:param url: a pipe: URL
:param mode: desired mode
:param bufsize: desired buffer size
"""
assert url.startswith("pipe:")
cmd = url[5:]
if mode[0] == "r":
return Pipe(
cmd,
mode=mode,
shell=True,
bufsize=bufsize,
ignore_status=[141],
) # skipcq: BAN-B604
elif mode[0] == "w":
return Pipe(
cmd,
mode=mode,
shell=True,
bufsize=bufsize,
ignore_status=[141],
) # skipcq: BAN-B604
else:
raise ValueError(f"{mode}: unknown mode")
def gopen_curl(url, mode="rb", bufsize=8192):
"""Open a URL with `curl`.
:param url: url (usually, http:// etc.)
:param mode: file mode
:param bufsize: buffer size
"""
if mode[0] == "r":
cmd = f"curl -s -L '{url}'"
return Pipe(
cmd,
mode=mode,
shell=True,
bufsize=bufsize,
ignore_status=[141, 23],
) # skipcq: BAN-B604
elif mode[0] == "w":
cmd = f"curl -s -L -T - '{url}'"
return Pipe(
cmd,
mode=mode,
shell=True,
bufsize=bufsize,
ignore_status=[141, 26],
) # skipcq: BAN-B604
else:
raise ValueError(f"{mode}: unknown mode")
def gopen_htgs(url, mode="rb", bufsize=8192):
"""Open a URL with `curl`.
:param url: url (usually, http:// etc.)
:param mode: file mode
:param bufsize: buffer size
"""
if mode[0] == "r":
url = re.sub(r"(?i)^htgs://", "gs://", url)
cmd = f"curl -s -L '{url}'"
return Pipe(
cmd,
mode=mode,
shell=True,
bufsize=bufsize,
ignore_status=[141, 23],
) # skipcq: BAN-B604
elif mode[0] == "w":
raise ValueError(f"{mode}: cannot write")
else:
raise ValueError(f"{mode}: unknown mode")
def gopen_gsutil(url, mode="rb", bufsize=8192):
"""Open a URL with `curl`.
:param url: url (usually, http:// etc.)
:param mode: file mode
:param bufsize: buffer size
"""
if mode[0] == "r":
cmd = f"gsutil cat '{url}'"
return Pipe(
cmd,
mode=mode,
shell=True,
bufsize=bufsize,
ignore_status=[141, 23],
) # skipcq: BAN-B604
elif mode[0] == "w":
cmd = f"gsutil cp - '{url}'"
return Pipe(
cmd,
mode=mode,
shell=True,
bufsize=bufsize,
ignore_status=[141, 26],
) # skipcq: BAN-B604
else:
raise ValueError(f"{mode}: unknown mode")
def gopen_error(url, *args, **kw):
"""Raise a value error.
:param url: url
:param args: other arguments
:param kw: other keywords
"""
raise ValueError(f"{url}: no gopen handler defined")
"""A dispatch table mapping URL schemes to handlers."""
gopen_schemes = dict(
__default__=gopen_error,
pipe=gopen_pipe,
http=gopen_curl,
https=gopen_curl,
sftp=gopen_curl,
ftps=gopen_curl,
scp=gopen_curl,
gs=gopen_gsutil,
htgs=gopen_htgs,
)
def gopen(url, mode="rb", bufsize=8192, **kw):
"""Open the URL.
This uses the `gopen_schemes` dispatch table to dispatch based
on scheme.
Support for the following schemes is built-in: pipe, file,
http, https, sftp, ftps, scp.
When no scheme is given the url is treated as a file.
You can use the OPEN_VERBOSE argument to get info about
files being opened.
:param url: the source URL
:param mode: the mode ("rb", "r")
:param bufsize: the buffer size
"""
global fallback_gopen
verbose = int(os.environ.get("GOPEN_VERBOSE", 0))
if verbose:
print("GOPEN", url, info, file=sys.stderr)
assert mode in ["rb", "wb"], mode
if url == "-":
if mode == "rb":
return sys.stdin.buffer
elif mode == "wb":
return sys.stdout.buffer
else:
raise ValueError(f"unknown mode {mode}")
pr = urlparse(url)
if pr.scheme == "":
bufsize = int(os.environ.get("GOPEN_BUFFER", -1))
return open(url, mode, buffering=bufsize)
if pr.scheme == "file":
bufsize = int(os.environ.get("GOPEN_BUFFER", -1))
return open(pr.path, mode, buffering=bufsize)
handler = gopen_schemes["__default__"]
handler = gopen_schemes.get(pr.scheme, handler)
return handler(url, mode, bufsize, **kw)
def reader(url, **kw):
"""Open url with gopen and mode "rb".
:param url: source URL
:param kw: other keywords forwarded to gopen
"""
return gopen(url, "rb", **kw)

@ -0,0 +1,47 @@
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
"""Pluggable exception handlers.
These are functions that take an exception as an argument and then return...
- the exception (in order to re-raise it)
- True (in order to continue and ignore the exception)
- False (in order to ignore the exception and stop processing)
They are used as handler= arguments in much of the library.
"""
import time, warnings
def reraise_exception(exn):
"""Call in an exception handler to re-raise the exception."""
raise exn
def ignore_and_continue(exn):
"""Call in an exception handler to ignore any exception and continue."""
return True
def warn_and_continue(exn):
"""Call in an exception handler to ignore any exception, isssue a warning, and continue."""
warnings.warn(repr(exn))
time.sleep(0.5)
return True
def ignore_and_stop(exn):
"""Call in an exception handler to ignore any exception and stop further processing."""
return False
def warn_and_stop(exn):
"""Call in an exception handler to ignore any exception and stop further processing."""
warnings.warn(repr(exn))
time.sleep(0.5)
return False

@ -0,0 +1,85 @@
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
"""Classes for mixing samples from multiple sources."""
import itertools, os, random, time, sys
from functools import reduce, wraps
import numpy as np
from . import autodecode, utils
from .paddle_utils import PaddleTensor, IterableDataset
from .utils import PipelineStage
def round_robin_shortest(*sources):
i = 0
while True:
try:
sample = next(sources[i % len(sources)])
yield sample
except StopIteration:
break
i += 1
def round_robin_longest(*sources):
i = 0
while len(sources) > 0:
try:
sample = next(sources[i])
i += 1
yield sample
except StopIteration:
del sources[i]
class RoundRobin(IterableDataset):
def __init__(self, datasets, longest=False):
self.datasets = datasets
self.longest = longest
def __iter__(self):
"""Return an iterator over the sources."""
sources = [iter(d) for d in self.datasets]
if self.longest:
return round_robin_longest(*sources)
else:
return round_robin_shortest(*sources)
def random_samples(sources, probs=None, longest=False):
if probs is None:
probs = [1] * len(sources)
else:
probs = list(probs)
while len(sources) > 0:
cum = (np.array(probs) / np.sum(probs)).cumsum()
r = random.random()
i = np.searchsorted(cum, r)
try:
yield next(sources[i])
except StopIteration:
if longest:
del sources[i]
del probs[i]
else:
break
class RandomMix(IterableDataset):
def __init__(self, datasets, probs=None, longest=False):
self.datasets = datasets
self.probs = probs
self.longest = longest
def __iter__(self):
"""Return an iterator over the sources."""
sources = [iter(d) for d in self.datasets]
return random_samples(sources, self.probs, longest=self.longest)

@ -10,8 +10,7 @@ from typing import List
import braceexpand, yaml
from webdataset import autodecode, extradatasets as eds, filters, shardlists, tariterators
from webdataset.handlers import reraise_exception
from .handlers import reraise_exception
from .paddle_utils import DataLoader, IterableDataset
from .utils import PipelineStage

@ -14,8 +14,8 @@ import random, re, tarfile
import braceexpand
from . import filters
from webdataset import gopen
from webdataset.handlers import reraise_exception
from . import gopen
from .handlers import reraise_exception
trace = False
meta_prefix = "__"

@ -0,0 +1,450 @@
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
"""Classes and functions for writing tar files and WebDataset files."""
import io, json, pickle, re, tarfile, time
from typing import Any, Callable, Optional, Union
import numpy as np
from . import gopen
def imageencoder(image: Any, format: str = "PNG"): # skipcq: PYL-W0622
"""Compress an image using PIL and return it as a string.
Can handle float or uint8 images.
:param image: ndarray representing an image
:param format: compression format (PNG, JPEG, PPM)
"""
import PIL
assert isinstance(image, (PIL.Image.Image, np.ndarray)), type(image)
if isinstance(image, np.ndarray):
if image.dtype in [np.dtype("f"), np.dtype("d")]:
if not (np.amin(image) > -0.001 and np.amax(image) < 1.001):
raise ValueError(
f"image values out of range {np.amin(image)} {np.amax(image)}"
)
image = np.clip(image, 0.0, 1.0)
image = np.array(image * 255.0, "uint8")
assert image.ndim in [2, 3]
if image.ndim == 3:
assert image.shape[2] in [1, 3]
image = PIL.Image.fromarray(image)
if format.upper() == "JPG":
format = "JPEG"
elif format.upper() in ["IMG", "IMAGE"]:
format = "PPM"
if format == "JPEG":
opts = dict(quality=100)
else:
opts = {}
with io.BytesIO() as result:
image.save(result, format=format, **opts)
return result.getvalue()
def bytestr(data: Any):
"""Convert data into a bytestring.
Uses str and ASCII encoding for data that isn't already in string format.
:param data: data
"""
if isinstance(data, bytes):
return data
if isinstance(data, str):
return data.encode("ascii")
return str(data).encode("ascii")
def paddle_dumps(data: Any):
"""Dump data into a bytestring using paddle.dumps.
This delays importing paddle until needed.
:param data: data to be dumped
"""
import io
import paddle
stream = io.BytesIO()
paddle.save(data, stream)
return stream.getvalue()
def numpy_dumps(data: np.ndarray):
"""Dump data into a bytestring using numpy npy format.
:param data: data to be dumped
"""
import io
import numpy.lib.format
stream = io.BytesIO()
numpy.lib.format.write_array(stream, data)
return stream.getvalue()
def numpy_npz_dumps(data: np.ndarray):
"""Dump data into a bytestring using numpy npz format.
:param data: data to be dumped
"""
import io
stream = io.BytesIO()
np.savez_compressed(stream, **data)
return stream.getvalue()
def tenbin_dumps(x):
from . import tenbin
if isinstance(x, list):
return memoryview(tenbin.encode_buffer(x))
else:
return memoryview(tenbin.encode_buffer([x]))
def cbor_dumps(x):
import cbor
return cbor.dumps(x)
def mp_dumps(x):
import msgpack
return msgpack.packb(x)
def add_handlers(d, keys, value):
if isinstance(keys, str):
keys = keys.split()
for k in keys:
d[k] = value
def make_handlers():
"""Create a list of handlers for encoding data."""
handlers = {}
add_handlers(
handlers, "cls cls2 class count index inx id", lambda x: str(x).encode("ascii")
)
add_handlers(handlers, "txt text transcript", lambda x: x.encode("utf-8"))
add_handlers(handlers, "html htm", lambda x: x.encode("utf-8"))
add_handlers(handlers, "pyd pickle", pickle.dumps)
add_handlers(handlers, "pdparams", paddle_dumps)
add_handlers(handlers, "npy", numpy_dumps)
add_handlers(handlers, "npz", numpy_npz_dumps)
add_handlers(handlers, "ten tenbin tb", tenbin_dumps)
add_handlers(handlers, "json jsn", lambda x: json.dumps(x).encode("utf-8"))
add_handlers(handlers, "mp msgpack msg", mp_dumps)
add_handlers(handlers, "cbor", cbor_dumps)
add_handlers(handlers, "jpg jpeg img image", lambda data: imageencoder(data, "jpg"))
add_handlers(handlers, "png", lambda data: imageencoder(data, "png"))
add_handlers(handlers, "pbm", lambda data: imageencoder(data, "pbm"))
add_handlers(handlers, "pgm", lambda data: imageencoder(data, "pgm"))
add_handlers(handlers, "ppm", lambda data: imageencoder(data, "ppm"))
return handlers
default_handlers = make_handlers()
def encode_based_on_extension1(data: Any, tname: str, handlers: dict):
"""Encode data based on its extension and a dict of handlers.
:param data: data
:param tname: file extension
:param handlers: handlers
"""
if tname[0] == "_":
if not isinstance(data, str):
raise ValueError("the values of metadata must be of string type")
return data
extension = re.sub(r".*\.", "", tname).lower()
if isinstance(data, bytes):
return data
if isinstance(data, str):
return data.encode("utf-8")
handler = handlers.get(extension)
if handler is None:
raise ValueError(f"no handler found for {extension}")
return handler(data)
def encode_based_on_extension(sample: dict, handlers: dict):
"""Encode an entire sample with a collection of handlers.
:param sample: data sample (a dict)
:param handlers: handlers for encoding
"""
return {
k: encode_based_on_extension1(v, k, handlers) for k, v in list(sample.items())
}
def make_encoder(spec: Union[bool, str, dict, Callable]):
"""Make an encoder function from a specification.
:param spec: specification
"""
if spec is False or spec is None:
def encoder(x):
"""Do not encode at all."""
return x
elif callable(spec):
encoder = spec
elif isinstance(spec, dict):
def f(sample):
"""Encode based on extension."""
return encode_based_on_extension(sample, spec)
encoder = f
elif spec is True:
handlers = default_handlers
def g(sample):
"""Encode based on extension."""
return encode_based_on_extension(sample, handlers)
encoder = g
else:
raise ValueError(f"{spec}: unknown decoder spec")
if not callable(encoder):
raise ValueError(f"{spec} did not yield a callable encoder")
return encoder
class TarWriter:
"""A class for writing dictionaries to tar files.
:param fileobj: fileobj: file name for tar file (.tgz/.tar) or open file descriptor
:param encoder: sample encoding (Default value = True)
:param compress: (Default value = None)
`True` will use an encoder that behaves similar to the automatic
decoder for `Dataset`. `False` disables encoding and expects byte strings
(except for metadata, which must be strings). The `encoder` argument can
also be a `callable`, or a dictionary mapping extensions to encoders.
The following code will add two file to the tar archive: `a/b.png` and
`a/b.output.png`.
```Python
tarwriter = TarWriter(stream)
image = imread("b.jpg")
image2 = imread("b.out.jpg")
sample = {"__key__": "a/b", "png": image, "output.png": image2}
tarwriter.write(sample)
```
"""
def __init__(
self,
fileobj,
user: str = "bigdata",
group: str = "bigdata",
mode: int = 0o0444,
compress: Optional[bool] = None,
encoder: Union[None, bool, Callable] = True,
keep_meta: bool = False,
):
"""Create a tar writer.
:param fileobj: stream to write data to
:param user: user for tar files
:param group: group for tar files
:param mode: mode for tar files
:param compress: desired compression
:param encoder: encoder function
:param keep_meta: keep metadata (entries starting with "_")
"""
if isinstance(fileobj, str):
if compress is False:
tarmode = "w|"
elif compress is True:
tarmode = "w|gz"
else:
tarmode = "w|gz" if fileobj.endswith("gz") else "w|"
fileobj = gopen.gopen(fileobj, "wb")
self.own_fileobj = fileobj
else:
tarmode = "w|gz" if compress is True else "w|"
self.own_fileobj = None
self.encoder = make_encoder(encoder)
self.keep_meta = keep_meta
self.stream = fileobj
self.tarstream = tarfile.open(fileobj=fileobj, mode=tarmode)
self.user = user
self.group = group
self.mode = mode
self.compress = compress
def __enter__(self):
"""Enter context."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Exit context."""
self.close()
def close(self):
"""Close the tar file."""
self.tarstream.close()
if self.own_fileobj is not None:
self.own_fileobj.close()
self.own_fileobj = None
def write(self, obj):
"""Write a dictionary to the tar file.
:param obj: dictionary of objects to be stored
:returns: size of the entry
"""
total = 0
obj = self.encoder(obj)
if "__key__" not in obj:
raise ValueError("object must contain a __key__")
for k, v in list(obj.items()):
if k[0] == "_":
continue
if not isinstance(v, (bytes, bytearray, memoryview)):
raise ValueError(
f"{k} doesn't map to a bytes after encoding ({type(v)})"
)
key = obj["__key__"]
for k in sorted(obj.keys()):
if k == "__key__":
continue
if not self.keep_meta and k[0] == "_":
continue
v = obj[k]
if isinstance(v, str):
v = v.encode("utf-8")
now = time.time()
ti = tarfile.TarInfo(key + "." + k)
ti.size = len(v)
ti.mtime = now
ti.mode = self.mode
ti.uname = self.user
ti.gname = self.group
if not isinstance(v, (bytes, bytearray, memoryview)):
raise ValueError(f"converter didn't yield bytes: {k}, {type(v)}")
stream = io.BytesIO(v)
self.tarstream.addfile(ti, stream)
total += ti.size
return total
class ShardWriter:
"""Like TarWriter but splits into multiple shards."""
def __init__(
self,
pattern: str,
maxcount: int = 100000,
maxsize: float = 3e9,
post: Optional[Callable] = None,
start_shard: int = 0,
**kw,
):
"""Create a ShardWriter.
:param pattern: output file pattern
:param maxcount: maximum number of records per shard (Default value = 100000)
:param maxsize: maximum size of each shard (Default value = 3e9)
:param kw: other options passed to TarWriter
"""
self.verbose = 1
self.kw = kw
self.maxcount = maxcount
self.maxsize = maxsize
self.post = post
self.tarstream = None
self.shard = start_shard
self.pattern = pattern
self.total = 0
self.count = 0
self.size = 0
self.fname = None
self.next_stream()
def next_stream(self):
"""Close the current stream and move to the next."""
self.finish()
self.fname = self.pattern % self.shard
if self.verbose:
print(
"# writing",
self.fname,
self.count,
"%.1f GB" % (self.size / 1e9),
self.total,
)
self.shard += 1
stream = open(self.fname, "wb")
self.tarstream = TarWriter(stream, **self.kw)
self.count = 0
self.size = 0
def write(self, obj):
"""Write a sample.
:param obj: sample to be written
"""
if (
self.tarstream is None
or self.count >= self.maxcount
or self.size >= self.maxsize
):
self.next_stream()
size = self.tarstream.write(obj)
self.count += 1
self.total += 1
self.size += size
def finish(self):
"""Finish all writing (use close instead)."""
if self.tarstream is not None:
self.tarstream.close()
assert self.fname is not None
if callable(self.post):
self.post(self.fname)
self.tarstream = None
def close(self):
"""Close the stream."""
self.finish()
del self.tarstream
del self.shard
del self.count
del self.size
def __enter__(self):
"""Enter context."""
return self
def __exit__(self, *args, **kw):
"""Exit context."""
self.close()

@ -18,6 +18,7 @@ from typing import Text
import jsonlines
import numpy as np
import paddle
from paddle.io import BatchSampler
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
@ -28,7 +29,7 @@ from paddlespeech.s2t.io.dataset import TransformDataset
from paddlespeech.s2t.io.reader import LoadInputsAndTargets
from paddlespeech.s2t.utils.log import Log
import paddlespeech.audio.stream_data as stream_data
import paddlespeech.audio.streamdata as streamdata
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
__all__ = ["BatchDataLoader"]
@ -101,38 +102,46 @@ class StreamDataLoader():
shardlist.append(line.strip())
if self.dist_sampler:
base_dataset = stream_data.DataPipeline(
stream_data.SimpleShardList(shardlist),
stream_data.split_by_node,
stream_data.split_by_worker,
stream_data.tarfile_to_samples(stream_data.reraise_exception)
base_dataset = streamdata.DataPipeline(
streamdata.SimpleShardList(shardlist),
streamdata.split_by_node,
streamdata.split_by_worker,
streamdata.tarfile_to_samples(streamdata.reraise_exception)
)
else:
base_dataset = stream_data.DataPipeline(
stream_data.SimpleShardList(shardlist),
stream_data.split_by_worker,
stream_data.tarfile_to_samples(stream_data.reraise_exception)
base_dataset = streamdata.DataPipeline(
streamdata.SimpleShardList(shardlist),
streamdata.split_by_worker,
streamdata.tarfile_to_samples(streamdata.reraise_exception)
)
self.dataset = base_dataset.append_list(
stream_data.tokenize(symbol_table),
stream_data.data_filter(frame_shift=frame_shift, max_length=maxlen_in, min_length=minlen_in, token_max_length=maxlen_out, token_min_length=minlen_in),
stream_data.resample(resample_rate=resample_rate),
stream_data.compute_fbank(num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither),
stream_data.spec_aug(**augment_conf) if train_mode else stream_data.placeholder(), # num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80)
stream_data.shuffle(shuffle_size),
stream_data.sort(sort_size=sort_size),
stream_data.batched(batch_size),
stream_data.padding(),
stream_data.cmvn(cmvn_file)
)
self.loader = stream_data.WebLoader(
self.dataset,
num_workers=self.n_iter_processes,
prefetch_factor = self.prefetch_factor,
batch_size=None
streamdata.tokenize(symbol_table),
streamdata.data_filter(frame_shift=frame_shift, max_length=maxlen_in, min_length=minlen_in, token_max_length=maxlen_out, token_min_length=minlen_in),
streamdata.resample(resample_rate=resample_rate),
streamdata.compute_fbank(num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither),
streamdata.spec_aug(**augment_conf) if train_mode else streamdata.placeholder(), # num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80)
streamdata.shuffle(shuffle_size),
streamdata.sort(sort_size=sort_size),
streamdata.batched(batch_size),
streamdata.padding(),
streamdata.cmvn(cmvn_file)
)
if paddle.__version__ >= '2.3.2':
self.loader = streamdata.WebLoader(
self.dataset,
num_workers=self.n_iter_processes,
prefetch_factor = self.prefetch_factor,
batch_size=None
)
else:
self.loader = streamdata.WebLoader(
self.dataset,
num_workers=self.n_iter_processes,
batch_size=None
)
def __iter__(self):
return self.loader.__iter__()

@ -38,7 +38,8 @@ base = [
"pypinyin", "pypinyin-dict", "python-dateutil", "pyworld", "resampy==0.2.2",
"sacrebleu", "scipy", "sentencepiece~=0.1.96", "soundfile~=0.10",
"textgrid", "timer", "tqdm", "typeguard", "visualdl", "webrtcvad",
"yacs~=0.1.8", "prettytable", "zhon", 'colorlog', 'pathos == 0.2.8', 'webdataset'
"yacs~=0.1.8", "prettytable", "zhon", "colorlog", "pathos == 0.2.8",
"braceexpand", "pyyaml"
]
server = [

Loading…
Cancel
Save