fix batch sampler set_epoch when epcoh start

pull/1273/head
Hui Zhang 3 years ago
parent 680eac02b9
commit 6f651d762e

@ -240,7 +240,9 @@ class U2Trainer(Trainer):
preprocess_conf=config.preprocess_config,
n_iter_processes=config.num_workers,
subsampling_factor=1,
num_encs=1)
num_encs=1,
dist_sampler=True,
shortest_first=False)
self.valid_loader = BatchDataLoader(
json_file=config.dev_manifest,
@ -259,7 +261,9 @@ class U2Trainer(Trainer):
preprocess_conf=config.preprocess_config,
n_iter_processes=config.num_workers,
subsampling_factor=1,
num_encs=1)
num_encs=1,
dist_sampler=True,
shortest_first=False)
logger.info("Setup train/valid Dataloader!")
else:
decode_batch_size = config.get('decode', dict()).get(

@ -78,7 +78,8 @@ class BatchDataLoader():
load_aux_input: bool=False,
load_aux_output: bool=False,
num_encs: int=1,
dist_sampler: bool=False):
dist_sampler: bool=False,
shortest_first: bool=False):
self.json_file = json_file
self.train_mode = train_mode
self.use_sortagrad = sortagrad == -1 or sortagrad > 0
@ -97,6 +98,7 @@ class BatchDataLoader():
self.load_aux_input = load_aux_input
self.load_aux_output = load_aux_output
self.dist_sampler = dist_sampler
self.shortest_first = shortest_first
# read json data
with jsonlines.open(json_file, 'r') as reader:
@ -113,7 +115,7 @@ class BatchDataLoader():
maxlen_out,
minibatches, # for debug
min_batch_size=mini_batch_size,
shortest_first=self.use_sortagrad,
shortest_first=self.shortest_first or self.use_sortagrad,
count=batch_count,
batch_bins=batch_bins,
batch_frames_in=batch_frames_in,
@ -149,13 +151,13 @@ class BatchDataLoader():
self.reader)
if self.dist_sampler:
self.sampler = DistributedBatchSampler(
self.batch_sampler = DistributedBatchSampler(
dataset=self.dataset,
batch_size=1,
shuffle=not self.use_sortagrad if self.train_mode else False,
drop_last=False, )
else:
self.sampler = BatchSampler(
self.batch_sampler = BatchSampler(
dataset=self.dataset,
batch_size=1,
shuffle=not self.use_sortagrad if self.train_mode else False,
@ -163,7 +165,7 @@ class BatchDataLoader():
self.dataloader = DataLoader(
dataset=self.dataset,
batch_sampler=self.sampler,
batch_sampler=self.batch_sampler,
collate_fn=batch_collate,
num_workers=self.n_iter_processes, )
@ -194,5 +196,6 @@ class BatchDataLoader():
echo += f"load_aux_input: {self.load_aux_input}, "
echo += f"load_aux_output: {self.load_aux_output}, "
echo += f"dist_sampler: {self.dist_sampler}, "
echo += f"shortest_first: {self.shortest_first}, "
echo += f"file: {self.json_file}"
return echo

@ -39,9 +39,6 @@ except ImportError:
except Exception as e:
logger.info("paddlespeech_ctcdecoders not installed!")
#try:
#except Exception as e:
# logger.info("ctcdecoder not installed!")
__all__ = ['CTCDecoder']

@ -67,18 +67,19 @@ class WarmupLR(LRScheduler):
super().__init__(learning_rate, last_epoch, verbose)
def __repr__(self):
return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})"
return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps}, lr={self.base_lr}, last_epoch={self.last_epoch})"
def get_lr(self):
# self.last_epoch start from zero
step_num = self.last_epoch + 1
return self.base_lr * self.warmup_steps**0.5 * min(
step_num**-0.5, step_num * self.warmup_steps**-1.5)
def set_step(self, step: int=None):
'''
It will update the learning rate in optimizer according to current ``epoch`` .
It will update the learning rate in optimizer according to current ``epoch`` .
The new learning rate will take effect on next ``optimizer.step`` .
Args:
step (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.
Returns:
@ -94,7 +95,7 @@ class ConstantLR(LRScheduler):
learning_rate (float): The initial learning rate. It is a python float number.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``ConstantLR`` instance to schedule learning rate.
"""

@ -222,7 +222,7 @@ class Trainer():
batch_sampler = self.train_loader.batch_sampler
if isinstance(batch_sampler, paddle.io.DistributedBatchSampler):
logger.debug(
f"train_loader.batch_sample set epoch: {self.epoch}")
f"train_loader.batch_sample.set_epoch: {self.epoch}")
batch_sampler.set_epoch(self.epoch)
def before_train(self):

Loading…
Cancel
Save