def __call__(self, *inputs):
outputs = []
for idx, _input in enumerate(inputs):
size = _input.size()
img_height = size[1]
img_width = size[2]
x_blocks = int(img_height/self.blocksize) # number of x blocks
y_blocks = int(img_width/self.blocksize)
ind = th.randperm(x_blocks*y_blocks)
new = th.zeros(_input.size())
count = 0
for i in range(x_blocks):
for j in range (y_blocks):
row = int(ind[count] / x_blocks)
column = ind[count] % x_blocks
new[:, i*self.blocksize:(i+1)*self.blocksize, j*self.blocksize:(j+1)*self.blocksize] = \
_input[:, row*self.blocksize:(row+1)*self.blocksize, column*self.blocksize:(column+1)*self.blocksize]
count += 1
outputs.append(new)
return outputs if idx > 1 else outputs[0]
评论列表
文章目录