huzey commited on
Commit
d48a41d
·
1 Parent(s): 19c8c49
Files changed (1) hide show
  1. app.py +131 -30
app.py CHANGED
@@ -409,34 +409,40 @@ def blend_image_with_heatmap(image, heatmap, opacity1=0.5, opacity2=0.5):
409
  return blended.astype(np.uint8)
410
 
411
 
412
- def segment_fg_bg(images):
 
 
 
 
413
 
414
- images = F.interpolate(images, (224, 224), mode="bilinear")
415
 
416
  # model = load_alignedthreemodel()
417
  model = load_model("CLIP(ViT-B-16/openai)")
418
  from ncut_pytorch.backbone import resample_position_embeddings
419
  pos_embed = model.model.visual.positional_embedding
420
- pos_embed = resample_position_embeddings(pos_embed, 14, 14)
421
  model.model.visual.positional_embedding = torch.nn.Parameter(pos_embed)
422
 
423
- batch_size = 4
424
  chunk_idxs = torch.split(torch.arange(images.shape[0]), batch_size)
425
 
426
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
427
  model.to(device)
428
- means = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
429
- stds = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
 
 
430
 
431
  fg_acts, bg_acts = [], []
432
  for chunk_idx in chunk_idxs:
433
  with torch.no_grad():
434
  input_images = images[chunk_idx].to(device)
435
  # transform the input images
436
- input_images = (input_images - means) / stds
 
437
  # output = model(input_images)[:, 5]
438
- output = model(input_images)['attn'][6] # [B, H=14, W=14, C]
439
- fg_act = output[:, 6, 6].mean(0)
440
  bg_act = output[:, 0, 0].mean(0)
441
  fg_acts.append(fg_act)
442
  bg_acts.append(bg_act)
@@ -445,21 +451,6 @@ def segment_fg_bg(images):
445
  fg_act = F.normalize(fg_act, dim=-1)
446
  bg_act = F.normalize(bg_act, dim=-1)
447
 
448
- # ref_image = default_images[0]
449
- # image = Image.open(ref_image).convert("RGB").resize((224, 224), Image.Resampling.BILINEAR)
450
- # image = torch.tensor(np.array(image)).permute(2, 0, 1).float().to(device)
451
- # image = (image / 255.0 - means) / stds
452
- # output = model(image)['attn'][6][0]
453
- # # print(output.shape)
454
- # # bg on the center
455
- # fg_act = output[5, 5]
456
- # # bg on the bottom left
457
- # bg_act = output[0, 0]
458
- # fg_act = F.normalize(fg_act, dim=-1)
459
- # bg_act = F.normalize(bg_act, dim=-1)
460
-
461
- # print(images.mean(), images.std())
462
-
463
  fg_act, bg_act = fg_act.to(device), bg_act.to(device)
464
  chunk_idxs = torch.split(torch.arange(images.shape[0]), batch_size)
465
  heatmap_fgs, heatmap_bgs = [], []
@@ -467,9 +458,10 @@ def segment_fg_bg(images):
467
  with torch.no_grad():
468
  input_images = images[chunk_idx].to(device)
469
  # transform the input images
470
- input_images = (input_images - means) / stds
 
471
  # output = model(input_images)[:, 5]
472
- output = model(input_images)['attn'][6]
473
  output = F.normalize(output, dim=-1)
474
  heatmap_fg = output @ fg_act[:, None] # [B, H, W, 1]
475
  heatmap_bg = output @ bg_act[:, None] # [B, H, W, 1]
