From 4646f7cc8de954497e9edc8ff10ca95b171d8fdb Mon Sep 17 00:00:00 2001 From: TianYuan Date: Wed, 20 Apr 2022 09:48:46 +0000 Subject: [PATCH] add paddle device set for ort and inference, test=doc --- paddlespeech/t2s/exps/inference.py | 4 ++++ paddlespeech/t2s/exps/inference_streaming.py | 4 ++++ paddlespeech/t2s/exps/ort_predict.py | 4 ++++ paddlespeech/t2s/exps/ort_predict_e2e.py | 3 +++ paddlespeech/t2s/exps/ort_predict_streaming.py | 3 +++ 5 files changed, 18 insertions(+) diff --git a/paddlespeech/t2s/exps/inference.py b/paddlespeech/t2s/exps/inference.py index 7a19a113..98e73e10 100644 --- a/paddlespeech/t2s/exps/inference.py +++ b/paddlespeech/t2s/exps/inference.py @@ -14,6 +14,7 @@ import argparse from pathlib import Path +import paddle import soundfile as sf from timer import timer @@ -101,6 +102,9 @@ def parse_args(): # only inference for models trained with csmsc now def main(): args = parse_args() + + paddle.set_device(args.device) + # frontend frontend = get_frontend( lang=args.lang, diff --git a/paddlespeech/t2s/exps/inference_streaming.py b/paddlespeech/t2s/exps/inference_streaming.py index ef6d1a4a..b680f19a 100644 --- a/paddlespeech/t2s/exps/inference_streaming.py +++ b/paddlespeech/t2s/exps/inference_streaming.py @@ -15,6 +15,7 @@ import argparse from pathlib import Path import numpy as np +import paddle import soundfile as sf from timer import timer @@ -100,6 +101,9 @@ def parse_args(): # only inference for models trained with csmsc now def main(): args = parse_args() + + paddle.set_device(args.device) + # frontend frontend = get_frontend( lang=args.lang, diff --git a/paddlespeech/t2s/exps/ort_predict.py b/paddlespeech/t2s/exps/ort_predict.py index adbd6809..2e8596de 100644 --- a/paddlespeech/t2s/exps/ort_predict.py +++ b/paddlespeech/t2s/exps/ort_predict.py @@ -16,6 +16,7 @@ from pathlib import Path import jsonlines import numpy as np +import paddle import soundfile as sf from timer import timer @@ -25,6 +26,7 @@ from paddlespeech.t2s.utils import str2bool def ort_predict(args): + # construct dataset for evaluation with jsonlines.open(args.test_metadata, 'r') as reader: test_metadata = list(reader) @@ -143,6 +145,8 @@ def parse_args(): def main(): args = parse_args() + paddle.set_device(args.device) + ort_predict(args) diff --git a/paddlespeech/t2s/exps/ort_predict_e2e.py b/paddlespeech/t2s/exps/ort_predict_e2e.py index ae5e900b..a2ef8e4c 100644 --- a/paddlespeech/t2s/exps/ort_predict_e2e.py +++ b/paddlespeech/t2s/exps/ort_predict_e2e.py @@ -15,6 +15,7 @@ import argparse from pathlib import Path import numpy as np +import paddle import soundfile as sf from timer import timer @@ -178,6 +179,8 @@ def parse_args(): def main(): args = parse_args() + paddle.set_device(args.device) + ort_predict(args) diff --git a/paddlespeech/t2s/exps/ort_predict_streaming.py b/paddlespeech/t2s/exps/ort_predict_streaming.py index 5568ed39..5d2c66bc 100644 --- a/paddlespeech/t2s/exps/ort_predict_streaming.py +++ b/paddlespeech/t2s/exps/ort_predict_streaming.py @@ -15,6 +15,7 @@ import argparse from pathlib import Path import numpy as np +import paddle import soundfile as sf from timer import timer @@ -246,6 +247,8 @@ def parse_args(): def main(): args = parse_args() + paddle.set_device(args.device) + ort_predict(args)