test_logcumsumexp.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:shoelace 作者: rjagerman 项目源码 文件源码
def test_backward():

    # Construct test data
    x = Variable(np.array([5., 3., 3., 1., 0.]))
    g = Variable(np.ones(5))
    expected_result = np.array([0.7717692057972512, 0.562087881852882,
                                1.4058826163342215, 0.9213241007090265,
                                1.3389361953066183])

    # Generate object
    lcse = LogCumsumExp()

    # Run forward and backward pass
    lcse.forward((x.data,))
    result = lcse.backward((x.data, ), (g.data, ))

    # Assert that the result equals the expected result
    assert_true(np.array_equal(result[0], expected_result))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号