高维矩阵乘法

NLP中经常为碰到高维矩阵运算,如Attention中的Q、K、V相乘,在此记录矩阵相乘的运算规则。

高维矩阵可视化

一维: 首先shape=[4]的一维矩阵非常简单,可以用下图表示

1
[1,2,3,4]

一维矩阵

**二维:**shape=[2,3]的二维矩阵可视化如下

1
2
[[1,2,3],
[4,5,6]]

二维

为方便展示三维矩阵,旋转角度如下:

二维

**三维:**一个shape=[2,2,3]的三维矩阵,可视化如下:

1
2
3
4
5
[[[1,2,3],
[4,5,6]],

[[7,8,9],
[10,11,12]]]

三维

切片展示如下:

三维

**四维:**shape=[2,2,2,3]的四维矩阵可视化如下:

1
2
3
4
5
6
7
8
9
10
11
[[[[1,2,3],
[4,5,6]],

[[7,8,9],
[10,11,12]]],

[[[13,14,15],
[16,17,18]],

[[19,20,21],
[22,23,24]]]]

四维

高维矩阵运算

从上面的结论可以看出:所有大于二维的,最终都是以二维为基础堆叠在一起的!!

所以在矩阵运算的时候,其实最后都可以转成我们常见的二维矩阵运算,遵循的原则是:在多维矩阵相乘中,需最后两维满足shape匹配原则,最后两维才是有数据的矩阵,前面的维度只是矩阵的排列而已!

相乘必须满足以下两个条件:

  1. 两个n维数组的前n-2维必须完全相同。例如(3,2,4,2)(3,2,2,3)前两维必须完全一致;
  2. 最后两维必须满足二阶矩阵乘法要求。例如(3,2,4,2)(3,2,2,3)的后两维可视为(4,2)x(2,3)满足矩阵乘法。

另,由于广播机制,第一维为1的,可以与第一维任何数相乘:

(3,2,4,2)*(1,2,2,3)——>>(3,2,4,3)

(1,2,4,2)*(3,2,2,3)——>>(3,2,4,3)

比如两个三维的矩阵相乘,分别为shape=[2,2,3]和shape=[2,3,2]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
a = 
[[[ 1. 2. 3.]
[ 4. 5. 6.]]
[[ 7. 8. 9.]
[10. 11. 12.]]]

b =
[[[ 1. 2.]
[ 3. 4.]
[ 5. 6.]]

[[ 7. 8.]
[ 9. 10.]
[11. 12.]]]

计算的时候把a的第一个shape=[2,3]的矩阵和b的第一个shape=[3,2]的矩阵相乘,得到的shape=[2,2],即

matmul1

同理,再把a,b个字的第二个shape=[2,3]的矩阵相乘,得到的shape=[2,2]。

matmul2

最终把结果堆叠在一起,就是2个shape=[2,2]的矩阵堆叠在一起,结果为:

1
2
3
4
5
[[[ 22.  28.]
[ 49. 64.]]

[[220. 244.]
[301. 334.]]]

参考文献

  1. 【全面理解多维矩阵运算】多维(三维四维)矩阵向量运算-超强可视化
  2. 高维数组相乘的运算规则