huzey commited on
Commit
032a59d
·
1 Parent(s): 5a2bda7

update playground

Browse files
Files changed (1) hide show
  1. app.py +242 -264
app.py CHANGED
@@ -1702,7 +1702,7 @@ def load_and_append(existing_images, *args, **kwargs):
1702
  gr.Info(f"Total images: {len(existing_images)}")
1703
  return existing_images
1704
 
1705
- def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_random=False, allow_download=False, markdown=True, n_example_images=100):
1706
  if markdown:
1707
  gr.Markdown('### Input Images')
1708
  input_gallery = gr.Gallery(value=None, label="Input images", show_label=True, elem_id="input_images", columns=[cols], rows=[rows], object_fit="contain", height=height, type="pil", show_share_button=False,
@@ -1750,7 +1750,7 @@ def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_
1750
  # ("CatDog", ['./images/catdog1.jpg', './images/catdog2.jpg', './images/catdog3.jpg'], "microsoft/cats_vs_dogs"),
1751
  # ("Bird", ['./images/bird1.jpg', './images/bird2.jpg', './images/bird3.jpg'], "Multimodal-Fatima/CUB_train"),
1752
  # ("ChestXray", ['./images/chestxray1.jpg', './images/chestxray2.jpg', './images/chestxray3.jpg'], "hongrui/mimic_chest_xray_v_1"),
1753
- ("BrainMRI", ['./images/brain1.jpg', './images/brain2.jpg', './images/brain3.jpg'], "sartajbhuvaji/Brain-Tumor-Classification"),
1754
  ("Kanji", ['./images/kanji1.jpg', './images/kanji2.jpg', './images/kanji3.jpg'], "yashvoladoddi37/kanjienglish"),
1755
  ]
1756
  for name, images, dataset_name in example_items:
@@ -2073,7 +2073,7 @@ def add_download_button(gallery, filename_prefix="output"):
2073
  def make_output_images_section(markdown=True, button=True):
2074
  if markdown:
2075
  gr.Markdown('### Output Images')
2076
- output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
2077
  if button:
2078
  add_rotate_flip_buttons(output_gallery)
2079
  return output_gallery
@@ -2202,7 +2202,7 @@ with demo:
2202
  with gr.Column(scale=5, min_width=200):
2203
  # gr.Markdown("### Step 2a: Run Backbone and NCUT")
2204
  # with gr.Accordion(label="Backbone Parameters", visible=True, open=False):
