update notes, test=doc

pull/1665/head
TianYuan 3 years ago
parent e0d222e674
commit 124eb6af8f

@ -69,13 +69,13 @@ def ort_predict(args):
voc_sess = get_sess(args, filed='voc') voc_sess = get_sess(args, filed='voc')
# am warmup # am warmup
for batch in [27, 38, 54]: for T in [27, 38, 54]:
data = np.random.randint(1, 266, size=(batch, )) data = np.random.randint(1, 266, size=(T, ))
am_sess.run(None, {"text": data}) am_sess.run(None, {"text": data})
# voc warmup # voc warmup
for batch in [227, 308, 544]: for T in [227, 308, 544]:
data = np.random.rand(batch, 80).astype("float32") data = np.random.rand(T, 80).astype("float32")
voc_sess.run(None, {"logmel": data}) voc_sess.run(None, {"logmel": data})
print("warm up done!") print("warm up done!")
@ -120,9 +120,7 @@ def parse_args():
'--voc', '--voc',
type=str, type=str,
default='hifigan_csmsc', default='hifigan_csmsc',
choices=[ choices=['hifigan_csmsc', 'mb_melgan_csmsc'],
'hifigan_csmsc', 'mb_melgan_csmsc'
],
help='Choose vocoder type of tts task.') help='Choose vocoder type of tts task.')
# other # other
parser.add_argument( parser.add_argument(

@ -69,13 +69,13 @@ def ort_predict(args):
voc_sess = get_sess(args, filed='voc') voc_sess = get_sess(args, filed='voc')
# am warmup # am warmup
for batch in [27, 38, 54]: for T in [27, 38, 54]:
data = np.random.randint(1, 266, size=(batch, )) data = np.random.randint(1, 266, size=(T, ))
am_sess.run(None, {"text": data}) am_sess.run(None, {"text": data})
# voc warmup # voc warmup
for batch in [227, 308, 544]: for T in [227, 308, 544]:
data = np.random.rand(batch, 80).astype("float32") data = np.random.rand(T, 80).astype("float32")
voc_sess.run(None, {"logmel": data}) voc_sess.run(None, {"logmel": data})
print("warm up done!") print("warm up done!")

Loading…
Cancel
Save