def my_clustering_loss(net_out,feature_map):
net_out_vec = tf.reshape(net_out,[-1,1])
pix_num = net_out_vec.get_shape().as_list()[0]
feature_vec = tf.reshape(feature_map,[pix_num,-1])
net_out_vec = tf.div(net_out_vec, tf.reduce_sum(net_out_vec,keep_dims=True))
not_net_out_vec = tf.subtract(tf.constant(1.),net_out_vec)
mean_fg_var = tf.get_variable('mean_bg',shape = [feature_vec.get_shape().as_list()[1],1], trainable=False)
mean_bg_var = tf.get_variable('mean_fg',shape = [feature_vec.get_shape().as_list()[1],1], trainable=False)
mean_bg = tf.matmul(not_net_out_vec,feature_vec,True)
mean_fg = tf.matmul(net_out_vec,feature_vec,True)
feature_square = tf.square(feature_vec)
loss = tf.add(tf.matmul(net_out_vec, tf.reduce_sum(tf.square(tf.subtract(feature_vec, mean_fg_var)), 1, True), True),
tf.matmul(not_net_out_vec, tf.reduce_sum(tf.square(tf.subtract(feature_vec,mean_bg_var)), 1, True), True))
with tf.control_dependencies([loss]):
update_mean = tf.group(tf.assign(mean_fg_var,mean_fg),tf.assign(mean_bg_var,mean_bg))
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_mean)
return loss
评论列表
文章目录