2205
- output_gallery = gr.Gallery(format='png', value=[], label="NCUT spectral-tSNE", show_label=True, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
2206
  def add_rotate_flip_buttons_with_state(output_gallery, tsne3d_rgb):
2207
  with gr.Row():
2208
  rotate_button = gr.Button("🔄 Rotate", elem_id="rotate_button", variant='secondary')
@@ -2414,7 +2414,7 @@ with demo:
2414
  gr.Markdown("Known Issue: Resize the browser window will break the clicking, please refresh the page.")
2415
  with gr.Accordion("Outputs", open=True):
2416
  gr.Markdown("""
2417
- 1. spectral-tSNE tree: ◆ marker is the N points, connected components to the clicked dot.
2418
  2. Cluster Heatmap: max of N cosine similarity to N points in the connected components.
2419
  """)
2420
  with gr.Column(scale=5, min_width=200):
@@ -2448,8 +2448,10 @@ with demo:
2448
  with gr.Column(scale=5, min_width=200):
2449
  output_tree_image = gr.Image(label=f"spectral-tSNE tree [row#{i_row}]", elem_id="output_image", interactive=False)
2450
  text_block = gr.Textbox("", label="Logging", elem_id=f"logging_{i_row}", type="text", placeholder="Logging information", autofocus=False, autoscroll=False, lines=2, show_label=False)
 
2451
  with gr.Column(scale=10, min_width=200):
2452
- heatmap_gallery = gr.Gallery(format='png', value=[], label=f"Cluster Heatmap [row#{i_row}]", show_label=True, elem_id="heatmap", columns=[6], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
 
2453
  return inspect_output_row, output_tree_image, heatmap_gallery, text_block
2454
 
2455
  gr.Markdown('---')
@@ -2600,6 +2602,193 @@ with demo:
2600
  outputs=inspect_output_rows + output_tree_images + heatmap_galleries + text_blocks + [current_output_row, inspect_logging_text],
2601
  )
2602
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2603
  with gr.Tab('AlignedCut'):
2604
 
2605
  with gr.Row():
@@ -2610,7 +2799,7 @@ with demo:
2610
 
2611
  with gr.Column(scale=5, min_width=200):
2612
  output_gallery = make_output_images_section()
2613
- cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[2], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
2614
  [
2615
  model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
2616
  affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
@@ -2623,7 +2812,7 @@ with demo:
2623
  no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
2624
 
2625
  submit_button.click(
2626
- partial(run_fn, n_ret=2, plot_clusters=True),
2627
  inputs=[
2628
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
2629
  positive_prompt, negative_prompt,
@@ -2632,7 +2821,7 @@ with demo:
2632
  embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
2633
  perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown
2634
  ],
2635
- outputs=[output_gallery, cluster_gallery, logging_text],
2636
  api_name="API_AlignedCut",
2637
  scroll_to_output=True,
2638
  )
@@ -2648,9 +2837,9 @@ with demo:
2648
  with gr.Column(scale=5, min_width=200):
2649
  output_gallery = make_output_images_section()
2650
  add_download_button(output_gallery, "ncut_embed")
2651
- norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
2652
  add_download_button(norm_gallery, "eig_norm")
2653
- cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
2654
  add_download_button(cluster_gallery, "clusters")
2655
  [
2656
  model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
@@ -2740,7 +2929,7 @@ with demo:
2740
  api_name="API_NCut",
2741
  )
2742
 
2743
- with gr.Tab('Recursive Cut'):
2744
  gr.Markdown('NCUT can be applied recursively, the eigenvectors from previous iteration is the input for the next iteration NCUT. ')
2745
  gr.Markdown('__Recursive NCUT__ can amplify or weaken the connections, depending on the `affinity_focal_gamma` setting, please see [Documentation](https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/#recursive-ncut)')
2746
 
@@ -2753,15 +2942,15 @@ with demo:
2753
  logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
2754
  with gr.Column(scale=5, min_width=200):
2755
  gr.Markdown('### Output (Recursion #1)')
2756
- l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
2757
  add_rotate_flip_buttons(l1_gallery)
2758
  with gr.Column(scale=5, min_width=200):
2759
  gr.Markdown('### Output (Recursion #2)')
2760
- l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
2761
  add_rotate_flip_buttons(l2_gallery)
2762
  with gr.Column(scale=5, min_width=200):
2763
  gr.Markdown('### Output (Recursion #3)')
2764
- l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
2765
  add_rotate_flip_buttons(l3_gallery)
2766
  with gr.Row():
2767
 
@@ -2809,7 +2998,7 @@ with demo:
2809
  api_name="API_RecursiveCut"
2810
  )
2811
 
2812
- with gr.Tab('Recursive Cut (Advanced)', visible=False) as tab_recursivecut_advanced:
2813
 
2814
  with gr.Row():
2815
  with gr.Column(scale=5, min_width=200):
@@ -2818,30 +3007,30 @@ with demo:
2818
  logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", lines=20)
2819
  with gr.Column(scale=5, min_width=200):
2820
  gr.Markdown('### Output (Recursion #1)')
2821
- l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
2822
  add_rotate_flip_buttons(l1_gallery)
2823
  add_download_button(l1_gallery, "ncut_embed_recur1")
2824
- l1_norm_gallery = gr.Gallery(value=[], label="Recursion #1 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
2825
  add_download_button(l1_norm_gallery, "eig_norm_recur1")
2826
- l1_cluster_gallery = gr.Gallery(value=[], label="Recursion #1 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height='auto', show_share_button=True, preview=False, interactive=False)
2827
  add_download_button(l1_cluster_gallery, "clusters_recur1")
2828
  with gr.Column(scale=5, min_width=200):
2829
  gr.Markdown('### Output (Recursion #2)')
2830
- l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
2831
  add_rotate_flip_buttons(l2_gallery)
2832
  add_download_button(l2_gallery, "ncut_embed_recur2")
2833
- l2_norm_gallery = gr.Gallery(value=[], label="Recursion #2 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
2834
  add_download_button(l2_norm_gallery, "eig_norm_recur2")
2835
- l2_cluster_gallery = gr.Gallery(value=[], label="Recursion #2 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height='auto', show_share_button=True, preview=False, interactive=False)
2836
  add_download_button(l2_cluster_gallery, "clusters_recur2")
2837
  with gr.Column(scale=5, min_width=200):
2838
  gr.Markdown('### Output (Recursion #3)')
2839
- l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
2840
  add_rotate_flip_buttons(l3_gallery)
2841
  add_download_button(l3_gallery, "ncut_embed_recur3")
2842
- l3_norm_gallery = gr.Gallery(value=[], label="Recursion #3 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
2843
  add_download_button(l3_norm_gallery, "eig_norm_recur3")
2844
- l3_cluster_gallery = gr.Gallery(value=[], label="Recursion #3 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height='auto', show_share_button=True, preview=False, interactive=False)
2845
  add_download_button(l3_cluster_gallery, "clusters_recur3")
2846
 
2847
  with gr.Row():
@@ -2889,7 +3078,7 @@ with demo:
2889
  )
2890
 
2891
 
2892
- with gr.Tab('Video'):
2893
  with gr.Row():
2894
  with gr.Column(scale=5, min_width=200):
2895
  video_input_gallery, submit_button, clear_video_button, max_frame_number = make_input_video_section()
@@ -2937,7 +3126,7 @@ with demo:
2937
  from draft_gradio_app_text import make_demo
2938
  make_demo()
2939
 
2940
- with gr.Tab('Vision-Language'):
2941
  gr.Markdown('[LISA](https://arxiv.org/pdf/2308.00692) is a vision-language model. Input a text prompt and image, LISA generate segmentation masks.')
2942
  gr.Markdown('In the mask decoder layers, LISA updates the image features w.r.t. the text prompt')
2943
  gr.Markdown('This page aims to see how the text prompt affects the image features')
@@ -2946,15 +3135,15 @@ with demo:
2946
  with gr.Row():
2947
  with gr.Column(scale=5, min_width=200):
2948
  gr.Markdown('### Output (Prompt #1)')
2949
- l1_gallery = gr.Gallery(format='png', value=[], label="Prompt #1", show_label=False, elem_id="ncut_p1", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
2950
  prompt1 = gr.Textbox(label="Input Prompt #1", elem_id="prompt1", value="where is the person, include the clothes, don't include the guitar and chair", lines=3)
2951
  with gr.Column(scale=5, min_width=200):
2952
  gr.Markdown('### Output (Prompt #2)')
2953
- l2_gallery = gr.Gallery(format='png', value=[], label="Prompt #2", show_label=False, elem_id="ncut_p2", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
2954
  prompt2 = gr.Textbox(label="Input Prompt #2", elem_id="prompt2", value="where is the Gibson Les Pual guitar", lines=3)
2955
  with gr.Column(scale=5, min_width=200):
2956
  gr.Markdown('### Output (Prompt #3)')
2957
- l3_gallery = gr.Gallery(format='png', value=[], label="Prompt #3", show_label=False, elem_id="ncut_p3", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
2958
  prompt3 = gr.Textbox(label="Input Prompt #3", elem_id="prompt3", value="where is the floor", lines=3)
2959
 
2960
  with gr.Row():
@@ -2986,7 +3175,7 @@ with demo:
2986
  outputs=galleries + [logging_text],
2987
  )
2988
 
2989
- with gr.Tab('Model Aligned'):
2990
  gr.Markdown('This page reproduce the results from the paper [AlignedCut](https://arxiv.org/abs/2406.18344)')
2991
  gr.Markdown('---')
2992
  gr.Markdown('**Features are aligned across models and layers.** A linear alignment transform is trained for each model/layer, learning signal comes from 1) fMRI brain activation and 2) segmentation preserving eigen-constraints.')
@@ -3049,30 +3238,17 @@ with demo:
3049
  gr.Markdown('')
3050
  gr.Markdown("To see a good pattern, you will need to load 100~1000 images. 100 images need 10sec for RTX4090. Running out of HuggingFace GPU Quota? Try [Demo](https://ncut-pytorch.readthedocs.io/en/latest/demo/) hosted at UPenn")
3051
  gr.Markdown('---')
3052
-
3053
- # with gr.Row():
3054
- # with gr.Column(scale=5, min_width=200):
3055
- # gr.Markdown('### Output (Recursion #1)')
3056
- # l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=False, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
3057
- # add_output_images_buttons(l1_gallery)
3058
- # with gr.Column(scale=5, min_width=200):
3059
- # gr.Markdown('### Output (Recursion #2)')
3060
- # l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=False, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
3061
- # add_output_images_buttons(l2_gallery)
3062
- # with gr.Column(scale=5, min_width=200):
3063
- # gr.Markdown('### Output (Recursion #3)')
3064
- # l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=False, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
3065
- # add_output_images_buttons(l3_gallery)
3066
  gr.Markdown('### Output (Recursion #1)')
3067
- l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[100], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False, preview=True)
3068
  add_rotate_flip_buttons(l1_gallery)
3069
  add_download_button(l1_gallery, "modelaligned_recur1")
3070
  gr.Markdown('### Output (Recursion #2)')
3071
- l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[100], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False, preview=True)
3072
  add_rotate_flip_buttons(l2_gallery)
3073
  add_download_button(l2_gallery, "modelaligned_recur2")
3074
  gr.Markdown('### Output (Recursion #3)')
3075
- l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[100], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False, preview=True)
3076
  add_rotate_flip_buttons(l3_gallery)
3077
  add_download_button(l3_gallery, "modelaligned_recur3")
3078
 
@@ -3141,7 +3317,7 @@ with demo:
3141
  def add_one_model(i_model=1):
3142
  with gr.Column(scale=5, min_width=200) as col:
3143
  gr.Markdown(f'### Output Images')
3144
- output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=False, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
3145
  submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
3146
  add_rotate_flip_buttons(output_gallery)
3147
  [
@@ -3248,16 +3424,16 @@ with demo:
3248
  def add_one_model(i_model=1):
3249
  with gr.Column(scale=5, min_width=200) as col:
3250
  gr.Markdown(f'### Output Images')
3251
- output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
3252
  submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
3253
  add_rotate_flip_buttons(output_gallery)
3254
  add_download_button(output_gallery, f"ncut_embed")
3255
- mlp_gallery = gr.Gallery(format='png', value=[], label="MLP color align", show_label=True, elem_id=f"mlp{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
3256
  add_mlp_fitting_buttons(output_gallery, mlp_gallery)
3257
  add_download_button(mlp_gallery, f"mlp_color_align")
3258
- norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id=f"eig_norm{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
3259
  add_download_button(norm_gallery, f"eig_norm")
3260
- cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id=f"clusters{i_model}", columns=[2], rows=[4], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
3261
  add_download_button(cluster_gallery, f"clusters")
3262
  [
3263
  model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
@@ -3463,16 +3639,16 @@ with demo:
3463
  def add_one_model(i_model=1):
3464
  with gr.Column(scale=5, min_width=200) as col:
3465
  gr.Markdown(f'### Output Images')
3466
- output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
3467
  submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
3468
  add_rotate_flip_buttons(output_gallery)
3469
  add_download_button(output_gallery, f"ncut_embed")
3470
- mlp_gallery = gr.Gallery(format='png', value=[], label="MLP color align", show_label=True, elem_id=f"mlp{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
3471
  add_mlp_fitting_buttons(output_gallery, mlp_gallery)
3472
  add_download_button(mlp_gallery, f"mlp_color_align")
3473
- norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id=f"eig_norm{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
3474
  add_download_button(norm_gallery, f"eig_norm")
3475
- cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id=f"clusters{i_model}", columns=[2], rows=[4], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
3476
  add_download_button(cluster_gallery, f"clusters")
3477
  [
3478
  model_dropdown, layer_slider, node_type_dropdown, node_type_dropdown2, head_index_text, make_symmetric, num_eig_slider,
@@ -3655,7 +3831,7 @@ with demo:
3655
 
3656
  with gr.Column(scale=5, min_width=200):
3657
  gr.Markdown("### Step 3: Segment and Crop")
3658
- mask_gallery = gr.Gallery(value=[], label="Segmentation Masks", show_label=True, elem_id="mask_gallery", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
3659
  run_crop_button = gr.Button("🔴 RUN", elem_id="run_crop_button", variant='primary')
3660
  add_download_button(mask_gallery, "mask")
3661
  distance_threshold_slider = gr.Slider(0, 1, step=0.01, label="Mask Threshold (FG)", value=0.9, elem_id="distance_threshold", info="increase for smaller FG mask")
@@ -3665,7 +3841,7 @@ with demo:
3665
  overlay_image_checkbox = gr.Checkbox(label="Overlay Original Image", value=True, elem_id="overlay_image_checkbox")
3666
  # filter_small_area_checkbox = gr.Checkbox(label="Noise Reduction", value=True, elem_id="filter_small_area_checkbox")
3667
  distance_power_slider = gr.Slider(-3, 3, step=0.01, label="Distance Power", value=0.5, elem_id="distance_power", info="d = d^p", visible=False)
3668
- crop_gallery = gr.Gallery(value=[], label="Cropped Images", show_label=True, elem_id="crop_gallery", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
3669
  add_download_button(crop_gallery, "cropped")
3670
  crop_expand_slider = gr.Slider(1.0, 2.0, step=0.1, label="Crop bbox Expand Factor", value=1.0, elem_id="crop_expand", info="increase for larger crop", visible=True)
3671
  area_threshold_slider = gr.Slider(0, 100, step=0.1, label="Area Threshold (%)", value=3, elem_id="area_threshold", info="for noise filtering (area of connected components)", visible=False)
@@ -4143,192 +4319,6 @@ with demo:
4143
  outputs=[n_eig, current_idx, parent_plot, current_plot, *child_plots, child_idx],
4144
  )
4145
 
4146
- with gr.Tab('PlayGround (eig)', visible=True) as test_playground_tab2:
4147
- eigvecs = gr.State(np.array([]))
4148
- with gr.Row():
4149
- with gr.Column(scale=5, min_width=200):
4150
- gr.Markdown("### Step 1: Load Images")
4151
- input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(n_example_images=100)
4152
- submit_button.visible = False
4153
- num_images_slider.value = 30
4154
-
4155
- false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
4156
- no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
4157
-
4158
-
4159
- with gr.Column(scale=5, min_width=200):
4160
- gr.Markdown("### Step 2a: Run Backbone and NCUT")
4161
- with gr.Accordion(label="Backbone Parameters", visible=True, open=False):
4162
- [
4163
- model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
4164
- affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
4165
- embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
4166
- perplexity_slider, n_neighbors_slider, min_dist_slider,
4167
- sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
4168
- ] = make_parameters_section(ncut_parameter_dropdown=False, tsne_parameter_dropdown=False)
4169
- num_eig_slider.value = 1024
4170
- num_eig_slider.visible = False
4171
- submit_button = gr.Button("🔴 RUN NCUT", elem_id="run_ncut", variant='primary')
4172
- logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
4173
- submit_button.click(
4174
- partial(run_fn, n_ret=1, only_eigvecs=True),
4175
- inputs=[
4176
- input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
4177
- positive_prompt, negative_prompt,
4178
- false_placeholder, no_prompt, no_prompt, no_prompt,
4179
- affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
4180
- embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
4181
- perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown
4182
- ],
4183
- outputs=[eigvecs, logging_text],
4184
- )
4185
- gr.Markdown("### Step 2b: Pick an Image")
4186
- from gradio_image_prompter import ImagePrompter
4187
- with gr.Row():
4188
- image1_slider = gr.Slider(0, 100, step=1, label="Image#1 Index", value=0, elem_id="image1_slider", interactive=True)
4189
- load_one_image_button = gr.Button("🔴 Load Image", elem_id="load_one_image_button", variant='primary')
4190
- gr.Markdown("### Step 2c: Draw a Point")
4191
- gr.Markdown("""
4192
- <h5>
4193
- 🖱️ Left Click: Foreground </br>
4194
- </h5>
4195
- """)
4196
- prompt_image1 = ImagePrompter(show_label=False, elem_id="prompt_image1", interactive=False)
4197
- def update_prompt_image(original_images, index):
4198
- images = original_images
4199
- if images is None:
4200
- return
4201
- total_len = len(images)
4202
- if total_len == 0:
4203
- return
4204
- if index >= total_len:
4205
- index = total_len - 1
4206
-
4207
- return ImagePrompter(value={'image': images[index][0], 'points': []}, interactive=True)
4208
- # return gr.Image(value=images[index][0], elem_id=f"prompt_image{randint}", interactive=True)
4209
- load_one_image_button.click(update_prompt_image, inputs=[input_gallery, image1_slider], outputs=[prompt_image1])
4210
-
4211
- child_idx = gr.State([])
4212
- current_idx = gr.State(None)
4213
- n_eig = gr.State(64)
4214
- with gr.Column(scale=5, min_width=200):
4215
- gr.Markdown("### Step 3: Check groupping")
4216
- child_distance_slider = gr.Slider(0, 0.5, step=0.001, label="Child Distance", value=0.1, elem_id="child_distance_slider", interactive=True)
4217
- child_distance_slider.visible = False
4218
- overlay_image_checkbox = gr.Checkbox(label="Overlay Image", value=True, elem_id="overlay_image_checkbox", interactive=True)
4219
- n_eig_slider = gr.Slider(0, 1024, step=1, label="Number of Eigenvectors", value=256, elem_id="n_eig_slider", interactive=True)
4220
- run_button = gr.Button("🔴 RUN", elem_id="run_groupping", variant='primary')
4221
- with gr.Row():
4222
- doublue_eigs_button = gr.Button("⬇️ +eigvecs", elem_id="doublue_eigs_button", variant='secondary')
4223
- half_eigs_button = gr.Button("⬆️ -eigvecs", elem_id="half_eigs_button", variant='secondary')
4224
- current_plot = gr.Gallery(value=None, label="Current", show_label=True, elem_id="current_plot", interactive=False, rows=[1], columns=[2])
4225
-
4226
- def relative_xy(prompts):
4227
- image = prompts['image']
4228
- points = np.asarray(prompts['points'])
4229
- if points.shape[0] == 0:
4230
- return [], []
4231
- is_point = points[:, 5] == 4.0
4232
- points = points[is_point]
4233
- is_positive = points[:, 2] == 1.0
4234
- is_negative = points[:, 2] == 0.0
4235
- xy = points[:, :2].tolist()
4236
- if isinstance(image, str):
4237
- image = Image.open(image)
4238
- image = np.array(image)
4239
- h, w = image.shape[:2]
4240
- new_xy = [(x/w, y/h) for x, y in xy]
4241
- # print(new_xy)
4242
- return new_xy, is_positive
4243
-
4244
- def xy_eigvec(prompts, image_idx, eigvecs):
4245
- eigvec = eigvecs[image_idx]
4246
- xy, is_positive = relative_xy(prompts)
4247
- for i, (x, y) in enumerate(xy):
4248
- if not is_positive[i]:
4249
- continue
4250
- x = int(x * eigvec.shape[1])
4251
- y = int(y * eigvec.shape[0])
4252
- return eigvec[y, x], (y, x)
4253
-
4254
- from ncut_pytorch.ncut_pytorch import _transform_heatmap
4255
- def _run_heatmap_fn(images, eigvecs, prompt_image_idx, prompt_points, n_eig, flat_idx=None, raw_heatmap=False, overlay_image=True):
4256
- left = eigvecs[..., :n_eig]
4257
- if flat_idx is not None:
4258
- right = eigvecs.reshape(-1, eigvecs.shape[-1])[flat_idx]
4259
- y, x = None, None
4260
- else:
4261
- right, (y, x) = xy_eigvec(prompt_points, prompt_image_idx, eigvecs)
4262
- right = right[:n_eig]
4263
- left = F.normalize(left, p=2, dim=-1)
4264
- _right = F.normalize(right, p=2, dim=-1)
4265
- heatmap = left @ _right.unsqueeze(-1)
4266
- heatmap = heatmap.squeeze(-1)
4267
- # heatmap = 1 - heatmap
4268
- # heatmap = _transform_heatmap(heatmap)
4269
- if raw_heatmap:
4270
- return heatmap
4271
- # apply hot colormap and covert to PIL image 256x256
4272
- # gr.Info(f"heatmap vmin: {heatmap.min()}, vmax: {heatmap.max()}, mean: {heatmap.mean()}")
4273
- heatmap = heatmap.cpu().numpy()
4274
- hot_map = matplotlib.colormaps['hot']
4275
- heatmap = hot_map(heatmap)
4276
- pil_images = to_pil_images(torch.tensor(heatmap), target_size=256, force_size=True)
4277
- if overlay_image:
4278
- overlaied_images = []
4279
- for i_image in range(len(images)):
4280
- rgb_image = images[i_image].resize((256, 256))
4281
- rgb_image = np.array(rgb_image)
4282
- heatmap_image = np.array(pil_images[i_image])[..., :3]
4283
- blend_image = 0.5 * rgb_image + 0.5 * heatmap_image
4284
- blend_image = Image.fromarray(blend_image.astype(np.uint8))
4285
- overlaied_images.append(blend_image)
4286
- pil_images = overlaied_images
4287
- return pil_images, (y, x)
4288
-
4289
- @torch.no_grad()
4290
- def run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx=None, overlay_image=True):
4291
- gr.Info(f"current number of eigenvectors: {n_eig}", 2)
4292
- eigvecs = torch.tensor(eigvecs)
4293
- image1_slider = min(image1_slider, len(images)-1)
4294
- images = [image[0] for image in images]
4295
- if isinstance(images[0], str):
4296
- images = [Image.open(image[0]).convert("RGB").resize((256, 256)) for image in images]
4297
-
4298
- current_heatmap, (y, x) = _run_heatmap_fn(images, eigvecs, image1_slider, prompt_image1, n_eig, flat_idx, overlay_image=overlay_image)
4299
-
4300
- return current_heatmap
4301
-
4302
- def doublue_eigs_wrapper(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx=None, overlay_image=True):
4303
- n_eig = int(n_eig*2)
4304
- n_eig = min(n_eig, eigvecs.shape[-1])
4305
- n_eig = max(n_eig, 1)
4306
- return gr.update(value=n_eig), run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx, overlay_image=overlay_image)
4307
-
4308
- def half_eigs_wrapper(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, current_idx=None, overlay_image=True):
4309
- n_eig = int(n_eig/2)
4310
- n_eig = min(n_eig, eigvecs.shape[-1])
4311
- n_eig = max(n_eig, 1)
4312
- return gr.update(value=n_eig), run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, current_idx, overlay_image=overlay_image)
4313
-
4314
- none_placeholder = gr.State(None)
4315
- run_button.click(
4316
- run_heatmap,
4317
- inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig_slider, child_distance_slider, none_placeholder, overlay_image_checkbox],
4318
- outputs=[current_plot],
4319
- )
4320
-
4321
- doublue_eigs_button.click(
4322
- doublue_eigs_wrapper,
4323
- inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig_slider, child_distance_slider, none_placeholder, overlay_image_checkbox],
4324
- outputs=[n_eig_slider, current_plot],
4325
- )
4326
-
4327
- half_eigs_button.click(
4328
- half_eigs_wrapper,
4329
- inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig_slider, child_distance_slider, current_idx, overlay_image_checkbox],
4330
- outputs=[n_eig_slider, current_plot],
4331
- )
4332
 
4333
  with gr.Tab('📄About'):
4334
  with gr.Column():
@@ -4372,30 +4362,18 @@ with demo:
4372
  label = "".join(label)
4373
  return n_smiles, gr.update(label=label, value=False)
4374
 
4375
- def unlock_tabs_with_info(n_smiles):
4376
- if n_smiles == unlock_value:
4377
- gr.Info("🔓 unlocked tabs", 2)
4378
- return gr.update(visible=True)
4379
- return gr.update()
4380
 
4381
- def unlock_tabs(n_smiles):
4382
  if n_smiles == unlock_value:
4383
- return gr.update(visible=True)
4384
- return gr.update()
 
4385
 
4386
  hidden_button.change(update_smile, [n_smiles], [n_smiles, hidden_button])
4387
- hidden_button.change(unlock_tabs_with_info, n_smiles, tab_alignedcut_advanced)
4388
- hidden_button.change(unlock_tabs, n_smiles, tab_model_aligned_advanced)
4389
- hidden_button.change(unlock_tabs, n_smiles, tab_recursivecut_advanced)
4390
- hidden_button.change(unlock_tabs, n_smiles, tab_compare_models_advanced)
4391
- hidden_button.change(unlock_tabs, n_smiles, tab_directed_ncut)
4392
- hidden_button.change(unlock_tabs, n_smiles, test_playground_tab)
4393
 
4394
- # with gr.Row():
4395
- # with gr.Column():
4396
- # gr.Markdown("##### This demo is for `ncut-pytorch`, [Documentation](https://ncut-pytorch.readthedocs.io/) ")
4397
- # with gr.Column():
4398
- # gr.Markdown("###### Running out of GPU Quota? Try [Demo](https://ncut-pytorch.readthedocs.io/en/latest/demo/) hosted at UPenn")
4399
  with gr.Row():
4400
  gr.Markdown("**This demo is for Python package `ncut-pytorch`, [Documentation](https://ncut-pytorch.readthedocs.io/)**")
4401
 
 
1702
  gr.Info(f"Total images: {len(existing_images)}")
1703
  return existing_images
1704
 
1705
+ def make_input_images_section(rows=1, cols=3, height="450px", advanced=False, is_random=False, allow_download=False, markdown=True, n_example_images=100):
1706
  if markdown:
1707
  gr.Markdown('### Input Images')
1708
  input_gallery = gr.Gallery(value=None, label="Input images", show_label=True, elem_id="input_images", columns=[cols], rows=[rows], object_fit="contain", height=height, type="pil", show_share_button=False,
 
1750
  # ("CatDog", ['./images/catdog1.jpg', './images/catdog2.jpg', './images/catdog3.jpg'], "microsoft/cats_vs_dogs"),
1751
  # ("Bird", ['./images/bird1.jpg', './images/bird2.jpg', './images/bird3.jpg'], "Multimodal-Fatima/CUB_train"),
1752
  # ("ChestXray", ['./images/chestxray1.jpg', './images/chestxray2.jpg', './images/chestxray3.jpg'], "hongrui/mimic_chest_xray_v_1"),
1753
+ ("MRI", ['./images/brain1.jpg', './images/brain2.jpg', './images/brain3.jpg'], "sartajbhuvaji/Brain-Tumor-Classification"),
1754
  ("Kanji", ['./images/kanji1.jpg', './images/kanji2.jpg', './images/kanji3.jpg'], "yashvoladoddi37/kanjienglish"),
1755
  ]
1756
  for name, images, dataset_name in example_items:
 
2073
  def make_output_images_section(markdown=True, button=True):
2074
  if markdown:
2075
  gr.Markdown('### Output Images')
2076
+ output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, interactive=False)
2077
  if button:
2078
  add_rotate_flip_buttons(output_gallery)
2079
  return output_gallery
 
2202
  with gr.Column(scale=5, min_width=200):
2203
  # gr.Markdown("### Step 2a: Run Backbone and NCUT")
2204
  # with gr.Accordion(label="Backbone Parameters", visible=True, open=False):
2205
+ output_gallery = gr.Gallery(format='png', value=[], label="NCUT spectral-tSNE", show_label=True, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, interactive=False)
2206
  def add_rotate_flip_buttons_with_state(output_gallery, tsne3d_rgb):
2207
  with gr.Row():
2208
  rotate_button = gr.Button("🔄 Rotate", elem_id="rotate_button", variant='secondary')
 
2414
  gr.Markdown("Known Issue: Resize the browser window will break the clicking, please refresh the page.")
2415
  with gr.Accordion("Outputs", open=True):
2416
  gr.Markdown("""
2417
+ 1. spectral-tSNE tree: ◆ marker is the N points, connected components to the clicked .
2418
  2. Cluster Heatmap: max of N cosine similarity to N points in the connected components.
2419
  """)
2420
  with gr.Column(scale=5, min_width=200):
 
2448
  with gr.Column(scale=5, min_width=200):
2449
  output_tree_image = gr.Image(label=f"spectral-tSNE tree [row#{i_row}]", elem_id="output_image", interactive=False)
2450
  text_block = gr.Textbox("", label="Logging", elem_id=f"logging_{i_row}", type="text", placeholder="Logging information", autofocus=False, autoscroll=False, lines=2, show_label=False)
2451
+ delete_button = gr.Button("❌ Delete", elem_id=f"delete_button_{i_row}", variant='secondary')
2452
  with gr.Column(scale=10, min_width=200):
2453
+ heatmap_gallery = gr.Gallery(format='png', value=[], label=f"Cluster Heatmap [row#{i_row}]", show_label=True, elem_id="heatmap", columns=[6], rows=[1], object_fit="contain", height="500px", show_share_button=True, interactive=False)
2454
+ delete_button.click(lambda: gr.update(visible=False), outputs=[inspect_output_row])
2455
  return inspect_output_row, output_tree_image, heatmap_gallery, text_block
2456
 
2457
  gr.Markdown('---')
 
2602
  outputs=inspect_output_rows + output_tree_images + heatmap_galleries + text_blocks + [current_output_row, inspect_logging_text],
2603
  )
2604
 
2605
+ with gr.Tab('PlayGround (eig)', visible=True) as test_playground_tab2:
2606
+ eigvecs = gr.State(np.array([]))
2607
+ with gr.Row():
2608
+ with gr.Column(scale=5, min_width=200):
2609
+ gr.Markdown("### Step 1: Load Images")
2610
+ input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(n_example_images=10)
2611
+ submit_button.visible = False
2612
+ num_images_slider.value = 30
2613
+
2614
+ false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
2615
+ no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
2616
+
2617
+
2618
+ with gr.Column(scale=5, min_width=200):
2619
+ gr.Markdown("### Step 2a: Run Backbone and NCUT")
2620
+ with gr.Accordion(label="Backbone Parameters", visible=True, open=False):
2621
+ [
2622
+ model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
2623
+ affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
2624
+ embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
2625
+ perplexity_slider, n_neighbors_slider, min_dist_slider,
2626
+ sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
2627
+ ] = make_parameters_section(ncut_parameter_dropdown=False, tsne_parameter_dropdown=False)
2628
+ num_eig_slider.value = 1024
2629
+ num_eig_slider.visible = False
2630
+ submit_button = gr.Button("🔴 RUN NCUT", elem_id="run_ncut", variant='primary')
2631
+ logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
2632
+ submit_button.click(
2633
+ partial(run_fn, n_ret=1, only_eigvecs=True),
2634
+ inputs=[
2635
+ input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
2636
+ positive_prompt, negative_prompt,
2637
+ false_placeholder, no_prompt, no_prompt, no_prompt,
2638
+ affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
2639
+ embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
2640
+ perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown
2641
+ ],
2642
+ outputs=[eigvecs, logging_text],
2643
+ )
2644
+ gr.Markdown("### Step 2b: Pick an Image and Draw a Point")
2645
+ from gradio_image_prompter import ImagePrompter
2646
+ image1_slider = gr.Slider(0, 0, step=1, label="Image#1 Index", value=0, elem_id="image1_slider", interactive=True)
2647
+ load_one_image_button = gr.Button("🔴 Load Image", elem_id="load_one_image_button", variant='primary')
2648
+ gr.Markdown("""
2649
+ <h5>
2650
+ 🖱️ Left Click: Foreground </br>
2651
+ </h5>
2652
+ """)
2653
+ prompt_image1 = ImagePrompter(show_label=False, elem_id="prompt_image1", interactive=False)
2654
+ def update_prompt_image(original_images, index):
2655
+ images = original_images
2656
+ if images is None:
2657
+ return gr.update()
2658
+ total_len = len(images)
2659
+ if total_len == 0:
2660
+ return gr.update()
2661
+ if index >= total_len:
2662
+ index = total_len - 1
2663
+ return gr.update(value={'image': images[index][0], 'points': []}, interactive=True)
2664
+ load_one_image_button.click(update_prompt_image, inputs=[input_gallery, image1_slider], outputs=[prompt_image1])
2665
+ input_gallery.change(update_prompt_image, inputs=[input_gallery, image1_slider], outputs=[prompt_image1])
2666
+ input_gallery.change(fn=lambda x: gr.update(maximum=len(x)-1), inputs=[input_gallery], outputs=[image1_slider])
2667
+ image1_slider.change(update_prompt_image, inputs=[input_gallery, image1_slider], outputs=[prompt_image1])
2668
+
2669
+ child_idx = gr.State([])
2670
+ current_idx = gr.State(None)
2671
+ n_eig = gr.State(64)
2672
+ with gr.Column(scale=5, min_width=200):
2673
+ gr.Markdown("### Step 3: Check groupping")
2674
+ child_distance_slider = gr.Slider(0, 0.5, step=0.001, label="Child Distance", value=0.1, elem_id="child_distance_slider", interactive=True)
2675
+ child_distance_slider.visible = False
2676
+ overlay_image_checkbox = gr.Checkbox(label="Overlay Image", value=True, elem_id="overlay_image_checkbox", interactive=True)
2677
+ n_eig_slider = gr.Slider(0, 1024, step=1, label="Number of Eigenvectors", value=256, elem_id="n_eig_slider", interactive=True)
2678
+ run_button = gr.Button("🔴 RUN", elem_id="run_groupping", variant='primary')
2679
+ gr.Markdown("1. 🔴 RUN </br>2. repeat: [+num_eigvecs] / [-num_eigvecs]")
2680
+ with gr.Row():
2681
+ doublue_eigs_button = gr.Button("[+num_eigvecs]", elem_id="doublue_eigs_button", variant='secondary')
2682
+ half_eigs_button = gr.Button("[-num_eigvecs]", elem_id="half_eigs_button", variant='secondary')
2683
+ current_plot = gr.Gallery(value=None, label="Current", show_label=True, elem_id="current_plot", interactive=False, rows=[1], columns=[2])
2684
+
2685
+ def relative_xy(prompts):
2686
+ image = prompts['image']
2687
+ points = np.asarray(prompts['points'])
2688
+ if points.shape[0] == 0:
2689
+ return [], []
2690
+ is_point = points[:, 5] == 4.0
2691
+ points = points[is_point]
2692
+ is_positive = points[:, 2] == 1.0
2693
+ is_negative = points[:, 2] == 0.0
2694
+ xy = points[:, :2].tolist()
2695
+ if isinstance(image, str):
2696
+ image = Image.open(image)
2697
+ image = np.array(image)
2698
+ h, w = image.shape[:2]
2699
+ new_xy = [(x/w, y/h) for x, y in xy]
2700
+ # print(new_xy)
2701
+ return new_xy, is_positive
2702
+
2703
+ def xy_eigvec(prompts, image_idx, eigvecs):
2704
+ eigvec = eigvecs[image_idx]
2705
+ xy, is_positive = relative_xy(prompts)
2706
+ for i, (x, y) in enumerate(xy):
2707
+ if not is_positive[i]:
2708
+ continue
2709
+ x = int(x * eigvec.shape[1])
2710
+ y = int(y * eigvec.shape[0])
2711
+ return eigvec[y, x], (y, x)
2712
+
2713
+ from ncut_pytorch.ncut_pytorch import _transform_heatmap
2714
+ def _run_heatmap_fn(images, eigvecs, prompt_image_idx, prompt_points, n_eig, flat_idx=None, raw_heatmap=False, overlay_image=True):
2715
+ left = eigvecs[..., :n_eig]
2716
+ if flat_idx is not None:
2717
+ right = eigvecs.reshape(-1, eigvecs.shape[-1])[flat_idx]
2718
+ y, x = None, None
2719
+ else:
2720
+ right, (y, x) = xy_eigvec(prompt_points, prompt_image_idx, eigvecs)
2721
+ right = right[:n_eig]
2722
+ left = F.normalize(left, p=2, dim=-1)
2723
+ _right = F.normalize(right, p=2, dim=-1)
2724
+ heatmap = left @ _right.unsqueeze(-1)
2725
+ heatmap = heatmap.squeeze(-1)
2726
+ # heatmap = 1 - heatmap
2727
+ # heatmap = _transform_heatmap(heatmap)
2728
+ if raw_heatmap:
2729
+ return heatmap
2730
+ # apply hot colormap and covert to PIL image 256x256
2731
+ # gr.Info(f"heatmap vmin: {heatmap.min()}, vmax: {heatmap.max()}, mean: {heatmap.mean()}")
2732
+ heatmap = heatmap.cpu().numpy()
2733
+ hot_map = matplotlib.colormaps['hot']
2734
+ heatmap = hot_map(heatmap)
2735
+ pil_images = to_pil_images(torch.tensor(heatmap), target_size=256, force_size=True)
2736
+ if overlay_image:
2737
+ overlaied_images = []
2738
+ for i_image in range(len(images)):
2739
+ rgb_image = images[i_image].resize((256, 256))
2740
+ rgb_image = np.array(rgb_image)
2741
+ heatmap_image = np.array(pil_images[i_image])[..., :3]
2742
+ blend_image = 0.5 * rgb_image + 0.5 * heatmap_image
2743
+ blend_image = Image.fromarray(blend_image.astype(np.uint8))
2744
+ overlaied_images.append(blend_image)
2745
+ pil_images = overlaied_images
2746
+ return pil_images, (y, x)
2747
+
2748
+ @torch.no_grad()
2749
+ def run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx=None, overlay_image=True):
2750
+ gr.Info(f"current number of eigenvectors: {n_eig}", 2)
2751
+ eigvecs = torch.tensor(eigvecs)
2752
+ image1_slider = min(image1_slider, len(images)-1)
2753
+ images = [image[0] for image in images]
2754
+ if isinstance(images[0], str):
2755
+ images = [Image.open(image[0]).convert("RGB").resize((256, 256)) for image in images]
2756
+
2757
+ current_heatmap, (y, x) = _run_heatmap_fn(images, eigvecs, image1_slider, prompt_image1, n_eig, flat_idx, overlay_image=overlay_image)
2758
+
2759
+ return current_heatmap
2760
+
2761
+ def doublue_eigs_wrapper(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx=None, overlay_image=True):
2762
+ n_eig = int(n_eig*2)
2763
+ n_eig = min(n_eig, eigvecs.shape[-1])
2764
+ n_eig = max(n_eig, 1)
2765
+ return gr.update(value=n_eig), run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx, overlay_image=overlay_image)
2766
+
2767
+ def half_eigs_wrapper(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, current_idx=None, overlay_image=True):
2768
+ n_eig = int(n_eig/2)
2769
+ n_eig = min(n_eig, eigvecs.shape[-1])
2770
+ n_eig = max(n_eig, 1)
2771
+ return gr.update(value=n_eig), run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, current_idx, overlay_image=overlay_image)
2772
+
2773
+ none_placeholder = gr.State(None)
2774
+ run_button.click(
2775
+ run_heatmap,
2776
+ inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig_slider, child_distance_slider, none_placeholder, overlay_image_checkbox],
2777
+ outputs=[current_plot],
2778
+ )
2779
+
2780
+ doublue_eigs_button.click(
2781
+ doublue_eigs_wrapper,
2782
+ inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig_slider, child_distance_slider, none_placeholder, overlay_image_checkbox],
2783
+ outputs=[n_eig_slider, current_plot],
2784
+ )
2785
+
2786
+ half_eigs_button.click(
2787
+ half_eigs_wrapper,
2788
+ inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig_slider, child_distance_slider, current_idx, overlay_image_checkbox],
2789
+ outputs=[n_eig_slider, current_plot],
2790
+ )
2791
+
2792
  with gr.Tab('AlignedCut'):
