{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 使用LSTM进行情感分析" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 深度学习在自然语言处理中的应用\n", "自然语言处理是教会机器如何去处理或者读懂人类语言的系统,主要应用领域:\n", "\n", "* 对话系统 - 聊天机器人(小冰)\n", "* 情感分析 - 对一段文本进行情感识别(我们现在做)\n", "* 图文映射 - CNN和RNN的融合\n", "* 机器翻译 - 将一种语言翻译成另一种语言\n", "* 语音识别 - 将语音识别成文字,如王者荣耀\n", "\n", "请回顾[第四章——递归神经网络与词向量原理解读](https://github.com/ben1234560/AiLearning-Theory-Applying/blob/master/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%85%A5%E9%97%A8/%E7%AC%AC%E5%9B%9B%E7%AB%A0%E2%80%94%E2%80%94%E9%80%92%E5%BD%92%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C%E4%B8%8E%E8%AF%8D%E5%90%91%E9%87%8F%E5%8E%9F%E7%90%86%E8%A7%A3%E8%AF%BB.md)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 词向量模型\n", "计算机只认识数字!\n", "\n", "我们可以将一句话中的每个词都转换成一个向量\n", "\n", "它们的向量维度是一致的" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "词向量是具有空间一样的,并不是简单的映射!例如,我们希望单词“love”和“adore”这两个词在向量空间中是有一定的相关性的,因为他们有类似的定义,他们都在类似的上下文中使用。单词的向量表示也被称之为词嵌入。\n", "\n", "word2vec构建的词向量正如上图,相同含义的词在高维空间上是接近的,而不同含义的词差别很远。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Word2Vec\n", "为了去得到这些词嵌入,我们使用一个非常厉害的模型\"Word2vec\"。简单的说,这个模型根据上下文的语境来推断出毎个词的词向量。如果两个个词在上下文的语境中,可以被互相替换,那么这两个词的距离就非常近。在自然语言中,上下文的语境对分析词语的意义是非常重要的。比如,之前我们提到的\"adore\"和Tove\"这两个词,我们观察如下上下文的语境。\n", "\n", "从句子中我们可以看到,这两个词通常在句子中是表现积极的,而且-般比名词或者名词组合要好。这也说明了,这两个词可以被互相替换,他们的意思是非常相近的。对于句子的语法结构分析,上下文语境也是非常重要的。所有,这个模型的作用就是从一大堆句子(以 Wikipedia为例)中为毎个独一无二的单词进行建模,并且输出一个唯一的向量。word2vec模型的输出被称为一个嵌入矩阵\n", "\n", "这个嵌入矩阵包含训练集中每个词的一个向量。传统来讲,这个嵌入矩阵中的词向量数据会很大。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Recurrent Neural Networks(RNNs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "现在,我们已经得到了神经网络的输入数据——词向量,接下来让我们看看需要构建的神经网络。NLP数据的个独特之处是它是时间序列数据。每个单词的出现都依赖于它的前—个单词和后—个单词。由于这种依赖的存在,我们使用循环神经网络来处理这种时间序列数据。循环神经网络的结构和你之前看到的那些前馈神经网络的结枃可能有一些不一样。前馈神经网络由三部分组成,输入层,隐藏层和输出层。\n", "\n", "\n", "前馈神经网络和RNN之前的主要区别就是RNN考虑了时间的信息。在RNN中,句子中的每个单词都被考虑上了时间步骤。实际上,时间步长的数量将等于最大序列长度\n", "\n", "与每个时间步骤相关联的中间状态也被作为一个新的组件,称为隐藏状态向量h(t)。从抽象的角度来看,这个向量是用来封装和汇总前面时间步骤中所看到的所有信息。就像x(t)表示一个向量,它封装了一个特定单词的所有信息。\n", "\n", "隐藏状态是当前单词向量和前一步的隐藏状态冋量的函数。并且这两项之和需要通过激活函数来进行激活。\n", "\n", "\n", "\n", "如上图,第一个词The(Xt-1)经过神经元计算(Wxt-1),得出特征向量ht-1,再给第二个词movie使用,循环如此,直至最后,综合考虑前面的所有特征。\n", "\n", "从上图我们也能看到一个问题,就是越前面的数据,越无法感知,也就是俗称的梯度消失,所以引入LSTM,LSTM在上一章节已经了解过,这里不再重复。\n", "\n", "可回顾[第四章——递归神经网络与词向量原理解读](https://github.com/ben1234560/AiLearning-Theory-Applying/blob/master/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%85%A5%E9%97%A8/%E7%AC%AC%E5%9B%9B%E7%AB%A0%E2%80%94%E2%80%94%E9%80%92%E5%BD%92%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C%E4%B8%8E%E8%AF%8D%E5%90%91%E9%87%8F%E5%8E%9F%E7%90%86%E8%A7%A3%E8%AF%BB.md)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 项目流程\n", "\n", " 1.制作词向量,可以使用gensim库,也可以直接用现成的\n", " 2.词和ID的映射\n", " 3.构建RNN网络架构\n", " 4.训练模型\n", " 5.评估结果" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 导入数据\n", "首先,我们需要去创建词向量。为了简单起见,我们使用训陈练好的模型来创建。\n", "\n", "作为该领域的个最大玩家, Google已经帮助我们在大规模数据集上训练出来了word2vec模型,包括1000亿个不同的词!在这个模型中,谷歌能创建300万个词向量,每个向量维度为300。\n", "\n", "在理想情况下,我们将使用这些向量来构建模型,但是因为这个单词向量矩阵相当大(3.6G),我们用另外个现成的小—些的,该矩阵由Gove进行训练得到。矩阵将包含400000个词向量,每个向量的维数为50。\n", "\n", "我们将导入两个不同的数据结构,一个是包含40000个单词的 Python列表,一个是包含所有单词向量值得400000`*`50维的嵌入矩阵。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "数据集目录如下\n", "\n", "其中文件夹negativeReviews和positiveReviews里是一句话一个txt\n", "\n", "数据集地址:\n", "链接:https://pan.baidu.com/s/18vPGelYCXGqp5OCWZWz36A \n", "提取码:de0f" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loaded the word list!\n", "Loaded the word vectors!\n", "400000\n", "(400000, 50)\n" ] } ], "source": [ "import numpy as np\n", "wordsList = np.load('./training_data/wordsList.npy')\n", "print('Loaded the word list!')\n", "wordsList = wordsList.tolist() #Originally loaded as numpy array\n", "wordsList = [word.decode('UTF-8') for word in wordsList] #Encode words as UTF-8\n", "wordVectors = np.load('./training_data/wordVectors.npy')\n", "print ('Loaded the word vectors!')\n", "print(len(wordsList))\n", "print(wordVectors.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "我们也可以在词库中搜索单词,比如 “baseball”,然后可以通过访问嵌入矩阵来得到相应的向量,如下:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([-1.9327 , 1.0421 , -0.78515 , 0.91033 , 0.22711 , -0.62158 ,\n", " -1.6493 , 0.07686 , -0.5868 , 0.058831, 0.35628 , 0.68916 ,\n", " -0.50598 , 0.70473 , 1.2664 , -0.40031 , -0.020687, 0.80863 ,\n", " -0.90566 , -0.074054, -0.87675 , -0.6291 , -0.12685 , 0.11524 ,\n", " -0.55685 , -1.6826 , -0.26291 , 0.22632 , 0.713 , -1.0828 ,\n", " 2.1231 , 0.49869 , 0.066711, -0.48226 , -0.17897 , 0.47699 ,\n", " 0.16384 , 0.16537 , -0.11506 , -0.15962 , -0.94926 , -0.42833 ,\n", " -0.59457 , 1.3566 , -0.27506 , 0.19918 , -0.36008 , 0.55667 ,\n", " -0.70315 , 0.17157 ], dtype=float32)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "baseballIndex = wordsList.index('baseball')\n", "wordVectors[baseballIndex]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "现在我们有了向量,我们的第一步就是输入一个句子,然后构造它的向量表示。假设我们现在的输入句子是 “I thought the movie was incredible and inspiring”。为了得到词向量,我们可以使用 TensorFlow 的嵌入函数。这个函数有两个参数,一个是嵌入矩阵(在我们的情况下是词向量矩阵),另一个是每个词对应的索引。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(10,)\n", "[ 41 804 201534 1005 15 7446 5 13767 0 0]\n" ] } ], "source": [ "import tensorflow as tf # 注意,这里是TensorFlow 1\n", "maxSeqLength = 10 # Maximum length of sentence 设置最大词数\n", "numDimensions = 300 # Dimensions for each word vector 设置每个单词最大维度\n", "firstSentence = np.zeros((maxSeqLength), dtype='int32')\n", "firstSentence[0] = wordsList.index(\"i\")\n", "firstSentence[1] = wordsList.index(\"thought\")\n", "firstSentence[2] = wordsList.index(\"the\")\n", "firstSentence[3] = wordsList.index(\"movie\")\n", "firstSentence[4] = wordsList.index(\"was\")\n", "firstSentence[5] = wordsList.index(\"incredible\")\n", "firstSentence[6] = wordsList.index(\"and\")\n", "firstSentence[7] = wordsList.index(\"inspiring\")\n", "#如果长度没有达到设置标准,用0来占位\n", "print(firstSentence.shape)\n", "print(firstSentence) #Shows the row index for each word" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "数据管道如下图所示:\n", "\n", "\n", "输出数据是一个 10*50 的词矩阵,其中包括 10 个词,每个词的向量维度是 50。就是去找到这些词对应的向量" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(10, 50)\n" ] } ], "source": [ "with tf.Session() as sess:print(tf.nn.embedding_lookup(wordVectors,firstSentence).eval().shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "在整个训练集上面构造索引之前,我们先花一些时间来可视化我们所拥有的数据类型。这将帮助我们去决定如何设置最大序列长度的最佳值。在前面的例子中,我们设置了最大长度为 10,但这个值在很大程度上取决于你输入的数据。\n", "\n", "训练集我们使用的是 IMDB 数据集。这个数据集包含 25000 条电影数据,其中 12500 条正向数据,12500 条负向数据。这些数据都是存储在一个文本文件中,首先我们需要做的就是去解析这个文件。正向数据包含在一个文件中,负向数据包含在另一个文件中。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Positive files finished\n", "Negative files finished\n", "The total number of files is 24625\n", "The total number of words in the files is 5685527\n", "The average number of words in the files is 230.88434517766498\n" ] } ], "source": [ "from os import listdir\n", "from os.path import isfile, join\n", "# 指定数据集位置,由于提供的数据都是一个个单独的文件,所以得一个个读取\n", "positiveFiles = ['./training_data/positiveReviews/' + f for f in listdir('./training_data/positiveReviews/') if isfile(join('./training_data/positiveReviews/', f))]\n", "negativeFiles = ['./training_data/negativeReviews/' + f for f in listdir('./training_data/negativeReviews/') if isfile(join('./training_data/negativeReviews/', f))]\n", "numWords = []\n", "for pf in positiveFiles:\n", " with open(pf, \"r\", encoding='utf-8') as f:\n", " line=f.readline()\n", " counter = len(line.split())\n", " numWords.append(counter) \n", "print('Positive files finished')\n", "\n", "for nf in negativeFiles:\n", " with open(nf, \"r\", encoding='utf-8') as f:\n", " line=f.readline()\n", " counter = len(line.split())\n", " numWords.append(counter) \n", "print('Negative files finished')\n", "\n", "numFiles = len(numWords)\n", "print('The total number of files is', numFiles)\n", "print('The total number of words in the files is', sum(numWords))\n", "print('The average number of words in the files is', sum(numWords)/len(numWords))" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAEKCAYAAADenhiQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAHNVJREFUeJzt3X+UHWWd5/H3x0R+K0k0sJkkTsLaC6KrIbQhyIyjgiFEh+AMjPF4lhYzk9ldZtVxd8eg7kRRzsKuK8qOIlHQwCoQUCSLzIQ2gHN2ll8dwPB70gJCmwxpNj9A0WCY7/5R3xtuQv+4na7q2/fm8zrnnlv1raeqn8cK92s9T9VTigjMzMzK9KpmV8DMzNqPk4uZmZXOycXMzErn5GJmZqVzcjEzs9I5uZiZWekqTS6S/lLSQ5IelHS1pIMkzZZ0l6SNkq6VdECWPTDXe3P7rLrjnJfxxySdWmWdzcxs9CpLLpKmAx8DOiPiLcAEYAlwEXBxRHQA24CluctSYFtEvBG4OMsh6djc783AQuDrkiZUVW8zMxu9qrvFJgIHS5oIHAJsBt4DXJ/bVwFn5PLiXCe3nyxJGb8mInZGxBNALzCv4nqbmdkoTKzqwBHxC0lfAp4Cfg3cAqwHtkfErizWB0zP5enA07nvLkk7gNdl/M66Q9fvs5ukZcAygEMPPfT4Y445pvQ2mZm1s/Xr1z8bEVPLOFZlyUXSZIqrjtnAduA64LQBitbmn9Eg2waL7xmIWAmsBOjs7Iyenp59qLWZ2f5L0s/LOlaV3WKnAE9ERH9E/Bb4AfAOYFJ2kwHMADblch8wEyC3Hw5srY8PsI+ZmY1DVSaXp4D5kg7JsZOTgYeB24Azs0wXcGMur8l1cvutUcyquQZYkneTzQY6gLsrrLeZmY1SlWMud0m6HrgX2AXcR9Ft9SPgGklfzNjlucvlwFWSeimuWJbkcR6StJoiMe0Czo2Il6qqt5mZjZ7accp9j7mYmY2cpPUR0VnGsfyEvpmZlc7JxczMSufkYmZmpXNyMTOz0jm5mJlZ6ZxczMysdE4uZmZWOicXMzMrnZOLmZmVzsnFzMxK5+RiZmalc3IxM7PSObmYmVnpnFzMzKx0Ti5mZlY6JxczMyudk4uZmZXOycXMzEpXWXKRdLSk++s+z0n6hKQpkrolbczvyVleki6R1Ctpg6S5dcfqyvIbJXVVVWczMytHZcklIh6LiDkRMQc4HngBuAFYDqyLiA5gXa4DnAZ05GcZcCmApCnACuAEYB6wopaQzMxsfBqrbrGTgZ9FxM+BxcCqjK8CzsjlxcCVUbgTmCRpGnAq0B0RWyNiG9ANLByjepuZ2T4Yq+SyBLg6l4+MiM0A+X1ExqcDT9ft05exweJmZjZOVZ5cJB0AnA5cN1zRAWIxRHzvv7NMUo+knv7+/pFX1MzMSjMWVy6nAfdGxDO5/kx2d5HfWzLeB8ys228GsGmI+B4iYmVEdEZE59SpU0tugpmZjcRYJJcP8XKXGMAaoHbHVxdwY1387LxrbD6wI7vN1gILJE3OgfwFGTMzs3FqYpUHl3QI8F7gz+vCFwKrJS0FngLOyvjNwCKgl+LOsnMAImKrpC8A92S58yNia5X1NjOz0VHEK4YvWl5nZ2f09PQ0uxpmZi1F0vqI6CzjWH5C38zMSufkYmZmpXNyMTOz0jm5mJlZ6ZxczMysdE4uZmZWOicXMzMrnZOLmZmVzsnFzMxK5+RiZmalc3IxM7PSObmYmVnpnFzMzKx0Ti5mZlY6JxczMyudk4uZmZXOycXMzErn5GJmZqVzcjEzs9JVmlwkTZJ0vaRHJT0i6URJUyR1S9qY35OzrCRdIqlX0gZJc+uO05XlN0rqqrLOZmY2elVfuXwV+LuIOAZ4G/AIsBxYFxEdwLpcBzgN6MjPMuBSAElTgBXACcA8YEUtIZmZ2fhUWXKR9FrgncDlABHxYkRsBxYDq7LYKuCMXF4MXBmFO4FJkqYBpwLdEbE1IrYB3cDCquptZmajV+WVy1FAP/BtSfdJ+pakQ4EjI2IzQH4fkeWnA0/X7d+XscHie5C0TFKPpJ7+/v7yW2NmZg2rMrlMBOYCl0bEccCveLkLbCAaIBZDxPcMRKyMiM6I6Jw6deq+1NfMzEpSZXLpA/oi4q5cv54i2TyT3V3k95a68jPr9p8BbBoibmZm41RlySUi/gl4WtLRGToZeBhYA9Tu+OoCbszlNcDZedfYfGBHdputBRZImpwD+QsyZmZm49TEio//H4DvSjoAeBw4hyKhrZa0FHgKOCvL3gwsAnqBF7IsEbFV0heAe7Lc+RGxteJ6m5nZKCjiFcMXLa+zszN6enqaXQ0zs5YiaX1EdJZxLD+hb2ZmpXNyMTOz0jm5mJlZ6ZxczMysdE4uZmZWOicXMzMrnZOLmZmVzsnFzMxK5+RiZmalc3IxM7PSObmYmVnpnFzMzKx0Ti5mZlY6JxczMyudk4uZmZXOycXMzErn5GJmZqVzcjEzs9JVmlwkPSnpAUn3S+rJ2BRJ3ZI25vfkjEvSJZJ6JW2QNLfuOF1ZfqOkrirrbGZmozcWVy7vjog5de9lXg6si4gOYF2uA5wGdORnGXApFMkIWAGcAMwDVtQSkpmZjU/N6BZbDKzK5VXAGXXxK6NwJzBJ0jTgVKA7IrZGxDagG1g41pU2M7PGVZ1cArhF0npJyzJ2ZERsBsjvIzI+HXi6bt++jA0W34OkZZJ6JPX09/eX3AwzMxuJiRUf/6SI2CTpCKBb0qNDlNUAsRgivmcgYiWwEqCzs/MV283MbOxUeuUSEZvyewtwA8WYyTPZ3UV+b8nifcDMut1nAJuGiJuZ2TjVUHKR9JaRHljSoZJeU1sGFgAPAmuA2h1fXcCNubwGODvvGpsP7Mhus7XAAkmTcyB/QcbMzGycarRb7BuSDgC+A3wvIrY3sM+RwA2San/nexHxd5LuAVZLWgo8BZyV5W8GFgG9wAvAOQARsVXSF4B7stz5EbG1wXqbmVkTKKKx4QlJHcBHKZLB3cC3I6K7wrrts87Ozujp6Wl2NczMWoqk9XWPjYxKw2MuEbER+CzwKeAPgEskPSrpj8qoiJmZtY9Gx1zeKuli4BHgPcAfRsSbcvniCutnZmYtqNExl78Bvgl8OiJ+XQvmbcafraRmZmbWshpNLouAX0fESwCSXgUcFBEvRMRVldXOzMxaUqNjLj8GDq5bPyRjZmZmr9BocjkoIn5ZW8nlQ6qpkpmZtbpGk8uv9poC/3jg10OUNzOz/VijYy6fAK6TVJt2ZRrwwWqqZGZmra6h5BIR90g6BjiaYiLJRyPit5XWzMzMWtZIZkV+OzAr9zlOEhFxZSW1MjOzltZQcpF0FfAvgfuBlzIcgJOLmZm9QqNXLp3AsdHoRGT2CrOW/2jI7U9e+L4xqomZWfUavVvsQeBfVFkRMzNrH41eubweeFjS3cDOWjAiTq+kVmZm1tIaTS6fq7ISZmbWXhq9Ffknkn4X6IiIH0s6BJhQbdXMzKxVNTrl/p8B1wOXZWg68MOqKmVmZq2t0QH9c4GTgOdg94vDjqiqUmZm1toaTS47I+LF2oqkiRTPuQxL0gRJ90m6KddnS7pL0kZJ10o6IOMH5npvbp9Vd4zzMv6YpFMbbZyZmTVHo8nlJ5I+DRws6b3AdcD/bnDfj1O8wbLmIuDiiOgAtgFLM74U2BYRb6R4u+VFAJKOBZYAbwYWAl+X5PEeM7NxrNHkshzoBx4A/hy4GRj2DZSSZgDvA76V66J4NfL1WWQVcEYuL851cvvJWX4xcE1E7IyIJ4BeYF6D9TYzsyZo9G6xf6Z4zfE3R3j8rwB/Bbwm118HbI+IXbneR3FzAPn9dP69XZJ2ZPnpwJ11x6zfZzdJy4BlAG94wxtGWE0zMytTo3eLPSHp8b0/w+zzfmBLRKyvDw9QNIbZNtQ+LwciVkZEZ0R0Tp06daiqmZlZxUYyt1jNQcBZwJRh9jkJOF3SotzntRRXMpMkTcyrlxlA7R0xfcBMoC9vGDgc2FoXr6nfZ1wYbt4wM7P9TUNXLhHx/+o+v4iIr1CMnQy1z3kRMSMiZlEMyN8aER8GbgPOzGJdwI25vCbXye235kSZa4AleTfZbKADuLvxJpqZ2VhrdMr9uXWrr6K4knnNIMWH8yngGklfBO4DLs/45cBVknoprliWAETEQ5JWAw8Du4BzI+KlVx7WzMzGi0a7xf5H3fIu4EngTxr9IxFxO3B7Lj/OAHd7RcRvKLrbBtr/AuCCRv+emZk1V6N3i7276oqYmVn7aLRb7JNDbY+IL5dTHTMzawcjuVvs7RSD6wB/CPw9+VyKmZlZvZG8LGxuRDwPIOlzwHUR8adVVczMzFpXo9O/vAF4sW79RWBW6bUxM7O20OiVy1XA3ZJuoHg6/gPAlZXVyszMWlqjd4tdIOlvgd/P0DkRcV911TIzs1bWaLcYwCHAcxHxVYopWmZXVCczM2txjU5cuYLiyfrzMvRq4H9VVSkzM2ttjV65fAA4HfgVQERsYt+nfzEzszbXaHJ5MSeRDABJh1ZXJTMza3WNJpfVki6jmC7/z4AfM/IXh5mZ2X6i0bvFviTpvcBzwNHAX0dEd6U1MzOzljVscpE0AVgbEacATihmZjasYbvF8t0pL0g6fAzqY2ZmbaDRJ/R/AzwgqZu8YwwgIj5WSa3MzKylNZpcfpQfMzOzYQ2ZXCS9ISKeiohVY1UhMzNrfcONufywtiDp+yM5sKSDJN0t6aeSHpL0+YzPlnSXpI2SrpV0QMYPzPXe3D6r7ljnZfwxSaeOpB5mZjb2hksuqls+aoTH3gm8JyLeBswBFkqaD1wEXBwRHcA2YGmWXwpsi4g3AhdnOSQdCywB3gwsBL6ed7CZmdk4NVxyiUGWhxWFX+bqq/MTwHuA6zO+CjgjlxfnOrn9ZEnK+DURsTMingB6gXkjqYuZmY2t4ZLL2yQ9J+l54K25/Jyk5yU9N9zBJU2QdD+wheIZmZ8B2yNiVxbpA6bn8nTytcm5fQfwuvr4APvU/61lknok9fT39w9XNTMzq9CQA/oRMarup3xGZo6kScANwJsGKpbfGmTbYPG9/9ZKYCVAZ2fniK6yzMysXCN5n8s+i4jtwO3AfIr5yWpJbQawKZf7gJkAuf1wYGt9fIB9zMxsHGr0OZcRkzQV+G1EbJd0MHAKxSD9bcCZwDVAF3Bj7rIm1+/I7bdGREhaA3xP0peB3wE6gLurqnezzFo+/GNET174vjGoiZnZ6FWWXIBpwKq8s+tVwOqIuEnSw8A1kr4I3AdcnuUvB66S1EtxxbIEICIekrQaeBjYBZyb3W1mZjZOVZZcImIDcNwA8ccZ4G6viPgNcNYgx7oAuKDsOpqZWTXGZMzFzMz2L04uZmZWOicXMzMrnZOLmZmVzsnFzMxK5+RiZmalc3IxM7PSObmYmVnpnFzMzKx0Ti5mZlY6JxczMyudk4uZmZXOycXMzErn5GJmZqVzcjEzs9I5uZiZWemcXMzMrHROLmZmVrrKkoukmZJuk/SIpIckfTzjUyR1S9qY35MzLkmXSOqVtEHS3LpjdWX5jZK6qqqzmZmVo8orl13Af4yINwHzgXMlHQssB9ZFRAewLtcBTgM68rMMuBSKZASsAE4A5gEragnJzMzGp8qSS0Rsjoh7c/l54BFgOrAYWJXFVgFn5PJi4Moo3AlMkjQNOBXojoitEbEN6AYWVlVvMzMbvTEZc5E0CzgOuAs4MiI2Q5GAgCOy2HTg6brd+jI2WHzvv7FMUo+knv7+/rKbYGZmIzCx6j8g6TDg+8AnIuI5SYMWHSAWQ8T3DESsBFYCdHZ2vmJ7O5i1/EdDbn/ywveNUU3MzIZW6ZWLpFdTJJbvRsQPMvxMdneR31sy3gfMrNt9BrBpiLiZmY1TVd4tJuBy4JGI+HLdpjVA7Y6vLuDGuvjZedfYfGBHdputBRZImpwD+QsyZmZm41SV3WInAf8GeEDS/Rn7NHAhsFrSUuAp4KzcdjOwCOgFXgDOAYiIrZK+ANyT5c6PiK0V1tvMzEapsuQSEf+HgcdLAE4eoHwA5w5yrCuAK8qrnZmZVclP6JuZWemcXMzMrHROLmZmVjonFzMzK52Ti5mZlc7JxczMSufkYmZmpat8bjEbO8PNPQaef8zMxoavXMzMrHROLmZmVjonFzMzK52Ti5mZlc4D+g1oZKDczMxe5isXMzMrnZOLmZmVzsnFzMxK5zGX/cxw40d+yNLMyuArFzMzK11lVy6SrgDeD2yJiLdkbApwLTALeBL4k4jYJknAV4FFwAvARyLi3tynC/hsHvaLEbGq7Lr6bjAzs3JVeeXyHWDhXrHlwLqI6ADW5TrAaUBHfpYBl8LuZLQCOAGYB6yQNLnCOpuZWQkqSy4R8ffA1r3Ci4Halccq4Iy6+JVRuBOYJGkacCrQHRFbI2Ib0M0rE5aZmY0zYz3mcmREbAbI7yMyPh14uq5cX8YGi5uZ2Tg2Xgb0NUAshoi/8gDSMkk9knr6+/tLrZyZmY3MWN+K/IykaRGxObu9tmS8D5hZV24GsCnj79orfvtAB46IlcBKgM7OzgETkA3PtyqbWRnG+splDdCVy13AjXXxs1WYD+zIbrO1wAJJk3Mgf0HGzMxsHKvyVuSrKa46Xi+pj+KurwuB1ZKWAk8BZ2XxmyluQ+6luBX5HICI2CrpC8A9We78iNj7JgEzMxtnKksuEfGhQTadPEDZAM4d5DhXAFeUWDUzM6uYp3+xEWnkgVOPy5jZeLlbzMzM2oivXKx0vuPMzHzlYmZmpXNyMTOz0jm5mJlZ6ZxczMysdB7QtzHn25nN2p+vXMzMrHROLmZmVjp3i9m45GdlzFqbk4u1JI/bmI1v7hYzM7PS+crF2pa71syax1cuZmZWOl+52H7L4zZm1XFyMRtCIwloKE5Otr9q++Qy2h8Hs9Hw1ZHtr9o+uZiNd77xwNpRyyQXSQuBrwITgG9FxIVNrpLZmBirq28nsZf5inP0WiK5SJoAfA14L9AH3CNpTUQ83NyambWPVulCLuNHvVXa2spaIrkA84DeiHgcQNI1wGLAycVsP+PE0BpaJblMB56uW+8DTqgvIGkZsCxXd0p6cIzq1gyvB55tdiUq5Pa1tnZu3+626aIm16QaR5d1oFZJLhogFnusRKwEVgJI6omIzrGoWDO4fa3N7Wtd7dw2KNpX1rFa5Qn9PmBm3foMYFOT6mJmZsNoleRyD9AhabakA4AlwJom18nMzAbREt1iEbFL0l8AayluRb4iIh4aYpeVY1OzpnH7Wpvb17rauW1QYvsUEcOXMjMzG4FW6RYzM7MW4uRiZmala7vkImmhpMck9Upa3uz6jJSkmZJuk/SIpIckfTzjUyR1S9qY35MzLkmXZHs3SJrb3BY0RtIESfdJuinXZ0u6K9t3bd64gaQDc703t89qZr0bIWmSpOslPZrn8cR2On+S/jL/bT4o6WpJB7Xy+ZN0haQt9c/G7cv5ktSV5TdK6mpGWwYySPv+e/773CDpBkmT6radl+17TNKpdfGR/bZGRNt8KAb7fwYcBRwA/BQ4ttn1GmEbpgFzc/k1wD8CxwL/DVie8eXARbm8CPhbimeB5gN3NbsNDbbzk8D3gJtyfTWwJJe/Afy7XP73wDdyeQlwbbPr3kDbVgF/mssHAJPa5fxRPND8BHBw3Xn7SCufP+CdwFzgwbrYiM4XMAV4PL8n5/LkZrdtiPYtACbm8kV17Ts2fzcPBGbn7+mEffltbXrDS/4f8URgbd36ecB5za7XKNt0I8Wcao8B0zI2DXgsly8DPlRXfne58fqheE5pHfAe4Kb8D/XZun/su88jxR2CJ+byxCynZrdhiLa9Nn98tVe8Lc4fL8+WMSXPx03Aqa1+/oBZe/34juh8AR8CLquL71Gu2Z+927fXtg8A383lPX4za+dvX35b261bbKBpYqY3qS6jll0IxwF3AUdGxGaA/D4ii7Vim78C/BXwz7n+OmB7ROzK9fo27G5fbt+R5cero4B+4NvZ7fctSYfSJucvIn4BfAl4CthMcT7W0z7nr2ak56ulzuNePkpxNQYltq/dksuw08S0CkmHAd8HPhERzw1VdIDYuG2zpPcDWyJifX14gKLRwLbxaCJFF8SlEXEc8CuKbpXBtFT7cuxhMUWXye8AhwKnDVC0Vc/fcAZrT0u2U9JngF3Ad2uhAYrtU/vaLbm0xTQxkl5NkVi+GxE/yPAzkqbl9mnAloy3WptPAk6X9CRwDUXX2FeASZJqD/XWt2F3+3L74cDWsazwCPUBfRFxV65fT5Fs2uX8nQI8ERH9EfFb4AfAO2if81cz0vPVaueRvOng/cCHI/u6KLF97ZZcWn6aGEkCLgceiYgv121aA9TuQOmiGIupxc/Ou1jmAztql/PjUUScFxEzImIWxfm5NSI+DNwGnJnF9m5frd1nZvlx+/8II+KfgKcl1WaXPZni1RBtcf4ousPmSzok/63W2tcW56/OSM/XWmCBpMl5dbcgY+OSipcvfgo4PSJeqNu0BliSd/nNBjqAu9mX39ZmDzRVMHC1iOIOq58Bn2l2ffah/r9Hcbm5Abg/P4so+qnXARvze0qWF8WL1H4GPAB0NrsNI2jru3j5brGj8h9xL3AdcGDGD8r13tx+VLPr3UC75gA9eQ5/SHH3UNucP+DzwKPAg8BVFHcWtez5A66mGD/6LcX/Q1+6L+eLYuyiNz/nNLtdw7Svl2IMpfYb84268p/J9j0GnFYXH9Fvq6d/MTOz0rVbt5iZmY0DTi5mZlY6JxczMyudk4uZmZXOycXMzErn5GJtQdJncqbeDZLul3RCs+s0GpK+I+nM4Uvu8/HnSFpUt/45Sf+pqr9n+5+WeM2x2VAknUjxpPHciNgp6fUUM7fa4OYAncDNza6ItSdfuVg7mAY8GxE7ASLi2YjYBCDpeEk/kbRe0tq6KT2Ol/RTSXfkuy0ezPhHJP1N7cCSbpL0rlxekOXvlXRdzv+GpCclfT7jD0g6JuOHSfp2xjZI+uOhjtMISf9Z0j15vM9nbJaK98Z8M6/ebpF0cG57e5bd3c58wvp84IN5lffBPPyxkm6X9Likj+3z2TDDycXawy3ATEn/KOnrkv4Ads/R9j+BMyPieOAK4ILc59vAxyLixEb+QF4NfRY4JSLmUjyB/8m6Is9m/FKg1r30XyimB/nXEfFW4NYGjjNUHRZQTMcxj+LK43hJ78zNHcDXIuLNwHbgj+va+W+znS8BRMSLwF9TvFtlTkRcm2WPoZg+fx6wIv/3M9sn7hazlhcRv5R0PPD7wLuBa1W8Ka8HeAvQXUyDxQRgs6TDgUkR8ZM8xFUMPLNvvfkUL1L6hzzWAcAdddtrE4yuB/4ol0+hmIOpVs9tKmaFHuo4Q1mQn/ty/TCKpPIUxWSS99fVYZaKtwu+JiL+b8a/R9F9OJgf5dXfTklbgCMppgsxGzEnF2sLEfEScDtwu6QHKCYbXA88tPfVSf7oDjbv0S72vKI/qLYb0B0RHxpkv535/RIv/3elAf7OcMcZioD/GhGX7REs3vuzsy70EnAwA0+TPpS9j+HfB9tn7hazlifpaEkddaE5wM8pJt6bmgP+SHq1pDdHxHZgh6Tfy/Ifrtv3SWCOpFdJmknRRQRwJ3CSpDfmsQ6R9K+GqdotwF/U1XPyPh6nZi3w0bqxnumSjhiscERsA57P2Xuh7ioKeJ7iNdpmlXBysXZwGLBK0sOSNlB0O30uxxbOBC6S9FOK2V/fkfucA3xN0h3Ar+uO9Q8Uryl+gOKNi/cCREQ/xbvir86/cSfFGMVQvghMzkH0nwLvHuFxLpPUl587IuIWiq6tO/Lq7HqGTxBLgZXZTlG8CRKKKfKP3WtA36w0nhXZ9nvZrXRTRLylyVUpnaTDIuKXubyc4r3wH29ytWw/4D5Vs/b2PknnUfy3/nOKqyazyvnKxczMSucxFzMzK52Ti5mZlc7JxczMSufkYmZmpXNyMTOz0v1/CffxoCRf5ugAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "plt.hist(numWords, 50)\n", "plt.xlabel('Sequence Length')\n", "plt.ylabel('Frequency')\n", "plt.axis([0, 1200, 0, 8000])\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "从直方图和句子的平均单词数,我们认为将句子最大长度设置为绝大多数的长度 250 是可行的。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "maxSeqLength = 250" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The beginning of this movie is excellent with tremendous sound and some nice humor, but once the film changes into animation it quickly loses its appeal.

One of the reasons that was so, at least for me, was that the colors in much of the animation are too muted, with too little contrast. It doesn't look good, at least on VHS. Once in a while it breaks out and looks great, but not often Also, the characters come and go too quickly. For example, I would have liked to have seen more of \"Moby Dick.\" When the film starts to drag, however, it picks up again with the entrance of the dragon and then the film finishes strong.

Overall, just not memorable enough or able to compete with the great animated films of the last dozen years.\n" ] } ], "source": [ "# 查看其中一条评论\n", "fname = positiveFiles[3] #Can use any valid index (not just 3)\n", "with open(fname) as f:\n", " for lines in f:\n", " print(lines)\n", " exit" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接下来,我们将它转换成一个索引矩阵。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# 删除标点符号、括号、问号等,只留下字母数字字符\n", "import re\n", "strip_special_chars = re.compile(\"[^A-Za-z0-9 ]+\")\n", "\n", "def cleanSentences(string):\n", " string = string.lower().replace(\"
\", \" \")\n", " return re.sub(strip_special_chars, \"\", string.lower())" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([201534, 1084, 3, 37, 1005, 14, 4345, 17,\n", " 5977, 1507, 5, 77, 3082, 6202, 34, 442,\n", " 201534, 319, 1046, 75, 7673, 20, 1177, 7233,\n", " 47, 1574, 48, 3, 201534, 1997, 12, 15,\n", " 100, 22, 338, 10, 285, 15, 12, 201534,\n", " 5224, 6, 181, 3, 201534, 7673, 32, 317,\n", " 15717, 17, 317, 333, 3313, 20, 136283, 662,\n", " 219, 22, 338, 13, 20237, 442, 6, 7,\n", " 110, 20, 4573, 66, 5, 2146, 353, 34,\n", " 36, 456, 52, 201534, 2153, 326, 5, 242,\n", " 317, 1177, 10, 880, 41, 54, 33, 5572,\n", " 4, 33, 541, 56, 3, 32308, 4159, 61,\n", " 201534, 319, 2383, 4, 7280, 212, 20, 7199,\n", " 60, 378, 17, 201534, 4232, 3, 201534, 7394,\n", " 5, 127, 201534, 319, 9131, 562, 1250, 120,\n", " 36, 8787, 575, 46, 667, 4, 2797, 17,\n", " 201534, 353, 6092, 1588, 3, 201534, 76, 2068,\n", " 82, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0], dtype=int32)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "firstFile = np.zeros((maxSeqLength), dtype='int32')\n", "with open(fname) as f:\n", " indexCounter = 0\n", " line=f.readline()\n", " cleanedLine = cleanSentences(line)\n", " split = cleanedLine.split()\n", " for word in split:\n", " try:\n", " firstFile[indexCounter] = wordsList.index(word)\n", " except ValueError:\n", " firstFile[indexCounter] = 399999 #Vector for unknown words\n", " indexCounter = indexCounter + 1\n", "firstFile" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "现在,我们用相同的方法来处理全部的 25000 条评论。我们将导入电影训练集,并且得到一个 25000 * 250 的矩阵。这是一个计算成本非常高的过程,可以直接使用理好的索引矩阵文件。" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "ids = np.load('./training_data/idsMatrix.npy')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 构建LSTM网络模型" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### RNN Model\n", "现在,我们可以开始构建我们的 TensorFlow 图模型。首先,我们需要去定义一些超参数,比如批处理大小,LSTM的单元个数,分类类别和训练次数。" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "batchSize = 24 # 梯度处理的大小\n", "lstmUnits = 64 # 隐藏层神经元数量\n", "numClasses = 2 # 分类数量,n/p\n", "iterations = 50000 # 迭代次数" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "与大多数 TensorFlow 图一样,现在我们需要指定两个占位符,一个用于数据输入,另一个用于标签数据。对于占位符,最重要的一点就是确定好维度。\n", "\n", "标签占位符代表一组值,每一个值都为 [1,0] 或者 [0,1],这个取决于数据是正向的还是负向的。输入占位符,是一个整数化的索引数组。 \n", "\n", "" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "tf.reset_default_graph()\n", "\n", "labels = tf.placeholder(tf.float32, [batchSize, numClasses])\n", "input_data = tf.placeholder(tf.int32, [batchSize, maxSeqLength])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "一旦,我们设置了我们的输入数据占位符,我们可以调用 tf. nn. embedding lookup0函数来得到我们的词向量。该函数最后将返回一个三维向量,第一个维度是批处理大小,第二个维度是句子长度,第三个维度是词向量长度。更清晰的表达,如下图际示" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "data = tf.Variable(tf.zeros([batchSize, maxSeqLength, numDimensions]),dtype=tf.float32)\n", "data = tf.nn.embedding_lookup(wordVectors,input_data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "现在我们已经得到了我们想要的数据形式,那么揭晓了我们看看如何才能将这种数据形式输入到我们的 LSTM 网络中。首先,我们使用 tf.nn.rnn_cell.BasicLSTMCell 函数,这个函数输入的参数是一个整数,表示需要几个 LSTM 单元。这是我们设置的一个超参数,我们需要对这个数值进行调试从而来找到最优的解。然后,我们会设置一个 dropout 参数,以此来避免一些过拟合。\n", "\n", "最后,我们将 LSTM cell 和三维的数据输入到 tf.nn.dynamic_rnn ,这个函数的功能是展开整个网络,并且构建一整个 RNN 模型。" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:\n", "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n", "For more information, please see:\n", " * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n", " * https://github.com/tensorflow/addons\n", " * https://github.com/tensorflow/io (for I/O related ops)\n", "If you depend on functionality not listed there, please file an issue.\n", "\n", "WARNING:tensorflow:From :1: BasicLSTMCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "This class is equivalent as tf.keras.layers.LSTMCell, and will be replaced by that in Tensorflow 2.0.\n", "WARNING:tensorflow:From :3: dynamic_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Please use `keras.layers.RNN(cell)`, which is equivalent to this API\n", "WARNING:tensorflow:From /opt/conda/envs/tensorflow_py3/lib/python3.6/site-packages/tensorflow_core/python/ops/rnn_cell_impl.py:735: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Please use `layer.add_weight` method instead.\n", "WARNING:tensorflow:From /opt/conda/envs/tensorflow_py3/lib/python3.6/site-packages/tensorflow_core/python/ops/rnn_cell_impl.py:739: calling Zeros.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Call initializer instance with the dtype argument instead of passing it to the constructor\n" ] } ], "source": [ "lstmCell = tf.contrib.rnn.BasicLSTMCell(lstmUnits) # 基本单元\n", "lstmCell = tf.contrib.rnn.DropoutWrapper(cell=lstmCell, output_keep_prob=0.75) # 解决一些过拟合问题,output_keep_prob保留比例,这个在LSTM的讲解中有解释过\n", "value, _ = tf.nn.dynamic_rnn(lstmCell, data, dtype=tf.float32) # 构建网络,value是值h,_ 是中间传递结果,这里不需要分析所以去掉" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "堆栈 LSTM 网络是一个比较好的网络架构。也就是前一个LSTM 隐藏层的输出是下一个LSTM的输入。堆栈LSTM可以帮助模型记住更多的上下文信息,但是带来的弊端是训练参数会增加很多,模型的训练时间会很长,过拟合的几率也会增加。\n", "\n", "dynamic RNN 函数的第一个输出可以被认为是最后的隐藏状态向量。这个向量将被重新确定维度,然后乘以最后的权重矩阵和一个偏置项来获得最终的输出值。" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# 权重参数初始化\n", "weight = tf.Variable(tf.truncated_normal([lstmUnits, numClasses]))\n", "bias = tf.Variable(tf.constant(0.1, shape=[numClasses]))\n", "value = tf.transpose(value, [1, 0, 2])\n", "# 获取最终的结果值\n", "last = tf.gather(value, int(value.get_shape()[0]) - 1) # 去ht\n", "prediction = (tf.matmul(last, weight) + bias) # 最终连上w和b" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接下来,我们需要定义正确的预测函数和正确率评估参数。正确的预测形式是查看最后输出的0-1向量是否和标记的0-1向量相同。" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "correctPred = tf.equal(tf.argmax(prediction,1), tf.argmax(labels,1))\n", "accuracy = tf.reduce_mean(tf.cast(correctPred, tf.float32))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "之后,我们使用一个标准的交叉熵损失函数来作为损失值。对于优化器,我们选择 Adam,并且采用默认的学习率" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From :1: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "\n", "Future major versions of TensorFlow will allow gradients to flow\n", "into the labels input on backprop by default.\n", "\n", "See `tf.nn.softmax_cross_entropy_with_logits_v2`.\n", "\n" ] } ], "source": [ "loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=labels))\n", "optimizer = tf.train.AdamOptimizer().minimize(loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 训练与测试结果" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 超参数调整\n", "\n", "选择合适的超参数来训练你的神经网络是至关重要的。你会发现你的训练损失值与你选择的优化器(Adam,Adadelta,SGD,等等),学习率和网络架构都有很大的关系。特别是在RNN和LSTM中,单元数量和词向量的大小都是重要因素。\n", "\n", " * 学习率:RNN最难的一点就是它的训练非常困难,因为时间步骤很长。那么,学习率就变得非常重要了。如果我们将学习率设置的很大,那么学习曲线就会波动性很大,如果我们将学习率设置的很小,那么训练过程就会非常缓慢。根据经验,将学习率默认设置为 0.001 是一个比较好的开始。如果训练的非常缓慢,那么你可以适当的增大这个值,如果训练过程非常的不稳定,那么你可以适当的减小这个值。\n", " \n", " * 优化器:这个在研究中没有一个一致的选择,但是 Adam 优化器被广泛的使用。\n", " \n", " * LSTM单元的数量:这个值很大程度上取决于输入文本的平均长度。而更多的单元数量可以帮助模型存储更多的文本信息,当然模型的训练时间就会增加很多,并且计算成本会非常昂贵。\n", " \n", " * 词向量维度:词向量的维度一般我们设置为50到300。维度越多意味着可以存储更多的单词信息,但是你需要付出的是更昂贵的计算成本。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 训练\n", "\n", "训练过程的基本思路是,我们首先先定义一个 TensorFlow 会话。然后,我们加载一批评论和对应的标签。接下来,我们调用会话的 run 函数。这个函数有两个参数,第一个参数被称为 fetches 参数,这个参数定义了我们感兴趣的值。我们希望通过我们的优化器来最小化损失函数。第二个参数被称为 feed_dict 参数。这个数据结构就是我们提供给我们的占位符。我们需要将一个批处理的评论和标签输入模型,然后不断对这一组训练数据进行循环训练。" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "# 辅助函数\n", "from random import randint\n", "# 制作batch数据,通过数据集索引位置来设置训练集和预测集\n", "# 并让batch中正负样本各占一半,同事给定其当前标签\n", "def getTrainBatch():\n", " labels = []\n", " arr = np.zeros([batchSize, maxSeqLength])\n", " for i in range(batchSize):\n", " if (i % 2 == 0): \n", " num = randint(1,11499)\n", " labels.append([1,0])\n", " else:\n", " num = randint(13499,24999)\n", " labels.append([0,1])\n", " arr[i] = ids[num-1:num]\n", " return arr, labels\n", "\n", "def getTestBatch():\n", " labels = []\n", " arr = np.zeros([batchSize, maxSeqLength])\n", " for i in range(batchSize):\n", " num = randint(11499,13499)\n", " if (num <= 12499):\n", " labels.append([1,0])\n", " else:\n", " labels.append([0,1])\n", " arr[i] = ids[num-1:num]\n", " return arr, labels" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "iteration 1001/50000... loss 0.6618650555610657... accuracy 0.5416666865348816...\n", "iteration 2001/50000... loss 0.6902415156364441... accuracy 0.625...\n", "iteration 3001/50000... loss 0.6586635112762451... accuracy 0.6666666865348816...\n", "iteration 4001/50000... loss 0.677685558795929... accuracy 0.5833333134651184...\n", "iteration 5001/50000... loss 0.723791778087616... accuracy 0.4166666567325592...\n", "iteration 6001/50000... loss 0.6554713845252991... accuracy 0.625...\n", "iteration 7001/50000... loss 0.6332544684410095... accuracy 0.625...\n", "iteration 8001/50000... loss 0.687525749206543... accuracy 0.4583333432674408...\n", "iteration 9001/50000... loss 0.7028540968894958... accuracy 0.4583333432674408...\n", "iteration 10001/50000... loss 0.38588687777519226... accuracy 0.875...\n", "saved to models/pretrained_lstm.ckpt-10000\n", "iteration 11001/50000... loss 0.5381348133087158... accuracy 0.75...\n", "iteration 12001/50000... loss 0.4389437735080719... accuracy 0.875...\n", "iteration 13001/50000... loss 0.1364736258983612... accuracy 0.9166666865348816...\n", "iteration 14001/50000... loss 0.44142401218414307... accuracy 0.8333333134651184...\n", "iteration 15001/50000... loss 0.25360292196273804... accuracy 0.875...\n", "iteration 16001/50000... loss 0.2593815326690674... accuracy 0.8333333134651184...\n", "iteration 17001/50000... loss 0.22918541729450226... accuracy 0.875...\n", "iteration 18001/50000... loss 0.5921979546546936... accuracy 0.8333333134651184...\n", "iteration 19001/50000... loss 0.12871651351451874... accuracy 0.9166666865348816...\n", "iteration 20001/50000... loss 0.27459517121315... accuracy 0.875...\n", "saved to models/pretrained_lstm.ckpt-20000\n", "iteration 21001/50000... loss 0.08171811699867249... accuracy 1.0...\n", "iteration 22001/50000... loss 0.11089088767766953... accuracy 1.0...\n", "iteration 23001/50000... loss 0.045842695981264114... accuracy 1.0...\n", "iteration 24001/50000... loss 0.38367900252342224... accuracy 0.875...\n", "iteration 25001/50000... loss 0.09028583019971848... accuracy 0.9583333134651184...\n", "iteration 26001/50000... loss 0.057328272610902786... accuracy 1.0...\n", "iteration 27001/50000... loss 0.05184454843401909... accuracy 1.0...\n", "iteration 28001/50000... loss 0.205277219414711... accuracy 0.875...\n", "iteration 29001/50000... loss 0.02419412136077881... accuracy 1.0...\n", "iteration 30001/50000... loss 0.15514272451400757... accuracy 0.9583333134651184...\n", "saved to models/pretrained_lstm.ckpt-30000\n", "iteration 31001/50000... loss 0.029978496953845024... accuracy 1.0...\n", "iteration 32001/50000... loss 0.18621210753917694... accuracy 0.9166666865348816...\n", "iteration 33001/50000... loss 0.02521480619907379... accuracy 1.0...\n", "iteration 34001/50000... loss 0.01872040517628193... accuracy 1.0...\n", "iteration 35001/50000... loss 0.03600594401359558... accuracy 1.0...\n", "iteration 36001/50000... loss 0.11539971083402634... accuracy 0.9583333134651184...\n", "iteration 37001/50000... loss 0.0046129003167152405... accuracy 1.0...\n", "iteration 38001/50000... loss 0.020226482301950455... accuracy 1.0...\n", "iteration 39001/50000... loss 0.023621728643774986... accuracy 1.0...\n", "iteration 40001/50000... loss 0.03808807209134102... accuracy 1.0...\n", "saved to models/pretrained_lstm.ckpt-40000\n", "iteration 41001/50000... loss 0.049396928399801254... accuracy 1.0...\n", "iteration 42001/50000... loss 0.03090967796742916... accuracy 1.0...\n", "iteration 43001/50000... loss 0.031491968780756... accuracy 1.0...\n", "iteration 44001/50000... loss 0.0735802948474884... accuracy 1.0...\n", "iteration 45001/50000... loss 0.0040315561927855015... accuracy 1.0...\n", "iteration 46001/50000... loss 0.007067581173032522... accuracy 1.0...\n", "iteration 47001/50000... loss 0.031160416081547737... accuracy 0.9583333134651184...\n", "iteration 48001/50000... loss 0.00609404593706131... accuracy 1.0...\n", "iteration 49001/50000... loss 0.027362242341041565... accuracy 1.0...\n" ] } ], "source": [ "sess = tf.InteractiveSession()\n", "saver = tf.train.Saver()\n", "sess.run(tf.global_variables_initializer())\n", "\n", "for i in range(iterations):\n", " # 上面定义的,拿到batch数据的函数\n", " nextBatch, nextBatchLabels = getTrainBatch();\n", " sess.run(optimizer, {input_data: nextBatch, labels: nextBatchLabels}) \n", " # 隔一万次打印一次当前结果\n", " if (i % 1000 == 0 and i != 0):\n", " loss_ = sess.run(loss, {input_data: nextBatch, labels: nextBatchLabels})\n", " accuracy_ = sess.run(accuracy, {input_data: nextBatch, labels: nextBatchLabels})\n", "\n", " print(\"iteration {}/{}...\".format(i+1, iterations),\n", " \"loss {}...\".format(loss_),\n", " \"accuracy {}...\".format(accuracy_)) \n", " # Save the network every 10,000 training iterations,隔一万次保存一次,防止后面效果并没有继续变好\n", " if (i % 10000 == 0 and i != 0):\n", " save_path = saver.save(sess, \"models/pretrained_lstm.ckpt\", global_step=i)\n", " print(\"saved to %s\" % save_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "上面结果稍微过拟合,因为accuracy达到了1.0,查看可视化结果\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "查看上面的训练曲线,我们发现这个模型的训练结果还是不错的。损失值在稳定的下降,正确率也不断的在接近 100% 。然而,当分析训练曲线的时候,我们应该注意到我们的模型可能在训练集上面已经过拟合了。过拟合是机器学习中一个非常常见的问题,表示模型在训练集上面拟合的太好了,但是在测试集上面的泛化能力就会差很多。也就是说,如果你在训练集上面取得了损失值是 0 的模型,但是这个结果也不一定是最好的结果。当我们训练 LSTM 的时候,提前终止是一种常见的防止过拟合的方法。基本思路是,我们在训练集上面进行模型训练,同事不断的在测试集上面测量它的性能。一旦测试误差停止下降了,或者误差开始增大了,那么我们就需要停止训练了。因为这个迹象表明,我们网络的性能开始退化了。\n", "\n", "导入一个预训练的模型需要使用 TensorFlow 的另一个会话函数,称为 Server ,然后利用这个会话函数来调用 restore 函数。这个函数包括两个参数,一个表示当前的会话,另一个表示保存的模型。" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from models/pretrained_lstm.ckpt-40000\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/envs/tensorflow_py3/lib/python3.6/site-packages/tensorflow_core/python/client/session.py:1750: UserWarning: An interactive session is already active. This can cause out-of-memory errors in some cases. You must explicitly call `InteractiveSession.close()` to release resources held by the other session(s).\n", " warnings.warn('An interactive session is already active. This can '\n" ] } ], "source": [ "sess = tf.InteractiveSession()\n", "saver = tf.train.Saver()\n", "saver.restore(sess, tf.train.latest_checkpoint('models'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "然后,从我们的测试集中导入一些电影评论。请注意,这些评论是模型从来没有看见过的。" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy for this batch: 91.66666865348816\n", "Accuracy for this batch: 83.33333134651184\n", "Accuracy for this batch: 79.16666865348816\n", "Accuracy for this batch: 87.5\n", "Accuracy for this batch: 79.16666865348816\n", "Accuracy for this batch: 79.16666865348816\n", "Accuracy for this batch: 91.66666865348816\n", "Accuracy for this batch: 87.5\n", "Accuracy for this batch: 75.0\n", "Accuracy for this batch: 87.5\n" ] } ], "source": [ "iterations = 10\n", "for i in range(iterations):\n", " nextBatch, nextBatchLabels = getTestBatch();\n", " print(\"Accuracy for this batch:\", (sess.run(accuracy, {input_data: nextBatch, labels: nextBatchLabels})) * 100)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "测试结果中有些高有些低,可以看出依然存在过拟合现象,目前构建的模型也是比较简单的模型,在实际运用中会堆叠不止一层的LSTM\n", "\n", "我们自己测试的时候可以多试几个超参数,特别是词向量维度。具体可以参考上面的超参数调整。" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 4 }