Spaces:
Running
on
Zero
Running
on
Zero
update playground
Browse files
app.py
CHANGED
@@ -1702,7 +1702,7 @@ def load_and_append(existing_images, *args, **kwargs):
|
|
1702 |
gr.Info(f"Total images: {len(existing_images)}")
|
1703 |
return existing_images
|
1704 |
|
1705 |
-
def make_input_images_section(rows=1, cols=3, height="
|
1706 |
if markdown:
|
1707 |
gr.Markdown('### Input Images')
|
1708 |
input_gallery = gr.Gallery(value=None, label="Input images", show_label=True, elem_id="input_images", columns=[cols], rows=[rows], object_fit="contain", height=height, type="pil", show_share_button=False,
|
@@ -1750,7 +1750,7 @@ def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_
|
|
1750 |
# ("CatDog", ['./images/catdog1.jpg', './images/catdog2.jpg', './images/catdog3.jpg'], "microsoft/cats_vs_dogs"),
|
1751 |
# ("Bird", ['./images/bird1.jpg', './images/bird2.jpg', './images/bird3.jpg'], "Multimodal-Fatima/CUB_train"),
|
1752 |
# ("ChestXray", ['./images/chestxray1.jpg', './images/chestxray2.jpg', './images/chestxray3.jpg'], "hongrui/mimic_chest_xray_v_1"),
|
1753 |
-
("
|
1754 |
("Kanji", ['./images/kanji1.jpg', './images/kanji2.jpg', './images/kanji3.jpg'], "yashvoladoddi37/kanjienglish"),
|
1755 |
]
|
1756 |
for name, images, dataset_name in example_items:
|
@@ -2073,7 +2073,7 @@ def add_download_button(gallery, filename_prefix="output"):
|
|
2073 |
def make_output_images_section(markdown=True, button=True):
|
2074 |
if markdown:
|
2075 |
gr.Markdown('### Output Images')
|
2076 |
-
output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="
|
2077 |
if button:
|
2078 |
add_rotate_flip_buttons(output_gallery)
|
2079 |
return output_gallery
|
@@ -2202,7 +2202,7 @@ with demo:
|
|
2202 |
with gr.Column(scale=5, min_width=200):
|
2203 |
# gr.Markdown("### Step 2a: Run Backbone and NCUT")
|
2204 |
# with gr.Accordion(label="Backbone Parameters", visible=True, open=False):
|
2205 |
-
output_gallery = gr.Gallery(format='png', value=[], label="NCUT spectral-tSNE", show_label=True, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="
|
2206 |
def add_rotate_flip_buttons_with_state(output_gallery, tsne3d_rgb):
|
2207 |
with gr.Row():
|
2208 |
rotate_button = gr.Button("🔄 Rotate", elem_id="rotate_button", variant='secondary')
|
@@ -2414,7 +2414,7 @@ with demo:
|
|
2414 |
gr.Markdown("Known Issue: Resize the browser window will break the clicking, please refresh the page.")
|
2415 |
with gr.Accordion("Outputs", open=True):
|
2416 |
gr.Markdown("""
|
2417 |
-
1. spectral-tSNE tree: ◆ marker is the N points, connected components to the clicked
|
2418 |
2. Cluster Heatmap: max of N cosine similarity to N points in the connected components.
|
2419 |
""")
|
2420 |
with gr.Column(scale=5, min_width=200):
|
@@ -2448,8 +2448,10 @@ with demo:
|
|
2448 |
with gr.Column(scale=5, min_width=200):
|
2449 |
output_tree_image = gr.Image(label=f"spectral-tSNE tree [row#{i_row}]", elem_id="output_image", interactive=False)
|
2450 |
text_block = gr.Textbox("", label="Logging", elem_id=f"logging_{i_row}", type="text", placeholder="Logging information", autofocus=False, autoscroll=False, lines=2, show_label=False)
|
|
|
2451 |
with gr.Column(scale=10, min_width=200):
|
2452 |
-
heatmap_gallery = gr.Gallery(format='png', value=[], label=f"Cluster Heatmap [row#{i_row}]", show_label=True, elem_id="heatmap", columns=[6], rows=[1], object_fit="contain", height="
|
|
|
2453 |
return inspect_output_row, output_tree_image, heatmap_gallery, text_block
|
2454 |
|
2455 |
gr.Markdown('---')
|
@@ -2600,6 +2602,193 @@ with demo:
|
|
2600 |
outputs=inspect_output_rows + output_tree_images + heatmap_galleries + text_blocks + [current_output_row, inspect_logging_text],
|
2601 |
)
|
2602 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2603 |
with gr.Tab('AlignedCut'):
|
2604 |
|
2605 |
with gr.Row():
|
@@ -2610,7 +2799,7 @@ with demo:
|
|
2610 |
|
2611 |
with gr.Column(scale=5, min_width=200):
|
2612 |
output_gallery = make_output_images_section()
|
2613 |
-
cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[2], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
2614 |
[
|
2615 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
2616 |
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
@@ -2623,7 +2812,7 @@ with demo:
|
|
2623 |
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
2624 |
|
2625 |
submit_button.click(
|
2626 |
-
partial(run_fn, n_ret=
|
2627 |
inputs=[
|
2628 |
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
2629 |
positive_prompt, negative_prompt,
|
@@ -2632,7 +2821,7 @@ with demo:
|
|
2632 |
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
2633 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown
|
2634 |
],
|
2635 |
-
outputs=[output_gallery,
|
2636 |
api_name="API_AlignedCut",
|
2637 |
scroll_to_output=True,
|
2638 |
)
|
@@ -2648,9 +2837,9 @@ with demo:
|
|
2648 |
with gr.Column(scale=5, min_width=200):
|
2649 |
output_gallery = make_output_images_section()
|
2650 |
add_download_button(output_gallery, "ncut_embed")
|
2651 |
-
norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="
|
2652 |
add_download_button(norm_gallery, "eig_norm")
|
2653 |
-
cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height="
|
2654 |
add_download_button(cluster_gallery, "clusters")
|
2655 |
[
|
2656 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
@@ -2740,7 +2929,7 @@ with demo:
|
|
2740 |
api_name="API_NCut",
|
2741 |
)
|
2742 |
|
2743 |
-
with gr.Tab('
|
2744 |
gr.Markdown('NCUT can be applied recursively, the eigenvectors from previous iteration is the input for the next iteration NCUT. ')
|
2745 |
gr.Markdown('__Recursive NCUT__ can amplify or weaken the connections, depending on the `affinity_focal_gamma` setting, please see [Documentation](https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/#recursive-ncut)')
|
2746 |
|
@@ -2753,15 +2942,15 @@ with demo:
|
|
2753 |
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
|
2754 |
with gr.Column(scale=5, min_width=200):
|
2755 |
gr.Markdown('### Output (Recursion #1)')
|
2756 |
-
l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="
|
2757 |
add_rotate_flip_buttons(l1_gallery)
|
2758 |
with gr.Column(scale=5, min_width=200):
|
2759 |
gr.Markdown('### Output (Recursion #2)')
|
2760 |
-
l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="
|
2761 |
add_rotate_flip_buttons(l2_gallery)
|
2762 |
with gr.Column(scale=5, min_width=200):
|
2763 |
gr.Markdown('### Output (Recursion #3)')
|
2764 |
-
l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="
|
2765 |
add_rotate_flip_buttons(l3_gallery)
|
2766 |
with gr.Row():
|
2767 |
|
@@ -2809,7 +2998,7 @@ with demo:
|
|
2809 |
api_name="API_RecursiveCut"
|
2810 |
)
|
2811 |
|
2812 |
-
with gr.Tab('
|
2813 |
|
2814 |
with gr.Row():
|
2815 |
with gr.Column(scale=5, min_width=200):
|
@@ -2818,30 +3007,30 @@ with demo:
|
|
2818 |
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", lines=20)
|
2819 |
with gr.Column(scale=5, min_width=200):
|
2820 |
gr.Markdown('### Output (Recursion #1)')
|
2821 |
-
l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="
|
2822 |
add_rotate_flip_buttons(l1_gallery)
|
2823 |
add_download_button(l1_gallery, "ncut_embed_recur1")
|
2824 |
-
l1_norm_gallery = gr.Gallery(value=[], label="Recursion #1 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="
|
2825 |
add_download_button(l1_norm_gallery, "eig_norm_recur1")
|
2826 |
-
l1_cluster_gallery = gr.Gallery(value=[], label="Recursion #1 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height='
|
2827 |
add_download_button(l1_cluster_gallery, "clusters_recur1")
|
2828 |
with gr.Column(scale=5, min_width=200):
|
2829 |
gr.Markdown('### Output (Recursion #2)')
|
2830 |
-
l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="
|
2831 |
add_rotate_flip_buttons(l2_gallery)
|
2832 |
add_download_button(l2_gallery, "ncut_embed_recur2")
|
2833 |
-
l2_norm_gallery = gr.Gallery(value=[], label="Recursion #2 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="
|
2834 |
add_download_button(l2_norm_gallery, "eig_norm_recur2")
|
2835 |
-
l2_cluster_gallery = gr.Gallery(value=[], label="Recursion #2 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height='
|
2836 |
add_download_button(l2_cluster_gallery, "clusters_recur2")
|
2837 |
with gr.Column(scale=5, min_width=200):
|
2838 |
gr.Markdown('### Output (Recursion #3)')
|
2839 |
-
l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="
|
2840 |
add_rotate_flip_buttons(l3_gallery)
|
2841 |
add_download_button(l3_gallery, "ncut_embed_recur3")
|
2842 |
-
l3_norm_gallery = gr.Gallery(value=[], label="Recursion #3 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="
|
2843 |
add_download_button(l3_norm_gallery, "eig_norm_recur3")
|
2844 |
-
l3_cluster_gallery = gr.Gallery(value=[], label="Recursion #3 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height='
|
2845 |
add_download_button(l3_cluster_gallery, "clusters_recur3")
|
2846 |
|
2847 |
with gr.Row():
|
@@ -2889,7 +3078,7 @@ with demo:
|
|
2889 |
)
|
2890 |
|
2891 |
|
2892 |
-
with gr.Tab('Video'):
|
2893 |
with gr.Row():
|
2894 |
with gr.Column(scale=5, min_width=200):
|
2895 |
video_input_gallery, submit_button, clear_video_button, max_frame_number = make_input_video_section()
|
@@ -2937,7 +3126,7 @@ with demo:
|
|
2937 |
from draft_gradio_app_text import make_demo
|
2938 |
make_demo()
|
2939 |
|
2940 |
-
with gr.Tab('Vision-Language'):
|
2941 |
gr.Markdown('[LISA](https://arxiv.org/pdf/2308.00692) is a vision-language model. Input a text prompt and image, LISA generate segmentation masks.')
|
2942 |
gr.Markdown('In the mask decoder layers, LISA updates the image features w.r.t. the text prompt')
|
2943 |
gr.Markdown('This page aims to see how the text prompt affects the image features')
|
@@ -2946,15 +3135,15 @@ with demo:
|
|
2946 |
with gr.Row():
|
2947 |
with gr.Column(scale=5, min_width=200):
|
2948 |
gr.Markdown('### Output (Prompt #1)')
|
2949 |
-
l1_gallery = gr.Gallery(format='png', value=[], label="Prompt #1", show_label=False, elem_id="ncut_p1", columns=[3], rows=[5], object_fit="contain", height="
|
2950 |
prompt1 = gr.Textbox(label="Input Prompt #1", elem_id="prompt1", value="where is the person, include the clothes, don't include the guitar and chair", lines=3)
|
2951 |
with gr.Column(scale=5, min_width=200):
|
2952 |
gr.Markdown('### Output (Prompt #2)')
|
2953 |
-
l2_gallery = gr.Gallery(format='png', value=[], label="Prompt #2", show_label=False, elem_id="ncut_p2", columns=[3], rows=[5], object_fit="contain", height="
|
2954 |
prompt2 = gr.Textbox(label="Input Prompt #2", elem_id="prompt2", value="where is the Gibson Les Pual guitar", lines=3)
|
2955 |
with gr.Column(scale=5, min_width=200):
|
2956 |
gr.Markdown('### Output (Prompt #3)')
|
2957 |
-
l3_gallery = gr.Gallery(format='png', value=[], label="Prompt #3", show_label=False, elem_id="ncut_p3", columns=[3], rows=[5], object_fit="contain", height="
|
2958 |
prompt3 = gr.Textbox(label="Input Prompt #3", elem_id="prompt3", value="where is the floor", lines=3)
|
2959 |
|
2960 |
with gr.Row():
|
@@ -2986,7 +3175,7 @@ with demo:
|
|
2986 |
outputs=galleries + [logging_text],
|
2987 |
)
|
2988 |
|
2989 |
-
with gr.Tab('Model Aligned'):
|
2990 |
gr.Markdown('This page reproduce the results from the paper [AlignedCut](https://arxiv.org/abs/2406.18344)')
|
2991 |
gr.Markdown('---')
|
2992 |
gr.Markdown('**Features are aligned across models and layers.** A linear alignment transform is trained for each model/layer, learning signal comes from 1) fMRI brain activation and 2) segmentation preserving eigen-constraints.')
|
@@ -3049,30 +3238,17 @@ with demo:
|
|
3049 |
gr.Markdown('')
|
3050 |
gr.Markdown("To see a good pattern, you will need to load 100~1000 images. 100 images need 10sec for RTX4090. Running out of HuggingFace GPU Quota? Try [Demo](https://ncut-pytorch.readthedocs.io/en/latest/demo/) hosted at UPenn")
|
3051 |
gr.Markdown('---')
|
3052 |
-
|
3053 |
-
# with gr.Row():
|
3054 |
-
# with gr.Column(scale=5, min_width=200):
|
3055 |
-
# gr.Markdown('### Output (Recursion #1)')
|
3056 |
-
# l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=False, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
3057 |
-
# add_output_images_buttons(l1_gallery)
|
3058 |
-
# with gr.Column(scale=5, min_width=200):
|
3059 |
-
# gr.Markdown('### Output (Recursion #2)')
|
3060 |
-
# l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=False, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
3061 |
-
# add_output_images_buttons(l2_gallery)
|
3062 |
-
# with gr.Column(scale=5, min_width=200):
|
3063 |
-
# gr.Markdown('### Output (Recursion #3)')
|
3064 |
-
# l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=False, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
3065 |
-
# add_output_images_buttons(l3_gallery)
|
3066 |
gr.Markdown('### Output (Recursion #1)')
|
3067 |
-
l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[100], rows=[1], object_fit="contain", height="
|
3068 |
add_rotate_flip_buttons(l1_gallery)
|
3069 |
add_download_button(l1_gallery, "modelaligned_recur1")
|
3070 |
gr.Markdown('### Output (Recursion #2)')
|
3071 |
-
l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[100], rows=[1], object_fit="contain", height="
|
3072 |
add_rotate_flip_buttons(l2_gallery)
|
3073 |
add_download_button(l2_gallery, "modelaligned_recur2")
|
3074 |
gr.Markdown('### Output (Recursion #3)')
|
3075 |
-
l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[100], rows=[1], object_fit="contain", height="
|
3076 |
add_rotate_flip_buttons(l3_gallery)
|
3077 |
add_download_button(l3_gallery, "modelaligned_recur3")
|
3078 |
|
@@ -3141,7 +3317,7 @@ with demo:
|
|
3141 |
def add_one_model(i_model=1):
|
3142 |
with gr.Column(scale=5, min_width=200) as col:
|
3143 |
gr.Markdown(f'### Output Images')
|
3144 |
-
output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=False, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="
|
3145 |
submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
|
3146 |
add_rotate_flip_buttons(output_gallery)
|
3147 |
[
|
@@ -3248,16 +3424,16 @@ with demo:
|
|
3248 |
def add_one_model(i_model=1):
|
3249 |
with gr.Column(scale=5, min_width=200) as col:
|
3250 |
gr.Markdown(f'### Output Images')
|
3251 |
-
output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="
|
3252 |
submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
|
3253 |
add_rotate_flip_buttons(output_gallery)
|
3254 |
add_download_button(output_gallery, f"ncut_embed")
|
3255 |
-
mlp_gallery = gr.Gallery(format='png', value=[], label="MLP color align", show_label=True, elem_id=f"mlp{i_model}", columns=[3], rows=[1], object_fit="contain", height="
|
3256 |
add_mlp_fitting_buttons(output_gallery, mlp_gallery)
|
3257 |
add_download_button(mlp_gallery, f"mlp_color_align")
|
3258 |
-
norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id=f"eig_norm{i_model}", columns=[3], rows=[1], object_fit="contain", height="
|
3259 |
add_download_button(norm_gallery, f"eig_norm")
|
3260 |
-
cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id=f"clusters{i_model}", columns=[2], rows=[4], object_fit="contain", height="
|
3261 |
add_download_button(cluster_gallery, f"clusters")
|
3262 |
[
|
3263 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
@@ -3463,16 +3639,16 @@ with demo:
|
|
3463 |
def add_one_model(i_model=1):
|
3464 |
with gr.Column(scale=5, min_width=200) as col:
|
3465 |
gr.Markdown(f'### Output Images')
|
3466 |
-
output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="
|
3467 |
submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
|
3468 |
add_rotate_flip_buttons(output_gallery)
|
3469 |
add_download_button(output_gallery, f"ncut_embed")
|
3470 |
-
mlp_gallery = gr.Gallery(format='png', value=[], label="MLP color align", show_label=True, elem_id=f"mlp{i_model}", columns=[3], rows=[1], object_fit="contain", height="
|
3471 |
add_mlp_fitting_buttons(output_gallery, mlp_gallery)
|
3472 |
add_download_button(mlp_gallery, f"mlp_color_align")
|
3473 |
-
norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id=f"eig_norm{i_model}", columns=[3], rows=[1], object_fit="contain", height="
|
3474 |
add_download_button(norm_gallery, f"eig_norm")
|
3475 |
-
cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id=f"clusters{i_model}", columns=[2], rows=[4], object_fit="contain", height="
|
3476 |
add_download_button(cluster_gallery, f"clusters")
|
3477 |
[
|
3478 |
model_dropdown, layer_slider, node_type_dropdown, node_type_dropdown2, head_index_text, make_symmetric, num_eig_slider,
|
@@ -3655,7 +3831,7 @@ with demo:
|
|
3655 |
|
3656 |
with gr.Column(scale=5, min_width=200):
|
3657 |
gr.Markdown("### Step 3: Segment and Crop")
|
3658 |
-
mask_gallery = gr.Gallery(value=[], label="Segmentation Masks", show_label=True, elem_id="mask_gallery", columns=[3], rows=[1], object_fit="contain", height="
|
3659 |
run_crop_button = gr.Button("🔴 RUN", elem_id="run_crop_button", variant='primary')
|
3660 |
add_download_button(mask_gallery, "mask")
|
3661 |
distance_threshold_slider = gr.Slider(0, 1, step=0.01, label="Mask Threshold (FG)", value=0.9, elem_id="distance_threshold", info="increase for smaller FG mask")
|
@@ -3665,7 +3841,7 @@ with demo:
|
|
3665 |
overlay_image_checkbox = gr.Checkbox(label="Overlay Original Image", value=True, elem_id="overlay_image_checkbox")
|
3666 |
# filter_small_area_checkbox = gr.Checkbox(label="Noise Reduction", value=True, elem_id="filter_small_area_checkbox")
|
3667 |
distance_power_slider = gr.Slider(-3, 3, step=0.01, label="Distance Power", value=0.5, elem_id="distance_power", info="d = d^p", visible=False)
|
3668 |
-
crop_gallery = gr.Gallery(value=[], label="Cropped Images", show_label=True, elem_id="crop_gallery", columns=[3], rows=[1], object_fit="contain", height="
|
3669 |
add_download_button(crop_gallery, "cropped")
|
3670 |
crop_expand_slider = gr.Slider(1.0, 2.0, step=0.1, label="Crop bbox Expand Factor", value=1.0, elem_id="crop_expand", info="increase for larger crop", visible=True)
|
3671 |
area_threshold_slider = gr.Slider(0, 100, step=0.1, label="Area Threshold (%)", value=3, elem_id="area_threshold", info="for noise filtering (area of connected components)", visible=False)
|
@@ -4143,192 +4319,6 @@ with demo:
|
|
4143 |
outputs=[n_eig, current_idx, parent_plot, current_plot, *child_plots, child_idx],
|
4144 |
)
|
4145 |
|
4146 |
-
with gr.Tab('PlayGround (eig)', visible=True) as test_playground_tab2:
|
4147 |
-
eigvecs = gr.State(np.array([]))
|
4148 |
-
with gr.Row():
|
4149 |
-
with gr.Column(scale=5, min_width=200):
|
4150 |
-
gr.Markdown("### Step 1: Load Images")
|
4151 |
-
input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(n_example_images=100)
|
4152 |
-
submit_button.visible = False
|
4153 |
-
num_images_slider.value = 30
|
4154 |
-
|
4155 |
-
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
4156 |
-
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
4157 |
-
|
4158 |
-
|
4159 |
-
with gr.Column(scale=5, min_width=200):
|
4160 |
-
gr.Markdown("### Step 2a: Run Backbone and NCUT")
|
4161 |
-
with gr.Accordion(label="Backbone Parameters", visible=True, open=False):
|
4162 |
-
[
|
4163 |
-
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
4164 |
-
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
4165 |
-
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
4166 |
-
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
4167 |
-
sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
|
4168 |
-
] = make_parameters_section(ncut_parameter_dropdown=False, tsne_parameter_dropdown=False)
|
4169 |
-
num_eig_slider.value = 1024
|
4170 |
-
num_eig_slider.visible = False
|
4171 |
-
submit_button = gr.Button("🔴 RUN NCUT", elem_id="run_ncut", variant='primary')
|
4172 |
-
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
|
4173 |
-
submit_button.click(
|
4174 |
-
partial(run_fn, n_ret=1, only_eigvecs=True),
|
4175 |
-
inputs=[
|
4176 |
-
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
4177 |
-
positive_prompt, negative_prompt,
|
4178 |
-
false_placeholder, no_prompt, no_prompt, no_prompt,
|
4179 |
-
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
4180 |
-
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
4181 |
-
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown
|
4182 |
-
],
|
4183 |
-
outputs=[eigvecs, logging_text],
|
4184 |
-
)
|
4185 |
-
gr.Markdown("### Step 2b: Pick an Image")
|
4186 |
-
from gradio_image_prompter import ImagePrompter
|
4187 |
-
with gr.Row():
|
4188 |
-
image1_slider = gr.Slider(0, 100, step=1, label="Image#1 Index", value=0, elem_id="image1_slider", interactive=True)
|
4189 |
-
load_one_image_button = gr.Button("🔴 Load Image", elem_id="load_one_image_button", variant='primary')
|
4190 |
-
gr.Markdown("### Step 2c: Draw a Point")
|
4191 |
-
gr.Markdown("""
|
4192 |
-
<h5>
|
4193 |
-
🖱️ Left Click: Foreground </br>
|
4194 |
-
</h5>
|
4195 |
-
""")
|
4196 |
-
prompt_image1 = ImagePrompter(show_label=False, elem_id="prompt_image1", interactive=False)
|
4197 |
-
def update_prompt_image(original_images, index):
|
4198 |
-
images = original_images
|
4199 |
-
if images is None:
|
4200 |
-
return
|
4201 |
-
total_len = len(images)
|
4202 |
-
if total_len == 0:
|
4203 |
-
return
|
4204 |
-
if index >= total_len:
|
4205 |
-
index = total_len - 1
|
4206 |
-
|
4207 |
-
return ImagePrompter(value={'image': images[index][0], 'points': []}, interactive=True)
|
4208 |
-
# return gr.Image(value=images[index][0], elem_id=f"prompt_image{randint}", interactive=True)
|
4209 |
-
load_one_image_button.click(update_prompt_image, inputs=[input_gallery, image1_slider], outputs=[prompt_image1])
|
4210 |
-
|
4211 |
-
child_idx = gr.State([])
|
4212 |
-
current_idx = gr.State(None)
|
4213 |
-
n_eig = gr.State(64)
|
4214 |
-
with gr.Column(scale=5, min_width=200):
|
4215 |
-
gr.Markdown("### Step 3: Check groupping")
|
4216 |
-
child_distance_slider = gr.Slider(0, 0.5, step=0.001, label="Child Distance", value=0.1, elem_id="child_distance_slider", interactive=True)
|
4217 |
-
child_distance_slider.visible = False
|
4218 |
-
overlay_image_checkbox = gr.Checkbox(label="Overlay Image", value=True, elem_id="overlay_image_checkbox", interactive=True)
|
4219 |
-
n_eig_slider = gr.Slider(0, 1024, step=1, label="Number of Eigenvectors", value=256, elem_id="n_eig_slider", interactive=True)
|
4220 |
-
run_button = gr.Button("🔴 RUN", elem_id="run_groupping", variant='primary')
|
4221 |
-
with gr.Row():
|
4222 |
-
doublue_eigs_button = gr.Button("⬇️ +eigvecs", elem_id="doublue_eigs_button", variant='secondary')
|
4223 |
-
half_eigs_button = gr.Button("⬆️ -eigvecs", elem_id="half_eigs_button", variant='secondary')
|
4224 |
-
current_plot = gr.Gallery(value=None, label="Current", show_label=True, elem_id="current_plot", interactive=False, rows=[1], columns=[2])
|
4225 |
-
|
4226 |
-
def relative_xy(prompts):
|
4227 |
-
image = prompts['image']
|
4228 |
-
points = np.asarray(prompts['points'])
|
4229 |
-
if points.shape[0] == 0:
|
4230 |
-
return [], []
|
4231 |
-
is_point = points[:, 5] == 4.0
|
4232 |
-
points = points[is_point]
|
4233 |
-
is_positive = points[:, 2] == 1.0
|
4234 |
-
is_negative = points[:, 2] == 0.0
|
4235 |
-
xy = points[:, :2].tolist()
|
4236 |
-
if isinstance(image, str):
|
4237 |
-
image = Image.open(image)
|
4238 |
-
image = np.array(image)
|
4239 |
-
h, w = image.shape[:2]
|
4240 |
-
new_xy = [(x/w, y/h) for x, y in xy]
|
4241 |
-
# print(new_xy)
|
4242 |
-
return new_xy, is_positive
|
4243 |
-
|
4244 |
-
def xy_eigvec(prompts, image_idx, eigvecs):
|
4245 |
-
eigvec = eigvecs[image_idx]
|
4246 |
-
xy, is_positive = relative_xy(prompts)
|
4247 |
-
for i, (x, y) in enumerate(xy):
|
4248 |
-
if not is_positive[i]:
|
4249 |
-
continue
|
4250 |
-
x = int(x * eigvec.shape[1])
|
4251 |
-
y = int(y * eigvec.shape[0])
|
4252 |
-
return eigvec[y, x], (y, x)
|
4253 |
-
|
4254 |
-
from ncut_pytorch.ncut_pytorch import _transform_heatmap
|
4255 |
-
def _run_heatmap_fn(images, eigvecs, prompt_image_idx, prompt_points, n_eig, flat_idx=None, raw_heatmap=False, overlay_image=True):
|
4256 |
-
left = eigvecs[..., :n_eig]
|
4257 |
-
if flat_idx is not None:
|
4258 |
-
right = eigvecs.reshape(-1, eigvecs.shape[-1])[flat_idx]
|
4259 |
-
y, x = None, None
|
4260 |
-
else:
|
4261 |
-
right, (y, x) = xy_eigvec(prompt_points, prompt_image_idx, eigvecs)
|
4262 |
-
right = right[:n_eig]
|
4263 |
-
left = F.normalize(left, p=2, dim=-1)
|
4264 |
-
_right = F.normalize(right, p=2, dim=-1)
|
4265 |
-
heatmap = left @ _right.unsqueeze(-1)
|
4266 |
-
heatmap = heatmap.squeeze(-1)
|
4267 |
-
# heatmap = 1 - heatmap
|
4268 |
-
# heatmap = _transform_heatmap(heatmap)
|
4269 |
-
if raw_heatmap:
|
4270 |
-
return heatmap
|
4271 |
-
# apply hot colormap and covert to PIL image 256x256
|
4272 |
-
# gr.Info(f"heatmap vmin: {heatmap.min()}, vmax: {heatmap.max()}, mean: {heatmap.mean()}")
|
4273 |
-
heatmap = heatmap.cpu().numpy()
|
4274 |
-
hot_map = matplotlib.colormaps['hot']
|
4275 |
-
heatmap = hot_map(heatmap)
|
4276 |
-
pil_images = to_pil_images(torch.tensor(heatmap), target_size=256, force_size=True)
|
4277 |
-
if overlay_image:
|
4278 |
-
overlaied_images = []
|
4279 |
-
for i_image in range(len(images)):
|
4280 |
-
rgb_image = images[i_image].resize((256, 256))
|
4281 |
-
rgb_image = np.array(rgb_image)
|
4282 |
-
heatmap_image = np.array(pil_images[i_image])[..., :3]
|
4283 |
-
blend_image = 0.5 * rgb_image + 0.5 * heatmap_image
|
4284 |
-
blend_image = Image.fromarray(blend_image.astype(np.uint8))
|
4285 |
-
overlaied_images.append(blend_image)
|
4286 |
-
pil_images = overlaied_images
|
4287 |
-
return pil_images, (y, x)
|
4288 |
-
|
4289 |
-
@torch.no_grad()
|
4290 |
-
def run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx=None, overlay_image=True):
|
4291 |
-
gr.Info(f"current number of eigenvectors: {n_eig}", 2)
|
4292 |
-
eigvecs = torch.tensor(eigvecs)
|
4293 |
-
image1_slider = min(image1_slider, len(images)-1)
|
4294 |
-
images = [image[0] for image in images]
|
4295 |
-
if isinstance(images[0], str):
|
4296 |
-
images = [Image.open(image[0]).convert("RGB").resize((256, 256)) for image in images]
|
4297 |
-
|
4298 |
-
current_heatmap, (y, x) = _run_heatmap_fn(images, eigvecs, image1_slider, prompt_image1, n_eig, flat_idx, overlay_image=overlay_image)
|
4299 |
-
|
4300 |
-
return current_heatmap
|
4301 |
-
|
4302 |
-
def doublue_eigs_wrapper(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx=None, overlay_image=True):
|
4303 |
-
n_eig = int(n_eig*2)
|
4304 |
-
n_eig = min(n_eig, eigvecs.shape[-1])
|
4305 |
-
n_eig = max(n_eig, 1)
|
4306 |
-
return gr.update(value=n_eig), run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx, overlay_image=overlay_image)
|
4307 |
-
|
4308 |
-
def half_eigs_wrapper(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, current_idx=None, overlay_image=True):
|
4309 |
-
n_eig = int(n_eig/2)
|
4310 |
-
n_eig = min(n_eig, eigvecs.shape[-1])
|
4311 |
-
n_eig = max(n_eig, 1)
|
4312 |
-
return gr.update(value=n_eig), run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, current_idx, overlay_image=overlay_image)
|
4313 |
-
|
4314 |
-
none_placeholder = gr.State(None)
|
4315 |
-
run_button.click(
|
4316 |
-
run_heatmap,
|
4317 |
-
inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig_slider, child_distance_slider, none_placeholder, overlay_image_checkbox],
|
4318 |
-
outputs=[current_plot],
|
4319 |
-
)
|
4320 |
-
|
4321 |
-
doublue_eigs_button.click(
|
4322 |
-
doublue_eigs_wrapper,
|
4323 |
-
inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig_slider, child_distance_slider, none_placeholder, overlay_image_checkbox],
|
4324 |
-
outputs=[n_eig_slider, current_plot],
|
4325 |
-
)
|
4326 |
-
|
4327 |
-
half_eigs_button.click(
|
4328 |
-
half_eigs_wrapper,
|
4329 |
-
inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig_slider, child_distance_slider, current_idx, overlay_image_checkbox],
|
4330 |
-
outputs=[n_eig_slider, current_plot],
|
4331 |
-
)
|
4332 |
|
4333 |
with gr.Tab('📄About'):
|
4334 |
with gr.Column():
|
@@ -4372,30 +4362,18 @@ with demo:
|
|
4372 |
label = "".join(label)
|
4373 |
return n_smiles, gr.update(label=label, value=False)
|
4374 |
|
4375 |
-
def unlock_tabs_with_info(n_smiles):
|
4376 |
-
if n_smiles == unlock_value:
|
4377 |
-
gr.Info("🔓 unlocked tabs", 2)
|
4378 |
-
return gr.update(visible=True)
|
4379 |
-
return gr.update()
|
4380 |
|
4381 |
-
def unlock_tabs(n_smiles):
|
4382 |
if n_smiles == unlock_value:
|
4383 |
-
|
4384 |
-
|
|
|
4385 |
|
4386 |
hidden_button.change(update_smile, [n_smiles], [n_smiles, hidden_button])
|
4387 |
-
|
4388 |
-
|
4389 |
-
hidden_button.change(unlock_tabs, n_smiles,
|
4390 |
-
hidden_button.change(unlock_tabs, n_smiles, tab_compare_models_advanced)
|
4391 |
-
hidden_button.change(unlock_tabs, n_smiles, tab_directed_ncut)
|
4392 |
-
hidden_button.change(unlock_tabs, n_smiles, test_playground_tab)
|
4393 |
|
4394 |
-
# with gr.Row():
|
4395 |
-
# with gr.Column():
|
4396 |
-
# gr.Markdown("##### This demo is for `ncut-pytorch`, [Documentation](https://ncut-pytorch.readthedocs.io/) ")
|
4397 |
-
# with gr.Column():
|
4398 |
-
# gr.Markdown("###### Running out of GPU Quota? Try [Demo](https://ncut-pytorch.readthedocs.io/en/latest/demo/) hosted at UPenn")
|
4399 |
with gr.Row():
|
4400 |
gr.Markdown("**This demo is for Python package `ncut-pytorch`, [Documentation](https://ncut-pytorch.readthedocs.io/)**")
|
4401 |
|
|
|
1702 |
gr.Info(f"Total images: {len(existing_images)}")
|
1703 |
return existing_images
|
1704 |
|
1705 |
+
def make_input_images_section(rows=1, cols=3, height="450px", advanced=False, is_random=False, allow_download=False, markdown=True, n_example_images=100):
|
1706 |
if markdown:
|
1707 |
gr.Markdown('### Input Images')
|
1708 |
input_gallery = gr.Gallery(value=None, label="Input images", show_label=True, elem_id="input_images", columns=[cols], rows=[rows], object_fit="contain", height=height, type="pil", show_share_button=False,
|
|
|
1750 |
# ("CatDog", ['./images/catdog1.jpg', './images/catdog2.jpg', './images/catdog3.jpg'], "microsoft/cats_vs_dogs"),
|
1751 |
# ("Bird", ['./images/bird1.jpg', './images/bird2.jpg', './images/bird3.jpg'], "Multimodal-Fatima/CUB_train"),
|
1752 |
# ("ChestXray", ['./images/chestxray1.jpg', './images/chestxray2.jpg', './images/chestxray3.jpg'], "hongrui/mimic_chest_xray_v_1"),
|
1753 |
+
("MRI", ['./images/brain1.jpg', './images/brain2.jpg', './images/brain3.jpg'], "sartajbhuvaji/Brain-Tumor-Classification"),
|
1754 |
("Kanji", ['./images/kanji1.jpg', './images/kanji2.jpg', './images/kanji3.jpg'], "yashvoladoddi37/kanjienglish"),
|
1755 |
]
|
1756 |
for name, images, dataset_name in example_items:
|
|
|
2073 |
def make_output_images_section(markdown=True, button=True):
|
2074 |
if markdown:
|
2075 |
gr.Markdown('### Output Images')
|
2076 |
+
output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, interactive=False)
|
2077 |
if button:
|
2078 |
add_rotate_flip_buttons(output_gallery)
|
2079 |
return output_gallery
|
|
|
2202 |
with gr.Column(scale=5, min_width=200):
|
2203 |
# gr.Markdown("### Step 2a: Run Backbone and NCUT")
|
2204 |
# with gr.Accordion(label="Backbone Parameters", visible=True, open=False):
|
2205 |
+
output_gallery = gr.Gallery(format='png', value=[], label="NCUT spectral-tSNE", show_label=True, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, interactive=False)
|
2206 |
def add_rotate_flip_buttons_with_state(output_gallery, tsne3d_rgb):
|
2207 |
with gr.Row():
|
2208 |
rotate_button = gr.Button("🔄 Rotate", elem_id="rotate_button", variant='secondary')
|
|
|
2414 |
gr.Markdown("Known Issue: Resize the browser window will break the clicking, please refresh the page.")
|
2415 |
with gr.Accordion("Outputs", open=True):
|
2416 |
gr.Markdown("""
|
2417 |
+
1. spectral-tSNE tree: ◆ marker is the N points, connected components to the clicked .
|
2418 |
2. Cluster Heatmap: max of N cosine similarity to N points in the connected components.
|
2419 |
""")
|
2420 |
with gr.Column(scale=5, min_width=200):
|
|
|
2448 |
with gr.Column(scale=5, min_width=200):
|
2449 |
output_tree_image = gr.Image(label=f"spectral-tSNE tree [row#{i_row}]", elem_id="output_image", interactive=False)
|
2450 |
text_block = gr.Textbox("", label="Logging", elem_id=f"logging_{i_row}", type="text", placeholder="Logging information", autofocus=False, autoscroll=False, lines=2, show_label=False)
|
2451 |
+
delete_button = gr.Button("❌ Delete", elem_id=f"delete_button_{i_row}", variant='secondary')
|
2452 |
with gr.Column(scale=10, min_width=200):
|
2453 |
+
heatmap_gallery = gr.Gallery(format='png', value=[], label=f"Cluster Heatmap [row#{i_row}]", show_label=True, elem_id="heatmap", columns=[6], rows=[1], object_fit="contain", height="500px", show_share_button=True, interactive=False)
|
2454 |
+
delete_button.click(lambda: gr.update(visible=False), outputs=[inspect_output_row])
|
2455 |
return inspect_output_row, output_tree_image, heatmap_gallery, text_block
|
2456 |
|
2457 |
gr.Markdown('---')
|
|
|
2602 |
outputs=inspect_output_rows + output_tree_images + heatmap_galleries + text_blocks + [current_output_row, inspect_logging_text],
|
2603 |
)
|
2604 |
|
2605 |
+
with gr.Tab('PlayGround (eig)', visible=True) as test_playground_tab2:
|
2606 |
+
eigvecs = gr.State(np.array([]))
|
2607 |
+
with gr.Row():
|
2608 |
+
with gr.Column(scale=5, min_width=200):
|
2609 |
+
gr.Markdown("### Step 1: Load Images")
|
2610 |
+
input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(n_example_images=10)
|
2611 |
+
submit_button.visible = False
|
2612 |
+
num_images_slider.value = 30
|
2613 |
+
|
2614 |
+
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
2615 |
+
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
2616 |
+
|
2617 |
+
|
2618 |
+
with gr.Column(scale=5, min_width=200):
|
2619 |
+
gr.Markdown("### Step 2a: Run Backbone and NCUT")
|
2620 |
+
with gr.Accordion(label="Backbone Parameters", visible=True, open=False):
|
2621 |
+
[
|
2622 |
+
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
2623 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
2624 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
2625 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
2626 |
+
sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
|
2627 |
+
] = make_parameters_section(ncut_parameter_dropdown=False, tsne_parameter_dropdown=False)
|
2628 |
+
num_eig_slider.value = 1024
|
2629 |
+
num_eig_slider.visible = False
|
2630 |
+
submit_button = gr.Button("🔴 RUN NCUT", elem_id="run_ncut", variant='primary')
|
2631 |
+
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
|
2632 |
+
submit_button.click(
|
2633 |
+
partial(run_fn, n_ret=1, only_eigvecs=True),
|
2634 |
+
inputs=[
|
2635 |
+
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
2636 |
+
positive_prompt, negative_prompt,
|
2637 |
+
false_placeholder, no_prompt, no_prompt, no_prompt,
|
2638 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
2639 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
2640 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown
|
2641 |
+
],
|
2642 |
+
outputs=[eigvecs, logging_text],
|
2643 |
+
)
|
2644 |
+
gr.Markdown("### Step 2b: Pick an Image and Draw a Point")
|
2645 |
+
from gradio_image_prompter import ImagePrompter
|
2646 |
+
image1_slider = gr.Slider(0, 0, step=1, label="Image#1 Index", value=0, elem_id="image1_slider", interactive=True)
|
2647 |
+
load_one_image_button = gr.Button("🔴 Load Image", elem_id="load_one_image_button", variant='primary')
|
2648 |
+
gr.Markdown("""
|
2649 |
+
<h5>
|
2650 |
+
🖱️ Left Click: Foreground </br>
|
2651 |
+
</h5>
|
2652 |
+
""")
|
2653 |
+
prompt_image1 = ImagePrompter(show_label=False, elem_id="prompt_image1", interactive=False)
|
2654 |
+
def update_prompt_image(original_images, index):
|
2655 |
+
images = original_images
|
2656 |
+
if images is None:
|
2657 |
+
return gr.update()
|
2658 |
+
total_len = len(images)
|
2659 |
+
if total_len == 0:
|
2660 |
+
return gr.update()
|
2661 |
+
if index >= total_len:
|
2662 |
+
index = total_len - 1
|
2663 |
+
return gr.update(value={'image': images[index][0], 'points': []}, interactive=True)
|
2664 |
+
load_one_image_button.click(update_prompt_image, inputs=[input_gallery, image1_slider], outputs=[prompt_image1])
|
2665 |
+
input_gallery.change(update_prompt_image, inputs=[input_gallery, image1_slider], outputs=[prompt_image1])
|
2666 |
+
input_gallery.change(fn=lambda x: gr.update(maximum=len(x)-1), inputs=[input_gallery], outputs=[image1_slider])
|
2667 |
+
image1_slider.change(update_prompt_image, inputs=[input_gallery, image1_slider], outputs=[prompt_image1])
|
2668 |
+
|
2669 |
+
child_idx = gr.State([])
|
2670 |
+
current_idx = gr.State(None)
|
2671 |
+
n_eig = gr.State(64)
|
2672 |
+
with gr.Column(scale=5, min_width=200):
|
2673 |
+
gr.Markdown("### Step 3: Check groupping")
|
2674 |
+
child_distance_slider = gr.Slider(0, 0.5, step=0.001, label="Child Distance", value=0.1, elem_id="child_distance_slider", interactive=True)
|
2675 |
+
child_distance_slider.visible = False
|
2676 |
+
overlay_image_checkbox = gr.Checkbox(label="Overlay Image", value=True, elem_id="overlay_image_checkbox", interactive=True)
|
2677 |
+
n_eig_slider = gr.Slider(0, 1024, step=1, label="Number of Eigenvectors", value=256, elem_id="n_eig_slider", interactive=True)
|
2678 |
+
run_button = gr.Button("🔴 RUN", elem_id="run_groupping", variant='primary')
|
2679 |
+
gr.Markdown("1. 🔴 RUN </br>2. repeat: [+num_eigvecs] / [-num_eigvecs]")
|
2680 |
+
with gr.Row():
|
2681 |
+
doublue_eigs_button = gr.Button("[+num_eigvecs]", elem_id="doublue_eigs_button", variant='secondary')
|
2682 |
+
half_eigs_button = gr.Button("[-num_eigvecs]", elem_id="half_eigs_button", variant='secondary')
|
2683 |
+
current_plot = gr.Gallery(value=None, label="Current", show_label=True, elem_id="current_plot", interactive=False, rows=[1], columns=[2])
|
2684 |
+
|
2685 |
+
def relative_xy(prompts):
|
2686 |
+
image = prompts['image']
|
2687 |
+
points = np.asarray(prompts['points'])
|
2688 |
+
if points.shape[0] == 0:
|
2689 |
+
return [], []
|
2690 |
+
is_point = points[:, 5] == 4.0
|
2691 |
+
points = points[is_point]
|
2692 |
+
is_positive = points[:, 2] == 1.0
|
2693 |
+
is_negative = points[:, 2] == 0.0
|
2694 |
+
xy = points[:, :2].tolist()
|
2695 |
+
if isinstance(image, str):
|
2696 |
+
image = Image.open(image)
|
2697 |
+
image = np.array(image)
|
2698 |
+
h, w = image.shape[:2]
|
2699 |
+
new_xy = [(x/w, y/h) for x, y in xy]
|
2700 |
+
# print(new_xy)
|
2701 |
+
return new_xy, is_positive
|
2702 |
+
|
2703 |
+
def xy_eigvec(prompts, image_idx, eigvecs):
|
2704 |
+
eigvec = eigvecs[image_idx]
|
2705 |
+
xy, is_positive = relative_xy(prompts)
|
2706 |
+
for i, (x, y) in enumerate(xy):
|
2707 |
+
if not is_positive[i]:
|
2708 |
+
continue
|
2709 |
+
x = int(x * eigvec.shape[1])
|
2710 |
+
y = int(y * eigvec.shape[0])
|
2711 |
+
return eigvec[y, x], (y, x)
|
2712 |
+
|
2713 |
+
from ncut_pytorch.ncut_pytorch import _transform_heatmap
|
2714 |
+
def _run_heatmap_fn(images, eigvecs, prompt_image_idx, prompt_points, n_eig, flat_idx=None, raw_heatmap=False, overlay_image=True):
|
2715 |
+
left = eigvecs[..., :n_eig]
|
2716 |
+
if flat_idx is not None:
|
2717 |
+
right = eigvecs.reshape(-1, eigvecs.shape[-1])[flat_idx]
|
2718 |
+
y, x = None, None
|
2719 |
+
else:
|
2720 |
+
right, (y, x) = xy_eigvec(prompt_points, prompt_image_idx, eigvecs)
|
2721 |
+
right = right[:n_eig]
|
2722 |
+
left = F.normalize(left, p=2, dim=-1)
|
2723 |
+
_right = F.normalize(right, p=2, dim=-1)
|
2724 |
+
heatmap = left @ _right.unsqueeze(-1)
|
2725 |
+
heatmap = heatmap.squeeze(-1)
|
2726 |
+
# heatmap = 1 - heatmap
|
2727 |
+
# heatmap = _transform_heatmap(heatmap)
|
2728 |
+
if raw_heatmap:
|
2729 |
+
return heatmap
|
2730 |
+
# apply hot colormap and covert to PIL image 256x256
|
2731 |
+
# gr.Info(f"heatmap vmin: {heatmap.min()}, vmax: {heatmap.max()}, mean: {heatmap.mean()}")
|
2732 |
+
heatmap = heatmap.cpu().numpy()
|
2733 |
+
hot_map = matplotlib.colormaps['hot']
|
2734 |
+
heatmap = hot_map(heatmap)
|
2735 |
+
pil_images = to_pil_images(torch.tensor(heatmap), target_size=256, force_size=True)
|
2736 |
+
if overlay_image:
|
2737 |
+
overlaied_images = []
|
2738 |
+
for i_image in range(len(images)):
|
2739 |
+
rgb_image = images[i_image].resize((256, 256))
|
2740 |
+
rgb_image = np.array(rgb_image)
|
2741 |
+
heatmap_image = np.array(pil_images[i_image])[..., :3]
|
2742 |
+
blend_image = 0.5 * rgb_image + 0.5 * heatmap_image
|
2743 |
+
blend_image = Image.fromarray(blend_image.astype(np.uint8))
|
2744 |
+
overlaied_images.append(blend_image)
|
2745 |
+
pil_images = overlaied_images
|
2746 |
+
return pil_images, (y, x)
|
2747 |
+
|
2748 |
+
@torch.no_grad()
|
2749 |
+
def run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx=None, overlay_image=True):
|
2750 |
+
gr.Info(f"current number of eigenvectors: {n_eig}", 2)
|
2751 |
+
eigvecs = torch.tensor(eigvecs)
|
2752 |
+
image1_slider = min(image1_slider, len(images)-1)
|
2753 |
+
images = [image[0] for image in images]
|
2754 |
+
if isinstance(images[0], str):
|
2755 |
+
images = [Image.open(image[0]).convert("RGB").resize((256, 256)) for image in images]
|
2756 |
+
|
2757 |
+
current_heatmap, (y, x) = _run_heatmap_fn(images, eigvecs, image1_slider, prompt_image1, n_eig, flat_idx, overlay_image=overlay_image)
|
2758 |
+
|
2759 |
+
return current_heatmap
|
2760 |
+
|
2761 |
+
def doublue_eigs_wrapper(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx=None, overlay_image=True):
|
2762 |
+
n_eig = int(n_eig*2)
|
2763 |
+
n_eig = min(n_eig, eigvecs.shape[-1])
|
2764 |
+
n_eig = max(n_eig, 1)
|
2765 |
+
return gr.update(value=n_eig), run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx, overlay_image=overlay_image)
|
2766 |
+
|
2767 |
+
def half_eigs_wrapper(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, current_idx=None, overlay_image=True):
|
2768 |
+
n_eig = int(n_eig/2)
|
2769 |
+
n_eig = min(n_eig, eigvecs.shape[-1])
|
2770 |
+
n_eig = max(n_eig, 1)
|
2771 |
+
return gr.update(value=n_eig), run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, current_idx, overlay_image=overlay_image)
|
2772 |
+
|
2773 |
+
none_placeholder = gr.State(None)
|
2774 |
+
run_button.click(
|
2775 |
+
run_heatmap,
|
2776 |
+
inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig_slider, child_distance_slider, none_placeholder, overlay_image_checkbox],
|
2777 |
+
outputs=[current_plot],
|
2778 |
+
)
|
2779 |
+
|
2780 |
+
doublue_eigs_button.click(
|
2781 |
+
doublue_eigs_wrapper,
|
2782 |
+
inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig_slider, child_distance_slider, none_placeholder, overlay_image_checkbox],
|
2783 |
+
outputs=[n_eig_slider, current_plot],
|
2784 |
+
)
|
2785 |
+
|
2786 |
+
half_eigs_button.click(
|
2787 |
+
half_eigs_wrapper,
|
2788 |
+
inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig_slider, child_distance_slider, current_idx, overlay_image_checkbox],
|
2789 |
+
outputs=[n_eig_slider, current_plot],
|
2790 |
+
)
|
2791 |
+
|
2792 |
with gr.Tab('AlignedCut'):
|
2793 |
|
2794 |
with gr.Row():
|
|
|
2799 |
|
2800 |
with gr.Column(scale=5, min_width=200):
|
2801 |
output_gallery = make_output_images_section()
|
2802 |
+
# cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[2], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
2803 |
[
|
2804 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
2805 |
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
|
|
2812 |
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
2813 |
|
2814 |
submit_button.click(
|
2815 |
+
partial(run_fn, n_ret=1, plot_clusters=False),
|
2816 |
inputs=[
|
2817 |
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
2818 |
positive_prompt, negative_prompt,
|
|
|
2821 |
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
2822 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown
|
2823 |
],
|
2824 |
+
outputs=[output_gallery, logging_text],
|
2825 |
api_name="API_AlignedCut",
|
2826 |
scroll_to_output=True,
|
2827 |
)
|
|
|
2837 |
with gr.Column(scale=5, min_width=200):
|
2838 |
output_gallery = make_output_images_section()
|
2839 |
add_download_button(output_gallery, "ncut_embed")
|
2840 |
+
norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False)
|
2841 |
add_download_button(norm_gallery, "eig_norm")
|
2842 |
+
cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False)
|
2843 |
add_download_button(cluster_gallery, "clusters")
|
2844 |
[
|
2845 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
|
|
2929 |
api_name="API_NCut",
|
2930 |
)
|
2931 |
|
2932 |
+
with gr.Tab('RecursiveCut'):
|
2933 |
gr.Markdown('NCUT can be applied recursively, the eigenvectors from previous iteration is the input for the next iteration NCUT. ')
|
2934 |
gr.Markdown('__Recursive NCUT__ can amplify or weaken the connections, depending on the `affinity_focal_gamma` setting, please see [Documentation](https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/#recursive-ncut)')
|
2935 |
|
|
|
2942 |
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
|
2943 |
with gr.Column(scale=5, min_width=200):
|
2944 |
gr.Markdown('### Output (Recursion #1)')
|
2945 |
+
l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
|
2946 |
add_rotate_flip_buttons(l1_gallery)
|
2947 |
with gr.Column(scale=5, min_width=200):
|
2948 |
gr.Markdown('### Output (Recursion #2)')
|
2949 |
+
l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
|
2950 |
add_rotate_flip_buttons(l2_gallery)
|
2951 |
with gr.Column(scale=5, min_width=200):
|
2952 |
gr.Markdown('### Output (Recursion #3)')
|
2953 |
+
l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
|
2954 |
add_rotate_flip_buttons(l3_gallery)
|
2955 |
with gr.Row():
|
2956 |
|
|
|
2998 |
api_name="API_RecursiveCut"
|
2999 |
)
|
3000 |
|
3001 |
+
with gr.Tab('RecursiveCut (Advanced)', visible=False) as tab_recursivecut_advanced:
|
3002 |
|
3003 |
with gr.Row():
|
3004 |
with gr.Column(scale=5, min_width=200):
|
|
|
3007 |
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", lines=20)
|
3008 |
with gr.Column(scale=5, min_width=200):
|
3009 |
gr.Markdown('### Output (Recursion #1)')
|
3010 |
+
l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
|
3011 |
add_rotate_flip_buttons(l1_gallery)
|
3012 |
add_download_button(l1_gallery, "ncut_embed_recur1")
|
3013 |
+
l1_norm_gallery = gr.Gallery(value=[], label="Recursion #1 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False)
|
3014 |
add_download_button(l1_norm_gallery, "eig_norm_recur1")
|
3015 |
+
l1_cluster_gallery = gr.Gallery(value=[], label="Recursion #1 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height='450px', show_share_button=True, preview=False, interactive=False)
|
3016 |
add_download_button(l1_cluster_gallery, "clusters_recur1")
|
3017 |
with gr.Column(scale=5, min_width=200):
|
3018 |
gr.Markdown('### Output (Recursion #2)')
|
3019 |
+
l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
|
3020 |
add_rotate_flip_buttons(l2_gallery)
|
3021 |
add_download_button(l2_gallery, "ncut_embed_recur2")
|
3022 |
+
l2_norm_gallery = gr.Gallery(value=[], label="Recursion #2 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False)
|
3023 |
add_download_button(l2_norm_gallery, "eig_norm_recur2")
|
3024 |
+
l2_cluster_gallery = gr.Gallery(value=[], label="Recursion #2 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height='450px', show_share_button=True, preview=False, interactive=False)
|
3025 |
add_download_button(l2_cluster_gallery, "clusters_recur2")
|
3026 |
with gr.Column(scale=5, min_width=200):
|
3027 |
gr.Markdown('### Output (Recursion #3)')
|
3028 |
+
l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
|
3029 |
add_rotate_flip_buttons(l3_gallery)
|
3030 |
add_download_button(l3_gallery, "ncut_embed_recur3")
|
3031 |
+
l3_norm_gallery = gr.Gallery(value=[], label="Recursion #3 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False)
|
3032 |
add_download_button(l3_norm_gallery, "eig_norm_recur3")
|
3033 |
+
l3_cluster_gallery = gr.Gallery(value=[], label="Recursion #3 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height='450px', show_share_button=True, preview=False, interactive=False)
|
3034 |
add_download_button(l3_cluster_gallery, "clusters_recur3")
|
3035 |
|
3036 |
with gr.Row():
|
|
|
3078 |
)
|
3079 |
|
3080 |
|
3081 |
+
with gr.Tab('Video', visible=True) as tab_video:
|
3082 |
with gr.Row():
|
3083 |
with gr.Column(scale=5, min_width=200):
|
3084 |
video_input_gallery, submit_button, clear_video_button, max_frame_number = make_input_video_section()
|
|
|
3126 |
from draft_gradio_app_text import make_demo
|
3127 |
make_demo()
|
3128 |
|
3129 |
+
with gr.Tab('Vision-Language', visible=False) as tab_lisa:
|
3130 |
gr.Markdown('[LISA](https://arxiv.org/pdf/2308.00692) is a vision-language model. Input a text prompt and image, LISA generate segmentation masks.')
|
3131 |
gr.Markdown('In the mask decoder layers, LISA updates the image features w.r.t. the text prompt')
|
3132 |
gr.Markdown('This page aims to see how the text prompt affects the image features')
|
|
|
3135 |
with gr.Row():
|
3136 |
with gr.Column(scale=5, min_width=200):
|
3137 |
gr.Markdown('### Output (Prompt #1)')
|
3138 |
+
l1_gallery = gr.Gallery(format='png', value=[], label="Prompt #1", show_label=False, elem_id="ncut_p1", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
|
3139 |
prompt1 = gr.Textbox(label="Input Prompt #1", elem_id="prompt1", value="where is the person, include the clothes, don't include the guitar and chair", lines=3)
|
3140 |
with gr.Column(scale=5, min_width=200):
|
3141 |
gr.Markdown('### Output (Prompt #2)')
|
3142 |
+
l2_gallery = gr.Gallery(format='png', value=[], label="Prompt #2", show_label=False, elem_id="ncut_p2", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
|
3143 |
prompt2 = gr.Textbox(label="Input Prompt #2", elem_id="prompt2", value="where is the Gibson Les Pual guitar", lines=3)
|
3144 |
with gr.Column(scale=5, min_width=200):
|
3145 |
gr.Markdown('### Output (Prompt #3)')
|
3146 |
+
l3_gallery = gr.Gallery(format='png', value=[], label="Prompt #3", show_label=False, elem_id="ncut_p3", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
|
3147 |
prompt3 = gr.Textbox(label="Input Prompt #3", elem_id="prompt3", value="where is the floor", lines=3)
|
3148 |
|
3149 |
with gr.Row():
|
|
|
3175 |
outputs=galleries + [logging_text],
|
3176 |
)
|
3177 |
|
3178 |
+
with gr.Tab('Model Aligned', visible=False) as tab_aligned:
|
3179 |
gr.Markdown('This page reproduce the results from the paper [AlignedCut](https://arxiv.org/abs/2406.18344)')
|
3180 |
gr.Markdown('---')
|
3181 |
gr.Markdown('**Features are aligned across models and layers.** A linear alignment transform is trained for each model/layer, learning signal comes from 1) fMRI brain activation and 2) segmentation preserving eigen-constraints.')
|
|
|
3238 |
gr.Markdown('')
|
3239 |
gr.Markdown("To see a good pattern, you will need to load 100~1000 images. 100 images need 10sec for RTX4090. Running out of HuggingFace GPU Quota? Try [Demo](https://ncut-pytorch.readthedocs.io/en/latest/demo/) hosted at UPenn")
|
3240 |
gr.Markdown('---')
|
3241 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3242 |
gr.Markdown('### Output (Recursion #1)')
|
3243 |
+
l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[100], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False, preview=True)
|
3244 |
add_rotate_flip_buttons(l1_gallery)
|
3245 |
add_download_button(l1_gallery, "modelaligned_recur1")
|
3246 |
gr.Markdown('### Output (Recursion #2)')
|
3247 |
+
l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[100], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False, preview=True)
|
3248 |
add_rotate_flip_buttons(l2_gallery)
|
3249 |
add_download_button(l2_gallery, "modelaligned_recur2")
|
3250 |
gr.Markdown('### Output (Recursion #3)')
|
3251 |
+
l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[100], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False, preview=True)
|
3252 |
add_rotate_flip_buttons(l3_gallery)
|
3253 |
add_download_button(l3_gallery, "modelaligned_recur3")
|
3254 |
|
|
|
3317 |
def add_one_model(i_model=1):
|
3318 |
with gr.Column(scale=5, min_width=200) as col:
|
3319 |
gr.Markdown(f'### Output Images')
|
3320 |
+
output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=False, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
|
3321 |
submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
|
3322 |
add_rotate_flip_buttons(output_gallery)
|
3323 |
[
|
|
|
3424 |
def add_one_model(i_model=1):
|
3425 |
with gr.Column(scale=5, min_width=200) as col:
|
3426 |
gr.Markdown(f'### Output Images')
|
3427 |
+
output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
|
3428 |
submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
|
3429 |
add_rotate_flip_buttons(output_gallery)
|
3430 |
add_download_button(output_gallery, f"ncut_embed")
|
3431 |
+
mlp_gallery = gr.Gallery(format='png', value=[], label="MLP color align", show_label=True, elem_id=f"mlp{i_model}", columns=[3], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
|
3432 |
add_mlp_fitting_buttons(output_gallery, mlp_gallery)
|
3433 |
add_download_button(mlp_gallery, f"mlp_color_align")
|
3434 |
+
norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id=f"eig_norm{i_model}", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False)
|
3435 |
add_download_button(norm_gallery, f"eig_norm")
|
3436 |
+
cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id=f"clusters{i_model}", columns=[2], rows=[4], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False)
|
3437 |
add_download_button(cluster_gallery, f"clusters")
|
3438 |
[
|
3439 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
|
|
3639 |
def add_one_model(i_model=1):
|
3640 |
with gr.Column(scale=5, min_width=200) as col:
|
3641 |
gr.Markdown(f'### Output Images')
|
3642 |
+
output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
|
3643 |
submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
|
3644 |
add_rotate_flip_buttons(output_gallery)
|
3645 |
add_download_button(output_gallery, f"ncut_embed")
|
3646 |
+
mlp_gallery = gr.Gallery(format='png', value=[], label="MLP color align", show_label=True, elem_id=f"mlp{i_model}", columns=[3], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False)
|
3647 |
add_mlp_fitting_buttons(output_gallery, mlp_gallery)
|
3648 |
add_download_button(mlp_gallery, f"mlp_color_align")
|
3649 |
+
norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id=f"eig_norm{i_model}", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False)
|
3650 |
add_download_button(norm_gallery, f"eig_norm")
|
3651 |
+
cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id=f"clusters{i_model}", columns=[2], rows=[4], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False)
|
3652 |
add_download_button(cluster_gallery, f"clusters")
|
3653 |
[
|
3654 |
model_dropdown, layer_slider, node_type_dropdown, node_type_dropdown2, head_index_text, make_symmetric, num_eig_slider,
|
|
|
3831 |
|
3832 |
with gr.Column(scale=5, min_width=200):
|
3833 |
gr.Markdown("### Step 3: Segment and Crop")
|
3834 |
+
mask_gallery = gr.Gallery(value=[], label="Segmentation Masks", show_label=True, elem_id="mask_gallery", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, interactive=False)
|
3835 |
run_crop_button = gr.Button("🔴 RUN", elem_id="run_crop_button", variant='primary')
|
3836 |
add_download_button(mask_gallery, "mask")
|
3837 |
distance_threshold_slider = gr.Slider(0, 1, step=0.01, label="Mask Threshold (FG)", value=0.9, elem_id="distance_threshold", info="increase for smaller FG mask")
|
|
|
3841 |
overlay_image_checkbox = gr.Checkbox(label="Overlay Original Image", value=True, elem_id="overlay_image_checkbox")
|
3842 |
# filter_small_area_checkbox = gr.Checkbox(label="Noise Reduction", value=True, elem_id="filter_small_area_checkbox")
|
3843 |
distance_power_slider = gr.Slider(-3, 3, step=0.01, label="Distance Power", value=0.5, elem_id="distance_power", info="d = d^p", visible=False)
|
3844 |
+
crop_gallery = gr.Gallery(value=[], label="Cropped Images", show_label=True, elem_id="crop_gallery", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, interactive=False)
|
3845 |
add_download_button(crop_gallery, "cropped")
|
3846 |
crop_expand_slider = gr.Slider(1.0, 2.0, step=0.1, label="Crop bbox Expand Factor", value=1.0, elem_id="crop_expand", info="increase for larger crop", visible=True)
|
3847 |
area_threshold_slider = gr.Slider(0, 100, step=0.1, label="Area Threshold (%)", value=3, elem_id="area_threshold", info="for noise filtering (area of connected components)", visible=False)
|
|
|
4319 |
outputs=[n_eig, current_idx, parent_plot, current_plot, *child_plots, child_idx],
|
4320 |
)
|
4321 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4322 |
|
4323 |
with gr.Tab('📄About'):
|
4324 |
with gr.Column():
|
|
|
4362 |
label = "".join(label)
|
4363 |
return n_smiles, gr.update(label=label, value=False)
|
4364 |
|
|
|
|
|
|
|
|
|
|
|
4365 |
|
4366 |
+
def unlock_tabs(n_smiles, n_tab=1):
|
4367 |
if n_smiles == unlock_value:
|
4368 |
+
gr.Info("🔓 unlocked tabs", 2)
|
4369 |
+
return [gr.update(visible=True)] * n_tab
|
4370 |
+
return [gr.update()] * n_tab
|
4371 |
|
4372 |
hidden_button.change(update_smile, [n_smiles], [n_smiles, hidden_button])
|
4373 |
+
hidden_tabs = [tab_alignedcut_advanced, tab_model_aligned_advanced, tab_recursivecut_advanced,
|
4374 |
+
tab_compare_models_advanced, tab_directed_ncut, test_playground_tab, tab_aligned, tab_lisa]
|
4375 |
+
hidden_button.change(partial(unlock_tabs, n_tab=len(hidden_tabs)), [n_smiles], hidden_tabs)
|
|
|
|
|
|
|
4376 |
|
|
|
|
|
|
|
|
|
|
|
4377 |
with gr.Row():
|
4378 |
gr.Markdown("**This demo is for Python package `ncut-pytorch`, [Documentation](https://ncut-pytorch.readthedocs.io/)**")
|
4379 |
|