You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/paddlespeech/t2s/datasets/data_table.py

187 lines
6.2 KiB

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from multiprocessing import Manager
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
import numpy as np
from paddle.io import Dataset
class DataTable(Dataset):
"""Dataset to load and convert data for general purpose.
Args:
data (List[Dict[str, Any]]): Metadata, a list of meta datum, each of which is composed of several fields
fields (List[str], optional): Fields to use, if not specified, all the fields in the data are used, by default None
converters (Dict[str, Callable], optional): Converters used to process each field, by default None
use_cache (bool, optional): Whether to use cache, by default False
Raises:
ValueError:
If there is some field that does not exist in data.
ValueError:
If there is some field in converters that does not exist in fields.
"""
def __init__(self,
data: List[Dict[str, Any]],
fields: List[str]=None,
converters: Dict[str, Callable]=None,
use_cache: bool=False):
# metadata
self.data = data
assert len(data) > 0, "This dataset has no examples"
# peak an example to get existing fields.
first_example = self.data[0]
fields_in_data = first_example.keys()
# check all the requested fields exist
if fields is None:
self.fields = fields_in_data
else:
for field in fields:
if field not in fields_in_data:
raise ValueError(
f"The requested field ({field}) is not found"
f"in the data. Fields in the data is {fields_in_data}")
self.fields = fields
# check converters
if converters is None:
self.converters = {}
else:
for field in converters.keys():
if field not in self.fields:
raise ValueError(
f"The converter has a non existing field ({field})")
self.converters = converters
self.use_cache = use_cache
if use_cache:
self._initialize_cache()
def _initialize_cache(self):
self.manager = Manager()
self.caches = self.manager.list()
self.caches += [None for _ in range(len(self))]
def _get_metadata(self, idx: int) -> Dict[str, Any]:
"""Return a meta-datum given an index."""
return self.data[idx]
def _convert(self, meta_datum: Dict[str, Any]) -> Dict[str, Any]:
"""Convert a meta datum to an example by applying the corresponding
converters to each fields requested.
Args:
meta_datum (Dict[str, Any]): Meta datum
Returns:
Dict[str, Any]: Converted example
"""
example = {}
for field in self.fields:
converter = self.converters.get(field, None)
meta_datum_field = meta_datum[field]
if converter is not None:
converted_field = converter(meta_datum_field)
else:
converted_field = meta_datum_field
example[field] = converted_field
return example
def __getitem__(self, idx: int) -> Dict[str, Any]:
"""Get an example given an index.
Args:
idx (int): Index of the example to get
Returns:
Dict[str, Any]: A converted example
"""
if self.use_cache and self.caches[idx] is not None:
return self.caches[idx]
meta_datum = self._get_metadata(idx)
example = self._convert(meta_datum)
if self.use_cache:
self.caches[idx] = example
return example
def __len__(self) -> int:
"""Returns the size of the dataset.
Returns
-------
int
The length of the dataset
"""
return len(self.data)
class StarGANv2VCDataTable(DataTable):
def __init__(self, data: List[Dict[str, Any]]):
super().__init__(data)
raw_data = data
spk_id_set = list(set([item['spk_id'] for item in raw_data]))
data_list_per_class = {}
for spk_id in spk_id_set:
data_list_per_class[spk_id] = []
for item in raw_data:
for spk_id in spk_id_set:
if item['spk_id'] == spk_id:
data_list_per_class[spk_id].append(item)
self.data_list_per_class = data_list_per_class
def __getitem__(self, idx: int) -> Dict[str, Any]:
"""Get an example given an index.
Args:
idx (int): Index of the example to get
Returns:
Dict[str, Any]: A converted example
"""
if self.use_cache and self.caches[idx] is not None:
return self.caches[idx]
data = self._get_metadata(idx)
# 裁剪放到 batch_fn 里面
# 返回一个字典
"""
{'utt_id': 'p225_111', 'spk_id': '1', 'speech': 'path of *.npy'}
"""
ref_data = random.choice(self.data)
ref_label = ref_data['spk_id']
ref_data_2 = random.choice(self.data_list_per_class[ref_label])
# mel_tensor, label, ref_mel_tensor, ref2_mel_tensor, ref_label
new_example = {
'utt_id': data['utt_id'],
'mel': np.load(data['speech']),
'label': int(data['spk_id']),
'ref_mel': np.load(ref_data['speech']),
'ref_mel_2': np.load(ref_data_2['speech']),
'ref_label': int(ref_label)
}
if self.use_cache:
self.caches[idx] = new_example
return new_example