clustering_ops.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def _full_batch_training_op(self, inputs, cluster_idx_list, cluster_centers):
    """Creates an op for training for full batch case.

    Args:
      inputs: list of input Tensors.
      cluster_idx_list: A vector (or list of vectors). Each element in the
        vector corresponds to an input row in 'inp' and specifies the cluster id
        corresponding to the input.
      cluster_centers: Tensor Ref of cluster centers.

    Returns:
      An op for doing an update of mini-batch k-means.
    """
    cluster_sums = []
    cluster_counts = []
    epsilon = tf.constant(1e-6, dtype=inputs[0].dtype)
    for inp, cluster_idx in zip(inputs, cluster_idx_list):
      with ops.colocate_with(inp):
        cluster_sums.append(tf.unsorted_segment_sum(inp,
                                                    cluster_idx,
                                                    self._num_clusters))
        cluster_counts.append(tf.unsorted_segment_sum(
            tf.reshape(tf.ones(tf.reshape(tf.shape(inp)[0], [-1])), [-1, 1]),
            cluster_idx,
            self._num_clusters))
    with ops.colocate_with(cluster_centers):
      new_clusters_centers = tf.add_n(cluster_sums) / (
          tf.cast(tf.add_n(cluster_counts), cluster_sums[0].dtype) + epsilon)
      if self._clusters_l2_normalized():
        new_clusters_centers = tf.nn.l2_normalize(new_clusters_centers, dim=1)
    return tf.assign(cluster_centers, new_clusters_centers)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号