|
|
|
@ -200,10 +200,8 @@ class Trainer():
|
|
|
|
|
batch_sampler.set_epoch(self.epoch)
|
|
|
|
|
|
|
|
|
|
def after_train_batch(self):
|
|
|
|
|
if self.args.profiler_options:
|
|
|
|
|
profiler.add_profiler_step(self.args.profiler_options)
|
|
|
|
|
|
|
|
|
|
if self.args.benchmark_max_step and self.iteration > self.args.benchmark_max_step:
|
|
|
|
|
profiler.add_profiler_step(self.args.profiler_options)
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Reach benchmark-max-step: {self.args.benchmark_max_step}")
|
|
|
|
|
sys.exit(
|
|
|
|
|