135 lines
3.9 KiB

#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
# Modified from https://github.com/webdataset/webdataset
"""Miscellaneous utility functions."""
import importlib
import itertools as itt
import os
import re
import sys
from typing import Any
from typing import Callable
from typing import Iterator
from typing import Union
from ..utils.log import Logger
logger = Logger(__name__)
def make_seed(*args):
seed = 0
for arg in args:
seed = (seed * 31 + hash(arg)) & 0x7FFFFFFF
return seed
class PipelineStage:
def invoke(self, *args, **kw):
raise NotImplementedError
def identity(x: Any) -> Any:
"""Return the argument as is."""
return x
def safe_eval(s: str, expr: str="{}"):
"""Evaluate the given expression more safely."""
if re.sub("[^A-Za-z0-9_]", "", s) != s:
raise ValueError(f"safe_eval: illegal characters in: '{s}'")
return eval(expr.format(s))
def lookup_sym(sym: str, modules: list):
"""Look up a symbol in a list of modules."""
for mname in modules:
module = importlib.import_module(mname, package="webdataset")
result = getattr(module, sym, None)
if result is not None:
return result
return None
def repeatedly0(loader: Iterator,
nepochs: int=sys.maxsize,
nbatches: int=sys.maxsize):
"""Repeatedly returns batches from a DataLoader."""
for epoch in range(nepochs):
for sample in itt.islice(loader, nbatches):
yield sample
def guess_batchsize(batch: Union[tuple, list]):
"""Guess the batch size by looking at the length of the first element in a tuple."""
return len(batch[0])
def repeatedly(
source: Iterator,
nepochs: int=None,
nbatches: int=None,
nsamples: int=None,
batchsize: Callable[..., int]=guess_batchsize, ):
"""Repeatedly yield samples from an iterator."""
epoch = 0
batch = 0
total = 0
while True:
for sample in source:
yield sample
batch += 1
if nbatches is not None and batch >= nbatches:
return
if nsamples is not None:
total += guess_batchsize(sample)
if total >= nsamples:
return
epoch += 1
if nepochs is not None and epoch >= nepochs:
return
def paddle_worker_info(group=None):
"""Return node and worker info for PyTorch and some distributed environments."""
rank = 0
world_size = 1
worker = 0
num_workers = 1
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
else:
try:
import paddle.distributed
group = group or paddle.distributed.get_group()
rank = paddle.distributed.get_rank()
world_size = paddle.distributed.get_world_size()
except ModuleNotFoundError:
pass
if "WORKER" in os.environ and "NUM_WORKERS" in os.environ:
worker = int(os.environ["WORKER"])
num_workers = int(os.environ["NUM_WORKERS"])
else:
try:
from paddle.io import get_worker_info
worker_info = get_worker_info()
if worker_info is not None:
worker = worker_info.id
num_workers = worker_info.num_workers
except ModuleNotFoundError as E:
logger.info(f"not found {E}")
exit(-1)
return rank, world_size, worker, num_workers
def paddle_worker_seed(group=None):
"""Compute a distinct, deterministic RNG seed for each worker and node."""
rank, world_size, worker, num_workers = paddle_worker_info(group=group)
return rank * 1000 + worker