timer info for st,u2 kaldi

pull/820/head
Hui Zhang 3 years ago
parent 28a0a64153
commit 2480be8ebc

@ -32,6 +32,7 @@ from deepspeech.io.dataloader import BatchDataLoader
from deepspeech.models.u2 import U2Model
from deepspeech.training.optimizer import OptimizerFactory
from deepspeech.training.scheduler import LRSchedulerFactory
from deepspeech.training.timer import Timer
from deepspeech.training.trainer import Trainer
from deepspeech.utils import ctc_utils
from deepspeech.utils import error_rate
@ -190,6 +191,7 @@ class U2Trainer(Trainer):
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch:
with Timer("Epoch-Train Time Cost: {}"):
self.model.train()
try:
data_start_time = time.time()
@ -208,6 +210,7 @@ class U2Trainer(Trainer):
logger.error(e)
raise e
with Timer("Eval Time Cost: {}"):
total_loss, num_seen_utts = self.valid()
if dist.get_world_size() > 1:
num_seen_utts = paddle.to_tensor(num_seen_utts)

@ -38,6 +38,7 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.models.u2_st import U2STModel
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.training.scheduler import WarmupLR
from deepspeech.training.timer import Timer
from deepspeech.training.trainer import Trainer
from deepspeech.utils import bleu_score
from deepspeech.utils import ctc_utils
@ -207,6 +208,7 @@ class U2STTrainer(Trainer):
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch:
with Timer("Epoch-Train Time Cost: {}"):
self.model.train()
try:
data_start_time = time.time()
@ -225,6 +227,7 @@ class U2STTrainer(Trainer):
logger.error(e)
raise e
with Timer("Eval Time Cost: {}"):
total_loss, num_seen_utts = self.valid()
if dist.get_world_size() > 1:
num_seen_utts = paddle.to_tensor(num_seen_utts)

Loading…
Cancel
Save