Enable min_batch_num in train.py and update train info print.

pull/2/head
Xinghai Sun 8 years ago
parent 1cef98f210
commit 04a225ae4f

@ -143,11 +143,13 @@ def train():
train_batch_reader = train_generator.batch_reader_creator( train_batch_reader = train_generator.batch_reader_creator(
manifest_path=args.train_manifest_path, manifest_path=args.train_manifest_path,
batch_size=args.batch_size, batch_size=args.batch_size,
min_batch_size=args.trainer_count,
sortagrad=args.use_sortagrad if args.init_model_path is None else False, sortagrad=args.use_sortagrad if args.init_model_path is None else False,
batch_shuffle=True) batch_shuffle=True)
test_batch_reader = test_generator.batch_reader_creator( test_batch_reader = test_generator.batch_reader_creator(
manifest_path=args.dev_manifest_path, manifest_path=args.dev_manifest_path,
batch_size=args.batch_size, batch_size=args.batch_size,
min_batch_size=1, # must be 1, but will have errors.
sortagrad=False, sortagrad=False,
batch_shuffle=False) batch_shuffle=False)
@ -157,11 +159,11 @@ def train():
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
cost_sum += event.cost cost_sum += event.cost
cost_counter += 1 cost_counter += 1
if event.batch_id % 50 == 0: if (event.batch_id + 1) % 100 == 0:
print("\nPass: %d, Batch: %d, TrainCost: %f" % print("\nPass: %d, Batch: %d, TrainCost: %f" % (
(event.pass_id, event.batch_id, cost_sum / cost_counter)) event.pass_id, event.batch_id + 1, cost_sum / cost_counter))
cost_sum, cost_counter = 0.0, 0 cost_sum, cost_counter = 0.0, 0
with gzip.open("params_tmp.tar.gz", 'w') as f: with gzip.open("params.tar.gz", 'w') as f:
parameters.to_tar(f) parameters.to_tar(f)
else: else:
sys.stdout.write('.') sys.stdout.write('.')

Loading…
Cancel
Save