def _init_clusters_random(data, num_clusters, random_seed):
"""Does random initialization of clusters.
Args:
data: a list of Tensors with a matrix of data, each row is an example.
num_clusters: an integer with the number of clusters.
random_seed: Seed for PRNG used to initialize seeds.
Returns:
A Tensor with num_clusters random rows of data.
"""
assert isinstance(data, list)
num_data = tf.add_n([tf.shape(inp)[0] for inp in data])
with tf.control_dependencies([tf.assert_less_equal(num_clusters, num_data)]):
indices = tf.random_uniform([num_clusters],
minval=0,
maxval=tf.cast(num_data, tf.int64),
seed=random_seed,
dtype=tf.int64)
indices = tf.cast(indices, tf.int32) % num_data
clusters_init = embedding_lookup(data, indices, partition_strategy='div')
return clusters_init
评论列表
文章目录