dump_significance.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:latplan 作者: guicho271828 项目源码 文件源码
def run(ae,xs):
    zs = ae.encode_binary(xs)
    ys = ae.decode_binary(zs)
    mod_ys = []
    correlations = []
    print(ys.shape)
    print("corrlations:")
    print("bit \ image  {}".format(range(len(xs))))
    for i in range(ae.N):
        mod_zs = np.copy(zs)
        # increase the latent value from 0 to 1 and check the difference
        for j in range(11):
            mod_zs[:,i] = j / 10.0
            mod_ys.append(ae.decode_binary(mod_zs))
        zero_zs,one_zs = np.copy(zs),np.copy(zs)
        zero_zs[:,i] = 0.
        one_zs[:,i] = 1.
        correlation = np.mean(np.square(ae.decode_binary(zero_zs) - ae.decode_binary(one_zs)),
                              axis=(1,2))
        correlations.append(correlation)
        print("{:>5} {}".format(i,correlation))
    plot_grid2(np.einsum("ib...->bi...",np.array(mod_ys)).reshape((-1,)+ys.shape[1:]),
               w=11,path=ae.local("dump_significance.png"))
    return np.einsum("ib->bi",correlations)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号