def add_cifar_10(x, x_cifar_10, sh=True):
"""Add cifar 10 as background."""
sz = x.shape
mask = (x == 0) * 1.
# binarize cifar
back = x_cifar_10.reshape(x_cifar_10.shape[0], 3, 32, 32).mean(1)
back = back[:, 2:30, 2:30] # take 28x28 from the center.
back /= 255.
back = back.astype(np.float32)
# shuffle the index
if sh:
ind = np.random.randint(0, x_cifar_10.shape[0], sz[0]) # the index
for i in range(10):
np.random.shuffle(ind)
else:
# used only to plot images for paper.
assert x_cifar_10.shape[0] == sz[0]
ind = np.arange(0, sz[0]) # the index
back_sh = back[ind]
back_sh = back_sh.reshape(back_sh.shape[0], -1)
back_ready = np.multiply(back_sh, mask)
out = np.clip(x + back_ready, 0., 1.)
return out
tools.py 文件源码
python
阅读 17
收藏 0
点赞 0
评论 0
评论列表
文章目录