神经网络中利用矩阵进行反向传播运算的实质

神经网络中利用矩阵进行反向传播运算的实质

训练神经网络模型时,为了优化目标函数,我们需要不断地迭代更新网络中的权值,而这一过程是通过反向传播算法(Backpropagation,BP)实现的。在神经网络中,训练样本和权值参数都被表示为矩阵的形式,因为这样更利于反向传播的计算。

之前学习反向传播算法的时候一直有误解,认为它需要用到大量的矩阵求导,但仔细理解后发现实际上用到的还是标量的求导,只不过用矩阵表示出来了而已。

本文中通过递推的方法,用矩阵来形象化地表示神经网络模型训练中反向传播的过程,并从单个输入样本逐步扩展到多个输入样本(mini-batch)。

单个输入样本计算

对于形如L=f\left(Y\right)=f\left(XW\right)=f\left(\sum_{i}^{n}w_ix_i\right)的目标函数来说(省略偏置项b,因为它可以被整合进X中),若中间项Y取单一值y,则其可以表示为两个向量相乘的形式

\begin{align}
\left[y
\right]
=
\left[
\begin{matrix}
x_1&x_2&\cdots&x_n
\end{matrix}
\right]
\times
\left[
\begin{matrix}
w_1\\
w_2\\
\vdots\\
w_n
\end{matrix}
\right]
\end{align}

若中间值取多维\left[
\begin{matrix}
y_1 & y_2 & ... & y_n
\end{matrix}
\right],则可以表示为两个矩阵相乘的形式

\begin{align}
\left[
\begin{matrix}
y_1&y_2&\cdots&y_n
\end{matrix}
\right]
=
\left[
\begin{matrix}
x_1&x_2&\cdots&x_n
\end{matrix}
\right]
\times
\left[
\begin{matrix}
w_{11}&w_{12}&\cdots&w_{1c}\\
w_{21}&w_{22}&\cdots&w_{2c}\\
\vdots&\vdots&\ddots&\vdots\\
w_{n1}&w_{n2}&\cdots&w_{nc}
\end{matrix}
\right]
\end{align}

对于其中每一个目标值,y_c=\sum_{i=1}^{n}w_{ic}x_i=w_{1c}x_1+w_{2c}x_2+\cdots+w_{nc}x_n。那么如果想要求得y_cw_{ic}的导数,只需要列出公式

\begin{align}
W'_{ic}=
\frac{\partial y_c}{\partial w_{ic}}
=
\frac{\partial \sum^n_{i=1}w_{ic}x_{i}}{w_{ic}}=x_i
\end{align}

那么对于参数矩阵的列向量W_c来说

\frac{\partial y_c}{\partial W_c}=X^T

假设目标函数L对于Y的导数为Y'
=
\left[
\begin{matrix}
dy_1 & dy_2 & \cdots & dy_c
\end{matrix}
\right],那么L对于W_cW的列向量)的偏导数W_c'则为

\begin{align}
W'_c
=
\left[
\begin{matrix}
dw_{1c}\\
dw_{2c}\\
\vdots\\
dw_{nc}
\end{matrix}
\right]
=
\left[
\begin{matrix}
x_{1}\\
x_{2}\\
\vdots\\
x_{n}
\end{matrix}
\right]
\times
\left[
\begin{matrix}
d_{y_c}
\end{matrix}
\right]
\end{align}

W_c'
=
\frac{\partial L}{\partial W_c}
=
\frac{\partial L}{\partial y_c}
\frac{\partial y_c}{\partial W_c}
=
dy_cX^T
=
X^Tdy_c

那么L对于W的偏导数W'则可以通过矩阵表示为

\begin{align}
W'
=
\left[
\begin{matrix}
dw_{11}&dw_{12}&\cdots&dw_{1c}\\
dw_{21}&dw_{22}&\cdots&dw_{2c}\\
\vdots&\vdots&\ddots&\vdots\\
dw_{n1}&dw_{n2}&\cdots&dw_{nc}
\end{matrix}
\right]
=
\left[
\begin{matrix}
x_{1}\\
x_{2}\\
\vdots\\
x_{n}
\end{matrix}
\right]
\times
\left[
\begin{matrix}
d_{y_1}&d_{y_2}&\cdots&d_{y_c}
\end{matrix}
\right]
\end{align}

