|
|
|
@ -14,7 +14,6 @@
|
|
|
|
|
|
|
|
|
|
import time
|
|
|
|
|
import logging
|
|
|
|
|
import logging.handlers
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
@ -26,6 +25,8 @@ from deepspeech.utils import mp_tools
|
|
|
|
|
|
|
|
|
|
__all__ = ["Trainer"]
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Trainer():
|
|
|
|
|
"""
|
|
|
|
@ -186,14 +187,14 @@ class Trainer():
|
|
|
|
|
while self.epoch < self.config.training.n_epoch:
|
|
|
|
|
try:
|
|
|
|
|
data_start_time = time.time()
|
|
|
|
|
for batch in self.train_loader:
|
|
|
|
|
for batch_index, batch in enumerate(self.train_loader):
|
|
|
|
|
dataload_time = time.time() - data_start_time
|
|
|
|
|
msg = "Train: Rank: {}, ".format(dist.get_rank())
|
|
|
|
|
msg += "epoch: {}, ".format(self.epoch)
|
|
|
|
|
msg += "step: {}, ".format(self.iteration)
|
|
|
|
|
msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
|
|
|
|
|
msg += "dataloader time: {:>.3f}s, ".format(dataload_time)
|
|
|
|
|
self.train_batch(batch, msg)
|
|
|
|
|
self.train_batch(batch_index, batch, msg)
|
|
|
|
|
data_start_time = time.time()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
self.logger.error(e)
|
|
|
|
@ -215,6 +216,7 @@ class Trainer():
|
|
|
|
|
exit(-1)
|
|
|
|
|
finally:
|
|
|
|
|
self.destory()
|
|
|
|
|
self.logger.info("Training Done.")
|
|
|
|
|
|
|
|
|
|
def setup_output_dir(self):
|
|
|
|
|
"""Create a directory used for output.
|
|
|
|
@ -279,41 +281,6 @@ class Trainer():
|
|
|
|
|
backup - how many backup file to keep
|
|
|
|
|
default value: 7
|
|
|
|
|
"""
|
|
|
|
|
when = 'D'
|
|
|
|
|
backup = 7
|
|
|
|
|
format = '[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
|
|
|
|
|
formatter = logging.Formatter(fmt=format, datefmt='%Y/%m/%d %H:%M:%S')
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
logger.setLevel("INFO")
|
|
|
|
|
|
|
|
|
|
stream_handler = logging.StreamHandler()
|
|
|
|
|
stream_handler.setFormatter(formatter)
|
|
|
|
|
logger.addHandler(stream_handler)
|
|
|
|
|
|
|
|
|
|
log_file = self.output_dir / 'worker_{}.log'.format(dist.get_rank())
|
|
|
|
|
# file_handler = logging.FileHandler(str(log_file))
|
|
|
|
|
# file_handler.setFormatter(formatter)
|
|
|
|
|
# logger.addHandler(file_handler)
|
|
|
|
|
|
|
|
|
|
# handler = logging.handlers.TimedRotatingFileHandler(
|
|
|
|
|
# str(self.output_dir / "warning.log"), when=when, backupCount=backup)
|
|
|
|
|
# handler.setLevel(logging.WARNING)
|
|
|
|
|
# handler.setFormatter(formatter)
|
|
|
|
|
# logger.addHandler(handler)
|
|
|
|
|
|
|
|
|
|
# stop propagate for propagating may print
|
|
|
|
|
# log multiple times
|
|
|
|
|
logger.propagate = False
|
|
|
|
|
|
|
|
|
|
# global logger
|
|
|
|
|
stdout = False
|
|
|
|
|
save_path = str(log_file)
|
|
|
|
|
logging.basicConfig(
|
|
|
|
|
level=logging.DEBUG if stdout else logging.INFO,
|
|
|
|
|
format=format,
|
|
|
|
|
datefmt='%Y/%m/%d %H:%M:%S',
|
|
|
|
|
filename=save_path if not stdout else None)
|
|
|
|
|
self.logger = logger
|
|
|
|
|
|
|
|
|
|
@mp_tools.rank_zero_only
|
|
|
|
|