def test_flat_tensor_dot_tensor():
"""
Ensure that a flattened argument axis is not unflattend in the result.
"""
H = ng.make_axis(2)
W = ng.make_axis(7)
C = ng.make_axis(3)
K = ng.make_axis(11)
axes_a = ng.make_axes([H, W, C])
a = ng.constant(np.ones(axes_a.lengths), axes=axes_a)
flat_a = ng.flatten_at(a, 2)
axes_b = ng.make_axes([C, K])
b = ng.constant(np.ones(axes_b.lengths), axes=axes_b)
result = ng.dot(b, flat_a)
with ExecutorFactory() as factory:
result_fun = factory.executor(result)
result_val = result_fun()
result_correct = np.ones_like(result_val) * C.length
ng.testing.assert_allclose(result_val, result_correct)
评论列表
文章目录