cluster.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:section-detection 作者: gulfaraz 项目源码 文件源码
def meanShift(n_updates=-1):
    X1 = tf.expand_dims(tf.transpose(input_X), 0)
    X2 = tf.expand_dims(input_X, 0)
    C = init_C

    sbs_C = tf.TensorArray(dtype=tf.float32, size=10000, infer_shape=False)
    sbs_C = sbs_C.write(0, init_C)

    def _mean_shift_step(C):
        C = tf.expand_dims(C, 2)
        Y = tf.reduce_sum(tf.pow((C - X1) / window_radius, 2), axis=1)
        gY = tf.exp(-Y)
        num = tf.reduce_sum(tf.expand_dims(gY, 2) * X2, axis=1)
        denom = tf.reduce_sum(gY, axis=1, keep_dims=True)
        C = num / denom
        return C

    if n_updates > 0:
        for i in range(n_updates):
            C = _mean_shift_step(C)
            sbs_C = sbs_C.write(i + 1, C)
    else:
        def _mean_shift(i, C, sbs_C, max_diff):
            new_C = _mean_shift_step(C)
            max_diff = tf.reshape(tf.reduce_max(tf.sqrt(tf.reduce_sum(tf.pow(new_C - C, 2), axis=1))), [])
            sbs_C = sbs_C.write(i + 1, new_C)
            return i + 1, new_C, sbs_C, max_diff

        def _cond(i, C, sbs_C, max_diff):
            return max_diff > 1e-5

        n_updates, C, sbs_C, _ = tf.while_loop(cond=_cond,
                                       body=_mean_shift,
                                       loop_vars=(tf.constant(0), C, sbs_C, tf.constant(1e10)))

        n_updates = tf.Print(n_updates, [n_updates])


    return C, sbs_C.gather(tf.range(n_updates + 1))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号