def __init__(self,
n_classes,
label_name=None,
weight_column_name=None,
enable_centered_bias=False,
head_name=None,
loss_fn=_softmax_cross_entropy_loss,
thresholds=None,
metric_class_ids=None):
"""_Head for classification.
Args:
n_classes: Number of classes, must be greater than 2 (for 2 classes, use
`_BinaryLogisticHead`).
label_name: String, name of the key in label dict. Can be null if label
is a tensor (single headed models).
weight_column_name: A string defining feature column name representing
weights. It is used to down weight or boost examples during training. It
will be multiplied by the loss of the example.
enable_centered_bias: A bool. If True, estimator will learn a centered
bias variable for each class. Rest of the model structure learns the
residual after centered bias.
head_name: name of the head. If provided, predictions, summary, metrics
keys will be suffixed by `"/" + head_name` and the default variable
scope will be `head_name`.
loss_fn: Loss function.
thresholds: thresholds for eval.
metric_class_ids: List of class IDs for which we should report per-class
metrics. Must all be in the range `[0, n_classes)`.
Raises:
ValueError: if `n_classes` or `metric_class_ids` is invalid.
"""
super(_MultiClassHead, self).__init__(
problem_type=constants.ProblemType.CLASSIFICATION,
logits_dimension=n_classes,
label_name=label_name,
weight_column_name=weight_column_name,
head_name=head_name)
if (n_classes is None) or (n_classes <= 2):
raise ValueError("n_classes must be > 2: %s." % n_classes)
self._thresholds = thresholds if thresholds else (.5,)
self._loss_fn = loss_fn
self._enable_centered_bias = enable_centered_bias
self._metric_class_ids = tuple([] if metric_class_ids is None else
metric_class_ids)
for class_id in self._metric_class_ids:
if (class_id < 0) or (class_id >= n_classes):
raise ValueError("Class ID %s not in [0, %s)." % (class_id, n_classes))
head.py 文件源码
python
阅读 20
收藏 0
点赞 0
评论 0
评论列表
文章目录