datasets.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:RFHO 作者: lucfra 项目源码 文件源码
def load_caltech101_30(folder=CALTECH101_30_DIR, tiny_problem=False):
    caltech = scio.loadmat(folder + '/caltech101-30.matlab')
    k_train, k_test = caltech['Ktrain'], caltech['Ktest']
    label_tr, label_te = caltech['tr_label'], caltech['te_label']
    file_tr, file_te = caltech['tr_files'], caltech['te_files']

    if tiny_problem:
        pattern_step = 5
        fraction_limit = 0.2
        k_train = k_train[:int(len(label_tr) * fraction_limit):pattern_step,
                  :int(len(label_tr) * fraction_limit):pattern_step]
        label_tr = label_tr[:int(len(label_tr) * fraction_limit):pattern_step]

    U, s, Vh = linalg.svd(k_train)
    S_sqrt = linalg.diagsvd(s ** 0.5, len(s), len(s))
    X = np.dot(U, S_sqrt)  # examples in rows

    train_x, val_x, test_x = X[0:len(X):3, :], X[1:len(X):3, :], X[2:len(X):3, :]
    label_tr_enc = to_one_hot_enc(np.array(label_tr) - 1)
    train_y, val_y, test_y = label_tr_enc[0:len(X):3, :], label_tr_enc[1:len(X):3, :], label_tr_enc[2:len(X):3, :]
    train_file, val_file, test_file = file_tr[0:len(X):3], file_tr[1:len(X):3], file_tr[2:len(X):3]

    test_dataset = Dataset(data=test_x, target=test_y, info={'files': test_file})
    validation_dataset = Dataset(data=val_x, target=val_y, info={'files': val_file})
    training_dataset = Dataset(data=train_x, target=train_y, info={'files': train_file})

    return Datasets(train=training_dataset, validation=validation_dataset, test=test_dataset)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号