def main():
"""Load image, collect pixels, cluster, create segment images, plot."""
# load image
img_rgb = data.coffee()
img_rgb = misc.imresize(img_rgb, (256, 256)) / 255.0
img = color.rgb2hsv(img_rgb)
height, width, channels = img.shape
print("Image shape is: ", img.shape)
# collect pixels as tuples of (r, g, b, y, x)
print("Collecting pixels...")
pixels = []
for y in range(height):
for x in range(width):
pixel = img[y, x, ...]
pixels.append([pixel[0], pixel[1], pixel[2], (y/height)*2.0, (x/width)*2.0])
pixels = np.array(pixels)
print("Found %d pixels to cluster" % (len(pixels)))
# cluster the pixels using mean shift
print("Clustering...")
bandwidth = estimate_bandwidth(pixels, quantile=0.05, n_samples=500)
clusterer = MeanShift(bandwidth=bandwidth, bin_seeding=True)
labels = clusterer.fit_predict(pixels)
# process labels generated during clustering
labels_unique = set(labels)
labels_counts = [(lu, len([l for l in labels if l == lu])) for lu in labels_unique]
labels_unique = sorted(list(labels_unique), key=lambda l: labels_counts[l], reverse=True)
nb_clusters = len(labels_unique)
print("Found %d clusters" % (nb_clusters))
print(labels.shape)
print("Creating images of segments...")
img_segments = [np.copy(img_rgb)*0.25 for label in labels_unique]
for y in range(height):
for x in range(width):
pixel_idx = (y*width) + x
label = labels[pixel_idx]
img_segments[label][y, x, 0] = 1.0
print("Plotting...")
images = [img_rgb]
titles = ["Image"]
for i in range(min(8, nb_clusters)):
images.append(img_segments[i])
titles.append("Segment %d" % (i))
plot_images(images, titles)
mean_shift_segmentation.py 文件源码
python
阅读 24
收藏 0
点赞 0
评论 0
评论列表
文章目录