def get_position():
position = pickle.load(open('dump/geohash_to_position_dict.pkl','rb'))
position_set = position.values()
position_array = [list(pos) for pos in position_set]
position_matrix = np.array(position_array)
banwidth = cluster.estimate_bandwidth(position_matrix,quantile=0.3,n_jobs=-1)
ms = cluster.MeanShift(bandwidth=banwidth,bin_seeding=False,n_jobs=-1)
ms.fit(position_matrix)
cluster_center = ms.cluster_centers_
print(cluster_center)
pickle.dump(cluster_center,open('dump/cluster_center.pkl','wb'),protocol=4)
python类MeanShift()的实例源码
def mean_shift(df, l1, l2, c1name, qt, cluster_all, bin_seeding):
df1 = df.loc[df[c1name].isin([l1,l2])]
pccols = [ i for i in range(0,50) ]
xp = df1[pccols].as_matrix()
bandwidth = 0
if l1==l2:
bandwidth = estimate_bandwidth(xp, quantile=qt)
else:
xp1 = df1.loc[df1[c1name]==l1, pccols].as_matrix()
xp2 = df1.loc[df1[c1name]==l2, pccols].as_matrix()
bandwidth1 = estimate_bandwidth(xp1, quantile=qt)
bandwidth2 = estimate_bandwidth(xp2, quantile=qt)
bandwidth = max(bandwidth1, bandwidth2)
logging.info("compare (%d, %d) with width=%f", l1, l2, bandwidth)
ms = MeanShift(bandwidth=bandwidth, cluster_all=cluster_all,
bin_seeding=bin_seeding)
ms.fit(xp)
mslabels_unique = np.unique(ms.labels_)
nc = len(mslabels_unique)
nl = ms.labels_
df.loc[df[c1name].isin([l1,l2]), c1name] = df.loc[df[c1name].isin([l1,l2]), c1name]*1000 +nl
return nc, nl, bandwidth
def cluster_meanshift(X_train, model_args=None, gridsearch=True, estimate_bandwidth_samples=None):
from sklearn.cluster import MeanShift, estimate_bandwidth
print('MeanShift')
if gridsearch is True:
## TODO:
# add hyperparamter searching. No scoring method available for this model,
# so we can't easily use gridsearching.
raise NotImplementedError('No hyperparameter optimization available yet for this model. Set gridsearch to False')
# prune(param_grid, model_args)
else:
param_grid = None
if 'bandwidth' not in model_args:
print('Calculating the bandwidth')
bandwidth = estimate_bandwidth(X_train, n_samples=estimate_bandwidth_samples)
model_args['bandwidth'] = bandwidth
return ModelWrapper(MeanShift, X=X_train, model_args=model_args, param_grid=param_grid, unsupervised=True)
def cluster():
df_train = pd.read_csv(Train_CSV_Path,header=0)
destination = []
for i in range(len(df_train)):
destination.append(list(eval(df_train['DESTINATION'][i])))
destination = np.array(destination)
bw = estimate_bandwidth(
destination,
quantile = 0.1,
n_samples = 1000
)
ms = MeanShift(
bandwidth = bw,
bin_seeding = True,
min_bin_freq = 5
)
ms.fit(destination)
cluster_centers = ms.cluster_centers_
with h5py.File('cluster.h5','w') as f:
f.create_dataset('cluster',data = cluster_centers)
return cluster_centers
def test_clusterer_enforcement(self):
"""
Assert that only clustering estimators can be passed to cluster viz
"""
nomodels = [
SVC, SVR, Ridge, RidgeCV, LinearRegression, RandomForestClassifier
]
for nomodel in nomodels:
with self.assertRaises(YellowbrickTypeError):
visualizer = ClusteringScoreVisualizer(nomodel())
models = [
KMeans, MiniBatchKMeans, AffinityPropagation, MeanShift, DBSCAN, Birch
]
for model in models:
try:
visualizer = ClusteringScoreVisualizer(model())
except YellowbrickTypeError:
self.fail("could not pass clustering estimator to visualizer")
def mean_shift(data):
mean_shift = MeanShift(cluster_all=False, n_jobs=1).fit(data)
print 'Mean Shift'
print metrics.silhouette_score(data, mean_shift.labels_)
print collections.Counter(mean_shift.labels_)
def mean_shift(fig):
global X_iris, geo
ax = fig.add_subplot(geo + 4, projection='3d', title='mean_shift')
bandwidth = cluster.estimate_bandwidth(X_iris, quantile=0.2, n_samples=50)
mean_shift = cluster.MeanShift(bandwidth=bandwidth, bin_seeding=True)
mean_shift.fit(X_iris)
res = mean_shift.labels_
for n, i in enumerate(X_iris):
ax.scatter(*i[: 3], c='bgrcmyk'[res[n] % 7], marker='o')
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')
return res
def cluster_means_shift(self, image_cols):
print 'Means shifting'
bandwidth = estimate_bandwidth(image_cols, quantile=self.params.quantile, n_samples=400)
print self.params.quantile, bandwidth
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True, min_bin_freq=50)
ms.fit(image_cols)
# from IPython import embed; embed(); import ipdb; ipdb.set_trace()
self.number_of_clusters = len(np.unique(ms.labels_))
print 'number of clusters', self.number_of_clusters
return ms.cluster_centers_
def mean_shift(location, location_callback, bandwidth=None):
"""Returns one or more clusters of a set of points, using a mean shift
algorithm.
The result is sorted with the first value being the largest cluster.
Kwargs:
bandwidth (float): If bandwidth is None, a value is detected
automatically from the input using estimate_bandwidth.
Returns:
A list of NamedTuples (see get_cluster_named_tuple for a definition
of the tuple).
"""
pts = location._tuple_points()
if not pts:
return None
X = np.array(pts).reshape((len(pts), len(pts[0])))
if np.any(np.isnan(X)) or not np.all(np.isfinite(X)):
return None
X = Imputer().fit_transform(X)
X = X.astype(np.float32)
if not bandwidth:
bandwidth = estimate_bandwidth(X, quantile=0.3)
ms = MeanShift(bandwidth=bandwidth or None, bin_seeding=False).fit(X)
clusters = []
for cluster_id, cluster_centre in enumerate(ms.cluster_centers_):
locations = []
for j, label in enumerate(ms.labels_):
if not label == cluster_id:
continue
locations.append(location.locations[j])
if not locations:
continue
clusters.append(cluster_named_tuple()(label=cluster_id,
centroid=Point(cluster_centre),
location=location_callback(
locations)))
return clusters
def cluster(tr_lonlat_list,num_clusters,thiskey):
dict_fileName = r'pkl/dict_'+thiskey+".pkl"
tr_lonlat_list = np.array(tr_lonlat_list)
kmeans = KMeans(n_clusters=num_clusters,n_jobs=-1).fit(tr_lonlat_list)
#mf = MeanShift(bandwidth=0.001,bin_seeding=True,min_bin_freq=5).fit(tr_lonlat_list)
lonlat_cluster_dict = {}
for i in range(tr_lonlat_list.shape[0]):
key = str(tr_lonlat_list[i][0])+":"+str(tr_lonlat_list[i][1])
lonlat_cluster_dict[key] = kmeans.labels_[i]
f_w = open(dict_fileName,'w')
pickle.dump(lonlat_cluster_dict,f_w)
return lonlat_cluster_dict
def cluster(lonlat_list):
#dic_lonlat = {}
#for i in range(len(tr_lonlat_list)):
# lon = tr_lonlat_list[i][0]
# lat = tr_lonlat_list[i][1]
# key = str(lon)+":"+str(lat)
# if key not in dic_lonlat.keys():
# lonlat_list.append([lon,lat])
# dic_lonlat[key] = 1
#for i in range(len(te_lonlat_list)):
# lon = te_lonlat_list[i][0]
# lat = te_lonlat_list[i][1]
# key = str(lon)+":"+str(lat)
# if key not in dic_lonlat.keys():
# lonlat_list.append([lon,lat])
# dic_lonlat[key] = 1
#lonlat_list = np.array(lonlat_list)
#kmeans = KMeans(n_clusters=NUM_CLUSTERS,n_jobs=-1).fit(lonlat_list)
mf = MeanShift().fit(lonlat_list)
lonlat_cluster_dict = {}
for i in range(lonlat_list.shape[0]):
key = str(lonlat_list[i][0])+":"+str(lonlat_list[i][1])
lonlat_cluster_dict[key] = mf.labels_[i]
#for i in range(NUM_CLUSTERS):
# count = 0
# for k,v in lonlat_cluster_dict.items():
# if i == v:
# count += 1
# print('cluster:'+str(i)+'\tcount:'+str(count))
return lonlat_cluster_dict
def makeMeanShift(X, k=-1):
# estimate bandwidth for mean shift
bandwidth = cluster.estimate_bandwidth(X, quantile=0.3)
return cluster.MeanShift(bandwidth=bandwidth, bin_seeding=True)
def makeClusterers(X, k=2):
return [('MiniBatchKMeans', makeKMeans(X, k)),
('AffinityPropagation', makeAffinityProp()),
('MeanShift', makeMeanShift(X)),
('SpectralClustering', makeSpectral(X, k)),
('Ward', makeWard(X, k)),
('AgglomerativeAvg', makeAvgLinkage(X, k)),
('AgglomerativeMax', makeMaxLinkage(X, k)),
('AgglomerativeWard', makeWardLinkage(X, k)),
('DBSCAN', makeDBScan())]
def run(self):
cfg = self.cfg
self.prepare()
# Run KMean clustering. The resulted cluster centers
# will be used as seeds for the later MeanShift clustering, which will
# split the KMean clusters into subclusters if MeanShift find subgroups.
n_clusters = len(self.dfp)/cfg.avg_clsize
labels, centers = self.run_kmean(n_clusters)
self.dfp['label1'] = labels
kvals = np.unique(self.dfp.label1.values)
# Use the largest kmean group to estimate MeanShift bandwidth
idxmax = self.dfp.label1.value_counts().idxmax()
df_ = self.dfp.loc[self.dfp['label1']==idxmax]
xp_ = df_[self.pccols].as_matrix()
bandwidth = estimate_bandwidth(xp_, quantile=0.3)
# run mean shift using centers found by KMmean
ms = MeanShift(bandwidth=bandwidth, seeds=centers,
cluster_all=True)
xp = self.dfp[self.pccols].as_matrix()
ms.fit(xp)
mslabels_unique = np.unique(ms.labels_)
nc = len(mslabels_unique)
# run kmean again using number of clusters found by MeanShift
labels, centers = self.run_kmean(nc)
self.dfp['label1'] = labels
kvals = np.unique(self.dfp['label1'].values)
print "Classes after the second Kmean: ", kvals
# run mean_shift to analyze KMean clusters
# Samples classified as other clusters are assigned new labels
# New classes whose counts pass the minimum threshold will
# be kept in the analysis chain, which don't pass will be ignored.
for kval in kvals:
__,__, bandwidth = mean_shift(self.dfp, kval, kval, 'label1',
0.3, True, False)
print "Classification result before merging"
print "class counts"
print self.dfp['label1'].value_counts()
# count cut
cnts = self.dfp['label1'].value_counts()
passed_cnts = cnts[ cnts>self.min_counts ].index.tolist()
self.dfp = self.dfp[self.dfp['label1'].isin(passed_cnts)]
self.mean_shift_merge('label')
def make_mean_shift_clustering(self, short_filenames, input_texts):
output_dir = self.output_dir + 'mean_shift/'
if not os.path.exists(output_dir):
os.makedirs(output_dir)
if self.need_tf_idf:
self.signals.PrintInfo.emit("?????? TF-IDF...")
idf_filename = output_dir + 'tf_idf.csv'
msg = self.calculate_and_write_tf_idf(idf_filename, input_texts)
self.signals.PrintInfo.emit(msg)
vectorizer = CountVectorizer()
X = vectorizer.fit_transform(input_texts)
svd = TruncatedSVD(2)
normalizer = Normalizer(copy=False)
lsa = make_pipeline(svd, normalizer)
X = lsa.fit_transform(X)
if (len(input_texts) * self.mean_shift_quantile) < 1.0:
self.mean_shift_quantile = (1.0 / len(input_texts)) + 0.05
bandwidth = estimate_bandwidth(X, quantile=self.mean_shift_quantile)
if bandwidth == 0:
bandwidth = 0.1
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
predict_result = ms.fit_predict(X)
self.signals.PrintInfo.emit('\n??????? ?? ??????????:\n')
clasters_output = ''
for claster_index in range(max(predict_result) + 1):
clasters_output += ('??????? ' + str(claster_index) + ':\n')
for predict, document in zip(predict_result, short_filenames):
if predict == claster_index:
clasters_output += (' ' + str(document) + '\n')
clasters_output += '\n'
self.signals.PrintInfo.emit(clasters_output)
self.signals.PrintInfo.emit('????????? ?:' + str(output_dir + 'clusters.txt'))
writeStringToFile(clasters_output, output_dir + 'clusters.txt')
self.draw_clusters_plot(X, predict_result, short_filenames)
mean_shift_segmentation.py 文件源码
项目:computer-vision-algorithms
作者: aleju
项目源码
文件源码
阅读 23
收藏 0
点赞 0
评论 0
def main():
"""Load image, collect pixels, cluster, create segment images, plot."""
# load image
img_rgb = data.coffee()
img_rgb = misc.imresize(img_rgb, (256, 256)) / 255.0
img = color.rgb2hsv(img_rgb)
height, width, channels = img.shape
print("Image shape is: ", img.shape)
# collect pixels as tuples of (r, g, b, y, x)
print("Collecting pixels...")
pixels = []
for y in range(height):
for x in range(width):
pixel = img[y, x, ...]
pixels.append([pixel[0], pixel[1], pixel[2], (y/height)*2.0, (x/width)*2.0])
pixels = np.array(pixels)
print("Found %d pixels to cluster" % (len(pixels)))
# cluster the pixels using mean shift
print("Clustering...")
bandwidth = estimate_bandwidth(pixels, quantile=0.05, n_samples=500)
clusterer = MeanShift(bandwidth=bandwidth, bin_seeding=True)
labels = clusterer.fit_predict(pixels)
# process labels generated during clustering
labels_unique = set(labels)
labels_counts = [(lu, len([l for l in labels if l == lu])) for lu in labels_unique]
labels_unique = sorted(list(labels_unique), key=lambda l: labels_counts[l], reverse=True)
nb_clusters = len(labels_unique)
print("Found %d clusters" % (nb_clusters))
print(labels.shape)
print("Creating images of segments...")
img_segments = [np.copy(img_rgb)*0.25 for label in labels_unique]
for y in range(height):
for x in range(width):
pixel_idx = (y*width) + x
label = labels[pixel_idx]
img_segments[label][y, x, 0] = 1.0
print("Plotting...")
images = [img_rgb]
titles = ["Image"]
for i in range(min(8, nb_clusters)):
images.append(img_segments[i])
titles.append("Segment %d" % (i))
plot_images(images, titles)
def _cluster(infn, outfn):
if path.exists(outfn + '.lock') or path.exists(outfn + '.txt'):
return
open(outfn + '.lock', 'w').close()
print("doing " + infn)
data, xyt, transform, untransform, valid, maxes = parse_file(infn)
if not valid:
return
bandwidth = args.bandwidth
if bandwidth < 0:
bandwidth = estimate_bandwidth(xyt, quantile=0.2, n_samples=500)
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(xyt)
centers = ms.cluster_centers_
radius = bandwidth
centers = untransform(centers)
radius = untransform(radius)
xyt = untransform(xyt)
labels = ms.labels_
labels_unique = np.unique(labels)
n_clusters_ = len(labels_unique)
few = np.bincount(labels) < args.min_deaths
extract_battles(outfn + '.txt', data, ms, maxes, xyt, transform, untransform)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk')
for k, too_few in zip(range(n_clusters_), few):
my_members = labels == k
cluster_center = centers[k]
if too_few:
col = 'black'
else:
col = next(colors)
ax.scatter(cluster_center[0], cluster_center[1], cluster_center[2], 'o', c=col, s=100)
ax.scatter(xyt[my_members, 0], xyt[my_members, 1], xyt[my_members, 2], c=col)
if args.show:
plt.show()
plt.savefig(outfn + ".png")
plt.close(fig)
try:
os.remove(outfn + '.lock')
except:
pass