def sample_entropy(samples):
# Assume B x C input
dist_mat = pairwise_euclidean(samples)
# Get max and add it to diag
m = dist_mat.max().detach()
dist_mat_d = dist_mat + \
Variable(torch.eye(dist_mat.size(0)) * (m.data[0] + 1)).cuda()
entropy = (dist_mat_d.min(1)[0] + 1e-4).log().sum()
entropy *= (samples.size(1) + 0.) / samples.size(0)
return entropy
评论列表
文章目录