test_autodiff_cpu.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:Aurora 作者: upul 项目源码 文件源码
def test_reshape():
    x2 = ad.Variable(name='x2')
    y = ad.reshape(x2, newshape=(1, 4))

    grad_x2, = ad.gradients(y, [x2])
    executor = ad.Executor([y, grad_x2])
    x2_val = np.random.randn(2, 2)
    y_val, grad_x2_val = executor.run(feed_shapes={x2: x2_val})

    assert isinstance(y, ad.Node)
    assert y_val.shape == (1, 4)
    npt.assert_array_equal(grad_x2_val, np.ones((2, 2)))

    # x2 = ad.Variable(name='x2')
    # y = ad.reshape(x2, newshape=(2, 1, 2, 3))
    # grad_x2, = ad.gradients(y, [x2])
    # executor = ad.Executor([y, grad_x2])
    # x2_val = np.random.randn(2, 6)
    # y_val, grad_x2_val = executor.run(feed_shapes={x2: x2_val})
    #
    # assert isinstance(y, ad.Node)
    # assert y_val.shape == (2, 1, 2, 3)
    # npt.assert_array_equal(grad_x2_val, np.ones((2, 1, 2, 3)))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号