def _predictions(logits, n_classes):
"""Returns predictions for the given logits and n_classes."""
predictions = {}
if n_classes == 2:
predictions[_LOGISTIC] = math_ops.sigmoid(logits)
logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
predictions[_PROBABILITIES] = nn.softmax(logits)
predictions[_CLASSES] = array_ops.reshape(
math_ops.argmax(logits, 1), shape=(-1, 1))
return predictions
评论列表
文章目录