random_forest.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def predict(
      self, x=None, input_fn=None, axis=None, batch_size=None, outputs=None,
      as_iterable=True):
    """Returns predictions for given features.

    Args:
      x: features.
      input_fn: Input function. If set, x must be None.
      axis: Axis on which to argmax (for classification).
            Last axis is used by default.
      batch_size: Override default batch size.
      outputs: list of `str`, name of the output to predict.
        If `None`, returns all.
      as_iterable: If True, return an iterable which keeps yielding predictions
        for each example until inputs are exhausted. Note: The inputs must
        terminate if you want the iterable to terminate (e.g. be sure to pass
        num_epochs=1 if you are using something like read_batch_features).

    Returns:
      Numpy array of predicted classes or regression values (or an iterable of
      predictions if as_iterable is True).
    """
    results = self._estimator.predict(
        x=x, input_fn=input_fn, batch_size=batch_size, outputs=outputs,
        as_iterable=as_iterable)

    predict_name = (eval_metrics.INFERENCE_PROB_NAME if self.params.regression
                    else eval_metrics.INFERENCE_PRED_NAME)
    if as_iterable:
      return (x[predict_name] for x in results)
    else:
      return results[predict_name]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号