def test_dynamic_stitch():
x = tf.zeros((1, 3))
y = tf.dynamic_stitch([[0], [0]], [x, tf.ones((1, 3))])
z = tf.gather(y, [0])
with tf.Session():
analytic, numeric = tf.test.compute_gradient(x, (1, 3), z, (1, 3))
assert np.allclose(analytic, numeric)
评论列表
文章目录