电脑基础 · 2023年3月31日

Pytorch深度学习实战3-5:详解计算图与自动微分机(附实例)

目录

  • 1 计算图原理
  • 2 基于计算图的传播
  • 3 神经网络计算图
  • 4 自动微分机
  • 5 Pytorch中的自动微分
    • 5.1 梯度缓存
    • 5.2 参数冻结

1 计算图原理

计算图(Computational Graph)是机器学习领域中推导神经网络和其他模型算法,以及软件编程实现的有效工具。

计算图的核心是将模型表示成一张拓扑有序(Topologically Ordered)有向无环图(Directed Acyclic Graph),其中每个节点
u
i
u_i
ui
包含数值信息(可以是标量、向量、矩阵或张量)和算子信息
f
i
f_i
fi
。拓扑有序指当前节点仅在全体指向它的节点被计算后才进行计算。

Pytorch深度学习实战3-5:详解计算图与自动微分机(附实例)
计算图的优点在于:

  • 可以通过基本初等映射 的拓扑联结,形成复合的复杂模型,大多数神经网络模型都可以被计算图表示;
  • 便于实现自动微分机(Automatic Differentiation Machine),对给定计算图可基于链式法则由节点局部梯度进行反向传播。

计算图的基本概念如表所示,基于计算图的基本前向传播和反向传播算法如表

符号 含义

n
n
n
计算图的节点数

l
l
l
计算图的叶节点数

L
L
L
计算图的叶节点索引集

C
C
C
计算图的非叶节点索引集

E
E
E
计算图的有向边集合

u
i
u_i
ui
计算图中的第
i
i
i
节点或其值

d
i
d_i
di

u
i
u_i
ui
的维度

f
i
f_i
fi

u
i
u_i
ui
的算子

α
i
\alpha _i
αi

u
i
u_i
ui
的全体关联输入

J
j

i
\boldsymbol{J}_{j\rightarrow i}
Jji
节点
u
i
u_i
ui
关于节点
u
j
u_j
uj
的雅克比矩阵

P
i
\boldsymbol{P}_i
Pi
输出节点关于输入节点的雅克比矩阵

2 基于计算图的传播

基于计算图的前向传播算法如下

Pytorch深度学习实战3-5:详解计算图与自动微分机(附实例)
基于计算图的反向传播算法如下

Pytorch深度学习实战3-5:详解计算图与自动微分机(附实例)

以第一节的图为例,可知
E
=
{
(
1
,
3
)
,
(
2
,
3
)
,
(
2
,
4
)
,
(
3
,
4
)
}
E=\left\{ \left( 1,3 \right) ,\left( 2,3 \right) ,\left( 2,4 \right) ,\left( 3,4 \right) \right\}
E={(1,3),(2,3),(2,4),(3,4)}
。首先进行前向传播:


