action_generator_sdprune.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:latplan 作者: guicho271828 项目源码 文件源码
def main():
    import numpy.random as random
    from trace import trace

    import sys
    if len(sys.argv) == 1:
        sys.exit("{} [directory]".format(sys.argv[0]))

    directory = sys.argv[1]
    directory_ad = "{}_ad/".format(directory)
    print("loading the Discriminator", end='...', flush=True)
    ad = Discriminator(directory_ad).load()
    print("done.")
    name = "generated_actions.csv"

    print("loading {}".format("{}/generated_states.csv".format(directory)), end='...', flush=True)
    states  = np.loadtxt("{}/generated_states.csv".format(directory),dtype=np.uint8)
    print("done.")
    total   = states.shape[0]
    N       = states.shape[1]
    actions = np.pad(states,((0,0),(0,N)),"constant")

    acc = 0

    try:
        print(ad.local(name))
        with open(ad.local(name), 'wb') as f:
            for i,s in enumerate(states):
                print("Iteration {}/{} base: {}".format(i,total,i*total), end=' ')
                actions[:,N:] = s
                ys            = ad.discriminate(actions,batch_size=400000)
                valid_actions = actions[np.where(ys > 0.8)]
                acc           += len(valid_actions)
                print(len(valid_actions),acc)
                np.savetxt(f,valid_actions,"%d")
    except KeyboardInterrupt:
        print("dump stopped")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号