diff --git a/paddlespeech/t2s/datasets/sampler.py b/paddlespeech/t2s/datasets/sampler.py new file mode 100644 index 00000000..a69bc860 --- /dev/null +++ b/paddlespeech/t2s/datasets/sampler.py @@ -0,0 +1,181 @@ +import paddle +import math +import numpy as np +from paddle.io import BatchSampler + +class ErnieSATSampler(BatchSampler): + """Sampler that restricts data loading to a subset of the dataset. + In such case, each process can pass a DistributedBatchSampler instance + as a DataLoader sampler, and load a subset of the original dataset that + is exclusive to it. + .. note:: + Dataset is assumed to be of constant size. + + Args: + dataset(paddle.io.Dataset): this could be a `paddle.io.Dataset` implement + or other python object which implemented + `__len__` for BatchSampler to get sample + number of data source. + batch_size(int): sample indice number in a mini-batch indices. + num_replicas(int, optional): porcess number in distributed training. + If :attr:`num_replicas` is None, :attr:`num_replicas` will be + retrieved from :code:`paddle.distributed.ParallenEnv`. + Default None. + rank(int, optional): the rank of the current process among :attr:`num_replicas` + processes. If :attr:`rank` is None, :attr:`rank` is retrieved from + :code:`paddle.distributed.ParallenEnv`. Default None. + shuffle(bool): whther to shuffle indices order before genrating + batch indices. Default False. + drop_last(bool): whether drop the last incomplete batch dataset size + is not divisible by the batch size. Default False + Examples: + .. code-block:: python + import numpy as np + from paddle.io import Dataset, DistributedBatchSampler + # init with dataset + class RandomDataset(Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples + + def __getitem__(self, idx): + image = np.random.random([784]).astype('float32') + label = np.random.randint(0, 9, (1, )).astype('int64') + return image, label + + def __len__(self): + return self.num_samples + + dataset = RandomDataset(100) + sampler = DistributedBatchSampler(dataset, batch_size=64) + for data in sampler: + # do something + break + """ + + def __init__(self, + dataset, + batch_size, + num_replicas=None, + rank=None, + shuffle=False, + drop_last=False): + self.dataset = dataset + + assert isinstance(batch_size, int) and batch_size > 0, \ + "batch_size should be a positive integer" + self.batch_size = batch_size + assert isinstance(shuffle, bool), \ + "shuffle should be a boolean value" + self.shuffle = shuffle + assert isinstance(drop_last, bool), \ + "drop_last should be a boolean number" + + from paddle.fluid.dygraph.parallel import ParallelEnv + + if num_replicas is not None: + assert isinstance(num_replicas, int) and num_replicas > 0, \ + "num_replicas should be a positive integer" + self.nranks = num_replicas + else: + self.nranks = ParallelEnv().nranks + + if rank is not None: + assert isinstance(rank, int) and rank >= 0, \ + "rank should be a non-negative integer" + self.local_rank = rank + else: + self.local_rank = ParallelEnv().local_rank + + self.drop_last = drop_last + self.epoch = 0 + self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.nranks)) + self.total_size = self.num_samples * self.nranks + + def __iter__(self): + num_samples = len(self.dataset) + indices = np.arange(num_samples).tolist() + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + def _get_indices_by_batch_size(indices): + subsampled_indices = [] + last_batch_size = self.total_size % (self.batch_size * self.nranks) + assert last_batch_size % self.nranks == 0 + last_local_batch_size = last_batch_size // self.nranks + + for i in range(self.local_rank * self.batch_size, + len(indices) - last_batch_size, + self.batch_size * self.nranks): + subsampled_indices.extend(indices[i:i + self.batch_size]) + + indices = indices[len(indices) - last_batch_size:] + subsampled_indices.extend(indices[ + self.local_rank * last_local_batch_size:( + self.local_rank + 1) * last_local_batch_size]) + return subsampled_indices + + if self.nranks > 1: + indices = _get_indices_by_batch_size(indices) + + assert len(indices) == self.num_samples + _sample_iter = iter(indices) + + batch_indices_list = [] + batch_indices = [] + for idx in _sample_iter: + batch_indices.append(idx) + if len(batch_indices) == self.batch_size: + batch_indices_list.append(batch_indices) + batch_indices = [] + if not self.drop_last and len(batch_indices) > 0: + batch_indices_list.append(batch_indices) + + if self.shuffle: + np.random.RandomState(self.epoch).shuffle(batch_indices_list) + self.epoch += 1 + + for batch_indices in batch_indices_list: + yield batch_indices + + def __len__(self): + num_samples = self.num_samples + num_samples += int(not self.drop_last) * (self.batch_size - 1) + return num_samples // self.batch_size + + def set_epoch(self, epoch): + """ + Sets the epoch number. When :attr:`shuffle=True`, this number is used + as seeds of random numbers. By default, users may not set this, all + replicas (workers) use a different random ordering for each epoch. + If set same number at each epoch, this sampler will yield the same + ordering at all epoches. + Arguments: + epoch (int): Epoch number. + Examples: + .. code-block:: python + + import numpy as np + + from paddle.io import Dataset, DistributedBatchSampler + + # init with dataset + class RandomDataset(Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples + + def __getitem__(self, idx): + image = np.random.random([784]).astype('float32') + label = np.random.randint(0, 9, (1, )).astype('int64') + return image, label + + def __len__(self): + return self.num_samples + + dataset = RandomDataset(100) + sampler = DistributedBatchSampler(dataset, batch_size=64) + + for epoch in range(10): + sampler.set_epoch(epoch) + """ + self.epoch = epoch diff --git a/paddlespeech/t2s/exps/ernie_sat/normalize.py b/paddlespeech/t2s/exps/ernie_sat/normalize.py index 74cdae2a..fed11147 100644 --- a/paddlespeech/t2s/exps/ernie_sat/normalize.py +++ b/paddlespeech/t2s/exps/ernie_sat/normalize.py @@ -118,7 +118,7 @@ def main(): record["spk_emb"] = str(item["spk_emb"]) output_metadata.append(record) - output_metadata.sort(key=itemgetter('utt_id')) + output_metadata.sort(key=itemgetter('speech_lengths')) output_metadata_path = Path(args.dumpdir) / "metadata.jsonl" with jsonlines.open(output_metadata_path, 'w') as writer: for item in output_metadata: diff --git a/paddlespeech/t2s/exps/ernie_sat/preprocess.py b/paddlespeech/t2s/exps/ernie_sat/preprocess.py index fc9e0888..486ed13a 100644 --- a/paddlespeech/t2s/exps/ernie_sat/preprocess.py +++ b/paddlespeech/t2s/exps/ernie_sat/preprocess.py @@ -165,7 +165,7 @@ def process_sentences(config, if record: results.append(record) - results.sort(key=itemgetter("utt_id")) + results.sort(key=itemgetter("speech_lengths")) # replace 'w' with 'a' to write from the end of file with jsonlines.open(output_dir / "metadata.jsonl", 'a') as writer: for item in results: diff --git a/paddlespeech/t2s/exps/ernie_sat/train.py b/paddlespeech/t2s/exps/ernie_sat/train.py index ccd1245e..af653ef8 100644 --- a/paddlespeech/t2s/exps/ernie_sat/train.py +++ b/paddlespeech/t2s/exps/ernie_sat/train.py @@ -31,6 +31,7 @@ from yacs.config import CfgNode from paddlespeech.t2s.datasets.am_batch_fn import build_erniesat_collate_fn from paddlespeech.t2s.datasets.data_table import DataTable +from paddlespeech.t2s.datasets.sampler import ErnieSATSampler from paddlespeech.t2s.models.ernie_sat import ErnieSAT from paddlespeech.t2s.models.ernie_sat import ErnieSATEvaluator from paddlespeech.t2s.models.ernie_sat import ErnieSATUpdater @@ -86,7 +87,7 @@ def train_sp(args, config): seg_emb=config.model['enc_input_layer'] == 'sega_mlm', text_masking=config["model"]["text_masking"]) - train_sampler = DistributedBatchSampler( + train_sampler = ErnieSATSampler( train_dataset, batch_size=config.batch_size, shuffle=True, diff --git a/paddlespeech/t2s/training/updaters/standard_updater.py b/paddlespeech/t2s/training/updaters/standard_updater.py index b1c48620..668d2fc6 100644 --- a/paddlespeech/t2s/training/updaters/standard_updater.py +++ b/paddlespeech/t2s/training/updaters/standard_updater.py @@ -27,7 +27,7 @@ from timer import timer from paddlespeech.t2s.training.reporter import report from paddlespeech.t2s.training.updater import UpdaterBase from paddlespeech.t2s.training.updater import UpdaterState - +from paddlespeech.t2s.datasets.sampler import ErnieSATSampler class StandardUpdater(UpdaterBase): """An example of over-simplification. Things may not be that simple, but @@ -165,7 +165,8 @@ class StandardUpdater(UpdaterBase): # NOTE: all batch sampler for distributed training should # subclass DistributedBatchSampler and implement `set_epoch` method batch_sampler = self.dataloader.batch_sampler - if isinstance(batch_sampler, DistributedBatchSampler): + if isinstance(batch_sampler, DistributedBatchSampler) \ + or isinstance(batch_sampler, ErnieSATSampler): batch_sampler.set_epoch(self.state.epoch) self.train_iterator = iter(self.dataloader)