Merge pull request #2263 from oyjxer/pc

[TTS]add ernie-sat sampler
pull/2283/head
TianYuan 2 years ago committed by GitHub
commit f9a6970a62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,181 @@
import paddle
import math
import numpy as np
from paddle.io import BatchSampler
class ErnieSATSampler(BatchSampler):
"""Sampler that restricts data loading to a subset of the dataset.
In such case, each process can pass a DistributedBatchSampler instance
as a DataLoader sampler, and load a subset of the original dataset that
is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Args:
dataset(paddle.io.Dataset): this could be a `paddle.io.Dataset` implement
or other python object which implemented
`__len__` for BatchSampler to get sample
number of data source.
batch_size(int): sample indice number in a mini-batch indices.
num_replicas(int, optional): porcess number in distributed training.
If :attr:`num_replicas` is None, :attr:`num_replicas` will be
retrieved from :code:`paddle.distributed.ParallenEnv`.
Default None.
rank(int, optional): the rank of the current process among :attr:`num_replicas`
processes. If :attr:`rank` is None, :attr:`rank` is retrieved from
:code:`paddle.distributed.ParallenEnv`. Default None.
shuffle(bool): whther to shuffle indices order before genrating
batch indices. Default False.
drop_last(bool): whether drop the last incomplete batch dataset size
is not divisible by the batch size. Default False
Examples:
.. code-block:: python
import numpy as np
from paddle.io import Dataset, DistributedBatchSampler
# init with dataset
class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
image = np.random.random([784]).astype('float32')
label = np.random.randint(0, 9, (1, )).astype('int64')
return image, label
def __len__(self):
return self.num_samples
dataset = RandomDataset(100)
sampler = DistributedBatchSampler(dataset, batch_size=64)
for data in sampler:
# do something
break
"""
def __init__(self,
dataset,
batch_size,
num_replicas=None,
rank=None,
shuffle=False,
drop_last=False):
self.dataset = dataset
assert isinstance(batch_size, int) and batch_size > 0, \
"batch_size should be a positive integer"
self.batch_size = batch_size
assert isinstance(shuffle, bool), \
"shuffle should be a boolean value"
self.shuffle = shuffle
assert isinstance(drop_last, bool), \
"drop_last should be a boolean number"
from paddle.fluid.dygraph.parallel import ParallelEnv
if num_replicas is not None:
assert isinstance(num_replicas, int) and num_replicas > 0, \
"num_replicas should be a positive integer"
self.nranks = num_replicas
else:
self.nranks = ParallelEnv().nranks
if rank is not None:
assert isinstance(rank, int) and rank >= 0, \
"rank should be a non-negative integer"
self.local_rank = rank
else:
self.local_rank = ParallelEnv().local_rank
self.drop_last = drop_last
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.nranks))
self.total_size = self.num_samples * self.nranks
def __iter__(self):
num_samples = len(self.dataset)
indices = np.arange(num_samples).tolist()
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
def _get_indices_by_batch_size(indices):
subsampled_indices = []
last_batch_size = self.total_size % (self.batch_size * self.nranks)
assert last_batch_size % self.nranks == 0
last_local_batch_size = last_batch_size // self.nranks
for i in range(self.local_rank * self.batch_size,
len(indices) - last_batch_size,
self.batch_size * self.nranks):
subsampled_indices.extend(indices[i:i + self.batch_size])
indices = indices[len(indices) - last_batch_size:]
subsampled_indices.extend(indices[
self.local_rank * last_local_batch_size:(
self.local_rank + 1) * last_local_batch_size])
return subsampled_indices
if self.nranks > 1:
indices = _get_indices_by_batch_size(indices)
assert len(indices) == self.num_samples
_sample_iter = iter(indices)
batch_indices_list = []
batch_indices = []
for idx in _sample_iter:
batch_indices.append(idx)
if len(batch_indices) == self.batch_size:
batch_indices_list.append(batch_indices)
batch_indices = []
if not self.drop_last and len(batch_indices) > 0:
batch_indices_list.append(batch_indices)
if self.shuffle:
np.random.RandomState(self.epoch).shuffle(batch_indices_list)
self.epoch += 1
for batch_indices in batch_indices_list:
yield batch_indices
def __len__(self):
num_samples = self.num_samples
num_samples += int(not self.drop_last) * (self.batch_size - 1)
return num_samples // self.batch_size
def set_epoch(self, epoch):
"""
Sets the epoch number. When :attr:`shuffle=True`, this number is used
as seeds of random numbers. By default, users may not set this, all
replicas (workers) use a different random ordering for each epoch.
If set same number at each epoch, this sampler will yield the same
ordering at all epoches.
Arguments:
epoch (int): Epoch number.
Examples:
.. code-block:: python
import numpy as np
from paddle.io import Dataset, DistributedBatchSampler
# init with dataset
class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
image = np.random.random([784]).astype('float32')
label = np.random.randint(0, 9, (1, )).astype('int64')
return image, label
def __len__(self):
return self.num_samples
dataset = RandomDataset(100)
sampler = DistributedBatchSampler(dataset, batch_size=64)
for epoch in range(10):
sampler.set_epoch(epoch)
"""
self.epoch = epoch

