diff --git a/paddlespeech/t2s/exps/transformer_tts/train.py b/paddlespeech/t2s/exps/transformer_tts/train.py index 8695c06a..9b1ab76b 100644 --- a/paddlespeech/t2s/exps/transformer_tts/train.py +++ b/paddlespeech/t2s/exps/transformer_tts/train.py @@ -42,10 +42,12 @@ from paddlespeech.t2s.training.trainer import Trainer def train_sp(args, config): # decides device type and whether to run in parallel # setup running environment correctly - if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0: - paddle.set_device("cpu") - else: + if paddle.is_compiled_with_cuda() and args.ngpu > 0: paddle.set_device("gpu") + elif paddle.is_compiled_with_npu() and args.ngpu > 0: + paddle.set_device("npu") + else: + paddle.set_device("cpu") world_size = paddle.distributed.get_world_size() if world_size > 1: paddle.distributed.init_parallel_env()