tf_sparse_fit.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:tf-sparse-fit 作者: cmcneil 项目源码 文件源码
def threshold_by_percent_max(t, threshold, use_active_set=False):
    """Creates tensorflow ops to perform a thresholding of a tensor by a
    percentage of the maximum value. To be used when thresholding gradients.
    Optionally maintains an active set.

    Parameters
    ----------
    t: tensor
        The tensor to threshold by percent max.
    threshold: float
        A number between 0 and 1 that specifies the threshold.
    use_active_set: bool
        Specifies whether or not to use an active set.

    Returns
    -------
    A tensor of the same shape as t that has had all values under the threshold
    set to 0.
    """
    with tf.name_scope("threshold_by_percent_max"):
        # t = tf.convert_to_tensor(t, name="t")
        # shape = tf.shape(t)
        abs_t  = tf.abs(t)
        thresh_percentile = tf.constant(threshold, dtype=tf.float32)
        zeros = tf.zeros(shape=tf.shape(t), dtype=tf.float32)

        maximum = tf.reduce_max(abs_t, reduction_indices=[0])
        # A tensor, the same shape as t, that has the threshold values to be
        # compared against every value.
        thresh_one_voxel = tf.expand_dims(tf.mul(thresh_percentile,
                                                 maximum), 0)


        thresh_tensor = tf.tile(thresh_one_voxel,
                                tf.pack([tf.shape(t)[0], 1]))
        above_thresh_values = tf.greater_equal(abs_t, thresh_tensor)
        if use_active_set:
            active_set = tf.Variable(tf.equal(tf.ones(tf.shape(t)),
                                               tf.zeros(tf.shape(t))),
                                     name="active_set", dtype=tf.bool)

            active_set_inc = tf.assign(active_set,
                                       tf.logical_or(active_set,
                                                     above_thresh_values),
                                   name="incremented_active_set")
            active_set_size = tf.reduce_sum(tf.cast(active_set, tf.float32),
                                            name="size_of_active_set")
            return (tf.select(active_set_inc, t, zeros), active_set_size)
        else:
            return tf.select(above_thresh_values, t, zeros)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号