def test_reshape():
x2 = ad.Variable(name='x2')
y = ad.reshape(x2, newshape=(1, 4))
grad_x2, = ad.gradients(y, [x2])
executor = ad.Executor([y, grad_x2])
x2_val = np.random.randn(2, 2)
y_val, grad_x2_val = executor.run(feed_shapes={x2: x2_val})
assert isinstance(y, ad.Node)
assert y_val.shape == (1, 4)
npt.assert_array_equal(grad_x2_val, np.ones((2, 2)))
# x2 = ad.Variable(name='x2')
# y = ad.reshape(x2, newshape=(2, 1, 2, 3))
# grad_x2, = ad.gradients(y, [x2])
# executor = ad.Executor([y, grad_x2])
# x2_val = np.random.randn(2, 6)
# y_val, grad_x2_val = executor.run(feed_shapes={x2: x2_val})
#
# assert isinstance(y, ad.Node)
# assert y_val.shape == (2, 1, 2, 3)
# npt.assert_array_equal(grad_x2_val, np.ones((2, 1, 2, 3)))
评论列表
文章目录