action_compressor.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:latplan 作者: guicho271828 项目源码 文件源码
def prepare(data):
    num = len(data)
    dim = data.shape[1]//2
    print("in prepare: ",data.shape,num,dim)
    pre, suc = data[:,:dim], data[:,dim:]

    suc_invalid = np.copy(suc)
    random.shuffle(suc_invalid)

    diff_valid   = suc         - pre
    diff_invalid = suc_invalid - pre

    inputs = np.concatenate((diff_valid,diff_invalid),axis=0)
    outputs = np.concatenate((np.ones((num,1)),np.zeros((num,1))),axis=0)
    print("in prepare: ",inputs.shape,outputs.shape)
    io = np.concatenate((inputs,outputs),axis=1)
    random.shuffle(io)

    train_n = int(2*num*0.9)
    train, test = io[:train_n], io[train_n:]
    train_in, train_out = train[:,:dim], train[:,dim:]
    test_in, test_out = test[:,:dim], test[:,dim:]
    print("in prepare: ",train_in.shape, train_out.shape, test_in.shape, test_out.shape)

    return train_in, train_out, test_in, test_out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号