huzey commited on
Commit
a16a404
·
1 Parent(s): c8dc051

update sub-cluster

Browse files
Files changed (1) hide show
  1. app.py +216 -3
app.py CHANGED
@@ -148,7 +148,7 @@ def compute_ncut(
148
  logging_str = ""
149
 
150
  num_nodes = np.prod(features.shape[:-1])
151
- if num_nodes / 2 < num_eig:
152
  # raise gr.Error("Number of eigenvectors should be less than half the number of nodes.")
153
  gr.Warning("Number of eigenvectors should be less than half the number of nodes.\n" f"Setting num_eig to {num_nodes // 2 - 1}.")
154
  num_eig = num_nodes // 2 - 1
@@ -2196,7 +2196,7 @@ demo = gr.Blocks(
2196
  css=custom_css,
2197
  )
2198
  with demo:
2199
-
2200
  with gr.Tab('PlayGround'):
2201
  eigvecs = gr.State(np.array([]))
2202
  tsne3d_rgb = gr.State(np.array([]))
@@ -4246,6 +4246,219 @@ with demo:
4246
  fg_contrast_slider, bg_contrast_slider],
4247
  outputs=[mask_gallery, crop_gallery])
4248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4249
 
4250
  # with gr.Tab('PlayGround (test)', visible=False) as test_playground_tab:
4251
  # eigvecs = gr.State(np.array([]))
@@ -4569,7 +4782,7 @@ with demo:
4569
 
4570
  hidden_button.change(update_smile, [n_smiles], [n_smiles, hidden_button])
4571
  hidden_tabs = [tab_alignedcut_advanced, tab_model_aligned_advanced, tab_recursivecut_advanced,
4572
- tab_compare_models_advanced, tab_directed_ncut, tab_aligned, tab_lisa]
4573
  hidden_button.change(partial(unlock_tabs, n_tab=len(hidden_tabs)), [n_smiles], hidden_tabs)
4574
 
4575
  with gr.Row():
 
148
  logging_str = ""
149
 
150
  num_nodes = np.prod(features.shape[:-1])
151
+ if num_nodes / 2 - 1 < num_eig:
152
  # raise gr.Error("Number of eigenvectors should be less than half the number of nodes.")
153
  gr.Warning("Number of eigenvectors should be less than half the number of nodes.\n" f"Setting num_eig to {num_nodes // 2 - 1}.")
154
  num_eig = num_nodes // 2 - 1
 
2196
  css=custom_css,
2197
  )
2198
  with demo:
2199
+
2200
  with gr.Tab('PlayGround'):
2201
  eigvecs = gr.State(np.array([]))
2202
  tsne3d_rgb = gr.State(np.array([]))
 
4246
  fg_contrast_slider, bg_contrast_slider],
4247
  outputs=[mask_gallery, crop_gallery])
4248
 
