张量分析推导 loss 对 X 的导数
利用张量分析书写矩阵乘法
在一个标准的神经网络线性层中,输出和输入的关系如下
T=XW+b(1)
其中
- T 是一个 b×h 的矩阵,这里 b 一般是
batch_size
,h 是隐藏层大小
- X 是一个 b×m 的矩阵 ,这里 m 一般是
emb_size * block_size
- W 是一个 m×h 的矩阵
- b 在
pytorch
中一般不是一个矩阵,是一个长度为 h 的向量,
通过使用了 pytorch
的扩展第一列扩展到所有列
我们可以使用张量分量的形式来表示
T=(xikw⋅jk+bj)gigj(2)
用张量推导导数
在深度学习中,上述说的 T 其实会是 loss 的一个函数,然后 T 又是 X 的函数,写出来是下述的样子
loss=l(T(X))
上述 l 可以看做是一个二阶张量 T 的标量函数,T 是二阶张量 X 的二阶张量函数,而张量函数的导数存在以下链式求导规则
∂X∂l=∂T∂l:∂X∂T
将上式按照张量分量展开如下
∂X∂l=∂tij∂lgigj:∂xmn∂tklgkglgmgn=∂tij∂l∂xmn∂tklδi⋅kδj⋅lgmgn=∂tij∂l∂xmn∂tijgmgn(4)
于是我们得到了基于张量分量的链式法则。在公式 (4) 中,其实 ∂tij∂l 是一个已知字段,在上一次 backward
中已经计算出来了,现在我们的目标是计算 loss 各相对于 W、X 和 b 的导数,将该链式法则带入 (2) 式
∂X∂l=∂tij∂l∂xmn∂(xikw⋅jk+bj)gmgn=∂tij∂lδi⋅mδk⋅nw⋅jkgmgn=∂tmj∂lw⋅jngmgn=∂tmj∂l(wj⋅n)⊤gmgn=∂T∂lW⊤(5)
上式中 (wki)⊤ 表示矩阵的转置。因此我们将上式改为矩阵形式就可以得到
∂X∂l=∂T∂lW⊤(6)
由此,我们用很简单规整的数学推导推出了 loss 相对于 X 的导数(或者称为梯度)
同样的思路,我们可以推导出 W 对应的导数
∂W∂l=∂tij∂l∂wmn∂(xikw⋅jk+bj)gmgn=∂tij∂lxikδkmδj⋅ngmgn=∂tin∂lxi⋅mgmgn=(x⋅im)⊤∂tin∂lgmgn=X⊤∂T∂l(7)
和 b 对应的导数
∂b∂l=∂tij∂l∂bm∂(wikx⋅jk+bj)gm=∂tij∂lδj⋅mgm=∂tim∂lgm=i∑∂tim∂lgm(8)
这里比较特殊,最终得到的结果其实是 ∂T∂l 按照行求和得到了一个长度为 h 的向量。
将上述的公式表示为 pytorch
代码即为
1 2 3
| dX = dT @ W.T dW = X.T @ dT db = dT.sum(0)
|