def test_state_grads():
with tf.Session() as sess:
v = tf.Variable([0., 0., 0.])
x = tf.ones((3,))
y0 = tf.assign(v, x)
y1 = tf.assign_add(v, x)
grad0 = tf.gradients(y0, [v, x])
grad1 = tf.gradients(y1, [v, x])
grad_vals = sess.run((grad0, grad1))
assert np.allclose(grad_vals[0][0], 0)
assert np.allclose(grad_vals[0][1], 1)
assert np.allclose(grad_vals[1][0], 1)
assert np.allclose(grad_vals[1][1], 1)
with tf.Session() as sess:
v = tf.Variable([0., 0., 0.])
x = tf.ones((1,))
y0 = tf.scatter_update(v, [0], x)
y1 = tf.scatter_add(v, [0], x)
grad0 = tf.gradients(y0, [v._ref(), x])
grad1 = tf.gradients(y1, [v._ref(), x])
grad_vals = sess.run((grad0, grad1))
assert np.allclose(grad_vals[0][0], [0, 1, 1])
assert np.allclose(grad_vals[0][1], 1)
assert np.allclose(grad_vals[1][0], 1)
assert np.allclose(grad_vals[1][1], 1)
评论列表
文章目录