Spaces:
Running
on
Zero
Running
on
Zero
add kway
Browse files
app.py
CHANGED
|
@@ -927,6 +927,44 @@ def ncut_run(
|
|
| 927 |
|
| 928 |
return to_pil_images(rgb_all), logging_str
|
| 929 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 930 |
# ailgnedcut
|
| 931 |
if not directed:
|
| 932 |
only_eigvecs = kwargs.get("only_eigvecs", False)
|
|
@@ -1318,6 +1356,7 @@ def run_fn(
|
|
| 1318 |
return_eigvec_and_rgb=False,
|
| 1319 |
normalize_eigvec_return=False,
|
| 1320 |
separate_fg_bg=False,
|
|
|
|
| 1321 |
):
|
| 1322 |
# print(node_type2, head_index_text, make_symmetric)
|
| 1323 |
progress=gr.Progress()
|
|
@@ -1463,6 +1502,7 @@ def run_fn(
|
|
| 1463 |
"return_eigvec_and_rgb": return_eigvec_and_rgb,
|
| 1464 |
"normalize_eigvec_return": normalize_eigvec_return,
|
| 1465 |
"separate_fg_bg": separate_fg_bg,
|
|
|
|
| 1466 |
}
|
| 1467 |
# print(kwargs)
|
| 1468 |
|
|
@@ -4348,6 +4388,44 @@ with demo:
|
|
| 4348 |
outputs=[output_gallery, logging_text],
|
| 4349 |
)
|
| 4350 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4351 |
|
| 4352 |
with gr.Tab('Sub-cluster (dev)', visible=False) as sub_cluster_tab:
|
| 4353 |
with gr.Row():
|
|
|
|
| 927 |
|
| 928 |
return to_pil_images(rgb_all), logging_str
|
| 929 |
|
| 930 |
+
kway = kwargs.get("kway", False)
|
| 931 |
+
if kway:
|
| 932 |
+
only_eigvecs = True
|
| 933 |
+
rgb, _logging_str, eigvecs = compute_ncut(
|
| 934 |
+
features,
|
| 935 |
+
num_eig=num_eig,
|
| 936 |
+
num_sample_ncut=num_sample_ncut,
|
| 937 |
+
affinity_focal_gamma=affinity_focal_gamma,
|
| 938 |
+
knn_ncut=knn_ncut,
|
| 939 |
+
knn_tsne=knn_tsne,
|
| 940 |
+
num_sample_tsne=num_sample_tsne,
|
| 941 |
+
embedding_method=embedding_method,
|
| 942 |
+
embedding_metric=embedding_metric,
|
| 943 |
+
perplexity=perplexity,
|
| 944 |
+
n_neighbors=n_neighbors,
|
| 945 |
+
min_dist=min_dist,
|
| 946 |
+
sampling_method=sampling_method,
|
| 947 |
+
indirect_connection=indirect_connection,
|
| 948 |
+
make_orthogonal=make_orthogonal,
|
| 949 |
+
metric=ncut_metric,
|
| 950 |
+
only_eigvecs=only_eigvecs,
|
| 951 |
+
)
|
| 952 |
+
from ncut_pytorch import kway_ncut
|
| 953 |
+
kway_onehot = kway_ncut(eigvecs) # [N, K]
|
| 954 |
+
kway_indices = kway_onehot.argmax(dim=-1) # [N]
|
| 955 |
+
kway_indices = kway_indices.cpu().numpy()
|
| 956 |
+
if kway_indices.max() > 10:
|
| 957 |
+
cm = plt.colormaps['tab20']
|
| 958 |
+
rgb = cm(kway_indices / 20)
|
| 959 |
+
else:
|
| 960 |
+
cm = plt.colormaps['tab10']
|
| 961 |
+
rgb = cm(kway_indices / 10)
|
| 962 |
+
if kway_indices.max() > 20:
|
| 963 |
+
gr.Error("Too many clusters for kway_ncut")
|
| 964 |
+
rgb = rgb[:, :3]
|
| 965 |
+
rgb = rgb.reshape(*features.shape[:-1], 3)
|
| 966 |
+
return to_pil_images(rgb), logging_str
|
| 967 |
+
|
| 968 |
# ailgnedcut
|
| 969 |
if not directed:
|
| 970 |
only_eigvecs = kwargs.get("only_eigvecs", False)
|
|
|
|
| 1356 |
return_eigvec_and_rgb=False,
|
| 1357 |
normalize_eigvec_return=False,
|
| 1358 |
separate_fg_bg=False,
|
| 1359 |
+
kway=False,
|
| 1360 |
):
|
| 1361 |
# print(node_type2, head_index_text, make_symmetric)
|
| 1362 |
progress=gr.Progress()
|
|
|
|
| 1502 |
"return_eigvec_and_rgb": return_eigvec_and_rgb,
|
| 1503 |
"normalize_eigvec_return": normalize_eigvec_return,
|
| 1504 |
"separate_fg_bg": separate_fg_bg,
|
| 1505 |
+
"kway": kway,
|
| 1506 |
}
|
| 1507 |
# print(kwargs)
|
| 1508 |
|
|
|
|
| 4388 |
outputs=[output_gallery, logging_text],
|
| 4389 |
)
|
| 4390 |
|
| 4391 |
+
|
| 4392 |
+
with gr.Tab('K-way'):
|
| 4393 |
+
|
| 4394 |
+
with gr.Row():
|
| 4395 |
+
with gr.Column(scale=5, min_width=200):
|
| 4396 |
+
input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section()
|
| 4397 |
+
num_images_slider.value = 30
|
| 4398 |
+
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
|
| 4399 |
+
|
| 4400 |
+
with gr.Column(scale=5, min_width=200):
|
| 4401 |
+
output_gallery = make_output_images_section()
|
| 4402 |
+
# cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[2], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
| 4403 |
+
[
|
| 4404 |
+
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
| 4405 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
| 4406 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
| 4407 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
| 4408 |
+
sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
|
| 4409 |
+
] = make_parameters_section()
|
| 4410 |
+
num_eig_slider.value = 6
|
| 4411 |
+
num_eig_slider.maximum = 20
|
| 4412 |
+
|
| 4413 |
+
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
| 4414 |
+
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
| 4415 |
+
|
| 4416 |
+
submit_button.click(
|
| 4417 |
+
partial(run_fn, n_ret=1, plot_clusters=False, kway=True),
|
| 4418 |
+
inputs=[
|
| 4419 |
+
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
| 4420 |
+
positive_prompt, negative_prompt,
|
| 4421 |
+
false_placeholder, no_prompt, no_prompt, no_prompt,
|
| 4422 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
| 4423 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
| 4424 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown,
|
| 4425 |
+
*[false_placeholder]*12,
|
| 4426 |
+
],
|
| 4427 |
+
outputs=[output_gallery, logging_text],
|
| 4428 |
+
)
|
| 4429 |
|
| 4430 |
with gr.Tab('Sub-cluster (dev)', visible=False) as sub_cluster_tab:
|
| 4431 |
with gr.Row():
|