test_axes_transformer_dependent.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def test_idempotent_axes_a():
    """
    Test test axes transformations with autodiff, case a, reference test
    """
    with ExecutorFactory() as ex:
        axes = ng.make_axes([ng.make_axis(3), ng.make_axis(1)])

        w = ng.variable(axes, initial_value=np.ones((3, 1)))
        result = w + w

        result = ng.cast_axes(result, axes)
        cost = ng.sum(result, reduction_axes=axes)
        grad = ng.deriv(cost, w)

        grad_comp = ex.executor(grad)
        cost_comp = ex.executor(cost)

        cost_comp_val = cost_comp()
        grad_comp_val = grad_comp()
        grad_comp_np = np.ones((3, 1)) * 2.

        assert cost_comp_val == 6.0
        assert np.array_equal(grad_comp_val, grad_comp_np)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号