python类estimate_bandwidth()的实例源码

cluster_center.py 文件源码 项目:mobike 作者: angryBird2014 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def get_position():

    position = pickle.load(open('dump/geohash_to_position_dict.pkl','rb'))

    position_set = position.values()

    position_array = [list(pos) for pos in position_set]


    position_matrix = np.array(position_array)

    banwidth = cluster.estimate_bandwidth(position_matrix,quantile=0.3,n_jobs=-1)

    ms = cluster.MeanShift(bandwidth=banwidth,bin_seeding=False,n_jobs=-1)

    ms.fit(position_matrix)

    cluster_center = ms.cluster_centers_

    print(cluster_center)
    pickle.dump(cluster_center,open('dump/cluster_center.pkl','wb'),protocol=4)
helper.py 文件源码 项目:UVA 作者: chiachun 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def mean_shift(df, l1, l2, c1name, qt, cluster_all, bin_seeding):
    df1 = df.loc[df[c1name].isin([l1,l2])]
    pccols = [ i for i in range(0,50) ]
    xp = df1[pccols].as_matrix()
    bandwidth = 0
    if l1==l2:
        bandwidth = estimate_bandwidth(xp, quantile=qt)
    else:
        xp1 = df1.loc[df1[c1name]==l1, pccols].as_matrix()
        xp2 = df1.loc[df1[c1name]==l2, pccols].as_matrix()
        bandwidth1 = estimate_bandwidth(xp1, quantile=qt)
        bandwidth2 = estimate_bandwidth(xp2, quantile=qt)
        bandwidth = max(bandwidth1, bandwidth2)
    logging.info("compare (%d, %d) with width=%f", l1, l2, bandwidth)
    ms = MeanShift(bandwidth=bandwidth, cluster_all=cluster_all,
                   bin_seeding=bin_seeding)
    ms.fit(xp)        
    mslabels_unique = np.unique(ms.labels_)
    nc = len(mslabels_unique)
    nl = ms.labels_
    df.loc[df[c1name].isin([l1,l2]), c1name] = df.loc[df[c1name].isin([l1,l2]), c1name]*1000 +nl
    return nc, nl, bandwidth
clustering.py 文件源码 项目:eezzy 作者: 3Blades 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def cluster_meanshift(X_train, model_args=None, gridsearch=True, estimate_bandwidth_samples=None):
    from sklearn.cluster import MeanShift, estimate_bandwidth
    print('MeanShift')

    if gridsearch is True:
        ## TODO:
        # add hyperparamter searching. No scoring method available for this model, 
        # so we can't easily use gridsearching.

        raise NotImplementedError('No hyperparameter optimization available yet for this model. Set gridsearch to False')
        # prune(param_grid, model_args)
    else:
        param_grid = None

    if 'bandwidth' not in model_args:
        print('Calculating the bandwidth')
        bandwidth = estimate_bandwidth(X_train, n_samples=estimate_bandwidth_samples)
        model_args['bandwidth'] = bandwidth

    return ModelWrapper(MeanShift, X=X_train, model_args=model_args, param_grid=param_grid, unsupervised=True)
train_mlp.py 文件源码 项目:taxi 作者: xuguanggen 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def cluster():
    df_train = pd.read_csv(Train_CSV_Path,header=0)
    destination = []
    for i in range(len(df_train)):
        destination.append(list(eval(df_train['DESTINATION'][i])))

    destination = np.array(destination)
    bw = estimate_bandwidth(
            destination,
            quantile = 0.1,
            n_samples = 1000
            )
    ms = MeanShift(
            bandwidth = bw,
            bin_seeding = True,
            min_bin_freq = 5
            )
    ms.fit(destination)
    cluster_centers = ms.cluster_centers_
    with h5py.File('cluster.h5','w') as f:
        f.create_dataset('cluster',data = cluster_centers)
    return cluster_centers
sklearn_basic.py 文件源码 项目:base_function 作者: Rockyzsu 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def mean_shift(fig):
    global X_iris, geo
    ax = fig.add_subplot(geo + 4, projection='3d', title='mean_shift')
    bandwidth = cluster.estimate_bandwidth(X_iris, quantile=0.2, n_samples=50)
    mean_shift = cluster.MeanShift(bandwidth=bandwidth, bin_seeding=True)
    mean_shift.fit(X_iris)
    res = mean_shift.labels_
    for n, i in enumerate(X_iris):
        ax.scatter(*i[: 3], c='bgrcmyk'[res[n] % 7], marker='o')

    ax.set_xlabel('X Label')
    ax.set_ylabel('Y Label')
    ax.set_zlabel('Z Label')
    return res
alg.py 文件源码 项目:image-segmentation 作者: alexlouden 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def cluster_means_shift(self, image_cols):
        print 'Means shifting'

        bandwidth = estimate_bandwidth(image_cols, quantile=self.params.quantile, n_samples=400)
        print self.params.quantile, bandwidth
        ms = MeanShift(bandwidth=bandwidth, bin_seeding=True, min_bin_freq=50)
        ms.fit(image_cols)

        # from IPython import embed; embed(); import ipdb; ipdb.set_trace()
        self.number_of_clusters = len(np.unique(ms.labels_))

        print 'number of clusters', self.number_of_clusters

        return ms.cluster_centers_
