def prepare(data):
num = len(data)
dim = data.shape[1]//2
print(data.shape,num,dim)
pre, suc = data[:,:dim], data[:,dim:]
suc_invalid = np.copy(suc)
random.shuffle(suc_invalid)
data_invalid = np.concatenate((pre,suc_invalid),axis=1)
ai = data_invalid.view([('', data_invalid.dtype)] * 2*dim)
av = data.view ([('', data.dtype)] * 2*dim)
data_invalid = np.setdiff1d(ai, av).view(data_invalid.dtype).reshape((-1, 2*dim))
inputs = np.concatenate((data,data_invalid),axis=0)
outputs = np.concatenate((np.ones((num,1)),np.zeros((len(data_invalid),1))),axis=0)
print(inputs.shape,outputs.shape)
io = np.concatenate((inputs,outputs),axis=1)
random.shuffle(io)
train_n = int(2*num*0.9)
train, test = io[:train_n], io[train_n:]
train_in, train_out = train[:,:dim*2], train[:,dim*2:]
test_in, test_out = test[:,:dim*2], test[:,dim*2:]
return train_in, train_out, test_in, test_out
# default values
评论列表
文章目录