Merge pull request #1727 from yt605155624/refactor_syn_util

[TTS]add paddle device set for ort and inference
pull/1732/head
TianYuan 3 years ago committed by GitHub
commit e089268642
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -14,6 +14,7 @@
import argparse import argparse
from pathlib import Path from pathlib import Path
import paddle
import soundfile as sf import soundfile as sf
from timer import timer from timer import timer
@ -101,6 +102,9 @@ def parse_args():
# only inference for models trained with csmsc now # only inference for models trained with csmsc now
def main(): def main():
args = parse_args() args = parse_args()
paddle.set_device(args.device)
# frontend # frontend
frontend = get_frontend( frontend = get_frontend(
lang=args.lang, lang=args.lang,

@ -15,6 +15,7 @@ import argparse
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import paddle
import soundfile as sf import soundfile as sf
from timer import timer from timer import timer
@ -100,6 +101,9 @@ def parse_args():
# only inference for models trained with csmsc now # only inference for models trained with csmsc now
def main(): def main():
args = parse_args() args = parse_args()
paddle.set_device(args.device)
# frontend # frontend
frontend = get_frontend( frontend = get_frontend(
lang=args.lang, lang=args.lang,

@ -16,6 +16,7 @@ from pathlib import Path
import jsonlines import jsonlines
import numpy as np import numpy as np
import paddle
import soundfile as sf import soundfile as sf
from timer import timer from timer import timer
@ -25,6 +26,7 @@ from paddlespeech.t2s.utils import str2bool
def ort_predict(args): def ort_predict(args):
# construct dataset for evaluation # construct dataset for evaluation
with jsonlines.open(args.test_metadata, 'r') as reader: with jsonlines.open(args.test_metadata, 'r') as reader:
test_metadata = list(reader) test_metadata = list(reader)
@ -143,6 +145,8 @@ def parse_args():
def main(): def main():
args = parse_args() args = parse_args()
paddle.set_device(args.device)
ort_predict(args) ort_predict(args)

@ -15,6 +15,7 @@ import argparse
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import paddle
import soundfile as sf import soundfile as sf
from timer import timer from timer import timer
@ -178,6 +179,8 @@ def parse_args():
def main(): def main():
args = parse_args() args = parse_args()
paddle.set_device(args.device)
ort_predict(args) ort_predict(args)

@ -15,6 +15,7 @@ import argparse
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import paddle
import soundfile as sf import soundfile as sf
from timer import timer from timer import timer
@ -246,6 +247,8 @@ def parse_args():
def main(): def main():
args = parse_args() args = parse_args()
paddle.set_device(args.device)
ort_predict(args) ort_predict(args)

Loading…
Cancel
Save