head.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def __init__(self,
               n_classes,
               label_name=None,
               weight_column_name=None,
               enable_centered_bias=False,
               head_name=None,
               loss_fn=_softmax_cross_entropy_loss,
               thresholds=None,
               metric_class_ids=None):
    """_Head for classification.

    Args:
      n_classes: Number of classes, must be greater than 2 (for 2 classes, use
        `_BinaryLogisticHead`).
      label_name: String, name of the key in label dict. Can be null if label
        is a tensor (single headed models).
      weight_column_name: A string defining feature column name representing
        weights. It is used to down weight or boost examples during training. It
        will be multiplied by the loss of the example.
      enable_centered_bias: A bool. If True, estimator will learn a centered
        bias variable for each class. Rest of the model structure learns the
        residual after centered bias.
      head_name: name of the head. If provided, predictions, summary, metrics
        keys will be suffixed by `"/" + head_name` and the default variable
        scope will be `head_name`.
      loss_fn: Loss function.
      thresholds: thresholds for eval.
      metric_class_ids: List of class IDs for which we should report per-class
        metrics. Must all be in the range `[0, n_classes)`.

    Raises:
      ValueError: if `n_classes` or `metric_class_ids` is invalid.
    """
    super(_MultiClassHead, self).__init__(
        problem_type=constants.ProblemType.CLASSIFICATION,
        logits_dimension=n_classes,
        label_name=label_name,
        weight_column_name=weight_column_name,
        head_name=head_name)

    if (n_classes is None) or (n_classes <= 2):
      raise ValueError("n_classes must be > 2: %s." % n_classes)
    self._thresholds = thresholds if thresholds else (.5,)
    self._loss_fn = loss_fn
    self._enable_centered_bias = enable_centered_bias
    self._metric_class_ids = tuple([] if metric_class_ids is None else
                                   metric_class_ids)
    for class_id in self._metric_class_ids:
      if (class_id < 0) or (class_id >= n_classes):
        raise ValueError("Class ID %s not in [0, %s)." % (class_id, n_classes))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号