test_model_wrappers.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def test_attribute_wrapper():
    def attribute_value_test(attribute_value):
        node = make_node('Abs', ['X'], [], name='test_node', test_attribute=attribute_value)
        model = make_model(make_graph([node], 'test_graph', [
            make_tensor_value_info('X', onnx.TensorProto.FLOAT, [1, 2]),
        ], []), producer_name='ngraph')
        wrapped_attribute = ModelWrapper(model).graph.node[0].get_attribute('test_attribute')
        return wrapped_attribute.get_value()

    tensor = make_tensor('test_tensor', onnx.TensorProto.FLOAT, [1], [1])

    assert attribute_value_test(1) == 1
    assert type(attribute_value_test(1)) == np.long
    assert attribute_value_test(1.0) == 1.0
    assert type(attribute_value_test(1.0)) == np.float
    assert attribute_value_test('test') == 'test'
    assert attribute_value_test(tensor)._proto == tensor

    assert attribute_value_test([1, 2, 3]) == [1, 2, 3]
    assert attribute_value_test([1.0, 2.0, 3.0]) == [1.0, 2.0, 3.0]
    assert attribute_value_test(['test1', 'test2']) == ['test1', 'test2']
    assert attribute_value_test([tensor, tensor])[1]._proto == tensor
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号