add paddle device set for ort and inference, test=doc

pull/1727/head
TianYuan 2 years ago
parent c74fa9ada8
commit 4646f7cc8d

@ -14,6 +14,7 @@
import argparse import argparse
from pathlib import Path from pathlib import Path
import paddle
import soundfile as sf import soundfile as sf
from timer import timer from timer import timer
@ -101,6 +102,9 @@ def parse_args():
# only inference for models trained with csmsc now # only inference for models trained with csmsc now
def main(): def main():
args = parse_args() args = parse_args()
paddle.set_device(args.device)
# frontend # frontend
frontend = get_frontend( frontend = get_frontend(
lang=args.lang, lang=args.lang,

@ -15,6 +15,7 @@ import argparse
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import paddle
import soundfile as sf import soundfile as sf
from timer import timer from timer import timer
@ -100,6 +101,9 @@ def parse_args():
# only inference for models trained with csmsc now # only inference for models trained with csmsc now
def main(): def main():
args = parse_args() args = parse_args()
paddle.set_device(args.device)
# frontend # frontend
frontend = get_frontend( frontend = get_frontend(
lang=args.lang, lang=args.lang,

@ -16,6 +16,7 @@ from pathlib import Path
import jsonlines import jsonlines
import numpy as np import numpy as np
import paddle
import soundfile as sf import soundfile as sf
from timer import timer from timer import timer
@ -25,6 +26,7 @@ from paddlespeech.t2s.utils import str2bool
def ort_predict(args): def ort_predict(args):
# construct dataset for evaluation # construct dataset for evaluation
with jsonlines.open(args.test_metadata, 'r') as reader: with jsonlines.open(args.test_metadata, 'r') as reader:
test_metadata = list(reader) test_metadata = list(reader)
@ -143,6 +145,8 @@ def parse_args():
def main(): def main():
args = parse_args() args = parse_args()
paddle.set_device(args.device)
ort_predict(args) ort_predict(args)

@ -15,6 +15,7 @@ import argparse
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import paddle
import soundfile as sf import soundfile as sf
from timer import timer from timer import timer
@ -178,6 +179,8 @@ def parse_args():
def main(): def main():
args = parse_args() args = parse_args()
paddle.set_device(args.device)
ort_predict(args) ort_predict(args)

@ -15,6 +15,7 @@ import argparse
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import paddle
import soundfile as sf import soundfile as sf
from timer import timer from timer import timer
@ -246,6 +247,8 @@ def parse_args():
def main(): def main():
args = parse_args() args = parse_args()
paddle.set_device(args.device)
ort_predict(args) ort_predict(args)

Loading…
Cancel
Save