|
|
|
@ -55,7 +55,7 @@ class MyDataProcessor(DataProcessor):
|
|
|
|
|
f = open(file_path, 'r', encoding='utf-8') # 读取数据,并指定中文常用的utf-8
|
|
|
|
|
train_data = []
|
|
|
|
|
index = 0 # ID值
|
|
|
|
|
for line in f.readline(): # 参考XnliProcessor
|
|
|
|
|
for line in f.readlines(): # 参考XnliProcessor
|
|
|
|
|
guid = "train-%d" % index
|
|
|
|
|
line = line.replace('\n', '').split('\t') # 处理换行符,原数据是以tab分割
|
|
|
|
|
text_a = tokenization.convert_to_unicode(str(line[1])) # 第0位置是索引,第1位置才是数据,可以查看train_sentiment.txt
|
|
|
|
@ -111,7 +111,60 @@ class XnliProcessor(DataProcessor):
|
|
|
|
|
**以下是完整的**
|
|
|
|
|
|
|
|
|
|
~~~python
|
|
|
|
|
class MyDataProcessor(DataProcessor):
|
|
|
|
|
"""Base class for data converters for sequence classification data sets."""
|
|
|
|
|
|
|
|
|
|
def get_train_examples(self, data_dir):
|
|
|
|
|
"""Gets a collection of `InputExample`s for the train set."""
|
|
|
|
|
file_path = os.path.join(data_dir, 'train_sentiment.txt')
|
|
|
|
|
f = open(file_path, 'r', encoding='utf-8') # 读取数据,并指定中文常用的utf-8
|
|
|
|
|
train_data = []
|
|
|
|
|
index = 0 # ID值
|
|
|
|
|
for line in f.readlines(): # 参考XnliProcessor
|
|
|
|
|
guid = "train-%d" % index
|
|
|
|
|
line = line.replace("\n", "").split("\t") # 处理换行符,原数据是以tab分割
|
|
|
|
|
text_a = tokenization.convert_to_unicode(str(line[1])) # 第0位置是索引,第1位置才是数据,可以查看train_sentiment.txt
|
|
|
|
|
label = str(line[2]) # 我们的label里没有什么东西,只有数值,所以转字符串即可
|
|
|
|
|
train_data.append(
|
|
|
|
|
InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) # 这里我们没text_b,所以传入None
|
|
|
|
|
index += 1 # index每次不一样,所以加等于1
|
|
|
|
|
return train_data # 这样数据就读取完成
|
|
|
|
|
|
|
|
|
|
def get_dev_examples(self, data_dir):
|
|
|
|
|
"""Gets a collection of `InputExample`s for the dev set."""
|
|
|
|
|
file_path = os.path.join(data_dir, 'test_sentiment.txt')
|
|
|
|
|
f = open(file_path, 'r', encoding='utf-8')
|
|
|
|
|
dev_data = []
|
|
|
|
|
index = 0
|
|
|
|
|
for line in f.readlines():
|
|
|
|
|
guid = "dev-%d" % index
|
|
|
|
|
line = line.replace('\n', '').split('\t')
|
|
|
|
|
text_a = tokenization.convert_to_unicode(str(line[1]))
|
|
|
|
|
label = str(line[2])
|
|
|
|
|
dev_data.append(
|
|
|
|
|
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
|
|
|
|
index += 1
|
|
|
|
|
return dev_data
|
|
|
|
|
|
|
|
|
|
def get_test_examples(self, data_dir):
|
|
|
|
|
"""Gets a collection of `InputExample`s for prediction."""
|
|
|
|
|
file_path = os.path.join(data_dir, 'test.csv')
|
|
|
|
|
f = open(file_path, 'r', encoding='utf-8')
|
|
|
|
|
test_data = []
|
|
|
|
|
index = 0
|
|
|
|
|
for line in f.readlines():
|
|
|
|
|
guid = "test-%d" % index
|
|
|
|
|
line = line.replace('\n', '').split('\t')
|
|
|
|
|
text_a = tokenization.convert_to_unicode(str(line[0]))
|
|
|
|
|
label = str(line[1])
|
|
|
|
|
test_data.append(
|
|
|
|
|
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
|
|
|
|
index += 1
|
|
|
|
|
return test_data
|
|
|
|
|
|
|
|
|
|
def get_labels(self):
|
|
|
|
|
"""Gets the list of labels for this data set."""
|
|
|
|
|
return ["0", "1", "2"] # 参考XnliProcessor,改成返回0,1,2
|
|
|
|
|
~~~
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -152,3 +205,33 @@ def main(_):
|
|
|
|
|
-output_dir=my_model
|
|
|
|
|
~~~
|
|
|
|
|
|
|
|
|
|
> task_name:运行的模块,在main里指定了名字对应的类
|
|
|
|
|
>
|
|
|
|
|
> do_train:是否训练
|
|
|
|
|
>
|
|
|
|
|
> do_eval:是否验证
|
|
|
|
|
>
|
|
|
|
|
> data_dir:数据地址
|
|
|
|
|
>
|
|
|
|
|
> vocab_file:词库表
|
|
|
|
|
>
|
|
|
|
|
> bert_config_file:bert参数
|
|
|
|
|
>
|
|
|
|
|
> init_checkpoint:初始化参数
|
|
|
|
|
>
|
|
|
|
|
> max_seq_length:最长字符限制
|
|
|
|
|
>
|
|
|
|
|
> train_batch_size:训练次数
|
|
|
|
|
>
|
|
|
|
|
> learning_rate:学习率
|
|
|
|
|
>
|
|
|
|
|
> num_train_epochs:循环训练次数
|
|
|
|
|
>
|
|
|
|
|
> output_dir:输出路径
|
|
|
|
|
|
|
|
|
|
![1610334445172](assets/1610334445172.png)
|
|
|
|
|
|
|
|
|
|
设置参数完成,run即可
|
|
|
|
|
|
|
|
|
|
![1610334464000](assets/1610334464000.png)
|
|
|
|
|
|
|
|
|
|