def training_graph(self, input_data, input_labels, data_spec=None,
epoch=None, **tree_kwargs):
"""Constructs a TF graph for training a random forest.
Args:
input_data: A tensor or SparseTensor or placeholder for input data.
input_labels: A tensor or placeholder for labels associated with
input_data.
data_spec: A list of tf.dtype values specifying the original types of
each column.
epoch: A tensor or placeholder for the epoch the training data comes from.
**tree_kwargs: Keyword arguments passed to each tree's training_graph.
Returns:
The last op in the random forest training graph.
"""
data_spec = [constants.DATA_FLOAT] if data_spec is None else data_spec
tree_graphs = []
for i in range(self.params.num_trees):
with ops.device(self.device_assigner.get_device(i)):
seed = self.params.base_random_seed
if seed != 0:
seed += i
# If using bagging, randomly select some of the input.
tree_data = input_data
tree_labels = input_labels
if self.params.bagging_fraction < 1.0:
# TODO(thomaswc): This does sampling without replacment. Consider
# also allowing sampling with replacement as an option.
batch_size = array_ops.slice(array_ops.shape(input_data), [0], [1])
r = random_ops.random_uniform(batch_size, seed=seed)
mask = math_ops.less(
r, array_ops.ones_like(r) * self.params.bagging_fraction)
gather_indices = array_ops.squeeze(
array_ops.where(mask), squeeze_dims=[1])
# TODO(thomaswc): Calculate out-of-bag data and labels, and store
# them for use in calculating statistics later.
tree_data = array_ops.gather(input_data, gather_indices)
tree_labels = array_ops.gather(input_labels, gather_indices)
if self.params.bagged_features:
tree_data = self._bag_features(i, tree_data)
initialization = self.trees[i].tree_initialization()
with ops.control_dependencies([initialization]):
tree_graphs.append(
self.trees[i].training_graph(
tree_data, tree_labels, seed, data_spec=data_spec,
epoch=([0] if epoch is None else epoch),
**tree_kwargs))
return control_flow_ops.group(*tree_graphs, name='train')
评论列表
文章目录