def color_reduction(image, n_colors, method='kmeans', palette=None):
"""Reduce the number of colors in image to n_colors using method"""
method = method.lower()
if method not in ('kmeans', 'linear', 'max', 'median', 'octree'):
method = 'kmeans'
if n_colors < 2:
n_colors = 2
elif n_colors > 128:
n_colors = 128
if method == 'kmeans':
n_clusters = n_colors
h, w = image.shape[:2]
img = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
img = img.reshape((-1, 3)) # -1 -> img.shape[0] * img.shape[1]
centers, labels = kmeans(img, n_clusters)
if palette is not None:
# palette comes in RGB
centers = cv2.cvtColor(np.array([palette]), cv2.COLOR_RGB2LAB)[0]
quant = centers[labels].reshape((h, w, 3))
output = cv2.cvtColor(quant, cv2.COLOR_LAB2BGR)
else:
img = PIL.Image.fromarray(image[:, :, ::-1], mode='RGB')
quant = img.quantize(colors=n_colors,
method=get_quantize_method(method))
if palette is not None:
palette = np.array(palette, dtype=np.uint8)
quant.putpalette(palette.flatten())
output = np.array(quant.convert('RGB'), dtype=np.uint8)[:, :, ::-1]
return output
评论列表
文章目录