def _test_InstanceNorm(self, cls, input):
b, c = input.size(0), input.size(1)
input_var = Variable(input)
IN = cls(c, eps=0)
output = IN(input_var)
out_reshaped = output.transpose(1, 0).contiguous().view(c, -1)
mean = out_reshaped.mean(1)
var = out_reshaped.var(1, unbiased=False)
self.assertAlmostEqual(torch.abs(mean.data).mean(), 0, delta=1e-5)
self.assertAlmostEqual(torch.abs(var.data).mean(), 1, delta=1e-5)
# If momentum==1 running_mean/var should be
# equal to mean/var of the input
IN = cls(c, momentum=1, eps=0)
output = IN(input_var)
input_reshaped = input_var.transpose(1, 0).contiguous().view(c, -1)
mean = input_reshaped.mean(1)
input_reshaped = input_var.transpose(1, 0).contiguous().view(c, b, -1)
var = input_reshaped.var(2, unbiased=True)[:, :]
self.assertAlmostEqual(torch.abs(mean.data - IN.running_mean).mean(), 0, delta=1e-5)
self.assertAlmostEqual(torch.abs(var.data.mean(1) - IN.running_var).mean(), 0, delta=1e-5)
评论列表
文章目录