Add. Building LSTM network model

pull/2/head
benjas 4 years ago
parent 29f1c7f65c
commit 93712ba5d8

@ -393,6 +393,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# 辅助函数\n",
"from random import randint\n", "from random import randint\n",
"\n", "\n",
"def getTrainBatch():\n", "def getTrainBatch():\n",
@ -420,6 +421,151 @@
" arr[i] = ids[num-1:num]\n", " arr[i] = ids[num-1:num]\n",
" return arr, labels" " return arr, labels"
] ]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### RNN Model\n",
"现在,我们可以开始构建我们的 TensorFlow 图模型。首先我们需要去定义一些超参数比如批处理大小LSTM的单元个数分类类别和训练次数。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"batchSize = 24 # 梯度处理的大小\n",
"lstmUnits = 64 # 隐藏层神经元数量\n",
"numClasses = 2 # 分类数量n/p\n",
"iterations = 50000 # 迭代次数"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"与大多数 TensorFlow 图一样,现在我们需要指定两个占位符,一个用于数据输入,另一个用于标签数据。对于占位符,最重要的一点就是确定好维度。\n",
"\n",
"标签占位符代表一组值,每一个值都为 [1,0] 或者 [0,1],这个取决于数据是正向的还是负向的。输入占位符,是一个整数化的索引数组。 \n",
"\n",
"<img src=\"assets/20210112150210.png\" width=\"100%\">"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tf.reset_default_graph()\n",
"\n",
"labels = tf.placeholder(tf.float32, [batchSize, numClasses])\n",
"input_data = tf.placeholder(tf.int32, [batchSize, maxSeqLength])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"一旦,我们设置了我们的输入数据占位符,我们可以调用 tf. nn. embedding lookup0函数来得到我们的词向量。该函数最后将返回一个三维向量,第一个维度是批处理大小,第二个维度是句子长度,第三个维度是词向量长度。更清晰的表达,如下图际示"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data = tf.Variable(tf.zeros([batchSize, maxSeqLength, numDimensions]),dtype=tf.float32)\n",
"data = tf.nn.embedding_lookup(wordVectors,input_data)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"现在我们已经得到了我们想要的数据形式,那么揭晓了我们看看如何才能将这种数据形式输入到我们的 LSTM 网络中。首先,我们使用 tf.nn.rnn_cell.BasicLSTMCell 函数,这个函数输入的参数是一个整数,表示需要几个 LSTM 单元。这是我们设置的一个超参数,我们需要对这个数值进行调试从而来找到最优的解。然后,我们会设置一个 dropout 参数,以此来避免一些过拟合。\n",
"\n",
"最后,我们将 LSTM cell 和三维的数据输入到 tf.nn.dynamic_rnn ,这个函数的功能是展开整个网络,并且构建一整个 RNN 模型。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"lstmCell = tf.contrib.rnn.BasicLSTMCell(lstmUnits) # 基本单元\n",
"lstmCell = tf.contrib.rnn.DropoutWrapper(cell=lstmCell, output_keep_prob=0.75) # 解决一些过拟合问题output_keep_prob保留比例这个在LSTM的讲解中有解释过\n",
"value, _ = tf.nn.dynamic_rnn(lstmCell, data, dtype=tf.float32) # 构建网络value是值h_ 是中间传递结果,这里不需要分析所以去掉"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"堆栈 LSTM 网络是一个比较好的网络架构。也就是前一个LSTM 隐藏层的输出是下一个LSTM的输入。堆栈LSTM可以帮助模型记住更多的上下文信息但是带来的弊端是训练参数会增加很多模型的训练时间会很长过拟合的几率也会增加。\n",
"\n",
"dynamic RNN 函数的第一个输出可以被认为是最后的隐藏状态向量。这个向量将被重新确定维度,然后乘以最后的权重矩阵和一个偏置项来获得最终的输出值。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 权重参数初始化\n",
"weight = tf.Variable(tf.truncated_normal([lstmUnits, numClasses]))\n",
"bias = tf.Variable(tf.constant(0.1, shape=[numClasses]))\n",
"value = tf.transpose(value, [1, 0, 2])\n",
"# 获取最终的结果值\n",
"last = tf.gather(value, int(value.get_shape()[0]) - 1) # 去ht\n",
"prediction = (tf.matmul(last, weight) + bias) # 最终连上w和b"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"接下来我们需要定义正确的预测函数和正确率评估参数。正确的预测形式是查看最后输出的0-1向量是否和标记的0-1向量相同。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"correctPred = tf.equal(tf.argmax(prediction,1), tf.argmax(labels,1))\n",
"accuracy = tf.reduce_mean(tf.cast(correctPred, tf.float32))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"之后,我们使用一个标准的交叉熵损失函数来作为损失值。对于优化器,我们选择 Adam并且采用默认的学习率"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=labels))\n",
"optimizer = tf.train.AdamOptimizer().minimize(loss)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
} }
], ],
"metadata": { "metadata": {

Binary file not shown.

Before

Width:  |  Height:  |  Size: 93 KiB

After

Width:  |  Height:  |  Size: 34 KiB

Loading…
Cancel
Save