test_execution.py 文件源码

python
阅读 39 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def test_variance_sqrt_inverse(input_tensor):
    inputs = input_tensor
    targets = ng.placeholder(inputs.axes)

    epsilon = 1e-3

    inp_stat = ng.reciprocal(
        ng.sqrt(
            ng.variance(inputs, reduction_axes=inputs.axes.batch_axes()) + epsilon
        )
    )
    err = ng.sum(inp_stat - targets, out_axes=())
    d_inputs = ng.deriv(err, inputs)
    with executor([err, d_inputs], inputs, targets) as comp_func:

        input_value = rng.uniform(-1, 1, inputs.axes)
        target_value = rng.uniform(-1, 1, targets.axes)
        ng_f_res, ng_b_res = comp_func(input_value, target_value)

        npv = np.var(input_value, axis=1, keepdims=True) + epsilon
        np_f_res = 1.0 / np.sqrt(npv)

        npv_delta = 2 * (input_value - np.mean(input_value, axis=1, keepdims=True))

        np_b_res = - 0.5 * np_f_res / npv * npv_delta

        np_f_res = np.sum(np_f_res - target_value)

        ng.testing.assert_allclose(np_f_res, ng_f_res, atol=1e-4, rtol=1e-4)
        ng.testing.assert_allclose(np_b_res, ng_b_res, atol=1e-4, rtol=1e-4)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号