|
|
@ -75,7 +75,7 @@ class U2Trainer(Trainer):
|
|
|
|
def __init__(self, config, args):
|
|
|
|
def __init__(self, config, args):
|
|
|
|
super().__init__(config, args)
|
|
|
|
super().__init__(config, args)
|
|
|
|
|
|
|
|
|
|
|
|
def train_batch(self, batch_data, msg):
|
|
|
|
def train_batch(self, batch_index, batch_data, msg):
|
|
|
|
train_conf = self.config.training
|
|
|
|
train_conf = self.config.training
|
|
|
|
self.model.train()
|
|
|
|
self.model.train()
|
|
|
|
|
|
|
|
|
|
|
@ -93,7 +93,7 @@ class U2Trainer(Trainer):
|
|
|
|
'train_ctc_loss': float(ctc_loss),
|
|
|
|
'train_ctc_loss': float(ctc_loss),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if (self.iteration + 1) % train_conf.accum_grad == 0:
|
|
|
|
if (batch_index + 1) % train_conf.accum_grad == 0:
|
|
|
|
if dist.get_rank() == 0 and self.visualizer:
|
|
|
|
if dist.get_rank() == 0 and self.visualizer:
|
|
|
|
for k, v in losses_np.items():
|
|
|
|
for k, v in losses_np.items():
|
|
|
|
self.visualizer.add_scalar("train/{}".format(k), v,
|
|
|
|
self.visualizer.add_scalar("train/{}".format(k), v,
|
|
|
@ -105,7 +105,7 @@ class U2Trainer(Trainer):
|
|
|
|
|
|
|
|
|
|
|
|
iteration_time = time.time() - start
|
|
|
|
iteration_time = time.time() - start
|
|
|
|
|
|
|
|
|
|
|
|
if (self.iteration + 1) % train_conf.log_interval == 0:
|
|
|
|
if (batch_index + 1) % train_conf.log_interval == 0:
|
|
|
|
msg += "time: {:>.3f}s, ".format(iteration_time)
|
|
|
|
msg += "time: {:>.3f}s, ".format(iteration_time)
|
|
|
|
msg += "batch size: {}, ".format(self.config.data.batch_size)
|
|
|
|
msg += "batch size: {}, ".format(self.config.data.batch_size)
|
|
|
|
msg += "accum: {}, ".format(train_conf.accum_grad)
|
|
|
|
msg += "accum: {}, ".format(train_conf.accum_grad)
|
|
|
@ -136,14 +136,14 @@ class U2Trainer(Trainer):
|
|
|
|
while self.epoch < self.config.training.n_epoch:
|
|
|
|
while self.epoch < self.config.training.n_epoch:
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
data_start_time = time.time()
|
|
|
|
data_start_time = time.time()
|
|
|
|
for batch in self.train_loader:
|
|
|
|
for batch_index, batch in enumerate(self.train_loader):
|
|
|
|
dataload_time = time.time() - data_start_time
|
|
|
|
dataload_time = time.time() - data_start_time
|
|
|
|
msg = "Train: Rank: {}, ".format(dist.get_rank())
|
|
|
|
msg = "Train: Rank: {}, ".format(dist.get_rank())
|
|
|
|
msg += "epoch: {}, ".format(self.epoch)
|
|
|
|
msg += "epoch: {}, ".format(self.epoch)
|
|
|
|
msg += "step: {}, ".format(self.iteration)
|
|
|
|
msg += "step: {}, ".format(self.iteration)
|
|
|
|
msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
|
|
|
|
msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
|
|
|
|
msg += "dataloader time: {:>.3f}s, ".format(dataload_time)
|
|
|
|
msg += "dataloader time: {:>.3f}s, ".format(dataload_time)
|
|
|
|
self.train_batch(batch, msg)
|
|
|
|
self.train_batch(batch_index, batch, msg)
|
|
|
|
data_start_time = time.time()
|
|
|
|
data_start_time = time.time()
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
self.logger.error(e)
|
|
|
|
self.logger.error(e)
|
|
|
|