timer info for st,u2 kaldi

pull/820/head
Hui Zhang 3 years ago
parent 28a0a64153
commit 2480be8ebc

@ -32,6 +32,7 @@ from deepspeech.io.dataloader import BatchDataLoader
from deepspeech.models.u2 import U2Model from deepspeech.models.u2 import U2Model
from deepspeech.training.optimizer import OptimizerFactory from deepspeech.training.optimizer import OptimizerFactory
from deepspeech.training.scheduler import LRSchedulerFactory from deepspeech.training.scheduler import LRSchedulerFactory
from deepspeech.training.timer import Timer
from deepspeech.training.trainer import Trainer from deepspeech.training.trainer import Trainer
from deepspeech.utils import ctc_utils from deepspeech.utils import ctc_utils
from deepspeech.utils import error_rate from deepspeech.utils import error_rate
@ -190,6 +191,7 @@ class U2Trainer(Trainer):
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch: while self.epoch < self.config.training.n_epoch:
with Timer("Epoch-Train Time Cost: {}"):
self.model.train() self.model.train()
try: try:
data_start_time = time.time() data_start_time = time.time()
@ -208,6 +210,7 @@ class U2Trainer(Trainer):
logger.error(e) logger.error(e)
raise e raise e
with Timer("Eval Time Cost: {}"):
total_loss, num_seen_utts = self.valid() total_loss, num_seen_utts = self.valid()
if dist.get_world_size() > 1: if dist.get_world_size() > 1:
num_seen_utts = paddle.to_tensor(num_seen_utts) num_seen_utts = paddle.to_tensor(num_seen_utts)

@ -38,6 +38,7 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.models.u2_st import U2STModel from deepspeech.models.u2_st import U2STModel
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.training.scheduler import WarmupLR from deepspeech.training.scheduler import WarmupLR
from deepspeech.training.timer import Timer
from deepspeech.training.trainer import Trainer from deepspeech.training.trainer import Trainer
from deepspeech.utils import bleu_score from deepspeech.utils import bleu_score
from deepspeech.utils import ctc_utils from deepspeech.utils import ctc_utils
@ -207,6 +208,7 @@ class U2STTrainer(Trainer):
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch: while self.epoch < self.config.training.n_epoch:
with Timer("Epoch-Train Time Cost: {}"):
self.model.train() self.model.train()
try: try:
data_start_time = time.time() data_start_time = time.time()
@ -225,6 +227,7 @@ class U2STTrainer(Trainer):
logger.error(e) logger.error(e)
raise e raise e
with Timer("Eval Time Cost: {}"):
total_loss, num_seen_utts = self.valid() total_loss, num_seen_utts = self.valid()
if dist.get_world_size() > 1: if dist.get_world_size() > 1:
num_seen_utts = paddle.to_tensor(num_seen_utts) num_seen_utts = paddle.to_tensor(num_seen_utts)

Loading…
Cancel
Save