2793
 
2794
  with gr.Row():
 
2799
 
2800
  with gr.Column(scale=5, min_width=200):
2801
  output_gallery = make_output_images_section()
2802
+ # cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[2], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
2803
  [
2804
  model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
2805
  affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
 
2812
  no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
2813
 
2814
  submit_button.click(
2815
+ partial(run_fn, n_ret=1, plot_clusters=False),
2816
  inputs=[
2817
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
2818
  positive_prompt, negative_prompt,
 
2821
  embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
2822
  perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown
2823
  ],
2824
+ outputs=[output_gallery, logging_text],
2825
  api_name="API_AlignedCut",
2826
  scroll_to_output=True,
2827
  )
 
2837
  with gr.Column(scale=5, min_width=200):
2838
  output_gallery = make_output_images_section()
2839
  add_download_button(output_gallery, "ncut_embed")
2840
+ norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False)
2841
  add_download_button(norm_gallery, "eig_norm")
2842
+ cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False)
2843
  add_download_button(cluster_gallery, "clusters")
2844
  [
2845
  model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
 
2929
  api_name="API_NCut",
2930
  )
2931
 
2932
+ with gr.Tab('RecursiveCut'):
2933
  gr.Markdown('NCUT can be applied recursively, the eigenvectors from previous iteration is the input for the next iteration NCUT. ')
