def create_cost_soft_min_distance_valid(self, c, s, v):
"""Creates a soft-min distance of the centers to the points"""
c_shape = c.get_shape().as_list();
s_shape = s.get_shape().as_list();
#expand matrices
cc = tf.reshape(c, [c_shape[0], c_shape[1], c_shape[2], 1]);
mm = tf.reduce_max(v); #hack for batch size = 1
ss = tf.slice(s, [0,0,0], [-1,mm,-1]);
ss = tf.reshape(ss, [s_shape[0], s_shape[1], s_shape[2], 1]);
ss = tf.transpose(ss, perm = [0,3,2,1]);
cc = tf.tile(cc, [1, 1, 1, s_shape[0]]);
ss = tf.tile(ss, [1, c_shape[0], 1, 1]);
#pairwise distances
dist2 = tf.sqrt(tf.reduce_sum(tf.squared_difference(cc,ss), reduction_indices = 2));
dist2 = tf.reduce_mean(dist2, reduction_indices=0); # hack: get rid of batches here
#softmin
distmin = tf.reduce_sum(tf.mul(tf.nn.softmax(tf.scalar_mul(tf.constant(-1.0,"float32"), dist2)), dist2),reduction_indices = 1);
return tf.reduce_mean(distmin);
machine_vision_2.py 文件源码
python
阅读 26
收藏 0
点赞 0
评论 0
评论列表
文章目录