def prepare_split(self, X, y, validation_data=None, validation_split=None):
# Preparing validation data
assert validation_split or validation_data
if validation_data is not None:
trainX, trainy = X, y
devX, devy = validation_data
else:
permutation = np.random.permutation(len(X))
trainidx = permutation[int(validation_split*len(X)):]
devidx = permutation[0:int(validation_split*len(X))]
trainX, trainy = X[trainidx], y[trainidx]
devX, devy = X[devidx], y[devidx]
if not self.cudaEfficient:
trainX = torch.FloatTensor(trainX).cuda()
trainy = torch.LongTensor(trainy).cuda()
devX = torch.FloatTensor(devX).cuda()
devy = torch.LongTensor(devy).cuda()
else:
trainX = torch.FloatTensor(trainX)
trainy = torch.LongTensor(trainy)
devX = torch.FloatTensor(devX)
devy = torch.LongTensor(devy)
return trainX, trainy, devX, devy
评论列表
文章目录