fix profiler

pull/831/head
Hui Zhang 3 years ago
parent 438e1bd34f
commit 7907319288

@ -185,7 +185,8 @@ class Trainer():
batch_sampler.set_epoch(self.epoch) batch_sampler.set_epoch(self.epoch)
def after_train_batch(self): def after_train_batch(self):
profiler.add_profiler_step(self.args.profiler_options) if self.args.profiler_options:
profiler.add_profiler_step(self.args.profiler_options)
def train(self): def train(self):
"""The training process control by epoch.""" """The training process control by epoch."""

@ -61,6 +61,9 @@ class ProfilerOptions(object):
self._parse_from_string(options_str) self._parse_from_string(options_str)
def _parse_from_string(self, options_str): def _parse_from_string(self, options_str):
if not options_str:
return
for kv in options_str.replace(' ', '').split(';'): for kv in options_str.replace(' ', '').split(';'):
key, value = kv.split('=') key, value = kv.split('=')
if key == 'batch_range': if key == 'batch_range':

@ -48,7 +48,7 @@ training:
n_epoch: 10 n_epoch: 10
accum_grad: 1 accum_grad: 1
lr: 1e-5 lr: 1e-5
lr_decay: 1.0 lr_decay: 0.8
weight_decay: 1e-06 weight_decay: 1e-06
global_grad_clip: 5.0 global_grad_clip: 5.0
log_interval: 1 log_interval: 1

@ -38,7 +38,7 @@ python3 -u ${BIN_DIR}/train.py \
--config ${config_path} \ --config ${config_path} \
--output exp/${ckpt_name} \ --output exp/${ckpt_name} \
--model_type ${model_type} \ --model_type ${model_type} \
--profiler_options ${profiler_options} \ --profiler_options "${profiler_options}" \
--seed ${seed} --seed ${seed}
if [ ${seed} != 0 ]; then if [ ${seed} != 0 ]; then

Loading…
Cancel
Save