You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
108 lines
4.1 KiB
108 lines
4.1 KiB
2 years ago
|
# Copyright (c) 2023 speechbrain Authors. All Rights Reserved.
|
||
|
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
#
|
||
|
# Modified from speechbrain 2023 (https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/dataio/batch.py)
|
||
|
"""Batch collation
|
||
|
|
||
|
Authors
|
||
|
* Aku Rouhe 2020
|
||
|
"""
|
||
|
import collections
|
||
|
|
||
|
import paddle
|
||
|
|
||
|
from paddlespeech.s2t.io.speechbrain.data_utils import batch_pad_right
|
||
|
from paddlespeech.s2t.io.speechbrain.data_utils import mod_default_collate
|
||
|
|
||
|
PaddedData = collections.namedtuple("PaddedData", ["data", "lengths"])
|
||
|
|
||
|
|
||
|
class PaddedBatch:
|
||
|
"""Collate_fn when examples are dicts and have variable-length sequences.
|
||
|
|
||
|
Different elements in the examples get matched by key.
|
||
|
All numpy tensors get converted to paddle.Tensor
|
||
|
Then, by default, all paddle.Tensor valued elements get padded and support
|
||
|
collective pin_memory() and to() calls.
|
||
|
Regular Python data types are just collected in a list.
|
||
|
|
||
|
Arguments
|
||
|
---------
|
||
|
examples : list
|
||
|
List of example dicts, as produced by Dataloader.
|
||
|
padded_keys : list, None
|
||
|
(Optional) List of keys to pad on. If None, pad all paddle.Tensors
|
||
|
device_prep_keys : list, None
|
||
|
(Optional) Only these keys participate in collective memory pinning and moving with
|
||
|
to().
|
||
|
If None, defaults to all items with paddle.Tensor values.
|
||
|
padding_func : callable, optional
|
||
|
Called with a list of tensors to be padded together. Needs to return
|
||
|
two tensors: the padded data, and another tensor for the data lengths.
|
||
|
padding_kwargs : dict
|
||
|
(Optional) Extra kwargs to pass to padding_func. E.G. mode, value
|
||
|
nonpadded_stack : bool
|
||
|
Whether to apply Tensor stacking on values that didn't get padded.
|
||
|
This stacks if it can, but doesn't error out if it cannot.
|
||
|
Default:True, usually does the right thing.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
examples,
|
||
|
padded_keys=None,
|
||
|
device_prep_keys=None,
|
||
|
padding_func=batch_pad_right,
|
||
|
padding_kwargs={},
|
||
|
nonpadded_stack=True, ):
|
||
|
self.__length = len(examples)
|
||
|
self.__keys = list(examples[0].keys())
|
||
|
self.__padded_keys = []
|
||
|
self.__device_prep_keys = []
|
||
|
for key in self.__keys:
|
||
|
values = [example[key] for example in examples]
|
||
|
# Default convert usually does the right thing (numpy2tensor etc.)
|
||
|
values = paddle.to_tensor(values)
|
||
|
|
||
|
if (padded_keys is not None and key in padded_keys) or (
|
||
|
padded_keys is None and
|
||
|
isinstance(values[0], paddle.Tensor)):
|
||
|
# Padding and PaddedData
|
||
|
self.__padded_keys.append(key)
|
||
|
padded = PaddedData(*padding_func(values, **padding_kwargs))
|
||
|
setattr(self, key, padded)
|
||
|
else:
|
||
|
if nonpadded_stack:
|
||
|
values = mod_default_collate(values)
|
||
|
setattr(self, key, values)
|
||
|
if (device_prep_keys is not None and key in device_prep_keys) or (
|
||
|
device_prep_keys is None and
|
||
|
isinstance(values[0], paddle.Tensor)):
|
||
|
self.__device_prep_keys.append(key)
|
||
|
|
||
|
def __len__(self):
|
||
|
return self.__length
|
||
|
|
||
|
def __getitem__(self, key):
|
||
|
if key in self.__keys:
|
||
|
return getattr(self, key)
|
||
|
else:
|
||
|
raise KeyError(f"Batch doesn't have key: {key}")
|
||
|
|
||
|
def __iter__(self):
|
||
|
"""Iterates over the different elements of the batch.
|
||
|
"""
|
||
|
return iter((getattr(self, key) for key in self.__keys))
|