# Copyright (c) 2023 speechbrain Authors. All Rights Reserved. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Modified from speechbrain 2023 (https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/dataio/batch.py) """Batch collation Authors * Aku Rouhe 2020 """ import collections import paddle from paddlespeech.s2t.io.speechbrain.data_utils import batch_pad_right from paddlespeech.s2t.io.speechbrain.data_utils import mod_default_collate PaddedData = collections.namedtuple("PaddedData", ["data", "lengths"]) class PaddedBatch: """Collate_fn when examples are dicts and have variable-length sequences. Different elements in the examples get matched by key. All numpy tensors get converted to paddle.Tensor Then, by default, all paddle.Tensor valued elements get padded and support collective pin_memory() and to() calls. Regular Python data types are just collected in a list. Arguments --------- examples : list List of example dicts, as produced by Dataloader. padded_keys : list, None (Optional) List of keys to pad on. If None, pad all paddle.Tensors device_prep_keys : list, None (Optional) Only these keys participate in collective memory pinning and moving with to(). If None, defaults to all items with paddle.Tensor values. padding_func : callable, optional Called with a list of tensors to be padded together. Needs to return two tensors: the padded data, and another tensor for the data lengths. padding_kwargs : dict (Optional) Extra kwargs to pass to padding_func. E.G. mode, value nonpadded_stack : bool Whether to apply Tensor stacking on values that didn't get padded. This stacks if it can, but doesn't error out if it cannot. Default:True, usually does the right thing. """ def __init__( self, examples, padded_keys=None, device_prep_keys=None, padding_func=batch_pad_right, padding_kwargs={}, nonpadded_stack=True, ): self.__length = len(examples) self.__keys = list(examples[0].keys()) self.__padded_keys = [] self.__device_prep_keys = [] for key in self.__keys: values = [example[key] for example in examples] # Default convert usually does the right thing (numpy2tensor etc.) values = paddle.to_tensor(values) if (padded_keys is not None and key in padded_keys) or ( padded_keys is None and isinstance(values[0], paddle.Tensor)): # Padding and PaddedData self.__padded_keys.append(key) padded = PaddedData(*padding_func(values, **padding_kwargs)) setattr(self, key, padded) else: if nonpadded_stack: values = mod_default_collate(values) setattr(self, key, values) if (device_prep_keys is not None and key in device_prep_keys) or ( device_prep_keys is None and isinstance(values[0], paddle.Tensor)): self.__device_prep_keys.append(key) def __len__(self): return self.__length def __getitem__(self, key): if key in self.__keys: return getattr(self, key) else: raise KeyError(f"Batch doesn't have key: {key}") def __iter__(self): """Iterates over the different elements of the batch. """ return iter((getattr(self, key) for key in self.__keys))