add timer for u2; refactor grad norm type

pull/820/head
Hui Zhang 4 years ago
parent 890a28f9bf
commit 8873ebe38c

@ -34,6 +34,7 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.models.u2 import U2Model from deepspeech.models.u2 import U2Model
from deepspeech.training.optimizer import OptimizerFactory from deepspeech.training.optimizer import OptimizerFactory
from deepspeech.training.scheduler import LRSchedulerFactory from deepspeech.training.scheduler import LRSchedulerFactory
from deepspeech.training.timer import Timer
from deepspeech.training.trainer import Trainer from deepspeech.training.trainer import Trainer
from deepspeech.utils import ctc_utils from deepspeech.utils import ctc_utils
from deepspeech.utils import error_rate from deepspeech.utils import error_rate
@ -184,11 +185,12 @@ class U2Trainer(Trainer):
self.save(tag='init') self.save(tag='init')
self.lr_scheduler.step(self.iteration) self.lr_scheduler.step(self.iteration)
if self.parallel: if self.parallel and hasattr(self.train_loader, 'batch_sampler'):
self.train_loader.batch_sampler.set_epoch(self.epoch) self.train_loader.batch_sampler.set_epoch(self.epoch)
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch: while self.epoch < self.config.training.n_epoch:
with Timer("Epoch-Train Time Cost: {}"):
self.model.train() self.model.train()
try: try:
data_start_time = time.time() data_start_time = time.time()
@ -207,6 +209,7 @@ class U2Trainer(Trainer):
logger.error(e) logger.error(e)
raise e raise e
with Timer("Eval Time Cost: {}"):
total_loss, num_seen_utts = self.valid() total_loss, num_seen_utts = self.valid()
if dist.get_world_size() > 1: if dist.get_world_size() > 1:
num_seen_utts = paddle.to_tensor(num_seen_utts) num_seen_utts = paddle.to_tensor(num_seen_utts)

@ -36,16 +36,16 @@ class CTCLoss(nn.Layer):
f"CTCLoss Loss reduction: {reduction}, div-bs: {batch_average}") f"CTCLoss Loss reduction: {reduction}, div-bs: {batch_average}")
# instance for norm_by_times # instance for norm_by_times
# batchsize for norm_by_batchsize # batch for norm_by_batchsize
# frame for norm_by_total_logits_len # frame for norm_by_total_logits_len
assert grad_norm_type in ('instance', 'batchsize', 'frame', None) assert grad_norm_type in ('instance', 'batch', 'frame', None)
self.norm_by_times = False self.norm_by_times = False
self.norm_by_batchsize = False self.norm_by_batchsize = False
self.norm_by_total_logits_len = False self.norm_by_total_logits_len = False
logger.info(f"CTCLoss Grad Norm Type: {grad_norm_type}") logger.info(f"CTCLoss Grad Norm Type: {grad_norm_type}")
if grad_norm_type == 'instance': if grad_norm_type == 'instance':
self.norm_by_times = True self.norm_by_times = True
if grad_norm_type == 'batchsize': if grad_norm_type == 'batch':
self.norm_by_times = True self.norm_by_times = True
if grad_norm_type == 'frame': if grad_norm_type == 'frame':
self.norm_by_total_logits_len = True self.norm_by_total_logits_len = True

Loading…
Cancel
Save