4249
+
4250
+ with gr.Tab('Sub-cluster (dev)') as sub_cluster_tab:
4251
+ with gr.Row():
4252
+ with gr.Column(scale=5, min_width=200):
4253
+ input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section()
4254
+ num_images_slider.value = 300
4255
+ logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
4256
+
4257
+ with gr.Column(scale=5, min_width=200):
4258
+ output_gallery = make_output_images_section()
4259
+ add_download_button(output_gallery)
4260
+ with gr.Accordion("Sub-cluster", open=True):
4261
+ num_image_cluster_slider = gr.Slider(1, 20, step=1, label="Number of clusters (image-level)", value=6, elem_id="num_image_cluster", info="Image-level clustering before pixel-level clustering")
4262
+ show_class_label_checkbox = gr.Checkbox(label="Show cluster label", value=True, elem_id="show_class_label", info="Show image-level clustering label on the output")
4263
+ overlay_original_image_checkbox = gr.Checkbox(label="Overlay original image", value=False, elem_id="overlay_original_image", info="Overlay original image on the output")
4264
+ [
4265
+ model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
4266
+ affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
4267
+ embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
4268
+ perplexity_slider, n_neighbors_slider, min_dist_slider,
4269
+ sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
4270
+ ] = make_parameters_section()
4271
+
4272
+ with gr.Row():
4273
+ image_cluster_plot = gr.Image(value=None, label="Image-level clustering", elem_id="image_cluster_plot", interactive=False)
4274
+
4275
+ false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
4276
+ no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
4277
+
4278
+ def cluster_image_level(image_gallery, n_clusters=6):
4279
+ images = [tup[0] for tup in image_gallery]
4280
+ if isinstance(images[0], str):
4281
+ images = [Image.open(image) for image in images]
4282
+ resized_images = [image.resize((224, 224), Image.LANCZOS) for image in images]
4283
+ images = [transform_image(image, resolution=(224, 224), stablediffusion=False) for image in resized_images]
4284
+ images = torch.stack(images)
4285
+
4286
+ is_cuda = torch.cuda.is_available()
4287
+
4288
+ dino = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16').eval()
4289
+ from torch.utils.data import DataLoader, TensorDataset
4290
+ dataset = TensorDataset(images)
4291
+ dataloader = DataLoader(dataset, batch_size=128, shuffle=False)
4292
+ features = []
4293
+ if is_cuda:
4294
+ dino = dino.cuda()
4295
+ with torch.no_grad():
4296
+ for _images in dataloader:
4297
+ _images = _images[0]
4298
+ _images = _images.cuda() if is_cuda else _images
4299
+ feature = dino.get_intermediate_layers(_images, 12)[11]
4300
+ feature = feature[:, 0]
4301
+ # normalize the feature
4302
+ feature = torch.nn.functional.normalize(feature, dim=-1)
4303
+ features.append(feature.cpu())
4304
+ class_features = torch.cat(features, dim=0)
4305
+
4306
+ from ncut_pytorch import NCUT
4307
+ inp = class_features
4308
+ b, c = inp.shape
4309
+ inp = inp.reshape(b, c)
4310
+ class_ncut = NCUT(num_eig=n_clusters, move_output_to_cpu=True, distance='cosine', knn=1, device='cuda' if is_cuda else 'cpu')
4311
+ class_eigvecs, _ = class_ncut.fit_transform(inp)
4312
+ from ncut_pytorch import kway_ncut
4313
+ kway_eigvecs = kway_ncut(class_eigvecs.cuda()) # (n, k), one-hot
4314
+ kway_indices = torch.argmax(kway_eigvecs, dim=1)
4315
+ cluster_labels = kway_indices.cpu().numpy()
4316
+
4317
+ from sklearn.manifold import TSNE
4318
+ perplexity = min(30, int(len(cluster_labels)/2 - 1))
4319
+ tsne_2d = TSNE(n_components=2, perplexity=perplexity, random_state=0).fit_transform(class_eigvecs.cpu())
4320
+
4321
+ tsne_image_plot = plot_2d_tsne_with_images(cluster_labels, tsne_2d, resized_images)
4322
+
4323
+ return cluster_labels, tsne_image_plot, resized_images
4324
+
4325
+ def plot_2d_tsne_with_images(cluster_labels, tsne_2d, images, max_num_images=300):
4326
+ # use dark background
4327
+ plt.style.use('dark_background')
4328
+
4329
+ def pad_image_with_border(image, border_color, border_width):
4330
+ new_image = np.ones((image.shape[0] + 2 * border_width, image.shape[1] + 2 * border_width, image.shape[2]), dtype=image.dtype)
4331
+ new_image[:, :] = border_color
4332
+ new_image[border_width:-border_width, border_width:-border_width] = image
4333
+ return new_image
4334
+
4335
+ padded_images = []
4336
+ for i in range(len(tsne_2d)):
4337
+ image = images[i]
4338
+ image = np.array(image)
4339
+
4340
+ border_color = np.array(plt.cm.tab20(cluster_labels[i] / 20))[:3] * 255
4341
+ border_width = 20
4342
+ padded_image = pad_image_with_border(image, border_color, border_width)
4343
+ padded_images.append(padded_image)
4344
+
4345
+ # Plot the t-SNE points
4346
+ fig, ax = plt.subplots(1, 1, figsize=(15, 15))
4347
+ ax.scatter(tsne_2d[:, 0], tsne_2d[:, 1], s=20, c=cluster_labels, cmap='tab20')
4348
+ ax.set_xticks([])
4349
+ ax.set_yticks([])
4350
+ ax.axis('off')
4351
+ ax.set_xlim(tsne_2d[:, 0].min()*1.1, tsne_2d[:, 0].max()*1.1)
4352
+ ax.set_ylim(tsne_2d[:, 1].min()*1.1, tsne_2d[:, 1].max()*1.1)
4353
+ from matplotlib.offsetbox import AnnotationBbox, OffsetImage
4354
+
4355
+ if len(tsne_2d) > max_num_images:
4356
+ random_indices = np.random.choice(len(tsne_2d), max_num_images, replace=False)
4357
+ tsne_2d = tsne_2d[random_indices]
4358
+ padded_images = [padded_images[i] for i in random_indices]
4359
+
4360
+
4361
+ # Add the top1_image_blended to the scatter plot
4362
+ for i, (x, y) in enumerate(tsne_2d):
4363
+ img = padded_images[i]
4364
+ img = np.array(img)
4365
+ imgbox = OffsetImage(img, zoom=0.15)
4366
+ ab = AnnotationBbox(imgbox, (x, y), frameon=False)
4367
+ ax.add_artist(ab)
4368
+
4369
+ # Remove the white space around the plot
4370
+ fig.tight_layout(pad=0)
4371
+
4372
+ # Save the plot to an in-memory buffer
4373
+ buf = io.BytesIO()
4374
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
4375
+ buf.seek(0)
4376
+
4377
+ # Load the image into a NumPy array
4378
+ image = np.array(Image.open(buf))
4379
+
4380
+ # Close the buffer and plot
4381
+ buf.close()
4382
+ plt.close(fig)
4383
+
4384
+ pil_image = Image.fromarray(image)
4385
+
4386
+ # Reset the style
4387
+ plt.style.use('default')
4388
+
4389
+ return pil_image
4390
+
4391
+ def add_boarder_and_overlay_image(ncut_images, original_images, cluster_labels, is_overlay_image, is_add_border):
4392
+ def pad_image_with_left_border(image, border_color, border_width):
4393
+ new_image = np.ones((image.shape[0], image.shape[1] + border_width, image.shape[2]), dtype=image.dtype)
4394
+ new_image[:, :] = border_color
4395
+ new_image[:, border_width:] = image
4396
+ return new_image
4397
+
4398
+ if is_overlay_image:
4399
+ for i, ncut_image in enumerate(ncut_images):
4400
+ original_image = original_images[i]
4401
+ original_image = np.array(original_image)
4402
+ ncut_image = ncut_image.resize(original_image.shape[:2][::-1], Image.LANCZOS)
4403
+ ncut_image = np.array(ncut_image)
4404
+ blend_image = 0.5 * original_image + 0.5 * ncut_image
4405
+ blend_image = blend_image.astype(np.uint8)
4406
+ ncut_images[i] = Image.fromarray(blend_image)
4407
+ if is_add_border:
4408
+ for i, ncut_image in enumerate(ncut_images):
4409
+ border_color = np.array(plt.cm.tab20(cluster_labels[i] / 20))[:3] * 255
4410
+ border_width = 20
4411
+ ncut_image = ncut_image.resize((224, 224), Image.LANCZOS)
4412
+ ncut_image = np.array(ncut_image)
4413
+ ncut_image = pad_image_with_left_border(ncut_image, border_color, border_width)
4414
+ ncut_images[i] = Image.fromarray(ncut_image)
4415
+
4416
+ return ncut_images
4417
+
4418
+ def sub_cluster_run_fn(images, n_clusters, is_overlay_image, is_add_border, *args, **kwargs):
4419
+ cluster_labels, tsne_image_plot, resized_images = cluster_image_level(images, n_clusters)
4420
+
4421
+ output_images = [None] * len(images)
4422
+ for i in range(n_clusters):
4423
+ indices = np.where(cluster_labels == i)[0]
4424
+ if len(indices) == 0:
4425
+ continue
4426
+ input_images = [images[j] for j in indices]
4427
+ _i_output_images, logging_text = run_fn(input_images, *args, **kwargs)
4428
+ for _i, _idx in enumerate(indices):
4429
+ output_images[_idx] = _i_output_images[_i]
4430
+
4431
+ output_images = add_boarder_and_overlay_image(output_images, resized_images, cluster_labels, is_overlay_image, is_add_border)
4432
+ return output_images, tsne_image_plot
4433
+
4434
+ if USE_HUGGINGFACE_ZEROGPU:
4435
+ @spaces.GPU(duration=120)
4436
+ def sub_cluster_run_fn_wrap(*args, **kwargs):
4437
+ return sub_cluster_run_fn(*args, **kwargs)
4438
+ else:
4439
+ sub_cluster_run_fn_wrap = sub_cluster_run_fn
4440
+
4441
+ # image_cluster_indices = gr.State(np.array([]))
4442
+
4443
+ # submit_button.click(
4444
+ # cluster_image_level,
4445
+ # inputs=[input_gallery, num_image_cluster_slider],
4446
+ # outputs=[image_cluster_indices, image_cluster_plot],
4447
+ # )
4448
+
4449
+ submit_button.click(
4450
+ partial(sub_cluster_run_fn_wrap, n_ret=1, plot_clusters=False),
4451
+ inputs=[
4452
+ input_gallery, num_image_cluster_slider, overlay_original_image_checkbox, show_class_label_checkbox,
4453
+ model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
4454
+ positive_prompt, negative_prompt,
4455
+ false_placeholder, no_prompt, no_prompt, no_prompt,
4456
+ affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
4457
+ embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
4458
+ perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown
4459
+ ],
4460
+ outputs=[output_gallery, image_cluster_plot],
4461
+ )
4462
 
4463
  # with gr.Tab('PlayGround (test)', visible=False) as test_playground_tab:
4464
  # eigvecs = gr.State(np.array([]))
 
4782
 
4783
  hidden_button.change(update_smile, [n_smiles], [n_smiles, hidden_button])
4784
  hidden_tabs = [tab_alignedcut_advanced, tab_model_aligned_advanced, tab_recursivecut_advanced,
4785
+ tab_compare_models_advanced, tab_directed_ncut, tab_aligned, tab_lisa, sub_cluster_tab]
4786
  hidden_button.change(partial(unlock_tabs, n_tab=len(hidden_tabs)), [n_smiles], hidden_tabs)
4787
 
4788
  with gr.Row():