distortion_transforms.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:torchsample 作者: ncullen93 项目源码 文件源码
def __call__(self, *inputs):
        outputs = []
        for idx, _input in enumerate(inputs):
            size = _input.size()
            img_height = size[1]
            img_width = size[2]

            x_blocks = int(img_height/self.blocksize) # number of x blocks
            y_blocks = int(img_width/self.blocksize)
            ind = th.randperm(x_blocks*y_blocks)

            new = th.zeros(_input.size())
            count = 0
            for i in range(x_blocks):
                for j in range (y_blocks):
                    row = int(ind[count] / x_blocks)
                    column = ind[count] % x_blocks
                    new[:, i*self.blocksize:(i+1)*self.blocksize, j*self.blocksize:(j+1)*self.blocksize] = \
                    _input[:, row*self.blocksize:(row+1)*self.blocksize, column*self.blocksize:(column+1)*self.blocksize]
                    count += 1
            outputs.append(new)
        return outputs if idx > 1 else outputs[0]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号