Update. Training Chinese classification of BERT model

pull/2/head
benjas 4 years ago
parent 2e880a45c7
commit 3276bf462f

Binary file not shown.

After

Width:  |  Height:  |  Size: 88 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

@ -55,7 +55,7 @@ class MyDataProcessor(DataProcessor):
f = open(file_path, 'r', encoding='utf-8') # 读取数据并指定中文常用的utf-8
train_data = []
index = 0 # ID值
for line in f.readline(): # 参考XnliProcessor
for line in f.readlines(): # 参考XnliProcessor
guid = "train-%d" % index
line = line.replace('\n', '').split('\t') # 处理换行符原数据是以tab分割
text_a = tokenization.convert_to_unicode(str(line[1])) # 第0位置是索引第1位置才是数据可以查看train_sentiment.txt
@ -111,7 +111,60 @@ class XnliProcessor(DataProcessor):
**以下是完整的**
~~~python
class MyDataProcessor(DataProcessor):
"""Base class for data converters for sequence classification data sets."""
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
file_path = os.path.join(data_dir, 'train_sentiment.txt')
f = open(file_path, 'r', encoding='utf-8') # 读取数据并指定中文常用的utf-8
train_data = []
index = 0 # ID值
for line in f.readlines(): # 参考XnliProcessor
guid = "train-%d" % index
line = line.replace("\n", "").split("\t") # 处理换行符原数据是以tab分割
text_a = tokenization.convert_to_unicode(str(line[1])) # 第0位置是索引第1位置才是数据可以查看train_sentiment.txt
label = str(line[2]) # 我们的label里没有什么东西只有数值所以转字符串即可
train_data.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) # 这里我们没text_b所以传入None
index += 1 # index每次不一样所以加等于1
return train_data # 这样数据就读取完成
def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
file_path = os.path.join(data_dir, 'test_sentiment.txt')
f = open(file_path, 'r', encoding='utf-8')
dev_data = []
index = 0
for line in f.readlines():
guid = "dev-%d" % index
line = line.replace('\n', '').split('\t')
text_a = tokenization.convert_to_unicode(str(line[1]))
label = str(line[2])
dev_data.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
index += 1
return dev_data
def get_test_examples(self, data_dir):
"""Gets a collection of `InputExample`s for prediction."""
file_path = os.path.join(data_dir, 'test.csv')
f = open(file_path, 'r', encoding='utf-8')
test_data = []
index = 0
for line in f.readlines():
guid = "test-%d" % index
line = line.replace('\n', '').split('\t')
text_a = tokenization.convert_to_unicode(str(line[0]))
label = str(line[1])
test_data.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
index += 1
return test_data
def get_labels(self):
"""Gets the list of labels for this data set."""
return ["0", "1", "2"] # 参考XnliProcessor改成返回012
~~~
@ -152,3 +205,33 @@ def main(_):
-output_dir=my_model
~~~
> task_name运行的模块在main里指定了名字对应的类
>
> do_train是否训练
>
> do_eval是否验证
>
> data_dir数据地址
>
> vocab_file词库表
>
> bert_config_filebert参数
>
> init_checkpoint初始化参数
>
> max_seq_length最长字符限制
>
> train_batch_size训练次数
>
> learning_rate学习率
>
> num_train_epochs循环训练次数
>
> output_dir输出路径
![1610334445172](assets/1610334445172.png)
设置参数完成run即可
![1610334464000](assets/1610334464000.png)

Loading…
Cancel
Save