|
|
|
@ -37,7 +37,7 @@ class Rhy_predictor():
|
|
|
|
|
model_dir: os.PathLike=MODEL_HOME, ):
|
|
|
|
|
uncompress_path = download_and_decompress(
|
|
|
|
|
rhy_frontend_models['rhy_e2e'][model_version], model_dir)
|
|
|
|
|
with open(os.path.join(uncompress_path, 'default.yaml')) as f:
|
|
|
|
|
with open(os.path.join(uncompress_path, 'rhy_default.yaml')) as f:
|
|
|
|
|
config = CfgNode(yaml.safe_load(f))
|
|
|
|
|
self.punc_list = []
|
|
|
|
|
with open(os.path.join(uncompress_path, 'rhy_token'), 'r') as f:
|
|
|
|
@ -45,11 +45,11 @@ class Rhy_predictor():
|
|
|
|
|
self.punc_list.append(line.strip())
|
|
|
|
|
self.punc_list = [0] + self.punc_list
|
|
|
|
|
self.make_rhy_dict()
|
|
|
|
|
self.model = DefinedClassifier[config["model_type"]](**config["model"])
|
|
|
|
|
self.model = DefinedClassifier["ErnieLinear"](**config["model"])
|
|
|
|
|
pretrained_token = config['data_params']['pretrained_token']
|
|
|
|
|
self.tokenizer = ErnieTokenizer.from_pretrained(pretrained_token)
|
|
|
|
|
state_dict = paddle.load(
|
|
|
|
|
os.path.join(uncompress_path, 'snapshot_iter_153000.pdz'))
|
|
|
|
|
os.path.join(uncompress_path, 'snapshot_iter_2600.pdz'))
|
|
|
|
|
self.model.set_state_dict(state_dict["main_params"])
|
|
|
|
|
self.model.eval()
|
|
|
|
|
|
|
|
|
|