|
|
|
@ -17,6 +17,7 @@ import os
|
|
|
|
|
import sys
|
|
|
|
|
import time
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
from contextlib import nullcontext
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Optional
|
|
|
|
@ -36,6 +37,8 @@ from deepspeech.training.optimizer import OptimizerFactory
|
|
|
|
|
from deepspeech.training.scheduler import LRSchedulerFactory
|
|
|
|
|
from deepspeech.training.timer import Timer
|
|
|
|
|
from deepspeech.training.trainer import Trainer
|
|
|
|
|
from deepspeech.training.reporter import report
|
|
|
|
|
from deepspeech.training.reporter import ObsScope
|
|
|
|
|
from deepspeech.utils import ctc_utils
|
|
|
|
|
from deepspeech.utils import error_rate
|
|
|
|
|
from deepspeech.utils import layer_tools
|
|
|
|
@ -121,12 +124,11 @@ class U2Trainer(Trainer):
|
|
|
|
|
iteration_time = time.time() - start
|
|
|
|
|
|
|
|
|
|
if (batch_index + 1) % train_conf.log_interval == 0:
|
|
|
|
|
msg += "train time: {:>.3f}s, ".format(iteration_time)
|
|
|
|
|
msg += "batch size: {}, ".format(self.config.collator.batch_size)
|
|
|
|
|
msg += "accum: {}, ".format(train_conf.accum_grad)
|
|
|
|
|
msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
|
|
|
|
for k, v in losses_np.items())
|
|
|
|
|
logger.info(msg)
|
|
|
|
|
for k, v in losses_np.items():
|
|
|
|
|
report(k, v)
|
|
|
|
|
report("batch_size", self.config.collator.batch_size)
|
|
|
|
|
report("accum", train_conf.accum_grad)
|
|
|
|
|
report("step_cost", iteration_time)
|
|
|
|
|
|
|
|
|
|
if dist.get_rank() == 0 and self.visualizer:
|
|
|
|
|
losses_np_v = losses_np.copy()
|
|
|
|
@ -199,15 +201,25 @@ class U2Trainer(Trainer):
|
|
|
|
|
data_start_time = time.time()
|
|
|
|
|
for batch_index, batch in enumerate(self.train_loader):
|
|
|
|
|
dataload_time = time.time() - data_start_time
|
|
|
|
|
msg = "Train: Rank: {}, ".format(dist.get_rank())
|
|
|
|
|
msg += "epoch: {}, ".format(self.epoch)
|
|
|
|
|
msg += "step: {}, ".format(self.iteration)
|
|
|
|
|
msg += "batch : {}/{}, ".format(batch_index + 1,
|
|
|
|
|
len(self.train_loader))
|
|
|
|
|
msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
|
|
|
|
|
msg += "data time: {:>.3f}s, ".format(dataload_time)
|
|
|
|
|
self.train_batch(batch_index, batch, msg)
|
|
|
|
|
self.after_train_batch()
|
|
|
|
|
msg = "Train:"
|
|
|
|
|
observation = OrderedDict()
|
|
|
|
|
with ObsScope(observation):
|
|
|
|
|
report("Rank", dist.get_rank())
|
|
|
|
|
report("epoch", self.epoch)
|
|
|
|
|
report('step', self.iteration)
|
|
|
|
|
report('step/total', (batch_index + 1) / len(self.train_loader))
|
|
|
|
|
report("lr", self.lr_scheduler())
|
|
|
|
|
self.train_batch(batch_index, batch, msg)
|
|
|
|
|
self.after_train_batch()
|
|
|
|
|
report('reader_cost', dataload_time)
|
|
|
|
|
observation['batch_cost'] = observation['reader_cost']+observation['step_cost']
|
|
|
|
|
observation['samples'] = observation['batch_size']
|
|
|
|
|
observation['ips[sent./sec]'] = observation['batch_size'] / observation['batch_cost']
|
|
|
|
|
for k, v in observation.items():
|
|
|
|
|
msg += f" {k}: "
|
|
|
|
|
msg += f"{v:>.8f}" if isinstance(v, float) else f"{v}"
|
|
|
|
|
msg += ","
|
|
|
|
|
logger.info(msg)
|
|
|
|
|
data_start_time = time.time()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(e)
|
|
|
|
|