diff --git a/NLP通用框架BERT项目实战/assets/1610334445172.png b/NLP通用框架BERT项目实战/assets/1610334445172.png new file mode 100644 index 0000000..1db6555 Binary files /dev/null and b/NLP通用框架BERT项目实战/assets/1610334445172.png differ diff --git a/NLP通用框架BERT项目实战/assets/1610334464000.png b/NLP通用框架BERT项目实战/assets/1610334464000.png new file mode 100644 index 0000000..91bacbb Binary files /dev/null and b/NLP通用框架BERT项目实战/assets/1610334464000.png differ diff --git a/NLP通用框架BERT项目实战/第三章——基于BERT的中文情感分析实战.md b/NLP通用框架BERT项目实战/第三章——基于BERT的中文情感分析实战.md index 3ea566e..fd8220e 100644 --- a/NLP通用框架BERT项目实战/第三章——基于BERT的中文情感分析实战.md +++ b/NLP通用框架BERT项目实战/第三章——基于BERT的中文情感分析实战.md @@ -55,7 +55,7 @@ class MyDataProcessor(DataProcessor): f = open(file_path, 'r', encoding='utf-8') # 读取数据,并指定中文常用的utf-8 train_data = [] index = 0 # ID值 - for line in f.readline(): # 参考XnliProcessor + for line in f.readlines(): # 参考XnliProcessor guid = "train-%d" % index line = line.replace('\n', '').split('\t') # 处理换行符,原数据是以tab分割 text_a = tokenization.convert_to_unicode(str(line[1])) # 第0位置是索引,第1位置才是数据,可以查看train_sentiment.txt @@ -111,7 +111,60 @@ class XnliProcessor(DataProcessor): **以下是完整的** ~~~python +class MyDataProcessor(DataProcessor): + """Base class for data converters for sequence classification data sets.""" + def get_train_examples(self, data_dir): + """Gets a collection of `InputExample`s for the train set.""" + file_path = os.path.join(data_dir, 'train_sentiment.txt') + f = open(file_path, 'r', encoding='utf-8') # 读取数据,并指定中文常用的utf-8 + train_data = [] + index = 0 # ID值 + for line in f.readlines(): # 参考XnliProcessor + guid = "train-%d" % index + line = line.replace("\n", "").split("\t") # 处理换行符,原数据是以tab分割 + text_a = tokenization.convert_to_unicode(str(line[1])) # 第0位置是索引,第1位置才是数据,可以查看train_sentiment.txt + label = str(line[2]) # 我们的label里没有什么东西,只有数值,所以转字符串即可 + train_data.append( + InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) # 这里我们没text_b,所以传入None + index += 1 # index每次不一样,所以加等于1 + return train_data # 这样数据就读取完成 + + def get_dev_examples(self, data_dir): + """Gets a collection of `InputExample`s for the dev set.""" + file_path = os.path.join(data_dir, 'test_sentiment.txt') + f = open(file_path, 'r', encoding='utf-8') + dev_data = [] + index = 0 + for line in f.readlines(): + guid = "dev-%d" % index + line = line.replace('\n', '').split('\t') + text_a = tokenization.convert_to_unicode(str(line[1])) + label = str(line[2]) + dev_data.append( + InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) + index += 1 + return dev_data + + def get_test_examples(self, data_dir): + """Gets a collection of `InputExample`s for prediction.""" + file_path = os.path.join(data_dir, 'test.csv') + f = open(file_path, 'r', encoding='utf-8') + test_data = [] + index = 0 + for line in f.readlines(): + guid = "test-%d" % index + line = line.replace('\n', '').split('\t') + text_a = tokenization.convert_to_unicode(str(line[0])) + label = str(line[1]) + test_data.append( + InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) + index += 1 + return test_data + + def get_labels(self): + """Gets the list of labels for this data set.""" + return ["0", "1", "2"] # 参考XnliProcessor,改成返回0,1,2 ~~~ @@ -152,3 +205,33 @@ def main(_): -output_dir=my_model ~~~ +> task_name:运行的模块,在main里指定了名字对应的类 +> +> do_train:是否训练 +> +> do_eval:是否验证 +> +> data_dir:数据地址 +> +> vocab_file:词库表 +> +> bert_config_file:bert参数 +> +> init_checkpoint:初始化参数 +> +> max_seq_length:最长字符限制 +> +> train_batch_size:训练次数 +> +> learning_rate:学习率 +> +> num_train_epochs:循环训练次数 +> +> output_dir:输出路径 + +![1610334445172](assets/1610334445172.png) + +设置参数完成,run即可 + +![1610334464000](assets/1610334464000.png) +