def normalized_cross_correlation(self):
w = self.weight.view(self.weight.size(0), -1)
t_norm = torch.norm(w, p=2, dim=1)
if self.in_channels == 1 & sum(self.kernel_size) == 1:
ncc = w.squeeze() / torch.norm(self.t0_norm, p=2)
ncc = ncc - self.start_ncc
return ncc
#mean = torch.mean(w, dim=1).unsqueeze(1).expand_as(w)
mean = torch.mean(w, dim=1).unsqueeze(1) # 0.2 broadcasting
t_factor = w - mean
h_product = self.t0_factor * t_factor
cov = torch.sum(h_product, dim=1) # (w.size(1) - 1)
# had normalization code commented out
denom = self.t0_norm * t_norm
ncc = cov / denom
ncc = ncc - self.start_ncc
return ncc
评论列表
文章目录