add timer for u2; refactor grad norm type

pull/820/head
Hui Zhang 3 years ago
parent 890a28f9bf
commit 8873ebe38c

@ -34,6 +34,7 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.models.u2 import U2Model from deepspeech.models.u2 import U2Model
from deepspeech.training.optimizer import OptimizerFactory from deepspeech.training.optimizer import OptimizerFactory
from deepspeech.training.scheduler import LRSchedulerFactory from deepspeech.training.scheduler import LRSchedulerFactory
from deepspeech.training.timer import Timer
from deepspeech.training.trainer import Trainer from deepspeech.training.trainer import Trainer
from deepspeech.utils import ctc_utils from deepspeech.utils import ctc_utils
from deepspeech.utils import error_rate from deepspeech.utils import error_rate
@ -184,40 +185,42 @@ class U2Trainer(Trainer):
self.save(tag='init') self.save(tag='init')
self.lr_scheduler.step(self.iteration) self.lr_scheduler.step(self.iteration)
if self.parallel: if self.parallel and hasattr(self.train_loader, 'batch_sampler'):
self.train_loader.batch_sampler.set_epoch(self.epoch) self.train_loader.batch_sampler.set_epoch(self.epoch)
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch: while self.epoch < self.config.training.n_epoch:
self.model.train() with Timer("Epoch-Train Time Cost: {}"):
try: self.model.train()
data_start_time = time.time() try:
for batch_index, batch in enumerate(self.train_loader):
dataload_time = time.time() - data_start_time
msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "batch : {}/{}, ".format(batch_index + 1,
len(self.train_loader))
msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
msg += "data time: {:>.3f}s, ".format(dataload_time)
self.train_batch(batch_index, batch, msg)
data_start_time = time.time() data_start_time = time.time()
except Exception as e: for batch_index, batch in enumerate(self.train_loader):
logger.error(e) dataload_time = time.time() - data_start_time
raise e msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch)
total_loss, num_seen_utts = self.valid() msg += "step: {}, ".format(self.iteration)
if dist.get_world_size() > 1: msg += "batch : {}/{}, ".format(batch_index + 1,
num_seen_utts = paddle.to_tensor(num_seen_utts) len(self.train_loader))
# the default operator in all_reduce function is sum. msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
dist.all_reduce(num_seen_utts) msg += "data time: {:>.3f}s, ".format(dataload_time)
total_loss = paddle.to_tensor(total_loss) self.train_batch(batch_index, batch, msg)
dist.all_reduce(total_loss) data_start_time = time.time()
cv_loss = total_loss / num_seen_utts except Exception as e:
cv_loss = float(cv_loss) logger.error(e)
else: raise e
cv_loss = total_loss / num_seen_utts
with Timer("Eval Time Cost: {}"):
total_loss, num_seen_utts = self.valid()
if dist.get_world_size() > 1:
num_seen_utts = paddle.to_tensor(num_seen_utts)
# the default operator in all_reduce function is sum.
dist.all_reduce(num_seen_utts)
total_loss = paddle.to_tensor(total_loss)
dist.all_reduce(total_loss)
cv_loss = total_loss / num_seen_utts
cv_loss = float(cv_loss)
else:
cv_loss = total_loss / num_seen_utts
logger.info( logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))

@ -36,16 +36,16 @@ class CTCLoss(nn.Layer):
f"CTCLoss Loss reduction: {reduction}, div-bs: {batch_average}") f"CTCLoss Loss reduction: {reduction}, div-bs: {batch_average}")
# instance for norm_by_times # instance for norm_by_times
# batchsize for norm_by_batchsize # batch for norm_by_batchsize
# frame for norm_by_total_logits_len # frame for norm_by_total_logits_len
assert grad_norm_type in ('instance', 'batchsize', 'frame', None) assert grad_norm_type in ('instance', 'batch', 'frame', None)
self.norm_by_times = False self.norm_by_times = False
self.norm_by_batchsize = False self.norm_by_batchsize = False
self.norm_by_total_logits_len = False self.norm_by_total_logits_len = False
logger.info(f"CTCLoss Grad Norm Type: {grad_norm_type}") logger.info(f"CTCLoss Grad Norm Type: {grad_norm_type}")
if grad_norm_type == 'instance': if grad_norm_type == 'instance':
self.norm_by_times = True self.norm_by_times = True
if grad_norm_type == 'batchsize': if grad_norm_type == 'batch':
self.norm_by_times = True self.norm_by_times = True
if grad_norm_type == 'frame': if grad_norm_type == 'frame':
self.norm_by_total_logits_len = True self.norm_by_total_logits_len = True

Loading…
Cancel
Save