def single_label(predictions_batch, one_hot_labels_batch, moving_average=True):
with tf.variable_scope('metrics'):
shape = predictions_batch.get_shape().as_list()
batch_size, num_outputs = shape[0], shape[1]
# get the most probable label
predicted_batch = tf.argmax(predictions_batch, axis=1)
real_label_batch = tf.argmax(one_hot_labels_batch, axis=1)
# tp, tn, fp, fn
predicted_bool = tf.cast(tf.one_hot(predicted_batch, depth=num_outputs), dtype=tf.bool)
real_bool = tf.cast(tf.one_hot(real_label_batch, depth=num_outputs), dtype=tf.bool)
d = _metrics(predicted_bool, real_bool, moving_average)
# confusion matrix
confusion_batch = tf.confusion_matrix(labels=real_label_batch, predictions=predicted_batch,
num_classes=num_outputs)
if moving_average:
# calculate moving averages
confusion_batch = tf.cast(confusion_batch, dtype=tf.float32)
ema = tf.train.ExponentialMovingAverage(decay=0.9)
update_op = ema.apply([confusion_batch])
confusion_matrix = ema.average(confusion_batch)
d['update_op'] = [d['update_op'], update_op]
else:
# accumulative
confusion_matrix = tf.Variable(tf.zeros([num_outputs, num_outputs], dtype=tf.int32),
name='confusion_matrix', trainable=False)
confusion_matrix = tf.assign_add(confusion_matrix, confusion_batch)
d['confusion_matrix'] = confusion_matrix
return d
评论列表
文章目录