def test_forward_computes_forward_pass():
weight = torch.randn(4, 8, 3, 3).cuda()
input = torch.randn(4, 8, 4, 4).cuda()
out = F.conv2d(
input=Variable(input),
weight=Parameter(weight),
bias=None,
stride=1,
padding=1,
dilation=1,
groups=1,
).data
func = _EfficientConv2d(
stride=1,
padding=1,
dilation=1,
groups=1,
)
out_efficient = func.forward(weight, None, input)
assert(almost_equal(out, out_efficient))
efficient_conv_test.py 文件源码
python
阅读 30
收藏 0
点赞 0
评论 0
评论列表
文章目录