[s2t] DataLoader with BatchSampler or DistributeBatchSampler (#1242)

* batchsampler or distributebatchsampler

* format
pull/1245/head
Hui Zhang 4 years ago committed by GitHub
parent 6d93f3e55e
commit c81a3f0f83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -292,7 +292,8 @@ class U2STTrainer(Trainer):
n_iter_processes=config.collator.num_workers,
subsampling_factor=1,
load_aux_output=load_transcript,
num_encs=1)
num_encs=1,
dist_sampler=True)
self.valid_loader = BatchDataLoader(
json_file=config.data.dev_manifest,
@ -313,7 +314,8 @@ class U2STTrainer(Trainer):
n_iter_processes=config.collator.num_workers,
subsampling_factor=1,
load_aux_output=load_transcript,
num_encs=1)
num_encs=1,
dist_sampler=True)
logger.info("Setup train/valid Dataloader!")
else:
# test dataset, return raw text
@ -335,7 +337,8 @@ class U2STTrainer(Trainer):
augmentation_config, # aug will be off when train_mode=False
n_iter_processes=config.collator.num_workers,
subsampling_factor=1,
num_encs=1)
num_encs=1,
dist_sampler=False)
logger.info("Setup test Dataloader!")
@ -542,7 +545,8 @@ class U2STTester(U2STTrainer):
len_refs += metrics['len_refs']
num_ins += metrics['num_ins']
rtf = num_time / (num_frames * stride_ms)
logger.info("RTF: %f, instance (%d), batch BELU = %f" % (rtf, num_ins, bleu))
logger.info("RTF: %f, instance (%d), batch BELU = %f" %
(rtf, num_ins, bleu))
rtf = num_time / (num_frames * stride_ms)
msg = "Test: "

@ -65,8 +65,9 @@ class CustomConverter():
# text data (output): (text_len, )
ys_data.append(ud)
assert xs_data[0][0] is not None, "please check Reader and Augmentation impl."
assert xs_data[0][
0] is not None, "please check Reader and Augmentation impl."
xs_pad, ilens = [], []
for xs in xs_data:
# perform subsampling
@ -79,22 +80,26 @@ class CustomConverter():
# perform padding and convert to tensor
# currently only support real number
xs_pad.append(pad_list(xs, 0).astype(self.dtype))
if not self.load_aux_input:
xs_pad, ilens = xs_pad[0], ilens[0]
break
# NOTE: this is for multi-output (e.g., speech translation)
ys_pad, olens = [], []
for ys in ys_data:
ys_pad.append(pad_list(
[np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys],
self.ignore_id))
ys_pad.append(
pad_list([
np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys
], self.ignore_id))
olens.append(
np.array([
y[0].shape[0] if isinstance(y, tuple) else y.shape[0]
for y in ys
]))
olens.append(np.array(
[y[0].shape[0] if isinstance(y, tuple) else y.shape[0] for y in ys]))
if not self.load_aux_output:
ys_pad, olens = ys_pad[0], olens[0]
break

@ -18,6 +18,7 @@ from typing import Text
import jsonlines
import numpy as np
from paddle.io import BatchSampler
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
@ -76,7 +77,8 @@ class BatchDataLoader():
subsampling_factor: int=1,
load_aux_input: bool=False,
load_aux_output: bool=False,
num_encs: int=1):
num_encs: int=1,
dist_sampler: bool=False):
self.json_file = json_file
self.train_mode = train_mode
self.use_sortagrad = sortagrad == -1 or sortagrad > 0
@ -94,6 +96,7 @@ class BatchDataLoader():
self.n_iter_processes = n_iter_processes
self.load_aux_input = load_aux_input
self.load_aux_output = load_aux_output
self.dist_sampler = dist_sampler
# read json data
with jsonlines.open(json_file, 'r') as reader:
@ -145,11 +148,18 @@ class BatchDataLoader():
self.dataset = TransformDataset(self.minibaches, self.converter,
self.reader)
self.sampler = DistributedBatchSampler(
dataset=self.dataset,
batch_size=1,
shuffle=not self.use_sortagrad if self.train_mode else False,
)
if self.dist_sampler:
self.sampler = DistributedBatchSampler(
dataset=self.dataset,
batch_size=1,
shuffle=not self.use_sortagrad if self.train_mode else False,
drop_last=False, )
else:
self.sampler = BatchSampler(
dataset=self.dataset,
batch_size=1,
shuffle=not self.use_sortagrad if self.train_mode else False,
drop_last=False, )
self.dataloader = DataLoader(
dataset=self.dataset,
@ -181,5 +191,8 @@ class BatchDataLoader():
echo += f"subsampling_factor: {self.subsampling_factor}, "
echo += f"num_encs: {self.num_encs}, "
echo += f"num_workers: {self.n_iter_processes}, "
echo += f"load_aux_input: {self.load_aux_input}, "
echo += f"load_aux_output: {self.load_aux_output}, "
echo += f"dist_sampler: {self.dist_sampler}, "
echo += f"file: {self.json_file}"
return echo

@ -203,12 +203,15 @@ def evaluate(args):
get_tone_ids = True
if args.lang == 'zh':
input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids)
sentence,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
if get_tone_ids:
tone_ids = input_ids["tone_ids"]
elif args.lang == 'en':
input_ids = frontend.get_input_ids(sentence, merge_sentences=merge_sentences)
input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en'}!")

Loading…
Cancel
Save