矩阵乘法的 Strassen 算法

矩阵乘法的 Strassen 算法

Hexarhy

·

2021-06-13 18:12:26

·

算法·理论

Update

[2021/11/14] 感谢 @serverkiller 指出错误并修正。

[2021/11/14] 替换了扩展阅读第二个 link。

Introduction

Strassen 算法在 OI 上没有任何应用价值,不过了解一下理论计算机科学相关还蛮有意思的。

计算两个矩阵相乘的朴素方法是 \Theta(n^3)。而 Strassen 算法通过采用分治策略,并减少递归次数,实现在 \Theta(n^{\log 7}) 时间内完成。\log 7\approx 2.808,也就是 O(n^{2.81})。

前置知识: 矩阵基础,主定理。

其实主定理只需要知道结论即可,这里放一下简化结论:

T(n)=aT\left(\dfrac{n}{b}\right)+O(n^d),a\ge 1,b>1\\

T(n)=\begin{cases}

O(n^d)&,d>\log_{b}a\\

O(n^d\log n)&,d=\log_b a\\

O(n^{\log_b a})&,d<\log_b a

\end{cases}

Analysis

从分治开始

在介绍 Strassen 算法之前,先探讨如何用分治完成矩阵乘法。

分治策略其实非常直截了当,就是将 n\times n 的矩阵划分为四个 \dfrac{n}{2}\times \dfrac{n}{2} 的矩阵。当然,这种做法要求 n 是 2 的幂,但相关细节我们稍后探讨。

先以 2\times 2 的矩阵为例。

\begin{bmatrix}A_{1,1}&A_{1,2}\\A_{2,1}&A_{2,2}\end{bmatrix}\cdot\begin{bmatrix}B_{1,1}&B_{1,2}\\B_{2,1}&B_{2,2}\end{bmatrix}=\begin{bmatrix}C_{1,1}&C_{1,2}\\C_{2,1}&C_{2,2}\end{bmatrix}

具体写出计算矩阵 C 的等式:

C_{1,1}=A_{1,1}\cdot B_{1,1}+A_{1,2}\cdot B_{2,1}\\

C_{1,2}=A_{1,1}\cdot B_{1,2}+A_{1,2}\cdot B_{2,2}\\

C_{2,1}=A_{2,1}\cdot B_{1,1}+A_{2,2}\cdot B_{2,1}\\

C_{2,2}=A_{2,1}\cdot B_{1,2}+A_{2,2}\cdot B_{2,2}\\

对于每一条公式,相当于计算两对 \dfrac{n}{2}\times\dfrac{n}{2} 矩阵乘法,再计算一次这样的矩阵加法。

直接利用这个运算方式就可以写出递归分治策略。形象化地说,就是将矩阵分解为左上,左下,右上,右下四个子矩阵再分别进行运算。

\boxed{

\begin{array}{ll}

&\textbf {Function Multiply}(A,B)\\

1& n\gets A.\text{rows}\\

2& \text{let}\ C\ \text{be a new } n\times n\ \text{matrix}\\

3& \textbf{if } n=1\\

4& \qquad c_{1,1}\gets a_{1,1}\times b_{1,1}\\

5&\textbf{else}\ \mathrm{Partion}\ A,B,C\\

6&\qquad C_{1,1}\gets\text {Multiply}(A_{1,1},B_{1,1})+\text {Multiply}(A_{1,2},B_{2,1})\\

7&\qquad C_{1,2}\gets\text {Multiply}(A_{1,1},B_{1,2})+\text {Multiply}(A_{1,2},B_{2,2})\\

8&\qquad C_{2,1}\gets\text {Multiply}(A_{2,1},B_{1,1})+\text {Multiply}(A_{2,2},B_{2,1})\\

9&\qquad C_{2,2}\gets\text {Multiply}(A_{2,1},B_{1,2})+\text {Multiply}(A_{2,2},B_{2,2})\\

10&\textbf{return}\ C

\end{array}}

注意 \rm Partion 部分只需要对应好分解后的矩阵的下标即可。具体的对应方法可以参考 Strassen 算法伪代码。

需要说明的是,《算法导论》认为,只需通过下标计算即可对实现分解子矩阵并操作,而不用 \Theta(n^2) 拷贝子矩阵。然而在亲自动手实现代码时,避免拷贝子矩阵来进行其他操作是异常麻烦的,书中也没有给出伪代码。况且拷贝也不影响总的时间复杂度,因为矩阵加法需要不可避免的 \Theta(n^2)。但拷贝操作对常数因子影响比较大。如果有谁会实现不用拷贝的请务必把代码发给我/kk

现在来看一下这段代码的时间复杂度。

