@ -151,7 +151,7 @@ $$
QKV分别获得后,QK则是根据路线进行矩阵相乘,如下图
<img src="../assets/image-20240501173316308.png" alt="image-20240501173316308" style="zoom:50%;" />
<img src="../assets/image-20240502212200231.png" alt="image-20240502212200231" style="zoom:50%;" />
其中我们把K进行了翻转,方便相乘。矩阵相乘则是每个batch_size里的每个头进行矩阵相乘,即[16, 64]和[64, 16]进行矩阵相乘,相乘后则是变成了[16, 16]的矩阵。