@ -118,7 +118,7 @@ def main():
record["spk_emb"] = str(item["spk_emb"])
output_metadata.append(record)
output_metadata.sort(key=itemgetter('utt_id'))
output_metadata.sort(key=itemgetter('speech_lengths'))
output_metadata_path = Path(args.dumpdir) / "metadata.jsonl"
with jsonlines.open(output_metadata_path, 'w') as writer:
for item in output_metadata:

@ -165,7 +165,7 @@ def process_sentences(config,
if record:
results.append(record)
results.sort(key=itemgetter("utt_id"))
results.sort(key=itemgetter("speech_lengths"))
# replace 'w' with 'a' to write from the end of file
with jsonlines.open(output_dir / "metadata.jsonl", 'a') as writer:
for item in results:

@ -31,6 +31,7 @@ from yacs.config import CfgNode
from paddlespeech.t2s.datasets.am_batch_fn import build_erniesat_collate_fn
from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.datasets.sampler import ErnieSATSampler
from paddlespeech.t2s.models.ernie_sat import ErnieSAT
from paddlespeech.t2s.models.ernie_sat import ErnieSATEvaluator
from paddlespeech.t2s.models.ernie_sat import ErnieSATUpdater
@ -86,7 +87,7 @@ def train_sp(args, config):
seg_emb=config.model['enc_input_layer'] == 'sega_mlm',
text_masking=config["model"]["text_masking"])
train_sampler = DistributedBatchSampler(
train_sampler = ErnieSATSampler(
train_dataset,
batch_size=config.batch_size,
shuffle=True,

@ -27,7 +27,7 @@ from timer import timer
from paddlespeech.t2s.training.reporter import report
from paddlespeech.t2s.training.updater import UpdaterBase
from paddlespeech.t2s.training.updater import UpdaterState
from paddlespeech.t2s.datasets.sampler import ErnieSATSampler
class StandardUpdater(UpdaterBase):
"""An example of over-simplification. Things may not be that simple, but
@ -165,7 +165,8 @@ class StandardUpdater(UpdaterBase):
# NOTE: all batch sampler for distributed training should
# subclass DistributedBatchSampler and implement `set_epoch` method
batch_sampler = self.dataloader.batch_sampler
if isinstance(batch_sampler, DistributedBatchSampler):
if isinstance(batch_sampler, DistributedBatchSampler) \
or isinstance(batch_sampler, ErnieSATSampler):
batch_sampler.set_epoch(self.state.epoch)
self.train_iterator = iter(self.dataloader)

Loading…
Cancel
Save