You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
173 lines
6.0 KiB
173 lines
6.0 KiB
# Copyright (c) 2023 speechbrain Authors. All Rights Reserved.
|
|
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
# Modified from speechbrain 2023 (https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/dataio/dataloader.py)
|
|
"""Paddle compatible DataLoaders
|
|
|
|
Essentially we extend Paddle DataLoader by adding the ability to save the
|
|
data loading state, so that a checkpoint may be saved in the middle of an
|
|
epoch.
|
|
|
|
Authors:
|
|
* Aku Rouhe 2020
|
|
"""
|
|
import collections
|
|
import functools
|
|
import logging
|
|
import warnings
|
|
|
|
import paddle
|
|
from paddle.io import DataLoader
|
|
|
|
from paddlespeech.s2t.io.speechbrain.data_utils import batch_pad_right
|
|
from paddlespeech.s2t.io.speechbrain.data_utils import mod_default_collate
|
|
from paddlespeech.s2t.io.speechbrain.dataset import DynamicItemDataset
|
|
from paddlespeech.s2t.io.speechbrain.sampler import ReproducibleRandomSampler
|
|
PaddedData = collections.namedtuple("PaddedData", ["data", "lengths"])
|
|
import numpy
|
|
|
|
|
|
class Wav2vec2DataLoader(DataLoader):
|
|
def __init__(self,
|
|
dataset,
|
|
batch_size=1,
|
|
shuffle=False,
|
|
sampler=None,
|
|
batch_sampler=None,
|
|
num_workers=0,
|
|
collate_fn=None,
|
|
pin_memory=False,
|
|
drop_last=False,
|
|
timeout=0,
|
|
worker_init_fn=None,
|
|
multiprocessing_context=None,
|
|
generator=None):
|
|
if isinstance(dataset[0], (tuple, list)):
|
|
return_list = True
|
|
else:
|
|
return_list = False
|
|
|
|
super().__init__(
|
|
dataset,
|
|
feed_list=None,
|
|
places=None,
|
|
return_list=return_list,
|
|
batch_sampler=batch_sampler,
|
|
batch_size=batch_size,
|
|
shuffle=shuffle,
|
|
drop_last=drop_last,
|
|
collate_fn=collate_fn,
|
|
num_workers=num_workers,
|
|
use_buffer_reader=True,
|
|
use_shared_memory=False,
|
|
timeout=timeout,
|
|
worker_init_fn=worker_init_fn)
|
|
if sampler is not None:
|
|
self.batch_sampler.sampler = sampler
|
|
|
|
|
|
def PaddedBatch(
|
|
examples,
|
|
padded_keys=None,
|
|
device_prep_keys=None,
|
|
padding_func=batch_pad_right,
|
|
padding_kwargs={},
|
|
nonpadded_stack=True, ):
|
|
__length = len(examples)
|
|
__keys = list(examples[0].keys())
|
|
__padded_keys = []
|
|
__device_prep_keys = []
|
|
res = {}
|
|
for key in __keys:
|
|
values = [example[key] for example in examples]
|
|
# Default convert usually does the right thing (numpy2tensor etc.)
|
|
# values = default_convert(values)
|
|
if (padded_keys is not None and key in padded_keys) or (
|
|
padded_keys is None and isinstance(values[0], numpy.ndarray)):
|
|
# Padding and PaddedData
|
|
__padded_keys.append(key)
|
|
|
|
padded = PaddedData(*padding_func(values, **padding_kwargs))
|
|
res[key] = padded
|
|
else:
|
|
# Default collate usually does the right thing
|
|
# (convert lists of equal sized tensors to batch tensors, etc.)
|
|
if nonpadded_stack:
|
|
values = mod_default_collate(values)
|
|
res[key] = values
|
|
if (device_prep_keys is not None and key in device_prep_keys) or (
|
|
device_prep_keys is None and
|
|
isinstance(values[0], paddle.Tensor)):
|
|
__device_prep_keys.append(key)
|
|
return res
|
|
|
|
|
|
def make_dataloader(dataset, stage, **loader_kwargs):
|
|
"""Makes a basic DataLoader.
|
|
|
|
For DynamicItemDatasets (which return dicts), use
|
|
PaddedBatch as the default collate_fn.
|
|
|
|
Shuffling gets implemented by ReproducibleRandomSampler.
|
|
|
|
If the Dataset is not an IterableDataset, the DataLoader
|
|
is a SaveableDataLoader.
|
|
|
|
If the Dataset is a webdataset.dataset.Composable, set default
|
|
batch_size = None.
|
|
|
|
Can also loop over the underlying dataloader continuously,
|
|
and stop iterations at nominal epoch lengths.
|
|
|
|
Arguments
|
|
---------
|
|
dataset : Dataset
|
|
The dataset to make a DataLoader for.
|
|
looped_nominal_epoch : None, int
|
|
If an integer is given, loop the underlying DataLoader infinitely and
|
|
set a nominal epoch length in batches (or whatever the DataLoader
|
|
yields).
|
|
**loader_kwargs : dict
|
|
Keyword args to DataLoader, see Paddle DataLoader for
|
|
options.
|
|
|
|
Returns
|
|
-------
|
|
DataLoader
|
|
If looped_nominal_epoch is None
|
|
LoopedLoader
|
|
If looped_nominal_epoch is not None
|
|
"""
|
|
# PaddedBatch as default collation for DynamicItemDataset
|
|
if "collate_fn" not in loader_kwargs and isinstance(dataset,
|
|
DynamicItemDataset):
|
|
loader_kwargs["collate_fn"] = PaddedBatch
|
|
# Reproducible random sampling
|
|
if loader_kwargs.get("shuffle", False):
|
|
if loader_kwargs.get("sampler") is not None:
|
|
raise ValueError("Cannot specify both shuffle=True and a "
|
|
"sampler in loader_kwargs")
|
|
sampler = ReproducibleRandomSampler(dataset)
|
|
loader_kwargs["sampler"] = sampler
|
|
# Should delete shuffle because you can't set both Sampler and
|
|
# shuffle
|
|
# NOTE: the dict of loader options may get used elsewhere!
|
|
# However, this del doesn't touch those because loader_kwargs comes
|
|
# from a **kwargs dict.
|
|
del loader_kwargs["shuffle"]
|
|
# Create the loader
|
|
dataloader = Wav2vec2DataLoader(dataset, **loader_kwargs)
|
|
return dataloader
|