我们分解出了 8 个子问题,每个子问题规模缩小了一半,同时花了 \Theta(n^2) 时间分解出子矩阵,花了 \Theta(n^2) 进行矩阵加法。容易写出其递归式:

\begin{aligned}

T(n)&=8T\left(\dfrac{n}{2}\right)+\Theta(n^2)

\end{aligned}

运用主定理即可求解时间复杂度为 \Theta(n^3)。

没有优化啊?这就来到 Strassen 算法的另一个核心:减少递归次数。

还能再少一次

Strassen 算法在朴素分治算法的基础上,只进行了 7 次递归。当然减少递归次数的代价就是多进行了几次矩阵加法,但幸好只是常数级别。

我们先对时间复杂度进行分析。递归式为:

T(n)=7T\left(\dfrac{n}{2}\right)+\Theta(n^2)

运用主定理求解出时间复杂度为 \Theta(n^{\log_2 7}),也就是 O(n^{2.81})。

步骤上就比朴素的分治算法要麻烦一些。

与朴素分治算法相同,分解出左上,左下,右上,右下四个子矩阵。

创建 10 个 \dfrac{n}2\times\dfrac n2 的矩阵 S_i,每个 S_i 保存两个子矩阵的和或差。

用子矩阵和 S_i 相乘,递归地计算 7 个 \dfrac{n}2\times\dfrac n2 的矩阵 P_i。

通过 P_i 的不同组合进行加减,得到 C 的子矩阵。

具体地,步骤 2 中创建的 10 个矩阵 S_i 分别为:

\begin{array}{ll}

S_1&=B_{1,2}-B_{2,2}\\

S_2&=A_{1,1}+A_{1,2}\\

S_3&=A_{2,1}+A_{2,2}\\

S_4&=B_{2,1}-B_{1,1}\\

S_5&=A_{1,1}+A_{2,2}\\

S_6&=B_{1,1}+B_{2,2}\\

S_8&=A_{1,2}-A_{2,2}\\

S_8&=B_{2,1}+B_{2,2}\\

S_9&=A_{1,1}-A_{2,1}\\

S_{10}&=B_{1,1}+B_{1,2}

\end{array}

步骤 3 中需要递归计算的 7 个矩阵 P_i 分别为:

\begin{array}{ll}

P_1&=A_{1,1}\cdot S_1\\

P_2&=S_2\cdot B_{2,2}\\

P_3&=S_3\cdot B_{1,1}\\

P_4&=A_{2,2}\cdot S_4\\

P_5&=S_5\cdot S_6\\

P_6&=S_7\cdot S_8\\

P_7&=S_9\cdot S_{10}

\end{array}

到了步骤 4,计算 C 的子矩阵的方法为:

\begin{array}{ll}

C_{1,1}&=P_5+P_4-P_2+P_6\\

C_{1,2}&=P_1+P_2\\

C_{2,1}&=P_3+P_4\\

C_{2,2}&=P_5+P_1-P_3-P_7

\end{array}

这些式子为什么正确?直接代入即可验证。由于验证过程过于冗长,这里只举一例 C_{1,2}=P_1+P_2。

\begin{aligned}

P_1+P_2&=A_{1,1}\cdot S_1+S_2\cdot B_{2,2}\\

&=A_{1,1}\cdot (B_{1,2}-B_{2,2})+(A_{1,1}+A_{1,2})\cdot B_{2,2}\\

&=A_{1,1}\cdot B_{1,2}-A_{1,1}\cdot B_{2,2}+A_{1,1}\cdot B_{2,2}+A_{1,2}\cdot B_{2,2}\\

&=A_{1,1}\cdot B_{1,2}+A_{1,2}\cdot B_{2,2}\\

&=C_{1,2}

\end{aligned}

至于 Strassen 具体是如何想到构造出这些算式的,则留给我们作为无限的遐想。有兴趣可以浏览这里。

Exercises

节选自《算法导论》4.2 练习。

试只用三次乘法完成复数相乘 (a+b\mathrm{i})(c+d\mathrm{i})=(ac-bd)+(ad+bc)\mathrm{i}。

\begin{cases}\alpha=ac\\ \beta=bd\\\gamma=(a+b)(c+d)\\\end{cases}

则:

(a+b\mathrm{i})(c+d\mathrm{i})=(\alpha-\beta)+(\gamma-\alpha-\beta)\mathrm{i}

计算 \alpha,\beta,\gamma 只用了三次乘法,代价就是增加了加法次数。事实上 Gauss 早已发现了三次乘法进行复数相乘的方法,说不定 Strassen 是受到了这个启发?

若矩阵规模不是 2 的幂,如何应用 Strassen 算法?

