[NPU] test TransformerTTS with NPU

pull/1656/head
zhangkeliang 3 years ago
parent c7a9650a04
commit 22b4b441e1

@ -42,10 +42,12 @@ from paddlespeech.t2s.training.trainer import Trainer
def train_sp(args, config):
# decides device type and whether to run in parallel
# setup running environment correctly
if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0:
paddle.set_device("cpu")
else:
if paddle.is_compiled_with_cuda() and args.ngpu > 0:
paddle.set_device("gpu")
elif paddle.is_compiled_with_npu() and args.ngpu > 0:
paddle.set_device("npu")
else:
paddle.set_device("cpu")
world_size = paddle.distributed.get_world_size()
if world_size > 1:
paddle.distributed.init_parallel_env()

Loading…
Cancel
Save