densenet_efficient_multi_gpu.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:efficient_densenet_pytorch 作者: gpleiss 项目源码 文件源码
def forward(self, weight, bias, input):
        # Assert we're using cudnn
        for i in ([weight, bias, input]):
            if i is not None and not(cudnn.is_acceptable(i)):
                raise Exception('You must be using CUDNN to use EfficientBatchNorm')

        # Create save variables
        self.save_mean = self.running_mean.new()
        self.save_mean.resize_as_(self.running_mean)
        self.save_var = self.running_var.new()
        self.save_var.resize_as_(self.running_var)

        # Do forward pass - store in input variable
        cur_device_id = weight.get_device()
        res = type(input)(self.storage.change_device(cur_device_id)).resize_as_(input)
        assert weight.get_device() == res.get_device(), \
            "input and output should be on the same chip!"

        torch._C._cudnn_batch_norm_forward(input, res,
                weight, bias,
                self.running_mean, self.running_var,
                self.save_mean, self.save_var,
                self.training,
                self.momentum,
                self.eps)
        return res
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号