steingan_celeba.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:SteinGAN 作者: DartML 项目源码 文件源码
def rbf_kernel(X0):
    XY = T.dot(X0, X0.transpose())
    x2 = T.reshape(T.sum(T.square(X0), axis=1), (X0.shape[0], 1))
    X2e = T.repeat(x2, X0.shape[0], axis=1)
    H = T.sub(T.add(X2e, X2e.transpose()), 2 * XY)

    V = H.flatten()

    # median distance
    h = T.switch(T.eq((V.shape[0] % 2), 0),
        # if even vector
        T.mean(T.sort(V)[ ((V.shape[0] // 2) - 1) : ((V.shape[0] // 2) + 1) ]),
        # if odd vector
        T.sort(V)[V.shape[0] // 2])

    h = T.sqrt(0.5 * h / T.log(X0.shape[0].astype('float32') + 1.0)) / 2.

    Kxy = T.exp(-H / h ** 2 / 2.0)
    neighbors = T.argsort(H, axis=1)[:, 1]

    return Kxy, neighbors, h
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号