efficient_conv_test.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:efficient_densenet_pytorch 作者: gpleiss 项目源码 文件源码
def test_forward_computes_forward_pass():
    weight = torch.randn(4, 8, 3, 3).cuda()
    input = torch.randn(4, 8, 4, 4).cuda()

    out = F.conv2d(
        input=Variable(input),
        weight=Parameter(weight),
        bias=None,
        stride=1,
        padding=1,
        dilation=1,
        groups=1,
    ).data

    func = _EfficientConv2d(
        stride=1,
        padding=1,
        dilation=1,
        groups=1,
    )
    out_efficient = func.forward(weight, None, input)

    assert(almost_equal(out, out_efficient))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号