1. 计算图
  2. PyTorch的动态图机制

一、计算图

1.1 计算图的表示

  • 计算图是用来描述运算的有向无环图
  • 计算图有两个主要元素:结点(Node)和 边(Edge)
    • 结点表示数据,如向量,矩阵,张量
    • 表示运算,如加、减、乘、除、卷积等
  • 用计算图表示:计算图与动态图机制 - 图1

image.png

  • 计算图与动态图机制 - 图3
  • 计算图与动态图机制 - 图4
  • 计算图与动态图机制 - 图5

    1.2 计算图与梯度求导

    image.png
  • 叶子结点:用户创建的结点称为叶子结点,如 X 与 W
    • 叶子节点是整个计算图的根基,例如前面求导的计算图,在前向传导中的a、b和y都要依据创建的叶子节点x和w进行计算的。同样,在反向传播过程中,所有梯度的计算都要依赖叶子节点。
    • 设置叶子节点主要是为了节省内存,在梯度反向传播结束之后,非叶子节点的梯度都会被释放掉。可以根据代码分析一下非叶子节点a、b和y的梯度情况。
    • is_leaf : 指示张量是否为叶子结点
    • 所有的梯度计算都需要叶子节点
    • image.png
  • grad_fn : 记录创建该张量时所用的方法(函数)
    • image.png
      1. y.grad_fn = <MulBackward0>
      2. a.grad_fn = <AddBackward0>
      3. b.grad_fn = <AddBackward0>
      ```python import torch

w = torch.tensor(1, dtype=float, requires_grad=True) x = torch.tensor(2, dtype=float, requires_grad=True)

a = torch.add(w, x) a.retain_grad() # 在反向传播之后,仍然保留非叶子节点的梯度 b = torch.add(w, 1) y = torch.mul(a, b)

y.backward() print(w.grad)

查看叶子结点

print(“is_leaf:\n”, w.is_leaf, x.is_leaf, a.is_leaf, b.is_leaf, y.is_leaf)

查看梯度

print(“gradient:\n”, w.grad, x.grad, a.grad, b.grad, y.grad) # 反向传播之后,会将非叶子节点释放掉,节省内存开销

查看 grad_fn

print(“grad_fn:\n”, w.grad_fn, x.grad_fn, a.grad_fn, b.grad_fn, y.grad_fn) # 记录创建该张量时所用的方法

![image.png](https://cdn.nlark.com/yuque/0/2020/png/353587/1588325528711-bb165cfa-411d-4076-878b-6586038d2601.png#align=left&display=inline&height=103&margin=%5Bobject%20Object%5D&name=image.png&originHeight=241&originWidth=1748&size=36458&status=done&style=none&width=746)
<a name="pRQA6"></a>
# 二、动态图机制
![image.png](https://cdn.nlark.com/yuque/0/2020/png/353587/1588324900484-6939a15c-d7a5-4372-918f-56f34a638378.png#align=left&display=inline&height=227&margin=%5Bobject%20Object%5D&name=image.png&originHeight=556&originWidth=1829&size=174485&status=done&style=none&width=746)<br />![](https://cdn.nlark.com/yuque/0/2020/gif/353587/1588325751404-799fb259-9ea1-4359-8d06-0e838a9daa55.gif#align=left&display=inline&height=448&margin=%5Bobject%20Object%5D&originHeight=448&originWidth=252&size=0&status=done&style=none&width=252)<br />**TensorFlow静态图演示**
> 在上面这个图中,框代表的就是节点,带箭头的线代表边。tensorflow使用的是静态图,是先将图搭建好之后,再input数据进去。

![image.png](https://cdn.nlark.com/yuque/0/2020/png/353587/1588326378657-eaee5b16-d1e9-47ee-97ff-3c8c654599f2.png#align=left&display=inline&height=365&margin=%5Bobject%20Object%5D&name=image.png&originHeight=619&originWidth=1264&size=396994&status=done&style=none&width=746)<br />**PyTorch动态图**
> 动态图的搭建是根据每一步的计算搭建的,而tensorflow是先搭建所有的计算图之后,再把数据输入进去。这就是动态图和静态图的区别。

```python
# 1. 创建四个tensot
W_h = torch.randn(20, 20, requires_grad=True)
W_x = torch.randn(20, 10, requires_grad=True)
x = torch.randn(1, 10)
prev_h = torch.randn(1, 20)

# 2. 定义各种数据操作
h2h = torch.mm(W_h, prev_h.t())  # 将W_h和prev_h进行相乘,得到一个新张量h2h
i2h = torch.mm(W_x, x.t())  # 将W_x和x进行相乘,等到一个新张量i2h
next_h = h2h + i2h  # 创建加法操作
next_h = next_h.tanh()  # 使用激活函数

# 3. 计算并回传loss
loss = next_h.sum()   #计算损失函数
loss.backward()   #梯度反向传播