def compute_cov_a(a, classname, layer_info, fast_cnn):
batch_size = a.size(0)
if classname == 'Conv2d':
if fast_cnn:
a = _extract_patches(a, *layer_info)
a = a.view(a.size(0), -1, a.size(-1))
a = a.mean(1)
else:
a = _extract_patches(a, *layer_info)
a = a.view(-1, a.size(-1)).div_(a.size(1)).div_(a.size(2))
elif classname == 'AddBias':
is_cuda = a.is_cuda
a = torch.ones(a.size(0), 1)
if is_cuda:
a = a.cuda()
return a.t() @ (a / batch_size)
评论列表
文章目录