Merge pull request #2263 from oyjxer/pc

[TTS]add ernie-sat sampler
pull/2283/head
TianYuan 2 years ago committed by GitHub
commit f9a6970a62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,181 @@
import paddle
import math
import numpy as np
from paddle.io import BatchSampler
class ErnieSATSampler(BatchSampler):
"""Sampler that restricts data loading to a subset of the dataset.
In such case, each process can pass a DistributedBatchSampler instance
as a DataLoader sampler, and load a subset of the original dataset that
is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Args:
dataset(paddle.io.Dataset): this could be a `paddle.io.Dataset` implement
or other python object which implemented
`__len__` for BatchSampler to get sample
number of data source.
batch_size(int): sample indice number in a mini-batch indices.
num_replicas(int, optional): porcess number in distributed training.
If :attr:`num_replicas` is None, :attr:`num_replicas` will be
retrieved from :code:`paddle.distributed.ParallenEnv`.
Default None.
rank(int, optional): the rank of the current process among :attr:`num_replicas`
processes. If :attr:`rank` is None, :attr:`rank` is retrieved from
:code:`paddle.distributed.ParallenEnv`. Default None.
shuffle(bool): whther to shuffle indices order before genrating
batch indices. Default False.
drop_last(bool): whether drop the last incomplete batch dataset size
is not divisible by the batch size. Default False
Examples:
.. code-block:: python
import numpy as np
from paddle.io import Dataset, DistributedBatchSampler
# init with dataset
class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
image = np.random.random([784]).astype('float32')
label = np.random.randint(0, 9, (1, )).astype('int64')
return image, label
def __len__(self):
return self.num_samples
dataset = RandomDataset(100)
sampler = DistributedBatchSampler(dataset, batch_size=64)
for data in sampler:
# do something
break
"""
def __init__(self,
dataset,
batch_size,
num_replicas=None,
rank=None,
shuffle=False,
drop_last=False):
self.dataset = dataset
assert isinstance(batch_size, int) and batch_size > 0, \
"batch_size should be a positive integer"
self.batch_size = batch_size
assert isinstance(shuffle, bool), \
"shuffle should be a boolean value"
self.shuffle = shuffle
assert isinstance(drop_last, bool), \
"drop_last should be a boolean number"
from paddle.fluid.dygraph.parallel import ParallelEnv
if num_replicas is not None:
assert isinstance(num_replicas, int) and num_replicas > 0, \
"num_replicas should be a positive integer"
self.nranks = num_replicas
else:
self.nranks = ParallelEnv().nranks
if rank is not None:
assert isinstance(rank, int) and rank >= 0, \
"rank should be a non-negative integer"
self.local_rank = rank
else:
self.local_rank = ParallelEnv().local_rank
self.drop_last = drop_last
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.nranks))
self.total_size = self.num_samples * self.nranks
def __iter__(self):
num_samples = len(self.dataset)
indices = np.arange(num_samples).tolist()
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
def _get_indices_by_batch_size(indices):
subsampled_indices = []
last_batch_size = self.total_size % (self.batch_size * self.nranks)
assert last_batch_size % self.nranks == 0
last_local_batch_size = last_batch_size // self.nranks
for i in range(self.local_rank * self.batch_size,
len(indices) - last_batch_size,
self.batch_size * self.nranks):
subsampled_indices.extend(indices[i:i + self.batch_size])
indices = indices[len(indices) - last_batch_size:]
subsampled_indices.extend(indices[
self.local_rank * last_local_batch_size:(
self.local_rank + 1) * last_local_batch_size])
return subsampled_indices
if self.nranks > 1:
indices = _get_indices_by_batch_size(indices)
assert len(indices) == self.num_samples
_sample_iter = iter(indices)
batch_indices_list = []
batch_indices = []
for idx in _sample_iter:
batch_indices.append(idx)
if len(batch_indices) == self.batch_size:
batch_indices_list.append(batch_indices)
batch_indices = []
if not self.drop_last and len(batch_indices) > 0:
batch_indices_list.append(batch_indices)
if self.shuffle:
np.random.RandomState(self.epoch).shuffle(batch_indices_list)
self.epoch += 1
for batch_indices in batch_indices_list:
yield batch_indices
def __len__(self):
num_samples = self.num_samples
num_samples += int(not self.drop_last) * (self.batch_size - 1)
return num_samples // self.batch_size
def set_epoch(self, epoch):
"""
Sets the epoch number. When :attr:`shuffle=True`, this number is used
as seeds of random numbers. By default, users may not set this, all
replicas (workers) use a different random ordering for each epoch.
If set same number at each epoch, this sampler will yield the same
ordering at all epoches.
Arguments:
epoch (int): Epoch number.
Examples:
.. code-block:: python
import numpy as np
from paddle.io import Dataset, DistributedBatchSampler
# init with dataset
class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
image = np.random.random([784]).astype('float32')
label = np.random.randint(0, 9, (1, )).astype('int64')
return image, label
def __len__(self):
return self.num_samples
dataset = RandomDataset(100)
sampler = DistributedBatchSampler(dataset, batch_size=64)
for epoch in range(10):
sampler.set_epoch(epoch)
"""
self.epoch = epoch