@@ -868,6 +860,71 @@ def ncut_run(
868
  return to_pil_images(rgb), logging_str
869
 
870
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
871
 
872
  # ailgnedcut
873
  if not directed:
@@ -1037,9 +1094,9 @@ def _ncut_run(*args, **kwargs):
1037
  torch.cuda.empty_cache()
1038
  return *(None for _ in range(n_ret)), "Error: " + str(e)
1039
 
1040
- # ret = ncut_run(*args, **kwargs)
1041
- # ret = list(ret)[:n_ret] + [ret[-1]]
1042
- # return ret
1043
 
1044
  if USE_HUGGINGFACE_ZEROGPU:
1045
  @spaces.GPU(duration=30)
@@ -1250,6 +1307,7 @@ def run_fn(
1250
  node_type2="k",
1251
  head_index_text='all',
1252
  make_symmetric=False,
 
1253
  n_ret=1,
1254
  plot_clusters=False,
1255
  alignedcut_eig_norm_plot=False,
@@ -1258,6 +1316,7 @@ def run_fn(
1258
  only_eigvecs=False,
1259
  return_eigvec_and_rgb=False,
1260
  normalize_eigvec_return=False,
 
1261
  ):
1262
  # print(node_type2, head_index_text, make_symmetric)
1263
  progress=gr.Progress()
@@ -1390,6 +1449,7 @@ def run_fn(
1390
  "lisa_prompt2": lisa_prompt2,
1391
  "lisa_prompt3": lisa_prompt3,
1392
  "is_lisa": is_lisa,
 
1393
  "n_ret": n_ret,
1394
  "plot_clusters": plot_clusters,
1395
  "alignedcut_eig_norm_plot": alignedcut_eig_norm_plot,
@@ -1401,6 +1461,7 @@ def run_fn(
1401
  "only_eigvecs": only_eigvecs,
1402
  "return_eigvec_and_rgb": return_eigvec_and_rgb,
1403
  "normalize_eigvec_return": normalize_eigvec_return,
 
1404
  }
1405
  # print(kwargs)
1406
 
@@ -2196,7 +2257,8 @@ demo = gr.Blocks(
2196
  css=custom_css,
2197
  )
2198
  with demo:
2199
-
 
2200
  with gr.Tab('PlayGround'):
2201
  eigvecs = gr.State(np.array([]))
2202
  tsne3d_rgb = gr.State(np.array([]))
@@ -4247,6 +4309,45 @@ with demo:
4247
  outputs=[mask_gallery, crop_gallery])
4248
 
4249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4250
  with gr.Tab('Sub-cluster (dev)', visible=False) as sub_cluster_tab:
4251
  with gr.Row():
4252
  image_cluster_plot = gr.Image(value=None, label="Image-level clustering", elem_id="image_cluster_plot", interactive=False)
 
409
  return blended.astype(np.uint8)
410
 
411
 
412
+ def segment_fg_bg(images, hw=224, i_layer=6, batch_size=4, transform_images=True):
413
+
414
+ assert hw % 16 == 0, "The height and width of the image must be divisible by 16."
415
+ psz = hw // 16
416
+ center_xy = (psz-1) // 2
417
 
418
+ images = F.interpolate(images, (hw, hw), mode="bilinear")
419
 
420
  # model = load_alignedthreemodel()
421
  model = load_model("CLIP(ViT-B-16/openai)")
422
  from ncut_pytorch.backbone import resample_position_embeddings
423
  pos_embed = model.model.visual.positional_embedding
424
+ pos_embed = resample_position_embeddings(pos_embed, psz, psz)
425
  model.model.visual.positional_embedding = torch.nn.Parameter(pos_embed)
426
 
 
427
  chunk_idxs = torch.split(torch.arange(images.shape[0]), batch_size)
428
 
429
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
430
  model.to(device)
431
+
432
+ if transform_images:
433
+ means = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
434
+ stds = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
435
 
436
  fg_acts, bg_acts = [], []
437
  for chunk_idx in chunk_idxs:
438
  with torch.no_grad():
439
  input_images = images[chunk_idx].to(device)
440
  # transform the input images
441
+ if transform_images:
442
+ input_images = (input_images - means) / stds
443
  # output = model(input_images)[:, 5]
444
+ output = model(input_images)['attn'][i_layer] # [B, H=14, W=14, C]
445
+ fg_act = output[:, center_xy, center_xy].mean(0)
446
  bg_act = output[:, 0, 0].mean(0)
447
  fg_acts.append(fg_act)
448
  bg_acts.append(bg_act)
 
451
  fg_act = F.normalize(fg_act, dim=-1)
452
  bg_act = F.normalize(bg_act, dim=-1)
453
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
  fg_act, bg_act = fg_act.to(device), bg_act.to(device)
455
  chunk_idxs = torch.split(torch.arange(images.shape[0]), batch_size)
456
  heatmap_fgs, heatmap_bgs = [], []
 
458
  with torch.no_grad():
459
  input_images = images[chunk_idx].to(device)
460
  # transform the input images
461
+ if transform_images:
462
+ input_images = (input_images - means) / stds
463
  # output = model(input_images)[:, 5]
464
+ output = model(input_images)['attn'][i_layer]
465
  output = F.normalize(output, dim=-1)
466
  heatmap_fg = output @ fg_act[:, None] # [B, H, W, 1]
467
  heatmap_bg = output @ bg_act[:, None] # [B, H, W, 1]
 
860
  return to_pil_images(rgb), logging_str
861
 
862
 
863
+ # fg-bg separated
864
+ separate_fg_bg = kwargs.get("separate_fg_bg", False)
865
+ if separate_fg_bg:
866
+ fg_threshold = kwargs.get("fg_threshold", 0.5)
867
+ feature_hw = features.shape[1]
868
+ progress(0.4, desc="Segmenting FG-BG")
869
+ heatmap_fg, heatmap_bg = segment_fg_bg(images, hw=448, transform_images=False, i_layer=4)
870
+ heatmap_fg = 1 - heatmap_fg
871
+ heatmap_bg = 1 - heatmap_bg
872
+ b, h, w, c = features.shape
873
+ heatmap_bg = rearrange(heatmap_bg, 'b h w c -> b c h w')
874
+ heatmap_fg = rearrange(heatmap_fg, 'b h w c -> b c h w')
875
+ is_cuda = torch.cuda.is_available()
876
+ heatmap_fg = F.interpolate(heatmap_fg, (h, w), mode="bicubic")
877
+ heatmap_bg = F.interpolate(heatmap_bg, (h, w), mode="bicubic")
878
+ heatmap_fg = heatmap_fg.flatten()
879
+ heatmap_bg = heatmap_bg.flatten()
880
+ fg_minus_bg = heatmap_fg - heatmap_bg
881
+
882
+ def _to_mask(heatmap, threshold, gamma=0.5):
883
+ heatmap = (heatmap - heatmap.mean()) / heatmap.std()
884
+ heatmap = heatmap.double()
885
+ heatmap = torch.exp(heatmap)
886
+ heatmap = 1 / heatmap ** gamma
887
+ if heatmap.shape[0] > 10000:
888
+ np.random.seed(0)
889
+ random_idx = np.random.choice(heatmap.shape[0], 10000, replace=False)
890
+ vmin, vmax = heatmap[random_idx].quantile(0.01), heatmap[random_idx].quantile(0.99)
891
+ else:
892
+ vmin, vmax = heatmap.quantile(0.01), heatmap.quantile(0.99)
893
+ heatmap = (heatmap - vmin) / (vmax - vmin)
894
+ heatmap = heatmap.reshape(b, h, w)
895
+ mask = heatmap > threshold
896
+ return mask
897
+
898
+ fg_mask = _to_mask(fg_minus_bg, fg_threshold)
899
+ features_fg = features.flatten(0, 2)[fg_mask.flatten()]
900
+
901
+ progress(0.4, desc="NCut FG")
902
+ rgb, _logging_str, eigvecs = compute_ncut(
903
+ features_fg,
904
+ num_eig=num_eig,
905
+ num_sample_ncut=num_sample_ncut,
906
+ affinity_focal_gamma=affinity_focal_gamma,
907
+ knn_ncut=knn_ncut,
908
+ knn_tsne=knn_tsne,
909
+ num_sample_tsne=num_sample_tsne,
910
+ embedding_method=embedding_method,
911
+ embedding_metric=embedding_metric,
912
+ perplexity=perplexity,
913
+ n_neighbors=n_neighbors,
914
+ min_dist=min_dist,
915
+ sampling_method=sampling_method,
916
+ indirect_connection=indirect_connection,
917
+ make_orthogonal=make_orthogonal,
918
+ metric=ncut_metric,
919
+ only_eigvecs=False,
920
+ )
921
+
922
+ rgb_all = torch.zeros(b, h, w, 3)
923
+ rgb_all_flat = rgb_all.flatten(0, 2)
924
+ rgb_all_flat[fg_mask.flatten()] = rgb
925
+ rgb_all = rgb_all_flat.reshape(b, h, w, 3)
926
+
927
+ return to_pil_images(rgb_all), logging_str
928
 
929
  # ailgnedcut
930
  if not directed:
 
1094
  torch.cuda.empty_cache()
1095
  return *(None for _ in range(n_ret)), "Error: " + str(e)
1096
 
1097
+ ret = ncut_run(*args, **kwargs)
1098
+ ret = list(ret)[:n_ret] + [ret[-1]]
1099
+ return ret
1100
 
1101
  if USE_HUGGINGFACE_ZEROGPU:
1102
  @spaces.GPU(duration=30)
 
1307
  node_type2="k",
1308
  head_index_text='all',
1309
  make_symmetric=False,
1310
+ fg_threshold=0.5,
1311
  n_ret=1,
1312
  plot_clusters=False,
1313
  alignedcut_eig_norm_plot=False,
 
1316
  only_eigvecs=False,
1317
  return_eigvec_and_rgb=False,
1318
  normalize_eigvec_return=False,
1319
+ separate_fg_bg=False,
1320
  ):
1321
  # print(node_type2, head_index_text, make_symmetric)
1322
  progress=gr.Progress()
 
1449
  "lisa_prompt2": lisa_prompt2,
1450
  "lisa_prompt3": lisa_prompt3,
1451
  "is_lisa": is_lisa,
1452
+ "fg_threshold": fg_threshold,
1453
  "n_ret": n_ret,
1454
  "plot_clusters": plot_clusters,
1455
  "alignedcut_eig_norm_plot": alignedcut_eig_norm_plot,
 
1461
  "only_eigvecs": only_eigvecs,
1462
  "return_eigvec_and_rgb": return_eigvec_and_rgb,
1463
  "normalize_eigvec_return": normalize_eigvec_return,
1464
+ "separate_fg_bg": separate_fg_bg,
1465
  }
