def _entropy(self):
if self.logits.get_shape().ndims == 2:
logits_2d = self.logits
else:
logits_2d = array_ops.reshape(self.logits, [-1, self.num_classes])
histogram_2d = nn_ops.softmax(logits_2d)
ret = array_ops.reshape(
nn_ops.softmax_cross_entropy_with_logits(logits_2d, histogram_2d),
self.batch_shape())
ret.set_shape(self.get_batch_shape())
return ret
评论列表
文章目录