action_generator.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:latplan 作者: guicho271828 项目源码 文件源码
def main():
    import numpy.random as random
    from trace import trace

    import sys
    if len(sys.argv) == 1:
        sys.exit("{} [directory]".format(sys.argv[0]))

    directory = sys.argv[1]
    directory_ad = "{}_ad/".format(directory)
    discriminator = Discriminator(directory_ad).load()
    name = "generated_actions.csv"

    N = discriminator.net.input_shape[1]
    lowbit  = 20
    highbit = N - lowbit
    print("batch size: {}".format(2**lowbit))

    xs   = (((np.arange(2**lowbit )[:,None] & (1 << np.arange(N)))) > 0).astype(int)
    # xs_h = (((np.arange(2**highbit)[:,None] & (1 << np.arange(highbit)))) > 0).astype(int)

    try:
        print(discriminator.local(name))
        with open(discriminator.local(name), 'wb') as f:
            for i in range(2**highbit):
                print("Iteration {}/{} base: {}".format(i,2**highbit,i*(2**lowbit)), end=' ')
                # h = np.binary_repr(i*(2**lowbit), width=N)
                # print(h)
                # xs_h = np.unpackbits(np.array([i*(2**lowbit)],dtype=int))
                xs_h = (((np.array([i])[:,None] & (1 << np.arange(highbit)))) > 0).astype(int)
                xs[:,lowbit:] = xs_h
                # print(xs_h)
                # print(xs[:10])
                ys = discriminator.discriminate(xs,batch_size=100000)
                ind = np.where(ys > 0.5)
                valid_xs = xs[ind]
                print(len(valid_xs))
                np.savetxt(f,valid_xs,"%d")
    except KeyboardInterrupt:
        print("dump stopped")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号