test_execution.py 文件源码

python
阅读 44 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def test_variance_wgrad(input_tensor):
    inputs = input_tensor
    targets = ng.placeholder(inputs.axes)

    inp_stat = ng.variance(inputs, reduction_axes=inputs.axes.batch_axes())
    err = ng.sum(inp_stat - targets, out_axes=())
    d_inputs = ng.deriv(err, inputs)
    with executor([err, d_inputs], inputs, targets) as comp_func:

        input_value = rng.uniform(-0.1, 0.1, inputs.axes)
        target_value = rng.uniform(-0.1, 0.1, targets.axes)
        ng_f_res, ng_b_res = comp_func(input_value, target_value)

        np_f_res = np.sum(np.var(input_value, axis=1, keepdims=True) - target_value)

        ng.testing.assert_allclose(np_f_res, ng_f_res, atol=1e-4, rtol=1e-4)

        np_b_res = 2 * (input_value - np.mean(input_value, axis=1, keepdims=True))

        ng.testing.assert_allclose(np_b_res, ng_b_res, atol=1e-4, rtol=1e-4)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号