tensor_forest_test.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:deep-learning 作者: lbkchen 项目源码 文件源码
def testTrainingConstructionRegression(self):
    input_data = [[-1., 0.], [-1., 2.],  # node 1
                  [1., 0.], [1., -2.]]  # node 2
    input_labels = [0, 1, 2, 3]

    params = tensor_forest.ForestHParams(
        num_classes=4, num_features=2, num_trees=10, max_nodes=1000,
        split_after_samples=25, regression=True).fill()

    graph_builder = tensor_forest.RandomForestGraphs(params)
    graph = graph_builder.training_graph(input_data, input_labels)
    self.assertTrue(isinstance(graph, tf.Operation))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号