|
|
@ -22,17 +22,16 @@ import paddle
|
|
|
|
from paddle.io import BatchSampler
|
|
|
|
from paddle.io import BatchSampler
|
|
|
|
from paddle.io import DataLoader
|
|
|
|
from paddle.io import DataLoader
|
|
|
|
from paddle.io import DistributedBatchSampler
|
|
|
|
from paddle.io import DistributedBatchSampler
|
|
|
|
|
|
|
|
from yacs.config import CfgNode
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import paddlespeech.audio.streamdata as streamdata
|
|
|
|
|
|
|
|
from paddlespeech.audio.text.text_featurizer import TextFeaturizer
|
|
|
|
from paddlespeech.s2t.io.batchfy import make_batchset
|
|
|
|
from paddlespeech.s2t.io.batchfy import make_batchset
|
|
|
|
from paddlespeech.s2t.io.converter import CustomConverter
|
|
|
|
from paddlespeech.s2t.io.converter import CustomConverter
|
|
|
|
from paddlespeech.s2t.io.dataset import TransformDataset
|
|
|
|
from paddlespeech.s2t.io.dataset import TransformDataset
|
|
|
|
from paddlespeech.s2t.io.reader import LoadInputsAndTargets
|
|
|
|
from paddlespeech.s2t.io.reader import LoadInputsAndTargets
|
|
|
|
from paddlespeech.s2t.utils.log import Log
|
|
|
|
from paddlespeech.s2t.utils.log import Log
|
|
|
|
|
|
|
|
|
|
|
|
import paddlespeech.audio.streamdata as streamdata
|
|
|
|
|
|
|
|
from paddlespeech.audio.text.text_featurizer import TextFeaturizer
|
|
|
|
|
|
|
|
from yacs.config import CfgNode
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ["BatchDataLoader", "StreamDataLoader"]
|
|
|
|
__all__ = ["BatchDataLoader", "StreamDataLoader"]
|
|
|
|
|
|
|
|
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
@ -61,6 +60,7 @@ def batch_collate(x):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
return x[0]
|
|
|
|
return x[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def read_preprocess_cfg(preprocess_conf_file):
|
|
|
|
def read_preprocess_cfg(preprocess_conf_file):
|
|
|
|
augment_conf = dict()
|
|
|
|
augment_conf = dict()
|
|
|
|
preprocess_cfg = CfgNode(new_allowed=True)
|
|
|
|
preprocess_cfg = CfgNode(new_allowed=True)
|
|
|
@ -84,6 +84,7 @@ def read_preprocess_cfg(preprocess_conf_file):
|
|
|
|
augment_conf['t_replace_with_zero'] = process['replace_with_zero']
|
|
|
|
augment_conf['t_replace_with_zero'] = process['replace_with_zero']
|
|
|
|
return augment_conf
|
|
|
|
return augment_conf
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StreamDataLoader():
|
|
|
|
class StreamDataLoader():
|
|
|
|
def __init__(self,
|
|
|
|
def __init__(self,
|
|
|
|
manifest_file: str,
|
|
|
|
manifest_file: str,
|
|
|
@ -131,10 +132,14 @@ class StreamDataLoader():
|
|
|
|
world_size = paddle.distributed.get_world_size()
|
|
|
|
world_size = paddle.distributed.get_world_size()
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
logger.warninig(e)
|
|
|
|
logger.warninig(e)
|
|
|
|
logger.warninig("can not get world_size using paddle.distributed.get_world_size(), use world_size=1")
|
|
|
|
logger.warninig(
|
|
|
|
assert(len(shardlist) >= world_size, "the length of shard list should >= number of gpus/xpus/...")
|
|
|
|
"can not get world_size using paddle.distributed.get_world_size(), use world_size=1"
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
assert (len(shardlist) >= world_size,
|
|
|
|
|
|
|
|
"the length of shard list should >= number of gpus/xpus/...")
|
|
|
|
|
|
|
|
|
|
|
|
update_n_iter_processes = int(max(min(len(shardlist)/world_size - 1, self.n_iter_processes), 0))
|
|
|
|
update_n_iter_processes = int(
|
|
|
|
|
|
|
|
max(min(len(shardlist) / world_size - 1, self.n_iter_processes), 0))
|
|
|
|
logger.info(f"update_n_iter_processes {update_n_iter_processes}")
|
|
|
|
logger.info(f"update_n_iter_processes {update_n_iter_processes}")
|
|
|
|
if update_n_iter_processes != self.n_iter_processes:
|
|
|
|
if update_n_iter_processes != self.n_iter_processes:
|
|
|
|
self.n_iter_processes = update_n_iter_processes
|
|
|
|
self.n_iter_processes = update_n_iter_processes
|
|
|
@ -142,44 +147,50 @@ class StreamDataLoader():
|
|
|
|
|
|
|
|
|
|
|
|
if self.dist_sampler:
|
|
|
|
if self.dist_sampler:
|
|
|
|
base_dataset = streamdata.DataPipeline(
|
|
|
|
base_dataset = streamdata.DataPipeline(
|
|
|
|
streamdata.SimpleShardList(shardlist),
|
|
|
|
streamdata.SimpleShardList(shardlist), streamdata.split_by_node
|
|
|
|
streamdata.split_by_node if train_mode else streamdata.placeholder(),
|
|
|
|
if train_mode else streamdata.placeholder(),
|
|
|
|
streamdata.split_by_worker,
|
|
|
|
streamdata.split_by_worker,
|
|
|
|
streamdata.tarfile_to_samples(streamdata.reraise_exception)
|
|
|
|
streamdata.tarfile_to_samples(streamdata.reraise_exception))
|
|
|
|
)
|
|
|
|
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
base_dataset = streamdata.DataPipeline(
|
|
|
|
base_dataset = streamdata.DataPipeline(
|
|
|
|
streamdata.SimpleShardList(shardlist),
|
|
|
|
streamdata.SimpleShardList(shardlist),
|
|
|
|
streamdata.split_by_worker,
|
|
|
|
streamdata.split_by_worker,
|
|
|
|
streamdata.tarfile_to_samples(streamdata.reraise_exception)
|
|
|
|
streamdata.tarfile_to_samples(streamdata.reraise_exception))
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.dataset = base_dataset.append_list(
|
|
|
|
self.dataset = base_dataset.append_list(
|
|
|
|
streamdata.audio_tokenize(symbol_table),
|
|
|
|
streamdata.audio_tokenize(symbol_table),
|
|
|
|
streamdata.audio_data_filter(frame_shift=frame_shift, max_length=maxlen_in, min_length=minlen_in, token_max_length=maxlen_out, token_min_length=minlen_out),
|
|
|
|
streamdata.audio_data_filter(
|
|
|
|
|
|
|
|
frame_shift=frame_shift,
|
|
|
|
|
|
|
|
max_length=maxlen_in,
|
|
|
|
|
|
|
|
min_length=minlen_in,
|
|
|
|
|
|
|
|
token_max_length=maxlen_out,
|
|
|
|
|
|
|
|
token_min_length=minlen_out),
|
|
|
|
streamdata.audio_resample(resample_rate=resample_rate),
|
|
|
|
streamdata.audio_resample(resample_rate=resample_rate),
|
|
|
|
streamdata.audio_compute_fbank(num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither),
|
|
|
|
streamdata.audio_compute_fbank(
|
|
|
|
streamdata.audio_spec_aug(**augment_conf) if train_mode else streamdata.placeholder(), # num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80)
|
|
|
|
num_mel_bins=num_mel_bins,
|
|
|
|
|
|
|
|
frame_length=frame_length,
|
|
|
|
|
|
|
|
frame_shift=frame_shift,
|
|
|
|
|
|
|
|
dither=dither),
|
|
|
|
|
|
|
|
streamdata.audio_spec_aug(**augment_conf)
|
|
|
|
|
|
|
|
if train_mode else streamdata.placeholder(
|
|
|
|
|
|
|
|
), # num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80)
|
|
|
|
streamdata.shuffle(shuffle_size),
|
|
|
|
streamdata.shuffle(shuffle_size),
|
|
|
|
streamdata.sort(sort_size=sort_size),
|
|
|
|
streamdata.sort(sort_size=sort_size),
|
|
|
|
streamdata.batched(batch_size),
|
|
|
|
streamdata.batched(batch_size),
|
|
|
|
streamdata.audio_padding(),
|
|
|
|
streamdata.audio_padding(),
|
|
|
|
streamdata.audio_cmvn(cmvn_file)
|
|
|
|
streamdata.audio_cmvn(cmvn_file))
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if paddle.__version__ >= '2.3.2':
|
|
|
|
if paddle.__version__ >= '2.3.2':
|
|
|
|
self.loader = streamdata.WebLoader(
|
|
|
|
self.loader = streamdata.WebLoader(
|
|
|
|
self.dataset,
|
|
|
|
self.dataset,
|
|
|
|
num_workers=self.n_iter_processes,
|
|
|
|
num_workers=self.n_iter_processes,
|
|
|
|
prefetch_factor=self.prefetch_factor,
|
|
|
|
prefetch_factor=self.prefetch_factor,
|
|
|
|
batch_size=None
|
|
|
|
batch_size=None)
|
|
|
|
)
|
|
|
|
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
self.loader = streamdata.WebLoader(
|
|
|
|
self.loader = streamdata.WebLoader(
|
|
|
|
self.dataset,
|
|
|
|
self.dataset,
|
|
|
|
num_workers=self.n_iter_processes,
|
|
|
|
num_workers=self.n_iter_processes,
|
|
|
|
batch_size=None
|
|
|
|
batch_size=None)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
def __iter__(self):
|
|
|
|
return self.loader.__iter__()
|
|
|
|
return self.loader.__iter__()
|
|
|
@ -188,7 +199,9 @@ class StreamDataLoader():
|
|
|
|
return self.__iter__()
|
|
|
|
return self.__iter__()
|
|
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
def __len__(self):
|
|
|
|
logger.info("Stream dataloader does not support calculate the length of the dataset")
|
|
|
|
logger.info(
|
|
|
|
|
|
|
|
"Stream dataloader does not support calculate the length of the dataset"
|
|
|
|
|
|
|
|
)
|
|
|
|
return -1
|
|
|
|
return -1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -358,7 +371,9 @@ class DataLoaderFactory():
|
|
|
|
config['maxlen_out'] = float('inf')
|
|
|
|
config['maxlen_out'] = float('inf')
|
|
|
|
config['dist_sampler'] = False
|
|
|
|
config['dist_sampler'] = False
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
raise KeyError("not valid mode type!!, please input one of 'train, valid, test, align'")
|
|
|
|
raise KeyError(
|
|
|
|
|
|
|
|
"not valid mode type!!, please input one of 'train, valid, test, align'"
|
|
|
|
|
|
|
|
)
|
|
|
|
return StreamDataLoader(
|
|
|
|
return StreamDataLoader(
|
|
|
|
manifest_file=config.manifest,
|
|
|
|
manifest_file=config.manifest,
|
|
|
|
train_mode=config.train_mode,
|
|
|
|
train_mode=config.train_mode,
|
|
|
@ -380,8 +395,7 @@ class DataLoaderFactory():
|
|
|
|
prefetch_factor=config.prefetch_factor,
|
|
|
|
prefetch_factor=config.prefetch_factor,
|
|
|
|
dist_sampler=config.dist_sampler,
|
|
|
|
dist_sampler=config.dist_sampler,
|
|
|
|
cmvn_file=config.cmvn_file,
|
|
|
|
cmvn_file=config.cmvn_file,
|
|
|
|
vocab_filepath=config.vocab_filepath,
|
|
|
|
vocab_filepath=config.vocab_filepath, )
|
|
|
|
)
|
|
|
|
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
if mode == 'train':
|
|
|
|
if mode == 'train':
|
|
|
|
config['manifest'] = config.train_manifest
|
|
|
|
config['manifest'] = config.train_manifest
|
|
|
@ -427,7 +441,9 @@ class DataLoaderFactory():
|
|
|
|
config['dist_sampler'] = False
|
|
|
|
config['dist_sampler'] = False
|
|
|
|
config['shortest_first'] = False
|
|
|
|
config['shortest_first'] = False
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
raise KeyError("not valid mode type!!, please input one of 'train, valid, test, align'")
|
|
|
|
raise KeyError(
|
|
|
|
|
|
|
|
"not valid mode type!!, please input one of 'train, valid, test, align'"
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return BatchDataLoader(
|
|
|
|
return BatchDataLoader(
|
|
|
|
json_file=config.manifest,
|
|
|
|
json_file=config.manifest,
|
|
|
@ -450,4 +466,3 @@ class DataLoaderFactory():
|
|
|
|
num_encs=config.num_encs,
|
|
|
|
num_encs=config.num_encs,
|
|
|
|
dist_sampler=config.dist_sampler,
|
|
|
|
dist_sampler=config.dist_sampler,
|
|
|
|
shortest_first=config.shortest_first)
|
|
|
|
shortest_first=config.shortest_first)
|
|
|
|
|
|
|
|
|
|
|
|