|
|
@ -12,6 +12,7 @@
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
# limitations under the License.
|
|
|
|
import argparse
|
|
|
|
import argparse
|
|
|
|
|
|
|
|
import os
|
|
|
|
from pathlib import Path
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
import paddle
|
|
|
@ -96,13 +97,16 @@ def main():
|
|
|
|
|
|
|
|
|
|
|
|
paddle.set_device(args.device)
|
|
|
|
paddle.set_device(args.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# set model_suffix
|
|
|
|
|
|
|
|
model_suffix = ".json" if os.path.exists(args.am + ".json") else ".pdmodel"
|
|
|
|
|
|
|
|
|
|
|
|
# frontend
|
|
|
|
# frontend
|
|
|
|
frontend = get_frontend(lang=args.lang, phones_dict=args.phones_dict)
|
|
|
|
frontend = get_frontend(lang=args.lang, phones_dict=args.phones_dict)
|
|
|
|
|
|
|
|
|
|
|
|
# am_predictor
|
|
|
|
# am_predictor
|
|
|
|
am_predictor = get_predictor(
|
|
|
|
am_predictor = get_predictor(
|
|
|
|
model_dir=args.inference_dir,
|
|
|
|
model_dir=args.inference_dir,
|
|
|
|
model_file=args.am + ".pdmodel",
|
|
|
|
model_file=args.am + model_suffix,
|
|
|
|
params_file=args.am + ".pdiparams",
|
|
|
|
params_file=args.am + ".pdiparams",
|
|
|
|
device=args.device,
|
|
|
|
device=args.device,
|
|
|
|
use_trt=args.use_trt,
|
|
|
|
use_trt=args.use_trt,
|
|
|
|