action_generator_maxdiff.py 文件源码

python
阅读 16 收藏 0 点赞 0 评论 0

项目:latplan 作者: guicho271828 项目源码 文件源码
def main():
    import numpy.random as random
    from trace import trace

    import sys
    if len(sys.argv) == 1:
        sys.exit("{} [directory]".format(sys.argv[0]))

    directory = sys.argv[1]
    directory_ad = "{}_ad/".format(directory)
    print("loading the Discriminator", end='...', flush=True)
    ad = Discriminator(directory_ad).load()
    print("done.")

    # valid_states  = load("{}/states.csv".format(directory))
    valid_actions = load("{}/actions.csv".format(directory))
    threshold = maxdiff(valid_actions)
    print("maxdiff:",threshold)

    states  = load("{}/generated_states.csv".format(directory))

    path = "{}/generated_actions.csv".format(directory)

    total   = states.shape[0]
    N       = states.shape[1]
    acc = 0

    try:
        print(path)
        with open(path, 'wb') as f:
            for i,s in enumerate(states):
                print("Iteration {}/{} base: {}".format(i,total,i*total), end=' ')
                diff = np.sum(np.abs(states - s),axis=1)
                neighbors = states[np.where(diff<threshold)]
                tmp_actions = np.pad(neighbors,((0,0),(0,N)),"constant")
                tmp_actions[:,N:] = s
                ys            = ad.discriminate(tmp_actions,batch_size=400000)
                valid_actions = tmp_actions[np.where(ys > 0.8)]
                acc           += len(valid_actions)
                print(len(neighbors),len(valid_actions),acc)
                np.savetxt(f,valid_actions,"%d")
    except KeyboardInterrupt:
        print("dump stopped")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号