def test_forward_eval_mode_computes_forward_pass():
momentum = 0.1
eps = 1e-5
weight = torch.randn(10).cuda()
bias = torch.randn(10).cuda()
running_mean = torch.randn(10).cuda()
running_var = torch.randn(10).abs().cuda()
input_1 = torch.randn(4, 5).cuda()
input_2 = torch.randn(4, 5).cuda()
storage = torch.Storage(40).cuda()
bn = F.batch_norm(
input=Variable(torch.cat([input_1, input_2], dim=1)),
running_mean=running_mean,
running_var=running_var,
weight=Parameter(weight),
bias=Parameter(bias),
training=False,
momentum=momentum,
eps=eps
).data
input_efficient = torch.cat([input_1, input_2], dim=1)
func = _EfficientBatchNorm(
storage=storage,
running_mean=running_mean,
running_var=running_var,
training=False,
momentum=momentum,
eps=eps
)
bn_efficient = func.forward(weight, bias, input_efficient)
assert(almost_equal(bn, bn_efficient))
assert(bn_efficient.storage().data_ptr() == storage.data_ptr())
efficient_batch_norm_test.py 文件源码
python
阅读 25
收藏 0
点赞 0
评论 0
评论列表
文章目录