losses.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:AGE 作者: DmitryUlyanov 项目源码 文件源码
def sample_entropy(samples):

        # Assume B x C input

    dist_mat = pairwise_euclidean(samples)

    # Get max and add it to diag
    m = dist_mat.max().detach()
    dist_mat_d = dist_mat + \
        Variable(torch.eye(dist_mat.size(0)) * (m.data[0] + 1)).cuda()

    entropy = (dist_mat_d.min(1)[0] + 1e-4).log().sum()

    entropy *= (samples.size(1) + 0.) / samples.size(0)

    return entropy
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号