def test_idempotent_axes_a():
"""
Test test axes transformations with autodiff, case a, reference test
"""
with ExecutorFactory() as ex:
axes = ng.make_axes([ng.make_axis(3), ng.make_axis(1)])
w = ng.variable(axes, initial_value=np.ones((3, 1)))
result = w + w
result = ng.cast_axes(result, axes)
cost = ng.sum(result, reduction_axes=axes)
grad = ng.deriv(cost, w)
grad_comp = ex.executor(grad)
cost_comp = ex.executor(cost)
cost_comp_val = cost_comp()
grad_comp_val = grad_comp()
grad_comp_np = np.ones((3, 1)) * 2.
assert cost_comp_val == 6.0
assert np.array_equal(grad_comp_val, grad_comp_np)
test_axes_transformer_dependent.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录