|
|
@ -65,6 +65,7 @@ class SimpleShardList(IterableDataset):
|
|
|
|
|
|
|
|
|
|
|
|
def split_by_node(src, group=None):
|
|
|
|
def split_by_node(src, group=None):
|
|
|
|
rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group)
|
|
|
|
rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group)
|
|
|
|
|
|
|
|
logger.info(f"world_size:{world_size}, rank:{rank}")
|
|
|
|
if world_size > 1:
|
|
|
|
if world_size > 1:
|
|
|
|
for s in islice(src, rank, None, world_size):
|
|
|
|
for s in islice(src, rank, None, world_size):
|
|
|
|
yield s
|
|
|
|
yield s
|
|
|
@ -83,6 +84,7 @@ def single_node_only(src, group=None):
|
|
|
|
|
|
|
|
|
|
|
|
def split_by_worker(src):
|
|
|
|
def split_by_worker(src):
|
|
|
|
rank, world_size, worker, num_workers = utils.paddle_worker_info()
|
|
|
|
rank, world_size, worker, num_workers = utils.paddle_worker_info()
|
|
|
|
|
|
|
|
logger.info(f"num_workers:{num_workers}, worker:{worker}")
|
|
|
|
if num_workers > 1:
|
|
|
|
if num_workers > 1:
|
|
|
|
for s in islice(src, worker, None, num_workers):
|
|
|
|
for s in islice(src, worker, None, num_workers):
|
|
|
|
yield s
|
|
|
|
yield s
|
|
|
|