[s2t] DataLoader with BatchSampler or DistributeBatchSampler (#1242)

* batchsampler or distributebatchsampler

* format
pull/1245/head
Hui Zhang 4 years ago committed by GitHub
parent 6d93f3e55e
commit c81a3f0f83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -292,7 +292,8 @@ class U2STTrainer(Trainer):
n_iter_processes=config.collator.num_workers, n_iter_processes=config.collator.num_workers,
subsampling_factor=1, subsampling_factor=1,
load_aux_output=load_transcript, load_aux_output=load_transcript,
num_encs=1) num_encs=1,
dist_sampler=True)
self.valid_loader = BatchDataLoader( self.valid_loader = BatchDataLoader(
json_file=config.data.dev_manifest, json_file=config.data.dev_manifest,
@ -313,7 +314,8 @@ class U2STTrainer(Trainer):
n_iter_processes=config.collator.num_workers, n_iter_processes=config.collator.num_workers,
subsampling_factor=1, subsampling_factor=1,
load_aux_output=load_transcript, load_aux_output=load_transcript,
num_encs=1) num_encs=1,
dist_sampler=True)
logger.info("Setup train/valid Dataloader!") logger.info("Setup train/valid Dataloader!")
else: else:
# test dataset, return raw text # test dataset, return raw text
@ -335,7 +337,8 @@ class U2STTrainer(Trainer):
augmentation_config, # aug will be off when train_mode=False augmentation_config, # aug will be off when train_mode=False
n_iter_processes=config.collator.num_workers, n_iter_processes=config.collator.num_workers,
subsampling_factor=1, subsampling_factor=1,
num_encs=1) num_encs=1,
dist_sampler=False)
logger.info("Setup test Dataloader!") logger.info("Setup test Dataloader!")
@ -542,7 +545,8 @@ class U2STTester(U2STTrainer):
len_refs += metrics['len_refs'] len_refs += metrics['len_refs']
num_ins += metrics['num_ins'] num_ins += metrics['num_ins']
rtf = num_time / (num_frames * stride_ms) rtf = num_time / (num_frames * stride_ms)
logger.info("RTF: %f, instance (%d), batch BELU = %f" % (rtf, num_ins, bleu)) logger.info("RTF: %f, instance (%d), batch BELU = %f" %
(rtf, num_ins, bleu))
rtf = num_time / (num_frames * stride_ms) rtf = num_time / (num_frames * stride_ms)
msg = "Test: " msg = "Test: "

@ -65,7 +65,8 @@ class CustomConverter():
# text data (output): (text_len, ) # text data (output): (text_len, )
ys_data.append(ud) ys_data.append(ud)
assert xs_data[0][0] is not None, "please check Reader and Augmentation impl." assert xs_data[0][
0] is not None, "please check Reader and Augmentation impl."
xs_pad, ilens = [], [] xs_pad, ilens = [], []
for xs in xs_data: for xs in xs_data:
@ -88,12 +89,16 @@ class CustomConverter():
ys_pad, olens = [], [] ys_pad, olens = [], []
for ys in ys_data: for ys in ys_data:
ys_pad.append(pad_list( ys_pad.append(
[np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys], pad_list([
self.ignore_id)) np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys
], self.ignore_id))
olens.append(np.array(
[y[0].shape[0] if isinstance(y, tuple) else y.shape[0] for y in ys])) olens.append(
np.array([
y[0].shape[0] if isinstance(y, tuple) else y.shape[0]
for y in ys
]))
if not self.load_aux_output: if not self.load_aux_output:
ys_pad, olens = ys_pad[0], olens[0] ys_pad, olens = ys_pad[0], olens[0]

@ -18,6 +18,7 @@ from typing import Text
import jsonlines import jsonlines
import numpy as np import numpy as np
from paddle.io import BatchSampler
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler from paddle.io import DistributedBatchSampler
@ -76,7 +77,8 @@ class BatchDataLoader():
subsampling_factor: int=1, subsampling_factor: int=1,
load_aux_input: bool=False, load_aux_input: bool=False,
load_aux_output: bool=False, load_aux_output: bool=False,
num_encs: int=1): num_encs: int=1,
dist_sampler: bool=False):
self.json_file = json_file self.json_file = json_file
self.train_mode = train_mode self.train_mode = train_mode
self.use_sortagrad = sortagrad == -1 or sortagrad > 0 self.use_sortagrad = sortagrad == -1 or sortagrad > 0
@ -94,6 +96,7 @@ class BatchDataLoader():
self.n_iter_processes = n_iter_processes self.n_iter_processes = n_iter_processes
self.load_aux_input = load_aux_input self.load_aux_input = load_aux_input
self.load_aux_output = load_aux_output self.load_aux_output = load_aux_output
self.dist_sampler = dist_sampler
# read json data # read json data
with jsonlines.open(json_file, 'r') as reader: with jsonlines.open(json_file, 'r') as reader:
@ -145,11 +148,18 @@ class BatchDataLoader():
self.dataset = TransformDataset(self.minibaches, self.converter, self.dataset = TransformDataset(self.minibaches, self.converter,
self.reader) self.reader)
self.sampler = DistributedBatchSampler( if self.dist_sampler:
dataset=self.dataset, self.sampler = DistributedBatchSampler(
batch_size=1, dataset=self.dataset,
shuffle=not self.use_sortagrad if self.train_mode else False, batch_size=1,
) shuffle=not self.use_sortagrad if self.train_mode else False,
drop_last=False, )
else:
self.sampler = BatchSampler(
dataset=self.dataset,
batch_size=1,
shuffle=not self.use_sortagrad if self.train_mode else False,
drop_last=False, )
self.dataloader = DataLoader( self.dataloader = DataLoader(
dataset=self.dataset, dataset=self.dataset,
@ -181,5 +191,8 @@ class BatchDataLoader():
echo += f"subsampling_factor: {self.subsampling_factor}, " echo += f"subsampling_factor: {self.subsampling_factor}, "
echo += f"num_encs: {self.num_encs}, " echo += f"num_encs: {self.num_encs}, "
echo += f"num_workers: {self.n_iter_processes}, " echo += f"num_workers: {self.n_iter_processes}, "
echo += f"load_aux_input: {self.load_aux_input}, "
echo += f"load_aux_output: {self.load_aux_output}, "
echo += f"dist_sampler: {self.dist_sampler}, "
echo += f"file: {self.json_file}" echo += f"file: {self.json_file}"
return echo return echo

@ -203,12 +203,15 @@ def evaluate(args):
get_tone_ids = True get_tone_ids = True
if args.lang == 'zh': if args.lang == 'zh':
input_ids = frontend.get_input_ids( input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) sentence,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"] phone_ids = input_ids["phone_ids"]
if get_tone_ids: if get_tone_ids:
tone_ids = input_ids["tone_ids"] tone_ids = input_ids["tone_ids"]
elif args.lang == 'en': elif args.lang == 'en':
input_ids = frontend.get_input_ids(sentence, merge_sentences=merge_sentences) input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"] phone_ids = input_ids["phone_ids"]
else: else:
print("lang should in {'zh', 'en'}!") print("lang should in {'zh', 'en'}!")

Loading…
Cancel
Save