W'
=
\frac{\partial L}{\partial W}
=
\frac{\partial L}{\partial Y}
\frac{\partial Y}{\partial W}
=
X^TY'

多个输入样本计算

在神经网络中,我们通常采用mini-batch的方法进行训练,对于含有m个样本的mini-batch来说

\begin{align}
\left[
\begin{matrix}
y_{11}&y_{12}&\cdots&y_{1c}\\
y_{21}&y_{22}&\cdots&y_{2c}\\
\vdots&\vdots&\ddots&\vdots\\
y_{m1}&y_{m2}&\cdots&y_{mc}
\end{matrix}
\right]
=
\left[
\begin{matrix}
x_{11}&x_{12}&\cdots&x_{1n}\\
x_{21}&x_{22}&\cdots&x_{2n}\\
\vdots&\vdots&\ddots&\vdots\\
x_{m1}&x_{m2}&\cdots&x_{mn}
\end{matrix}
\right]
\times
\left[
\begin{matrix}
w_{11}&w_{12}&\cdots&w_{1c}\\
w_{21}&w_{22}&\cdots&w_{2c}\\
\vdots&\vdots&\ddots&\vdots\\
w_{n1}&w_{n2}&\cdots&w_{nc}
\end{matrix}
\right]
\end{align}

其中,W_c'可表示为

\begin{align}
W'_c
=
\left[
\begin{matrix}
dw_{1c}\\
dw_{2c}\\
\vdots\\
dw_{nc}
\end{matrix}
\right]
=
\left[
\begin{matrix}
x_{11}&x_{21}&\cdots&x_{m1}\\
x_{12}&x_{22}&\cdots&x_{m2}\\
\vdots&\vdots&\ddots&\vdots\\
x_{1n}&x_{2n}&\cdots&x_{mn}
\end{matrix}
\right]
\times
\left[
\begin{matrix}
dy_{1c}\\
dy_{2c}\\
\vdots\\
dy_{mc}
\end{matrix}
\right]
\end{align}

W_c'
=
\frac{\partial L}{\partial W_c}
=
\frac{\partial L}{\partial Y_c}
\frac{\partial Y_c}{\partial W_c}
=
X^TY_c'Y_cY的列向量),

W'表示为

\begin{align}
W'
=
\left[
\begin{matrix}
dw_{11}&dw_{12}&\cdots&dw_{1c}\\
dw_{21}&dw_{22}&\cdots&dw_{2c}\\
\vdots&\vdots&\ddots&\vdots\\
dw_{n1}&dw_{n2}&\cdots&dw_{nc}
\end{matrix}
\right]
=
\left[
\begin{matrix}
x_{11}&x_{21}&\cdots&x_{m1}\\
x_{12}&x_{22}&\cdots&x_{m2}\\
\vdots&\vdots&\ddots&\vdots\\
x_{1n}&x_{2n}&\cdots&x_{mn}
\end{matrix}
\right]
\times
\left[
\begin{matrix}
dy_{11}&dy_{12}&\cdots&dy_{1c}\\
dy_{21}&dy_{22}&\cdots&dy_{2c}\\
\vdots&\vdots&\ddots&\vdots\\
dy_{m1}&dy_{m2}&\cdots&dy_{mc}
\end{matrix}
\right]
\end{align}

W'
=
\frac{\partial L}{\partial W}
=
\frac{\partial L}{\partial Y}
\frac{\partial Y}{\partial W}
=
X^TY'

快速计算方法

其实还有一种简便的方法可以推导上面的公式,对于Y=XW,假设Y的维度是M \times CX的维度是M \times NW的维度是N \times C,那么可以利用维度的关系进行导数的计算。Y'=\frac{\partial L}{\partial Y}的维度必然是M \times C,那么W'=\frac{\partial L}{\partial W}==\frac{\partial L}{\partial Y}\frac{\partial Y}{\partial W}的维度必然是N \times C且与X有关,那么必有W'=X^TY',同理必有X'=Y'W^T

参考

文章被以下专栏收录