test_axes_transformer_dependent.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def test_idempotent_axes_b():
    """
    Test test axes transformations with autodiff, case b, with broadcast applied
    to the same tensor
    """
    with ExecutorFactory() as ex:
        axes = ng.make_axes([ng.make_axis(3), ng.make_axis(1)])

        w = ng.variable(axes, initial_value=np.ones((3, 1)))
        l = ng.broadcast(w, axes)
        r = ng.broadcast(w, axes)
        result = ng.add(l, r)

        result = ng.cast_axes(result, axes)
        cost = ng.sum(result, reduction_axes=axes)
        grad = ng.deriv(cost, w)

        grad_comp = ex.executor(grad)
        cost_comp = ex.executor(cost)

        assert cost_comp() == 6.0
        assert np.array_equal(grad_comp(), np.ones((3, 1)) * 2.)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号