test_linear_layer.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def test_linear_ones(input_size, input_placeholder, output_size):
    # basic sanity check with all ones on the inputs and weights, check that
    # each row in output is the sum of the weights for that output this check
    # will confirm that the correct number of operations is being run
    x = np.ones(input_placeholder.axes.lengths)
    layer = Linear(nout=output_size, init=UniformInit(1.0, 1.0))

    with ExecutorFactory() as ex:
        if ex.transformer.transformer_name == 'hetr':
            pytest.xfail("hetr fork-safe issue on mac")
        out = layer(input_placeholder)
        comp = ex.executor([out, layer.W], input_placeholder)
        output_values, w = comp(x)

    ng.testing.assert_allclose(
        np.ones(out.axes.lengths) * input_size,
        output_values,
        atol=0.0, rtol=0.0
    )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号