roi_pool.py 文件源码

python
阅读 69 收藏 0 点赞 0 评论 0

项目:yolo2-pytorch 作者: longcw 项目源码 文件源码
def forward(self, features, rois):
        batch_size, num_channels, data_height, data_width = features.size()
        num_rois = rois.size()[0]
        output = torch.zeros(num_rois, num_channels, self.pooled_height, self.pooled_width)
        argmax = torch.IntTensor(num_rois, num_channels, self.pooled_height, self.pooled_width).zero_()

        if not features.is_cuda:
            _features = features.permute(0, 2, 3, 1)
            roi_pooling.roi_pooling_forward(self.pooled_height, self.pooled_width, self.spatial_scale,
                                            _features, rois, output)
            # output = output.cuda()
        else:
            output = output.cuda()
            argmax = argmax.cuda()
            roi_pooling.roi_pooling_forward_cuda(self.pooled_height, self.pooled_width, self.spatial_scale,
                                                 features, rois, output, argmax)
            self.output = output
            self.argmax = argmax
            self.rois = rois
            self.feature_size = features.size()

        return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号