def test_linear_keep_batch_axes_ones(batch_axis, input_size, input_placeholder, output_size,
transformer_factory):
# basic sanity check with all ones on the inputs and weights, check that
# each row in output is the sum of the weights for that output this check
# will confirm that the correct number of operations is being run
x = np.ones(input_placeholder.axes.lengths)
layer = Linear(nout=output_size, keep_axes=[batch_axis], init=UniformInit(1.0, 1.0))
with ExecutorFactory() as ex:
if ex.transformer.transformer_name == 'hetr':
pytest.xfail("hetr fork-safe issue on mac")
out = layer(input_placeholder)
comp = ex.executor([out, layer.W], input_placeholder)
output_values, w = comp(x)
assert np.allclose(
np.ones(out.axes.lengths) * input_size,
output_values,
atol=0.0, rtol=0.0
)
评论列表
文章目录