2934
  gr.Markdown('__Recursive NCUT__ can amplify or weaken the connections, depending on the `affinity_focal_gamma` setting, please see [Documentation](https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/#recursive-ncut)')
2935
 
 
2942
  logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
2943
  with gr.Column(scale=5, min_width=200):
2944
  gr.Markdown('### Output (Recursion #1)')
2945
+ l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
2946
  add_rotate_flip_buttons(l1_gallery)
2947
  with gr.Column(scale=5, min_width=200):
2948
  gr.Markdown('### Output (Recursion #2)')
2949
+ l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
2950
  add_rotate_flip_buttons(l2_gallery)
2951
  with gr.Column(scale=5, min_width=200):
2952
  gr.Markdown('### Output (Recursion #3)')
2953
+ l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
2954
  add_rotate_flip_buttons(l3_gallery)
2955
  with gr.Row():
2956
 
 
2998
  api_name="API_RecursiveCut"
2999
  )
3000
 
3001
+ with gr.Tab('RecursiveCut (Advanced)', visible=False) as tab_recursivecut_advanced:
3002
 
3003
  with gr.Row():
3004
  with gr.Column(scale=5, min_width=200):
 
3007
  logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", lines=20)
3008
  with gr.Column(scale=5, min_width=200):
3009
  gr.Markdown('### Output (Recursion #1)')