用值 0 补齐到 2 的幂即可。这是实现代码时需要注意的地方。

已知用 k 次乘法操作完成两个 3\times 3 的矩阵相乘,那么满足在 o(n^{\log 7}) 的时间内完成 n\times n 的矩阵相乘,k 的最大值是多少?

容易列出递归式并求解:

T(n)&=kT\left(\dfrac n3\right)+O(n^2)\\

T(n)&=O(\log_3 k)\\

\log_3k&<\log_2 7\\

k_{\max}&=21

\end{aligned}

编写 Strassen 算法的伪代码。

凑合着看吧,这里把 \rm Partion 部分具体写了出来。

其中 A[1\sim n/2][1\sim n/2] 表示由 \forall i\in[1,n/2],\forall j\in[1,n/2],A_{i,j} 组成的子矩阵,其余类似。

\boxed{

\begin{array}{ll}

&\textbf{Function Strassen}(A, B)\\

1& n \gets A.\mathrm{rows}\\

2& \textbf{if}\ n = 1\\

3& \qquad \textbf{return}\ a[1, 1]\times b[1, 1]\\

4& \mathrm{let}\ C\ \mathrm{be\ a\ new}\ n \times n\ \mathrm{matrix}\\

5& A[1, 1] \gets A[1\sim n / 2][1\sim n / 2]\\

6& A[1, 2] \gets A[1\sim n / 2][n / 2 + 1\sim n]\\

7& A[2, 1] \gets A[n / 2 + 1\sim n][1\sim n / 2]\\

8& A[2, 2] \gets A[n / 2 + 1\sim n][n / 2 + 1\sim n]\\

9& B[1, 1] \gets B[1\sim n / 2][1\sim n / 2]\\

10& B[1, 2] \gets B[1\sim n / 2][n / 2 + 1\sim n]\\

11& B[2, 1] \gets B[n / 2 + 1\sim n][1\sim n / 2]\\

12& B[2, 2] \gets B[n / 2 + 1\sim n][n / 2 + 1\sim n]\\

13& S[1] \gets B[1, 2] - B[2, 2]\\

14& S[2] \gets A[1, 1] + A[1, 2]\\

15& S[3] \gets A[2, 1] + A[2, 2]\\

16& S[4] \gets B[2, 1] - B[1, 1]\\

17& S[5] \gets A[1, 1] + A[2, 2]\\

18& S[6] \gets B[1, 1] + B[2, 2]\\

19& S[7] \gets A[1, 2] - A[2, 2]\\

20& S[8] \gets B[2, 1] + B[2, 2]\\

21& S[9] \gets A[1, 1] - A[2, 1]\\

22& S[10] \gets B[1, 1] + B[1, 2]\\

23& P[1] \gets \mathrm{Strassen}(A[1, 1], S[1])\\

24& P[2] \gets \mathrm{Strassen}(S[2], B[2, 2])\\

25& P[3] \gets \mathrm{Strassen}(S[3], B[1, 1])\\

26& P[4] \gets \mathrm{Strassen}(A[2, 2], S[4])\\

27& P[5] \gets \mathrm{Strassen}(S[5], S[6])\\

28& P[6] \gets \mathrm{Strassen}(S[7], S[8])\\

29& P[7] \gets \mathrm{Strassen}(S[9], S[10])\\

30& C[1\sim n / 2][1\sim n / 2] \gets P[5] + P[4] - P[2] + P[6]\\

31& C[1\sim n / 2][n / 2 + 1\sim n] \gets P[1] + P[2]\\

32& C[n / 2 + 1\sim n][1\sim n / 2] \gets P[3] + P[4]\\

33& C[n / 2 + 1\sim n][n / 2 + 1\sim n] \gets P[5] + P[1] - P[3] - P[7]\\

34& \textbf{return}\ C\\

\end{array}}

至于 C++ 代码,就留作读者课后习题吧。

Notice

然而很遗憾的是,Strassen 算法由于使用了大量递归,多次创建临时矩阵,常数因子比朴素矩阵乘法大非常多,因此 OI 范围内几乎没有应用价值。朴素的矩阵乘法就足够了。

如果确实想要优化常数,可以从以下方面考虑:

使用指令集等卡常科技。

在子矩阵小到一定规模时使用朴素矩阵乘法。

尽可能地减少创建临时矩阵和复制次数。

据 w33z8kqrqk8zzzx33 在帖子里说道,使用指令集可以在 n=2^{10} 时用少于 0.2 s 的时间完成(朴素矩阵乘法用指令集应该也会快不少吧?)。

而实际应用中的大型矩阵乘法都依赖于硬件(cache, GPU 等)和分布式计算。

