Merge branch 'webdataset' of https://github.com/Jackwaterveg/DeepSpeech into webdataset

pull/2062/head
huangyuxin 2 years ago
commit 6ec6921255

@ -67,7 +67,7 @@ maxlen_out: 150 # if output length(number of tokens) > maxlen-out, data is auto
resample_rate: 16000 resample_rate: 16000
shuffle_size: 1500 shuffle_size: 1500
sort_size: 1000 sort_size: 1000
num_workers: 0 num_workers: 8
prefetch_factor: 10 prefetch_factor: 10
dist_sampler: True dist_sampler: True
num_encs: 1 num_encs: 1

@ -45,8 +45,7 @@ python3 -u ${BIN_DIR}/train.py \
--benchmark-batch-size ${benchmark_batch_size} \ --benchmark-batch-size ${benchmark_batch_size} \
--benchmark-max-step ${benchmark_max_step} --benchmark-max-step ${benchmark_max_step}
else else
#NCCL_SOCKET_IFNAME=eth0 NCCL_SOCKET_IFNAME=eth0 python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
--ngpu ${ngpu} \ --ngpu ${ngpu} \
--seed ${seed} \ --seed ${seed} \
--config ${config_path} \ --config ${config_path} \

@ -65,6 +65,7 @@ class SimpleShardList(IterableDataset):
def split_by_node(src, group=None): def split_by_node(src, group=None):
rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group) rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group)
logger.info(f"world_size:{world_size}, rank:{rank}")
if world_size > 1: if world_size > 1:
for s in islice(src, rank, None, world_size): for s in islice(src, rank, None, world_size):
yield s yield s
@ -83,6 +84,7 @@ def single_node_only(src, group=None):
def split_by_worker(src): def split_by_worker(src):
rank, world_size, worker, num_workers = utils.paddle_worker_info() rank, world_size, worker, num_workers = utils.paddle_worker_info()
logger.info(f"num_workers:{num_workers}, worker:{worker}")
if num_workers > 1: if num_workers > 1:
for s in islice(src, worker, None, num_workers): for s in islice(src, worker, None, num_workers):
yield s yield s

@ -16,6 +16,9 @@ import re
import sys import sys
from typing import Any, Callable, Iterator, Optional, Union from typing import Any, Callable, Iterator, Optional, Union
from ..utils.log import Logger
logger = Logger(__name__)
def make_seed(*args): def make_seed(*args):
seed = 0 seed = 0
@ -112,13 +115,14 @@ def paddle_worker_info(group=None):
num_workers = int(os.environ["NUM_WORKERS"]) num_workers = int(os.environ["NUM_WORKERS"])
else: else:
try: try:
import paddle.io.get_worker_info from paddle.io import get_worker_info
worker_info = paddle.io.get_worker_info() worker_info = paddle.io.get_worker_info()
if worker_info is not None: if worker_info is not None:
worker = worker_info.id worker = worker_info.id
num_workers = worker_info.num_workers num_workers = worker_info.num_workers
except ModuleNotFoundError: except ModuleNotFoundError as E:
pass logger.info(f"not found {E}")
exit(-1)
return rank, world_size, worker, num_workers return rank, world_size, worker, num_workers

@ -33,7 +33,7 @@ from ..log import logger
from ..utils import CLI_TIMER from ..utils import CLI_TIMER
from ..utils import stats_wrapper from ..utils import stats_wrapper
from ..utils import timer_register from ..utils import timer_register
from paddlespeech.s2t.audio.transformation import Transformation from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig

@ -104,7 +104,7 @@ class StreamDataLoader():
if self.dist_sampler: if self.dist_sampler:
base_dataset = streamdata.DataPipeline( base_dataset = streamdata.DataPipeline(
streamdata.SimpleShardList(shardlist), streamdata.SimpleShardList(shardlist),
streamdata.split_by_node, streamdata.split_by_node if train_mode else streamdata.placeholder(),
streamdata.split_by_worker, streamdata.split_by_worker,
streamdata.tarfile_to_samples(streamdata.reraise_exception) streamdata.tarfile_to_samples(streamdata.reraise_exception)
) )

Loading…
Cancel
Save