gmm_ops.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def _init_clusters_random(data, num_clusters, random_seed):
  """Does random initialization of clusters.

  Args:
    data: a list of Tensors with a matrix of data, each row is an example.
    num_clusters: an integer with the number of clusters.
    random_seed: Seed for PRNG used to initialize seeds.

  Returns:
    A Tensor with num_clusters random rows of data.
  """
  assert isinstance(data, list)
  num_data = tf.add_n([tf.shape(inp)[0] for inp in data])
  with tf.control_dependencies([tf.assert_less_equal(num_clusters, num_data)]):
    indices = tf.random_uniform([num_clusters],
                                minval=0,
                                maxval=tf.cast(num_data, tf.int64),
                                seed=random_seed,
                                dtype=tf.int64)
  indices = tf.cast(indices, tf.int32) % num_data
  clusters_init = embedding_lookup(data, indices, partition_strategy='div')
  return clusters_init
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号