You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
121 lines
3.8 KiB
121 lines
3.8 KiB
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import argparse
|
|
|
|
import numpy as np
|
|
import paddle
|
|
import pandas as pd
|
|
import yaml
|
|
from paddle import nn
|
|
from paddle.io import DataLoader
|
|
from sklearn.metrics import classification_report
|
|
from sklearn.metrics import precision_recall_fscore_support
|
|
from yacs.config import CfgNode
|
|
|
|
from paddlespeech.text.models.ernie_linear import ErnieLinear
|
|
from paddlespeech.text.models.ernie_linear import PuncDataset
|
|
from paddlespeech.text.models.ernie_linear import PuncDatasetFromErnieTokenizer
|
|
|
|
DefinedClassifier = {
|
|
'ErnieLinear': ErnieLinear,
|
|
}
|
|
|
|
DefinedLoss = {
|
|
"ce": nn.CrossEntropyLoss,
|
|
}
|
|
|
|
DefinedDataset = {
|
|
'Punc': PuncDataset,
|
|
'Ernie': PuncDatasetFromErnieTokenizer,
|
|
}
|
|
|
|
|
|
def evaluation(y_pred, y_test):
|
|
precision, recall, f1, _ = precision_recall_fscore_support(
|
|
y_test, y_pred, average=None, labels=[1, 2, 3])
|
|
overall = precision_recall_fscore_support(
|
|
y_test, y_pred, average='macro', labels=[1, 2, 3])
|
|
result = pd.DataFrame(
|
|
np.array([precision, recall, f1]),
|
|
columns=list(['O', 'COMMA', 'PERIOD', 'QUESTION'])[1:],
|
|
index=['Precision', 'Recall', 'F1'])
|
|
result['OVERALL'] = overall[:3]
|
|
return result
|
|
|
|
|
|
def test(args):
|
|
with open(args.config) as f:
|
|
config = CfgNode(yaml.safe_load(f))
|
|
print("========Args========")
|
|
print(yaml.safe_dump(vars(args)))
|
|
print("========Config========")
|
|
print(config)
|
|
|
|
test_dataset = DefinedDataset[config["dataset_type"]](
|
|
train_path=config["test_path"], **config["data_params"])
|
|
test_loader = DataLoader(
|
|
test_dataset,
|
|
batch_size=config.batch_size,
|
|
shuffle=False,
|
|
drop_last=False)
|
|
model = DefinedClassifier[config["model_type"]](**config["model"])
|
|
state_dict = paddle.load(args.checkpoint)
|
|
model.set_state_dict(state_dict["main_params"])
|
|
model.eval()
|
|
|
|
punc_list = []
|
|
for i in range(len(test_loader.dataset.id2punc)):
|
|
punc_list.append(test_loader.dataset.id2punc[i])
|
|
|
|
test_total_label = []
|
|
test_total_predict = []
|
|
|
|
for i, batch in enumerate(test_loader):
|
|
input, label = batch
|
|
label = paddle.reshape(label, shape=[-1])
|
|
y, logit = model(input)
|
|
pred = paddle.argmax(logit, axis=1)
|
|
test_total_label.extend(label.numpy().tolist())
|
|
test_total_predict.extend(pred.numpy().tolist())
|
|
t = classification_report(
|
|
test_total_label, test_total_predict, target_names=punc_list)
|
|
print(t)
|
|
t2 = evaluation(test_total_label, test_total_predict)
|
|
print('=========================================================')
|
|
print(t2)
|
|
|
|
|
|
def main():
|
|
# parse args and config and redirect to train_sp
|
|
parser = argparse.ArgumentParser(description="Test a ErnieLinear model.")
|
|
parser.add_argument("--config", type=str, help="ErnieLinear config file.")
|
|
parser.add_argument("--checkpoint", type=str, help="snapshot to load.")
|
|
parser.add_argument(
|
|
"--ngpu", type=int, default=1, help="if ngpu=0, use cpu.")
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.ngpu == 0:
|
|
paddle.set_device("cpu")
|
|
elif args.ngpu > 0:
|
|
paddle.set_device("gpu")
|
|
else:
|
|
print("ngpu should >= 0 !")
|
|
|
|
test(args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|