gridgen.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:lr-gan.pytorch 作者: jwyang 项目源码 文件源码
def forward(self, input1):
        self.input1 = input1
        output = torch.zeros(torch.Size([input1.size(0)]) + self.grid.size())
        self.batchgrid = torch.zeros(torch.Size([input1.size(0)]) + self.grid.size())
        for i in range(input1.size(0)):
            self.batchgrid[i] = self.grid

        if input1.is_cuda:
            self.batchgrid = self.batchgrid.cuda()
            output = output.cuda()

        batchgrid_temp = self.batchgrid.view(-1, self.height*self.width, 3)
        batchgrid_temp.contiguous()
        input_temp = torch.transpose(input1, 1, 2)
        input_temp.contiguous()
        output_temp = torch.bmm(batchgrid_temp, input_temp)
        output = output_temp.view(-1, self.height, self.width, 2)
        output.contiguous()
        return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号