|
|
@ -13,16 +13,16 @@
|
|
|
|
# limitations under the License.
|
|
|
|
# limitations under the License.
|
|
|
|
import sys
|
|
|
|
import sys
|
|
|
|
import time
|
|
|
|
import time
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
import paddle
|
|
|
|
from paddle import distributed as dist
|
|
|
|
from paddle import distributed as dist
|
|
|
|
from tensorboardX import SummaryWriter
|
|
|
|
from tensorboardX import SummaryWriter
|
|
|
|
|
|
|
|
|
|
|
|
from deepspeech.training.timer import Timer
|
|
|
|
|
|
|
|
from deepspeech.training.reporter import report
|
|
|
|
|
|
|
|
from deepspeech.training.reporter import ObsScope
|
|
|
|
from deepspeech.training.reporter import ObsScope
|
|
|
|
|
|
|
|
from deepspeech.training.reporter import report
|
|
|
|
|
|
|
|
from deepspeech.training.timer import Timer
|
|
|
|
from deepspeech.utils import mp_tools
|
|
|
|
from deepspeech.utils import mp_tools
|
|
|
|
from deepspeech.utils import profiler
|
|
|
|
from deepspeech.utils import profiler
|
|
|
|
from deepspeech.utils.checkpoint import Checkpoint
|
|
|
|
from deepspeech.utils.checkpoint import Checkpoint
|
|
|
@ -30,7 +30,6 @@ from deepspeech.utils.log import Log
|
|
|
|
from deepspeech.utils.utility import seed_all
|
|
|
|
from deepspeech.utils.utility import seed_all
|
|
|
|
from deepspeech.utils.utility import UpdateConfig
|
|
|
|
from deepspeech.utils.utility import UpdateConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ["Trainer"]
|
|
|
|
__all__ = ["Trainer"]
|
|
|
|
|
|
|
|
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
@ -236,17 +235,21 @@ class Trainer():
|
|
|
|
report("Rank", dist.get_rank())
|
|
|
|
report("Rank", dist.get_rank())
|
|
|
|
report("epoch", self.epoch)
|
|
|
|
report("epoch", self.epoch)
|
|
|
|
report('step', self.iteration)
|
|
|
|
report('step', self.iteration)
|
|
|
|
report('step/total', (batch_index + 1) / len(self.train_loader))
|
|
|
|
report('step/total',
|
|
|
|
|
|
|
|
(batch_index + 1) / len(self.train_loader))
|
|
|
|
report("lr", self.lr_scheduler())
|
|
|
|
report("lr", self.lr_scheduler())
|
|
|
|
self.train_batch(batch_index, batch, msg)
|
|
|
|
self.train_batch(batch_index, batch, msg)
|
|
|
|
self.after_train_batch()
|
|
|
|
self.after_train_batch()
|
|
|
|
report('reader_cost', dataload_time)
|
|
|
|
report('reader_cost', dataload_time)
|
|
|
|
observation['batch_cost'] = observation['reader_cost']+observation['step_cost']
|
|
|
|
observation['batch_cost'] = observation[
|
|
|
|
|
|
|
|
'reader_cost'] + observation['step_cost']
|
|
|
|
observation['samples'] = observation['batch_size']
|
|
|
|
observation['samples'] = observation['batch_size']
|
|
|
|
observation['ips[sent./sec]'] = observation['batch_size'] / observation['batch_cost']
|
|
|
|
observation['ips[sent./sec]'] = observation[
|
|
|
|
|
|
|
|
'batch_size'] / observation['batch_cost']
|
|
|
|
for k, v in observation.items():
|
|
|
|
for k, v in observation.items():
|
|
|
|
msg += f" {k}: "
|
|
|
|
msg += f" {k}: "
|
|
|
|
msg += f"{v:>.8f}" if isinstance(v, float) else f"{v}"
|
|
|
|
msg += f"{v:>.8f}" if isinstance(v,
|
|
|
|
|
|
|
|
float) else f"{v}"
|
|
|
|
msg += ","
|
|
|
|
msg += ","
|
|
|
|
logger.info(msg)
|
|
|
|
logger.info(msg)
|
|
|
|
data_start_time = time.time()
|
|
|
|
data_start_time = time.time()
|
|
|
|