@ -118,7 +118,7 @@ def main():
record["spk_emb"] = str(item["spk_emb"]) record["spk_emb"] = str(item["spk_emb"])
output_metadata.append(record) output_metadata.append(record)
output_metadata.sort(key=itemgetter('utt_id')) output_metadata.sort(key=itemgetter('speech_lengths'))
output_metadata_path = Path(args.dumpdir) / "metadata.jsonl" output_metadata_path = Path(args.dumpdir) / "metadata.jsonl"
with jsonlines.open(output_metadata_path, 'w') as writer: with jsonlines.open(output_metadata_path, 'w') as writer:
for item in output_metadata: for item in output_metadata:

@ -165,7 +165,7 @@ def process_sentences(config,
if record: if record:
results.append(record) results.append(record)
results.sort(key=itemgetter("utt_id")) results.sort(key=itemgetter("speech_lengths"))
# replace 'w' with 'a' to write from the end of file # replace 'w' with 'a' to write from the end of file
with jsonlines.open(output_dir / "metadata.jsonl", 'a') as writer: with jsonlines.open(output_dir / "metadata.jsonl", 'a') as writer:
for item in results: for item in results:

@ -31,6 +31,7 @@ from yacs.config import CfgNode
from paddlespeech.t2s.datasets.am_batch_fn import build_erniesat_collate_fn from paddlespeech.t2s.datasets.am_batch_fn import build_erniesat_collate_fn
from paddlespeech.t2s.datasets.data_table import DataTable from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.datasets.sampler import ErnieSATSampler
from paddlespeech.t2s.models.ernie_sat import ErnieSAT from paddlespeech.t2s.models.ernie_sat import ErnieSAT
from paddlespeech.t2s.models.ernie_sat import ErnieSATEvaluator from paddlespeech.t2s.models.ernie_sat import ErnieSATEvaluator
from paddlespeech.t2s.models.ernie_sat import ErnieSATUpdater from paddlespeech.t2s.models.ernie_sat import ErnieSATUpdater
@ -86,7 +87,7 @@ def train_sp(args, config):
seg_emb=config.model['enc_input_layer'] == 'sega_mlm', seg_emb=config.model['enc_input_layer'] == 'sega_mlm',
text_masking=config["model"]["text_masking"]) text_masking=config["model"]["text_masking"])
train_sampler = DistributedBatchSampler( train_sampler = ErnieSATSampler(
train_dataset, train_dataset,
batch_size=config.batch_size, batch_size=config.batch_size,
shuffle=True, shuffle=True,

@ -27,7 +27,7 @@ from timer import timer
from paddlespeech.t2s.training.reporter import report from paddlespeech.t2s.training.reporter import report
from paddlespeech.t2s.training.updater import UpdaterBase from paddlespeech.t2s.training.updater import UpdaterBase
from paddlespeech.t2s.training.updater import UpdaterState from paddlespeech.t2s.training.updater import UpdaterState
from paddlespeech.t2s.datasets.sampler import ErnieSATSampler
class StandardUpdater(UpdaterBase): class StandardUpdater(UpdaterBase):
"""An example of over-simplification. Things may not be that simple, but """An example of over-simplification. Things may not be that simple, but
@ -165,7 +165,8 @@ class StandardUpdater(UpdaterBase):
# NOTE: all batch sampler for distributed training should # NOTE: all batch sampler for distributed training should
# subclass DistributedBatchSampler and implement `set_epoch` method # subclass DistributedBatchSampler and implement `set_epoch` method
batch_sampler = self.dataloader.batch_sampler batch_sampler = self.dataloader.batch_sampler
if isinstance(batch_sampler, DistributedBatchSampler): if isinstance(batch_sampler, DistributedBatchSampler) \
or isinstance(batch_sampler, ErnieSATSampler):
batch_sampler.set_epoch(self.state.epoch) batch_sampler.set_epoch(self.state.epoch)
self.train_iterator = iter(self.dataloader) self.train_iterator = iter(self.dataloader)

Loading…
Cancel
Save