adopt multi machine traiing

pull/2062/head
huangyuxin 3 years ago
parent ac1b301657
commit 429221dc03

@ -67,7 +67,7 @@ maxlen_out: 150 # if output length(number of tokens) > maxlen-out, data is auto
resample_rate: 16000 resample_rate: 16000
shuffle_size: 1500 shuffle_size: 1500
sort_size: 1000 sort_size: 1000
num_workers: 0 num_workers: 8
prefetch_factor: 10 prefetch_factor: 10
dist_sampler: True dist_sampler: True
num_encs: 1 num_encs: 1

@ -45,8 +45,7 @@ python3 -u ${BIN_DIR}/train.py \
--benchmark-batch-size ${benchmark_batch_size} \ --benchmark-batch-size ${benchmark_batch_size} \
--benchmark-max-step ${benchmark_max_step} --benchmark-max-step ${benchmark_max_step}
else else
#NCCL_SOCKET_IFNAME=eth0 NCCL_SOCKET_IFNAME=eth0 python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
--ngpu ${ngpu} \ --ngpu ${ngpu} \
--seed ${seed} \ --seed ${seed} \
--config ${config_path} \ --config ${config_path} \

@ -65,6 +65,7 @@ class SimpleShardList(IterableDataset):
def split_by_node(src, group=None): def split_by_node(src, group=None):
rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group) rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group)
logger.info(f"world_size:{world_size}, rank:{rank}")
if world_size > 1: if world_size > 1:
for s in islice(src, rank, None, world_size): for s in islice(src, rank, None, world_size):
yield s yield s
@ -83,6 +84,7 @@ def single_node_only(src, group=None):
def split_by_worker(src): def split_by_worker(src):
rank, world_size, worker, num_workers = utils.paddle_worker_info() rank, world_size, worker, num_workers = utils.paddle_worker_info()
logger.info(f"num_workers:{num_workers}, worker:{worker}")
if num_workers > 1: if num_workers > 1:
for s in islice(src, worker, None, num_workers): for s in islice(src, worker, None, num_workers):
yield s yield s

@ -16,6 +16,9 @@ import re
import sys import sys
from typing import Any, Callable, Iterator, Optional, Union from typing import Any, Callable, Iterator, Optional, Union
from ..utils.log import Logger
logger = Logger(__name__)
def make_seed(*args): def make_seed(*args):
seed = 0 seed = 0
@ -112,13 +115,14 @@ def paddle_worker_info(group=None):
num_workers = int(os.environ["NUM_WORKERS"]) num_workers = int(os.environ["NUM_WORKERS"])
else: else:
try: try:
import paddle.io.get_worker_info from paddle.io import get_worker_info
worker_info = paddle.io.get_worker_info() worker_info = paddle.io.get_worker_info()
if worker_info is not None: if worker_info is not None:
worker = worker_info.id worker = worker_info.id
num_workers = worker_info.num_workers num_workers = worker_info.num_workers
except ModuleNotFoundError: except ModuleNotFoundError as E:
pass logger.info(f"not found {E}")
exit(-1)
return rank, world_size, worker, num_workers return rank, world_size, worker, num_workers

Loading…
Cancel
Save