def softmax(logits, scope=None):
"""Performs softmax on Nth dimension of N-dimensional logit tensor.
For two-dimensional logits this reduces to tf.nn.softmax. The N-th dimension
needs to have a specified number of elements (number of classes).
Args:
logits: N-dimensional `Tensor` with logits, where N > 1.
scope: Optional scope for variable_scope.
Returns:
a `Tensor` with same shape and type as logits.
"""
# TODO(jrru): Add axis argument which defaults to last dimension.
with variable_scope.variable_scope(scope, 'softmax', [logits]):
num_logits = utils.last_dimension(logits.get_shape(), min_rank=2)
logits_2d = array_ops.reshape(logits, [-1, num_logits])
predictions = nn.softmax(logits_2d)
predictions = array_ops.reshape(predictions, array_ops.shape(logits))
predictions.set_shape(logits.get_shape())
return predictions
评论列表
文章目录