utils.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:DeepCellSeg 作者: arbellea 项目源码 文件源码
def my_clustering_loss(net_out,feature_map):

    net_out_vec = tf.reshape(net_out,[-1,1])
    pix_num = net_out_vec.get_shape().as_list()[0]
    feature_vec = tf.reshape(feature_map,[pix_num,-1])
    net_out_vec = tf.div(net_out_vec, tf.reduce_sum(net_out_vec,keep_dims=True))
    not_net_out_vec = tf.subtract(tf.constant(1.),net_out_vec)
    mean_fg_var = tf.get_variable('mean_bg',shape = [feature_vec.get_shape().as_list()[1],1], trainable=False)
    mean_bg_var = tf.get_variable('mean_fg',shape = [feature_vec.get_shape().as_list()[1],1], trainable=False)

    mean_bg = tf.matmul(not_net_out_vec,feature_vec,True)
    mean_fg = tf.matmul(net_out_vec,feature_vec,True)

    feature_square = tf.square(feature_vec)

    loss = tf.add(tf.matmul(net_out_vec, tf.reduce_sum(tf.square(tf.subtract(feature_vec, mean_fg_var)), 1, True), True),
                  tf.matmul(not_net_out_vec, tf.reduce_sum(tf.square(tf.subtract(feature_vec,mean_bg_var)), 1, True), True))
    with tf.control_dependencies([loss]):
        update_mean = tf.group(tf.assign(mean_fg_var,mean_fg),tf.assign(mean_bg_var,mean_bg))
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_mean)

    return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号