You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
372 lines
14 KiB
372 lines
14 KiB
# Copyright (c) 2023 speechbrain Authors. All Rights Reserved.
|
|
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
# Modified from speechbrain 2023 (https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/dataio/dataset.py)
|
|
import contextlib
|
|
import copy
|
|
import logging
|
|
from types import MethodType
|
|
|
|
from paddle.io import Dataset
|
|
|
|
from paddlespeech.s2t.io.speechbrain.data_pipeline import DataPipeline
|
|
from paddlespeech.s2t.io.speechbrain.dataio import load_data_csv
|
|
from paddlespeech.s2t.io.speechbrain.dataio import load_data_json
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DynamicItemDataset(Dataset):
|
|
"""Dataset that reads, wrangles, and produces dicts.
|
|
|
|
Each data point dict provides some items (by key), for example, a path to a
|
|
wavefile with the key "wav_file". When a data point is fetched from this
|
|
Dataset, more items are produced dynamically, based on pre-existing items
|
|
and other dynamic created items. For example, a dynamic item could take the
|
|
wavfile path and load the audio from the disk.
|
|
|
|
The dynamic items can depend on other dynamic items: a suitable evaluation
|
|
order is used automatically, as long as there are no circular dependencies.
|
|
|
|
A specified list of keys is collected in the output dict. These can be items
|
|
in the original data or dynamic items. If some dynamic items are not
|
|
requested, nor depended on by other requested items, they won't be computed.
|
|
So for example if a user simply wants to iterate over the text, the
|
|
time-consuming audio loading can be skipped.
|
|
|
|
About the format:
|
|
Takes a dict of dicts as the collection of data points to read/wrangle.
|
|
The top level keys are data point IDs.
|
|
Each data point (example) dict should have the same keys, corresponding to
|
|
different items in that data point.
|
|
|
|
Altogether the data collection could look like this:
|
|
|
|
>>> data = {
|
|
... "spk1utt1": {
|
|
... "wav_file": "/path/to/spk1utt1.wav",
|
|
... "text": "hello world",
|
|
... "speaker": "spk1",
|
|
... },
|
|
... "spk1utt2": {
|
|
... "wav_file": "/path/to/spk1utt2.wav",
|
|
... "text": "how are you world",
|
|
... "speaker": "spk1",
|
|
... }
|
|
... }
|
|
|
|
NOTE
|
|
----
|
|
The top-level key, the data point id, is implicitly added as an item
|
|
in the data point, with the key "id"
|
|
|
|
Each dynamic item is configured by three things: a key, a func, and a list
|
|
of argkeys. The key should be unique among all the items (dynamic or not) in
|
|
each data point. The func is any callable, and it returns the dynamic item's
|
|
value. The callable is called with the values of other items as specified
|
|
by the argkeys list (as positional args, passed in the order specified by
|
|
argkeys).
|
|
|
|
Arguments
|
|
---------
|
|
data : dict
|
|
Dictionary containing single data points (e.g. utterances).
|
|
dynamic_items : list, optional
|
|
Configuration for the dynamic items produced when fetching an example.
|
|
List of DynamicItems or dicts with the format::
|
|
func: <callable> # To be called
|
|
takes: <list> # key or list of keys of args this takes
|
|
provides: key # key or list of keys that this provides
|
|
output_keys : dict, list, optional
|
|
List of keys (either directly available in data or dynamic items)
|
|
to include in the output dict when data points are fetched.
|
|
|
|
If a dict is given; it is used to map internal keys to output keys.
|
|
From the output_keys dict key:value pairs the key appears outside,
|
|
and value is the internal key.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
data,
|
|
dynamic_items=[],
|
|
output_keys=[], ):
|
|
self.data = data
|
|
self.data_ids = list(self.data.keys())
|
|
static_keys = list(self.data[self.data_ids[0]].keys())
|
|
if "id" in static_keys:
|
|
raise ValueError("The key 'id' is reserved for the data point id.")
|
|
else:
|
|
static_keys.append("id")
|
|
self.pipeline = DataPipeline(static_keys, dynamic_items)
|
|
self.set_output_keys(output_keys)
|
|
|
|
def __len__(self):
|
|
return len(self.data_ids)
|
|
|
|
def __getitem__(self, index):
|
|
data_id = self.data_ids[index]
|
|
data_point = self.data[data_id]
|
|
return self.pipeline.compute_outputs({"id": data_id, **data_point})
|
|
|
|
def add_dynamic_item(self, func, takes=None, provides=None):
|
|
"""Makes a new dynamic item available on the dataset.
|
|
|
|
Two calling conventions. For DynamicItem objects, just use:
|
|
add_dynamic_item(dynamic_item).
|
|
But otherwise, should use:
|
|
add_dynamic_item(func, takes, provides).
|
|
|
|
Arguments
|
|
---------
|
|
func : callable, DynamicItem
|
|
If a DynamicItem is given, adds that directly. Otherwise a
|
|
DynamicItem is created, and this specifies the callable to use. If
|
|
a generator function is given, then create a GeneratorDynamicItem.
|
|
Otherwise creates a normal DynamicItem.
|
|
takes : list, str
|
|
List of keys. When func is called, each key is resolved to
|
|
either an entry in the data or the output of another dynamic_item.
|
|
The func is then called with these as positional arguments,
|
|
in the same order as specified here.
|
|
A single arg can be given directly.
|
|
provides : str
|
|
Unique key or keys that this provides.
|
|
"""
|
|
self.pipeline.add_dynamic_item(func, takes, provides)
|
|
|
|
def set_output_keys(self, keys):
|
|
"""Use this to change the output keys.
|
|
|
|
These are the keys that are actually evaluated when a data point
|
|
is fetched from the dataset.
|
|
|
|
Arguments
|
|
---------
|
|
keys : dict, list
|
|
List of keys (str) to produce in output.
|
|
|
|
If a dict is given; it is used to map internal keys to output keys.
|
|
From the output_keys dict key:value pairs the key appears outside,
|
|
and value is the internal key.
|
|
"""
|
|
self.pipeline.set_output_keys(keys)
|
|
|
|
@contextlib.contextmanager
|
|
def output_keys_as(self, keys):
|
|
"""Context manager to temporarily set output keys.
|
|
|
|
NOTE
|
|
----
|
|
Not thread-safe. While in this context manager, the output keys
|
|
are affected for any call.
|
|
"""
|
|
saved_output = self.pipeline.output_mapping
|
|
self.pipeline.set_output_keys(keys)
|
|
yield self
|
|
self.pipeline.set_output_keys(saved_output)
|
|
|
|
def filtered_sorted(
|
|
self,
|
|
key_min_value={},
|
|
key_max_value={},
|
|
key_test={},
|
|
sort_key=None,
|
|
reverse=False,
|
|
select_n=None, ):
|
|
"""Get a filtered and/or sorted version of this, shares static data.
|
|
|
|
The reason to implement these operations in the same method is that
|
|
computing some dynamic items may be expensive, and this way the
|
|
filtering and sorting steps don't need to compute the dynamic items
|
|
twice.
|
|
|
|
Arguments
|
|
---------
|
|
key_min_value : dict
|
|
Map from key (in data or in dynamic items) to limit, will only keep
|
|
data_point if data_point[key] >= limit
|
|
key_max_value : dict
|
|
Map from key (in data or in dynamic items) to limit, will only keep
|
|
data_point if data_point[key] <= limit
|
|
key_test : dict
|
|
Map from key (in data or in dynamic items) to func, will only keep
|
|
data_point if bool(func(data_point[key])) == True
|
|
sort_key : None, str
|
|
If not None, sort by data_point[sort_key]. Default is ascending
|
|
order.
|
|
reverse : bool
|
|
If True, sort in descending order.
|
|
select_n : None, int
|
|
If not None, only keep (at most) the first n filtered data_points.
|
|
The possible sorting is applied, but only on the first n data
|
|
points found. Meant for debugging.
|
|
|
|
Returns
|
|
-------
|
|
FilteredSortedDynamicItemDataset
|
|
Shares the static data, but has its own output keys and
|
|
dynamic items (initially deep copied from this, so they have the
|
|
same dynamic items available)
|
|
|
|
NOTE
|
|
----
|
|
Temporarily changes the output keys!
|
|
"""
|
|
filtered_sorted_ids = self._filtered_sorted_ids(
|
|
key_min_value,
|
|
key_max_value,
|
|
key_test,
|
|
sort_key,
|
|
reverse,
|
|
select_n, )
|
|
return FilteredSortedDynamicItemDataset(
|
|
self, filtered_sorted_ids) # NOTE: defined below
|
|
|
|
def _filtered_sorted_ids(
|
|
self,
|
|
key_min_value={},
|
|
key_max_value={},
|
|
key_test={},
|
|
sort_key=None,
|
|
reverse=False,
|
|
select_n=None, ):
|
|
"""Returns a list of data ids, fulfilling the sorting and filtering."""
|
|
|
|
def combined_filter(computed):
|
|
"""Applies filter."""
|
|
for key, limit in key_min_value.items():
|
|
# NOTE: docstring promises >= so using that.
|
|
# Mathematically could also use < for nicer syntax, but
|
|
# maybe with some super special weird edge case some one can
|
|
# depend on the >= operator
|
|
if computed[key] >= limit:
|
|
continue
|
|
return False
|
|
for key, limit in key_max_value.items():
|
|
if computed[key] <= limit:
|
|
continue
|
|
return False
|
|
for key, func in key_test.items():
|
|
if bool(func(computed[key])):
|
|
continue
|
|
return False
|
|
return True
|
|
|
|
temp_keys = (set(key_min_value.keys()) | set(key_max_value.keys()) |
|
|
set(key_test.keys()) |
|
|
set([] if sort_key is None else [sort_key]))
|
|
filtered_ids = []
|
|
with self.output_keys_as(temp_keys):
|
|
for i, data_id in enumerate(self.data_ids):
|
|
if select_n is not None and len(filtered_ids) == select_n:
|
|
break
|
|
data_point = self.data[data_id]
|
|
data_point["id"] = data_id
|
|
computed = self.pipeline.compute_outputs(data_point)
|
|
if combined_filter(computed):
|
|
if sort_key is not None:
|
|
# Add (main sorting index, current index, data_id)
|
|
# So that we maintain current sorting and don't compare
|
|
# data_id values ever.
|
|
filtered_ids.append((computed[sort_key], i, data_id))
|
|
else:
|
|
filtered_ids.append(data_id)
|
|
if sort_key is not None:
|
|
filtered_sorted_ids = [
|
|
tup[2] for tup in sorted(filtered_ids, reverse=reverse)
|
|
]
|
|
else:
|
|
filtered_sorted_ids = filtered_ids
|
|
return filtered_sorted_ids
|
|
|
|
@classmethod
|
|
def from_json(cls,
|
|
json_path,
|
|
replacements={},
|
|
dynamic_items=[],
|
|
output_keys=[]):
|
|
"""Load a data prep JSON file and create a Dataset based on it."""
|
|
data = load_data_json(json_path, replacements)
|
|
return cls(data, dynamic_items, output_keys)
|
|
|
|
@classmethod
|
|
def from_csv(cls,
|
|
csv_path,
|
|
replacements={},
|
|
dynamic_items=[],
|
|
output_keys=[]):
|
|
"""Load a data prep CSV file and create a Dataset based on it."""
|
|
data = load_data_csv(csv_path, replacements)
|
|
return cls(data, dynamic_items, output_keys)
|
|
|
|
@classmethod
|
|
def from_arrow_dataset(cls,
|
|
dataset,
|
|
replacements={},
|
|
dynamic_items=[],
|
|
output_keys=[]):
|
|
"""Loading a prepared huggingface dataset"""
|
|
|
|
# define an unbound method to generate puesdo keys
|
|
def keys(self):
|
|
"Returns the keys."
|
|
return [i for i in range(dataset.__len__())]
|
|
|
|
# bind this method to arrow dataset
|
|
dataset.keys = MethodType(keys, dataset)
|
|
return cls(dataset, dynamic_items, output_keys)
|
|
|
|
|
|
class FilteredSortedDynamicItemDataset(DynamicItemDataset):
|
|
"""Possibly filtered, possibly sorted DynamicItemDataset.
|
|
|
|
Shares the static data (reference).
|
|
Has its own dynamic_items and output_keys (deepcopy).
|
|
"""
|
|
|
|
def __init__(self, from_dataset, data_ids):
|
|
self.data = from_dataset.data
|
|
self.data_ids = data_ids
|
|
self.pipeline = copy.deepcopy(from_dataset.pipeline)
|
|
|
|
@classmethod
|
|
def from_json(cls,
|
|
json_path,
|
|
replacements={},
|
|
dynamic_items=None,
|
|
output_keys=None):
|
|
raise TypeError("Cannot create SubsetDynamicItemDataset directly!")
|
|
|
|
@classmethod
|
|
def from_csv(cls,
|
|
csv_path,
|
|
replacements={},
|
|
dynamic_items=None,
|
|
output_keys=None):
|
|
raise TypeError("Cannot create SubsetDynamicItemDataset directly!")
|
|
|
|
|
|
def add_dynamic_item(datasets, func, takes=None, provides=None):
|
|
"""Helper for adding the same item to multiple datasets."""
|
|
for dataset in datasets:
|
|
dataset.add_dynamic_item(func, takes, provides)
|
|
|
|
|
|
def set_output_keys(datasets, output_keys):
|
|
"""Helper for setting the same item to multiple datasets."""
|
|
for dataset in datasets:
|
|
dataset.set_output_keys(output_keys)
|