def simplified_topk(x, k):
''' Proof-of-concept implementation of simplified topk
Note all we neend the k-th largest vaule, thus an algorithm of log(n) complexity exists.
'''
original_size = None
if x.dim() > 2:
original_size = x.size()
x = x.view(x.size(0), -1)
ax = x.data.abs().sum(0).view(-1)
topk, ids = ax.topk(x.size(-1)-k, dim=0, largest=False)
y = x.clone()
# zero out small values
for id in ids:
y[:, id] = 0
if original_size:
y = y.view(original_size)
return y
评论列表
文章目录