def load_caltech101_30(folder=CALTECH101_30_DIR, tiny_problem=False):
caltech = scio.loadmat(folder + '/caltech101-30.matlab')
k_train, k_test = caltech['Ktrain'], caltech['Ktest']
label_tr, label_te = caltech['tr_label'], caltech['te_label']
file_tr, file_te = caltech['tr_files'], caltech['te_files']
if tiny_problem:
pattern_step = 5
fraction_limit = 0.2
k_train = k_train[:int(len(label_tr) * fraction_limit):pattern_step,
:int(len(label_tr) * fraction_limit):pattern_step]
label_tr = label_tr[:int(len(label_tr) * fraction_limit):pattern_step]
U, s, Vh = linalg.svd(k_train)
S_sqrt = linalg.diagsvd(s ** 0.5, len(s), len(s))
X = np.dot(U, S_sqrt) # examples in rows
train_x, val_x, test_x = X[0:len(X):3, :], X[1:len(X):3, :], X[2:len(X):3, :]
label_tr_enc = to_one_hot_enc(np.array(label_tr) - 1)
train_y, val_y, test_y = label_tr_enc[0:len(X):3, :], label_tr_enc[1:len(X):3, :], label_tr_enc[2:len(X):3, :]
train_file, val_file, test_file = file_tr[0:len(X):3], file_tr[1:len(X):3], file_tr[2:len(X):3]
test_dataset = Dataset(data=test_x, target=test_y, info={'files': test_file})
validation_dataset = Dataset(data=val_x, target=val_y, info={'files': val_file})
training_dataset = Dataset(data=train_x, target=train_y, info={'files': train_file})
return Datasets(train=training_dataset, validation=validation_dataset, test=test_dataset)
评论列表
文章目录