def test_idempotent_axes_b():
"""
Test test axes transformations with autodiff, case b, with broadcast applied
to the same tensor
"""
with ExecutorFactory() as ex:
axes = ng.make_axes([ng.make_axis(3), ng.make_axis(1)])
w = ng.variable(axes, initial_value=np.ones((3, 1)))
l = ng.broadcast(w, axes)
r = ng.broadcast(w, axes)
result = ng.add(l, r)
result = ng.cast_axes(result, axes)
cost = ng.sum(result, reduction_axes=axes)
grad = ng.deriv(cost, w)
grad_comp = ex.executor(grad)
cost_comp = ex.executor(cost)
assert cost_comp() == 6.0
assert np.array_equal(grad_comp(), np.ones((3, 1)) * 2.)
test_axes_transformer_dependent.py 文件源码
python
阅读 27
收藏 0
点赞 0
评论 0
评论列表
文章目录