|
|
|
@ -42,10 +42,12 @@ from paddlespeech.t2s.training.trainer import Trainer
|
|
|
|
|
def train_sp(args, config):
|
|
|
|
|
# decides device type and whether to run in parallel
|
|
|
|
|
# setup running environment correctly
|
|
|
|
|
if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0:
|
|
|
|
|
paddle.set_device("cpu")
|
|
|
|
|
else:
|
|
|
|
|
if paddle.is_compiled_with_cuda() and args.ngpu > 0:
|
|
|
|
|
paddle.set_device("gpu")
|
|
|
|
|
elif paddle.is_compiled_with_npu() and args.ngpu > 0:
|
|
|
|
|
paddle.set_device("npu")
|
|
|
|
|
else:
|
|
|
|
|
paddle.set_device("cpu")
|
|
|
|
|
world_size = paddle.distributed.get_world_size()
|
|
|
|
|
if world_size > 1:
|
|
|
|
|
paddle.distributed.init_parallel_env()
|
|
|
|
|