# Copyright (c) 2023 speechbrain Authors. All Rights Reserved. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Modified from speechbrain 2023 (https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/core.py) import paddlespeech.s2t.io.speechbrain.dataloader def _train_loader_specifics(self, dataset, loader_kwargs): sampler = loader_kwargs.get("sampler", None) # Shuffling should really only matter for the train stage. Shuffling # will also lead to more padding in batches if the order was otherwise # sorted by length. shuffle = loader_kwargs.get("shuffle", False) if shuffle and not self.distributed_launch: if sampler is not None: raise ValueError("Cannot specify both shuffle=True" "and a sampler in loader_kwargs") sampler = ReproducibleRandomSampler(dataset) self.train_sampler = sampler loader_kwargs["sampler"] = self.train_sampler # Delete the shuffle flag, since you cannot specify both a sampler and # shuffling: del loader_kwargs["shuffle"] # Possibly make a DistributedSampler or a wrapper for some other sampler if self.distributed_launch and not isinstance(dataset, IterableDataset): drop_last = loader_kwargs.get("drop_last", False) # num_replicas arg is equal to world_size # and retrieved automatically within # DistributedSampler obj. if sampler is not None: self.train_sampler = DistributedSamplerWrapper( sampler, rank=self.rank, drop_last=drop_last, shuffle=shuffle, ) # with DistributedSamplerWrapper, one must disable shuffling for dataloader loader_kwargs["shuffle"] = False loader_kwargs["sampler"] = self.train_sampler elif loader_kwargs.get("batch_sampler") is None: # no sampler and batch-sampler self.train_sampler = DistributedSampler( dataset, rank=self.rank, shuffle=True, drop_last=drop_last) # with DistributedSamplerWrapper, one must disable shuffling for dataloader loader_kwargs["shuffle"] = False loader_kwargs["sampler"] = self.train_sampler else: # batch_sampler was specified self.train_sampler = DistributedSamplerWrapper( loader_kwargs.get("batch_sampler", None), rank=self.rank, shuffle=True, ) loader_kwargs["batch_sampler"] = self.train_sampler elif self.distributed_launch and isinstance(dataset, IterableDataset): logger.warning("Cannot automatically solve distributed sampling " "for IterableDataset.") return loader_kwargs def make_dataloader(self, dataset, stage, **loader_kwargs): """Creates DataLoaders for Datasets. This is used by ``fit()`` and ``evaluate()`` if they just receive Datasets. Alternatively, this can be called from outside the Brain subclass. In that case, the DataLoader should be passed to ``fit()`` in place of the dataset. The Stage.TRAIN DataLoader is handled specially. It has extra args for shuffle and drop_last. In DDP a DistributedSampler is created (unless the dataset is an IterableDataset). NOTE ---- Some important DataLoader arguments are passed via **loader_kwargs, e.g., batch_size, num_workers, pin_memory. NOTE ---- By default, ``evaluate()`` specifies ckpt_prefix=None to stop the test DataLoader being added to the checkpointer. If you need to add a recoverable after saving checkpoints (e.g., at test time, after checkpointing the training), and still be able to recover reasonably, you should probably specify ``allow_partial_load=True``. Arguments --------- dataset : Dataset A set of data to use to create data loader. If the Dataset is a DynamicItemDataset, PaddedBatch is used as the default collate_fn, unless specified in loader_kwargs. stage : Stage The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST ckpt_prefix : str, None Prefix to use for SaveableDataLoader Checkpoint name. The Stage name is added to this to create the full key. Set to None to not save the DataLoader. **loader_kwargs : dict Additional keyword arguments to the DataLoader. E.g., batch_size, num_workers, pin_memory. """ dataloader_ = dataloader.make_dataloader(dataset, **loader_kwargs) return dataloader_