You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/paddlespeech/audio/stream_data/pipeline.py

134 lines
4.3 KiB

# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#%%
import copy, os, random, sys, time
from dataclasses import dataclass
from itertools import islice
from typing import List
import braceexpand, yaml
from webdataset import autodecode, extradatasets as eds, filters, shardlists, tariterators
from webdataset.handlers import reraise_exception
from .paddle_utils import DataLoader, IterableDataset
from .utils import PipelineStage
def add_length_method(obj):
def length(self):
return self.size
Combined = type(
obj.__class__.__name__ + "_Length",
(obj.__class__, IterableDataset),
{"__len__": length},
)
obj.__class__ = Combined
return obj
class DataPipeline(IterableDataset, PipelineStage):
"""A pipeline starting with an IterableDataset and a series of filters."""
def __init__(self, *args, **kwargs):
super().__init__()
self.pipeline = []
self.length = -1
self.repetitions = 1
self.nsamples = -1
for arg in args:
if arg is None:
continue
if isinstance(arg, list):
self.pipeline.extend(arg)
else:
self.pipeline.append(arg)
def invoke(self, f, *args, **kwargs):
"""Apply a pipeline stage, possibly to the output of a previous stage."""
if isinstance(f, PipelineStage):
return f.run(*args, **kwargs)
if isinstance(f, (IterableDataset, DataLoader)) and len(args) == 0:
return iter(f)
if isinstance(f, list):
return iter(f)
if callable(f):
result = f(*args, **kwargs)
return result
raise ValueError(f"{f}: not a valid pipeline stage")
def iterator1(self):
"""Create an iterator through one epoch in the pipeline."""
source = self.invoke(self.pipeline[0])
for step in self.pipeline[1:]:
source = self.invoke(step, source)
return source
def iterator(self):
"""Create an iterator through the entire dataset, using the given number of repetitions."""
for i in range(self.repetitions):
for sample in self.iterator1():
yield sample
def __iter__(self):
"""Create an iterator through the pipeline, repeating and slicing as requested."""
if self.repetitions != 1:
if self.nsamples > 0:
return islice(self.iterator(), self.nsamples)
else:
return self.iterator()
else:
return self.iterator()
def stage(self, i):
"""Return pipeline stage i."""
return self.pipeline[i]
def append(self, f):
"""Append a pipeline stage (modifies the object)."""
self.pipeline.append(f)
return self
def append_list(self, *args):
for arg in args:
self.pipeline.append(arg)
return self
def compose(self, *args):
"""Append a pipeline stage to a copy of the pipeline and returns the copy."""
result = copy.copy(self)
for arg in args:
result.append(arg)
return result
def with_length(self, n):
"""Add a __len__ method returning the desired value.
This does not change the actual number of samples in an epoch.
PyTorch IterableDataset should not have a __len__ method.
This is provided only as a workaround for some broken training environments
that require a __len__ method.
"""
self.size = n
return add_length_method(self)
def with_epoch(self, nsamples=-1, nbatches=-1):
"""Change the epoch to return the given number of samples/batches.
The two arguments mean the same thing."""
self.repetitions = sys.maxsize
self.nsamples = max(nsamples, nbatches)
return self
def repeat(self, nepochs=-1, nbatches=-1):
"""Repeat iterating through the dataset for the given #epochs up to the given #samples."""
if nepochs > 0:
self.repetitions = nepochs
self.nsamples = nbatches
else:
self.repetitions = sys.maxsize
self.nsamples = nbatches
return self