machine_vision_2.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:CElegansBehaviour 作者: ChristophKirst 项目源码 文件源码
def create_cost_soft_min_distance_valid(self, c, s, v):
    """Creates a soft-min distance of the centers to the points"""
    c_shape = c.get_shape().as_list();        
    s_shape = s.get_shape().as_list();

    #expand matrices
    cc = tf.reshape(c, [c_shape[0], c_shape[1], c_shape[2], 1]);    
    mm = tf.reduce_max(v); #hack for batch size = 1
    ss = tf.slice(s, [0,0,0], [-1,mm,-1]);
    ss = tf.reshape(ss, [s_shape[0], s_shape[1], s_shape[2], 1]);
    ss = tf.transpose(ss, perm = [0,3,2,1]);
    cc = tf.tile(cc, [1, 1, 1, s_shape[0]]);
    ss = tf.tile(ss, [1, c_shape[0], 1, 1]);

    #pairwise distances
    dist2 = tf.sqrt(tf.reduce_sum(tf.squared_difference(cc,ss), reduction_indices = 2));
    dist2 = tf.reduce_mean(dist2, reduction_indices=0); # hack: get rid of batches here 

    #softmin
    distmin = tf.reduce_sum(tf.mul(tf.nn.softmax(tf.scalar_mul(tf.constant(-1.0,"float32"), dist2)), dist2),reduction_indices = 1);
    return tf.reduce_mean(distmin);
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号