def dice_accuracy(decoded_predictions, annotations, class_nums):
DiceRatio = tf.constant(0,tf.float32)
misclassnum = tf.constant(0,tf.float32)
class_num = tf.constant(class_nums,tf.float32)
sublist = []
for index in range(1,class_nums-2):
current_annotation = tf.cast(tf.equal(tf.ones_like(annotations)*index,\
annotations),tf.float32)
cureent_prediction = tf.cast(tf.equal(tf.ones_like(decoded_predictions)*index,\
decoded_predictions),tf.float32)
Overlap = tf.add(current_annotation,cureent_prediction)
Common = tf.reduce_sum(tf.cast(tf.equal(tf.ones_like(Overlap)*2,Overlap),\
tf.float32),[0,1,2,3])
annotation_num = tf.reduce_sum(current_annotation,[0,1,2,3])
predict_num = tf.reduce_sum(cureent_prediction,[0,1,2,3])
all_num = tf.add(annotation_num,predict_num)
Sub_DiceRatio = 0
Sub_DiceRatio = Common*2/tf.clip_by_value(all_num, 1e-10, 1e+10)
misclassnum = tf.cond(tf.equal(Sub_DiceRatio,0.0), lambda: misclassnum + 1, lambda: misclassnum)
sublist.append(Sub_DiceRatio)
DiceRatio = DiceRatio + Sub_DiceRatio
del Sub_DiceRatio
DiceRatio = DiceRatio/tf.clip_by_value(tf.cast((class_num-misclassnum-3),tf.float32),1e-10,1e+1000)
return DiceRatio, sublist
评论列表
文章目录