Merge branch 'webdataset' of https://github.com/Jackwaterveg/DeepSpeech into webdataset

pull/2062/head
huangyuxin 2 years ago
commit 6ec6921255

@ -67,7 +67,7 @@ maxlen_out: 150 # if output length(number of tokens) > maxlen-out, data is auto
resample_rate: 16000
shuffle_size: 1500
sort_size: 1000
num_workers: 0
num_workers: 8
prefetch_factor: 10
dist_sampler: True
num_encs: 1

@ -45,8 +45,7 @@ python3 -u ${BIN_DIR}/train.py \
--benchmark-batch-size ${benchmark_batch_size} \
--benchmark-max-step ${benchmark_max_step}
else
#NCCL_SOCKET_IFNAME=eth0
python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
NCCL_SOCKET_IFNAME=eth0 python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
--ngpu ${ngpu} \
--seed ${seed} \
--config ${config_path} \

@ -65,6 +65,7 @@ class SimpleShardList(IterableDataset):
def split_by_node(src, group=None):
rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group)
logger.info(f"world_size:{world_size}, rank:{rank}")
if world_size > 1:
for s in islice(src, rank, None, world_size):
yield s
@ -83,6 +84,7 @@ def single_node_only(src, group=None):
def split_by_worker(src):
rank, world_size, worker, num_workers = utils.paddle_worker_info()
logger.info(f"num_workers:{num_workers}, worker:{worker}")
if num_workers > 1:
for s in islice(src, worker, None, num_workers):
yield s

@ -16,6 +16,9 @@ import re
import sys
from typing import Any, Callable, Iterator, Optional, Union
from ..utils.log import Logger
logger = Logger(__name__)
def make_seed(*args):
seed = 0
@ -112,13 +115,14 @@ def paddle_worker_info(group=None):
num_workers = int(os.environ["NUM_WORKERS"])
else:
try:
import paddle.io.get_worker_info
from paddle.io import get_worker_info
worker_info = paddle.io.get_worker_info()
if worker_info is not None:
worker = worker_info.id
num_workers = worker_info.num_workers
except ModuleNotFoundError:
pass
except ModuleNotFoundError as E:
logger.info(f"not found {E}")
exit(-1)
return rank, world_size, worker, num_workers

@ -33,7 +33,7 @@ from ..log import logger
from ..utils import CLI_TIMER
from ..utils import stats_wrapper
from ..utils import timer_register
from paddlespeech.s2t.audio.transformation import Transformation
from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.utils.utility import UpdateConfig

@ -104,7 +104,7 @@ class StreamDataLoader():
if self.dist_sampler:
base_dataset = streamdata.DataPipeline(
streamdata.SimpleShardList(shardlist),
streamdata.split_by_node,
streamdata.split_by_node if train_mode else streamdata.placeholder(),
streamdata.split_by_worker,
streamdata.tarfile_to_samples(streamdata.reraise_exception)
)

Loading…
Cancel
Save