test_neighbors.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:FreeDiscovery 作者: FreeDiscovery 项目源码 文件源码
def test_nearest_centroid_ranker():
    # in the case where there is a single point by centroid,
    # nearest centroid should reduce to nearest neighbor
    from sklearn.neighbors import NearestNeighbors
    np.random.seed(0)

    n_samples = 100
    n_features = 120
    X = np.random.rand(n_samples, n_features)
    normalize(X, copy=False)
    index = np.arange(n_samples, dtype='int')
    y = np.arange(n_samples, dtype='int')
    index_train, index_test, y_train, y_test = train_test_split(index, y)
    X_train = X[index_train]
    X_test = X[index_test]


    nn = NearestNeighbors(n_neighbors=1, algorithm='brute')
    nn.fit(X_train)
    dist_ref, idx_ref = nn.kneighbors(X_test)

    nc = NearestCentroidRanker()
    nc.fit(X_train, y_train)
    dist_pred = nc.decision_function(X_test)
    y_pred = nc.predict(X_test)

    # ensures that we have the same number of unique ouput points
    # (even if absolute labels are not preserved)
    assert np.unique(idx_ref[:,0]).shape ==  np.unique(y_pred).shape

    assert_allclose(dist_pred, dist_ref[:,0])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号