def threshold_by_percent_max(t, threshold, use_active_set=False):
"""Creates tensorflow ops to perform a thresholding of a tensor by a
percentage of the maximum value. To be used when thresholding gradients.
Optionally maintains an active set.
Parameters
----------
t: tensor
The tensor to threshold by percent max.
threshold: float
A number between 0 and 1 that specifies the threshold.
use_active_set: bool
Specifies whether or not to use an active set.
Returns
-------
A tensor of the same shape as t that has had all values under the threshold
set to 0.
"""
with tf.name_scope("threshold_by_percent_max"):
# t = tf.convert_to_tensor(t, name="t")
# shape = tf.shape(t)
abs_t = tf.abs(t)
thresh_percentile = tf.constant(threshold, dtype=tf.float32)
zeros = tf.zeros(shape=tf.shape(t), dtype=tf.float32)
maximum = tf.reduce_max(abs_t, reduction_indices=[0])
# A tensor, the same shape as t, that has the threshold values to be
# compared against every value.
thresh_one_voxel = tf.expand_dims(tf.mul(thresh_percentile,
maximum), 0)
thresh_tensor = tf.tile(thresh_one_voxel,
tf.pack([tf.shape(t)[0], 1]))
above_thresh_values = tf.greater_equal(abs_t, thresh_tensor)
if use_active_set:
active_set = tf.Variable(tf.equal(tf.ones(tf.shape(t)),
tf.zeros(tf.shape(t))),
name="active_set", dtype=tf.bool)
active_set_inc = tf.assign(active_set,
tf.logical_or(active_set,
above_thresh_values),
name="incremented_active_set")
active_set_size = tf.reduce_sum(tf.cast(active_set, tf.float32),
name="size_of_active_set")
return (tf.select(active_set_inc, t, zeros), active_set_size)
else:
return tf.select(above_thresh_values, t, zeros)
评论列表
文章目录