尽管 Strassen 算法看上去没有什么实际用处,《算法导论》依然指出了 Strassen 算法的一个最重要的意义:在理论研究上作出了突破性的贡献。

类比 1959 年发明的 Shell Sort(希尔排序),目前来看也没有任何应用价值,但这是计算机第一次在整数排序上突破了 \Theta(n^2) 的壁障(当时常见的还是插入排序和冒泡排序)。

同样地,是 Strassen 算法使得矩阵乘法在渐进上界上第一次快于 \Theta(n^3),并鼓舞着后人继续在这方面探索。1969 年 Strassen 发表论文的的标题为《高斯消元法并非最优》,也正揭示了该算法的意义所在。

Extension

Strassen 算法其实能推广到很多矩阵操作,在原论文也略有提及。其基本思路都是分治并减少递归次数。本文将对其粗略说明。至于 Strassen 构造的这些玄妙算式是如何想到的,则给后人留下了神秘的美感。

下面所有算法的时间复杂度分析都与前文矩阵乘法相类似,不再赘述。

矩阵求逆

一般的矩阵求逆是用 \Theta(n^3) 的高斯消元法完成的。而 Strassen 在矩阵乘法的基础上,得到了更快的算法。这也是 Strassen 论文标题的由来。

基本思路依然与矩阵乘法类似。

先将矩阵 A 按照左上,左下,右上,右下分解出子矩阵。

创建如下 7 个矩阵 P_i:

\begin{array}{ll}

P_1&=A_{1,1}^{-1}\\

P_2&=A_{2,1}\cdot P_1\\

P_3&=P_{1}\cdot A_{1,2}\\

P_4&=A_{2,1}\cdot P_3\\

P_5&=P_4-A_{2,2}\\

P_6&=P_5^{-1}\\

P_7&=P_{3}\cdot P_6\cdot P_2\\

\end{array}

然后计算出 A 的逆矩阵 A^{-1} 的 4 个子矩阵:

\begin{array}{ll}

A^{-1}_{1,1}&=P_1-P_7\\

A^{-1}_{1,2}&=P_3\cdot P_6\\

A^{-1}_{2,1}&=P_6\cdot P_2\\

A^{-1}_{2,2}&=-P_6

\end{array}

当子矩阵缩小到一定规模,我们就可以直接用高斯消元法求解来减小常数。

当然,递归求逆过程中也要顺便用 Strassen 算法求矩阵乘法。

需要说明的是,Strassen 假定了操作过程中所有矩阵都是可逆的,而对于更复杂的情况则束手无策。对于 Strassen 算法在矩阵求逆操作上的更深研究,限于篇幅请参考这里。

解线性方程组

Similar results hold for solving a system of linear equations or computing a

determinant.

Strassen 对于矩阵求逆和行列式计算的记录相当简略啊。

我们知道线性方程组可以写成矩阵的形式。

\begin{bmatrix}

a_{1,1}&a_{1,2}&a_{1,3}&\cdots& a_{1,m}\\

a_{2,1}&a_{2,2}&a_{2,3}&\cdots& a_{2,m}\\

a_{3,1}&a_{3,2}&a_{3,3}&\cdots& a_{3,m}\\

\vdots&\vdots&\vdots&\ddots&\vdots\\

a_{n,1}&a_{n,2}&a_{n,3}&\cdots& a_{n,m}\\

\end{bmatrix}

\cdot

\begin{bmatrix}

x_1\\

x_2\\

x_3\\

\vdots\\

x_n

\end{bmatrix}=

\begin{bmatrix}

y_1\\

y_2\\

y_3\\

\vdots\\

y_n

\end{bmatrix}

上式简记为 A\cdot x=B,则有 x=A^{-1}\cdot

B。

如果逆矩阵存在,则方程恰好有一解,直接用 Strassen 算法进行矩阵求逆即可。

行列式

同样把矩阵 A 分解成四个子矩阵,然后用一种简洁的计算方式求解即可:

\operatorname{det}(A)=\operatorname{det}(A_{1,1})\cdot \operatorname{det}\left(A_{2,2}-A_{2,1}\cdot A_{1,1}^{-1}\cdot A_{1,2}\right)

Reference

《算法导论》第三版第四章。

Strassen 原论文 - Gaussian Elimination is not Optimal

关于矩阵乘法的其它研究,感兴趣的可以阅读以下文献:

用 O(n^{2.37286}) 时间完成矩阵乘法 - Link

对于 n\times n 的矩阵左乘上 n\times n^p(p\le 0.294) 的矩阵,可以做到 O(n^2) 的时间复杂度 - Link

关于 GPU 上进行矩阵乘法的效率 - Link