Spaces:
Running
on
Zero
Running
on
Zero
update playground
Browse files
app.py
CHANGED
|
@@ -1702,7 +1702,7 @@ def load_and_append(existing_images, *args, **kwargs):
|
|
| 1702 |
gr.Info(f"Total images: {len(existing_images)}")
|
| 1703 |
return existing_images
|
| 1704 |
|
| 1705 |
-
def make_input_images_section(rows=1, cols=3, height="
|
| 1706 |
if markdown:
|
| 1707 |
gr.Markdown('### Input Images')
|
| 1708 |
input_gallery = gr.Gallery(value=None, label="Input images", show_label=True, elem_id="input_images", columns=[cols], rows=[rows], object_fit="contain", height=height, type="pil", show_share_button=False,
|
|
@@ -1750,7 +1750,7 @@ def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_
|
|
| 1750 |
# ("CatDog", ['./images/catdog1.jpg', './images/catdog2.jpg', './images/catdog3.jpg'], "microsoft/cats_vs_dogs"),
|
| 1751 |
# ("Bird", ['./images/bird1.jpg', './images/bird2.jpg', './images/bird3.jpg'], "Multimodal-Fatima/CUB_train"),
|
| 1752 |
# ("ChestXray", ['./images/chestxray1.jpg', './images/chestxray2.jpg', './images/chestxray3.jpg'], "hongrui/mimic_chest_xray_v_1"),
|
| 1753 |
-
("
|
| 1754 |
("Kanji", ['./images/kanji1.jpg', './images/kanji2.jpg', './images/kanji3.jpg'], "yashvoladoddi37/kanjienglish"),
|
| 1755 |
]
|
| 1756 |
for name, images, dataset_name in example_items:
|
|
@@ -2073,7 +2073,7 @@ def add_download_button(gallery, filename_prefix="output"):
|
|
| 2073 |
def make_output_images_section(markdown=True, button=True):
|
| 2074 |
if markdown:
|
| 2075 |
gr.Markdown('### Output Images')
|
| 2076 |
-
output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="
|
| 2077 |
if button:
|
| 2078 |
add_rotate_flip_buttons(output_gallery)
|
| 2079 |
return output_gallery
|
|
@@ -2202,7 +2202,7 @@ with demo:
|
|
| 2202 |
with gr.Column(scale=5, min_width=200):
|
| 2203 |
# gr.Markdown("### Step 2a: Run Backbone and NCUT")
|
| 2204 |
# with gr.Accordion(label="Backbone Parameters", visible=True, open=False):
|
| 2205 |
-
output_gallery = gr.Gallery(format='png', value=[], label="NCUT spectral-tSNE", show_label=True, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="
|
| 2206 |
def add_rotate_flip_buttons_with_state(output_gallery, tsne3d_rgb):
|
| 2207 |
with gr.Row():
|
| 2208 |
rotate_button = gr.Button("🔄 Rotate", elem_id="rotate_button", variant='secondary')
|
|
@@ -2414,7 +2414,7 @@ with demo:
|
|
| 2414 |
gr.Markdown("Known Issue: Resize the browser window will break the clicking, please refresh the page.")
|
| 2415 |
with gr.Accordion("Outputs", open=True):
|
| 2416 |
gr.Markdown("""
|
| 2417 |
-
1. spectral-tSNE tree: ◆ marker is the N points, connected components to the clicked
|
| 2418 |
2. Cluster Heatmap: max of N cosine similarity to N points in the connected components.
|
| 2419 |
""")
|
| 2420 |
with gr.Column(scale=5, min_width=200):
|
|
@@ -2448,8 +2448,10 @@ with demo:
|
|
| 2448 |
with gr.Column(scale=5, min_width=200):
|
| 2449 |
output_tree_image = gr.Image(label=f"spectral-tSNE tree [row#{i_row}]", elem_id="output_image", interactive=False)
|
| 2450 |
text_block = gr.Textbox("", label="Logging", elem_id=f"logging_{i_row}", type="text", placeholder="Logging information", autofocus=False, autoscroll=False, lines=2, show_label=False)
|
|
|
|
| 2451 |
with gr.Column(scale=10, min_width=200):
|
| 2452 |
-
heatmap_gallery = gr.Gallery(format='png', value=[], label=f"Cluster Heatmap [row#{i_row}]", show_label=True, elem_id="heatmap", columns=[6], rows=[1], object_fit="contain", height="
|
|
|
|
| 2453 |
return inspect_output_row, output_tree_image, heatmap_gallery, text_block
|
| 2454 |
|
| 2455 |
gr.Markdown('---')
|
|
@@ -2600,6 +2602,193 @@ with demo:
|
|
| 2600 |
outputs=inspect_output_rows + output_tree_images + heatmap_galleries + text_blocks + [current_output_row, inspect_logging_text],
|
| 2601 |
)
|
| 2602 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2603 |
with gr.Tab('AlignedCut'):
|
| 2604 |
|
| 2605 |
with gr.Row():
|
|
@@ -2610,7 +2799,7 @@ with demo:
|
|
| 2610 |
|
| 2611 |
with gr.Column(scale=5, min_width=200):
|
| 2612 |
output_gallery = make_output_images_section()
|
| 2613 |
-
cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[2], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
| 2614 |
[
|
| 2615 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
| 2616 |
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
|
@@ -2623,7 +2812,7 @@ with demo:
|
|
| 2623 |
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
| 2624 |
|
| 2625 |
submit_button.click(
|
| 2626 |
-
partial(run_fn, n_ret=
|
| 2627 |
inputs=[
|
| 2628 |
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
| 2629 |
positive_prompt, negative_prompt,
|
|
@@ -2632,7 +2821,7 @@ with demo:
|
|
| 2632 |
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
| 2633 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown
|
| 2634 |
],
|
| 2635 |
-
outputs=[output_gallery,
|
| 2636 |
api_name="API_AlignedCut",
|
| 2637 |
scroll_to_output=True,
|
| 2638 |
)
|
|
@@ -2648,9 +2837,9 @@ with demo:
|
|
| 2648 |
with gr.Column(scale=5, min_width=200):
|
| 2649 |
output_gallery = make_output_images_section()
|
| 2650 |
add_download_button(output_gallery, "ncut_embed")
|
| 2651 |
-
norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="
|
| 2652 |
add_download_button(norm_gallery, "eig_norm")
|
| 2653 |
-
cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height="
|
| 2654 |
add_download_button(cluster_gallery, "clusters")
|
| 2655 |
[
|
| 2656 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
|
@@ -2740,7 +2929,7 @@ with demo:
|
|
| 2740 |
api_name="API_NCut",
|
| 2741 |
)
|
| 2742 |
|
| 2743 |
-
with gr.Tab('
|
| 2744 |
gr.Markdown('NCUT can be applied recursively, the eigenvectors from previous iteration is the input for the next iteration NCUT. ')
|
| 2745 |
gr.Markdown('__Recursive NCUT__ can amplify or weaken the connections, depending on the `affinity_focal_gamma` setting, please see [Documentation](https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/#recursive-ncut)')
|
| 2746 |
|
|
@@ -2753,15 +2942,15 @@ with demo:
|
|
| 2753 |
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
|
| 2754 |
with gr.Column(scale=5, min_width=200):
|
| 2755 |
gr.Markdown('### Output (Recursion #1)')
|
| 2756 |
-
l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="
|
| 2757 |
add_rotate_flip_buttons(l1_gallery)
|
| 2758 |
with gr.Column(scale=5, min_width=200):
|
| 2759 |
gr.Markdown('### Output (Recursion #2)')
|
| 2760 |
-
l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="
|
| 2761 |
add_rotate_flip_buttons(l2_gallery)
|
| 2762 |
with gr.Column(scale=5, min_width=200):
|
| 2763 |
gr.Markdown('### Output (Recursion #3)')
|
| 2764 |
-
l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="
|
| 2765 |
add_rotate_flip_buttons(l3_gallery)
|
| 2766 |
with gr.Row():
|
| 2767 |
|
|
@@ -2809,7 +2998,7 @@ with demo:
|
|
| 2809 |
api_name="API_RecursiveCut"
|
| 2810 |
)
|
| 2811 |
|
| 2812 |
-
with gr.Tab('
|
| 2813 |
|
| 2814 |
with gr.Row():
|
| 2815 |
with gr.Column(scale=5, min_width=200):
|
|
@@ -2818,30 +3007,30 @@ with demo:
|
|
| 2818 |
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", lines=20)
|
| 2819 |
with gr.Column(scale=5, min_width=200):
|
| 2820 |
gr.Markdown('### Output (Recursion #1)')
|
| 2821 |
-
l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="
|
| 2822 |
add_rotate_flip_buttons(l1_gallery)
|
| 2823 |
add_download_button(l1_gallery, "ncut_embed_recur1")
|
| 2824 |
-
l1_norm_gallery = gr.Gallery(value=[], label="Recursion #1 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="
|
| 2825 |
add_download_button(l1_norm_gallery, "eig_norm_recur1")
|
| 2826 |
-
l1_cluster_gallery = gr.Gallery(value=[], label="Recursion #1 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height='
|
| 2827 |
add_download_button(l1_cluster_gallery, "clusters_recur1")
|
| 2828 |
with gr.Column(scale=5, min_width=200):
|
| 2829 |
gr.Markdown('### Output (Recursion #2)')
|
| 2830 |
-
l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="
|
| 2831 |
add_rotate_flip_buttons(l2_gallery)
|
| 2832 |
add_download_button(l2_gallery, "ncut_embed_recur2")
|
| 2833 |
-
l2_norm_gallery = gr.Gallery(value=[], label="Recursion #2 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="
|
| 2834 |
add_download_button(l2_norm_gallery, "eig_norm_recur2")
|
| 2835 |
-
l2_cluster_gallery = gr.Gallery(value=[], label="Recursion #2 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height='
|
| 2836 |
add_download_button(l2_cluster_gallery, "clusters_recur2")
|
| 2837 |
with gr.Column(scale=5, min_width=200):
|
| 2838 |
gr.Markdown('### Output (Recursion #3)')
|
| 2839 |
-
l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="
|
| 2840 |
add_rotate_flip_buttons(l3_gallery)
|
| 2841 |
add_download_button(l3_gallery, "ncut_embed_recur3")
|
| 2842 |
-
l3_norm_gallery = gr.Gallery(value=[], label="Recursion #3 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="
|
| 2843 |
add_download_button(l3_norm_gallery, "eig_norm_recur3")
|
| 2844 |
-
l3_cluster_gallery = gr.Gallery(value=[], label="Recursion #3 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height='
|
| 2845 |
add_download_button(l3_cluster_gallery, "clusters_recur3")
|
| 2846 |
|
| 2847 |
with gr.Row():
|
|
@@ -2889,7 +3078,7 @@ with demo:
|
|
| 2889 |
)
|
| 2890 |
|
| 2891 |
|
| 2892 |
-
with gr.Tab('Video'):
|
| 2893 |
with gr.Row():
|
| 2894 |
with gr.Column(scale=5, min_width=200):
|
| 2895 |
video_input_gallery, submit_button, clear_video_button, max_frame_number = make_input_video_section()
|
|
@@ -2937,7 +3126,7 @@ with demo:
|
|
| 2937 |
from draft_gradio_app_text import make_demo
|
| 2938 |
make_demo()
|
| 2939 |
|
| 2940 |
-
with gr.Tab('Vision-Language'):
|
| 2941 |
gr.Markdown('[LISA](https://arxiv.org/pdf/2308.00692) is a vision-language model. Input a text prompt and image, LISA generate segmentation masks.')
|
| 2942 |
gr.Markdown('In the mask decoder layers, LISA updates the image features w.r.t. the text prompt')
|
| 2943 |
gr.Markdown('This page aims to see how the text prompt affects the image features')
|
|
@@ -2946,15 +3135,15 @@ with demo:
|
|
| 2946 |
with gr.Row():
|
| 2947 |
with gr.Column(scale=5, min_width=200):
|
| 2948 |
gr.Markdown('### Output (Prompt #1)')
|
| 2949 |
-
l1_gallery = gr.Gallery(format='png', value=[], label="Prompt #1", show_label=False, elem_id="ncut_p1", columns=[3], rows=[5], object_fit="contain", height="
|
| 2950 |
prompt1 = gr.Textbox(label="Input Prompt #1", elem_id="prompt1", value="where is the person, include the clothes, don't include the guitar and chair", lines=3)
|
| 2951 |
with gr.Column(scale=5, min_width=200):
|
| 2952 |
gr.Markdown('### Output (Prompt #2)')
|
| 2953 |
-
l2_gallery = gr.Gallery(format='png', value=[], label="Prompt #2", show_label=False, elem_id="ncut_p2", columns=[3], rows=[5], object_fit="contain", height="
|
| 2954 |
prompt2 = gr.Textbox(label="Input Prompt #2", elem_id="prompt2", value="where is the Gibson Les Pual guitar", lines=3)
|
| 2955 |
with gr.Column(scale=5, min_width=200):
|
| 2956 |
gr.Markdown('### Output (Prompt #3)')
|
| 2957 |
-
l3_gallery = gr.Gallery(format='png', value=[], label="Prompt #3", show_label=False, elem_id="ncut_p3", columns=[3], rows=[5], object_fit="contain", height="
|
| 2958 |
prompt3 = gr.Textbox(label="Input Prompt #3", elem_id="prompt3", value="where is the floor", lines=3)
|
| 2959 |
|
| 2960 |
with gr.Row():
|
|
@@ -2986,7 +3175,7 @@ with demo:
|
|
| 2986 |
outputs=galleries + [logging_text],
|
| 2987 |
)
|
| 2988 |
|
| 2989 |
-
with gr.Tab('Model Aligned'):
|
| 2990 |
gr.Markdown('This page reproduce the results from the paper [AlignedCut](https://arxiv.org/abs/2406.18344)')
|
| 2991 |
gr.Markdown('---')
|
| 2992 |
gr.Markdown('**Features are aligned across models and layers.** A linear alignment transform is trained for each model/layer, learning signal comes from 1) fMRI brain activation and 2) segmentation preserving eigen-constraints.')
|
|
@@ -3049,30 +3238,17 @@ with demo:
|
|
| 3049 |
gr.Markdown('')
|
| 3050 |
gr.Markdown("To see a good pattern, you will need to load 100~1000 images. 100 images need 10sec for RTX4090. Running out of HuggingFace GPU Quota? Try [Demo](https://ncut-pytorch.readthedocs.io/en/latest/demo/) hosted at UPenn")
|
| 3051 |
gr.Markdown('---')
|
| 3052 |
-
|
| 3053 |
-
# with gr.Row():
|
| 3054 |
-
# with gr.Column(scale=5, min_width=200):
|
| 3055 |
-
# gr.Markdown('### Output (Recursion #1)')
|
| 3056 |
-
# l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=False, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
| 3057 |
-
# add_output_images_buttons(l1_gallery)
|
| 3058 |
-
# with gr.Column(scale=5, min_width=200):
|
| 3059 |
-
# gr.Markdown('### Output (Recursion #2)')
|
| 3060 |
-
# l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=False, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
| 3061 |
-
# add_output_images_buttons(l2_gallery)
|
| 3062 |
-
# with gr.Column(scale=5, min_width=200):
|
| 3063 |
-
# gr.Markdown('### Output (Recursion #3)')
|
| 3064 |
-
# l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=False, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
| 3065 |
-
# add_output_images_buttons(l3_gallery)
|
| 3066 |
gr.Markdown('### Output (Recursion #1)')
|
| 3067 |
-
l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[100], rows=[1], object_fit="contain", height="
|
| 3068 |
add_rotate_flip_buttons(l1_gallery)
|
| 3069 |
add_download_button(l1_gallery, "modelaligned_recur1")
|
| 3070 |
gr.Markdown('### Output (Recursion #2)')
|
| 3071 |
-
l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[100], rows=[1], object_fit="contain", height="
|
| 3072 |
add_rotate_flip_buttons(l2_gallery)
|
| 3073 |
add_download_button(l2_gallery, "modelaligned_recur2")
|
| 3074 |
gr.Markdown('### Output (Recursion #3)')
|
| 3075 |
-
l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[100], rows=[1], object_fit="contain", height="
|
| 3076 |
add_rotate_flip_buttons(l3_gallery)
|
| 3077 |
add_download_button(l3_gallery, "modelaligned_recur3")
|
| 3078 |
|
|
@@ -3141,7 +3317,7 @@ with demo:
|
|
| 3141 |
def add_one_model(i_model=1):
|
| 3142 |
with gr.Column(scale=5, min_width=200) as col:
|
| 3143 |
gr.Markdown(f'### Output Images')
|
| 3144 |
-
output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=False, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="
|
| 3145 |
submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
|
| 3146 |
add_rotate_flip_buttons(output_gallery)
|
| 3147 |
[
|
|
@@ -3248,16 +3424,16 @@ with demo:
|
|
| 3248 |
def add_one_model(i_model=1):
|
| 3249 |
with gr.Column(scale=5, min_width=200) as col:
|
| 3250 |
gr.Markdown(f'### Output Images')
|
| 3251 |
-
output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="
|
| 3252 |
submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
|
| 3253 |
add_rotate_flip_buttons(output_gallery)
|
| 3254 |
add_download_button(output_gallery, f"ncut_embed")
|
| 3255 |
-
mlp_gallery = gr.Gallery(format='png', value=[], label="MLP color align", show_label=True, elem_id=f"mlp{i_model}", columns=[3], rows=[1], object_fit="contain", height="
|
| 3256 |
add_mlp_fitting_buttons(output_gallery, mlp_gallery)
|
| 3257 |
add_download_button(mlp_gallery, f"mlp_color_align")
|
| 3258 |
-
norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id=f"eig_norm{i_model}", columns=[3], rows=[1], object_fit="contain", height="
|
| 3259 |
add_download_button(norm_gallery, f"eig_norm")
|
| 3260 |
-
cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id=f"clusters{i_model}", columns=[2], rows=[4], object_fit="contain", height="
|
| 3261 |
add_download_button(cluster_gallery, f"clusters")
|
| 3262 |
[
|
| 3263 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
|
@@ -3463,16 +3639,16 @@ with demo:
|
|
| 3463 |
def add_one_model(i_model=1):
|
| 3464 |
with gr.Column(scale=5, min_width=200) as col:
|
| 3465 |
gr.Markdown(f'### Output Images')
|
| 3466 |
-
output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="
|
| 3467 |
submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
|
| 3468 |
add_rotate_flip_buttons(output_gallery)
|
| 3469 |
add_download_button(output_gallery, f"ncut_embed")
|
| 3470 |
-
mlp_gallery = gr.Gallery(format='png', value=[], label="MLP color align", show_label=True, elem_id=f"mlp{i_model}", columns=[3], rows=[1], object_fit="contain", height="
|
| 3471 |
add_mlp_fitting_buttons(output_gallery, mlp_gallery)
|
| 3472 |
add_download_button(mlp_gallery, f"mlp_color_align")
|
| 3473 |
-
norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id=f"eig_norm{i_model}", columns=[3], rows=[1], object_fit="contain", height="
|
| 3474 |
add_download_button(norm_gallery, f"eig_norm")
|
| 3475 |
-
cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id=f"clusters{i_model}", columns=[2], rows=[4], object_fit="contain", height="
|
| 3476 |
add_download_button(cluster_gallery, f"clusters")
|
| 3477 |
[
|
| 3478 |
model_dropdown, layer_slider, node_type_dropdown, node_type_dropdown2, head_index_text, make_symmetric, num_eig_slider,
|
|
@@ -3655,7 +3831,7 @@ with demo:
|
|
| 3655 |
|
| 3656 |
with gr.Column(scale=5, min_width=200):
|
| 3657 |
gr.Markdown("### Step 3: Segment and Crop")
|
| 3658 |
-
mask_gallery = gr.Gallery(value=[], label="Segmentation Masks", show_label=True, elem_id="mask_gallery", columns=[3], rows=[1], object_fit="contain", height="
|
| 3659 |
run_crop_button = gr.Button("🔴 RUN", elem_id="run_crop_button", variant='primary')
|
| 3660 |
add_download_button(mask_gallery, "mask")
|
| 3661 |
distance_threshold_slider = gr.Slider(0, 1, step=0.01, label="Mask Threshold (FG)", value=0.9, elem_id="distance_threshold", info="increase for smaller FG mask")
|
|
@@ -3665,7 +3841,7 @@ with demo:
|
|
| 3665 |
overlay_image_checkbox = gr.Checkbox(label="Overlay Original Image", value=True, elem_id="overlay_image_checkbox")
|
| 3666 |
# filter_small_area_checkbox = gr.Checkbox(label="Noise Reduction", value=True, elem_id="filter_small_area_checkbox")
|
| 3667 |
distance_power_slider = gr.Slider(-3, 3, step=0.01, label="Distance Power", value=0.5, elem_id="distance_power", info="d = d^p", visible=False)
|
| 3668 |
-
crop_gallery = gr.Gallery(value=[], label="Cropped Images", show_label=True, elem_id="crop_gallery", columns=[3], rows=[1], object_fit="contain", height="
|
| 3669 |
add_download_button(crop_gallery, "cropped")
|
| 3670 |
crop_expand_slider = gr.Slider(1.0, 2.0, step=0.1, label="Crop bbox Expand Factor", value=1.0, elem_id="crop_expand", info="increase for larger crop", visible=True)
|
| 3671 |
area_threshold_slider = gr.Slider(0, 100, step=0.1, label="Area Threshold (%)", value=3, elem_id="area_threshold", info="for noise filtering (area of connected components)", visible=False)
|
|
@@ -4143,192 +4319,6 @@ with demo:
|
|
| 4143 |
outputs=[n_eig, current_idx, parent_plot, current_plot, *child_plots, child_idx],
|
| 4144 |
)
|
| 4145 |
|
| 4146 |
-
with gr.Tab('PlayGround (eig)', visible=True) as test_playground_tab2:
|
| 4147 |
-
eigvecs = gr.State(np.array([]))
|
| 4148 |
-
with gr.Row():
|
| 4149 |
-
with gr.Column(scale=5, min_width=200):
|
| 4150 |
-
gr.Markdown("### Step 1: Load Images")
|
| 4151 |
-
input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(n_example_images=100)
|
| 4152 |
-
submit_button.visible = False
|
| 4153 |
-
num_images_slider.value = 30
|
| 4154 |
-
|
| 4155 |
-
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
| 4156 |
-
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
| 4157 |
-
|
| 4158 |
-
|
| 4159 |
-
with gr.Column(scale=5, min_width=200):
|
| 4160 |
-
gr.Markdown("### Step 2a: Run Backbone and NCUT")
|
| 4161 |
-
with gr.Accordion(label="Backbone Parameters", visible=True, open=False):
|
| 4162 |
-
[
|
| 4163 |
-
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
| 4164 |
-
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
| 4165 |
-
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
| 4166 |
-
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
| 4167 |
-
sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
|
| 4168 |
-
] = make_parameters_section(ncut_parameter_dropdown=False, tsne_parameter_dropdown=False)
|
| 4169 |
-
num_eig_slider.value = 1024
|
| 4170 |
-
num_eig_slider.visible = False
|
| 4171 |
-
submit_button = gr.Button("🔴 RUN NCUT", elem_id="run_ncut", variant='primary')
|
| 4172 |
-
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
|
| 4173 |
-
submit_button.click(
|
| 4174 |
-
partial(run_fn, n_ret=1, only_eigvecs=True),
|
| 4175 |
-
inputs=[
|
| 4176 |
-
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
| 4177 |
-
positive_prompt, negative_prompt,
|
| 4178 |
-
false_placeholder, no_prompt, no_prompt, no_prompt,
|
| 4179 |
-
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
| 4180 |
-
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
| 4181 |
-
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown
|
| 4182 |
-
],
|
| 4183 |
-
outputs=[eigvecs, logging_text],
|
| 4184 |
-
)
|
| 4185 |
-
gr.Markdown("### Step 2b: Pick an Image")
|
| 4186 |
-
from gradio_image_prompter import ImagePrompter
|
| 4187 |
-
with gr.Row():
|
| 4188 |
-
image1_slider = gr.Slider(0, 100, step=1, label="Image#1 Index", value=0, elem_id="image1_slider", interactive=True)
|
| 4189 |
-
load_one_image_button = gr.Button("🔴 Load Image", elem_id="load_one_image_button", variant='primary')
|
| 4190 |
-
gr.Markdown("### Step 2c: Draw a Point")
|
| 4191 |
-
gr.Markdown("""
|
| 4192 |
-
<h5>
|
| 4193 |
-
🖱️ Left Click: Foreground </br>
|
| 4194 |
-
</h5>
|
| 4195 |
-
""")
|
| 4196 |
-
prompt_image1 = ImagePrompter(show_label=False, elem_id="prompt_image1", interactive=False)
|
| 4197 |
-
def update_prompt_image(original_images, index):
|
| 4198 |
-
images = original_images
|
| 4199 |
-
if images is None:
|
| 4200 |
-
return
|
| 4201 |
-
total_len = len(images)
|
| 4202 |
-
if total_len == 0:
|
| 4203 |
-
return
|
| 4204 |
-
if index >= total_len:
|
| 4205 |
-
index = total_len - 1
|
| 4206 |
-
|
| 4207 |
-
return ImagePrompter(value={'image': images[index][0], 'points': []}, interactive=True)
|
| 4208 |
-
# return gr.Image(value=images[index][0], elem_id=f"prompt_image{randint}", interactive=True)
|
| 4209 |
-
load_one_image_button.click(update_prompt_image, inputs=[input_gallery, image1_slider], outputs=[prompt_image1])
|
| 4210 |
-
|
| 4211 |
-
child_idx = gr.State([])
|
| 4212 |
-
current_idx = gr.State(None)
|
| 4213 |
-
n_eig = gr.State(64)
|
| 4214 |
-
with gr.Column(scale=5, min_width=200):
|
| 4215 |
-
gr.Markdown("### Step 3: Check groupping")
|
| 4216 |
-
child_distance_slider = gr.Slider(0, 0.5, step=0.001, label="Child Distance", value=0.1, elem_id="child_distance_slider", interactive=True)
|
| 4217 |
-
child_distance_slider.visible = False
|
| 4218 |
-
overlay_image_checkbox = gr.Checkbox(label="Overlay Image", value=True, elem_id="overlay_image_checkbox", interactive=True)
|
| 4219 |
-
n_eig_slider = gr.Slider(0, 1024, step=1, label="Number of Eigenvectors", value=256, elem_id="n_eig_slider", interactive=True)
|
| 4220 |
-
run_button = gr.Button("🔴 RUN", elem_id="run_groupping", variant='primary')
|
| 4221 |
-
with gr.Row():
|
| 4222 |
-
doublue_eigs_button = gr.Button("⬇️ +eigvecs", elem_id="doublue_eigs_button", variant='secondary')
|
| 4223 |
-
half_eigs_button = gr.Button("⬆️ -eigvecs", elem_id="half_eigs_button", variant='secondary')
|
| 4224 |
-
current_plot = gr.Gallery(value=None, label="Current", show_label=True, elem_id="current_plot", interactive=False, rows=[1], columns=[2])
|
| 4225 |
-
|
| 4226 |
-
def relative_xy(prompts):
|
| 4227 |
-
image = prompts['image']
|
| 4228 |
-
points = np.asarray(prompts['points'])
|
| 4229 |
-
if points.shape[0] == 0:
|
| 4230 |
-
return [], []
|
| 4231 |
-
is_point = points[:, 5] == 4.0
|
| 4232 |
-
points = points[is_point]
|
| 4233 |
-
is_positive = points[:, 2] == 1.0
|
| 4234 |
-
is_negative = points[:, 2] == 0.0
|
| 4235 |
-
xy = points[:, :2].tolist()
|
| 4236 |
-
if isinstance(image, str):
|
| 4237 |
-
image = Image.open(image)
|
| 4238 |
-
image = np.array(image)
|
| 4239 |
-
h, w = image.shape[:2]
|
| 4240 |
-
new_xy = [(x/w, y/h) for x, y in xy]
|
| 4241 |
-
# print(new_xy)
|
| 4242 |
-
return new_xy, is_positive
|
| 4243 |
-
|
| 4244 |
-
def xy_eigvec(prompts, image_idx, eigvecs):
|
| 4245 |
-
eigvec = eigvecs[image_idx]
|
| 4246 |
-
xy, is_positive = relative_xy(prompts)
|
| 4247 |
-
for i, (x, y) in enumerate(xy):
|
| 4248 |
-
if not is_positive[i]:
|
| 4249 |
-
continue
|
| 4250 |
-
x = int(x * eigvec.shape[1])
|
| 4251 |
-
y = int(y * eigvec.shape[0])
|
| 4252 |
-
return eigvec[y, x], (y, x)
|
| 4253 |
-
|
| 4254 |
-
from ncut_pytorch.ncut_pytorch import _transform_heatmap
|
| 4255 |
-
def _run_heatmap_fn(images, eigvecs, prompt_image_idx, prompt_points, n_eig, flat_idx=None, raw_heatmap=False, overlay_image=True):
|
| 4256 |
-
left = eigvecs[..., :n_eig]
|
| 4257 |
-
if flat_idx is not None:
|
| 4258 |
-
right = eigvecs.reshape(-1, eigvecs.shape[-1])[flat_idx]
|
| 4259 |
-
y, x = None, None
|
| 4260 |
-
else:
|
| 4261 |
-
right, (y, x) = xy_eigvec(prompt_points, prompt_image_idx, eigvecs)
|
| 4262 |
-
right = right[:n_eig]
|
| 4263 |
-
left = F.normalize(left, p=2, dim=-1)
|
| 4264 |
-
_right = F.normalize(right, p=2, dim=-1)
|
| 4265 |
-
heatmap = left @ _right.unsqueeze(-1)
|
| 4266 |
-
heatmap = heatmap.squeeze(-1)
|
| 4267 |
-
# heatmap = 1 - heatmap
|
| 4268 |
-
# heatmap = _transform_heatmap(heatmap)
|
| 4269 |
-
if raw_heatmap:
|
| 4270 |
-
return heatmap
|
| 4271 |
-
# apply hot colormap and covert to PIL image 256x256
|
| 4272 |
-
# gr.Info(f"heatmap vmin: {heatmap.min()}, vmax: {heatmap.max()}, mean: {heatmap.mean()}")
|
| 4273 |
-
heatmap = heatmap.cpu().numpy()
|
| 4274 |
-
hot_map = matplotlib.colormaps['hot']
|
| 4275 |
-
heatmap = hot_map(heatmap)
|
| 4276 |
-
pil_images = to_pil_images(torch.tensor(heatmap), target_size=256, force_size=True)
|
| 4277 |
-
if overlay_image:
|
| 4278 |
-
overlaied_images = []
|
| 4279 |
-
for i_image in range(len(images)):
|
| 4280 |
-
rgb_image = images[i_image].resize((256, 256))
|
| 4281 |
-
rgb_image = np.array(rgb_image)
|
| 4282 |
-
heatmap_image = np.array(pil_images[i_image])[..., :3]
|
| 4283 |
-
blend_image = 0.5 * rgb_image + 0.5 * heatmap_image
|
| 4284 |
-
blend_image = Image.fromarray(blend_image.astype(np.uint8))
|
| 4285 |
-
overlaied_images.append(blend_image)
|
| 4286 |
-
pil_images = overlaied_images
|
| 4287 |
-
return pil_images, (y, x)
|
| 4288 |
-
|
| 4289 |
-
@torch.no_grad()
|
| 4290 |
-
def run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx=None, overlay_image=True):
|
| 4291 |
-
gr.Info(f"current number of eigenvectors: {n_eig}", 2)
|
| 4292 |
-
eigvecs = torch.tensor(eigvecs)
|
| 4293 |
-
image1_slider = min(image1_slider, len(images)-1)
|
| 4294 |
-
images = [image[0] for image in images]
|
| 4295 |
-
if isinstance(images[0], str):
|
| 4296 |
-
images = [Image.open(image[0]).convert("RGB").resize((256, 256)) for image in images]
|
| 4297 |
-
|
| 4298 |
-
current_heatmap, (y, x) = _run_heatmap_fn(images, eigvecs, image1_slider, prompt_image1, n_eig, flat_idx, overlay_image=overlay_image)
|
| 4299 |
-
|
| 4300 |
-
return current_heatmap
|
| 4301 |
-
|
| 4302 |
-
def doublue_eigs_wrapper(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx=None, overlay_image=True):
|
| 4303 |
-
n_eig = int(n_eig*2)
|
| 4304 |
-
n_eig = min(n_eig, eigvecs.shape[-1])
|
| 4305 |
-
n_eig = max(n_eig, 1)
|
| 4306 |
-
return gr.update(value=n_eig), run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx, overlay_image=overlay_image)
|
| 4307 |
-
|
| 4308 |
-
def half_eigs_wrapper(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, current_idx=None, overlay_image=True):
|
| 4309 |
-
n_eig = int(n_eig/2)
|
| 4310 |
-
n_eig = min(n_eig, eigvecs.shape[-1])
|
| 4311 |
-
n_eig = max(n_eig, 1)
|
| 4312 |
-
return gr.update(value=n_eig), run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, current_idx, overlay_image=overlay_image)
|
| 4313 |
-
|
| 4314 |
-
none_placeholder = gr.State(None)
|
| 4315 |
-
run_button.click(
|
| 4316 |
-
run_heatmap,
|
| 4317 |
-
inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig_slider, child_distance_slider, none_placeholder, overlay_image_checkbox],
|
| 4318 |
-
outputs=[current_plot],
|
| 4319 |
-
)
|
| 4320 |
-
|
| 4321 |
-
doublue_eigs_button.click(
|
| 4322 |
-
doublue_eigs_wrapper,
|
| 4323 |
-
inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig_slider, child_distance_slider, none_placeholder, overlay_image_checkbox],
|
| 4324 |
-
outputs=[n_eig_slider, current_plot],
|
| 4325 |
-
)
|
| 4326 |
-
|
| 4327 |
-
half_eigs_button.click(
|
| 4328 |
-
half_eigs_wrapper,
|
| 4329 |
-
inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig_slider, child_distance_slider, current_idx, overlay_image_checkbox],
|
| 4330 |
-
outputs=[n_eig_slider, current_plot],
|
| 4331 |
-
)
|
| 4332 |
|
| 4333 |
with gr.Tab('📄About'):
|
| 4334 |
with gr.Column():
|
|
@@ -4372,30 +4362,18 @@ with demo:
|
|
| 4372 |
label = "".join(label)
|
| 4373 |
return n_smiles, gr.update(label=label, value=False)
|
| 4374 |
|
| 4375 |
-
def unlock_tabs_with_info(n_smiles):
|
| 4376 |
-
if n_smiles == unlock_value:
|
| 4377 |
-
gr.Info("🔓 unlocked tabs", 2)
|
| 4378 |
-
return gr.update(visible=True)
|
| 4379 |
-
return gr.update()
|
| 4380 |
|
| 4381 |
-
def unlock_tabs(n_smiles):
|
| 4382 |
if n_smiles == unlock_value:
|
| 4383 |
-
|
| 4384 |
-
|
|
|
|
| 4385 |
|
| 4386 |
hidden_button.change(update_smile, [n_smiles], [n_smiles, hidden_button])
|
| 4387 |
-
|
| 4388 |
-
|
| 4389 |
-
hidden_button.change(unlock_tabs, n_smiles,
|
| 4390 |
-
hidden_button.change(unlock_tabs, n_smiles, tab_compare_models_advanced)
|
| 4391 |
-
hidden_button.change(unlock_tabs, n_smiles, tab_directed_ncut)
|
| 4392 |
-
hidden_button.change(unlock_tabs, n_smiles, test_playground_tab)
|
| 4393 |
|
| 4394 |
-
# with gr.Row():
|
| 4395 |
-
# with gr.Column():
|
| 4396 |
-
# gr.Markdown("##### This demo is for `ncut-pytorch`, [Documentation](https://ncut-pytorch.readthedocs.io/) ")
|
| 4397 |
-
# with gr.Column():
|
| 4398 |
-
# gr.Markdown("###### Running out of GPU Quota? Try [Demo](https://ncut-pytorch.readthedocs.io/en/latest/demo/) hosted at UPenn")
|
| 4399 |
with gr.Row():
|
| 4400 |
gr.Markdown("**This demo is for Python package `ncut-pytorch`, [Documentation](https://ncut-pytorch.readthedocs.io/)**")
|
| 4401 |
|
|
|
|
| 1702 |
gr.Info(f"Total images: {len(existing_images)}")
|
| 1703 |
return existing_images
|
| 1704 |
|
| 1705 |
+
def make_input_images_section(rows=1, cols=3, height="450px", advanced=False, is_random=False, allow_download=False, markdown=True, n_example_images=100):
|
| 1706 |
if markdown:
|
| 1707 |
gr.Markdown('### Input Images')
|
| 1708 |
input_gallery = gr.Gallery(value=None, label="Input images", show_label=True, elem_id="input_images", columns=[cols], rows=[rows], object_fit="contain", height=height, type="pil", show_share_button=False,
|
|
|
|
| 1750 |
# ("CatDog", ['./images/catdog1.jpg', './images/catdog2.jpg', './images/catdog3.jpg'], "microsoft/cats_vs_dogs"),
|
| 1751 |
# ("Bird", ['./images/bird1.jpg', './images/bird2.jpg', './images/bird3.jpg'], "Multimodal-Fatima/CUB_train"),
|
| 1752 |
# ("ChestXray", ['./images/chestxray1.jpg', './images/chestxray2.jpg', './images/chestxray3.jpg'], "hongrui/mimic_chest_xray_v_1"),
|
| 1753 |
+
("MRI", ['./images/brain1.jpg', './images/brain2.jpg', './images/brain3.jpg'], "sartajbhuvaji/Brain-Tumor-Classification"),
|
| 1754 |
("Kanji", ['./images/kanji1.jpg', './images/kanji2.jpg', './images/kanji3.jpg'], "yashvoladoddi37/kanjienglish"),
|
| 1755 |
]
|
| 1756 |
for name, images, dataset_name in example_items:
|
|
|
|
| 2073 |
def make_output_images_section(markdown=True, button=True):
|
| 2074 |
if markdown:
|
| 2075 |
gr.Markdown('### Output Images')
|
| 2076 |
+
output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, interactive=False)
|
| 2077 |
if button:
|
| 2078 |
add_rotate_flip_buttons(output_gallery)
|
| 2079 |
return output_gallery
|
|
|
|
| 2202 |
with gr.Column(scale=5, min_width=200):
|
| 2203 |
# gr.Markdown("### Step 2a: Run Backbone and NCUT")
|
| 2204 |
# with gr.Accordion(label="Backbone Parameters", visible=True, open=False):
|
| 2205 |
+
output_gallery = gr.Gallery(format='png', value=[], label="NCUT spectral-tSNE", show_label=True, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, interactive=False)
|
| 2206 |
def add_rotate_flip_buttons_with_state(output_gallery, tsne3d_rgb):
|
| 2207 |
with gr.Row():
|
| 2208 |
rotate_button = gr.Button("🔄 Rotate", elem_id="rotate_button", variant='secondary')
|
|
|
|
| 2414 |
gr.Markdown("Known Issue: Resize the browser window will break the clicking, please refresh the page.")
|
| 2415 |
with gr.Accordion("Outputs", open=True):
|
| 2416 |
gr.Markdown("""
|
| 2417 |
+
1. spectral-tSNE tree: ◆ marker is the N points, connected components to the clicked .
|
| 2418 |
2. Cluster Heatmap: max of N cosine similarity to N points in the connected components.
|
| 2419 |
""")
|
| 2420 |
with gr.Column(scale=5, min_width=200):
|
|
|
|
| 2448 |
with gr.Column(scale=5, min_width=200):
|
| 2449 |
output_tree_image = gr.Image(label=f"spectral-tSNE tree [row#{i_row}]", elem_id="output_image", interactive=False)
|
| 2450 |
text_block = gr.Textbox("", label="Logging", elem_id=f"logging_{i_row}", type="text", placeholder="Logging information", autofocus=False, autoscroll=False, lines=2, show_label=False)
|
| 2451 |
+
delete_button = gr.Button("❌ Delete", elem_id=f"delete_button_{i_row}", variant='secondary')
|
| 2452 |
with gr.Column(scale=10, min_width=200):
|
| 2453 |
+
heatmap_gallery = gr.Gallery(format='png', value=[], label=f"Cluster Heatmap [row#{i_row}]", show_label=True, elem_id="heatmap", columns=[6], rows=[1], object_fit="contain", height="500px", show_share_button=True, interactive=False)
|
| 2454 |
+
delete_button.click(lambda: gr.update(visible=False), outputs=[inspect_output_row])
|
| 2455 |
return inspect_output_row, output_tree_image, heatmap_gallery, text_block
|
| 2456 |
|
| 2457 |
gr.Markdown('---')
|
|
|
|
| 2602 |
outputs=inspect_output_rows + output_tree_images + heatmap_galleries + text_blocks + [current_output_row, inspect_logging_text],
|
| 2603 |
)
|
| 2604 |
|
| 2605 |
+
with gr.Tab('PlayGround (eig)', visible=True) as test_playground_tab2:
|
| 2606 |
+
eigvecs = gr.State(np.array([]))
|
| 2607 |
+
with gr.Row():
|
| 2608 |
+
with gr.Column(scale=5, min_width=200):
|
| 2609 |
+
gr.Markdown("### Step 1: Load Images")
|
| 2610 |
+
input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(n_example_images=10)
|
| 2611 |
+
submit_button.visible = False
|
| 2612 |
+
num_images_slider.value = 30
|
| 2613 |
+
|
| 2614 |
+
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
| 2615 |
+
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
| 2616 |
+
|
| 2617 |
+
|
| 2618 |
+
with gr.Column(scale=5, min_width=200):
|
| 2619 |
+
gr.Markdown("### Step 2a: Run Backbone and NCUT")
|
| 2620 |
+
with gr.Accordion(label="Backbone Parameters", visible=True, open=False):
|
| 2621 |
+
[
|
| 2622 |
+
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
| 2623 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
| 2624 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
| 2625 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
| 2626 |
+
sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
|
| 2627 |
+
] = make_parameters_section(ncut_parameter_dropdown=False, tsne_parameter_dropdown=False)
|
| 2628 |
+
num_eig_slider.value = 1024
|
| 2629 |
+
num_eig_slider.visible = False
|
| 2630 |
+
submit_button = gr.Button("🔴 RUN NCUT", elem_id="run_ncut", variant='primary')
|
| 2631 |
+
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
|
| 2632 |
+
submit_button.click(
|
| 2633 |
+
partial(run_fn, n_ret=1, only_eigvecs=True),
|
| 2634 |
+
inputs=[
|
| 2635 |
+
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
| 2636 |
+
positive_prompt, negative_prompt,
|
| 2637 |
+
false_placeholder, no_prompt, no_prompt, no_prompt,
|
| 2638 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
| 2639 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
| 2640 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown
|
| 2641 |
+
],
|
| 2642 |
+
outputs=[eigvecs, logging_text],
|
| 2643 |
+
)
|
| 2644 |
+
gr.Markdown("### Step 2b: Pick an Image and Draw a Point")
|
| 2645 |
+
from gradio_image_prompter import ImagePrompter
|
| 2646 |
+
image1_slider = gr.Slider(0, 0, step=1, label="Image#1 Index", value=0, elem_id="image1_slider", interactive=True)
|
| 2647 |
+
load_one_image_button = gr.Button("🔴 Load Image", elem_id="load_one_image_button", variant='primary')
|
| 2648 |
+
gr.Markdown("""
|
| 2649 |
+
<h5>
|
| 2650 |
+
🖱️ Left Click: Foreground </br>
|
| 2651 |
+
</h5>
|
| 2652 |
+
""")
|
| 2653 |
+
prompt_image1 = ImagePrompter(show_label=False, elem_id="prompt_image1", interactive=False)
|
| 2654 |
+
def update_prompt_image(original_images, index):
|
| 2655 |
+
images = original_images
|
| 2656 |
+
if images is None:
|
| 2657 |
+
return gr.update()
|
| 2658 |
+
total_len = len(images)
|
| 2659 |
+
if total_len == 0:
|
| 2660 |
+
return gr.update()
|
| 2661 |
+
if index >= total_len:
|
| 2662 |
+
index = total_len - 1
|
| 2663 |
+
return gr.update(value={'image': images[index][0], 'points': []}, interactive=True)
|
| 2664 |
+
load_one_image_button.click(update_prompt_image, inputs=[input_gallery, image1_slider], outputs=[prompt_image1])
|
| 2665 |
+
input_gallery.change(update_prompt_image, inputs=[input_gallery, image1_slider], outputs=[prompt_image1])
|
| 2666 |
+
input_gallery.change(fn=lambda x: gr.update(maximum=len(x)-1), inputs=[input_gallery], outputs=[image1_slider])
|
| 2667 |
+
image1_slider.change(update_prompt_image, inputs=[input_gallery, image1_slider], outputs=[prompt_image1])
|
| 2668 |
+
|
| 2669 |
+
child_idx = gr.State([])
|
| 2670 |
+
current_idx = gr.State(None)
|
| 2671 |
+
n_eig = gr.State(64)
|
| 2672 |
+
with gr.Column(scale=5, min_width=200):
|
| 2673 |
+
gr.Markdown("### Step 3: Check groupping")
|
| 2674 |
+
child_distance_slider = gr.Slider(0, 0.5, step=0.001, label="Child Distance", value=0.1, elem_id="child_distance_slider", interactive=True)
|
| 2675 |
+
child_distance_slider.visible = False
|
| 2676 |
+
overlay_image_checkbox = gr.Checkbox(label="Overlay Image", value=True, elem_id="overlay_image_checkbox", interactive=True)
|
| 2677 |
+
n_eig_slider = gr.Slider(0, 1024, step=1, label="Number of Eigenvectors", value=256, elem_id="n_eig_slider", interactive=True)
|
| 2678 |
+
run_button = gr.Button("🔴 RUN", elem_id="run_groupping", variant='primary')
|
| 2679 |
+
gr.Markdown("1. 🔴 RUN </br>2. repeat: [+num_eigvecs] / [-num_eigvecs]")
|
| 2680 |
+
with gr.Row():
|
| 2681 |
+
doublue_eigs_button = gr.Button("[+num_eigvecs]", elem_id="doublue_eigs_button", variant='secondary')
|
| 2682 |
+
half_eigs_button = gr.Button("[-num_eigvecs]", elem_id="half_eigs_button", variant='secondary')
|
| 2683 |
+
current_plot = gr.Gallery(value=None, label="Current", show_label=True, elem_id="current_plot", interactive=False, rows=[1], columns=[2])
|
| 2684 |
+
|
| 2685 |
+
def relative_xy(prompts):
|
| 2686 |
+
image = prompts['image']
|
| 2687 |
+
points = np.asarray(prompts['points'])
|
| 2688 |
+
if points.shape[0] == 0:
|
| 2689 |
+
return [], []
|
| 2690 |
+
is_point = points[:, 5] == 4.0
|
| 2691 |
+
points = points[is_point]
|
| 2692 |
+
is_positive = points[:, 2] == 1.0
|
| 2693 |
+
is_negative = points[:, 2] == 0.0
|
| 2694 |
+
xy = points[:, :2].tolist()
|
| 2695 |
+
if isinstance(image, str):
|
| 2696 |
+
image = Image.open(image)
|
| 2697 |
+
image = np.array(image)
|
| 2698 |
+
h, w = image.shape[:2]
|
| 2699 |
+
new_xy = [(x/w, y/h) for x, y in xy]
|
| 2700 |
+
# print(new_xy)
|
| 2701 |
+
return new_xy, is_positive
|
| 2702 |
+
|
| 2703 |
+
def xy_eigvec(prompts, image_idx, eigvecs):
|
| 2704 |
+
eigvec = eigvecs[image_idx]
|
| 2705 |
+
xy, is_positive = relative_xy(prompts)
|
| 2706 |
+
for i, (x, y) in enumerate(xy):
|
| 2707 |
+
if not is_positive[i]:
|
| 2708 |
+
continue
|
| 2709 |
+
x = int(x * eigvec.shape[1])
|
| 2710 |
+
y = int(y * eigvec.shape[0])
|
| 2711 |
+
return eigvec[y, x], (y, x)
|
| 2712 |
+
|
| 2713 |
+
from ncut_pytorch.ncut_pytorch import _transform_heatmap
|
| 2714 |
+
def _run_heatmap_fn(images, eigvecs, prompt_image_idx, prompt_points, n_eig, flat_idx=None, raw_heatmap=False, overlay_image=True):
|
| 2715 |
+
left = eigvecs[..., :n_eig]
|
| 2716 |
+
if flat_idx is not None:
|
| 2717 |
+
right = eigvecs.reshape(-1, eigvecs.shape[-1])[flat_idx]
|
| 2718 |
+
y, x = None, None
|
| 2719 |
+
else:
|
| 2720 |
+
right, (y, x) = xy_eigvec(prompt_points, prompt_image_idx, eigvecs)
|
| 2721 |
+
right = right[:n_eig]
|
| 2722 |
+
left = F.normalize(left, p=2, dim=-1)
|
| 2723 |
+
_right = F.normalize(right, p=2, dim=-1)
|
| 2724 |
+
heatmap = left @ _right.unsqueeze(-1)
|
| 2725 |
+
heatmap = heatmap.squeeze(-1)
|
| 2726 |
+
# heatmap = 1 - heatmap
|
| 2727 |
+
# heatmap = _transform_heatmap(heatmap)
|
| 2728 |
+
if raw_heatmap:
|
| 2729 |
+
return heatmap
|
| 2730 |
+
# apply hot colormap and covert to PIL image 256x256
|
| 2731 |
+
# gr.Info(f"heatmap vmin: {heatmap.min()}, vmax: {heatmap.max()}, mean: {heatmap.mean()}")
|
| 2732 |
+
heatmap = heatmap.cpu().numpy()
|
| 2733 |
+
hot_map = matplotlib.colormaps['hot']
|
| 2734 |
+
heatmap = hot_map(heatmap)
|
| 2735 |
+
pil_images = to_pil_images(torch.tensor(heatmap), target_size=256, force_size=True)
|
| 2736 |
+
if overlay_image:
|
| 2737 |
+
overlaied_images = []
|
| 2738 |
+
for i_image in range(len(images)):
|
| 2739 |
+
rgb_image = images[i_image].resize((256, 256))
|
| 2740 |
+
rgb_image = np.array(rgb_image)
|
| 2741 |
+
heatmap_image = np.array(pil_images[i_image])[..., :3]
|
| 2742 |
+
blend_image = 0.5 * rgb_image + 0.5 * heatmap_image
|
| 2743 |
+
blend_image = Image.fromarray(blend_image.astype(np.uint8))
|
| 2744 |
+
overlaied_images.append(blend_image)
|
| 2745 |
+
pil_images = overlaied_images
|
| 2746 |
+
return pil_images, (y, x)
|
| 2747 |
+
|
| 2748 |
+
@torch.no_grad()
|
| 2749 |
+
def run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx=None, overlay_image=True):
|
| 2750 |
+
gr.Info(f"current number of eigenvectors: {n_eig}", 2)
|
| 2751 |
+
eigvecs = torch.tensor(eigvecs)
|
| 2752 |
+
image1_slider = min(image1_slider, len(images)-1)
|
| 2753 |
+
images = [image[0] for image in images]
|
| 2754 |
+
if isinstance(images[0], str):
|
| 2755 |
+
images = [Image.open(image[0]).convert("RGB").resize((256, 256)) for image in images]
|
| 2756 |
+
|
| 2757 |
+
current_heatmap, (y, x) = _run_heatmap_fn(images, eigvecs, image1_slider, prompt_image1, n_eig, flat_idx, overlay_image=overlay_image)
|
| 2758 |
+
|
| 2759 |
+
return current_heatmap
|
| 2760 |
+
|
| 2761 |
+
def doublue_eigs_wrapper(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx=None, overlay_image=True):
|
| 2762 |
+
n_eig = int(n_eig*2)
|
| 2763 |
+
n_eig = min(n_eig, eigvecs.shape[-1])
|
| 2764 |
+
n_eig = max(n_eig, 1)
|
| 2765 |
+
return gr.update(value=n_eig), run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx, overlay_image=overlay_image)
|
| 2766 |
+
|
| 2767 |
+
def half_eigs_wrapper(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, current_idx=None, overlay_image=True):
|
| 2768 |
+
n_eig = int(n_eig/2)
|
| 2769 |
+
n_eig = min(n_eig, eigvecs.shape[-1])
|
| 2770 |
+
n_eig = max(n_eig, 1)
|
| 2771 |
+
return gr.update(value=n_eig), run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, current_idx, overlay_image=overlay_image)
|
| 2772 |
+
|
| 2773 |
+
none_placeholder = gr.State(None)
|
| 2774 |
+
run_button.click(
|
| 2775 |
+
run_heatmap,
|
| 2776 |
+
inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig_slider, child_distance_slider, none_placeholder, overlay_image_checkbox],
|
| 2777 |
+
outputs=[current_plot],
|
| 2778 |
+
)
|
| 2779 |
+
|
| 2780 |
+
doublue_eigs_button.click(
|
| 2781 |
+
doublue_eigs_wrapper,
|
| 2782 |
+
inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig_slider, child_distance_slider, none_placeholder, overlay_image_checkbox],
|
| 2783 |
+
outputs=[n_eig_slider, current_plot],
|
| 2784 |
+
)
|
| 2785 |
+
|
| 2786 |
+
half_eigs_button.click(
|
| 2787 |
+
half_eigs_wrapper,
|
| 2788 |
+
inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig_slider, child_distance_slider, current_idx, overlay_image_checkbox],
|
| 2789 |
+
outputs=[n_eig_slider, current_plot],
|
| 2790 |
+
)
|
| 2791 |
+
|
| 2792 |
with gr.Tab('AlignedCut'):
|
| 2793 |
|
| 2794 |
with gr.Row():
|
|
|
|
| 2799 |
|
| 2800 |
with gr.Column(scale=5, min_width=200):
|
| 2801 |
output_gallery = make_output_images_section()
|
| 2802 |
+
# cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[2], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
| 2803 |
[
|
| 2804 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
| 2805 |
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
|
|
|
| 2812 |
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
| 2813 |
|
| 2814 |
submit_button.click(
|
| 2815 |
+
partial(run_fn, n_ret=1, plot_clusters=False),
|
| 2816 |
inputs=[
|
| 2817 |
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
| 2818 |
positive_prompt, negative_prompt,
|
|
|
|
| 2821 |
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
| 2822 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown
|
| 2823 |
],
|
| 2824 |
+
outputs=[output_gallery, logging_text],
|
| 2825 |
api_name="API_AlignedCut",
|
| 2826 |
scroll_to_output=True,
|
| 2827 |
)
|
|
|
|
| 2837 |
with gr.Column(scale=5, min_width=200):
|
| 2838 |
output_gallery = make_output_images_section()
|
| 2839 |
add_download_button(output_gallery, "ncut_embed")
|
| 2840 |
+
norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False)
|
| 2841 |
add_download_button(norm_gallery, "eig_norm")
|
| 2842 |
+
cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False)
|
| 2843 |
add_download_button(cluster_gallery, "clusters")
|
| 2844 |
[
|
| 2845 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
|
|
|
| 2929 |
api_name="API_NCut",
|
| 2930 |
)
|
| 2931 |
|
| 2932 |
+
with gr.Tab('RecursiveCut'):
|
| 2933 |
gr.Markdown('NCUT can be applied recursively, the eigenvectors from previous iteration is the input for the next iteration NCUT. ')
|
| 2934 |
gr.Markdown('__Recursive NCUT__ can amplify or weaken the connections, depending on the `affinity_focal_gamma` setting, please see [Documentation](https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/#recursive-ncut)')
|
| 2935 |
|
|
|
|
| 2942 |
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
|
| 2943 |
with gr.Column(scale=5, min_width=200):
|
| 2944 |
gr.Markdown('### Output (Recursion #1)')
|
| 2945 |
+
l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
|
| 2946 |
add_rotate_flip_buttons(l1_gallery)
|
| 2947 |
with gr.Column(scale=5, min_width=200):
|
| 2948 |
gr.Markdown('### Output (Recursion #2)')
|
| 2949 |
+
l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
|
| 2950 |
add_rotate_flip_buttons(l2_gallery)
|
| 2951 |
with gr.Column(scale=5, min_width=200):
|
| 2952 |
gr.Markdown('### Output (Recursion #3)')
|
| 2953 |
+
l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
|
| 2954 |
add_rotate_flip_buttons(l3_gallery)
|
| 2955 |
with gr.Row():
|
| 2956 |
|
|
|
|
| 2998 |
api_name="API_RecursiveCut"
|
| 2999 |
)
|
| 3000 |
|
| 3001 |
+
with gr.Tab('RecursiveCut (Advanced)', visible=False) as tab_recursivecut_advanced:
|
| 3002 |
|
| 3003 |
with gr.Row():
|
| 3004 |
with gr.Column(scale=5, min_width=200):
|
|
|
|
| 3007 |
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", lines=20)
|
| 3008 |
with gr.Column(scale=5, min_width=200):
|
| 3009 |
gr.Markdown('### Output (Recursion #1)')
|
| 3010 |
+
l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
|
| 3011 |
add_rotate_flip_buttons(l1_gallery)
|
| 3012 |
add_download_button(l1_gallery, "ncut_embed_recur1")
|
| 3013 |
+
l1_norm_gallery = gr.Gallery(value=[], label="Recursion #1 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False)
|
| 3014 |
add_download_button(l1_norm_gallery, "eig_norm_recur1")
|
| 3015 |
+
l1_cluster_gallery = gr.Gallery(value=[], label="Recursion #1 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height='450px', show_share_button=True, preview=False, interactive=False)
|
| 3016 |
add_download_button(l1_cluster_gallery, "clusters_recur1")
|
| 3017 |
with gr.Column(scale=5, min_width=200):
|
| 3018 |
gr.Markdown('### Output (Recursion #2)')
|
| 3019 |
+
l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
|
| 3020 |
add_rotate_flip_buttons(l2_gallery)
|
| 3021 |
add_download_button(l2_gallery, "ncut_embed_recur2")
|
| 3022 |
+
l2_norm_gallery = gr.Gallery(value=[], label="Recursion #2 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False)
|
| 3023 |
add_download_button(l2_norm_gallery, "eig_norm_recur2")
|
| 3024 |
+
l2_cluster_gallery = gr.Gallery(value=[], label="Recursion #2 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height='450px', show_share_button=True, preview=False, interactive=False)
|
| 3025 |
add_download_button(l2_cluster_gallery, "clusters_recur2")
|
| 3026 |
with gr.Column(scale=5, min_width=200):
|
| 3027 |
gr.Markdown('### Output (Recursion #3)')
|
| 3028 |
+
l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
|
| 3029 |
add_rotate_flip_buttons(l3_gallery)
|
| 3030 |
add_download_button(l3_gallery, "ncut_embed_recur3")
|
| 3031 |
+
l3_norm_gallery = gr.Gallery(value=[], label="Recursion #3 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False)
|
| 3032 |
add_download_button(l3_norm_gallery, "eig_norm_recur3")
|
| 3033 |
+
l3_cluster_gallery = gr.Gallery(value=[], label="Recursion #3 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height='450px', show_share_button=True, preview=False, interactive=False)
|
| 3034 |
add_download_button(l3_cluster_gallery, "clusters_recur3")
|
| 3035 |
|
| 3036 |
with gr.Row():
|
|
|
|
| 3078 |
)
|
| 3079 |
|
| 3080 |
|
| 3081 |
+
with gr.Tab('Video', visible=True) as tab_video:
|
| 3082 |
with gr.Row():
|
| 3083 |
with gr.Column(scale=5, min_width=200):
|
| 3084 |
video_input_gallery, submit_button, clear_video_button, max_frame_number = make_input_video_section()
|
|
|
|
| 3126 |
from draft_gradio_app_text import make_demo
|
| 3127 |
make_demo()
|
| 3128 |
|
| 3129 |
+
with gr.Tab('Vision-Language', visible=False) as tab_lisa:
|
| 3130 |
gr.Markdown('[LISA](https://arxiv.org/pdf/2308.00692) is a vision-language model. Input a text prompt and image, LISA generate segmentation masks.')
|
| 3131 |
gr.Markdown('In the mask decoder layers, LISA updates the image features w.r.t. the text prompt')
|
| 3132 |
gr.Markdown('This page aims to see how the text prompt affects the image features')
|
|
|
|
| 3135 |
with gr.Row():
|
| 3136 |
with gr.Column(scale=5, min_width=200):
|
| 3137 |
gr.Markdown('### Output (Prompt #1)')
|
| 3138 |
+
l1_gallery = gr.Gallery(format='png', value=[], label="Prompt #1", show_label=False, elem_id="ncut_p1", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
|
| 3139 |
prompt1 = gr.Textbox(label="Input Prompt #1", elem_id="prompt1", value="where is the person, include the clothes, don't include the guitar and chair", lines=3)
|
| 3140 |
with gr.Column(scale=5, min_width=200):
|
| 3141 |
gr.Markdown('### Output (Prompt #2)')
|
| 3142 |
+
l2_gallery = gr.Gallery(format='png', value=[], label="Prompt #2", show_label=False, elem_id="ncut_p2", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
|
| 3143 |
prompt2 = gr.Textbox(label="Input Prompt #2", elem_id="prompt2", value="where is the Gibson Les Pual guitar", lines=3)
|
| 3144 |
with gr.Column(scale=5, min_width=200):
|
| 3145 |
gr.Markdown('### Output (Prompt #3)')
|
| 3146 |
+
l3_gallery = gr.Gallery(format='png', value=[], label="Prompt #3", show_label=False, elem_id="ncut_p3", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
|
| 3147 |
prompt3 = gr.Textbox(label="Input Prompt #3", elem_id="prompt3", value="where is the floor", lines=3)
|
| 3148 |
|
| 3149 |
with gr.Row():
|
|
|
|
| 3175 |
outputs=galleries + [logging_text],
|
| 3176 |
)
|
| 3177 |
|
| 3178 |
+
with gr.Tab('Model Aligned', visible=False) as tab_aligned:
|
| 3179 |
gr.Markdown('This page reproduce the results from the paper [AlignedCut](https://arxiv.org/abs/2406.18344)')
|
| 3180 |
gr.Markdown('---')
|
| 3181 |
gr.Markdown('**Features are aligned across models and layers.** A linear alignment transform is trained for each model/layer, learning signal comes from 1) fMRI brain activation and 2) segmentation preserving eigen-constraints.')
|
|
|
|
| 3238 |
gr.Markdown('')
|
| 3239 |
gr.Markdown("To see a good pattern, you will need to load 100~1000 images. 100 images need 10sec for RTX4090. Running out of HuggingFace GPU Quota? Try [Demo](https://ncut-pytorch.readthedocs.io/en/latest/demo/) hosted at UPenn")
|
| 3240 |
gr.Markdown('---')
|
| 3241 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3242 |
gr.Markdown('### Output (Recursion #1)')
|
| 3243 |
+
l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[100], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False, preview=True)
|
| 3244 |
add_rotate_flip_buttons(l1_gallery)
|
| 3245 |
add_download_button(l1_gallery, "modelaligned_recur1")
|
| 3246 |
gr.Markdown('### Output (Recursion #2)')
|
| 3247 |
+
l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[100], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False, preview=True)
|
| 3248 |
add_rotate_flip_buttons(l2_gallery)
|
| 3249 |
add_download_button(l2_gallery, "modelaligned_recur2")
|
| 3250 |
gr.Markdown('### Output (Recursion #3)')
|
| 3251 |
+
l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[100], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False, preview=True)
|
| 3252 |
add_rotate_flip_buttons(l3_gallery)
|
| 3253 |
add_download_button(l3_gallery, "modelaligned_recur3")
|
| 3254 |
|
|
|
|
| 3317 |
def add_one_model(i_model=1):
|
| 3318 |
with gr.Column(scale=5, min_width=200) as col:
|
| 3319 |
gr.Markdown(f'### Output Images')
|
| 3320 |
+
output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=False, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
|
| 3321 |
submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
|
| 3322 |
add_rotate_flip_buttons(output_gallery)
|
| 3323 |
[
|
|
|
|
| 3424 |
def add_one_model(i_model=1):
|
| 3425 |
with gr.Column(scale=5, min_width=200) as col:
|
| 3426 |
gr.Markdown(f'### Output Images')
|
| 3427 |
+
output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
|
| 3428 |
submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
|
| 3429 |
add_rotate_flip_buttons(output_gallery)
|
| 3430 |
add_download_button(output_gallery, f"ncut_embed")
|
| 3431 |
+
mlp_gallery = gr.Gallery(format='png', value=[], label="MLP color align", show_label=True, elem_id=f"mlp{i_model}", columns=[3], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
|
| 3432 |
add_mlp_fitting_buttons(output_gallery, mlp_gallery)
|
| 3433 |
add_download_button(mlp_gallery, f"mlp_color_align")
|
| 3434 |
+
norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id=f"eig_norm{i_model}", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False)
|
| 3435 |
add_download_button(norm_gallery, f"eig_norm")
|
| 3436 |
+
cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id=f"clusters{i_model}", columns=[2], rows=[4], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False)
|
| 3437 |
add_download_button(cluster_gallery, f"clusters")
|
| 3438 |
[
|
| 3439 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
|
|
|
| 3639 |
def add_one_model(i_model=1):
|
| 3640 |
with gr.Column(scale=5, min_width=200) as col:
|
| 3641 |
gr.Markdown(f'### Output Images')
|
| 3642 |
+
output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
|
| 3643 |
submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
|
| 3644 |
add_rotate_flip_buttons(output_gallery)
|
| 3645 |
add_download_button(output_gallery, f"ncut_embed")
|
| 3646 |
+
mlp_gallery = gr.Gallery(format='png', value=[], label="MLP color align", show_label=True, elem_id=f"mlp{i_model}", columns=[3], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
|
| 3647 |
add_mlp_fitting_buttons(output_gallery, mlp_gallery)
|
| 3648 |
add_download_button(mlp_gallery, f"mlp_color_align")
|
| 3649 |
+
norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id=f"eig_norm{i_model}", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False)
|
| 3650 |
add_download_button(norm_gallery, f"eig_norm")
|
| 3651 |
+
cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id=f"clusters{i_model}", columns=[2], rows=[4], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False)
|
| 3652 |
add_download_button(cluster_gallery, f"clusters")
|
| 3653 |
[
|
| 3654 |
model_dropdown, layer_slider, node_type_dropdown, node_type_dropdown2, head_index_text, make_symmetric, num_eig_slider,
|
|
|
|
| 3831 |
|
| 3832 |
with gr.Column(scale=5, min_width=200):
|
| 3833 |
gr.Markdown("### Step 3: Segment and Crop")
|
| 3834 |
+
mask_gallery = gr.Gallery(value=[], label="Segmentation Masks", show_label=True, elem_id="mask_gallery", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, interactive=False)
|
| 3835 |
run_crop_button = gr.Button("🔴 RUN", elem_id="run_crop_button", variant='primary')
|
| 3836 |
add_download_button(mask_gallery, "mask")
|
| 3837 |
distance_threshold_slider = gr.Slider(0, 1, step=0.01, label="Mask Threshold (FG)", value=0.9, elem_id="distance_threshold", info="increase for smaller FG mask")
|
|
|
|
| 3841 |
overlay_image_checkbox = gr.Checkbox(label="Overlay Original Image", value=True, elem_id="overlay_image_checkbox")
|
| 3842 |
# filter_small_area_checkbox = gr.Checkbox(label="Noise Reduction", value=True, elem_id="filter_small_area_checkbox")
|
| 3843 |
distance_power_slider = gr.Slider(-3, 3, step=0.01, label="Distance Power", value=0.5, elem_id="distance_power", info="d = d^p", visible=False)
|
| 3844 |
+
crop_gallery = gr.Gallery(value=[], label="Cropped Images", show_label=True, elem_id="crop_gallery", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, interactive=False)
|
| 3845 |
add_download_button(crop_gallery, "cropped")
|
| 3846 |
crop_expand_slider = gr.Slider(1.0, 2.0, step=0.1, label="Crop bbox Expand Factor", value=1.0, elem_id="crop_expand", info="increase for larger crop", visible=True)
|
| 3847 |
area_threshold_slider = gr.Slider(0, 100, step=0.1, label="Area Threshold (%)", value=3, elem_id="area_threshold", info="for noise filtering (area of connected components)", visible=False)
|
|
|
|
| 4319 |
outputs=[n_eig, current_idx, parent_plot, current_plot, *child_plots, child_idx],
|
| 4320 |
)
|
| 4321 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4322 |
|
| 4323 |
with gr.Tab('📄About'):
|
| 4324 |
with gr.Column():
|
|
|
|
| 4362 |
label = "".join(label)
|
| 4363 |
return n_smiles, gr.update(label=label, value=False)
|
| 4364 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4365 |
|
| 4366 |
+
def unlock_tabs(n_smiles, n_tab=1):
|
| 4367 |
if n_smiles == unlock_value:
|
| 4368 |
+
gr.Info("🔓 unlocked tabs", 2)
|
| 4369 |
+
return [gr.update(visible=True)] * n_tab
|
| 4370 |
+
return [gr.update()] * n_tab
|
| 4371 |
|
| 4372 |
hidden_button.change(update_smile, [n_smiles], [n_smiles, hidden_button])
|
| 4373 |
+
hidden_tabs = [tab_alignedcut_advanced, tab_model_aligned_advanced, tab_recursivecut_advanced,
|
| 4374 |
+
tab_compare_models_advanced, tab_directed_ncut, test_playground_tab, tab_aligned, tab_lisa]
|
| 4375 |
+
hidden_button.change(partial(unlock_tabs, n_tab=len(hidden_tabs)), [n_smiles], hidden_tabs)
|
|
|
|
|
|
|
|
|
|
| 4376 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4377 |
with gr.Row():
|
| 4378 |
gr.Markdown("**This demo is for Python package `ncut-pytorch`, [Documentation](https://ncut-pytorch.readthedocs.io/)**")
|
| 4379 |
|