def compute_mean(cluster_center, x, label, K, eta):
""" Compute Mean
Input:
x: embedding of size N x D
label: cluster label of size N X 1
K: number of clusters
tf_eps: small constant
Output:
cluster_center: cluster center of size K x D
"""
tf_eps = tf.constant(1.0e-16)
cluster_size = tf.expand_dims(tf.unsorted_segment_sum(
tf.ones(label.get_shape()), label, K), 1)
cluster_center_new = (1 - eta) * tf.unsorted_segment_sum(x,
label, K) / (cluster_size + tf_eps) + eta * cluster_center
return cluster_center.assign(cluster_center_new)
评论列表
文章目录