From 683679bec72009517f6352395b6a133018cc92dd Mon Sep 17 00:00:00 2001 From: TianYuan Date: Thu, 10 Feb 2022 12:41:24 +0000 Subject: [PATCH] merge data and datasets, test=tts --- paddlespeech/t2s/__init__.py | 1 - paddlespeech/t2s/data/__init__.py | 17 -- paddlespeech/t2s/data/dataset.py | 261 ------------------ paddlespeech/t2s/datasets/__init__.py | 1 - paddlespeech/t2s/datasets/am_batch_fn.py | 2 +- paddlespeech/t2s/{data => datasets}/batch.py | 0 paddlespeech/t2s/datasets/common.py | 92 ------ .../t2s/{data => datasets}/get_feats.py | 0 .../t2s/exps/fastspeech2/preprocess.py | 6 +- .../parallelwave_gan/synthesize_from_wav.py | 2 +- .../t2s/exps/gan_vocoder/preprocess.py | 2 +- .../t2s/exps/speedyspeech/preprocess.py | 2 +- paddlespeech/t2s/exps/tacotron2/preprocess.py | 2 +- .../t2s/exps/transformer_tts/preprocess.py | 2 +- paddlespeech/t2s/exps/waveflow/ljspeech.py | 4 +- 15 files changed, 11 insertions(+), 383 deletions(-) delete mode 100644 paddlespeech/t2s/data/__init__.py delete mode 100644 paddlespeech/t2s/data/dataset.py rename paddlespeech/t2s/{data => datasets}/batch.py (100%) delete mode 100644 paddlespeech/t2s/datasets/common.py rename paddlespeech/t2s/{data => datasets}/get_feats.py (100%) diff --git a/paddlespeech/t2s/__init__.py b/paddlespeech/t2s/__init__.py index 8a0acc48a..7d93c026e 100644 --- a/paddlespeech/t2s/__init__.py +++ b/paddlespeech/t2s/__init__.py @@ -13,7 +13,6 @@ # limitations under the License. import logging -from . import data from . import datasets from . import exps from . import frontend diff --git a/paddlespeech/t2s/data/__init__.py b/paddlespeech/t2s/data/__init__.py deleted file mode 100644 index c605205d6..000000000 --- a/paddlespeech/t2s/data/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""t2s's infrastructure for data processing. -""" -from .batch import * -from .dataset import * diff --git a/paddlespeech/t2s/data/dataset.py b/paddlespeech/t2s/data/dataset.py deleted file mode 100644 index 2d6c03cb1..000000000 --- a/paddlespeech/t2s/data/dataset.py +++ /dev/null @@ -1,261 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import six -from paddle.io import Dataset - -__all__ = [ - "split", - "TransformDataset", - "CacheDataset", - "TupleDataset", - "DictDataset", - "SliceDataset", - "SubsetDataset", - "FilterDataset", - "ChainDataset", -] - - -def split(dataset, first_size): - """A utility function to split a dataset into two datasets.""" - first = SliceDataset(dataset, 0, first_size) - second = SliceDataset(dataset, first_size, len(dataset)) - return first, second - - -class TransformDataset(Dataset): - def __init__(self, dataset, transform): - """Dataset which is transformed from another with a transform. - - Args: - dataset (Dataset): the base dataset. - transform (callable): the transform which takes an example of the base dataset as parameter and return a new example. - """ - self._dataset = dataset - self._transform = transform - - def __len__(self): - return len(self._dataset) - - def __getitem__(self, i): - in_data = self._dataset[i] - return self._transform(in_data) - - -class CacheDataset(Dataset): - def __init__(self, dataset): - """A lazy cache of the base dataset. - - Args: - dataset (Dataset): the base dataset to cache. - """ - self._dataset = dataset - self._cache = dict() - - def __len__(self): - return len(self._dataset) - - def __getitem__(self, i): - if i not in self._cache: - self._cache[i] = self._dataset[i] - return self._cache[i] - - -class TupleDataset(Dataset): - def __init__(self, *datasets): - """A compound dataset made from several datasets of the same length. An example of the `TupleDataset` is a tuple of examples from the constituent datasets. - - Args: - datasets: tuple[Dataset], the constituent datasets. - """ - if not datasets: - raise ValueError("no datasets are given") - length = len(datasets[0]) - for i, dataset in enumerate(datasets): - if len(dataset) != length: - raise ValueError("all the datasets should have the same length." - "dataset {} has a different length".format(i)) - self._datasets = datasets - self._length = length - - def __getitem__(self, index): - # SOA - batches = [dataset[index] for dataset in self._datasets] - if isinstance(index, slice): - length = len(batches[0]) - # AOS - return [ - tuple([batch[i] for batch in batches]) - for i in six.moves.range(length) - ] - else: - return tuple(batches) - - def __len__(self): - return self._length - - -class DictDataset(Dataset): - def __init__(self, **datasets): - """ - A compound dataset made from several datasets of the same length. An - example of the `DictDataset` is a dict of examples from the constituent - datasets. - - WARNING: paddle does not have a good support for DictDataset, because - every batch yield from a DataLoader is a list, but it cannot be a dict. - So you have to provide a collate function because you cannot use the - default one. - - Args: - datasets: Dict[Dataset], the constituent datasets. - """ - if not datasets: - raise ValueError("no datasets are given") - length = None - for key, dataset in six.iteritems(datasets): - if length is None: - length = len(dataset) - elif len(dataset) != length: - raise ValueError( - "all the datasets should have the same length." - "dataset {} has a different length".format(key)) - self._datasets = datasets - self._length = length - - def __getitem__(self, index): - batches = { - key: dataset[index] - for key, dataset in six.iteritems(self._datasets) - } - if isinstance(index, slice): - length = len(six.next(six.itervalues(batches))) - return [{key: batch[i] - for key, batch in six.iteritems(batches)} - for i in six.moves.range(length)] - else: - return batches - - def __len__(self): - return self._length - - -class SliceDataset(Dataset): - def __init__(self, dataset, start, finish, order=None): - """A Dataset which is a slice of the base dataset. - - Args: - dataset (Dataset): the base dataset. - start (int): the start of the slice. - finish (int): the end of the slice, not inclusive. - order (List[int], optional): the order, it is a permutation of the valid example ids of the base dataset. If `order` is provided, the slice is taken in `order`. Defaults to None. - """ - if start < 0 or finish > len(dataset): - raise ValueError("subset overruns the dataset.") - self._dataset = dataset - self._start = start - self._finish = finish - self._size = finish - start - - if order is not None and len(order) != len(dataset): - raise ValueError( - "order should have the same length as the dataset" - "len(order) = {} which does not euqals len(dataset) = {} ". - format(len(order), len(dataset))) - self._order = order - - def __len__(self): - return self._size - - def __getitem__(self, i): - if i >= 0: - if i >= self._size: - raise IndexError('dataset index out of range') - index = self._start + i - else: - if i < -self._size: - raise IndexError('dataset index out of range') - index = self._finish + i - - if self._order is not None: - index = self._order[index] - return self._dataset[index] - - -class SubsetDataset(Dataset): - def __init__(self, dataset, indices): - """A Dataset which is a subset of the base dataset. - - Args: - dataset (Dataset): the base dataset. - indices (Iterable[int]): the indices of the examples to pick. - """ - self._dataset = dataset - if len(indices) > len(dataset): - raise ValueError("subset's size larger that dataset's size!") - self._indices = indices - self._size = len(indices) - - def __len__(self): - return self._size - - def __getitem__(self, i): - index = self._indices[i] - return self._dataset[index] - - -class FilterDataset(Dataset): - def __init__(self, dataset, filter_fn): - """A filtered dataset. - - Args: - dataset (Dataset): the base dataset. - filter_fn (callable): a callable which takes an example of the base dataset and return a boolean. - """ - self._dataset = dataset - self._indices = [ - i for i in range(len(dataset)) if filter_fn(dataset[i]) - ] - self._size = len(self._indices) - - def __len__(self): - return self._size - - def __getitem__(self, i): - index = self._indices[i] - return self._dataset[index] - - -class ChainDataset(Dataset): - def __init__(self, *datasets): - """A concatenation of the several datasets which the same structure. - - Args: - datasets (Iterable[Dataset]): datasets to concat. - """ - self._datasets = datasets - - def __len__(self): - return sum(len(dataset) for dataset in self._datasets) - - def __getitem__(self, i): - if i < 0: - raise IndexError("ChainDataset doesnot support negative indexing.") - - for dataset in self._datasets: - if i < len(dataset): - return dataset[i] - i -= len(dataset) - - raise IndexError("dataset index out of range") diff --git a/paddlespeech/t2s/datasets/__init__.py b/paddlespeech/t2s/datasets/__init__.py index fc64a82f2..caf20aac4 100644 --- a/paddlespeech/t2s/datasets/__init__.py +++ b/paddlespeech/t2s/datasets/__init__.py @@ -11,5 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .common import * from .ljspeech import * diff --git a/paddlespeech/t2s/datasets/am_batch_fn.py b/paddlespeech/t2s/datasets/am_batch_fn.py index 655e06e37..4e3ad3c12 100644 --- a/paddlespeech/t2s/datasets/am_batch_fn.py +++ b/paddlespeech/t2s/datasets/am_batch_fn.py @@ -14,7 +14,7 @@ import numpy as np import paddle -from paddlespeech.t2s.data.batch import batch_sequences +from paddlespeech.t2s.datasets.batch import batch_sequences def tacotron2_single_spk_batch_fn(examples): diff --git a/paddlespeech/t2s/data/batch.py b/paddlespeech/t2s/datasets/batch.py similarity index 100% rename from paddlespeech/t2s/data/batch.py rename to paddlespeech/t2s/datasets/batch.py diff --git a/paddlespeech/t2s/datasets/common.py b/paddlespeech/t2s/datasets/common.py deleted file mode 100644 index 122a35aeb..000000000 --- a/paddlespeech/t2s/datasets/common.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from pathlib import Path -from typing import List - -import librosa -import numpy as np -from paddle.io import Dataset - -__all__ = ["AudioSegmentDataset", "AudioDataset", "AudioFolderDataset"] - - -class AudioSegmentDataset(Dataset): - """A simple dataset adaptor for audio files to train vocoders. - Read -> trim silence -> normalize -> extract a segment - """ - - def __init__(self, - file_paths: List[Path], - sample_rate: int, - length: int, - top_db: float): - self.file_paths = file_paths - self.sr = sample_rate - self.top_db = top_db - self.length = length # samples in the clip - - def __getitem__(self, i): - fpath = self.file_paths[i] - y, sr = librosa.load(fpath, sr=self.sr) - y, _ = librosa.effects.trim(y, top_db=self.top_db) - y = librosa.util.normalize(y) - y = y.astype(np.float32) - - # pad or trim - if y.size <= self.length: - y = np.pad(y, [0, self.length - len(y)], mode='constant') - else: - start = np.random.randint(0, 1 + len(y) - self.length) - y = y[start:start + self.length] - return y - - def __len__(self): - return len(self.file_paths) - - -class AudioDataset(Dataset): - """A simple dataset adaptor for the audio files. - Read -> trim silence -> normalize - """ - - def __init__(self, - file_paths: List[Path], - sample_rate: int, - top_db: float=60): - self.file_paths = file_paths - self.sr = sample_rate - self.top_db = top_db - - def __getitem__(self, i): - fpath = self.file_paths[i] - y, sr = librosa.load(fpath, sr=self.sr) - y, _ = librosa.effects.trim(y, top_db=self.top_db) - y = librosa.util.normalize(y) - y = y.astype(np.float32) - return y - - def __len__(self): - return len(self.file_paths) - - -class AudioFolderDataset(AudioDataset): - def __init__( - self, - root, - sample_rate, - top_db=60, - extension=".wav", ): - root = Path(root).expanduser() - file_paths = sorted(list(root.rglob("*{}".format(extension)))) - super().__init__(file_paths, sample_rate, top_db) diff --git a/paddlespeech/t2s/data/get_feats.py b/paddlespeech/t2s/datasets/get_feats.py similarity index 100% rename from paddlespeech/t2s/data/get_feats.py rename to paddlespeech/t2s/datasets/get_feats.py diff --git a/paddlespeech/t2s/exps/fastspeech2/preprocess.py b/paddlespeech/t2s/exps/fastspeech2/preprocess.py index fd6da2cb3..5bda75451 100644 --- a/paddlespeech/t2s/exps/fastspeech2/preprocess.py +++ b/paddlespeech/t2s/exps/fastspeech2/preprocess.py @@ -27,9 +27,9 @@ import tqdm import yaml from yacs.config import CfgNode -from paddlespeech.t2s.data.get_feats import Energy -from paddlespeech.t2s.data.get_feats import LogMelFBank -from paddlespeech.t2s.data.get_feats import Pitch +from paddlespeech.t2s.datasets.get_feats import Energy +from paddlespeech.t2s.datasets.get_feats import LogMelFBank +from paddlespeech.t2s.datasets.get_feats import Pitch from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_length from paddlespeech.t2s.datasets.preprocess_utils import get_input_token from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur diff --git a/paddlespeech/t2s/exps/gan_vocoder/parallelwave_gan/synthesize_from_wav.py b/paddlespeech/t2s/exps/gan_vocoder/parallelwave_gan/synthesize_from_wav.py index f5affb50b..def30e67a 100644 --- a/paddlespeech/t2s/exps/gan_vocoder/parallelwave_gan/synthesize_from_wav.py +++ b/paddlespeech/t2s/exps/gan_vocoder/parallelwave_gan/synthesize_from_wav.py @@ -23,7 +23,7 @@ import soundfile as sf import yaml from yacs.config import CfgNode -from paddlespeech.t2s.data.get_feats import LogMelFBank +from paddlespeech.t2s.datasets.get_feats import LogMelFBank from paddlespeech.t2s.models.parallel_wavegan import PWGGenerator from paddlespeech.t2s.models.parallel_wavegan import PWGInference from paddlespeech.t2s.modules.normalizer import ZScore diff --git a/paddlespeech/t2s/exps/gan_vocoder/preprocess.py b/paddlespeech/t2s/exps/gan_vocoder/preprocess.py index 47d0a2921..4871bca71 100644 --- a/paddlespeech/t2s/exps/gan_vocoder/preprocess.py +++ b/paddlespeech/t2s/exps/gan_vocoder/preprocess.py @@ -27,7 +27,7 @@ import tqdm import yaml from yacs.config import CfgNode -from paddlespeech.t2s.data.get_feats import LogMelFBank +from paddlespeech.t2s.datasets.get_feats import LogMelFBank from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur from paddlespeech.t2s.datasets.preprocess_utils import merge_silence from paddlespeech.t2s.utils import str2bool diff --git a/paddlespeech/t2s/exps/speedyspeech/preprocess.py b/paddlespeech/t2s/exps/speedyspeech/preprocess.py index db888fbac..3f81c4e14 100644 --- a/paddlespeech/t2s/exps/speedyspeech/preprocess.py +++ b/paddlespeech/t2s/exps/speedyspeech/preprocess.py @@ -27,7 +27,7 @@ import tqdm import yaml from yacs.config import CfgNode -from paddlespeech.t2s.data.get_feats import LogMelFBank +from paddlespeech.t2s.datasets.get_feats import LogMelFBank from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_length from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur from paddlespeech.t2s.datasets.preprocess_utils import get_phones_tones diff --git a/paddlespeech/t2s/exps/tacotron2/preprocess.py b/paddlespeech/t2s/exps/tacotron2/preprocess.py index ffbeaad92..7f41089eb 100644 --- a/paddlespeech/t2s/exps/tacotron2/preprocess.py +++ b/paddlespeech/t2s/exps/tacotron2/preprocess.py @@ -27,7 +27,7 @@ import tqdm import yaml from yacs.config import CfgNode -from paddlespeech.t2s.data.get_feats import LogMelFBank +from paddlespeech.t2s.datasets.get_feats import LogMelFBank from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_length from paddlespeech.t2s.datasets.preprocess_utils import get_input_token from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur diff --git a/paddlespeech/t2s/exps/transformer_tts/preprocess.py b/paddlespeech/t2s/exps/transformer_tts/preprocess.py index 93158b671..7cfa91b9d 100644 --- a/paddlespeech/t2s/exps/transformer_tts/preprocess.py +++ b/paddlespeech/t2s/exps/transformer_tts/preprocess.py @@ -26,7 +26,7 @@ import tqdm import yaml from yacs.config import CfgNode as Configuration -from paddlespeech.t2s.data.get_feats import LogMelFBank +from paddlespeech.t2s.datasets.get_feats import LogMelFBank from paddlespeech.t2s.frontend import English diff --git a/paddlespeech/t2s/exps/waveflow/ljspeech.py b/paddlespeech/t2s/exps/waveflow/ljspeech.py index 655b63dad..a6efa9ec2 100644 --- a/paddlespeech/t2s/exps/waveflow/ljspeech.py +++ b/paddlespeech/t2s/exps/waveflow/ljspeech.py @@ -17,8 +17,8 @@ import numpy as np import pandas from paddle.io import Dataset -from paddlespeech.t2s.data.batch import batch_spec -from paddlespeech.t2s.data.batch import batch_wav +from paddlespeech.t2s.datasets.batch import batch_spec +from paddlespeech.t2s.datasets.batch import batch_wav class LJSpeech(Dataset):