def get_digits(classes=10, rng=42):
X, y = datasets.load_digits(n_class=classes, return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y,
test_size=0.3,
random_state=rng)
trg_train = np.zeros((classes, len(y_train)), dtype='uint8')
for e in range(trg_train.shape[1]):
v = y_train[e]
trg_train[v, e] = 1
trg_test = np.zeros((classes, len(y_test)), dtype='uint8')
for e in range(trg_test.shape[1]):
v = y_test[e]
trg_test[v, e] = 1
trn = Instance(X_train.T, trg_train)
tst = Instance(X_test.T, trg_test)
return trn, tst
评论列表
文章目录