高维矩阵乘法
NLP中经常为碰到高维矩阵运算,如Attention中的Q、K、V相乘,在此记录矩阵相乘的运算规则。
高维矩阵可视化
一维: 首先shape=[4]的一维矩阵非常简单,可以用下图表示
1 | [1,2,3,4] |
**二维:**shape=[2,3]的二维矩阵可视化如下
1 | [[1,2,3], |
为方便展示三维矩阵,旋转角度如下:
**三维:**一个shape=[2,2,3]的三维矩阵,可视化如下:
1 | [[[1,2,3], |
切片展示如下:
**四维:**shape=[2,2,2,3]的四维矩阵可视化如下:
1 | [[[[1,2,3], |
高维矩阵运算
从上面的结论可以看出:所有大于二维的,最终都是以二维为基础堆叠在一起的!!
所以在矩阵运算的时候,其实最后都可以转成我们常见的二维矩阵运算,遵循的原则是:在多维矩阵相乘中,需最后两维满足shape匹配原则,最后两维才是有数据的矩阵,前面的维度只是矩阵的排列而已!
相乘必须满足以下两个条件:
- 两个n维数组的前n-2维必须完全相同。例如(3,2,4,2)(3,2,2,3)前两维必须完全一致;
- 最后两维必须满足二阶矩阵乘法要求。例如(3,2,4,2)(3,2,2,3)的后两维可视为(4,2)x(2,3)满足矩阵乘法。
另,由于广播机制,第一维为1的,可以与第一维任何数相乘:
(3,2,4,2)*(1,2,2,3)——>>(3,2,4,3)
(1,2,4,2)*(3,2,2,3)——>>(3,2,4,3)
比如两个三维的矩阵相乘,分别为shape=[2,2,3]和shape=[2,3,2]
1 | a = |
计算的时候把a的第一个shape=[2,3]的矩阵和b的第一个shape=[3,2]的矩阵相乘,得到的shape=[2,2],即
同理,再把a,b个字的第二个shape=[2,3]的矩阵相乘,得到的shape=[2,2]。
最终把结果堆叠在一起,就是2个shape=[2,2]的矩阵堆叠在一起,结果为:
1 | [[[ 22. 28.] |