def wsparsify(w_gpu, percentage):
"""
Keeps only as many entries nonzero as specified by percentage.
"""
w = w_gpu.get()
vals = sort(w)[::-1]
idx = floor(prod(w.shape()) * percentage/100)
zw_gpu = cua.zeros_like(w_gpu) # gpu array filled with zeros
tw_gpu = cua.empty_like(w_gpu) # gpu array containing threshold
tw_gpu.fill(vals[idx])
w_gpu = cua.if_positive(w_gpu > tw_gpu, w_gpu, zw_gpu)
del zw_gpu
del tw_gpu
return w_gpu
评论列表
文章目录