def test_conv2d_depthwise(self):
n = 6
x = Variable(torch.randn(1,n,5,5).double().cuda(), requires_grad=True)
w = Variable(torch.randn(n,1,3,3).double().cuda(), requires_grad=True)
y_fast = P.conv2d_depthwise(x, w, padding=1)
y_ref = F.conv2d(x, w, padding=1, groups=n)
go = torch.randn(y_fast.size()).double().cuda()
self.assertLess((y_fast - y_ref).data.abs().max(), 1e-9)
x.requires_grad = True
w.requires_grad = True
y_fast.backward(go)
gx_fast = x.grad.data.clone()
gw_fast = w.grad.data.clone()
x.grad.data.zero_()
w.grad.data.zero_()
y_ref.backward(go)
gx_ref = x.grad.data.clone()
gw_ref = w.grad.data.clone()
self.assertTrue(gradcheck(partial(P.conv2d_depthwise, padding=1), (x, w,)))
评论列表
文章目录