def assign_label(label, x, cluster_center):
""" Assign Labels
Input:
x: embedding of size N x D
label: cluster label of size N X 1
K: number of clusters
tf_eps: small constant
Output:
cluster_center: cluster center of size K x D
"""
dist = pdist(x, cluster_center)
return label.assign(tf.argmin(dist, 1))
评论列表
文章目录