multimodalart HF staff commited on
Commit
18b98ac
·
1 Parent(s): fb8832c

Update sketch_helper.py

Browse files
Files changed (1) hide show
  1. sketch_helper.py +12 -13
sketch_helper.py CHANGED
@@ -6,20 +6,19 @@ from skimage.color import lab2rgb
6
  from sklearn.cluster import KMeans
7
 
8
  def color_quantization(image, n_colors):
9
- # Convert image to LAB color space
10
- lab_image = rgb2lab(image)
11
- # Reshape image to 2D array of pixels
12
- pixels = lab_image.reshape(-1, 3)
13
- # Perform K-means clustering
14
- kmeans = KMeans(n_clusters=n_colors, random_state=0).fit(pixels)
15
- # Replace each pixel with the closest color
16
- labels = kmeans.predict(pixels)
17
  colors = kmeans.cluster_centers_
18
- quantized_pixels = colors[labels]
19
- # Convert quantized image back to RGB color space
20
- quantized_lab_image = quantized_pixels.reshape(lab_image.shape)
21
- quantized_rgb_image = lab2rgb(quantized_lab_image)
22
- return (quantized_rgb_image * 255).astype(np.uint8)
 
23
 
24
  def get_high_freq_colors(image):
25
  im = image.getcolors(maxcolors=1024*1024)
 
6
  from sklearn.cluster import KMeans
7
 
8
  def color_quantization(image, n_colors):
9
+ # Determine the number of bins dynamically
10
+ unique_colors = np.unique(image.reshape(-1, 3), axis=0)
11
+ n_bins = int(np.ceil(np.sqrt(unique_colors.shape[0])))
12
+
13
+ # Cluster the colors using k-means
14
+ kmeans = KMeans(n_clusters=n_colors, random_state=0).fit(unique_colors)
 
 
15
  colors = kmeans.cluster_centers_
16
+
17
+ # Replace each pixel with the closest color
18
+ dists = np.sum((image.reshape(-1, 1, 3) - colors.reshape(1, -1, 3))**2, axis=2)
19
+ labels = np.argmin(dists, axis=1)
20
+ return colors[labels].reshape((image.shape[0], image.shape[1], 3)).astype(np.uint8)
21
+
22
 
23
  def get_high_freq_colors(image):
24
  im = image.getcolors(maxcolors=1024*1024)