action_discriminator2.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:latplan 作者: guicho271828 项目源码 文件源码
def prepare(data):
    num = len(data)
    dim = data.shape[1]//2
    print(data.shape,num,dim)
    pre, suc = data[:,:dim], data[:,dim:]

    suc_invalid = np.copy(suc)
    random.shuffle(suc_invalid)
    data_invalid = np.concatenate((pre,suc_invalid),axis=1)

    ai = data_invalid.view([('', data_invalid.dtype)] * 2*dim)
    av = data.view        ([('', data.dtype)]         * 2*dim)
    data_invalid = np.setdiff1d(ai, av).view(data_invalid.dtype).reshape((-1, 2*dim))

    inputs = np.concatenate((data,data_invalid),axis=0)
    outputs = np.concatenate((np.ones((num,1)),np.zeros((len(data_invalid),1))),axis=0)
    print(inputs.shape,outputs.shape)
    io = np.concatenate((inputs,outputs),axis=1)
    random.shuffle(io)

    train_n = int(2*num*0.9)
    train, test = io[:train_n], io[train_n:]
    train_in, train_out = train[:,:dim*2], train[:,dim*2:]
    test_in, test_out = test[:,:dim*2], test[:,dim*2:]

    return train_in, train_out, test_in, test_out


# default values
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号