utils.py 文件源码 项目:errorgeopy 作者: alpha-beta-soup 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def mean_shift(location, location_callback, bandwidth=None):
    """Returns one or more clusters of a set of points, using a mean shift
    algorithm.
    The result is sorted with the first value being the largest cluster.

    Kwargs:
        bandwidth (float): If bandwidth is None, a value is detected
        automatically from the input using estimate_bandwidth.

    Returns:
        A list of NamedTuples (see get_cluster_named_tuple for a definition
        of the tuple).
    """
    pts = location._tuple_points()
    if not pts:
        return None
    X = np.array(pts).reshape((len(pts), len(pts[0])))
    if np.any(np.isnan(X)) or not np.all(np.isfinite(X)):
        return None
    X = Imputer().fit_transform(X)
    X = X.astype(np.float32)
    if not bandwidth:
        bandwidth = estimate_bandwidth(X, quantile=0.3)
    ms = MeanShift(bandwidth=bandwidth or None, bin_seeding=False).fit(X)
    clusters = []
    for cluster_id, cluster_centre in enumerate(ms.cluster_centers_):
        locations = []
        for j, label in enumerate(ms.labels_):
            if not label == cluster_id:
                continue
            locations.append(location.locations[j])
        if not locations:
            continue
        clusters.append(cluster_named_tuple()(label=cluster_id,
                                              centroid=Point(cluster_centre),
                                              location=location_callback(
                                                  locations)))
    return clusters
test_mean_shift.py 文件源码 项目:Parallel-SGD 作者: angadgill 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def test_estimate_bandwidth():
    # Test estimate_bandwidth
    bandwidth = estimate_bandwidth(X, n_samples=200)
    assert_true(0.9 <= bandwidth <= 1.5)
clusters.py 文件源码 项目:extract 作者: dblalock 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def makeMeanShift(X, k=-1):
    # estimate bandwidth for mean shift
    bandwidth = cluster.estimate_bandwidth(X, quantile=0.3)
    return cluster.MeanShift(bandwidth=bandwidth, bin_seeding=True)
Classification.py 文件源码 项目:UVA 作者: chiachun 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def run(self):
        cfg = self.cfg
        self.prepare()

        # Run KMean clustering. The resulted cluster centers
        # will be used as seeds for the later MeanShift clustering, which will
        # split the KMean clusters into subclusters if MeanShift find subgroups.  
        n_clusters = len(self.dfp)/cfg.avg_clsize
        labels, centers = self.run_kmean(n_clusters)
        self.dfp['label1'] = labels
        kvals = np.unique(self.dfp.label1.values)

        # Use the largest kmean group to estimate MeanShift bandwidth
        idxmax = self.dfp.label1.value_counts().idxmax()
        df_ = self.dfp.loc[self.dfp['label1']==idxmax]
        xp_ = df_[self.pccols].as_matrix()
        bandwidth = estimate_bandwidth(xp_, quantile=0.3)

        # run mean shift using centers found by KMmean 
        ms = MeanShift(bandwidth=bandwidth, seeds=centers,
                       cluster_all=True)
        xp = self.dfp[self.pccols].as_matrix()
        ms.fit(xp)        
        mslabels_unique = np.unique(ms.labels_)
        nc = len(mslabels_unique)

        # run kmean again using number of clusters found by MeanShift
        labels, centers = self.run_kmean(nc)
        self.dfp['label1'] = labels
        kvals = np.unique(self.dfp['label1'].values)
        print "Classes after the second Kmean: ", kvals

        # run mean_shift to analyze KMean clusters 
        # Samples classified as other clusters are assigned new labels
        # New classes whose counts pass the minimum threshold will
        # be kept in the analysis chain, which don't pass will be ignored.
        for kval in kvals:
           __,__, bandwidth = mean_shift(self.dfp, kval, kval, 'label1',
                                         0.3, True, False)
        print "Classification result before merging"
        print "class  counts"
        print self.dfp['label1'].value_counts() 
        # count cut
        cnts = self.dfp['label1'].value_counts()
        passed_cnts = cnts[ cnts>self.min_counts ].index.tolist()
        self.dfp = self.dfp[self.dfp['label1'].isin(passed_cnts)]

        self.mean_shift_merge('label')
