You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/paddlespeech/s2t/io/speechbrain/batch.py

108 lines
4.1 KiB

# Copyright (c) 2023 speechbrain Authors. All Rights Reserved.
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Modified from speechbrain 2023 (https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/dataio/batch.py)
"""Batch collation
Authors
* Aku Rouhe 2020
"""
import collections
import paddle
from paddlespeech.s2t.io.speechbrain.data_utils import batch_pad_right
from paddlespeech.s2t.io.speechbrain.data_utils import mod_default_collate
PaddedData = collections.namedtuple("PaddedData", ["data", "lengths"])
class PaddedBatch:
"""Collate_fn when examples are dicts and have variable-length sequences.
Different elements in the examples get matched by key.
All numpy tensors get converted to paddle.Tensor
Then, by default, all paddle.Tensor valued elements get padded and support
collective pin_memory() and to() calls.
Regular Python data types are just collected in a list.
Arguments
---------
examples : list
List of example dicts, as produced by Dataloader.
padded_keys : list, None
(Optional) List of keys to pad on. If None, pad all paddle.Tensors
device_prep_keys : list, None
(Optional) Only these keys participate in collective memory pinning and moving with
to().
If None, defaults to all items with paddle.Tensor values.
padding_func : callable, optional
Called with a list of tensors to be padded together. Needs to return
two tensors: the padded data, and another tensor for the data lengths.
padding_kwargs : dict
(Optional) Extra kwargs to pass to padding_func. E.G. mode, value
nonpadded_stack : bool
Whether to apply Tensor stacking on values that didn't get padded.
This stacks if it can, but doesn't error out if it cannot.
Default:True, usually does the right thing.
"""
def __init__(
self,
examples,
padded_keys=None,
device_prep_keys=None,
padding_func=batch_pad_right,
padding_kwargs={},
nonpadded_stack=True, ):
self.__length = len(examples)
self.__keys = list(examples[0].keys())
self.__padded_keys = []
self.__device_prep_keys = []
for key in self.__keys:
values = [example[key] for example in examples]
# Default convert usually does the right thing (numpy2tensor etc.)
values = paddle.to_tensor(values)
if (padded_keys is not None and key in padded_keys) or (
padded_keys is None and
isinstance(values[0], paddle.Tensor)):
# Padding and PaddedData
self.__padded_keys.append(key)
padded = PaddedData(*padding_func(values, **padding_kwargs))
setattr(self, key, padded)
else:
if nonpadded_stack:
values = mod_default_collate(values)
setattr(self, key, values)
if (device_prep_keys is not None and key in device_prep_keys) or (
device_prep_keys is None and
isinstance(values[0], paddle.Tensor)):
self.__device_prep_keys.append(key)
def __len__(self):
return self.__length
def __getitem__(self, key):
if key in self.__keys:
return getattr(self, key)
else:
raise KeyError(f"Batch doesn't have key: {key}")
def __iter__(self):
"""Iterates over the different elements of the batch.
"""
return iter((getattr(self, key) for key in self.__keys))