1466
  # print(kwargs)
1467
 
 
2257
  css=custom_css,
2258
  )
2259
  with demo:
2260
+
2261
+
2262
  with gr.Tab('PlayGround'):
2263
  eigvecs = gr.State(np.array([]))
2264
  tsne3d_rgb = gr.State(np.array([]))
 
4309
  outputs=[mask_gallery, crop_gallery])
4310
 
4311
 
4312
+ with gr.Tab('FG'):
4313
+
4314
+ with gr.Row():
4315
+ with gr.Column(scale=5, min_width=200):
4316
+ input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section()
4317
+ num_images_slider.value = 30
4318
+ logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
4319
+
4320
+ with gr.Column(scale=5, min_width=200):
4321
+ output_gallery = make_output_images_section()
4322
+ fg_threshold_slider = gr.Slider(0.01, 1, step=0.01, label="Foreground threshold", value=0.5, elem_id="fg_threshold", info="increase for more foreground")
4323
+ # cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[2], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
4324
+ [
4325
+ model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
4326
+ affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
4327
+ embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
4328
+ perplexity_slider, n_neighbors_slider, min_dist_slider,
4329
+ sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
4330
+ ] = make_parameters_section()
4331
+
4332
+ false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
4333
+ no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
4334
+
4335
+ submit_button.click(
4336
+ partial(run_fn, n_ret=1, plot_clusters=False, separate_fg_bg=True),
4337
+ inputs=[
4338
+ input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
4339
+ positive_prompt, negative_prompt,
4340
+ false_placeholder, no_prompt, no_prompt, no_prompt,
4341
+ affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
4342
+ embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
4343
+ perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown,
4344
+ *[false_placeholder]*12,
4345
+ fg_threshold_slider
4346
+ ],
4347
+ outputs=[output_gallery, logging_text],
4348
+ )
4349
+
4350
+
4351
  with gr.Tab('Sub-cluster (dev)', visible=False) as sub_cluster_tab:
4352
  with gr.Row():
4353
  image_cluster_plot = gr.Image(value=None, label="Image-level clustering", elem_id="image_cluster_plot", interactive=False)