tensor_forest.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def training_graph(self, input_data, input_labels, data_spec=None,
                     epoch=None, **tree_kwargs):
    """Constructs a TF graph for training a random forest.

    Args:
      input_data: A tensor or SparseTensor or placeholder for input data.
      input_labels: A tensor or placeholder for labels associated with
        input_data.
      data_spec: A list of tf.dtype values specifying the original types of
        each column.
      epoch: A tensor or placeholder for the epoch the training data comes from.
      **tree_kwargs: Keyword arguments passed to each tree's training_graph.

    Returns:
      The last op in the random forest training graph.
    """
    data_spec = [constants.DATA_FLOAT] if data_spec is None else data_spec
    tree_graphs = []
    for i in range(self.params.num_trees):
      with ops.device(self.device_assigner.get_device(i)):
        seed = self.params.base_random_seed
        if seed != 0:
          seed += i
        # If using bagging, randomly select some of the input.
        tree_data = input_data
        tree_labels = input_labels
        if self.params.bagging_fraction < 1.0:
          # TODO(thomaswc): This does sampling without replacment.  Consider
          # also allowing sampling with replacement as an option.
          batch_size = array_ops.slice(array_ops.shape(input_data), [0], [1])
          r = random_ops.random_uniform(batch_size, seed=seed)
          mask = math_ops.less(
              r, array_ops.ones_like(r) * self.params.bagging_fraction)
          gather_indices = array_ops.squeeze(
              array_ops.where(mask), squeeze_dims=[1])
          # TODO(thomaswc): Calculate out-of-bag data and labels, and store
          # them for use in calculating statistics later.
          tree_data = array_ops.gather(input_data, gather_indices)
          tree_labels = array_ops.gather(input_labels, gather_indices)
        if self.params.bagged_features:
          tree_data = self._bag_features(i, tree_data)

        initialization = self.trees[i].tree_initialization()

        with ops.control_dependencies([initialization]):
          tree_graphs.append(
              self.trees[i].training_graph(
                  tree_data, tree_labels, seed, data_spec=data_spec,
                  epoch=([0] if epoch is None else epoch),
                  **tree_kwargs))

    return control_flow_ops.group(*tree_graphs, name='train')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号