3010
+ l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
3011
  add_rotate_flip_buttons(l1_gallery)
3012
  add_download_button(l1_gallery, "ncut_embed_recur1")
3013
+ l1_norm_gallery = gr.Gallery(value=[], label="Recursion #1 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False)
3014
  add_download_button(l1_norm_gallery, "eig_norm_recur1")
3015
+ l1_cluster_gallery = gr.Gallery(value=[], label="Recursion #1 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height='450px', show_share_button=True, preview=False, interactive=False)
3016
  add_download_button(l1_cluster_gallery, "clusters_recur1")
3017
  with gr.Column(scale=5, min_width=200):
3018
  gr.Markdown('### Output (Recursion #2)')
3019
+ l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
3020
  add_rotate_flip_buttons(l2_gallery)
3021
  add_download_button(l2_gallery, "ncut_embed_recur2")
3022
+ l2_norm_gallery = gr.Gallery(value=[], label="Recursion #2 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False)
3023
  add_download_button(l2_norm_gallery, "eig_norm_recur2")
3024
+ l2_cluster_gallery = gr.Gallery(value=[], label="Recursion #2 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height='450px', show_share_button=True, preview=False, interactive=False)
3025
  add_download_button(l2_cluster_gallery, "clusters_recur2")
3026
  with gr.Column(scale=5, min_width=200):
3027
  gr.Markdown('### Output (Recursion #3)')
3028
+ l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
3029
  add_rotate_flip_buttons(l3_gallery)
3030
  add_download_button(l3_gallery, "ncut_embed_recur3")
3031
+ l3_norm_gallery = gr.Gallery(value=[], label="Recursion #3 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False)
3032
  add_download_button(l3_norm_gallery, "eig_norm_recur3")
3033
+ l3_cluster_gallery = gr.Gallery(value=[], label="Recursion #3 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height='450px', show_share_button=True, preview=False, interactive=False)
3034
  add_download_button(l3_cluster_gallery, "clusters_recur3")
3035
 
3036
  with gr.Row():
 
3078
  )
3079
 
3080
 
3081
+ with gr.Tab('Video', visible=True) as tab_video:
3082
  with gr.Row():
3083
  with gr.Column(scale=5, min_width=200):
3084
  video_input_gallery, submit_button, clear_video_button, max_frame_number = make_input_video_section()
 
3126
  from draft_gradio_app_text import make_demo
3127
  make_demo()
3128
 
3129
+ with gr.Tab('Vision-Language', visible=False) as tab_lisa:
3130
  gr.Markdown('[LISA](https://arxiv.org/pdf/2308.00692) is a vision-language model. Input a text prompt and image, LISA generate segmentation masks.')
3131
  gr.Markdown('In the mask decoder layers, LISA updates the image features w.r.t. the text prompt')
3132
  gr.Markdown('This page aims to see how the text prompt affects the image features')
 
3135
  with gr.Row():
3136
  with gr.Column(scale=5, min_width=200):
3137
  gr.Markdown('### Output (Prompt #1)')
3138
+ l1_gallery = gr.Gallery(format='png', value=[], label="Prompt #1", show_label=False, elem_id="ncut_p1", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
3139
  prompt1 = gr.Textbox(label="Input Prompt #1", elem_id="prompt1", value="where is the person, include the clothes, don't include the guitar and chair", lines=3)
3140
  with gr.Column(scale=5, min_width=200):
3141
  gr.Markdown('### Output (Prompt #2)')
3142
+ l2_gallery = gr.Gallery(format='png', value=[], label="Prompt #2", show_label=False, elem_id="ncut_p2", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
3143
  prompt2 = gr.Textbox(label="Input Prompt #2", elem_id="prompt2", value="where is the Gibson Les Pual guitar", lines=3)
3144
  with gr.Column(scale=5, min_width=200):
3145
  gr.Markdown('### Output (Prompt #3)')
3146
+ l3_gallery = gr.Gallery(format='png', value=[], label="Prompt #3", show_label=False, elem_id="ncut_p3", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
3147
  prompt3 = gr.Textbox(label="Input Prompt #3", elem_id="prompt3", value="where is the floor", lines=3)
3148
 
3149
  with gr.Row():
 
3175
  outputs=galleries + [logging_text],
3176
  )
3177
 
3178
+ with gr.Tab('Model Aligned', visible=False) as tab_aligned:
3179
  gr.Markdown('This page reproduce the results from the paper [AlignedCut](https://arxiv.org/abs/2406.18344)')
3180
  gr.Markdown('---')
3181
  gr.Markdown('**Features are aligned across models and layers.** A linear alignment transform is trained for each model/layer, learning signal comes from 1) fMRI brain activation and 2) segmentation preserving eigen-constraints.')
 
3238
  gr.Markdown('')
3239
  gr.Markdown("To see a good pattern, you will need to load 100~1000 images. 100 images need 10sec for RTX4090. Running out of HuggingFace GPU Quota? Try [Demo](https://ncut-pytorch.readthedocs.io/en/latest/demo/) hosted at UPenn")
3240
  gr.Markdown('---')
3241
+
 
 
 
 
 
 
 
 
 
 
 
 
 
3242
  gr.Markdown('### Output (Recursion #1)')
3243
+ l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[100], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False, preview=True)
3244
  add_rotate_flip_buttons(l1_gallery)
3245
  add_download_button(l1_gallery, "modelaligned_recur1")
3246
  gr.Markdown('### Output (Recursion #2)')
3247
+ l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[100], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False, preview=True)
3248
  add_rotate_flip_buttons(l2_gallery)
3249
  add_download_button(l2_gallery, "modelaligned_recur2")
3250
  gr.Markdown('### Output (Recursion #3)')
3251
+ l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[100], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False, preview=True)
3252
  add_rotate_flip_buttons(l3_gallery)
3253
  add_download_button(l3_gallery, "modelaligned_recur3")
3254
 
 
3317
  def add_one_model(i_model=1):
3318
  with gr.Column(scale=5, min_width=200) as col:
3319
  gr.Markdown(f'### Output Images')
3320
+ output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=False, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
3321
  submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
3322
  add_rotate_flip_buttons(output_gallery)
3323
  [
 
3424
  def add_one_model(i_model=1):
3425
  with gr.Column(scale=5, min_width=200) as col:
3426
  gr.Markdown(f'### Output Images')
3427
+ output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
3428
  submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
3429
  add_rotate_flip_buttons(output_gallery)
3430
  add_download_button(output_gallery, f"ncut_embed")
3431
+ mlp_gallery = gr.Gallery(format='png', value=[], label="MLP color align", show_label=True, elem_id=f"mlp{i_model}", columns=[3], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
3432
  add_mlp_fitting_buttons(output_gallery, mlp_gallery)
3433
  add_download_button(mlp_gallery, f"mlp_color_align")
3434
+ norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id=f"eig_norm{i_model}", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False)
3435
  add_download_button(norm_gallery, f"eig_norm")
3436
+ cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id=f"clusters{i_model}", columns=[2], rows=[4], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False)
3437
  add_download_button(cluster_gallery, f"clusters")
3438
  [
3439
  model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
 
3639
  def add_one_model(i_model=1):
3640
  with gr.Column(scale=5, min_width=200) as col:
3641
  gr.Markdown(f'### Output Images')
3642
+ output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
3643
  submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
3644
  add_rotate_flip_buttons(output_gallery)
3645
  add_download_button(output_gallery, f"ncut_embed")
3646
+ mlp_gallery = gr.Gallery(format='png', value=[], label="MLP color align", show_label=True, elem_id=f"mlp{i_model}", columns=[3], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
3647
  add_mlp_fitting_buttons(output_gallery, mlp_gallery)
3648
  add_download_button(mlp_gallery, f"mlp_color_align")
3649
+ norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id=f"eig_norm{i_model}", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False)
3650
  add_download_button(norm_gallery, f"eig_norm")
3651
+ cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id=f"clusters{i_model}", columns=[2], rows=[4], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False)
3652
  add_download_button(cluster_gallery, f"clusters")
3653
  [
3654
  model_dropdown, layer_slider, node_type_dropdown, node_type_dropdown2, head_index_text, make_symmetric, num_eig_slider,
 
3831
 
3832
  with gr.Column(scale=5, min_width=200):
3833
  gr.Markdown("### Step 3: Segment and Crop")
3834
+ mask_gallery = gr.Gallery(value=[], label="Segmentation Masks", show_label=True, elem_id="mask_gallery", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, interactive=False)
3835
  run_crop_button = gr.Button("🔴 RUN", elem_id="run_crop_button", variant='primary')
3836
  add_download_button(mask_gallery, "mask")
3837
  distance_threshold_slider = gr.Slider(0, 1, step=0.01, label="Mask Threshold (FG)", value=0.9, elem_id="distance_threshold", info="increase for smaller FG mask")
 
3841
  overlay_image_checkbox = gr.Checkbox(label="Overlay Original Image", value=True, elem_id="overlay_image_checkbox")
3842
  # filter_small_area_checkbox = gr.Checkbox(label="Noise Reduction", value=True, elem_id="filter_small_area_checkbox")
3843
  distance_power_slider = gr.Slider(-3, 3, step=0.01, label="Distance Power", value=0.5, elem_id="distance_power", info="d = d^p", visible=False)
3844
+ crop_gallery = gr.Gallery(value=[], label="Cropped Images", show_label=True, elem_id="crop_gallery", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, interactive=False)
3845
  add_download_button(crop_gallery, "cropped")
3846
  crop_expand_slider = gr.Slider(1.0, 2.0, step=0.1, label="Crop bbox Expand Factor", value=1.0, elem_id="crop_expand", info="increase for larger crop", visible=True)
3847
  area_threshold_slider = gr.Slider(0, 100, step=0.1, label="Area Threshold (%)", value=3, elem_id="area_threshold", info="for noise filtering (area of connected components)", visible=False)
 
4319
  outputs=[n_eig, current_idx, parent_plot, current_plot, *child_plots, child_idx],
4320
  )
4321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4322
 
4323
  with gr.Tab('📄About'):
4324
  with gr.Column():
 
4362
  label = "".join(label)
4363
  return n_smiles, gr.update(label=label, value=False)
4364
 
 
 
 
 
 
4365
 
4366
+ def unlock_tabs(n_smiles, n_tab=1):
4367
  if n_smiles == unlock_value:
4368
+ gr.Info("🔓 unlocked tabs", 2)
4369
+ return [gr.update(visible=True)] * n_tab
4370
+ return [gr.update()] * n_tab
4371
 
4372
  hidden_button.change(update_smile, [n_smiles], [n_smiles, hidden_button])
4373
+ hidden_tabs = [tab_alignedcut_advanced, tab_model_aligned_advanced, tab_recursivecut_advanced,
4374
+ tab_compare_models_advanced, tab_directed_ncut, test_playground_tab, tab_aligned, tab_lisa]
4375
+ hidden_button.change(partial(unlock_tabs, n_tab=len(hidden_tabs)), [n_smiles], hidden_tabs)
 
 
 
4376
 
 
 
 
 
 
4377
  with gr.Row():
4378
  gr.Markdown("**This demo is for Python package `ncut-pytorch`, [Documentation](https://ncut-pytorch.readthedocs.io/)**")
4379