{
u
3
=
u
1
+
u
2
=
5
u
4
=
u
2
u
3
=
15
\begin{cases} u_3=u_1+u_2=5\\ u_4=u_2u_3=15\\\end{cases}
{u3=u1+u2=5u4=u2u3=15


{
J
1

3
=

u
3
/

u
1
=
1
J
2

3
=

u
3
/

u
2
=
1
J
2

4
=

u
4
/

u
2
=
u
3
=
5
J
3

4
=

u
4
/

u
3
=
u
2
=
3
\begin{cases} \boldsymbol{J}_{1\rightarrow 3}={{\partial u_3}/{\partial u_1=}}1\\ \boldsymbol{J}_{2\rightarrow 3}={{\partial u_3}/{\partial u_2=}}1\\ \boldsymbol{J}_{2\rightarrow 4}={{\partial u_4}/{\partial u_2=}}u_3=5\\ \boldsymbol{J}_{3\rightarrow 4}={{\partial u_4}/{\partial u_3=}}u_2=3\\\end{cases}
J13=u3/u1=1J23=u3/u2=1J24=u4/u2=u3=5J34=u4/u3=u2=3

接着进行反向传播:


{
P
4
=
1
P
3
=
P
4
J
3

4
=
3
P
2
=
P
4
J
2

4
+
P
3
J
2

3
=
8
P
1
=
P
3
J
1

3
=
3
\begin{cases} \boldsymbol{P}_4=1\\ \boldsymbol{P}_3=\boldsymbol{P}_4\boldsymbol{J}_{3\rightarrow 4}=3\\ \boldsymbol{P}_2=\boldsymbol{P}_4\boldsymbol{J}_{2\rightarrow 4}+\boldsymbol{P}_3\boldsymbol{J}_{2\rightarrow 3}=8\\ \boldsymbol{P}_1=\boldsymbol{P}_3\boldsymbol{J}_{1\rightarrow 3}=3\\\end{cases}
P4=1P3=P4J34=3P2=P4J24+P3J23=8P1=P3J13=3

3 神经网络计算图

一个神经网络的计算图实例如下,所有参数都可以用之前的模型表示

Pytorch深度学习实战3-5:详解计算图与自动微分机(附实例)


L
{
u
1
=
W
1

R
n
1
×
n
u
2
=
b
1

R
n
1
u
3
=
x

R
n
u
4
=
W
2

R
n
2
×
n
1
u
5
=
b
2

R
n
2
u
6
=
y

R
n
2
  
C
{
u
7
=
z
1

R
n
1
=
W
1
x
+
b
1
u
8
=
a
1

R
n
1
=
σ
(
z
1
)
u
9
=
z
2

R
n
2
=
W
2
a
1
+
b
2
u
10
=
y

R
n
2
=
σ
(
z
2
)
u
11
=
E

R
=
1
2
(
y

y
~
)
T
(
y

y
~
)
L\begin{cases} u_1=\boldsymbol{W}^1\in \mathbb{R} ^{n_1\times n_0}\\ u_2=\boldsymbol{b}^1\in \mathbb{R} ^{n_1}\\ u_3=\boldsymbol{x}\in \mathbb{R} ^{n_0}\\ u_4=\boldsymbol{W}^2\in \mathbb{R} ^{n_2\times n_1}\\ u_5=\boldsymbol{b}^2\in \mathbb{R} ^{n_2}\\ u_6=\boldsymbol{y}\in \mathbb{R} ^{n_2}\\\end{cases}\,\, C\begin{cases} u_7=\boldsymbol{z}^1\in \mathbb{R} ^{n_1}=\boldsymbol{W}^1\boldsymbol{x}+\boldsymbol{b}^1\\ u_8=\boldsymbol{a}^1\in \mathbb{R} ^{n_1}=\sigma \left( \boldsymbol{z}^1 \right)\\ u_9=\boldsymbol{z}^2\in \mathbb{R} ^{n_2}=\boldsymbol{W}^2\boldsymbol{a}^1+\boldsymbol{b}^2\\ u_{10}=\boldsymbol{y}\in \mathbb{R} ^{n_2}=\sigma \left( \boldsymbol{z}^2 \right)\\ u_{11}=E\in \mathbb{R} =\frac{1}{2}\left( \boldsymbol{y}-\boldsymbol{\tilde{y}} \right) ^T\left( \boldsymbol{y}-\boldsymbol{\tilde{y}} \right)\\\end{cases}
Lu1=W1Rn1×n0u2=b1Rn1u3=xRn0u4=W2Rn2×n1u5=b2Rn2u6=yRn2Cu7=z1Rn1=W1x+b1u8=a1Rn1=σ(z1)u9=z2Rn2=W2a1+b2u10=yRn2=σ(z2)u11=ER=21(yy~)T(yy~)

4 自动微分机

自动微分机的基本原理是:

  • 跟踪记录从输入张量到输出张量的计算过程,并生成一幅前向传播计算图,计算图中的节点与张量一一对应
  • 基于计算图反向传播原理即可链式地求解输出节点关于各节点的梯度

必须指出,Pytorch不允许张量对张量求导,故输出节点必须是标量,通常为损失函数或输出向量的加权和;为节约内存,每次反向传播后Pytorch会自动释放前向传播计算图,即销毁中间计算节点的梯度和节点间的连接结构。

5 Pytorch中的自动微分

Tensor在自动微分机中的重要属性如表所示。

属性 含义
device 该节点运行的设备环境,即CPU/GPU
requires_grad 自动微分机是否需要对该节点求导,缺省为False
grad 输出节点对该节点的梯度,缺省为None
grad_fn 中间计算节点关于全体输入节点的映射,记录了前向传播经过的操作。叶节点为None
is_leaf 该节点是否为叶节点

完成前向传播后,调用反向传播API即可更新各节点梯度,具体如下

backward(gradient=None, retain_graph=None, create_graph=None)

其中

  • gradient是权重向量,当输出节点
    y
    y
    y
    不为标量时需指定与其同维的gradient,并以标量
    g
    r
    a
    d
    i
    e
    n
    t
    T
    y
    gradient^Ty
    gradientTy
    为输出进行反向传播
  • retain_graph用于缓存前向传播计算图,可应用于一次传播测试多个损失函数等情形;
  • creat_graph用于构造导数计算图,可用于进一步求解高阶导数。

5.1 梯度缓存

中间计算节点的梯度需要通过retain_grad()方法进行缓存

w1 = torch.tensor([[2.], [3.]], requires_grad=True)
b1 = torch.tensor([1.], requires_grad=True)
x = torch.tensor([[10.], [20.]])
y = torch.mm(w1.transpose(0, 1), x) + b1
y.retain_grad()	# 若不缓存则y.grad=None
out = 3*y
out.backward()
>> tensor([[30.], [60.]]) tensor([3.]) None tensor([[3.]])

5.2 参数冻结

若希望冻结网络部分参数,只调整优化另一部分参数;或按顺序训练分支网络而屏蔽对主网络梯度的,可使用detach()方法从计算图中分离节点,阻断反向传播。分离的节点与原节点共享值内存,但不具有gradgrad_fn属性。

# 记第一层网络w1-b1为f,第二层网络w2-b2为g
w1 = torch.tensor([[2.], [3.]], requires_grad=True)
w2 = torch.tensor([3.], requires_grad=True)
b1 = torch.tensor([1.], requires_grad=True)
b2 = torch.tensor([2.], requires_grad=True)
x = torch.tensor([[10.], [20.]])
y = torch.mm(w1.transpose(0, 1), x) + b1
y_ = y.detach()
z = w2 * y_ + b2
out = 3*z
out.backward()
print(w1.grad, b1.grad, w2.grad, b2.grad)
>> None None tensor([243.]) tensor([3.]) # f被冻结,梯度不更新
# 若不使用detach冻结y之前的网络,则
>> tensor([[ 90.], [180.]]) tensor([9.]) tensor([243.]) tensor([3.])

🔥 更多精彩专栏

  • 《ROS从入门到精通》
  • 《Pytorch深度学习实战》
  • 《机器学习强基计划》
  • 《运动规划实战精讲》

👇源码获取 · 技术交流 · 抱团学习 · 咨询分享 请联系👇