minibatch_sampler.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:tensorflow 作者: luyishisi 项目源码 文件源码
def subsample_indicator(indicator, num_samples):
    """Subsample indicator vector.

    Given a boolean indicator vector with M elements set to `True`, the function
    assigns all but `num_samples` of these previously `True` elements to
    `False`. If `num_samples` is greater than M, the original indicator vector
    is returned.

    Args:
      indicator: a 1-dimensional boolean tensor indicating which elements
        are allowed to be sampled and which are not.
      num_samples: int32 scalar tensor

    Returns:
      a boolean tensor with the same shape as input (indicator) tensor
    """
    indices = tf.where(indicator)
    indices = tf.random_shuffle(indices)
    indices = tf.reshape(indices, [-1])

    num_samples = tf.minimum(tf.size(indices), num_samples)
    selected_indices = tf.slice(indices, [0], tf.reshape(num_samples, [1]))

    selected_indicator = ops.indices_to_dense_vector(selected_indices,
                                                     tf.shape(indicator)[0])

    return tf.equal(selected_indicator, 1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号