def testInferenceConstructionSparse(self):
input_data = tf.SparseTensor(
indices=[[0, 0], [0, 3],
[1, 0], [1, 7],
[2, 1],
[3, 9]],
values=[-1.0, 0.0,
-1., 2.,
1.,
-2.0],
shape=[4, 10])
params = tensor_forest.ForestHParams(
num_classes=4, num_features=10, num_trees=10, max_nodes=1000,
split_after_samples=25).fill()
graph_builder = tensor_forest.RandomForestGraphs(params)
graph = graph_builder.inference_graph(input_data)
self.assertTrue(isinstance(graph, tf.Tensor))
评论列表
文章目录