def _logits_to_predictions(self, logits):
"""See `_MultiClassHead`."""
predictions = {}
predictions[prediction_key.PredictionKey.LOGITS] = logits
logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
predictions[prediction_key.PredictionKey.CLASSES] = math_ops.argmax(
logits, 1)
return predictions
评论列表
文章目录