# # Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # This file is part of the WebDataset library. # See the LICENSE file for licensing terms (BSD-style). # # Modified from https://github.com/webdataset/webdataset """Miscellaneous utility functions.""" import importlib import itertools as itt import os import re import sys from typing import Any, Callable, Iterator, Optional, Union from ..utils.log import Logger logger = Logger(__name__) def make_seed(*args): seed = 0 for arg in args: seed = (seed * 31 + hash(arg)) & 0x7FFFFFFF return seed class PipelineStage: def invoke(self, *args, **kw): raise NotImplementedError def identity(x: Any) -> Any: """Return the argument as is.""" return x def safe_eval(s: str, expr: str = "{}"): """Evaluate the given expression more safely.""" if re.sub("[^A-Za-z0-9_]", "", s) != s: raise ValueError(f"safe_eval: illegal characters in: '{s}'") return eval(expr.format(s)) def lookup_sym(sym: str, modules: list): """Look up a symbol in a list of modules.""" for mname in modules: module = importlib.import_module(mname, package="webdataset") result = getattr(module, sym, None) if result is not None: return result return None def repeatedly0( loader: Iterator, nepochs: int = sys.maxsize, nbatches: int = sys.maxsize ): """Repeatedly returns batches from a DataLoader.""" for epoch in range(nepochs): for sample in itt.islice(loader, nbatches): yield sample def guess_batchsize(batch: Union[tuple, list]): """Guess the batch size by looking at the length of the first element in a tuple.""" return len(batch[0]) def repeatedly( source: Iterator, nepochs: int = None, nbatches: int = None, nsamples: int = None, batchsize: Callable[..., int] = guess_batchsize, ): """Repeatedly yield samples from an iterator.""" epoch = 0 batch = 0 total = 0 while True: for sample in source: yield sample batch += 1 if nbatches is not None and batch >= nbatches: return if nsamples is not None: total += guess_batchsize(sample) if total >= nsamples: return epoch += 1 if nepochs is not None and epoch >= nepochs: return def paddle_worker_info(group=None): """Return node and worker info for PyTorch and some distributed environments.""" rank = 0 world_size = 1 worker = 0 num_workers = 1 if "RANK" in os.environ and "WORLD_SIZE" in os.environ: rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) else: try: import paddle.distributed group = group or paddle.distributed.get_group() rank = paddle.distributed.get_rank() world_size = paddle.distributed.get_world_size() except ModuleNotFoundError: pass if "WORKER" in os.environ and "NUM_WORKERS" in os.environ: worker = int(os.environ["WORKER"]) num_workers = int(os.environ["NUM_WORKERS"]) else: try: from paddle.io import get_worker_info worker_info = paddle.io.get_worker_info() if worker_info is not None: worker = worker_info.id num_workers = worker_info.num_workers except ModuleNotFoundError as E: logger.info(f"not found {E}") exit(-1) return rank, world_size, worker, num_workers def paddle_worker_seed(group=None): """Compute a distinct, deterministic RNG seed for each worker and node.""" rank, world_size, worker, num_workers = paddle_worker_info(group=group) return rank * 1000 + worker