rbm_adv.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:SteinGAN 作者: DartML 项目源码 文件源码
def rbf_kernel(X):

    XY = T.dot(X, X.T)
    x2 = T.sum(X**2, axis=1).dimshuffle(0, 'x')
    X2e = T.repeat(x2, X.shape[0], axis=1)
    H = X2e +  X2e.T - 2. * XY

    V = H.flatten()
    # median distance
    h = T.switch(T.eq((V.shape[0] % 2), 0),
        # if even vector
        T.mean(T.sort(V)[ ((V.shape[0] // 2) - 1) : ((V.shape[0] // 2) + 1) ]),
        # if odd vector
        T.sort(V)[V.shape[0] // 2])

    h = T.sqrt(.5 * h / T.log(H.shape[0].astype('float32') + 1.)) 

    # compute the rbf kernel
    kxy = T.exp(-H / (h ** 2) / 2.0)

    dxkxy = -T.dot(kxy, X)
    sumkxy = T.sum(kxy, axis=1).dimshuffle(0, 'x')
    dxkxy = T.add(dxkxy, T.mul(X, sumkxy)) / (h ** 2)

    return kxy, dxkxy
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号