tensor_forest_test.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:deep-learning 作者: lbkchen 项目源码 文件源码
def testInferenceConstructionSparse(self):
    input_data = tf.SparseTensor(
        indices=[[0, 0], [0, 3],
                 [1, 0], [1, 7],
                 [2, 1],
                 [3, 9]],
        values=[-1.0, 0.0,
                -1., 2.,
                1.,
                -2.0],
        shape=[4, 10])

    params = tensor_forest.ForestHParams(
        num_classes=4, num_features=10, num_trees=10, max_nodes=1000,
        split_after_samples=25).fill()

    graph_builder = tensor_forest.RandomForestGraphs(params)
    graph = graph_builder.inference_graph(input_data)
    self.assertTrue(isinstance(graph, tf.Tensor))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号