test_nn.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:pytorch 作者: pytorch 项目源码 文件源码
def run_conv_double_back_test(self, kern, stride, padding, chan_in, chan_out, batch_size,
                                  inp_size, dilation, no_weight, groups=1, use_cuda=False, use_bias=True):
        tensor = torch.Tensor(1)
        if use_cuda:
            tensor = tensor.cuda()

        x = Variable(tensor.new(batch_size, chan_in, inp_size, inp_size), requires_grad=True)
        x.data.normal_()
        weight = Variable(tensor.new(chan_out, chan_in // groups, kern, kern), requires_grad=True)
        weight.data.normal_()
        if use_bias:
            bias = Variable(tensor.new(chan_out), requires_grad=True)
            bias.data.normal_()
        else:
            bias = None

        def func(*inputs):
            if no_weight:
                lweight = weight
                if use_bias:
                    lx, lbias = inputs
                else:
                    lx, = inputs
                    lbias = None
            else:
                if use_bias:
                    lx, lweight, lbias = inputs
                else:
                    lx, lweight = inputs
                    lbias = None
            # We disable cudnn during forward to avoid finite difference imprecision issues
            with cudnn.flags(enabled=False):
                out = F.conv2d(lx, lweight, lbias, stride, padding, dilation, groups)
            return out

        if no_weight:
            inputs = (x, bias)
        else:
            inputs = (x, weight, bias)

        if not use_bias:
            inputs = inputs[:-1]

        dummy_out = func(*inputs)
        grad_y = Variable(tensor.new(dummy_out.size()), requires_grad=True)
        grad_y.data.normal_()

        return gradgradcheck(func, inputs, (grad_y,))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号