|
|
|
@ -12,6 +12,7 @@
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
import argparse
|
|
|
|
|
import os
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
@ -120,6 +121,15 @@ def parse_args():
|
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_model_suffix(inference_dir, model_name):
|
|
|
|
|
if os.path.exists(os.path.join(inference_dir, model_name + ".pdmodel")):
|
|
|
|
|
return ".pdmodel"
|
|
|
|
|
elif os.path.exists(os.path.join(inference_dir, model_name + ".json")):
|
|
|
|
|
return ".json"
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("model file not found!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# only inference for models trained with csmsc now
|
|
|
|
|
def main():
|
|
|
|
|
args = parse_args()
|
|
|
|
@ -133,9 +143,10 @@ def main():
|
|
|
|
|
tones_dict=args.tones_dict)
|
|
|
|
|
|
|
|
|
|
# am_predictor
|
|
|
|
|
suffix = get_model_suffix(args.inference_dir, args.am)
|
|
|
|
|
am_predictor = get_predictor(
|
|
|
|
|
model_dir=args.inference_dir,
|
|
|
|
|
model_file=args.am + ".pdmodel",
|
|
|
|
|
model_file=args.am + suffix,
|
|
|
|
|
params_file=args.am + ".pdiparams",
|
|
|
|
|
device=args.device,
|
|
|
|
|
use_trt=args.use_trt,
|
|
|
|
@ -146,9 +157,10 @@ def main():
|
|
|
|
|
am_dataset = args.am[args.am.rindex('_') + 1:]
|
|
|
|
|
|
|
|
|
|
# voc_predictor
|
|
|
|
|
suffix = get_model_suffix(args.inference_dir, args.voc)
|
|
|
|
|
voc_predictor = get_predictor(
|
|
|
|
|
model_dir=args.inference_dir,
|
|
|
|
|
model_file=args.voc + ".pdmodel",
|
|
|
|
|
model_file=args.voc + suffix,
|
|
|
|
|
params_file=args.voc + ".pdiparams",
|
|
|
|
|
device=args.device,
|
|
|
|
|
use_trt=args.use_trt,
|
|
|
|
|