Add. Loading dataset

pull/2/head
benjas 4 years ago
parent 7b787d6c09
commit 29f1c7f65c

@ -235,12 +235,191 @@
"print(tf.nn.embedding_lookup(wordVectors,firstSentence).eval().shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"在整个训练集上面构造索引之前,我们先花一些时间来可视化我们所拥有的数据类型。这将帮助我们去决定如何设置最大序列长度的最佳值。在前面的例子中,我们设置了最大长度为 10但这个值在很大程度上取决于你输入的数据。\n",
"\n",
"训练集我们使用的是 IMDB 数据集。这个数据集包含 25000 条电影数据,其中 12500 条正向数据12500 条负向数据。这些数据都是存储在一个文本文件中,首先我们需要做的就是去解析这个文件。正向数据包含在一个文件中,负向数据包含在另一个文件中。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from os import listdir\n",
"from os.path import isfile, join\n",
"# 指定数据集位置,由于提供的数据都是一个个单独的文件,所以得一个个读取\n",
"positiveFiles = ['./training_data/positiveReviews/' + f for f in listdir('./training_data/positiveReviews/') if isfile(join('./training_data/positiveReviews/', f))]\n",
"negativeFiles = ['./training_data/negativeReviews/' + f for f in listdir('./training_data/negativeReviews/') if isfile(join('./training_data/negativeReviews/', f))]\n",
"numWords = []\n",
"for pf in positiveFiles:\n",
" with open(pf, \"r\", encoding='utf-8') as f:\n",
" line=f.readline()\n",
" counter = len(line.split())\n",
" numWords.append(counter) \n",
"print('Positive files finished')\n",
"\n",
"for nf in negativeFiles:\n",
" with open(nf, \"r\", encoding='utf-8') as f:\n",
" line=f.readline()\n",
" counter = len(line.split())\n",
" numWords.append(counter) \n",
"print('Negative files finished')\n",
"\n",
"numFiles = len(numWords)\n",
"print('The total number of files is', numFiles)\n",
"print('The total number of words in the files is', sum(numWords))\n",
"print('The average number of words in the files is', sum(numWords)/len(numWords))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"plt.hist(numWords, 50)\n",
"plt.xlabel('Sequence Length')\n",
"plt.ylabel('Frequency')\n",
"plt.axis([0, 1200, 0, 8000])\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"从直方图和句子的平均单词数,我们认为将句子最大长度设置为绝大多数的长度 250 是可行的。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"maxSeqLength = 250"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 查看其中一条评论\n",
"fname = positiveFiles[3] #Can use any valid index (not just 3)\n",
"with open(fname) as f:\n",
" for lines in f:\n",
" print(lines)\n",
" exit"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"接下来,我们将它转换成一个索引矩阵。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 删除标点符号、括号、问号等,只留下字母数字字符\n",
"import re\n",
"strip_special_chars = re.compile(\"[^A-Za-z0-9 ]+\")\n",
"\n",
"def cleanSentences(string):\n",
" string = string.lower().replace(\"<br />\", \" \")\n",
" return re.sub(strip_special_chars, \"\", string.lower())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"firstFile = np.zeros((maxSeqLength), dtype='int32')\n",
"with open(fname) as f:\n",
" indexCounter = 0\n",
" line=f.readline()\n",
" cleanedLine = cleanSentences(line)\n",
" split = cleanedLine.split()\n",
" for word in split:\n",
" try:\n",
" firstFile[indexCounter] = wordsList.index(word)\n",
" except ValueError:\n",
" firstFile[indexCounter] = 399999 #Vector for unknown words\n",
" indexCounter = indexCounter + 1\n",
"firstFile"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"现在,我们用相同的方法来处理全部的 25000 条评论。我们将导入电影训练集,并且得到一个 25000 * 250 的矩阵。这是一个计算成本非常高的过程,可以直接使用理好的索引矩阵文件。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ids = np.load('./training_data/idsMatrix.npy')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 构建LSTM网络模型"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from random import randint\n",
"\n",
"def getTrainBatch():\n",
" labels = []\n",
" arr = np.zeros([batchSize, maxSeqLength])\n",
" for i in range(batchSize):\n",
" if (i % 2 == 0): \n",
" num = randint(1,11499)\n",
" labels.append([1,0])\n",
" else:\n",
" num = randint(13499,24999)\n",
" labels.append([0,1])\n",
" arr[i] = ids[num-1:num]\n",
" return arr, labels\n",
"\n",
"def getTestBatch():\n",
" labels = []\n",
" arr = np.zeros([batchSize, maxSeqLength])\n",
" for i in range(batchSize):\n",
" num = randint(11499,13499)\n",
" if (num <= 12499):\n",
" labels.append([1,0])\n",
" else:\n",
" labels.append([0,1])\n",
" arr[i] = ids[num-1:num]\n",
" return arr, labels"
]
}
],
"metadata": {

Loading…
Cancel
Save