|
|
|
@ -86,7 +86,7 @@ class Trainer():
|
|
|
|
|
>>> config.merge_from_list(args.opts)
|
|
|
|
|
>>> config.freeze()
|
|
|
|
|
>>>
|
|
|
|
|
>>> if args.nprocs > 1 and args.device == "gpu":
|
|
|
|
|
>>> if args.nprocs > 0:
|
|
|
|
|
>>> dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
|
|
|
|
|
>>> else:
|
|
|
|
|
>>> main_sp(config, args)
|
|
|
|
@ -119,7 +119,7 @@ class Trainer():
|
|
|
|
|
def setup(self):
|
|
|
|
|
"""Setup the experiment.
|
|
|
|
|
"""
|
|
|
|
|
paddle.set_device(self.args.device)
|
|
|
|
|
paddle.set_device('gpu' self.args.nprocs > 0 else 'cpu')
|
|
|
|
|
if self.parallel:
|
|
|
|
|
self.init_parallel()
|
|
|
|
|
|
|
|
|
@ -139,7 +139,7 @@ class Trainer():
|
|
|
|
|
"""A flag indicating whether the experiment should run with
|
|
|
|
|
multiprocessing.
|
|
|
|
|
"""
|
|
|
|
|
return self.args.device == "gpu" and self.args.nprocs > 1
|
|
|
|
|
return elf.args.nprocs > 0
|
|
|
|
|
|
|
|
|
|
def init_parallel(self):
|
|
|
|
|
"""Init environment for multiprocess training.
|
|
|
|
|