From 59b3de6a6db154c29088b807cb5f660ce7d25720 Mon Sep 17 00:00:00 2001 From: zhangkeliang Date: Sat, 8 Jan 2022 00:49:28 +0800 Subject: [PATCH] [NPU] test TransformerTTS with NPU --- paddlespeech/t2s/exps/transformer_tts/train.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/paddlespeech/t2s/exps/transformer_tts/train.py b/paddlespeech/t2s/exps/transformer_tts/train.py index 8695c06a..9b1ab76b 100644 --- a/paddlespeech/t2s/exps/transformer_tts/train.py +++ b/paddlespeech/t2s/exps/transformer_tts/train.py @@ -42,10 +42,12 @@ from paddlespeech.t2s.training.trainer import Trainer def train_sp(args, config): # decides device type and whether to run in parallel # setup running environment correctly - if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0: - paddle.set_device("cpu") - else: + if paddle.is_compiled_with_cuda() and args.ngpu > 0: paddle.set_device("gpu") + elif paddle.is_compiled_with_npu() and args.ngpu > 0: + paddle.set_device("npu") + else: + paddle.set_device("cpu") world_size = paddle.distributed.get_world_size() if world_size > 1: paddle.distributed.init_parallel_env()