pull/815/head
Hui Zhang 3 years ago
parent 75cd366ddd
commit c29ee83a46

@ -0,0 +1,49 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import time
from deepspeech.utils.log import Log
__all__ = ["Timer"]
logger = Log(__name__).getlog()
class Timer():
"""To be used like this:
with Timer("Message") as value:
do some thing
"""
def __init__(self, message):
self.message = message
def duration(self) -> str:
elapsed_time = time.time() - self.start
time_str = str(datetime.timedelta(seconds=elapsed_time))
return time_str
def __enter__(self):
self.start = time.time()
return self
def __exit__(self, type, value, traceback):
logger.info(self.message.format(self.duration()))
def __call__(self) -> float:
return time.time() - self.start
def __str__(self):
return self.duration()

@ -18,6 +18,7 @@ import paddle
from paddle import distributed as dist from paddle import distributed as dist
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from deepspeech.training.timer import Timer
from deepspeech.utils import mp_tools from deepspeech.utils import mp_tools
from deepspeech.utils.checkpoint import Checkpoint from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
@ -184,13 +185,14 @@ class Trainer():
def train(self): def train(self):
"""The training process control by epoch.""" """The training process control by epoch."""
from_scratch = self.resume_or_scratch() with Timer("Load/Init Model: {}"):
if from_scratch: from_scratch = self.resume_or_scratch()
# save init model, i.e. 0 epoch if from_scratch:
self.save(tag='init', infos=None) # save init model, i.e. 0 epoch
self.lr_scheduler.step(self.epoch) self.save(tag='init', infos=None)
if self.parallel and hasattr(self.train_loader, "batch_sampler"): self.lr_scheduler.step(self.epoch)
self.train_loader.batch_sampler.set_epoch(self.epoch) if self.parallel and hasattr(self.train_loader, "batch_sampler"):
self.train_loader.batch_sampler.set_epoch(self.epoch)
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch: while self.epoch < self.config.training.n_epoch:
@ -240,14 +242,14 @@ class Trainer():
"""The routine of the experiment after setup. This method is intended """The routine of the experiment after setup. This method is intended
to be used by the user. to be used by the user.
""" """
try: with Timer("Training Done: {}"):
self.train() try:
except KeyboardInterrupt: self.train()
self.save() except KeyboardInterrupt:
exit(-1) self.save()
finally: exit(-1)
self.destory() finally:
logger.info("Training Done.") self.destory()
def setup_output_dir(self): def setup_output_dir(self):
"""Create a directory used for output. """Create a directory used for output.

Loading…
Cancel
Save