def test_backward_computes_backward_pass():
weight = torch.randn(4, 8, 3, 3).cuda()
input = torch.randn(4, 8, 4, 4).cuda()
input_var = Variable(input, requires_grad=True)
weight_var = Parameter(weight)
out_var = F.conv2d(
input=input_var,
weight=weight_var,
bias=None,
stride=1,
padding=1,
dilation=1,
groups=1,
)
out_var.backward(gradient=input_var.data.clone().fill_(1))
out = out_var.data
input_grad = input_var.grad.data
weight_grad = weight_var.grad.data
func = _EfficientConv2d(
stride=1,
padding=1,
dilation=1,
groups=1,
)
out_efficient = func.forward(weight, None, input)
weight_grad_efficient, _, input_grad_efficient = func.backward(
weight, None, input, input.clone().fill_(1))
assert(almost_equal(out, out_efficient))
assert(almost_equal(input_grad, input_grad_efficient))
assert(almost_equal(weight_grad, weight_grad_efficient))
efficient_conv_test.py 文件源码
python
阅读 24
收藏 0
点赞 0
评论 0
评论列表
文章目录