def inference_graph(self, input_data, data_spec=None):
"""Constructs a TF graph for evaluating a random forest.
Args:
input_data: A tensor or SparseTensor or placeholder for input data.
data_spec: A list of tf.dtype values specifying the original types of
each column.
Returns:
The last op in the random forest inference graph.
"""
data_spec = [constants.DATA_FLOAT] if data_spec is None else data_spec
probabilities = []
for i in range(self.params.num_trees):
with ops.device(self.device_assigner.get_device(i)):
tree_data = input_data
if self.params.bagged_features:
tree_data = self._bag_features(i, input_data)
probabilities.append(self.trees[i].inference_graph(tree_data,
data_spec))
with ops.device(self.device_assigner.get_device(0)):
all_predict = array_ops.pack(probabilities)
return math_ops.div(
math_ops.reduce_sum(all_predict, 0), self.params.num_trees,
name='probabilities')
评论列表
文章目录