def build(self, predictions, targets, inputs=None):
""" Prints the number of each kind of prediction """
self.built = True
pshape = predictions.get_shape()
self.inner_metric.build(predictions, targets, inputs)
with tf.name_scope(self.name):
if len(pshape) == 1 or (len(pshape) == 2 and int(pshape[1]) == 1):
self.name = self.name or "binary_prediction_counts"
y, idx, count = tf.unique_with_counts(tf.argmax(predictions))
self.tensor = tf.Print(self.inner_metric, [y, count], name=self.inner_metric.name)
else:
self.name = self.name or "categorical_prediction_counts"
y, idx, count = tf.unique_with_counts(tf.argmax(predictions, dimension=1))
self.tensor = tf.Print(self.inner_metric.tensor, [y, count], name=self.inner_metric.name)
评论列表
文章目录