ops.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:3D_Dense_Transformer_Networks 作者: JohnYC1995 项目源码 文件源码
def dice_accuracy(decoded_predictions, annotations, class_nums):
    DiceRatio = tf.constant(0,tf.float32)
    misclassnum = tf.constant(0,tf.float32)
    class_num   = tf.constant(class_nums,tf.float32) 
    sublist   =  []
    for index in range(1,class_nums-2):
        current_annotation   =     tf.cast(tf.equal(tf.ones_like(annotations)*index,\
                                        annotations),tf.float32)
        cureent_prediction   =     tf.cast(tf.equal(tf.ones_like(decoded_predictions)*index,\
                                            decoded_predictions),tf.float32)
        Overlap              =     tf.add(current_annotation,cureent_prediction)
        Common               =     tf.reduce_sum(tf.cast(tf.equal(tf.ones_like(Overlap)*2,Overlap),\
                                                            tf.float32),[0,1,2,3])
        annotation_num       =     tf.reduce_sum(current_annotation,[0,1,2,3])
        predict_num          =     tf.reduce_sum(cureent_prediction,[0,1,2,3])
        all_num              =     tf.add(annotation_num,predict_num)
        Sub_DiceRatio        =     0       
    Sub_DiceRatio        =     Common*2/tf.clip_by_value(all_num, 1e-10, 1e+10)
        misclassnum          =     tf.cond(tf.equal(Sub_DiceRatio,0.0), lambda: misclassnum + 1, lambda: misclassnum)
        sublist.append(Sub_DiceRatio)
    DiceRatio            =     DiceRatio + Sub_DiceRatio
    del Sub_DiceRatio
    DiceRatio                =     DiceRatio/tf.clip_by_value(tf.cast((class_num-misclassnum-3),tf.float32),1e-10,1e+1000)
    return DiceRatio, sublist
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号