TensorFlow2.0教程-自动求导

TensorFlow2.0教程-自动求导

TensorFlow2.0教程-自动求导

这节我们会介绍使用tensorflow2自动求导的方法。


一、Gradient tapes

tensorflow 提供tf.GradientTape api来实现自动求导功能。只要在tf.GradientTape()上下文中执行的操作,都会被记录与“tape”中,然后tensorflow使用反向自动微分来计算相关操作的梯度。

x = tf.ones((2,2))

# 需要计算梯度的操作
with tf.GradientTape() as t:
    t.watch(x)
    y = tf.reduce_sum(x)
    z = tf.multiply(y,y)
# 计算z关于x的梯度
dz_dx = t.gradient(z, x)
print(dz_dx)
tf.Tensor(
[[8. 8.]
 [8. 8.]], shape=(2, 2), dtype=float32)

也可以输出对中间变量的导数

# 梯度求导只能每个tape一次
with tf.GradientTape() as t:
    t.watch(x)
    y = tf.reduce_sum(x)
    z = tf.multiply(y,y)

dz_dy = t.gradient(z, y)
print(dz_dy)
tf.Tensor(8.0, shape=(), dtype=float32)

默认情况下GradientTape的资源会在执行tf.GradientTape()后被释放。如果想多次计算梯度,需要创建一个持久的GradientTape。

with tf.GradientTape(persistent=True) as t:
    t.watch(x)
    y = tf.reduce_sum(x)
    z = tf.multiply(y, y)

dz_dx = t.gradient(z,x)
print(dz_dx)
dz_dy = t.gradient(z, y)
print(dz_dy)
tf.Tensor(
[[8. 8.]
 [8. 8.]], shape=(2, 2), dtype=float32)
tf.Tensor(8.0, shape=(), dtype=float32)

二、记录控制流

因为tapes记录了整个操作,所以即使过程中存在python控制流(如if, while),梯度求导也能正常处理。

def f(x, y):
    output = 1.0
    # 根据y的循环
    for i in range(y):
        # 根据每一项进行判断
        if i> 1 and i<5:
            output = tf.multiply(output, x)
    return output

def grad(x, y):
    with tf.GradientTape() as t:
        t.watch(x)
        out = f(x, y)
        # 返回梯度
        return t.gradient(out, x)
# x为固定值
x = tf.convert_to_tensor(2.0)

print(grad(x, 6))
print(grad(x, 5))
print(grad(x, 4))
tf.Tensor(12.0, shape=(), dtype=float32)
tf.Tensor(12.0, shape=(), dtype=float32)
tf.Tensor(4.0, shape=(), dtype=float32)

三、高阶梯度

GradientTape上下文管理器在计算梯度的同时也会保持梯度,所以GradientTape也可以实现高阶梯度计算,

x = tf.Variable(1.0)

with tf.GradientTape() as t1:
    with tf.GradientTape() as t2:
        y = x * x * x
    dy_dx = t2.gradient(y, x)
    print(dy_dx)
d2y_d2x = t1.gradient(dy_dx, x)
print(d2y_d2x)
tf.Tensor(3.0, shape=(), dtype=float32)
tf.Tensor(6.0, shape=(), dtype=float32)

之前有事暂停了更新,后面会逐渐恢复更新,感谢大家的支持和喜爱。

发布于 2019-06-20

文章被以下专栏收录