def meanShift(n_updates=-1):
X1 = tf.expand_dims(tf.transpose(input_X), 0)
X2 = tf.expand_dims(input_X, 0)
C = init_C
sbs_C = tf.TensorArray(dtype=tf.float32, size=10000, infer_shape=False)
sbs_C = sbs_C.write(0, init_C)
def _mean_shift_step(C):
C = tf.expand_dims(C, 2)
Y = tf.reduce_sum(tf.pow((C - X1) / window_radius, 2), axis=1)
gY = tf.exp(-Y)
num = tf.reduce_sum(tf.expand_dims(gY, 2) * X2, axis=1)
denom = tf.reduce_sum(gY, axis=1, keep_dims=True)
C = num / denom
return C
if n_updates > 0:
for i in range(n_updates):
C = _mean_shift_step(C)
sbs_C = sbs_C.write(i + 1, C)
else:
def _mean_shift(i, C, sbs_C, max_diff):
new_C = _mean_shift_step(C)
max_diff = tf.reshape(tf.reduce_max(tf.sqrt(tf.reduce_sum(tf.pow(new_C - C, 2), axis=1))), [])
sbs_C = sbs_C.write(i + 1, new_C)
return i + 1, new_C, sbs_C, max_diff
def _cond(i, C, sbs_C, max_diff):
return max_diff > 1e-5
n_updates, C, sbs_C, _ = tf.while_loop(cond=_cond,
body=_mean_shift,
loop_vars=(tf.constant(0), C, sbs_C, tf.constant(1e10)))
n_updates = tf.Print(n_updates, [n_updates])
return C, sbs_C.gather(tf.range(n_updates + 1))
评论列表
文章目录