state_filter_sd.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:latplan 作者: guicho271828 项目源码 文件源码
def main():
    import sys
    if len(sys.argv) == 1:
        sys.exit("{} [directory]".format(sys.argv[0]))

    directory = sys.argv[1]
    sd = Discriminator("{}/_sd".format(directory)).load()
    ae = ConvolutionalGumbelAE2(directory).load()

    input = "generated_states.csv"
    print("loading {}".format("{}/{}".format(directory,input)), end='...', flush=True)
    states = np.loadtxt("{}/{}".format(directory,input),dtype=np.uint8)
    print("done.")
    zs      = states.view()
    total   = states.shape[0]
    N       = states.shape[1]
    batch   = 500000
    output = "generated_states2.csv"
    try:
        print(ae.local(output))
        with open(ae.local(output), 'wb') as f:
            print("original states:",total)
            for i in range(total//batch+1):
                _zs = zs[i*batch:(i+1)*batch]
                _result = sd.discriminate(_zs,batch_size=5000).round().astype(np.uint8)
                _zs_filtered = _zs[np.where(_result > 0)[0],:]
                print("reduced  states:",len(_zs_filtered),"/",len(_zs))

                _xs = ae.decode_binary(_zs_filtered[:20],batch_size=5000).round().astype(np.uint8)
                ae.plot(_xs,path="generated_states_filtered{}.png".format(i))

                np.savetxt(f,_zs_filtered,"%d",delimiter=" ")

    except KeyboardInterrupt:
        print("dump stopped")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号