parent
2ad3a81945
commit
e5b99965b8
@ -1,11 +0,0 @@
|
|||||||
beam_size: 10
|
|
||||||
decode_batch_size: 128
|
|
||||||
error_rate_type: cer
|
|
||||||
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
|
|
||||||
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
|
|
||||||
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
|
|
||||||
# <0: for decoding, use full chunk.
|
|
||||||
# >0: for decoding, use fixed chunk size as set.
|
|
||||||
# 0: used for training, it's prohibited here.
|
|
||||||
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
|
|
||||||
simulate_streaming: False # simulate streaming inference. Defaults to False.
|
|
@ -1,4 +1,3 @@
|
|||||||
process:
|
process:
|
||||||
# use raw audio
|
# use raw audio
|
||||||
- type: wav_process
|
- type: wav_process
|
||||||
dither: 0.0
|
|
@ -0,0 +1,4 @@
|
|||||||
|
decode_batch_size: 1
|
||||||
|
error_rate_type: cer
|
||||||
|
decoding_method: ctc_greedy_search # 'ctc_greedy_search', 'ctc_prefix_beam_search'
|
||||||
|
beam_size: 10
|
@ -0,0 +1,184 @@
|
|||||||
|
"""Batch collation
|
||||||
|
|
||||||
|
Authors
|
||||||
|
* Aku Rouhe 2020
|
||||||
|
"""
|
||||||
|
import collections
|
||||||
|
import torch
|
||||||
|
from paddlespeech.s2t.io.wav2vec2.data_utils import mod_default_collate
|
||||||
|
# from speechbrain.utils.data_utils import recursive_to
|
||||||
|
from paddlespeech.s2t.io.wav2vec2.data_utils import batch_pad_right
|
||||||
|
from torch.utils.data._utils.collate import default_convert
|
||||||
|
# from torch.utils.data._utils.pin_memory import (
|
||||||
|
# pin_memory as recursive_pin_memory,
|
||||||
|
# )
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
PaddedData = collections.namedtuple("PaddedData", ["data", "lengths"])
|
||||||
|
|
||||||
|
|
||||||
|
class PaddedBatch:
|
||||||
|
"""Collate_fn when examples are dicts and have variable-length sequences.
|
||||||
|
|
||||||
|
Different elements in the examples get matched by key.
|
||||||
|
All numpy tensors get converted to Torch (PyTorch default_convert)
|
||||||
|
Then, by default, all torch.Tensor valued elements get padded and support
|
||||||
|
collective pin_memory() and to() calls.
|
||||||
|
Regular Python data types are just collected in a list.
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
examples : list
|
||||||
|
List of example dicts, as produced by Dataloader.
|
||||||
|
padded_keys : list, None
|
||||||
|
(Optional) List of keys to pad on. If None, pad all torch.Tensors
|
||||||
|
device_prep_keys : list, None
|
||||||
|
(Optional) Only these keys participate in collective memory pinning and moving with
|
||||||
|
to().
|
||||||
|
If None, defaults to all items with torch.Tensor values.
|
||||||
|
padding_func : callable, optional
|
||||||
|
Called with a list of tensors to be padded together. Needs to return
|
||||||
|
two tensors: the padded data, and another tensor for the data lengths.
|
||||||
|
padding_kwargs : dict
|
||||||
|
(Optional) Extra kwargs to pass to padding_func. E.G. mode, value
|
||||||
|
apply_default_convert : bool
|
||||||
|
Whether to apply PyTorch default_convert (numpy to torch recursively,
|
||||||
|
etc.) on all data. Default:True, usually does the right thing.
|
||||||
|
nonpadded_stack : bool
|
||||||
|
Whether to apply PyTorch-default_collate-like stacking on values that
|
||||||
|
didn't get padded. This stacks if it can, but doesn't error out if it
|
||||||
|
cannot. Default:True, usually does the right thing.
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
>>> batch = PaddedBatch([
|
||||||
|
... {"id": "ex1", "foo": torch.Tensor([1.])},
|
||||||
|
... {"id": "ex2", "foo": torch.Tensor([2., 1.])}])
|
||||||
|
>>> # Attribute or key-based access:
|
||||||
|
>>> batch.id
|
||||||
|
['ex1', 'ex2']
|
||||||
|
>>> batch["id"]
|
||||||
|
['ex1', 'ex2']
|
||||||
|
>>> # torch.Tensors get padded
|
||||||
|
>>> type(batch.foo)
|
||||||
|
<class 'speechbrain.dataio.batch.PaddedData'>
|
||||||
|
>>> batch.foo.data
|
||||||
|
tensor([[1., 0.],
|
||||||
|
[2., 1.]])
|
||||||
|
>>> batch.foo.lengths
|
||||||
|
tensor([0.5000, 1.0000])
|
||||||
|
>>> # Batch supports collective operations:
|
||||||
|
>>> _ = batch.to(dtype=torch.half)
|
||||||
|
>>> batch.foo.data
|
||||||
|
tensor([[1., 0.],
|
||||||
|
[2., 1.]], dtype=torch.float16)
|
||||||
|
>>> batch.foo.lengths
|
||||||
|
tensor([0.5000, 1.0000], dtype=torch.float16)
|
||||||
|
>>> # Numpy tensors get converted to torch and padded as well:
|
||||||
|
>>> import numpy as np
|
||||||
|
>>> batch = PaddedBatch([
|
||||||
|
... {"wav": np.asarray([1,2,3,4])},
|
||||||
|
... {"wav": np.asarray([1,2,3])}])
|
||||||
|
>>> batch.wav # +ELLIPSIS
|
||||||
|
PaddedData(data=tensor([[1, 2,...
|
||||||
|
>>> # Basic stacking collation deals with non padded data:
|
||||||
|
>>> batch = PaddedBatch([
|
||||||
|
... {"spk_id": torch.tensor([1]), "wav": torch.tensor([.1,.0,.3])},
|
||||||
|
... {"spk_id": torch.tensor([2]), "wav": torch.tensor([.2,.3,-.1])}],
|
||||||
|
... padded_keys=["wav"])
|
||||||
|
>>> batch.spk_id
|
||||||
|
tensor([[1],
|
||||||
|
[2]])
|
||||||
|
>>> # And some data is left alone:
|
||||||
|
>>> batch = PaddedBatch([
|
||||||
|
... {"text": ["Hello"]},
|
||||||
|
... {"text": ["How", "are", "you?"]}])
|
||||||
|
>>> batch.text
|
||||||
|
[['Hello'], ['How', 'are', 'you?']]
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
examples,
|
||||||
|
padded_keys=None,
|
||||||
|
device_prep_keys=None,
|
||||||
|
padding_func=batch_pad_right,
|
||||||
|
padding_kwargs={},
|
||||||
|
nonpadded_stack=True,
|
||||||
|
):
|
||||||
|
self.__length = len(examples)
|
||||||
|
self.__keys = list(examples[0].keys())
|
||||||
|
self.__padded_keys = []
|
||||||
|
self.__device_prep_keys = []
|
||||||
|
for key in self.__keys:
|
||||||
|
values = [example[key] for example in examples]
|
||||||
|
# Default convert usually does the right thing (numpy2torch etc.)
|
||||||
|
values = default_convert(values)
|
||||||
|
|
||||||
|
if (padded_keys is not None and key in padded_keys) or (
|
||||||
|
padded_keys is None and isinstance(values[0], paddle.Tensor)
|
||||||
|
):
|
||||||
|
# Padding and PaddedData
|
||||||
|
self.__padded_keys.append(key)
|
||||||
|
padded = PaddedData(*padding_func(values, **padding_kwargs))
|
||||||
|
setattr(self, key, padded)
|
||||||
|
else:
|
||||||
|
# Default PyTorch collate usually does the right thing
|
||||||
|
# (convert lists of equal sized tensors to batch tensors, etc.)
|
||||||
|
if nonpadded_stack:
|
||||||
|
values = mod_default_collate(values)
|
||||||
|
setattr(self, key, values)
|
||||||
|
if (device_prep_keys is not None and key in device_prep_keys) or (
|
||||||
|
device_prep_keys is None and isinstance(values[0], paddle.Tensor)
|
||||||
|
):
|
||||||
|
self.__device_prep_keys.append(key)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.__length
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
if key in self.__keys:
|
||||||
|
return getattr(self, key)
|
||||||
|
else:
|
||||||
|
raise KeyError(f"Batch doesn't have key: {key}")
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
"""Iterates over the different elements of the batch.
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
>>> batch = PaddedBatch([
|
||||||
|
... {"id": "ex1", "val": torch.Tensor([1.])},
|
||||||
|
... {"id": "ex2", "val": torch.Tensor([2., 1.])}])
|
||||||
|
>>> ids, vals = batch
|
||||||
|
>>> ids
|
||||||
|
['ex1', 'ex2']
|
||||||
|
"""
|
||||||
|
return iter((getattr(self, key) for key in self.__keys))
|
||||||
|
|
||||||
|
# def pin_memory(self):
|
||||||
|
# """In-place, moves relevant elements to pinned memory."""
|
||||||
|
# for key in self.__device_prep_keys:
|
||||||
|
# value = getattr(self, key)
|
||||||
|
# pinned = value
|
||||||
|
# setattr(self, key, pinned)
|
||||||
|
# return self
|
||||||
|
|
||||||
|
# def to(self, *args, **kwargs):
|
||||||
|
# """In-place move/cast relevant elements.
|
||||||
|
|
||||||
|
# Passes all arguments to torch.Tensor.to, see its documentation.
|
||||||
|
# """
|
||||||
|
# for key in self.__device_prep_keys:
|
||||||
|
# value = getattr(self, key)
|
||||||
|
# moved = recursive_to(value, *args, **kwargs)
|
||||||
|
# setattr(self, key, moved)
|
||||||
|
# return self
|
||||||
|
|
||||||
|
# def at_position(self, pos):
|
||||||
|
# """Gets the position."""
|
||||||
|
# key = self.__keys[pos]
|
||||||
|
# return getattr(self, key)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,518 @@
|
|||||||
|
"""A pipeline for data transformations.
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
>>> from hyperpyyaml import load_hyperpyyaml
|
||||||
|
>>> yamlstring = '''
|
||||||
|
... pipeline: !new:speechbrain.utils.data_pipeline.DataPipeline
|
||||||
|
... static_data_keys: [a, b]
|
||||||
|
... dynamic_items:
|
||||||
|
... - func: !name:operator.add
|
||||||
|
... takes: ["a", "b"]
|
||||||
|
... provides: foo
|
||||||
|
... - func: !name:operator.sub
|
||||||
|
... takes: ["foo", "b"]
|
||||||
|
... provides: bar
|
||||||
|
... output_keys: ["foo", "bar"]
|
||||||
|
... '''
|
||||||
|
>>> hparams = load_hyperpyyaml(yamlstring)
|
||||||
|
>>> hparams["pipeline"]({"a":1, "b":2})
|
||||||
|
{'foo': 3, 'bar': 1}
|
||||||
|
|
||||||
|
Author:
|
||||||
|
* Aku Rouhe
|
||||||
|
"""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from paddlespeech.s2t.io.wav2vec2.depgraph import DependencyGraph
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StaticItem:
|
||||||
|
"""Data class that represents a static item.
|
||||||
|
|
||||||
|
Static items are in-memory items so they don't need to be computed
|
||||||
|
dynamically.
|
||||||
|
"""
|
||||||
|
|
||||||
|
key: str
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicItem:
|
||||||
|
"""Essentially represents a data transformation function.
|
||||||
|
|
||||||
|
A DynamicItem takes some arguments and computes its value dynamically when
|
||||||
|
called. A straight-forward use-case is to load something from disk
|
||||||
|
dynamically; take the path and provide the loaded data.
|
||||||
|
|
||||||
|
Instances of this class are often created implicitly via the
|
||||||
|
@takes and @provides decorators or otherwise from specifying the taken and
|
||||||
|
provided arguments and the function.
|
||||||
|
|
||||||
|
A counterpart is the GeneratorDynamicItem, which should be used for
|
||||||
|
generator functions.
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
takes : list
|
||||||
|
The keys of the items that this needs to compute its output.
|
||||||
|
func : callable
|
||||||
|
The function that is used to compute the output.
|
||||||
|
provides : list
|
||||||
|
The keys that this provides.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, takes=[], func=None, provides=[]):
|
||||||
|
self.takes = takes
|
||||||
|
self.func = func
|
||||||
|
self.provides = provides
|
||||||
|
|
||||||
|
def __call__(self, *args):
|
||||||
|
return self.func(*args)
|
||||||
|
|
||||||
|
# The next methods are more about supporting GeneratorDynamicItems
|
||||||
|
def next_takes(self):
|
||||||
|
"""The next argkeys to provide to this, when called."""
|
||||||
|
# Regular function DynamicItems always just need the same set of args
|
||||||
|
return self.takes
|
||||||
|
|
||||||
|
def next_provides(self):
|
||||||
|
"""The next keys that this provides, when called."""
|
||||||
|
# Regular function DynamicItems always just provide the same set of keys
|
||||||
|
return self.provides
|
||||||
|
|
||||||
|
def provided_in_order(self):
|
||||||
|
"""Assuming that this may need to be called multiple times; which keys
|
||||||
|
does it provide at that call. Returns a list, with len equal to the
|
||||||
|
number of times that this may be called."""
|
||||||
|
# Regular function DynamicItems are only called once:
|
||||||
|
return [self.provides]
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Signals that this will not be called any more times on this pipeline
|
||||||
|
call."""
|
||||||
|
# Regular function DynamicItems don't need special resets.
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GeneratorDynamicItem(DynamicItem):
|
||||||
|
"""Essentially represents a multi-step data transformation.
|
||||||
|
|
||||||
|
This is the generator function counterpart for DynamicItem (which should be
|
||||||
|
used for regular functions).
|
||||||
|
|
||||||
|
A GeneratorDynamicItem first takes some arguments and then uses those in
|
||||||
|
multiple steps to incrementally compute some values when called.
|
||||||
|
|
||||||
|
A typical use-case is a pipeline of transformations on data: e.g. taking in
|
||||||
|
text as a string, and first a tokenized version, and then on the second
|
||||||
|
call providing an integer-encoded version. This can be used even though the
|
||||||
|
integer-encoder needs to be trained on the first outputs.
|
||||||
|
|
||||||
|
The main benefit is to be able to define the pipeline in a clear function,
|
||||||
|
even if parts of the pipeline depend on others for their initialization.
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
>>> lab2ind = {}
|
||||||
|
>>> def text_pipeline(text):
|
||||||
|
... text = text.lower().strip()
|
||||||
|
... text = "".join(c for c in text if c.isalpha() or c == " ")
|
||||||
|
... words = text.split()
|
||||||
|
... yield words
|
||||||
|
... encoded = [lab2ind[word] for word in words]
|
||||||
|
... yield encoded
|
||||||
|
>>> item = GeneratorDynamicItem(
|
||||||
|
... func=text_pipeline,
|
||||||
|
... takes=["text"],
|
||||||
|
... provides=["words", "words_encoded"])
|
||||||
|
>>> # First create the integer-encoding:
|
||||||
|
>>> ind = 1
|
||||||
|
>>> for token in item("Is this it? - This is it."):
|
||||||
|
... if token not in lab2ind:
|
||||||
|
... lab2ind[token] = ind
|
||||||
|
... ind += 1
|
||||||
|
>>> # Now the integers can be encoded!
|
||||||
|
>>> item()
|
||||||
|
[1, 2, 3, 2, 1, 3]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
# Doesn't generate electricity, only stores the currently active
|
||||||
|
# generator:
|
||||||
|
self.current_generator = None
|
||||||
|
self.num_provided_items = 0
|
||||||
|
|
||||||
|
def __call__(self, *args):
|
||||||
|
if self.num_provided_items == len(self.provides):
|
||||||
|
raise RuntimeError("DynamicItemPipeline called too many times!")
|
||||||
|
if not self.current_generator:
|
||||||
|
self.current_generator = self.func(*args)
|
||||||
|
# NOTE: Not supporting sending new values to the pipeline.
|
||||||
|
out = next(self.current_generator)
|
||||||
|
self.num_provided_items += 1
|
||||||
|
return out
|
||||||
|
|
||||||
|
def next_takes(self):
|
||||||
|
"""The next argkeys to provide to this, when called."""
|
||||||
|
if not self.current_generator:
|
||||||
|
return self.takes
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def next_provides(self):
|
||||||
|
"""The next keys that this provides, when called."""
|
||||||
|
keys = self.provides[self.num_provided_items]
|
||||||
|
# Support multiple yielded values like:
|
||||||
|
# @yields("wav_read", ["left_ch", "right_ch"])
|
||||||
|
if isinstance(keys, str):
|
||||||
|
return [keys]
|
||||||
|
else:
|
||||||
|
return keys
|
||||||
|
|
||||||
|
def provided_in_order(self):
|
||||||
|
"""Assuming that this may need to be called multiple times; which keys
|
||||||
|
does it provide at that call. Returns a list, with len equal to the
|
||||||
|
number of times that this may be called."""
|
||||||
|
in_order = []
|
||||||
|
for keys in self.provides:
|
||||||
|
# Support multiple yielded values like:
|
||||||
|
# @provides("wav_read", ["left_ch", "right_ch"])
|
||||||
|
if isinstance(keys, str):
|
||||||
|
in_order.append([keys])
|
||||||
|
else:
|
||||||
|
in_order.append(keys)
|
||||||
|
return in_order
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Signals that this will not be called any more times on this pipeline
|
||||||
|
call."""
|
||||||
|
if self.current_generator is not None:
|
||||||
|
self.current_generator.close()
|
||||||
|
self.current_generator = None
|
||||||
|
self.num_provided_items = 0
|
||||||
|
|
||||||
|
|
||||||
|
def takes(*argkeys):
|
||||||
|
"""Decorator which makes a DynamicItem and specifies its argkeys.
|
||||||
|
|
||||||
|
If the wrapped object is a generator function (has a yield statement),
|
||||||
|
Creates a GeneratorDynamicItem. If the object is already a DynamicItem,
|
||||||
|
just specifies the argkeys for that. Otherwise creates a new regular
|
||||||
|
DynamicItem, with argkeys specified.
|
||||||
|
|
||||||
|
The args are always passed to the function at the start. Generators could
|
||||||
|
support sending new arguments, but for such use cases, simply create a new
|
||||||
|
dynamic item. The GeneratorDynamicItem class is meant for pipelines which
|
||||||
|
take in an input and transform it in multiple ways, where the intermediate
|
||||||
|
representations may be needed for e.g. fitting a BPE segmenter.
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
>>> @takes("text")
|
||||||
|
... def tokenize(text):
|
||||||
|
... return text.strip().lower().split()
|
||||||
|
>>> tokenize.provides = ["tokenized"]
|
||||||
|
>>> tokenize('\tThis Example gets tokenized')
|
||||||
|
['this', 'example', 'gets', 'tokenized']
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(obj):
|
||||||
|
"""Decorator definition."""
|
||||||
|
if isinstance(obj, DynamicItem):
|
||||||
|
if obj.takes:
|
||||||
|
raise ValueError("Can't overwrite DynamicItem.takes")
|
||||||
|
obj.takes = argkeys
|
||||||
|
return obj
|
||||||
|
elif inspect.isgeneratorfunction(obj):
|
||||||
|
return GeneratorDynamicItem(takes=argkeys, func=obj)
|
||||||
|
else:
|
||||||
|
return DynamicItem(takes=argkeys, func=obj)
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
takes_decorator = takes # Just for DataPipeline.add_dynamic_item
|
||||||
|
|
||||||
|
def provides(*output_keys):
|
||||||
|
"""Decorator which makes a DynamicItem and specifies what keys it provides.
|
||||||
|
|
||||||
|
If the wrapped object is a generator function (has a yield statement),
|
||||||
|
Creates a GeneratorDynamicItem. If the object is already a DynamicItem,
|
||||||
|
just specifies the provided keys for that. Otherwise creates a new regular
|
||||||
|
DynamicItem, with provided keys specified.
|
||||||
|
|
||||||
|
NOTE
|
||||||
|
----
|
||||||
|
The behavior is slightly different for generators and regular functions, if
|
||||||
|
many output keys are specified, e.g. @provides("signal", "mfcc"). Regular
|
||||||
|
functions should return a tuple with len equal to len(output_keys), while
|
||||||
|
generators should yield the items one by one.
|
||||||
|
|
||||||
|
>>> @provides("signal", "feat")
|
||||||
|
... def read_feat():
|
||||||
|
... wav = [.1,.2,-.1]
|
||||||
|
... feat = [s**2 for s in wav]
|
||||||
|
... return wav, feat
|
||||||
|
>>> @provides("signal", "feat")
|
||||||
|
... def read_feat():
|
||||||
|
... wav = [.1,.2,-.1]
|
||||||
|
... yield wav
|
||||||
|
... feat = [s**2 for s in wav]
|
||||||
|
... yield feat
|
||||||
|
|
||||||
|
If multiple keys are yielded at once, write e.g.,
|
||||||
|
|
||||||
|
>>> @provides("wav_read", ["left_channel", "right_channel"])
|
||||||
|
... def read_multi_channel():
|
||||||
|
... wav = [[.1,.2,-.1],[.2,.1,-.1]]
|
||||||
|
... yield wav
|
||||||
|
... yield wav[0], wav[1]
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(obj):
|
||||||
|
"""Decorator definition."""
|
||||||
|
if isinstance(obj, DynamicItem):
|
||||||
|
if obj.provides:
|
||||||
|
raise ValueError("Can't overwrite DynamicItem provides-list.")
|
||||||
|
obj.provides = output_keys
|
||||||
|
return obj
|
||||||
|
elif inspect.isgeneratorfunction(obj):
|
||||||
|
return GeneratorDynamicItem(func=obj, provides=output_keys)
|
||||||
|
else:
|
||||||
|
return DynamicItem(func=obj, provides=output_keys)
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
provides_decorator = provides # Just for DataPipeline.add_dynamic_item
|
||||||
|
|
||||||
|
|
||||||
|
class DataPipeline:
|
||||||
|
"""Organises data transformations into a pipeline.
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
>>> pipeline = DataPipeline(
|
||||||
|
... static_data_keys=["text"],
|
||||||
|
... dynamic_items=[
|
||||||
|
... {"func": lambda x: x.lower(), "takes": "text", "provides": "foo"},
|
||||||
|
... {"func": lambda x: x[::-1], "takes": "foo", "provides": "bar"},
|
||||||
|
... ],
|
||||||
|
... output_keys=["bar"],
|
||||||
|
... )
|
||||||
|
>>> pipeline({"text": "Test"})
|
||||||
|
{'bar': 'tset'}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, static_data_keys, dynamic_items=[], output_keys=[]):
|
||||||
|
self.dg = DependencyGraph()
|
||||||
|
self._exec_order = None
|
||||||
|
self.key_to_node = {}
|
||||||
|
self.unaccounted_keys = {}
|
||||||
|
self.dynamic_items = []
|
||||||
|
self.output_mapping = {}
|
||||||
|
self.add_static_keys(static_data_keys)
|
||||||
|
self.add_dynamic_items(dynamic_items)
|
||||||
|
self.set_output_keys(output_keys)
|
||||||
|
|
||||||
|
def add_static_keys(self, static_keys):
|
||||||
|
"""Informs the pipeline about static items.
|
||||||
|
|
||||||
|
Static items are the ones provided to __call__ as data.
|
||||||
|
"""
|
||||||
|
for key in static_keys:
|
||||||
|
node_id = self.dg.add_node(data=StaticItem(key=key))
|
||||||
|
self.key_to_node[key] = node_id
|
||||||
|
|
||||||
|
def add_dynamic_items(self, dynamic_items):
|
||||||
|
"""Add multiple dynamic items at once."""
|
||||||
|
for item in dynamic_items:
|
||||||
|
try:
|
||||||
|
self.add_dynamic_item(**item)
|
||||||
|
except TypeError:
|
||||||
|
self.add_dynamic_item(item)
|
||||||
|
|
||||||
|
def add_dynamic_item(self, func, takes=None, provides=None):
|
||||||
|
"""Adds a dynamic item to the Pipeline.
|
||||||
|
|
||||||
|
Two calling conventions. For DynamicItem objects, just use:
|
||||||
|
add_dynamic_item(dynamic_item)
|
||||||
|
But otherwise, should use:
|
||||||
|
add_dynamic_item(func, takes, provides)
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
func : callable, DynamicItem
|
||||||
|
If a DynamicItem is given, adds that directly. Otherwise a
|
||||||
|
DynamicItem is created, and this specifies the callable to use. If
|
||||||
|
a generator function is given, then create a GeneratorDynamicItem.
|
||||||
|
Otherwise creates a normal DynamicItem.
|
||||||
|
takes : list, str
|
||||||
|
List of keys. When func is called, each key is resolved to
|
||||||
|
either an entry in the data or the output of another dynamic_item.
|
||||||
|
The func is then called with these as positional arguments,
|
||||||
|
in the same order as specified here.
|
||||||
|
A single key can be given as a bare string.
|
||||||
|
provides : str, list
|
||||||
|
For regular functions, the key or list of keys that it provides.
|
||||||
|
If you give a generator function, key or list of keys that it
|
||||||
|
yields, in order. Also see the provides decorator.
|
||||||
|
A single key can be given as a bare string.
|
||||||
|
"""
|
||||||
|
if isinstance(func, DynamicItem):
|
||||||
|
if takes is not None or provides is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"If providing a DynamicItem directly, don't "
|
||||||
|
"specify takes or provides"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._add_dynamic_item_object(func)
|
||||||
|
return
|
||||||
|
if isinstance(takes, str):
|
||||||
|
takes = [takes]
|
||||||
|
if isinstance(provides, str):
|
||||||
|
provides = [provides]
|
||||||
|
di = takes_decorator(*takes)(provides_decorator(*provides)(func))
|
||||||
|
self._add_dynamic_item_object(di)
|
||||||
|
|
||||||
|
def _add_dynamic_item_object(self, obj):
|
||||||
|
"""Internally adds the object.
|
||||||
|
|
||||||
|
There is a node in the dependency graph for each call of the
|
||||||
|
DynamicItem. Each call may return multiple keys and depend on multiple
|
||||||
|
keys. An internal dict maps key to the id of the node that produces it.
|
||||||
|
"""
|
||||||
|
if not obj.provides:
|
||||||
|
raise ValueError(
|
||||||
|
"Won't add redundant dynamic item which doesn't "
|
||||||
|
"provide anything."
|
||||||
|
)
|
||||||
|
depended = []
|
||||||
|
for key in obj.takes:
|
||||||
|
# Might not be accounted for, yet:
|
||||||
|
if key not in self.key_to_node:
|
||||||
|
dependee_keys = self.unaccounted_keys.setdefault(key, [])
|
||||||
|
dependee_keys.extend(obj.next_provides())
|
||||||
|
else:
|
||||||
|
depended.append(self.key_to_node[key])
|
||||||
|
for provided in obj.provided_in_order():
|
||||||
|
node_id = self.dg.add_node(data=obj)
|
||||||
|
for key in provided:
|
||||||
|
self.key_to_node[key] = node_id
|
||||||
|
# This key may also be unaccounted for, so account for it now:
|
||||||
|
if key in self.unaccounted_keys:
|
||||||
|
for dependee_key in self.unaccounted_keys[key]:
|
||||||
|
dependee_node = self.key_to_node[dependee_key]
|
||||||
|
self.dg.add_edge(dependee_node, node_id)
|
||||||
|
del self.unaccounted_keys[key] # Now accounted for!
|
||||||
|
for dep_id in depended:
|
||||||
|
self.dg.add_edge(node_id, dep_id)
|
||||||
|
# Next call will depend on this call:
|
||||||
|
depended = [node_id]
|
||||||
|
# Keep a reference to the item in this object, as well:
|
||||||
|
self.dynamic_items.append(obj)
|
||||||
|
|
||||||
|
def set_output_keys(self, keys):
|
||||||
|
"""Use this to change the output keys.
|
||||||
|
|
||||||
|
Also re-evaluates execution order.
|
||||||
|
So if you request different outputs, some parts of the
|
||||||
|
data pipeline may be skipped.
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
keys : dict, list, None
|
||||||
|
List of keys (str) to produce in output.
|
||||||
|
|
||||||
|
If a dict is given; it is used to map internal keys to output keys.
|
||||||
|
From the output_keys dict key:value pairs the key appears outside,
|
||||||
|
and value is the internal key.
|
||||||
|
"""
|
||||||
|
self.output_mapping = self._output_keys_to_mapping(keys)
|
||||||
|
self._exec_order = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _output_keys_to_mapping(keys):
|
||||||
|
# Ensure a mapping (accept a list for convenience, too)
|
||||||
|
if keys is None:
|
||||||
|
output_mapping = {}
|
||||||
|
elif isinstance(keys, dict):
|
||||||
|
output_mapping = keys
|
||||||
|
else:
|
||||||
|
output_mapping = {key: key for key in keys}
|
||||||
|
return output_mapping
|
||||||
|
|
||||||
|
def compute_outputs(self, data):
|
||||||
|
"""
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
data : dict
|
||||||
|
Dictionary with data entries by key.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict
|
||||||
|
With the keys that were set.
|
||||||
|
"""
|
||||||
|
if self._exec_order is None:
|
||||||
|
self._prepare_run(data)
|
||||||
|
return self._compute(data, self._exec_order, self.output_mapping)
|
||||||
|
|
||||||
|
def compute_specific(self, keys, data):
|
||||||
|
"""Compute output of specific item, without changing output_keys."""
|
||||||
|
output_mapping = self._output_keys_to_mapping(keys)
|
||||||
|
order = self.dg.get_evaluation_order(
|
||||||
|
selected_keys=self.get_selected_node_ids(keys)
|
||||||
|
)
|
||||||
|
return self._compute(data, order, output_mapping)
|
||||||
|
|
||||||
|
def _compute(self, data, order, output_mapping):
|
||||||
|
if self.unaccounted_keys:
|
||||||
|
MSG = "These keys are still unaccounted for in the data pipeline: "
|
||||||
|
MSG += ", ".join(self.unaccounted_keys)
|
||||||
|
raise RuntimeError(MSG)
|
||||||
|
intermediate = {}
|
||||||
|
for node_id, edges, item in order:
|
||||||
|
if isinstance(item, StaticItem):
|
||||||
|
# Static item in data.
|
||||||
|
# Just check that key is found.
|
||||||
|
try:
|
||||||
|
data[item.key]
|
||||||
|
continue
|
||||||
|
except KeyError:
|
||||||
|
raise KeyError(f"Expected key {item.key} in data!")
|
||||||
|
# A dynamic item, which we should compute:
|
||||||
|
args = [
|
||||||
|
data[argkey] if argkey in data else intermediate[argkey]
|
||||||
|
for argkey in item.next_takes()
|
||||||
|
]
|
||||||
|
# This needs to be called BEFORE the dynamic item is called.
|
||||||
|
provided_keys = item.next_provides()
|
||||||
|
values = item(*args) # Call the DynamicItem to produce output
|
||||||
|
# If there is just one output value, wrap in a list so that
|
||||||
|
# it can be zipped as well:
|
||||||
|
if len(provided_keys) == 1:
|
||||||
|
values = [values]
|
||||||
|
intermediate.update(zip(provided_keys, values))
|
||||||
|
for dynamic_item in self.dynamic_items:
|
||||||
|
dynamic_item.reset()
|
||||||
|
return {
|
||||||
|
outkey: data[inkey] if inkey in data else intermediate[inkey]
|
||||||
|
for outkey, inkey in output_mapping.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_selected_node_ids(self, selected_keys):
|
||||||
|
"""Translates selected keys to dependency graph keys."""
|
||||||
|
return [self.key_to_node[key] for key in selected_keys]
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
return self.compute_outputs(data)
|
||||||
|
|
||||||
|
def _prepare_run(self, data):
|
||||||
|
self._exec_order = list(
|
||||||
|
self.dg.get_evaluation_order(
|
||||||
|
self.get_selected_node_ids(self.output_mapping.values())
|
||||||
|
)
|
||||||
|
)
|
@ -0,0 +1,167 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
import csv
|
||||||
|
import shutil
|
||||||
|
import urllib.request
|
||||||
|
import collections.abc
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
import pathlib
|
||||||
|
import paddle
|
||||||
|
import numpy as np
|
||||||
|
def batch_pad_right(array: list, mode="constant", value=0):
|
||||||
|
"""Given a list of torch tensors it batches them together by padding to the right
|
||||||
|
on each dimension in order to get same length for all.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
tensors : list
|
||||||
|
List of tensor we wish to pad together.
|
||||||
|
mode : str
|
||||||
|
Padding mode see torch.nn.functional.pad documentation.
|
||||||
|
value : float
|
||||||
|
Padding value see torch.nn.functional.pad documentation.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
tensor : torch.Tensor
|
||||||
|
Padded tensor.
|
||||||
|
valid_vals : listf
|
||||||
|
List containing proportion for each dimension of original, non-padded values.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not len(array):
|
||||||
|
raise IndexError("Tensors list must not be empty")
|
||||||
|
|
||||||
|
if len(array) == 1:
|
||||||
|
# if there is only one tensor in the batch we simply unsqueeze it.
|
||||||
|
return np.expand_dims(array[0], 0), np.array([1.0], dtype="float32")
|
||||||
|
if not (
|
||||||
|
any(
|
||||||
|
[array[i].ndim == array[0].ndim for i in range(1, len(array))]
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise IndexError("All array must have same number of dimensions")
|
||||||
|
|
||||||
|
# FIXME we limit the support here: we allow padding of only the first dimension
|
||||||
|
# need to remove this when feat extraction is updated to handle multichannel.
|
||||||
|
max_shape = []
|
||||||
|
for dim in range(array[0].ndim):
|
||||||
|
if dim != 0:
|
||||||
|
if not all(
|
||||||
|
[x.shape[dim] == array[0].shape[dim] for x in array[1:]]
|
||||||
|
):
|
||||||
|
raise EnvironmentError(
|
||||||
|
"Tensors should have same dimensions except for the first one"
|
||||||
|
)
|
||||||
|
max_shape.append(max([x.shape[dim] for x in array]))
|
||||||
|
|
||||||
|
batched = []
|
||||||
|
valid = []
|
||||||
|
for t in array:
|
||||||
|
# for each tensor we apply pad_right_to
|
||||||
|
padded, valid_percent = pad_right_to(
|
||||||
|
t, max_shape, mode=mode, value=value
|
||||||
|
)
|
||||||
|
batched.append(padded)
|
||||||
|
valid.append(valid_percent[0])
|
||||||
|
|
||||||
|
batched = np.stack(batched)
|
||||||
|
|
||||||
|
return batched, np.array(valid, dtype="float32")
|
||||||
|
|
||||||
|
np_str_obj_array_pattern = re.compile(r"[SaUO]")
|
||||||
|
|
||||||
|
def pad_right_to(
|
||||||
|
array: np.ndarray, target_shape: (list, tuple), mode="constant", value=0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
This function takes a torch tensor of arbitrary shape and pads it to target
|
||||||
|
shape by appending values on the right.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
tensor : input torch tensor
|
||||||
|
Input tensor whose dimension we need to pad.
|
||||||
|
target_shape : (list, tuple)
|
||||||
|
Target shape we want for the target tensor its len must be equal to tensor.ndim
|
||||||
|
mode : str
|
||||||
|
Pad mode, please refer to torch.nn.functional.pad documentation.
|
||||||
|
value : float
|
||||||
|
Pad value, please refer to torch.nn.functional.pad documentation.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
tensor : torch.Tensor
|
||||||
|
Padded tensor.
|
||||||
|
valid_vals : list
|
||||||
|
List containing proportion for each dimension of original, non-padded values.
|
||||||
|
"""
|
||||||
|
assert len(target_shape) == array.ndim
|
||||||
|
pads = [] # this contains the abs length of the padding for each dimension.
|
||||||
|
valid_vals = [] # this contains the relative lengths for each dimension.
|
||||||
|
i = len(target_shape) - 1 # iterating over target_shape ndims
|
||||||
|
j = 0
|
||||||
|
while i >= 0:
|
||||||
|
assert (
|
||||||
|
target_shape[i] >= array.shape[i]
|
||||||
|
), "Target shape must be >= original shape for every dim"
|
||||||
|
pads.extend([0, target_shape[i] - array.shape[i]])
|
||||||
|
valid_vals.append(array.shape[j] / target_shape[j])
|
||||||
|
i -= 1
|
||||||
|
j += 1
|
||||||
|
array = np.pad(array, pads, mode, constant_values=(value, value))
|
||||||
|
|
||||||
|
return array, valid_vals
|
||||||
|
|
||||||
|
def mod_default_collate(batch):
|
||||||
|
"""Makes a tensor from list of batch values.
|
||||||
|
|
||||||
|
Note that this doesn't need to zip(*) values together
|
||||||
|
as PaddedBatch connects them already (by key).
|
||||||
|
|
||||||
|
Here the idea is not to error out.
|
||||||
|
|
||||||
|
This is modified from:
|
||||||
|
https://github.com/pytorch/pytorch/blob/c0deb231db76dbea8a9d326401417f7d1ce96ed5/torch/utils/data/_utils/collate.py#L42
|
||||||
|
"""
|
||||||
|
elem = batch[0]
|
||||||
|
elem_type = type(elem)
|
||||||
|
if isinstance(elem, paddle.Tensor):
|
||||||
|
out = None
|
||||||
|
try:
|
||||||
|
if torch.io.get_worker_info() is not None:
|
||||||
|
|
||||||
|
# If we're in a background process, concatenate directly into a
|
||||||
|
# shared memory tensor to avoid an extra copy
|
||||||
|
numel = sum([x.numel() for x in batch])
|
||||||
|
storage = elem.storage()._new_shared(numel)
|
||||||
|
out = elem.new(storage)
|
||||||
|
return torch.stack(batch, 0, out=out)
|
||||||
|
except RuntimeError: # Unequal size:
|
||||||
|
return batch
|
||||||
|
elif (
|
||||||
|
elem_type.__module__ == "numpy"
|
||||||
|
and elem_type.__name__ != "str_"
|
||||||
|
and elem_type.__name__ != "string_"
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
if (
|
||||||
|
elem_type.__name__ == "ndarray"
|
||||||
|
or elem_type.__name__ == "memmap"
|
||||||
|
):
|
||||||
|
# array of string classes and object
|
||||||
|
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
|
||||||
|
return batch
|
||||||
|
return mod_default_collate([paddle.to_tensor(b, dtype=b.dtype) for b in batch])
|
||||||
|
elif elem.shape == (): # scalars
|
||||||
|
return paddle.to_tensor(batch, dtype=batch.dtype)
|
||||||
|
except RuntimeError: # Unequal size
|
||||||
|
return batch
|
||||||
|
elif isinstance(elem, float):
|
||||||
|
return paddle.to_tensor(batch, dtype=paddle.float64)
|
||||||
|
elif isinstance(elem, int):
|
||||||
|
return paddle.to_tensor(batch, dtype=paddle.int64)
|
||||||
|
else:
|
||||||
|
return batch
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,215 @@
|
|||||||
|
"""PyTorch compatible DataLoaders
|
||||||
|
|
||||||
|
Essentially we extend PyTorch DataLoader by adding the ability to save the
|
||||||
|
data loading state, so that a checkpoint may be saved in the middle of an
|
||||||
|
epoch.
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
>>> import torch
|
||||||
|
>>> from speechbrain.utils.checkpoints import Checkpointer
|
||||||
|
>>> # An example "dataset" and its loader
|
||||||
|
>>> dataset = torch.randn(10, 1)
|
||||||
|
>>> dataloader = SaveableDataLoader(dataset, num_workers = 3)
|
||||||
|
>>> # Setup the checkpointer:
|
||||||
|
>>> tmpdir = getfixture('tmpdir')
|
||||||
|
>>> checkpointer = Checkpointer(tmpdir, {"dataloader": dataloader})
|
||||||
|
>>> # Iterate:
|
||||||
|
>>> for i, data_point in enumerate(dataloader):
|
||||||
|
... # Here you would process the data:
|
||||||
|
... rainfall_amount_prediction = data_point * 4.
|
||||||
|
... # Now, imagine the experiment gets killed on the fifth batch:
|
||||||
|
... if i == 4:
|
||||||
|
... break
|
||||||
|
... # Luckily, you had just saved a checkpoint:
|
||||||
|
... if i == 3:
|
||||||
|
... _ = checkpointer.save_checkpoint(end_of_epoch = False)
|
||||||
|
>>> # So when you restart the experiment:
|
||||||
|
>>> new_dataloader = SaveableDataLoader(dataset, num_workers = 3)
|
||||||
|
>>> new_checkpointer = Checkpointer(tmpdir, {"dataloader": new_dataloader})
|
||||||
|
>>> _ = new_checkpointer.recover_if_possible()
|
||||||
|
>>> # The dataloader fast-forwards to the position where we left off:
|
||||||
|
>>> assert next(iter(new_dataloader)) == dataset[4]
|
||||||
|
|
||||||
|
Authors:
|
||||||
|
* Aku Rouhe 2020
|
||||||
|
"""
|
||||||
|
import collections
|
||||||
|
import torch
|
||||||
|
from paddlespeech.s2t.io.wav2vec2.data_utils import mod_default_collate
|
||||||
|
# from speechbrain.utils.data_utils import recursive_to
|
||||||
|
from paddlespeech.s2t.io.wav2vec2.data_utils import batch_pad_right
|
||||||
|
from paddle.io import DataLoader
|
||||||
|
import logging
|
||||||
|
import warnings
|
||||||
|
import functools
|
||||||
|
# from batch import PaddedBatch
|
||||||
|
from paddlespeech.s2t.io.wav2vec2.dataset import DynamicItemDataset
|
||||||
|
from paddlespeech.s2t.io.wav2vec2.sampler import ReproducibleRandomSampler
|
||||||
|
import paddle
|
||||||
|
PaddedData = collections.namedtuple("PaddedData", ["data", "lengths"])
|
||||||
|
import numpy
|
||||||
|
|
||||||
|
class Wav2vec2DataLoader(DataLoader):
|
||||||
|
def __init__(self,
|
||||||
|
dataset,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
sampler=None,
|
||||||
|
batch_sampler=None,
|
||||||
|
num_workers=0,
|
||||||
|
collate_fn=None,
|
||||||
|
pin_memory=False,
|
||||||
|
drop_last=False,
|
||||||
|
timeout=0,
|
||||||
|
worker_init_fn=None,
|
||||||
|
multiprocessing_context=None,
|
||||||
|
generator=None):
|
||||||
|
if isinstance(dataset[0], (tuple, list)):
|
||||||
|
return_list = True
|
||||||
|
else:
|
||||||
|
return_list = False
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
dataset,
|
||||||
|
feed_list=None,
|
||||||
|
places=None,
|
||||||
|
return_list=return_list,
|
||||||
|
batch_sampler=batch_sampler,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=shuffle,
|
||||||
|
drop_last=drop_last,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
num_workers=num_workers,
|
||||||
|
use_buffer_reader=True,
|
||||||
|
use_shared_memory=False,
|
||||||
|
timeout=timeout,
|
||||||
|
worker_init_fn=worker_init_fn)
|
||||||
|
if sampler is not None:
|
||||||
|
self.batch_sampler.sampler = sampler
|
||||||
|
# self.dataloader = DataLoader(
|
||||||
|
# dataset=dataset,
|
||||||
|
# batch_sampler=batch_sampler,
|
||||||
|
# collate_fn=collate_fn,
|
||||||
|
# num_workers=num_workers,)
|
||||||
|
|
||||||
|
# def __len__(self):
|
||||||
|
# return len(self.dataloader)
|
||||||
|
|
||||||
|
# def __iter__(self):
|
||||||
|
# return self.dataloader.__iter__()
|
||||||
|
|
||||||
|
# def __call__(self):
|
||||||
|
# return self.__iter__()
|
||||||
|
|
||||||
|
|
||||||
|
def PaddedBatch(
|
||||||
|
examples,
|
||||||
|
padded_keys=None,
|
||||||
|
device_prep_keys=None,
|
||||||
|
padding_func=batch_pad_right,
|
||||||
|
padding_kwargs={},
|
||||||
|
nonpadded_stack=True,
|
||||||
|
):
|
||||||
|
__length = len(examples)
|
||||||
|
__keys = list(examples[0].keys())
|
||||||
|
__padded_keys = []
|
||||||
|
__device_prep_keys = []
|
||||||
|
res = {}
|
||||||
|
for key in __keys:
|
||||||
|
values = [example[key] for example in examples]
|
||||||
|
# Default convert usually does the right thing (numpy2torch etc.)
|
||||||
|
# values = default_convert(values)
|
||||||
|
if (padded_keys is not None and key in padded_keys) or (
|
||||||
|
padded_keys is None and isinstance(values[0], numpy.ndarray)
|
||||||
|
):
|
||||||
|
# Padding and PaddedData
|
||||||
|
__padded_keys.append(key)
|
||||||
|
|
||||||
|
padded = PaddedData(*padding_func(values, **padding_kwargs))
|
||||||
|
res[key] = padded
|
||||||
|
else:
|
||||||
|
# Default PyTorch collate usually does the right thing
|
||||||
|
# (convert lists of equal sized tensors to batch tensors, etc.)
|
||||||
|
if nonpadded_stack:
|
||||||
|
values = mod_default_collate(values)
|
||||||
|
res[key] = values
|
||||||
|
if (device_prep_keys is not None and key in device_prep_keys) or (
|
||||||
|
device_prep_keys is None and isinstance(values[0], paddle.Tensor)
|
||||||
|
):
|
||||||
|
__device_prep_keys.append(key)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def make_dataloader(dataset, stage, **loader_kwargs):
|
||||||
|
"""Makes a basic DataLoader with SpeechBrain defaults.
|
||||||
|
|
||||||
|
For DynamicItemDatasets (which return dicts), use
|
||||||
|
PaddedBatch as the default collate_fn.
|
||||||
|
|
||||||
|
Shuffling gets implemented by ReproducibleRandomSampler.
|
||||||
|
|
||||||
|
If the Dataset is not an IterableDataset, the DataLoader
|
||||||
|
is a SaveableDataLoader.
|
||||||
|
|
||||||
|
If the Dataset is a webdataset.dataset.Composable, set default
|
||||||
|
batch_size = None.
|
||||||
|
|
||||||
|
Can also loop over the underlying dataloader continuously,
|
||||||
|
and stop iterations at nominal epoch lengths.
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
dataset : Dataset
|
||||||
|
The dataset to make a DataLoader for.
|
||||||
|
looped_nominal_epoch : None, int
|
||||||
|
If an integer is given, loop the underlying DataLoader infinitely and
|
||||||
|
set a nominal epoch length in batches (or whatever the DataLoader
|
||||||
|
yields).
|
||||||
|
**loader_kwargs : dict
|
||||||
|
Keyword args to DataLoader, see PyTorch DataLoader for
|
||||||
|
options.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
DataLoader
|
||||||
|
If looped_nominal_epoch is None
|
||||||
|
LoopedLoader
|
||||||
|
If looped_nominal_epoch is not None
|
||||||
|
"""
|
||||||
|
# PaddedBatch as default collation for DynamicItemDataset
|
||||||
|
if "collate_fn" not in loader_kwargs and isinstance(
|
||||||
|
dataset, DynamicItemDataset
|
||||||
|
):
|
||||||
|
loader_kwargs["collate_fn"] = PaddedBatch
|
||||||
|
# Reproducible random sampling
|
||||||
|
if loader_kwargs.get("shuffle", False):
|
||||||
|
if loader_kwargs.get("sampler") is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot specify both shuffle=True and a "
|
||||||
|
"sampler in loader_kwargs"
|
||||||
|
)
|
||||||
|
sampler = ReproducibleRandomSampler(dataset)
|
||||||
|
loader_kwargs["sampler"] = sampler
|
||||||
|
# Should delete shuffle because you can't set both Sampler and
|
||||||
|
# shuffle
|
||||||
|
# NOTE: the dict of loader options may get used elsewhere!
|
||||||
|
# However, this del doesn't touch those because loader_kwargs comes
|
||||||
|
# from a **kwargs dict.
|
||||||
|
del loader_kwargs["shuffle"]
|
||||||
|
# Create the loader
|
||||||
|
dataloader = Wav2vec2DataLoader(dataset, **loader_kwargs)
|
||||||
|
return dataloader
|
||||||
|
|
||||||
|
|
||||||
|
# import collections
|
||||||
|
# import torch
|
||||||
|
# from data_utils import mod_default_collate
|
||||||
|
# # from speechbrain.utils.data_utils import recursive_to
|
||||||
|
# from data_utils import batch_pad_right
|
||||||
|
# from torch.utils.data._utils.collate import default_convert
|
||||||
|
# # from torch.utils.data._utils.pin_memory import (
|
||||||
|
# # pin_memory as recursive_pin_memory,
|
||||||
|
# # )
|
||||||
|
# import paddle
|
||||||
|
|
||||||
|
# PaddedData = collections.namedtuple("PaddedData", ["data", "lengths"])
|
@ -0,0 +1,409 @@
|
|||||||
|
import copy
|
||||||
|
import contextlib
|
||||||
|
from types import MethodType
|
||||||
|
from paddle.io import Dataset
|
||||||
|
from paddlespeech.s2t.io.wav2vec2.data_pipeline import DataPipeline
|
||||||
|
from paddlespeech.s2t.io.wav2vec2.dataio import load_data_json, load_data_csv
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicItemDataset(Dataset):
|
||||||
|
"""Dataset that reads, wrangles, and produces dicts.
|
||||||
|
|
||||||
|
Each data point dict provides some items (by key), for example, a path to a
|
||||||
|
wavefile with the key "wav_file". When a data point is fetched from this
|
||||||
|
Dataset, more items are produced dynamically, based on pre-existing items
|
||||||
|
and other dynamic created items. For example, a dynamic item could take the
|
||||||
|
wavfile path and load the audio from the disk.
|
||||||
|
|
||||||
|
The dynamic items can depend on other dynamic items: a suitable evaluation
|
||||||
|
order is used automatically, as long as there are no circular dependencies.
|
||||||
|
|
||||||
|
A specified list of keys is collected in the output dict. These can be items
|
||||||
|
in the original data or dynamic items. If some dynamic items are not
|
||||||
|
requested, nor depended on by other requested items, they won't be computed.
|
||||||
|
So for example if a user simply wants to iterate over the text, the
|
||||||
|
time-consuming audio loading can be skipped.
|
||||||
|
|
||||||
|
About the format:
|
||||||
|
Takes a dict of dicts as the collection of data points to read/wrangle.
|
||||||
|
The top level keys are data point IDs.
|
||||||
|
Each data point (example) dict should have the same keys, corresponding to
|
||||||
|
different items in that data point.
|
||||||
|
|
||||||
|
Altogether the data collection could look like this:
|
||||||
|
|
||||||
|
>>> data = {
|
||||||
|
... "spk1utt1": {
|
||||||
|
... "wav_file": "/path/to/spk1utt1.wav",
|
||||||
|
... "text": "hello world",
|
||||||
|
... "speaker": "spk1",
|
||||||
|
... },
|
||||||
|
... "spk1utt2": {
|
||||||
|
... "wav_file": "/path/to/spk1utt2.wav",
|
||||||
|
... "text": "how are you world",
|
||||||
|
... "speaker": "spk1",
|
||||||
|
... }
|
||||||
|
... }
|
||||||
|
|
||||||
|
NOTE
|
||||||
|
----
|
||||||
|
The top-level key, the data point id, is implicitly added as an item
|
||||||
|
in the data point, with the key "id"
|
||||||
|
|
||||||
|
Each dynamic item is configured by three things: a key, a func, and a list
|
||||||
|
of argkeys. The key should be unique among all the items (dynamic or not) in
|
||||||
|
each data point. The func is any callable, and it returns the dynamic item's
|
||||||
|
value. The callable is called with the values of other items as specified
|
||||||
|
by the argkeys list (as positional args, passed in the order specified by
|
||||||
|
argkeys).
|
||||||
|
|
||||||
|
The dynamic_items configuration could look like this:
|
||||||
|
|
||||||
|
>>> import torch
|
||||||
|
>>> dynamic_items = [
|
||||||
|
... {"func": lambda l: torch.Tensor(l),
|
||||||
|
... "takes": ["wav_loaded"],
|
||||||
|
... "provides": "wav"},
|
||||||
|
... {"func": lambda path: [ord(c)/100 for c in path], # Fake "loading"
|
||||||
|
... "takes": ["wav_file"],
|
||||||
|
... "provides": "wav_loaded"},
|
||||||
|
... {"func": lambda t: t.split(),
|
||||||
|
... "takes": ["text"],
|
||||||
|
... "provides": "words"}]
|
||||||
|
|
||||||
|
With these, different views of the data can be loaded:
|
||||||
|
|
||||||
|
>>> from speechbrain.dataio.dataloader import SaveableDataLoader
|
||||||
|
>>> from speechbrain.dataio.batch import PaddedBatch
|
||||||
|
>>> dataset = DynamicItemDataset(data, dynamic_items)
|
||||||
|
>>> dataloader = SaveableDataLoader(dataset, collate_fn=PaddedBatch,
|
||||||
|
... batch_size=2)
|
||||||
|
>>> # First, create encoding for words:
|
||||||
|
>>> dataset.set_output_keys(["words"])
|
||||||
|
>>> encoding = {}
|
||||||
|
>>> next_id = 1
|
||||||
|
>>> for batch in dataloader:
|
||||||
|
... for sent in batch.words:
|
||||||
|
... for word in sent:
|
||||||
|
... if word not in encoding:
|
||||||
|
... encoding[word] = next_id
|
||||||
|
... next_id += 1
|
||||||
|
>>> # Next, add an encoded words_tensor dynamic item:
|
||||||
|
>>> dataset.add_dynamic_item(
|
||||||
|
... func = lambda ws: torch.tensor([encoding[w] for w in ws],
|
||||||
|
... dtype=torch.long),
|
||||||
|
... takes = ["words"],
|
||||||
|
... provides = "words_encoded")
|
||||||
|
>>> # Now we can get word and audio tensors:
|
||||||
|
>>> dataset.set_output_keys(["id", "wav", "words_encoded"])
|
||||||
|
>>> batch = next(iter(dataloader))
|
||||||
|
>>> batch.id
|
||||||
|
['spk1utt1', 'spk1utt2']
|
||||||
|
>>> batch.wav # +ELLIPSIS
|
||||||
|
PaddedData(data=tensor([[0.4700, 1.1200, ...
|
||||||
|
>>> batch.words_encoded
|
||||||
|
PaddedData(data=tensor([[1, 2, 0, 0],
|
||||||
|
[3, 4, 5, 2]]), lengths=tensor([0.5000, 1.0000]))
|
||||||
|
|
||||||
|
Output keys can also be a map:
|
||||||
|
|
||||||
|
>>> dataset.set_output_keys({"id":"id", "signal": "wav", "words": "words_encoded"})
|
||||||
|
>>> batch = next(iter(dataloader))
|
||||||
|
>>> batch.words
|
||||||
|
PaddedData(data=tensor([[1, 2, 0, 0],
|
||||||
|
[3, 4, 5, 2]]), lengths=tensor([0.5000, 1.0000]))
|
||||||
|
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
data : dict
|
||||||
|
Dictionary containing single data points (e.g. utterances).
|
||||||
|
dynamic_items : list, optional
|
||||||
|
Configuration for the dynamic items produced when fetching an example.
|
||||||
|
List of DynamicItems or dicts with the format::
|
||||||
|
func: <callable> # To be called
|
||||||
|
takes: <list> # key or list of keys of args this takes
|
||||||
|
provides: key # key or list of keys that this provides
|
||||||
|
output_keys : dict, list, optional
|
||||||
|
List of keys (either directly available in data or dynamic items)
|
||||||
|
to include in the output dict when data points are fetched.
|
||||||
|
|
||||||
|
If a dict is given; it is used to map internal keys to output keys.
|
||||||
|
From the output_keys dict key:value pairs the key appears outside,
|
||||||
|
and value is the internal key.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, data, dynamic_items=[], output_keys=[],
|
||||||
|
):
|
||||||
|
self.data = data
|
||||||
|
self.data_ids = list(self.data.keys())
|
||||||
|
static_keys = list(self.data[self.data_ids[0]].keys())
|
||||||
|
if "id" in static_keys:
|
||||||
|
raise ValueError("The key 'id' is reserved for the data point id.")
|
||||||
|
else:
|
||||||
|
static_keys.append("id")
|
||||||
|
self.pipeline = DataPipeline(static_keys, dynamic_items)
|
||||||
|
self.set_output_keys(output_keys)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data_ids)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
data_id = self.data_ids[index]
|
||||||
|
data_point = self.data[data_id]
|
||||||
|
return self.pipeline.compute_outputs({"id": data_id, **data_point})
|
||||||
|
|
||||||
|
def add_dynamic_item(self, func, takes=None, provides=None):
|
||||||
|
"""Makes a new dynamic item available on the dataset.
|
||||||
|
|
||||||
|
Two calling conventions. For DynamicItem objects, just use:
|
||||||
|
add_dynamic_item(dynamic_item).
|
||||||
|
But otherwise, should use:
|
||||||
|
add_dynamic_item(func, takes, provides).
|
||||||
|
|
||||||
|
See `speechbrain.utils.data_pipeline`.
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
func : callable, DynamicItem
|
||||||
|
If a DynamicItem is given, adds that directly. Otherwise a
|
||||||
|
DynamicItem is created, and this specifies the callable to use. If
|
||||||
|
a generator function is given, then create a GeneratorDynamicItem.
|
||||||
|
Otherwise creates a normal DynamicItem.
|
||||||
|
takes : list, str
|
||||||
|
List of keys. When func is called, each key is resolved to
|
||||||
|
either an entry in the data or the output of another dynamic_item.
|
||||||
|
The func is then called with these as positional arguments,
|
||||||
|
in the same order as specified here.
|
||||||
|
A single arg can be given directly.
|
||||||
|
provides : str
|
||||||
|
Unique key or keys that this provides.
|
||||||
|
"""
|
||||||
|
self.pipeline.add_dynamic_item(func, takes, provides)
|
||||||
|
|
||||||
|
def set_output_keys(self, keys):
|
||||||
|
"""Use this to change the output keys.
|
||||||
|
|
||||||
|
These are the keys that are actually evaluated when a data point
|
||||||
|
is fetched from the dataset.
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
keys : dict, list
|
||||||
|
List of keys (str) to produce in output.
|
||||||
|
|
||||||
|
If a dict is given; it is used to map internal keys to output keys.
|
||||||
|
From the output_keys dict key:value pairs the key appears outside,
|
||||||
|
and value is the internal key.
|
||||||
|
"""
|
||||||
|
self.pipeline.set_output_keys(keys)
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def output_keys_as(self, keys):
|
||||||
|
"""Context manager to temporarily set output keys.
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
>>> dataset = DynamicItemDataset({"a":{"x":1,"y":2},"b":{"x":3,"y":4}},
|
||||||
|
... output_keys = ["x"])
|
||||||
|
>>> with dataset.output_keys_as(["y"]):
|
||||||
|
... print(dataset[0])
|
||||||
|
{'y': 2}
|
||||||
|
>>> print(dataset[0])
|
||||||
|
{'x': 1}
|
||||||
|
|
||||||
|
NOTE
|
||||||
|
----
|
||||||
|
Not thread-safe. While in this context manager, the output keys
|
||||||
|
are affected for any call.
|
||||||
|
"""
|
||||||
|
saved_output = self.pipeline.output_mapping
|
||||||
|
self.pipeline.set_output_keys(keys)
|
||||||
|
yield self
|
||||||
|
self.pipeline.set_output_keys(saved_output)
|
||||||
|
|
||||||
|
def filtered_sorted(
|
||||||
|
self,
|
||||||
|
key_min_value={},
|
||||||
|
key_max_value={},
|
||||||
|
key_test={},
|
||||||
|
sort_key=None,
|
||||||
|
reverse=False,
|
||||||
|
select_n=None,
|
||||||
|
):
|
||||||
|
"""Get a filtered and/or sorted version of this, shares static data.
|
||||||
|
|
||||||
|
The reason to implement these operations in the same method is that
|
||||||
|
computing some dynamic items may be expensive, and this way the
|
||||||
|
filtering and sorting steps don't need to compute the dynamic items
|
||||||
|
twice.
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
key_min_value : dict
|
||||||
|
Map from key (in data or in dynamic items) to limit, will only keep
|
||||||
|
data_point if data_point[key] >= limit
|
||||||
|
key_max_value : dict
|
||||||
|
Map from key (in data or in dynamic items) to limit, will only keep
|
||||||
|
data_point if data_point[key] <= limit
|
||||||
|
key_test : dict
|
||||||
|
Map from key (in data or in dynamic items) to func, will only keep
|
||||||
|
data_point if bool(func(data_point[key])) == True
|
||||||
|
sort_key : None, str
|
||||||
|
If not None, sort by data_point[sort_key]. Default is ascending
|
||||||
|
order.
|
||||||
|
reverse : bool
|
||||||
|
If True, sort in descending order.
|
||||||
|
select_n : None, int
|
||||||
|
If not None, only keep (at most) the first n filtered data_points.
|
||||||
|
The possible sorting is applied, but only on the first n data
|
||||||
|
points found. Meant for debugging.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
FilteredSortedDynamicItemDataset
|
||||||
|
Shares the static data, but has its own output keys and
|
||||||
|
dynamic items (initially deep copied from this, so they have the
|
||||||
|
same dynamic items available)
|
||||||
|
|
||||||
|
NOTE
|
||||||
|
----
|
||||||
|
Temporarily changes the output keys!
|
||||||
|
"""
|
||||||
|
filtered_sorted_ids = self._filtered_sorted_ids(
|
||||||
|
key_min_value, key_max_value, key_test, sort_key, reverse, select_n,
|
||||||
|
)
|
||||||
|
return FilteredSortedDynamicItemDataset(
|
||||||
|
self, filtered_sorted_ids
|
||||||
|
) # NOTE: defined below
|
||||||
|
|
||||||
|
def _filtered_sorted_ids(
|
||||||
|
self,
|
||||||
|
key_min_value={},
|
||||||
|
key_max_value={},
|
||||||
|
key_test={},
|
||||||
|
sort_key=None,
|
||||||
|
reverse=False,
|
||||||
|
select_n=None,
|
||||||
|
):
|
||||||
|
"""Returns a list of data ids, fulfilling the sorting and filtering."""
|
||||||
|
|
||||||
|
def combined_filter(computed):
|
||||||
|
"""Applies filter."""
|
||||||
|
for key, limit in key_min_value.items():
|
||||||
|
# NOTE: docstring promises >= so using that.
|
||||||
|
# Mathematically could also use < for nicer syntax, but
|
||||||
|
# maybe with some super special weird edge case some one can
|
||||||
|
# depend on the >= operator
|
||||||
|
if computed[key] >= limit:
|
||||||
|
continue
|
||||||
|
return False
|
||||||
|
for key, limit in key_max_value.items():
|
||||||
|
if computed[key] <= limit:
|
||||||
|
continue
|
||||||
|
return False
|
||||||
|
for key, func in key_test.items():
|
||||||
|
if bool(func(computed[key])):
|
||||||
|
continue
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
temp_keys = (
|
||||||
|
set(key_min_value.keys())
|
||||||
|
| set(key_max_value.keys())
|
||||||
|
| set(key_test.keys())
|
||||||
|
| set([] if sort_key is None else [sort_key])
|
||||||
|
)
|
||||||
|
filtered_ids = []
|
||||||
|
with self.output_keys_as(temp_keys):
|
||||||
|
for i, data_id in enumerate(self.data_ids):
|
||||||
|
if select_n is not None and len(filtered_ids) == select_n:
|
||||||
|
break
|
||||||
|
data_point = self.data[data_id]
|
||||||
|
data_point["id"] = data_id
|
||||||
|
computed = self.pipeline.compute_outputs(data_point)
|
||||||
|
if combined_filter(computed):
|
||||||
|
if sort_key is not None:
|
||||||
|
# Add (main sorting index, current index, data_id)
|
||||||
|
# So that we maintain current sorting and don't compare
|
||||||
|
# data_id values ever.
|
||||||
|
filtered_ids.append((computed[sort_key], i, data_id))
|
||||||
|
else:
|
||||||
|
filtered_ids.append(data_id)
|
||||||
|
if sort_key is not None:
|
||||||
|
filtered_sorted_ids = [
|
||||||
|
tup[2] for tup in sorted(filtered_ids, reverse=reverse)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
filtered_sorted_ids = filtered_ids
|
||||||
|
return filtered_sorted_ids
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json(
|
||||||
|
cls, json_path, replacements={}, dynamic_items=[], output_keys=[]
|
||||||
|
):
|
||||||
|
"""Load a data prep JSON file and create a Dataset based on it."""
|
||||||
|
data = load_data_json(json_path, replacements)
|
||||||
|
return cls(data, dynamic_items, output_keys)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_csv(
|
||||||
|
cls, csv_path, replacements={}, dynamic_items=[], output_keys=[]
|
||||||
|
):
|
||||||
|
"""Load a data prep CSV file and create a Dataset based on it."""
|
||||||
|
data = load_data_csv(csv_path, replacements)
|
||||||
|
return cls(data, dynamic_items, output_keys)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_arrow_dataset(
|
||||||
|
cls, dataset, replacements={}, dynamic_items=[], output_keys=[]
|
||||||
|
):
|
||||||
|
"""Loading a prepared huggingface dataset"""
|
||||||
|
# define an unbound method to generate puesdo keys
|
||||||
|
def keys(self):
|
||||||
|
"Returns the keys."
|
||||||
|
return [i for i in range(dataset.__len__())]
|
||||||
|
|
||||||
|
# bind this method to arrow dataset
|
||||||
|
dataset.keys = MethodType(keys, dataset)
|
||||||
|
return cls(dataset, dynamic_items, output_keys)
|
||||||
|
|
||||||
|
class FilteredSortedDynamicItemDataset(DynamicItemDataset):
|
||||||
|
"""Possibly filtered, possibly sorted DynamicItemDataset.
|
||||||
|
|
||||||
|
Shares the static data (reference).
|
||||||
|
Has its own dynamic_items and output_keys (deepcopy).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, from_dataset, data_ids):
|
||||||
|
self.data = from_dataset.data
|
||||||
|
self.data_ids = data_ids
|
||||||
|
self.pipeline = copy.deepcopy(from_dataset.pipeline)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json(
|
||||||
|
cls, json_path, replacements={}, dynamic_items=None, output_keys=None
|
||||||
|
):
|
||||||
|
raise TypeError("Cannot create SubsetDynamicItemDataset directly!")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_csv(
|
||||||
|
cls, csv_path, replacements={}, dynamic_items=None, output_keys=None
|
||||||
|
):
|
||||||
|
raise TypeError("Cannot create SubsetDynamicItemDataset directly!")
|
||||||
|
|
||||||
|
|
||||||
|
def add_dynamic_item(datasets, func, takes=None, provides=None):
|
||||||
|
"""Helper for adding the same item to multiple datasets."""
|
||||||
|
for dataset in datasets:
|
||||||
|
dataset.add_dynamic_item(func, takes, provides)
|
||||||
|
|
||||||
|
|
||||||
|
def set_output_keys(datasets, output_keys):
|
||||||
|
"""Helper for setting the same item to multiple datasets."""
|
||||||
|
for dataset in datasets:
|
||||||
|
dataset.set_output_keys(output_keys)
|
@ -0,0 +1,276 @@
|
|||||||
|
"""A dependency graph for finding evaluation order.
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
>>> # The basic use case is that you have a bunch of keys
|
||||||
|
>>> # and some of them depend on each other:
|
||||||
|
>>> database = []
|
||||||
|
>>> functions = {'read': {'func': lambda: (0,1,2),
|
||||||
|
... 'needs': []},
|
||||||
|
... 'process': {'func': lambda X: [x**2 for x in X],
|
||||||
|
... 'needs': ['read']},
|
||||||
|
... 'save': {'func': lambda x: database.append(x),
|
||||||
|
... 'needs': ['process']},
|
||||||
|
... 'print': {'func': lambda x,y: print(x, "became", y),
|
||||||
|
... 'needs': ['read', 'process']},
|
||||||
|
... 'auxiliary': {'func': lambda: (1,2,3),
|
||||||
|
... 'needs': []}}
|
||||||
|
>>> # If this is user supplied info, so you can't just hardcode the order,
|
||||||
|
>>> # a dependency graph may be needed.
|
||||||
|
>>> dg = DependencyGraph()
|
||||||
|
>>> # In simple cases, you can just encode the dependencies directly:
|
||||||
|
>>> for key, conf in functions.items():
|
||||||
|
... for needed in conf["needs"]:
|
||||||
|
... dg.add_edge(key, needed)
|
||||||
|
>>> # Now we can evaluate:
|
||||||
|
>>> outputs = {}
|
||||||
|
>>> for node in dg.get_evaluation_order():
|
||||||
|
... f = functions[node.key]['func']
|
||||||
|
... args = [outputs[needed] for needed in functions[node.key]['needs']]
|
||||||
|
... outputs[node.key] = f(*args)
|
||||||
|
(0, 1, 2) became [0, 1, 4]
|
||||||
|
>>> # This added nodes implicitly.
|
||||||
|
>>> # However, since 'auxiliary' didn't depend on anything,
|
||||||
|
>>> # it didn't get added!
|
||||||
|
>>> assert 'auxiliary' not in outputs
|
||||||
|
>>> # So to be careful, we should also manually add nodes for any thing that
|
||||||
|
>>> # is not an intermediate step.
|
||||||
|
>>> _ = dg.add_node('auxiliary')
|
||||||
|
>>> assert 'auxiliary' in (node.key for node in dg.get_evaluation_order())
|
||||||
|
>>> # Arbitrary data can be added to nodes:
|
||||||
|
>>> dg2 = DependencyGraph()
|
||||||
|
>>> for key, conf in functions.items():
|
||||||
|
... _ = dg2.add_node(key, conf)
|
||||||
|
... for needed in conf["needs"]:
|
||||||
|
... dg2.add_edge(key, needed)
|
||||||
|
>>> # Now we get access to the data in evaluation:
|
||||||
|
>>> outputs2 = {}
|
||||||
|
>>> for key, _, conf in dg2.get_evaluation_order():
|
||||||
|
... f = conf['func']
|
||||||
|
... args = [outputs[needed] for needed in conf['needs']]
|
||||||
|
... outputs[key] = f(*args)
|
||||||
|
(0, 1, 2) became [0, 1, 4]
|
||||||
|
|
||||||
|
Authors:
|
||||||
|
* Aku Rouhe 2020
|
||||||
|
"""
|
||||||
|
import collections
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
class CircularDependencyError(ValueError):
|
||||||
|
"""
|
||||||
|
An error caused by running into circular dependencies while searching for
|
||||||
|
an evaluation order in a DependencyGraph.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
DGNode = collections.namedtuple("DGNode", ["key", "edges", "data"])
|
||||||
|
# A node in DependencyGraph.
|
||||||
|
|
||||||
|
|
||||||
|
class DependencyGraph:
|
||||||
|
"""General-purpose dependency graph.
|
||||||
|
|
||||||
|
Essentially a directed acyclic graph.
|
||||||
|
Usually used to find an evaluation order for e.g. variable substitution
|
||||||
|
The relation that an edge between A and B represents is:
|
||||||
|
"A depends on B, i.e. B should be evaluated before A"
|
||||||
|
|
||||||
|
Nodes can be added explicitly or they can be created implicitly
|
||||||
|
while adding edges.
|
||||||
|
Nodes have keys, which should be some hashable value that identifies
|
||||||
|
the elements the graph represents in your use case. E.G. they can just
|
||||||
|
be the variable name you want to substitute.
|
||||||
|
However, if needed, more generally you can attach any data to a node
|
||||||
|
(e.g. a path in your tree), and if so desired, a unique key can be
|
||||||
|
created for you. You'll only need to know that key while adding edges
|
||||||
|
to/from it.
|
||||||
|
Implicit keys and explicit keys can also be mixed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.digraph = []
|
||||||
|
self.key2ind = {}
|
||||||
|
# Guard for manual duplicates (but not implicitly added ones)
|
||||||
|
self._manually_added_keys = []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_unique_key():
|
||||||
|
"""Returns a unique hashable identifier."""
|
||||||
|
return uuid.uuid4()
|
||||||
|
|
||||||
|
def add_node(self, key=None, data=None):
|
||||||
|
"""Adds a node explicitly.
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
key : hashable, optional
|
||||||
|
If not given, a key is created for you.
|
||||||
|
data : Any, optional
|
||||||
|
Any additional data you wish to attach to this node.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
hashable
|
||||||
|
The key that was used (either yours or generated).
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If node with the given key has already been added explicitly
|
||||||
|
(with this method, not "add_edge").
|
||||||
|
"""
|
||||||
|
if key is None:
|
||||||
|
key = self.get_unique_key()
|
||||||
|
elif key in self._manually_added_keys:
|
||||||
|
raise ValueError("Adding duplicate node: {key}".format(key=key))
|
||||||
|
else:
|
||||||
|
self._manually_added_keys.append(key)
|
||||||
|
if key in self.key2ind: # Implicitly added already; don't add again.
|
||||||
|
ind = self.key2ind[key]
|
||||||
|
node = self.digraph[ind]
|
||||||
|
# All that this operation can do is add data:
|
||||||
|
self.digraph[ind] = DGNode(node.key, node.edges, data)
|
||||||
|
return key
|
||||||
|
self.key2ind[key] = len(self.digraph)
|
||||||
|
self.digraph.append(DGNode(key, [], data))
|
||||||
|
return key
|
||||||
|
|
||||||
|
def add_edge(self, from_key, to_key):
|
||||||
|
"""Adds an edge, and implicitly also creates nodes for keys which have
|
||||||
|
not been seen before. This will not let you add data to your nodes.
|
||||||
|
The relation encodes: "from_key depends on to_key"
|
||||||
|
(to_key must be evaluated before from_key).
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
from_key : hashable
|
||||||
|
The key which depends on.
|
||||||
|
to_key : hashable
|
||||||
|
The key which is depended on.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
from_ind = self._get_ind_and_add_if_new(from_key)
|
||||||
|
to_ind = self._get_ind_and_add_if_new(to_key)
|
||||||
|
edges_list = self.digraph[from_ind].edges
|
||||||
|
if to_ind not in edges_list:
|
||||||
|
edges_list.append(to_ind)
|
||||||
|
|
||||||
|
def _get_ind_and_add_if_new(self, key):
|
||||||
|
# Used internally to implicitly add nodes for unseen keys
|
||||||
|
if key not in self.key2ind:
|
||||||
|
self.key2ind[key] = len(self.digraph)
|
||||||
|
self.digraph.append(DGNode(key, [], None))
|
||||||
|
return self.key2ind[key]
|
||||||
|
|
||||||
|
def is_valid(self):
|
||||||
|
"""Checks if an evaluation order can be found.
|
||||||
|
|
||||||
|
A dependency graph is evaluatable if there are no circular
|
||||||
|
dependencies, i.e., the graph is acyclic.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
Indicating if the graph is evaluatable.
|
||||||
|
"""
|
||||||
|
return not self._find_first_cycle()
|
||||||
|
|
||||||
|
def get_evaluation_order(self, selected_keys=None):
|
||||||
|
"""Finds one valid evaluation order.
|
||||||
|
|
||||||
|
There can be many different valid
|
||||||
|
orders.
|
||||||
|
NOTE: Generates output one DGNode at a time. May generate DGNodes
|
||||||
|
before it finds a circular dependency. If you really need to know
|
||||||
|
whether an order can be found, check is_valid() first. However,
|
||||||
|
the algorithm for finding cycles is essentially the same as the one
|
||||||
|
used for finding an evaluation order, so for very large graphs...
|
||||||
|
Ah well, but maybe then you should be using some other solution
|
||||||
|
anyway.
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
selected_keys : list, None
|
||||||
|
List of keys. If not None, only the selected keys are guaranteed
|
||||||
|
in the evaluation order (along with the keys they depend on).
|
||||||
|
|
||||||
|
Yields
|
||||||
|
------
|
||||||
|
DGNode
|
||||||
|
The added DGNodes in a valid evaluation order.
|
||||||
|
See the DGNode namedtuple above.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
CircularDependencyError
|
||||||
|
If a circular dependency is found.
|
||||||
|
"""
|
||||||
|
seen_ever = set()
|
||||||
|
|
||||||
|
def toposort(root_ind, visited):
|
||||||
|
"""Implementation of topsort."""
|
||||||
|
nonlocal seen_ever
|
||||||
|
here = visited + [root_ind]
|
||||||
|
if root_ind in visited:
|
||||||
|
raise CircularDependencyError(
|
||||||
|
"{cycle}".format(
|
||||||
|
cycle=" -> ".join(
|
||||||
|
str(self.digraph[i].key) for i in here
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if root_ind in seen_ever:
|
||||||
|
return # Yield nothing
|
||||||
|
seen_ever = seen_ever.union(set([root_ind]))
|
||||||
|
for to_ind in self.digraph[root_ind].edges:
|
||||||
|
for ind in toposort(to_ind, visited=here):
|
||||||
|
yield ind
|
||||||
|
yield root_ind
|
||||||
|
|
||||||
|
if selected_keys is None:
|
||||||
|
start_inds = range(len(self.digraph))
|
||||||
|
else:
|
||||||
|
start_inds = [self.key2ind[key] for key in selected_keys]
|
||||||
|
|
||||||
|
for start_ind in start_inds:
|
||||||
|
for ind in toposort(start_ind, []):
|
||||||
|
yield self.digraph[ind]
|
||||||
|
|
||||||
|
def _find_first_cycle(self):
|
||||||
|
"""Depth-first search based algorithm for finding cycles in the graph."""
|
||||||
|
seen_ever = set()
|
||||||
|
|
||||||
|
def cycle_dfs(root_ind, visited):
|
||||||
|
"""Implementation of cycle_dfs."""
|
||||||
|
nonlocal seen_ever
|
||||||
|
print(root_ind, visited)
|
||||||
|
here = visited + [root_ind]
|
||||||
|
if root_ind in visited:
|
||||||
|
return here
|
||||||
|
if root_ind in seen_ever:
|
||||||
|
return []
|
||||||
|
seen_ever = seen_ever.union(set([root_ind]))
|
||||||
|
for to_ind in self.digraph[root_ind].edges:
|
||||||
|
cycle = cycle_dfs(to_ind, here)
|
||||||
|
if cycle:
|
||||||
|
return cycle
|
||||||
|
return []
|
||||||
|
|
||||||
|
for ind in range(len(self.digraph)):
|
||||||
|
if ind not in seen_ever:
|
||||||
|
cycle = cycle_dfs(ind, [])
|
||||||
|
if cycle:
|
||||||
|
return cycle
|
||||||
|
return []
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
# Allows the syntax:
|
||||||
|
# 'key' in dependency_graph
|
||||||
|
return key in self.key2ind
|
@ -0,0 +1,115 @@
|
|||||||
|
import paddlespeech.s2t.io.wav2vec2.dataloader
|
||||||
|
|
||||||
|
|
||||||
|
def _train_loader_specifics(self, dataset, loader_kwargs):
|
||||||
|
sampler = loader_kwargs.get("sampler", None)
|
||||||
|
# Shuffling should really only matter for the train stage. Shuffling
|
||||||
|
# will also lead to more padding in batches if the order was otherwise
|
||||||
|
# sorted by length.
|
||||||
|
shuffle = loader_kwargs.get("shuffle", False)
|
||||||
|
if shuffle and not self.distributed_launch:
|
||||||
|
if sampler is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot specify both shuffle=True"
|
||||||
|
"and a sampler in loader_kwargs"
|
||||||
|
)
|
||||||
|
sampler = ReproducibleRandomSampler(dataset)
|
||||||
|
self.train_sampler = sampler
|
||||||
|
loader_kwargs["sampler"] = self.train_sampler
|
||||||
|
# Delete the shuffle flag, since you cannot specify both a sampler and
|
||||||
|
# shuffling:
|
||||||
|
del loader_kwargs["shuffle"]
|
||||||
|
|
||||||
|
# Possibly make a DistributedSampler or a wrapper for some other sampler
|
||||||
|
if self.distributed_launch and not isinstance(dataset, IterableDataset):
|
||||||
|
drop_last = loader_kwargs.get("drop_last", False)
|
||||||
|
# num_replicas arg is equal to world_size
|
||||||
|
# and retrieved automatically within
|
||||||
|
# DistributedSampler obj.
|
||||||
|
if sampler is not None:
|
||||||
|
self.train_sampler = DistributedSamplerWrapper(
|
||||||
|
sampler,
|
||||||
|
rank=self.rank,
|
||||||
|
drop_last=drop_last,
|
||||||
|
shuffle=shuffle,
|
||||||
|
)
|
||||||
|
|
||||||
|
# with DistributedSamplerWrapper, one must disable shuffling for dataloader
|
||||||
|
loader_kwargs["shuffle"] = False
|
||||||
|
loader_kwargs["sampler"] = self.train_sampler
|
||||||
|
elif loader_kwargs.get("batch_sampler") is None:
|
||||||
|
# no sampler and batch-sampler
|
||||||
|
self.train_sampler = DistributedSampler(
|
||||||
|
dataset, rank=self.rank, shuffle=True, drop_last=drop_last
|
||||||
|
)
|
||||||
|
|
||||||
|
# with DistributedSamplerWrapper, one must disable shuffling for dataloader
|
||||||
|
loader_kwargs["shuffle"] = False
|
||||||
|
loader_kwargs["sampler"] = self.train_sampler
|
||||||
|
else: # batch_sampler was specified
|
||||||
|
self.train_sampler = DistributedSamplerWrapper(
|
||||||
|
loader_kwargs.get("batch_sampler", None),
|
||||||
|
rank=self.rank,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
loader_kwargs["batch_sampler"] = self.train_sampler
|
||||||
|
elif self.distributed_launch and isinstance(dataset, IterableDataset):
|
||||||
|
logger.warning(
|
||||||
|
"Cannot automatically solve distributed sampling "
|
||||||
|
"for IterableDataset."
|
||||||
|
)
|
||||||
|
return loader_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def make_dataloader(
|
||||||
|
self, dataset, stage, **loader_kwargs
|
||||||
|
):
|
||||||
|
"""Creates DataLoaders for Datasets.
|
||||||
|
|
||||||
|
This is used by ``fit()`` and ``evaluate()`` if they just receive
|
||||||
|
Datasets.
|
||||||
|
|
||||||
|
Alternatively, this can be called from outside the Brain subclass.
|
||||||
|
In that case, the DataLoader should be passed to ``fit()`` in place
|
||||||
|
of the dataset.
|
||||||
|
|
||||||
|
The Stage.TRAIN DataLoader is handled specially. It has extra args for
|
||||||
|
shuffle and drop_last. In DDP a DistributedSampler is created (unless
|
||||||
|
the dataset is an IterableDataset).
|
||||||
|
|
||||||
|
NOTE
|
||||||
|
----
|
||||||
|
Some important DataLoader arguments are passed via **loader_kwargs,
|
||||||
|
e.g., batch_size, num_workers, pin_memory.
|
||||||
|
|
||||||
|
NOTE
|
||||||
|
----
|
||||||
|
By default, ``evaluate()`` specifies ckpt_prefix=None to stop the test
|
||||||
|
DataLoader being added to the checkpointer. If you need to add a
|
||||||
|
recoverable after saving checkpoints (e.g., at test time, after
|
||||||
|
checkpointing the training), and still be able to recover reasonably,
|
||||||
|
you should probably specify ``allow_partial_load=True``.
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
dataset : Dataset
|
||||||
|
A set of data to use to create data loader. If the Dataset is a
|
||||||
|
DynamicItemDataset, PaddedBatch is used as the default collate_fn,
|
||||||
|
unless specified in loader_kwargs.
|
||||||
|
stage : Stage
|
||||||
|
The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST
|
||||||
|
ckpt_prefix : str, None
|
||||||
|
Prefix to use for SaveableDataLoader Checkpoint name. The Stage
|
||||||
|
name is added to this to create the full key. Set to None to not
|
||||||
|
save the DataLoader.
|
||||||
|
**loader_kwargs : dict
|
||||||
|
Additional keyword arguments to the DataLoader.
|
||||||
|
E.g., batch_size, num_workers, pin_memory.
|
||||||
|
"""
|
||||||
|
# TRAIN stage is handled specially.
|
||||||
|
# if stage == train:
|
||||||
|
# loader_kwargs = _train_loader_specifics(dataset, loader_kwargs)
|
||||||
|
dataloader_ = dataloader.make_dataloader(
|
||||||
|
dataset, **loader_kwargs
|
||||||
|
)
|
||||||
|
return dataloader_
|
@ -0,0 +1,695 @@
|
|||||||
|
"""PyTorch compatible samplers.
|
||||||
|
|
||||||
|
These determine the order of iteration through a dataset.
|
||||||
|
|
||||||
|
Authors:
|
||||||
|
* Aku Rouhe 2020
|
||||||
|
* Samuele Cornell 2020
|
||||||
|
* Ralf Leibold 2020
|
||||||
|
* Artem Ploujnikov 2021
|
||||||
|
* Andreas Nautsch 2021
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
from operator import itemgetter
|
||||||
|
from paddle.io import (
|
||||||
|
RandomSampler,
|
||||||
|
WeightedRandomSampler,
|
||||||
|
Sampler,
|
||||||
|
)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from typing import List
|
||||||
|
from paddlespeech.s2t.io.wav2vec2.dataset import DynamicItemDataset
|
||||||
|
from collections import Counter
|
||||||
|
from scipy.stats import lognorm
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class ReproducibleRandomSampler(RandomSampler):
|
||||||
|
"""A modification of RandomSampler which always returns the same values.
|
||||||
|
|
||||||
|
Also look at `torch.utils.data.RandomSampler`. This has mostly
|
||||||
|
the same behaviour and arguments, except for adding 'seed' and 'epoch' and
|
||||||
|
not supporting 'generator'.
|
||||||
|
|
||||||
|
Note
|
||||||
|
----
|
||||||
|
Call `set_epoch` before every epoch. Otherwise, the sampler will produce the
|
||||||
|
same sequence of indices every epoch.
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
data_source : Dataset
|
||||||
|
The data source to sample indices for.
|
||||||
|
seed : int
|
||||||
|
The base seed to use for the random number generator. It is recommended
|
||||||
|
to use a value which has a good mix of 0 and 1 bits.
|
||||||
|
epoch : int
|
||||||
|
The epoch to start at.
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
>>> import torch
|
||||||
|
>>> from speechbrain.utils.checkpoints import Checkpointer
|
||||||
|
>>> from speechbrain.dataio.dataloader import SaveableDataLoader
|
||||||
|
>>> # An example "dataset"
|
||||||
|
>>> dataset = torch.arange(10).unsqueeze(1)
|
||||||
|
>>> # Create the random sampler:
|
||||||
|
>>> sampler = ReproducibleRandomSampler(dataset)
|
||||||
|
>>> dataloader = SaveableDataLoader(dataset, sampler = sampler,
|
||||||
|
... num_workers = 3)
|
||||||
|
>>> # Setup the checkpointer.
|
||||||
|
>>> # Note that the sampler doesn't need to be saved itself.
|
||||||
|
>>> tmpdir = getfixture('tmpdir')
|
||||||
|
>>> checkpointer = Checkpointer(tmpdir, {"dataloader": dataloader})
|
||||||
|
>>> # Iterate:
|
||||||
|
>>> subset = []
|
||||||
|
>>> for i, data_point in enumerate(dataloader):
|
||||||
|
... # Say you save a checkpoint on the fourth batch:
|
||||||
|
... if i == 3:
|
||||||
|
... _ = checkpointer.save_checkpoint(end_of_epoch = False)
|
||||||
|
... # So let's save the numbers you would get if you continue
|
||||||
|
... if i >= 4:
|
||||||
|
... subset.append(data_point.item())
|
||||||
|
>>> # What if instead you had to restart the experiment?
|
||||||
|
>>> new_sampler = ReproducibleRandomSampler(dataset)
|
||||||
|
>>> new_dataloader = SaveableDataLoader(dataset, sampler = new_sampler,
|
||||||
|
... num_workers = 3)
|
||||||
|
>>> new_checkpointer = Checkpointer(tmpdir, {"dataloader": new_dataloader})
|
||||||
|
>>> _ = new_checkpointer.recover_if_possible()
|
||||||
|
>>> # You'll get the same random order again:
|
||||||
|
>>> new_subset = [data_point.item() for data_point in new_dataloader]
|
||||||
|
>>> assert subset == new_subset
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data_source, seed=563375142, epoch=0, **kwargs):
|
||||||
|
if "generator" in kwargs:
|
||||||
|
MSG = (
|
||||||
|
"Cannot give a separate generator when using "
|
||||||
|
+ "ReproducibleRandomSampler"
|
||||||
|
)
|
||||||
|
raise ValueError(MSG)
|
||||||
|
super().__init__(data_source, **kwargs)
|
||||||
|
self.seed = int(seed)
|
||||||
|
self.epoch = epoch
|
||||||
|
self.gen = paddle.seed(1)
|
||||||
|
|
||||||
|
def set_epoch(self, epoch):
|
||||||
|
"""
|
||||||
|
You can also just access self.epoch, but we maintain this interface
|
||||||
|
to mirror torch.utils.data.distributed.DistributedSampler
|
||||||
|
"""
|
||||||
|
self.epoch = epoch
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
self.gen.manual_seed(self.seed + self.epoch)
|
||||||
|
return super().__iter__()
|
||||||
|
|
||||||
|
|
||||||
|
class ReproducibleWeightedRandomSampler(WeightedRandomSampler):
|
||||||
|
"""A reproducible modification of WeightedRandomSampler.
|
||||||
|
|
||||||
|
Also look at `torch.utils.data.WeightedRandomSampler`. This has the
|
||||||
|
the same behaviour and arguments, except for adding 'seed' and 'epoch' and
|
||||||
|
not supporting 'generator'.
|
||||||
|
|
||||||
|
Note
|
||||||
|
----
|
||||||
|
Call `set_epoch` before every epoch. Otherwise, the sampler will produce the
|
||||||
|
same sequence of indices every epoch.
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
weights : sequence of float
|
||||||
|
Weights for each index. Doesn't need to sum to one.
|
||||||
|
num_samples : int
|
||||||
|
Number of samples to draw
|
||||||
|
replacement : bool
|
||||||
|
To draw with replacement or not (within an epoch of num_samples).
|
||||||
|
seed : int
|
||||||
|
The base seed to use for the random number generator. It is recommended
|
||||||
|
to use a value which has a good mix of 0 and 1 bits.
|
||||||
|
epoch : int
|
||||||
|
The epoch to start at.
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
>>> a = ReproducibleWeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)
|
||||||
|
>>> b = ReproducibleWeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)
|
||||||
|
>>> list(a)
|
||||||
|
[3, 1, 4, 4, 4]
|
||||||
|
>>> list(b)
|
||||||
|
[3, 1, 4, 4, 4]
|
||||||
|
>>> a.set_epoch(1)
|
||||||
|
>>> list(a)
|
||||||
|
[4, 5, 4, 4, 3]
|
||||||
|
>>> b.set_epoch(1)
|
||||||
|
>>> list(b)
|
||||||
|
[4, 5, 4, 4, 3]
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weights,
|
||||||
|
num_samples,
|
||||||
|
replacement,
|
||||||
|
seed=129491412,
|
||||||
|
epoch=0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if "generator" in kwargs:
|
||||||
|
MSG = (
|
||||||
|
"Cannot give a separate generator when using "
|
||||||
|
+ "ReproducibleRandomSampler"
|
||||||
|
)
|
||||||
|
raise ValueError(MSG)
|
||||||
|
super().__init__(weights, num_samples, replacement, **kwargs)
|
||||||
|
self.seed = int(seed)
|
||||||
|
self.epoch = epoch
|
||||||
|
self.gen = paddle.seed(1)
|
||||||
|
|
||||||
|
def set_epoch(self, epoch):
|
||||||
|
"""
|
||||||
|
You can also just access self.epoch, but we maintain this interface
|
||||||
|
to mirror torch.utils.data.distributed.DistributedSampler
|
||||||
|
"""
|
||||||
|
self.epoch = epoch
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
self.gen.manual_seed(self.seed + self.epoch)
|
||||||
|
return super().__iter__()
|
||||||
|
|
||||||
|
class DynamicBatchSampler(Sampler):
|
||||||
|
"""This BatchSampler batches examples together by grouping them by their length.
|
||||||
|
|
||||||
|
Every example in the batch have approximately the same length and
|
||||||
|
thus padding is minimized.
|
||||||
|
This enables faster training on datasets
|
||||||
|
where length of examples can vary significantly (e.g Librispeech).
|
||||||
|
Inspired by: https://www.tensorflow.org/api_docs/python/tf/data/experimental/bucket_by_sequence_length
|
||||||
|
|
||||||
|
Dynamic batching is performed by specifying a max_batch_length which is the
|
||||||
|
upper limit for the sum of the length of examples in a batch:
|
||||||
|
e.g., if ex1 has length 4, ex2 length 5 and if max_batch_length is set to 6
|
||||||
|
ex1 and ex2 will be placed, alone, in two distinct batches.
|
||||||
|
|
||||||
|
Length for each example can be obtained in two manners.
|
||||||
|
If the input dataset is a DynamicItemDataset it can be obtained by specifying a
|
||||||
|
length_func. Default assumes a "duration" entry is in the annotation.
|
||||||
|
Length for each example can also be passed to this class upon instantiation
|
||||||
|
by specifying a list containing the length for each example and passing it to
|
||||||
|
lengths_list.
|
||||||
|
|
||||||
|
Examples are grouped together by defining a set of possible discrete intervals
|
||||||
|
(buckets). Examples whose length fall into these intervals can be batched together.
|
||||||
|
|
||||||
|
The number of buckets can be specified by using the arg num_buckets.
|
||||||
|
There is usually an optimal range for the value of this argument.
|
||||||
|
|
||||||
|
If num_buckets == 1, all examples can be batched together. You have maximum randomization
|
||||||
|
but your training speed will be slower due to the fact that a large amount of the values will be padding
|
||||||
|
as long and short examples can be batched together.
|
||||||
|
As the number of buckets grows only examples with similar
|
||||||
|
length can be grouped together.
|
||||||
|
This trades-off speed with randomization.
|
||||||
|
TLDR: Low number -> better randomization, High number -> faster training.
|
||||||
|
NOTE THAT: if set too high the training speed will decrease. If num_buckets -> number of examples in the dataset the batch size
|
||||||
|
will be small impacting training speed and possibly performance.
|
||||||
|
|
||||||
|
The buckets can also be specified by passing a list to the bucket_boundaries
|
||||||
|
argument instead of specifying a left_bucket_length and a bucket_length_multiplier.
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
>>> import torch
|
||||||
|
>>> import speechbrain as sb
|
||||||
|
>>> from speechbrain.dataio.sampler import DynamicBatchSampler
|
||||||
|
>>> from speechbrain.dataio.dataset import DynamicItemDataset
|
||||||
|
>>> from speechbrain.dataio.dataloader import SaveableDataLoader
|
||||||
|
>>> from speechbrain.dataio.batch import PaddedBatch
|
||||||
|
>>> import numpy as np
|
||||||
|
>>> item_lengths = sorted([np.random.randint(10, 100) for x in range(20)])
|
||||||
|
>>> dataset = {"ex_{}".format(x) : {"wav" :torch.randn(x)} for x in item_lengths}
|
||||||
|
>>> dataset = DynamicItemDataset(dataset)
|
||||||
|
>>> dataset.set_output_keys(["wav"])
|
||||||
|
>>> length_func = lambda x : len(x) # trivial in this example
|
||||||
|
>>> bsampler = DynamicBatchSampler(dataset, 20, 4, length_func, shuffle=False, batch_ordering='descending')
|
||||||
|
>>> dataloader = SaveableDataLoader(dataset, batch_sampler=bsampler, collate_fn=PaddedBatch)
|
||||||
|
>>> for i, b in enumerate(dataloader):
|
||||||
|
... data, length = b["wav"]
|
||||||
|
>>> assert data.shape[-1] == max(item_lengths)
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
dataset : torch.utils.data.Dataset
|
||||||
|
Pytorch Dataset from which elements will be sampled.
|
||||||
|
max_batch_length : int
|
||||||
|
Upper limit for the sum of the length of examples in a batch.
|
||||||
|
Should be chosen based on your GPU memory.
|
||||||
|
num_buckets : int
|
||||||
|
Number of discrete buckets used to group examples together.
|
||||||
|
If num_buckets == 1, all examples can be batched together. As the number of buckets grows only examples with similar
|
||||||
|
length can be grouped together. This trades-off speed with randomization.
|
||||||
|
Low number -> better randomization, High number -> faster training.
|
||||||
|
However if set too high the training speed will decrease. If num_buckets -> number of examples in the dataset the batch size
|
||||||
|
will be small impacting training speed and possibly performance.
|
||||||
|
NOTE: you have either to specify manually the bucket_boundaries or the number of buckets.
|
||||||
|
length_func : callable
|
||||||
|
Function used to get length of each example from the dataset.
|
||||||
|
This argument can be used only when the dataset is a Speechbrain DynamicItemDataset object.
|
||||||
|
Can be anything: e.g. lambda x: x["duration"]*16000 returns number of samples
|
||||||
|
if duration key in the annotation is in seconds and the file has 16kHz sampling freq.
|
||||||
|
shuffle : bool
|
||||||
|
Whether or not shuffle examples between each epoch.
|
||||||
|
batch_ordering : string
|
||||||
|
If ``random``, batches are randomly permuted; otherwise ``ascending`` or ``descending`` sorted by length.
|
||||||
|
max_batch_ex: int
|
||||||
|
If set, it limits the maximum number of examples that can be in a batch superseeding max_batch_length
|
||||||
|
in instances where the amount of examples will exceeed the value specified here.
|
||||||
|
E.g. you have a lot of short examples and the batch size for those will be too high, you can use this argument
|
||||||
|
to limit the batch size for these short examples.
|
||||||
|
bucket_boundaries : list
|
||||||
|
Overrides bucket_length_multiplier and left_bucket_length by specifying manually
|
||||||
|
the buckets right boundaries.
|
||||||
|
lengths_list: list
|
||||||
|
Overrides length_func by passing a list containing the length of each example
|
||||||
|
in the dataset. This argument must be set when the dataset is a plain
|
||||||
|
Pytorch Dataset object and not a DynamicItemDataset object as length_func
|
||||||
|
cannot be used on Pytorch Datasets.
|
||||||
|
epoch : int
|
||||||
|
The epoch to start at.
|
||||||
|
drop_last : bool
|
||||||
|
If ``True``, the sampler will drop the last examples which
|
||||||
|
have not been grouped.
|
||||||
|
verbose: bool
|
||||||
|
If ``True``, log also the stats for each batch at the first epoch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset,
|
||||||
|
max_batch_length: int,
|
||||||
|
num_buckets: int = None,
|
||||||
|
length_func=lambda x: x["duration"],
|
||||||
|
shuffle: bool = True,
|
||||||
|
batch_ordering: str = "random",
|
||||||
|
max_batch_ex: int = None,
|
||||||
|
bucket_boundaries: List[int] = [],
|
||||||
|
lengths_list: List[int] = None,
|
||||||
|
seed: int = 42,
|
||||||
|
epoch: int = 0,
|
||||||
|
drop_last: bool = False,
|
||||||
|
verbose: bool = False,
|
||||||
|
):
|
||||||
|
self._dataset = dataset
|
||||||
|
self._ex_lengths = {}
|
||||||
|
ex_ids = self._dataset.data_ids
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
|
# We do not put a default on num_buckets to encourage users to play with this parameter
|
||||||
|
if num_buckets is None and len(bucket_boundaries) == 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Please specify either num_buckets or bucket boundaries."
|
||||||
|
"Check the docs, and/or the tutorial !"
|
||||||
|
)
|
||||||
|
|
||||||
|
if lengths_list is not None:
|
||||||
|
# take length of examples from this argument and bypass length_key
|
||||||
|
for indx in range(len(lengths_list)):
|
||||||
|
self._ex_lengths[str(indx)] = lengths_list[indx]
|
||||||
|
else:
|
||||||
|
# use length func
|
||||||
|
if not isinstance(dataset, DynamicItemDataset):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Dataset should be a Speechbrain DynamicItemDataset when using length function"
|
||||||
|
)
|
||||||
|
for indx in range(len(self._dataset)):
|
||||||
|
self._ex_lengths[str(indx)] = length_func(
|
||||||
|
self._dataset.data[ex_ids[indx]]
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(bucket_boundaries) > 0:
|
||||||
|
if not all([x >= 0 for x in bucket_boundaries]):
|
||||||
|
raise ValueError(
|
||||||
|
"All elements in bucket boundaries should be non-negative (>= 0)."
|
||||||
|
)
|
||||||
|
if not len(set(bucket_boundaries)) == len(bucket_boundaries):
|
||||||
|
raise ValueError(
|
||||||
|
"Bucket_boundaries should not contain duplicates."
|
||||||
|
)
|
||||||
|
np.testing.assert_array_equal(
|
||||||
|
np.array(bucket_boundaries),
|
||||||
|
np.array(sorted(bucket_boundaries)),
|
||||||
|
err_msg="The arg bucket_boundaries should be an ascending sorted list of non negative values values!",
|
||||||
|
)
|
||||||
|
self._bucket_boundaries = np.array(sorted(bucket_boundaries))
|
||||||
|
else:
|
||||||
|
# use num_buckets
|
||||||
|
self._bucket_boundaries = np.array(
|
||||||
|
self._get_boundaries_through_warping(
|
||||||
|
max_batch_length=max_batch_length,
|
||||||
|
num_quantiles=num_buckets,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._max_batch_length = max_batch_length
|
||||||
|
self._shuffle_ex = shuffle
|
||||||
|
self._batch_ordering = batch_ordering
|
||||||
|
self._seed = seed
|
||||||
|
self._drop_last = drop_last
|
||||||
|
if max_batch_ex is None:
|
||||||
|
max_batch_ex = np.inf
|
||||||
|
self._max_batch_ex = max_batch_ex
|
||||||
|
# Calculate bucket lengths - how often does one bucket boundary fit into max_batch_length?
|
||||||
|
self._bucket_lens = [
|
||||||
|
max(1, int(max_batch_length / self._bucket_boundaries[i]))
|
||||||
|
for i in range(len(self._bucket_boundaries))
|
||||||
|
] + [1]
|
||||||
|
self._epoch = epoch
|
||||||
|
self._generate_batches()
|
||||||
|
|
||||||
|
def get_durations(self, batch):
|
||||||
|
"""Gets durations of the elements in the batch."""
|
||||||
|
return [self._ex_lengths[str(idx)] for idx in batch]
|
||||||
|
|
||||||
|
def _get_boundaries_through_warping(
|
||||||
|
self, max_batch_length: int, num_quantiles: int,
|
||||||
|
) -> List[int]:
|
||||||
|
|
||||||
|
# NOTE: the following lines do not cover that there is only one example in the dataset
|
||||||
|
# warp frames (duration) distribution of train data
|
||||||
|
logger.info("Batch quantisation in latent space")
|
||||||
|
# linspace set-up
|
||||||
|
num_boundaries = num_quantiles + 1
|
||||||
|
# create latent linearly equal spaced buckets
|
||||||
|
latent_boundaries = np.linspace(
|
||||||
|
1 / num_boundaries, num_quantiles / num_boundaries, num_quantiles,
|
||||||
|
)
|
||||||
|
# get quantiles using lognormal distribution
|
||||||
|
quantiles = lognorm.ppf(latent_boundaries, 1)
|
||||||
|
# scale up to to max_batch_length
|
||||||
|
bucket_boundaries = quantiles * max_batch_length / quantiles[-1]
|
||||||
|
# compute resulting bucket length multipliers
|
||||||
|
length_multipliers = [
|
||||||
|
bucket_boundaries[x + 1] / bucket_boundaries[x]
|
||||||
|
for x in range(num_quantiles - 1)
|
||||||
|
]
|
||||||
|
# logging
|
||||||
|
logger.info(
|
||||||
|
"Latent bucket boundary - buckets: {} - length multipliers: {}".format(
|
||||||
|
list(map("{:.2f}".format, bucket_boundaries)),
|
||||||
|
list(map("{:.2f}".format, length_multipliers)),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return list(sorted(bucket_boundaries))
|
||||||
|
|
||||||
|
def _permute_batches(self):
|
||||||
|
|
||||||
|
if self._batch_ordering == "random":
|
||||||
|
# deterministically shuffle based on epoch and seed
|
||||||
|
gen = paddle.seed(1)
|
||||||
|
gen.manual_seed(self._seed + self._epoch)
|
||||||
|
sampler = torch.randperm(
|
||||||
|
len(self._batches)
|
||||||
|
).tolist() # type: ignore
|
||||||
|
tmp = []
|
||||||
|
for idx in sampler:
|
||||||
|
tmp.append(self._batches[idx])
|
||||||
|
self._batches = tmp
|
||||||
|
|
||||||
|
elif self._batch_ordering == "ascending":
|
||||||
|
self._batches = sorted(
|
||||||
|
self._batches,
|
||||||
|
key=lambda x: max([self._ex_lengths[str(idx)] for idx in x]),
|
||||||
|
)
|
||||||
|
elif self._batch_ordering == "descending":
|
||||||
|
self._batches = sorted(
|
||||||
|
self._batches,
|
||||||
|
key=lambda x: max([self._ex_lengths[str(idx)] for idx in x]),
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _generate_batches(self):
|
||||||
|
logger.info("DynamicBatchSampler: Generating dynamic batches")
|
||||||
|
if self._shuffle_ex:
|
||||||
|
# deterministically shuffle based on epoch and seed
|
||||||
|
gen = paddle.seed(1)
|
||||||
|
gen.manual_seed(self._seed + self._epoch)
|
||||||
|
sampler = paddle.randperm(len(self._dataset)).tolist() # type: ignore
|
||||||
|
else:
|
||||||
|
# take examples as they are: e.g. they have been sorted
|
||||||
|
sampler = range(len(self._dataset)) # type: ignore
|
||||||
|
|
||||||
|
self._batches = []
|
||||||
|
bucket_batches = [[] for i in self._bucket_lens]
|
||||||
|
|
||||||
|
stats_tracker = [
|
||||||
|
{"min": np.inf, "max": -np.inf, "tot": 0, "n_ex": 0}
|
||||||
|
for i in self._bucket_lens
|
||||||
|
]
|
||||||
|
|
||||||
|
for idx in sampler:
|
||||||
|
# length of pre-sampled audio
|
||||||
|
item_len = self._ex_lengths[str(idx)]
|
||||||
|
# bucket to fill up most padding
|
||||||
|
bucket_id = np.searchsorted(self._bucket_boundaries, item_len)
|
||||||
|
# fill audio's duration into that bucket
|
||||||
|
bucket_batches[bucket_id].append(idx)
|
||||||
|
|
||||||
|
stats_tracker[bucket_id]["min"] = min(
|
||||||
|
stats_tracker[bucket_id]["min"], item_len
|
||||||
|
)
|
||||||
|
stats_tracker[bucket_id]["max"] = max(
|
||||||
|
stats_tracker[bucket_id]["max"], item_len
|
||||||
|
)
|
||||||
|
stats_tracker[bucket_id]["tot"] += item_len
|
||||||
|
stats_tracker[bucket_id]["n_ex"] += 1
|
||||||
|
# track #samples - why not duration/#frames; rounded up?
|
||||||
|
# keep track of durations, if necessary
|
||||||
|
|
||||||
|
if (
|
||||||
|
len(bucket_batches[bucket_id]) >= self._bucket_lens[bucket_id]
|
||||||
|
or len(bucket_batches[bucket_id]) >= self._max_batch_ex
|
||||||
|
):
|
||||||
|
self._batches.append(bucket_batches[bucket_id])
|
||||||
|
bucket_batches[bucket_id] = []
|
||||||
|
# keep track of durations
|
||||||
|
|
||||||
|
# Dump remaining batches
|
||||||
|
if not self._drop_last:
|
||||||
|
for batch in bucket_batches:
|
||||||
|
if batch:
|
||||||
|
self._batches.append(batch)
|
||||||
|
|
||||||
|
self._permute_batches() # possibly reorder batches
|
||||||
|
|
||||||
|
if self._epoch == 0: # only log at first epoch
|
||||||
|
# frames per batch & their padding remaining
|
||||||
|
boundaries = [0] + self._bucket_boundaries.tolist()
|
||||||
|
|
||||||
|
for bucket_indx in range(len(self._bucket_boundaries)):
|
||||||
|
try:
|
||||||
|
num_batches = stats_tracker[bucket_indx]["tot"] // (
|
||||||
|
self._max_batch_length
|
||||||
|
)
|
||||||
|
pad_factor = (
|
||||||
|
stats_tracker[bucket_indx]["max"]
|
||||||
|
- stats_tracker[bucket_indx]["min"]
|
||||||
|
) / (
|
||||||
|
stats_tracker[bucket_indx]["tot"]
|
||||||
|
/ stats_tracker[bucket_indx]["n_ex"]
|
||||||
|
)
|
||||||
|
except ZeroDivisionError:
|
||||||
|
num_batches = 0
|
||||||
|
pad_factor = 0
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
(
|
||||||
|
"DynamicBatchSampler: Bucket {} with boundary {:.1f}-{:.1f} and "
|
||||||
|
+ "batch_size {}: Num Examples {:.1f}, Num Full Batches {:.3f}, Pad Factor {:.3f}."
|
||||||
|
).format(
|
||||||
|
bucket_indx,
|
||||||
|
boundaries[bucket_indx],
|
||||||
|
boundaries[bucket_indx + 1],
|
||||||
|
self._bucket_lens[bucket_indx],
|
||||||
|
stats_tracker[bucket_indx]["n_ex"],
|
||||||
|
num_batches,
|
||||||
|
pad_factor * 100,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
batch_stats = {
|
||||||
|
"tot_frames": [],
|
||||||
|
"tot_pad_frames": [],
|
||||||
|
"pad_%": [],
|
||||||
|
}
|
||||||
|
for batch in self._batches:
|
||||||
|
tot_frames = sum(
|
||||||
|
[self._ex_lengths[str(idx)] for idx in batch]
|
||||||
|
)
|
||||||
|
batch_stats["tot_frames"].append(tot_frames)
|
||||||
|
max_frames = max(
|
||||||
|
[self._ex_lengths[str(idx)] for idx in batch]
|
||||||
|
)
|
||||||
|
tot_pad = sum(
|
||||||
|
[
|
||||||
|
max_frames - self._ex_lengths[str(idx)]
|
||||||
|
for idx in batch
|
||||||
|
]
|
||||||
|
)
|
||||||
|
batch_stats["tot_pad_frames"].append(tot_pad)
|
||||||
|
batch_stats["pad_%"].append(tot_pad / tot_frames * 100)
|
||||||
|
|
||||||
|
padding_details = "Batch {} with {:.1f} frames with {} files - {:.1f} padding, {:.2f} (%) of total."
|
||||||
|
padding_details = "DynamicBatchSampler: " + padding_details
|
||||||
|
for i in range(len(self._batches)):
|
||||||
|
logger.info(
|
||||||
|
padding_details.format(
|
||||||
|
i,
|
||||||
|
batch_stats["tot_frames"][i],
|
||||||
|
len(self._batches[i]),
|
||||||
|
batch_stats["tot_pad_frames"][i],
|
||||||
|
batch_stats["pad_%"][i],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for batch in self._batches:
|
||||||
|
yield batch
|
||||||
|
if self._shuffle_ex: # re-generate examples if ex_ordering == "random"
|
||||||
|
self._generate_batches()
|
||||||
|
if self._batch_ordering == "random":
|
||||||
|
# we randomly permute the batches only --> faster
|
||||||
|
self._permute_batches()
|
||||||
|
|
||||||
|
def set_epoch(self, epoch):
|
||||||
|
"""
|
||||||
|
You can also just access self.epoch, but we maintain this interface
|
||||||
|
to mirror torch.utils.data.distributed.DistributedSampler
|
||||||
|
"""
|
||||||
|
self._epoch = epoch
|
||||||
|
self._generate_batches()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._batches)
|
||||||
|
|
||||||
|
|
||||||
|
# Heavily inspired by Catalyst, which is under Apache 2.0 licence.
|
||||||
|
# https://github.com/catalyst-team/catalyst/blob/51428d7756e62b9b8ee5379f38e9fd576eeb36e5/catalyst/data/sampler.py#L522
|
||||||
|
# class DistributedSamplerWrapper(DistributedSampler):
|
||||||
|
# """This wrapper allows using any sampler (for example batch) with Distributed Data Parallel (DDP)
|
||||||
|
# correctly.
|
||||||
|
|
||||||
|
# Passing blindly the sampler to each DDP process will cause to have access
|
||||||
|
# within each process to all the data in the dataset instead of only a subset
|
||||||
|
# of it which is unique to each process. This wrapper prevents this and
|
||||||
|
# allows to use only a subset of the original data for each process.
|
||||||
|
|
||||||
|
# NOTE
|
||||||
|
# ----
|
||||||
|
# This is is automatically applied to any sampler in the Brain class when DDP
|
||||||
|
# training is used.
|
||||||
|
# """
|
||||||
|
|
||||||
|
# def __init__(self, sampler, *args, **kwargs):
|
||||||
|
# # DistributedSampler only calls len() on dataset
|
||||||
|
# # so a sampler is fine to pass there, as well.
|
||||||
|
# super().__init__(dataset=sampler, *args, **kwargs)
|
||||||
|
# self.sampler = sampler
|
||||||
|
|
||||||
|
# def __iter__(self):
|
||||||
|
# # It is easiest to use a random access interface to the wrapped
|
||||||
|
# # sampler's indices, so we just fetch all indices from the wrapped
|
||||||
|
# # sampler
|
||||||
|
# sampler_indices = list(self.sampler.__iter__())
|
||||||
|
# indices_of_indices = super().__iter__()
|
||||||
|
# # Itemgetter fetches the wrapped sampler indices from the positions
|
||||||
|
# # pointed to by DistributedSampler
|
||||||
|
# return iter(itemgetter(*indices_of_indices)(sampler_indices))
|
||||||
|
|
||||||
|
# def set_epoch(self, epoch):
|
||||||
|
# """Pass set_epoch() through to DistributedSampler and the wrapper one"""
|
||||||
|
# super().set_epoch(epoch)
|
||||||
|
# if hasattr(self.sampler, "set_epoch"):
|
||||||
|
# self.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
|
|
||||||
|
class BalancingDataSampler(ReproducibleWeightedRandomSampler):
|
||||||
|
"""A data sampler that takes a single key from the dataset and
|
||||||
|
ensures an approximately equal distribution by that key
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
dataset: DynamicItemDataset
|
||||||
|
the dataset form which samples will be drawn
|
||||||
|
key: str
|
||||||
|
the key from which samples will be taken
|
||||||
|
num_samples : int
|
||||||
|
Number of samples to draw
|
||||||
|
replacement : bool
|
||||||
|
To draw with replacement or not (within an epoch of num_samples).
|
||||||
|
seed : int
|
||||||
|
The base seed to use for the random number generator. It is recommended
|
||||||
|
to use a value which has a good mix of 0 and 1 bits.
|
||||||
|
epoch : int
|
||||||
|
The epoch to start at.
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
>>> from speechbrain.dataio.sampler import BalancingDataSampler
|
||||||
|
>>> from speechbrain.dataio.dataset import DynamicItemDataset
|
||||||
|
>>> sample_data = {
|
||||||
|
... 1: {"category": "A",
|
||||||
|
... "text": "This is a test"},
|
||||||
|
... 2: {"category": "A",
|
||||||
|
... "text": "This is a second test"},
|
||||||
|
... 3: {"category": "B",
|
||||||
|
... "text": "This is a third test"}
|
||||||
|
... }
|
||||||
|
>>> dataset = DynamicItemDataset(data=sample_data)
|
||||||
|
>>> sampler = BalancingDataSampler(
|
||||||
|
... dataset=dataset,
|
||||||
|
... key="category",
|
||||||
|
... num_samples=10
|
||||||
|
... )
|
||||||
|
>>> sampler.weights
|
||||||
|
tensor([0.5000, 0.5000, 1.0000], dtype=torch.float64)
|
||||||
|
>>> it = iter(sampler)
|
||||||
|
>>> [next(it) for _ in range(10)]
|
||||||
|
[2, 2, 1, 2, 2, 0, 1, 1, 1, 2]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset,
|
||||||
|
key,
|
||||||
|
num_samples=None,
|
||||||
|
replacement=True,
|
||||||
|
seed=563375142,
|
||||||
|
epoch=0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.dataset = dataset
|
||||||
|
self.key = key
|
||||||
|
if not num_samples:
|
||||||
|
num_samples = len(dataset)
|
||||||
|
weights = self._compute_weights()
|
||||||
|
super().__init__(
|
||||||
|
weights, num_samples, replacement, seed, epoch, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def _compute_weights(self):
|
||||||
|
with self.dataset.output_keys_as([self.key]):
|
||||||
|
class_ids = [item[self.key] for item in self.dataset]
|
||||||
|
class_counter = Counter(class_ids)
|
||||||
|
weights = 1 / paddle.to_tensor(
|
||||||
|
[class_counter[class_id] for class_id in class_ids]
|
||||||
|
)
|
||||||
|
return weights
|
@ -0,0 +1,162 @@
|
|||||||
|
import transformers
|
||||||
|
from hyperpyyaml import load_hyperpyyaml
|
||||||
|
import dataset
|
||||||
|
import data_pipeline
|
||||||
|
from dataloader import make_dataloader
|
||||||
|
import dataio
|
||||||
|
import paddle
|
||||||
|
import tqdm
|
||||||
|
import numpy
|
||||||
|
def dataio_prepare(hparams):
|
||||||
|
"""This function prepares the datasets to be used in the brain class.
|
||||||
|
It also defines the data processing pipeline through user-defined functions."""
|
||||||
|
data_folder = hparams["data_folder"]
|
||||||
|
|
||||||
|
train_data = dataset.DynamicItemDataset.from_csv(
|
||||||
|
csv_path=hparams["train_data"], replacements={"data_root": data_folder},
|
||||||
|
)
|
||||||
|
|
||||||
|
if hparams["sorting"] == "ascending":
|
||||||
|
# we sort training data to speed up training and get better results.
|
||||||
|
train_data = train_data.filtered_sorted(sort_key="duration")
|
||||||
|
# when sorting do not shuffle in dataloader ! otherwise is pointless
|
||||||
|
hparams["train_dataloader_opts"]["shuffle"] = False
|
||||||
|
|
||||||
|
elif hparams["sorting"] == "descending":
|
||||||
|
train_data = train_data.filtered_sorted(
|
||||||
|
sort_key="duration", reverse=True
|
||||||
|
)
|
||||||
|
# when sorting do not shuffle in dataloader ! otherwise is pointless
|
||||||
|
hparams["train_dataloader_opts"]["shuffle"] = False
|
||||||
|
|
||||||
|
elif hparams["sorting"] == "random":
|
||||||
|
pass
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"sorting must be random, ascending or descending"
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_data = dataset.DynamicItemDataset.from_csv(
|
||||||
|
csv_path=hparams["valid_data"], replacements={"data_root": data_folder},
|
||||||
|
)
|
||||||
|
valid_data = valid_data.filtered_sorted(sort_key="duration")
|
||||||
|
|
||||||
|
test_data = dataset.DynamicItemDataset.from_csv(
|
||||||
|
csv_path=hparams["test_data"], replacements={"data_root": data_folder},
|
||||||
|
)
|
||||||
|
test_data = test_data.filtered_sorted(sort_key="duration")
|
||||||
|
|
||||||
|
datasets = [train_data, valid_data, test_data]
|
||||||
|
|
||||||
|
# Defining tokenizer and loading it
|
||||||
|
tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-chinese')
|
||||||
|
|
||||||
|
# 2. Define audio pipeline:
|
||||||
|
@data_pipeline.takes("wav")
|
||||||
|
@data_pipeline.provides("sig")
|
||||||
|
def audio_pipeline(wav):
|
||||||
|
sig = dataio.read_audio(wav)
|
||||||
|
return sig
|
||||||
|
|
||||||
|
dataset.add_dynamic_item(datasets, audio_pipeline)
|
||||||
|
|
||||||
|
# 3. Define text pipeline:
|
||||||
|
@data_pipeline.takes("transcript")
|
||||||
|
@data_pipeline.provides("wrd", "tokens_list", "tokens")
|
||||||
|
def text_pipeline(wrd):
|
||||||
|
wrd = "".join(wrd.split(" "))
|
||||||
|
yield wrd
|
||||||
|
tokens_list = tokenizer(wrd)["input_ids"]
|
||||||
|
yield tokens_list
|
||||||
|
tokens = numpy.array(tokens_list, dtype="int64")
|
||||||
|
# tokens = paddle.to_tensor(tokens_list, dtype="int64")
|
||||||
|
yield tokens
|
||||||
|
|
||||||
|
dataset.add_dynamic_item(datasets, text_pipeline)
|
||||||
|
|
||||||
|
# 4. Set output:
|
||||||
|
dataset.set_output_keys(
|
||||||
|
datasets, ["id", "sig", "wrd", "tokens"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. If Dynamic Batching is used, we instantiate the needed samplers.
|
||||||
|
train_batch_sampler = None
|
||||||
|
valid_batch_sampler = None
|
||||||
|
if hparams["dynamic_batching"]:
|
||||||
|
from sampler import DynamicBatchSampler # noqa
|
||||||
|
|
||||||
|
dynamic_hparams = hparams["dynamic_batch_sampler"]
|
||||||
|
num_buckets = dynamic_hparams["num_buckets"]
|
||||||
|
|
||||||
|
train_batch_sampler = DynamicBatchSampler(
|
||||||
|
train_data,
|
||||||
|
dynamic_hparams["max_batch_len"],
|
||||||
|
num_buckets=num_buckets,
|
||||||
|
length_func=lambda x: x["duration"],
|
||||||
|
shuffle=dynamic_hparams["shuffle_ex"],
|
||||||
|
batch_ordering=dynamic_hparams["batch_ordering"],
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_batch_sampler = DynamicBatchSampler(
|
||||||
|
valid_data,
|
||||||
|
dynamic_hparams["max_batch_len"],
|
||||||
|
num_buckets=num_buckets,
|
||||||
|
length_func=lambda x: x["duration"],
|
||||||
|
shuffle=dynamic_hparams["shuffle_ex"],
|
||||||
|
batch_ordering=dynamic_hparams["batch_ordering"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
train_data,
|
||||||
|
valid_data,
|
||||||
|
test_data,
|
||||||
|
tokenizer,
|
||||||
|
train_batch_sampler,
|
||||||
|
valid_batch_sampler,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
hparams_file = 'train_with_wav2vec.yaml'
|
||||||
|
with open(hparams_file) as fin:
|
||||||
|
hparams = load_hyperpyyaml(fin, None)
|
||||||
|
|
||||||
|
(
|
||||||
|
train_data,
|
||||||
|
valid_data,
|
||||||
|
test_data,
|
||||||
|
tokenizer,
|
||||||
|
train_bsampler,
|
||||||
|
valid_bsampler,
|
||||||
|
) = dataio_prepare(hparams)
|
||||||
|
|
||||||
|
train_dataloader_opts = hparams["train_dataloader_opts"]
|
||||||
|
valid_dataloader_opts = hparams["valid_dataloader_opts"]
|
||||||
|
|
||||||
|
if train_bsampler is not None:
|
||||||
|
train_dataloader_opts = {
|
||||||
|
"batch_sampler": train_bsampler,
|
||||||
|
"num_workers": hparams["num_workers"],
|
||||||
|
}
|
||||||
|
|
||||||
|
if valid_bsampler is not None:
|
||||||
|
valid_dataloader_opts = {"batch_sampler": valid_bsampler}
|
||||||
|
|
||||||
|
|
||||||
|
train_set = make_dataloader(
|
||||||
|
train_data, stage='train', **train_dataloader_opts
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_set = make_dataloader(
|
||||||
|
valid_data,
|
||||||
|
stage='train',
|
||||||
|
**valid_dataloader_opts,
|
||||||
|
)
|
||||||
|
|
||||||
|
# print(len(train_set))
|
||||||
|
|
||||||
|
for batch in valid_set:
|
||||||
|
print(batch)
|
||||||
|
print('done') # exit()
|
@ -0,0 +1,86 @@
|
|||||||
|
# ############################################################################
|
||||||
|
# Model: CTC-wav2vec2
|
||||||
|
# Encoder: wav2vec2
|
||||||
|
# Decoder: -
|
||||||
|
# Tokens: Char
|
||||||
|
# losses: CTC
|
||||||
|
# Training: AISHELL-1
|
||||||
|
# Authors: Yingzhi WANG 2022
|
||||||
|
# ############################################################################
|
||||||
|
|
||||||
|
seed: 10
|
||||||
|
__set_seed: !apply:torch.manual_seed [!ref <seed>]
|
||||||
|
output_folder: !ref /home/zhangtianhao/workspace/speechbrain/recipes/AISHELL-1/ASR/CTC/results/ctc_wav2vec/<seed>
|
||||||
|
cer_file: !ref <output_folder>/cer.txt
|
||||||
|
save_folder: !ref <output_folder>/save
|
||||||
|
train_log: !ref <output_folder>/train_log.txt
|
||||||
|
|
||||||
|
# Data files
|
||||||
|
data_folder: /home/zhangtianhao/workspace/PaddleSpeech/dataset/aishell # e,g./path/to/aishell
|
||||||
|
|
||||||
|
skip_prep: False
|
||||||
|
ckpt_interval_minutes: 15 # save checkpoint every N min
|
||||||
|
train_data: !ref <output_folder>/train.csv
|
||||||
|
valid_data: !ref <output_folder>/dev.csv
|
||||||
|
test_data: !ref <output_folder>/test.csv
|
||||||
|
|
||||||
|
wav2vec2_hub: TencentGameMate/chinese-wav2vec2-large
|
||||||
|
|
||||||
|
# Training parameters
|
||||||
|
number_of_epochs: 80
|
||||||
|
lr: 1.0
|
||||||
|
lr_wav2vec: 0.0001
|
||||||
|
sorting: ascending
|
||||||
|
auto_mix_prec: False
|
||||||
|
sample_rate: 16000
|
||||||
|
|
||||||
|
# With data_parallel batch_size is split into N jobs
|
||||||
|
# With DDP batch_size is multiplied by N jobs
|
||||||
|
# Must be 8 per GPU to fit 32GB of VRAM
|
||||||
|
batch_size: 12
|
||||||
|
test_batch_size: 8
|
||||||
|
|
||||||
|
dynamic_batching: False
|
||||||
|
dynamic_batch_sampler:
|
||||||
|
feats_hop_size: 0.01
|
||||||
|
max_batch_len: 15 # in terms of "duration" in annotations by default, second here
|
||||||
|
left_bucket_len: 200 # old implementation attributs
|
||||||
|
multiplier: 1.1 # old implementation attributs
|
||||||
|
shuffle_ex: False # if true re-creates batches at each epoch shuffling examples.
|
||||||
|
num_buckets: 10 # floor(log(max_batch_len/left_bucket_len, multiplier)) + 1
|
||||||
|
batch_ordering: ascending
|
||||||
|
|
||||||
|
num_workers: 4
|
||||||
|
|
||||||
|
# Dataloader options
|
||||||
|
train_dataloader_opts:
|
||||||
|
batch_size: !ref <batch_size>
|
||||||
|
num_workers: !ref <num_workers>
|
||||||
|
valid_dataloader_opts:
|
||||||
|
batch_size: !ref <test_batch_size>
|
||||||
|
num_workers: !ref <num_workers>
|
||||||
|
test_dataloader_opts:
|
||||||
|
batch_size: !ref <test_batch_size>
|
||||||
|
num_workers: !ref <num_workers>
|
||||||
|
|
||||||
|
wav2vec_output_dim: 1024
|
||||||
|
dnn_neurons: 1024
|
||||||
|
freeze_wav2vec: False
|
||||||
|
dropout: 0.15
|
||||||
|
|
||||||
|
tokenizer: !apply:transformers.BertTokenizer.from_pretrained
|
||||||
|
pretrained_model_name_or_path: bert-base-chinese
|
||||||
|
# bert-base-chinese tokens length
|
||||||
|
output_neurons: 21128
|
||||||
|
|
||||||
|
# Decoding parameters
|
||||||
|
# Be sure that the bos and eos index match with the BPEs ones
|
||||||
|
blank_index: 0
|
||||||
|
|
||||||
|
# AISHELL-1 has spaces between words in the transcripts,
|
||||||
|
# which Chinese writing normally does not do.
|
||||||
|
# If remove_spaces, spaces are removed
|
||||||
|
# from the transcript before computing CER.
|
||||||
|
# (e.g., 祝 可爱 的 你 —> 祝可爱的你)
|
||||||
|
remove_spaces: True
|
||||||
|
split_tokens: !apply:operator.not_ [!ref <remove_spaces>]
|
Loading…
Reference in new issue