Create 源码解读.md

pull/2/head
benjas 5 years ago
parent 39d9cd1a0d
commit b100cfa87f

@ -0,0 +1,59 @@
### 源码解读
#### 数据读取模块
处理MRPC数据的类
~~~python
class MrpcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"] # 是否是二分类
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[3]) # 相关的test_a和b怎么切分
text_b = tokenization.convert_to_unicode(line[4])
if set_type == "test":
label = "0"
else:
label = tokenization.convert_to_unicode(line[0])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
~~~
读取训练数据代码:
~~~python
if FLAGS.do_train:
train_examples = processor.get_train_examples(FLAGS.data_dir)
num_train_steps = int(
len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) # 得到需要迭代的次数len(train_examples)计算出多少数据量 除以 我们设置的train_batch_size再乘上epochs次数。
num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) # 在刚开始时让学习率偏小经过warmup的百分比后再还原回原始的学习率
~~~
Loading…
Cancel
Save