|
|
@ -143,11 +143,13 @@ def train():
|
|
|
|
train_batch_reader = train_generator.batch_reader_creator(
|
|
|
|
train_batch_reader = train_generator.batch_reader_creator(
|
|
|
|
manifest_path=args.train_manifest_path,
|
|
|
|
manifest_path=args.train_manifest_path,
|
|
|
|
batch_size=args.batch_size,
|
|
|
|
batch_size=args.batch_size,
|
|
|
|
|
|
|
|
min_batch_size=args.trainer_count,
|
|
|
|
sortagrad=args.use_sortagrad if args.init_model_path is None else False,
|
|
|
|
sortagrad=args.use_sortagrad if args.init_model_path is None else False,
|
|
|
|
batch_shuffle=True)
|
|
|
|
batch_shuffle=True)
|
|
|
|
test_batch_reader = test_generator.batch_reader_creator(
|
|
|
|
test_batch_reader = test_generator.batch_reader_creator(
|
|
|
|
manifest_path=args.dev_manifest_path,
|
|
|
|
manifest_path=args.dev_manifest_path,
|
|
|
|
batch_size=args.batch_size,
|
|
|
|
batch_size=args.batch_size,
|
|
|
|
|
|
|
|
min_batch_size=1, # must be 1, but will have errors.
|
|
|
|
sortagrad=False,
|
|
|
|
sortagrad=False,
|
|
|
|
batch_shuffle=False)
|
|
|
|
batch_shuffle=False)
|
|
|
|
|
|
|
|
|
|
|
@ -157,11 +159,11 @@ def train():
|
|
|
|
if isinstance(event, paddle.event.EndIteration):
|
|
|
|
if isinstance(event, paddle.event.EndIteration):
|
|
|
|
cost_sum += event.cost
|
|
|
|
cost_sum += event.cost
|
|
|
|
cost_counter += 1
|
|
|
|
cost_counter += 1
|
|
|
|
if event.batch_id % 50 == 0:
|
|
|
|
if (event.batch_id + 1) % 100 == 0:
|
|
|
|
print("\nPass: %d, Batch: %d, TrainCost: %f" %
|
|
|
|
print("\nPass: %d, Batch: %d, TrainCost: %f" % (
|
|
|
|
(event.pass_id, event.batch_id, cost_sum / cost_counter))
|
|
|
|
event.pass_id, event.batch_id + 1, cost_sum / cost_counter))
|
|
|
|
cost_sum, cost_counter = 0.0, 0
|
|
|
|
cost_sum, cost_counter = 0.0, 0
|
|
|
|
with gzip.open("params_tmp.tar.gz", 'w') as f:
|
|
|
|
with gzip.open("params.tar.gz", 'w') as f:
|
|
|
|
parameters.to_tar(f)
|
|
|
|
parameters.to_tar(f)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
sys.stdout.write('.')
|
|
|
|
sys.stdout.write('.')
|
|
|
|