efficient_batch_norm_test.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:efficient_densenet_pytorch 作者: gpleiss 项目源码 文件源码
def test_forward_eval_mode_computes_forward_pass():
    momentum = 0.1
    eps = 1e-5

    weight = torch.randn(10).cuda()
    bias = torch.randn(10).cuda()
    running_mean = torch.randn(10).cuda()
    running_var = torch.randn(10).abs().cuda()

    input_1 = torch.randn(4, 5).cuda()
    input_2 = torch.randn(4, 5).cuda()
    storage = torch.Storage(40).cuda()

    bn = F.batch_norm(
        input=Variable(torch.cat([input_1, input_2], dim=1)),
        running_mean=running_mean,
        running_var=running_var,
        weight=Parameter(weight),
        bias=Parameter(bias),
        training=False,
        momentum=momentum,
        eps=eps
    ).data

    input_efficient = torch.cat([input_1, input_2], dim=1)
    func = _EfficientBatchNorm(
        storage=storage,
        running_mean=running_mean,
        running_var=running_var,
        training=False,
        momentum=momentum,
        eps=eps
    )
    bn_efficient = func.forward(weight, bias, input_efficient)

    assert(almost_equal(bn, bn_efficient))
    assert(bn_efficient.storage().data_ptr() == storage.data_ptr())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号