ClasteringCalculator.py 文件源码 项目:TextStageProcessor 作者: mhyhre 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def make_mean_shift_clustering(self, short_filenames, input_texts):

        output_dir = self.output_dir + 'mean_shift/'
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        if self.need_tf_idf:
            self.signals.PrintInfo.emit("?????? TF-IDF...")
            idf_filename = output_dir + 'tf_idf.csv'
            msg = self.calculate_and_write_tf_idf(idf_filename, input_texts)
            self.signals.PrintInfo.emit(msg)

        vectorizer = CountVectorizer()
        X = vectorizer.fit_transform(input_texts)

        svd = TruncatedSVD(2)
        normalizer = Normalizer(copy=False)
        lsa = make_pipeline(svd, normalizer)
        X = lsa.fit_transform(X)

        if (len(input_texts) * self.mean_shift_quantile) < 1.0:
            self.mean_shift_quantile = (1.0 / len(input_texts)) + 0.05

        bandwidth = estimate_bandwidth(X, quantile=self.mean_shift_quantile)
        if bandwidth == 0:
            bandwidth = 0.1

        ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
        predict_result = ms.fit_predict(X)
        self.signals.PrintInfo.emit('\n??????? ?? ??????????:\n')

        clasters_output = ''
        for claster_index in range(max(predict_result) + 1):
            clasters_output += ('??????? ' + str(claster_index) + ':\n')
            for predict, document in zip(predict_result, short_filenames):
                if predict == claster_index:
                    clasters_output += ('  ' + str(document) + '\n')
            clasters_output += '\n'
        self.signals.PrintInfo.emit(clasters_output)
        self.signals.PrintInfo.emit('????????? ?:' + str(output_dir + 'clusters.txt'))
        writeStringToFile(clasters_output, output_dir + 'clusters.txt')

        self.draw_clusters_plot(X, predict_result, short_filenames)
mean_shift_segmentation.py 文件源码 项目:computer-vision-algorithms 作者: aleju 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def main():
    """Load image, collect pixels, cluster, create segment images, plot."""
    # load image
    img_rgb = data.coffee()
    img_rgb = misc.imresize(img_rgb, (256, 256)) / 255.0
    img = color.rgb2hsv(img_rgb)
    height, width, channels = img.shape
    print("Image shape is: ", img.shape)

    # collect pixels as tuples of (r, g, b, y, x)
    print("Collecting pixels...")
    pixels = []
    for y in range(height):
        for x in range(width):
            pixel = img[y, x, ...]
            pixels.append([pixel[0], pixel[1], pixel[2], (y/height)*2.0, (x/width)*2.0])
    pixels = np.array(pixels)
    print("Found %d pixels to cluster" % (len(pixels)))

    # cluster the pixels using mean shift
    print("Clustering...")
    bandwidth = estimate_bandwidth(pixels, quantile=0.05, n_samples=500)
    clusterer = MeanShift(bandwidth=bandwidth, bin_seeding=True)
    labels = clusterer.fit_predict(pixels)

    # process labels generated during clustering
    labels_unique = set(labels)
    labels_counts = [(lu, len([l for l in labels if l == lu])) for lu in labels_unique]
    labels_unique = sorted(list(labels_unique), key=lambda l: labels_counts[l], reverse=True)
    nb_clusters = len(labels_unique)
    print("Found %d clusters" % (nb_clusters))
    print(labels.shape)

    print("Creating images of segments...")
    img_segments = [np.copy(img_rgb)*0.25 for label in labels_unique]

    for y in range(height):
        for x in range(width):
            pixel_idx = (y*width) + x
            label = labels[pixel_idx]
            img_segments[label][y, x, 0] = 1.0

    print("Plotting...")
    images = [img_rgb]
    titles = ["Image"]
    for i in range(min(8, nb_clusters)):
        images.append(img_segments[i])
        titles.append("Segment %d" % (i))

    plot_images(images, titles)
cluster.py 文件源码 项目:StarData 作者: TorchCraft 项目源码 文件源码 阅读 65 收藏 0 点赞 0 评论 0
def _cluster(infn, outfn):
    if path.exists(outfn + '.lock') or path.exists(outfn + '.txt'):
        return
    open(outfn + '.lock', 'w').close()
    print("doing " + infn)
    data, xyt, transform, untransform, valid, maxes = parse_file(infn)
    if not valid:
        return
    bandwidth = args.bandwidth
    if bandwidth < 0:
        bandwidth = estimate_bandwidth(xyt, quantile=0.2, n_samples=500)

    ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
    ms.fit(xyt)
    centers = ms.cluster_centers_
    radius = bandwidth
    centers = untransform(centers)
    radius = untransform(radius)
    xyt = untransform(xyt)

    labels = ms.labels_
    labels_unique = np.unique(labels)
    n_clusters_ = len(labels_unique)
    few = np.bincount(labels) < args.min_deaths
    extract_battles(outfn + '.txt', data, ms, maxes, xyt, transform, untransform)

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk')
    for k, too_few in zip(range(n_clusters_), few):
        my_members = labels == k
        cluster_center = centers[k]
        if too_few:
            col = 'black'
        else:
            col = next(colors)
            ax.scatter(cluster_center[0], cluster_center[1], cluster_center[2], 'o', c=col, s=100)
        ax.scatter(xyt[my_members, 0], xyt[my_members, 1], xyt[my_members, 2], c=col)
    if args.show:
        plt.show()
    plt.savefig(outfn + ".png")
    plt.close(fig)
    try:
        os.remove(outfn + '.lock')
    except:
        pass


问题


面经


文章

微信
公众号

扫码关注公众号