def kSparse(self, x, topk):
print 'run regular k-sparse'
dim = int(x.get_shape()[1])
if topk > dim:
warnings.warn('Warning: topk should not be larger than dim: %s, found: %s, using %s' % (dim, topk, dim))
topk = dim
k = dim - topk
values, indices = tf.nn.top_k(-x, k) # indices will be [[0, 1], [2, 1]], values will be [[6., 2.], [5., 4.]]
# We need to create full indices like [[0, 0], [0, 1], [1, 2], [1, 1]]
my_range = tf.expand_dims(tf.range(0, tf.shape(indices)[0]), 1) # will be [[0], [1]]
my_range_repeated = tf.tile(my_range, [1, k]) # will be [[0, 0], [1, 1]]
full_indices = tf.stack([my_range_repeated, indices], axis=2) # change shapes to [N, k, 1] and [N, k, 1], to concatenate into [N, k, 2]
full_indices = tf.reshape(full_indices, [-1, 2])
to_reset = tf.sparse_to_dense(full_indices, tf.shape(x), tf.reshape(values, [-1]), default_value=0., validate_indices=False)
res = tf.add(x, to_reset)
return res
评论列表
文章目录