From 5fa22d94af6bd2cf55803bfb4599bbbc394d4c97 Mon Sep 17 00:00:00 2001 From: Zhangjingyu06 Date: Tue, 24 May 2022 08:58:59 +0000 Subject: [PATCH] deepspeech2 modify for kunlun --- paddlespeech/s2t/exps/deepspeech2/bin/train.py | 5 +++++ paddlespeech/s2t/training/trainer.py | 9 ++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/train.py b/paddlespeech/s2t/exps/deepspeech2/bin/train.py index e2c68d4be..cb4867ef2 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/train.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/train.py @@ -33,6 +33,11 @@ if __name__ == "__main__": parser = default_argument_parser() parser.add_argument( "--model_type", type=str, default='offline', help='offline/online') + parser.add_argument( + '--nxpu', + type=int, + default=1, + help="if nxpu == 0 and ngpu == 0, use cpu.") args = parser.parse_args() print("model_type:{}".format(args.model_type)) print_arguments(args, globals()) diff --git a/paddlespeech/s2t/training/trainer.py b/paddlespeech/s2t/training/trainer.py index 84da251aa..864c1bd69 100644 --- a/paddlespeech/s2t/training/trainer.py +++ b/paddlespeech/s2t/training/trainer.py @@ -112,7 +112,14 @@ class Trainer(): logger.info(f"Rank: {self.rank}/{self.world_size}") # set device - paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu') + if self.args.ngpu == 0: + paddle.set_device('cpu') + if self.args.nxpu == 0: + paddle.set_device('cpu') + else: + paddle.set_device('xpu') + elif self.args.ngpu > 0: + paddle.set_device("gpu") if self.parallel: self.init_parallel()