Merge pull request #1656 from windstamp/npu_dev_r01_20220407

[NPU] test TransformerTTS with NPU
r0.1
TianYuan 2 years ago committed by GitHub
commit dd2bf469ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -42,10 +42,12 @@ from paddlespeech.t2s.training.trainer import Trainer
def train_sp(args, config):
# decides device type and whether to run in parallel
# setup running environment correctly
if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0:
paddle.set_device("cpu")
else:
if paddle.is_compiled_with_cuda() and args.ngpu > 0:
paddle.set_device("gpu")
elif paddle.is_compiled_with_npu() and args.ngpu > 0:
paddle.set_device("npu")
else:
paddle.set_device("cpu")
world_size = paddle.distributed.get_world_size()
if world_size > 1:
paddle.distributed.init_parallel_env()

Loading…
Cancel
Save