pull/3514/head
USTCKAY 2 years ago
parent 5406e15434
commit 5ec8f11592

@ -44,14 +44,16 @@ from paddlespeech.t2s.utils import str2bool
def train_sp(args, config): def train_sp(args, config):
# decides device type and whether to run in parallel # decides device type and whether to run in parallel
# setup running environment correctly # setup running environment correctly
if args.ngpu > 0: if args.ngpu > 0 and paddle.is_compiled_with_cuda():
paddle.set_device("gpu") paddle.set_device("gpu")
elif args.nxpu > 0: elif args.nxpu > 0 and paddle.is_compiled_with_xpu():
paddle.set_device("xpu") paddle.set_device("xpu")
elif args.ngpu == 0 and args.nxpu == 0: elif args.ngpu == 0 and args.nxpu == 0:
paddle.set_device("cpu") paddle.set_device("cpu")
else: else:
print("ngpu or nxpu should >= 0 !") raise ValueError(
"Please make sure that the paddle you installed matches the device type you set, "
"and that ngpu and nxpu cannot be negative at the same time.")
world_size = paddle.distributed.get_world_size() world_size = paddle.distributed.get_world_size()
if world_size > 1: if world_size > 1:

Loading…
Cancel
Save