def _preprocess(self, logits, targets):
# Get most probable class of the output
logits = tf.arg_max(logits, dimension=1)
# If one-hot provided, transform into class
if targets.get_shape().ndims > 2:
targets = tf.arg_max(targets, dimension=1)
# Erase singletion dimension if exists
if targets.get_shape().ndims > 1:
targets = tf.squeeze(targets, axis=1)
return logits, targets
评论列表
文章目录