|
|
|
@ -1,6 +1,6 @@
|
|
|
|
|
# 第四章——多头注意力机制——QK矩阵相乘
|
|
|
|
|
|
|
|
|
|
<img src="../assets/image-20240421212923027.png" alt="语义关系学习" style="zoom:50%;" />
|
|
|
|
|
<img src="../assets/image-20240421212923027.png" alt="语义关系学习" width="550" />
|
|
|
|
|
|
|
|
|
|
### 前言
|
|
|
|
|
|
|
|
|
@ -8,7 +8,7 @@
|
|
|
|
|
|
|
|
|
|
放大语义关系学习(注意力机制)内部
|
|
|
|
|
|
|
|
|
|
<img src="../assets/image-20240501143058994.png" alt="image-20240501143058994" style="zoom:50%;" />
|
|
|
|
|
<img src="../assets/image-20240501143058994.png" alt="image-20240501143058994" width="550" />
|
|
|
|
|
|
|
|
|
|
> Wq/Wk/Wv(Linear):线性层。数学表达式是 `y = wx + b`,其中 `x` 是输入向量,`W`是权重矩阵,`b` 是偏置向量,`y` 是输出向量。
|
|
|
|
|
>
|
|
|
|
@ -28,13 +28,13 @@
|
|
|
|
|
|
|
|
|
|
一个值是标量(Scalar),一组值是向量(Vector),多组值是矩阵(Matrix)
|
|
|
|
|
|
|
|
|
|
<img src="../assets/scalar-vector-matrix.svg" alt="Scalar, Vector, Matrix" style="zoom:50%;" />
|
|
|
|
|
<img src="../assets/scalar-vector-matrix.svg" alt="Scalar, Vector, Matrix" width="300" />
|
|
|
|
|
|
|
|
|
|
矩阵也就是多维的向量,矩阵是可以多种维度的,如3列2行(上面图的),亦或者2行3列。
|
|
|
|
|
|
|
|
|
|
矩阵相乘(又叫点积相乘)如下:
|
|
|
|
|
|
|
|
|
|
<img src="../assets/matrix-multiply-a.svg" alt="矩阵乘法" style="zoom:50%;" />
|
|
|
|
|
<img src="../assets/matrix-multiply-a.svg" alt="矩阵乘法" width="300" />
|
|
|
|
|
|
|
|
|
|
~~~markdown
|
|
|
|
|
"点积" 是把 对称的元素相乘,然后把结果加起来:
|
|
|
|
@ -45,7 +45,7 @@
|
|
|
|
|
|
|
|
|
|
这里顺便补充下, 我们平时说的线形变换,其实就是一种特殊的矩阵相乘,即,矩阵乘以一个向量
|
|
|
|
|
|
|
|
|
|
<img src="../assets/419801703.png" alt="L15.png" style="zoom:30%;" />
|
|
|
|
|
<img src="../assets/419801703.png" alt="L15.png" width="300" />
|
|
|
|
|
|
|
|
|
|
最终输出的是[3,1]的矩阵,即一个向量`[16 4 7]`。从上面我们也知道,只要[3,2]里的这个2,能够对应上另外一个矩阵的行,就能够相乘。即[3,2]对应上面的[2,1],2跟2对应。即第一个矩阵的第二个数 能跟 第二个矩阵的第一个数对应上,就能相乘。
|
|
|
|
|
|
|
|
|
@ -55,7 +55,7 @@
|
|
|
|
|
|
|
|
|
|
假设我们有两个矩阵:A [1 2] 和 B [3 3]两个矩阵,画到象限表,如下图
|
|
|
|
|
|
|
|
|
|
<img src="../assets/image-20240430191643501.png" alt="image-20240430191643501" style="zoom:50%;" />
|
|
|
|
|
<img src="../assets/image-20240430191643501.png" alt="image-20240430191643501" width="300" />
|
|
|
|
|
|
|
|
|
|
我们说A和B相似,如何判断相似,就看它们离的近不近,或者两个向量的夹角a比较小。并且我们肉眼看,A和C离的相对更远。
|
|
|
|
|
|
|
|
|
@ -71,7 +71,7 @@ A矩阵*B矩阵=B长度*A长度*cos(\theta)
|
|
|
|
|
$$
|
|
|
|
|
我们做一个浅绿色的垂线,它就变成一个直角三角形。在数学三角函数中,cos的邻边等于cos(θ)乘以斜边。也就是A的长度乘以cos(θ),等于黑色的线(B上的黑色线)
|
|
|
|
|
|
|
|
|
|
<img src="../assets/image-20240430191813984.png" alt="image-20240430191813984" style="zoom:50%;" />
|
|
|
|
|
<img src="../assets/image-20240430191813984.png" alt="image-20240430191813984" width="300" />
|
|
|
|
|
|
|
|
|
|
也就是公式等同于,也就是红色乘以黑色的部分
|
|
|
|
|
$$
|
|
|
|
@ -81,19 +81,19 @@ $$
|
|
|
|
|
|
|
|
|
|
如果是C做垂线B,可能就是负数了。如果是三维平面或者四维屏幕,则是如下增加多条线
|
|
|
|
|
|
|
|
|
|
<img src="../assets/image-20240430192046147.png" alt="image-20240430192046147" style="zoom: 33%;" />
|
|
|
|
|
<img src="../assets/image-20240430192046147.png" alt="image-20240430192046147" width="300" />
|
|
|
|
|
|
|
|
|
|
现在我们知道矩阵相乘能代表相似度的高低,回到实际中,过程图如下
|
|
|
|
|
|
|
|
|
|
<img src="../assets/image-20240430194452839.png" alt="image-20240430194452839" style="zoom:50%;" />
|
|
|
|
|
<img src="../assets/image-20240430194452839.png" alt="image-20240430194452839" width="550" />
|
|
|
|
|
|
|
|
|
|
上面我放的文字,实际传给机器的时候是数值。
|
|
|
|
|
|
|
|
|
|
<img src="../assets/image-20240430194746805.png" alt="image-20240430194746805" style="zoom:50%;" />
|
|
|
|
|
<img src="../assets/image-20240430194746805.png" alt="image-20240430194746805" width="550" />
|
|
|
|
|
|
|
|
|
|
通过矩阵相乘,即`LLM`和`me`的相似度是23,最终它们都会被投射到多维平面上。
|
|
|
|
|
|
|
|
|
|
<img src="../assets/image-20240430194034870.png" alt="image-20240430194034870" style="zoom:50%;" />
|
|
|
|
|
<img src="../assets/image-20240430194034870.png" alt="image-20240430194034870" width="300" />
|
|
|
|
|
|
|
|
|
|
当然时间向量的值一般是[-1,+1]区间的,而不是整数型,这里是一个简单的示例。而且还会经过不断的训练循环,来不断的调整每个文本的多维表达数值分别是多少,也就是LLM初始值假设是[1,2,3],可能训练的下一轮是[-1,3,1]下一轮又是[3,1,2],直到最终训练结束。
|
|
|
|
|
|
|
|
|
@ -119,19 +119,19 @@ $$
|
|
|
|
|
|
|
|
|
|
比如现在我们有4句话,同时我们复用GPT-2的768维向量
|
|
|
|
|
|
|
|
|
|
<img src="../assets/image-20240502132738816.png" alt="image-20240502132738816" style="zoom:50%;" />
|
|
|
|
|
<img src="../assets/image-20240502132738816.png" alt="image-20240502132738816" width="550" />
|
|
|
|
|
|
|
|
|
|
[4, 16, 768] = [batch_size, max_length, d_model],batch_size就是我们可以做并行的设置,做算法建模的同学应该对这个比较熟悉,越大的batch_size,意味着需要越大的内存和显存。max_length则是我们设置的最大长度,超过则截断(因为资源也是有限的,我们一般取能获取到绝大多数完整句子的长度即可)。768则是GPT-2的默认向量维度。
|
|
|
|
|
|
|
|
|
|
看上面的图,[4, 16, 768]复制成3份,分别去与Wq、Wk和Wv矩阵相乘。
|
|
|
|
|
|
|
|
|
|
<img src="../assets/image-20240502132811665.png" alt="image-20240502132811665" style="zoom:50%;" />
|
|
|
|
|
<img src="../assets/image-20240502132811665.png" alt="image-20240502132811665" width="550" />
|
|
|
|
|
|
|
|
|
|
如上图所示,Wq的也是[768, 768]维的矩阵,Wk、Wv同理,它们一开始会初始化值,训练过程会自动调整。
|
|
|
|
|
|
|
|
|
|
单独拿一个Q出来细看,[4, 16, 768]跟[768, 768]是怎么矩阵相乘的,实际上,相乘都是后两个维度跟768相乘,也就是[16,768]跟[768,768]。如下图所示:
|
|
|
|
|
|
|
|
|
|
<img src="../assets/image-20240501162941113.png" alt="image-20240501162941113" style="zoom:50%;" />
|
|
|
|
|
<img src="../assets/image-20240501162941113.png" alt="image-20240501162941113" width="550" />
|
|
|
|
|
|
|
|
|
|
把4个[16, 768]维的矩阵分别拿出来,去与[768, 768]维的矩阵相乘。原矩阵里的数值,经过W权重后,出来的Q里的值会不一样。即最终出来的QKV三个矩阵里的值都跟原始的有所变化。
|
|
|
|
|
|
|
|
|
@ -141,7 +141,7 @@ $$
|
|
|
|
|
|
|
|
|
|
上面我们看到单个头的是[4, 16, 768],前面我们也一直提到QKV的多头机制,如果按照GPT里的12头(Transformer原文中并没有规定是多少头),那么会这么切分,如下图:
|
|
|
|
|
|
|
|
|
|
<img src="../assets/image-20240502134443646.png" alt="image-20240502134443646" style="zoom:50%;" />
|
|
|
|
|
<img src="../assets/image-20240502134443646.png" alt="image-20240502134443646" width="550" />
|
|
|
|
|
|
|
|
|
|
可以看到我们将768维的矩阵,切成了12分,每份是64维。另外,由于大模型都是后两位数矩阵相乘,所以我们把头跟长互换,即[4, 16, 12, 64]转为[4, 12, 16, 64]。
|
|
|
|
|
|
|
|
|
@ -153,7 +153,7 @@ $$
|
|
|
|
|
|
|
|
|
|
QKV分别获得后,QK则是根据路线进行矩阵相乘,如下图
|
|
|
|
|
|
|
|
|
|
<img src="../assets/image-20240502212200231.png" alt="image-20240502212200231" style="zoom:50%;" />
|
|
|
|
|
<img src="../assets/image-20240502212200231.png" alt="image-20240502212200231" width="550" />
|
|
|
|
|
|
|
|
|
|
其中我们把K进行了翻转,方便相乘。矩阵相乘则是每个batch_size里的每个头进行矩阵相乘,即[16, 64]和[64, 16]进行矩阵相乘,相乘后则是变成了[16, 16]的矩阵。
|
|
|
|
|
|
|
|
|
|