You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/paddlespeech/s2t/io/speechbrain/dataloader.py

173 lines
6.0 KiB

# Copyright (c) 2023 speechbrain Authors. All Rights Reserved.
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Modified from speechbrain 2023 (https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/dataio/dataloader.py)
"""Paddle compatible DataLoaders
Essentially we extend Paddle DataLoader by adding the ability to save the
data loading state, so that a checkpoint may be saved in the middle of an
epoch.
Authors:
* Aku Rouhe 2020
"""
import collections
import functools
import logging
import warnings
import paddle
from paddle.io import DataLoader
from paddlespeech.s2t.io.speechbrain.data_utils import batch_pad_right
from paddlespeech.s2t.io.speechbrain.data_utils import mod_default_collate
from paddlespeech.s2t.io.speechbrain.dataset import DynamicItemDataset
from paddlespeech.s2t.io.speechbrain.sampler import ReproducibleRandomSampler
PaddedData = collections.namedtuple("PaddedData", ["data", "lengths"])
import numpy
class Wav2vec2DataLoader(DataLoader):
def __init__(self,
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
multiprocessing_context=None,
generator=None):
if isinstance(dataset[0], (tuple, list)):
return_list = True
else:
return_list = False
super().__init__(
dataset,
feed_list=None,
places=None,
return_list=return_list,
batch_sampler=batch_sampler,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
collate_fn=collate_fn,
num_workers=num_workers,
use_buffer_reader=True,
use_shared_memory=False,
timeout=timeout,
worker_init_fn=worker_init_fn)
if sampler is not None:
self.batch_sampler.sampler = sampler
def PaddedBatch(
examples,
padded_keys=None,
device_prep_keys=None,
padding_func=batch_pad_right,
padding_kwargs={},
nonpadded_stack=True, ):
__length = len(examples)
__keys = list(examples[0].keys())
__padded_keys = []
__device_prep_keys = []
res = {}
for key in __keys:
values = [example[key] for example in examples]
# Default convert usually does the right thing (numpy2tensor etc.)
# values = default_convert(values)
if (padded_keys is not None and key in padded_keys) or (
padded_keys is None and isinstance(values[0], numpy.ndarray)):
# Padding and PaddedData
__padded_keys.append(key)
padded = PaddedData(*padding_func(values, **padding_kwargs))
res[key] = padded
else:
# Default collate usually does the right thing
# (convert lists of equal sized tensors to batch tensors, etc.)
if nonpadded_stack:
values = mod_default_collate(values)
res[key] = values
if (device_prep_keys is not None and key in device_prep_keys) or (
device_prep_keys is None and
isinstance(values[0], paddle.Tensor)):
__device_prep_keys.append(key)
return res
def make_dataloader(dataset, stage, **loader_kwargs):
"""Makes a basic DataLoader.
For DynamicItemDatasets (which return dicts), use
PaddedBatch as the default collate_fn.
Shuffling gets implemented by ReproducibleRandomSampler.
If the Dataset is not an IterableDataset, the DataLoader
is a SaveableDataLoader.
If the Dataset is a webdataset.dataset.Composable, set default
batch_size = None.
Can also loop over the underlying dataloader continuously,
and stop iterations at nominal epoch lengths.
Arguments
---------
dataset : Dataset
The dataset to make a DataLoader for.
looped_nominal_epoch : None, int
If an integer is given, loop the underlying DataLoader infinitely and
set a nominal epoch length in batches (or whatever the DataLoader
yields).
**loader_kwargs : dict
Keyword args to DataLoader, see Paddle DataLoader for
options.
Returns
-------
DataLoader
If looped_nominal_epoch is None
LoopedLoader
If looped_nominal_epoch is not None
"""
# PaddedBatch as default collation for DynamicItemDataset
if "collate_fn" not in loader_kwargs and isinstance(dataset,
DynamicItemDataset):
loader_kwargs["collate_fn"] = PaddedBatch
# Reproducible random sampling
if loader_kwargs.get("shuffle", False):
if loader_kwargs.get("sampler") is not None:
raise ValueError("Cannot specify both shuffle=True and a "
"sampler in loader_kwargs")
sampler = ReproducibleRandomSampler(dataset)
loader_kwargs["sampler"] = sampler
# Should delete shuffle because you can't set both Sampler and
# shuffle
# NOTE: the dict of loader options may get used elsewhere!
# However, this del doesn't touch those because loader_kwargs comes
# from a **kwargs dict.
del loader_kwargs["shuffle"]
# Create the loader
dataloader = Wav2vec2DataLoader(dataset, **loader_kwargs)
return dataloader