test_linear_layer.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def test_linear_keep_axes_ones(batch_axis, input_size, input_placeholder, output_size,
                               transformer_factory):
    # basic sanity check with all ones on the inputs and weights, check that
    # each row in output is the sum of the weights for that output this check
    # will confirm that the correct number of operations is being run
    x = np.ones(input_placeholder.axes.lengths)
    layer = Linear(nout=output_size, keep_axes=[], init=UniformInit(1.0, 1.0))

    with ExecutorFactory() as ex:
        if ex.transformer.transformer_name == 'hetr':
            pytest.xfail("hetr fork-safe issue on mac")
        out = layer(input_placeholder)
        comp = ex.executor([out, layer.W], input_placeholder)
        output_values, w = comp(x)

    assert np.allclose(
        np.ones(out.axes.lengths) * input_size * batch_axis.length,
        output_values,
        atol=0.0, rtol=0.0
    )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号