|
|
|
@ -16,6 +16,9 @@ import re
|
|
|
|
|
import sys
|
|
|
|
|
from typing import Any, Callable, Iterator, Optional, Union
|
|
|
|
|
|
|
|
|
|
from ..utils.log import Logger
|
|
|
|
|
|
|
|
|
|
logger = Logger(__name__)
|
|
|
|
|
|
|
|
|
|
def make_seed(*args):
|
|
|
|
|
seed = 0
|
|
|
|
@ -112,13 +115,14 @@ def paddle_worker_info(group=None):
|
|
|
|
|
num_workers = int(os.environ["NUM_WORKERS"])
|
|
|
|
|
else:
|
|
|
|
|
try:
|
|
|
|
|
import paddle.io.get_worker_info
|
|
|
|
|
from paddle.io import get_worker_info
|
|
|
|
|
worker_info = paddle.io.get_worker_info()
|
|
|
|
|
if worker_info is not None:
|
|
|
|
|
worker = worker_info.id
|
|
|
|
|
num_workers = worker_info.num_workers
|
|
|
|
|
except ModuleNotFoundError:
|
|
|
|
|
pass
|
|
|
|
|
except ModuleNotFoundError as E:
|
|
|
|
|
logger.info(f"not found {E}")
|
|
|
|
|
exit(-1)
|
|
|
|
|
|
|
|
|
|
return rank, world_size, worker, num_workers
|
|
|
|
|
|
|
|
|
|