supervised_reduction_multiple.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:sef 作者: passalis 项目源码 文件源码
def supervised_reduction(method=None, dataset=None):
    np.random.seed(1)
    sklearn.utils.check_random_state(1)

    train_data, train_labels, test_data, test_labels = dataset_loader(dataset, seed=1)


    scaler = StandardScaler()
    train_data = scaler.fit_transform(train_data)
    test_data = scaler.transform(test_data)

    if dataset == 'yale':
        regularizer_weight = 0.0001
    else:
        regularizer_weight = 1

    n_classes = len(np.unique(train_labels))

    if method == 'lda':
        proj = LinearDiscriminantAnalysis(n_components=n_classes - 1)
        proj.fit(train_data, train_labels)
    elif method == 's-lda':
        proj = LinearSEF(train_data.shape[1], output_dimensionality=(n_classes - 1))
        proj.cuda()
        loss = proj.fit(data=train_data, target_labels=train_labels, epochs=100,
                        target='supervised', batch_size=256, regularizer_weight=regularizer_weight, learning_rate=0.001,
                        verbose=False)

    elif method == 's-lda-2x':
        # SEF output dimensions are not limited
        proj = LinearSEF(train_data.shape[1], output_dimensionality=2 * (n_classes - 1))
        proj.cuda()
        loss = proj.fit(data=train_data, target_labels=train_labels, epochs=100,
                        target='supervised', batch_size=256, regularizer_weight=regularizer_weight, learning_rate=0.001,
                        verbose=False)

    acc = evaluate_svm(proj.transform(train_data), train_labels,
                       proj.transform(test_data), test_labels)

    print("Method: ", method, " Test accuracy: ", 100 * acc, " %")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号