Add.Training Chinese classification of BERT model

pull/2/head
benjas 4 years ago
parent 5099b0623c
commit 07a491bf85

@ -43,9 +43,47 @@ class MyDataProcessor(object):
参照
**读取数据的类get_train_examples**
~~~python
class MyDataProcessor(DataProcessor):
"""Base class for data converters for sequence classification data sets."""
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
file_path = os.path.join(data_dir, 'train_sentiment.txt')
f = open(file_path, 'r', encoding='utf-8') # 读取数据并指定中文常用的utf-8
train_data = []
index = 0 # ID值
for line in f.readline(): # 参考XnliProcessor
guid = "train-%d" % index
line = line.replace('\n', '').split('\t') # 处理换行符原数据是以tab分割
text_a = tokenization.convert_to_unicode(str(line[1])) # 第0位置是索引第1位置才是数据可以查看train_sentiment.txt
label = str(line[2]) # 我们的label里没有什么东西只有数值所以转字符串即可
train_data.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) # 这里我们没text_b所以传入None
index += 1 # index每次不一样所以加等于1
return train_data # 这样数据就读取完成
~~~
> 参照XnliProcessor
~~~python
class XnliProcessor(DataProcessor):
"""Processor for the XNLI data set."""
def __init__(self):
self.language = "zh"
def get_train_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(
os.path.join(data_dir, "multinli",
"multinli.train.%s.tsv" % self.language))
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "train-%d" % (i) # 获取样本ID
text_a = tokenization.convert_to_unicode(line[0])
text_b = tokenization.convert_to_unicode(line[1]) # 获取text_a和b我们只有a所以把b去掉
@ -54,5 +92,63 @@ class MyDataProcessor(object):
label = tokenization.convert_to_unicode("contradiction")
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) # 把读进来的东西传到InputExample这个类可以点进去里面什么都没做只不过是模板我们也照着做
return examples
~~~
**获取label**
~~~
# 也是参考XnliProcessor把return改成012即可
def get_labels(self):
"""Gets the list of labels for this data set."""
return ["0", "1", "2"]
~~~
**以下是完整的**
~~~python
~~~
#### 训练BERT中文分类模型
main函数增加运行内容
~~~python
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
processors = {
"cola": ColaProcessor,
"mnli": MnliProcessor,
"mrpc": MrpcProcessor,
"xnli": XnliProcessor,
'my':MyDataProcessor, # 这是增加的部分这样运行参数task_name才能对应上
}
~~~
参数
~~~
-task_name=my
-do_train=true
-do_eval=true
-data_dir=data
-vocab_file=../GLUE/BERT_BASE_DIR/chinese_L-12_H-768_A-12/vocab.txt
-bert_config_file=../GLUE/BERT_BASE_DIR/chinese_L-12_H-768_A-12/bert_config.json
-init_checkpoint=../GLUE/BERT_BASE_DIR/chinese_L-12_H-768_A-12/bert_model.ckpt
-max_seq_length=70
-train_batch_size=32
-learning_rate=5e-5
--num_train_epochs=3.0
-output_dir=my_model
~~~

Loading…
Cancel
Save