Spaces:
Running
on
Zero
Running
on
Zero
add fg
Browse files
app.py
CHANGED
@@ -409,34 +409,40 @@ def blend_image_with_heatmap(image, heatmap, opacity1=0.5, opacity2=0.5):
|
|
409 |
return blended.astype(np.uint8)
|
410 |
|
411 |
|
412 |
-
def segment_fg_bg(images):
|
|
|
|
|
|
|
|
|
413 |
|
414 |
-
images = F.interpolate(images, (
|
415 |
|
416 |
# model = load_alignedthreemodel()
|
417 |
model = load_model("CLIP(ViT-B-16/openai)")
|
418 |
from ncut_pytorch.backbone import resample_position_embeddings
|
419 |
pos_embed = model.model.visual.positional_embedding
|
420 |
-
pos_embed = resample_position_embeddings(pos_embed,
|
421 |
model.model.visual.positional_embedding = torch.nn.Parameter(pos_embed)
|
422 |
|
423 |
-
batch_size = 4
|
424 |
chunk_idxs = torch.split(torch.arange(images.shape[0]), batch_size)
|
425 |
|
426 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
427 |
model.to(device)
|
428 |
-
|
429 |
-
|
|
|
|
|
430 |
|
431 |
fg_acts, bg_acts = [], []
|
432 |
for chunk_idx in chunk_idxs:
|
433 |
with torch.no_grad():
|
434 |
input_images = images[chunk_idx].to(device)
|
435 |
# transform the input images
|
436 |
-
|
|
|
437 |
# output = model(input_images)[:, 5]
|
438 |
-
output = model(input_images)['attn'][
|
439 |
-
fg_act = output[:,
|
440 |
bg_act = output[:, 0, 0].mean(0)
|
441 |
fg_acts.append(fg_act)
|
442 |
bg_acts.append(bg_act)
|
@@ -445,21 +451,6 @@ def segment_fg_bg(images):
|
|
445 |
fg_act = F.normalize(fg_act, dim=-1)
|
446 |
bg_act = F.normalize(bg_act, dim=-1)
|
447 |
|
448 |
-
# ref_image = default_images[0]
|
449 |
-
# image = Image.open(ref_image).convert("RGB").resize((224, 224), Image.Resampling.BILINEAR)
|
450 |
-
# image = torch.tensor(np.array(image)).permute(2, 0, 1).float().to(device)
|
451 |
-
# image = (image / 255.0 - means) / stds
|
452 |
-
# output = model(image)['attn'][6][0]
|
453 |
-
# # print(output.shape)
|
454 |
-
# # bg on the center
|
455 |
-
# fg_act = output[5, 5]
|
456 |
-
# # bg on the bottom left
|
457 |
-
# bg_act = output[0, 0]
|
458 |
-
# fg_act = F.normalize(fg_act, dim=-1)
|
459 |
-
# bg_act = F.normalize(bg_act, dim=-1)
|
460 |
-
|
461 |
-
# print(images.mean(), images.std())
|
462 |
-
|
463 |
fg_act, bg_act = fg_act.to(device), bg_act.to(device)
|
464 |
chunk_idxs = torch.split(torch.arange(images.shape[0]), batch_size)
|
465 |
heatmap_fgs, heatmap_bgs = [], []
|
@@ -467,9 +458,10 @@ def segment_fg_bg(images):
|
|
467 |
with torch.no_grad():
|
468 |
input_images = images[chunk_idx].to(device)
|
469 |
# transform the input images
|
470 |
-
|
|
|
471 |
# output = model(input_images)[:, 5]
|
472 |
-
output = model(input_images)['attn'][
|
473 |
output = F.normalize(output, dim=-1)
|
474 |
heatmap_fg = output @ fg_act[:, None] # [B, H, W, 1]
|
475 |
heatmap_bg = output @ bg_act[:, None] # [B, H, W, 1]
|
@@ -868,6 +860,71 @@ def ncut_run(
|
|
868 |
return to_pil_images(rgb), logging_str
|
869 |
|
870 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
871 |
|
872 |
# ailgnedcut
|
873 |
if not directed:
|
@@ -1037,9 +1094,9 @@ def _ncut_run(*args, **kwargs):
|
|
1037 |
torch.cuda.empty_cache()
|
1038 |
return *(None for _ in range(n_ret)), "Error: " + str(e)
|
1039 |
|
1040 |
-
|
1041 |
-
|
1042 |
-
|
1043 |
|
1044 |
if USE_HUGGINGFACE_ZEROGPU:
|
1045 |
@spaces.GPU(duration=30)
|
@@ -1250,6 +1307,7 @@ def run_fn(
|
|
1250 |
node_type2="k",
|
1251 |
head_index_text='all',
|
1252 |
make_symmetric=False,
|
|
|
1253 |
n_ret=1,
|
1254 |
plot_clusters=False,
|
1255 |
alignedcut_eig_norm_plot=False,
|
@@ -1258,6 +1316,7 @@ def run_fn(
|
|
1258 |
only_eigvecs=False,
|
1259 |
return_eigvec_and_rgb=False,
|
1260 |
normalize_eigvec_return=False,
|
|
|
1261 |
):
|
1262 |
# print(node_type2, head_index_text, make_symmetric)
|
1263 |
progress=gr.Progress()
|
@@ -1390,6 +1449,7 @@ def run_fn(
|
|
1390 |
"lisa_prompt2": lisa_prompt2,
|
1391 |
"lisa_prompt3": lisa_prompt3,
|
1392 |
"is_lisa": is_lisa,
|
|
|
1393 |
"n_ret": n_ret,
|
1394 |
"plot_clusters": plot_clusters,
|
1395 |
"alignedcut_eig_norm_plot": alignedcut_eig_norm_plot,
|
@@ -1401,6 +1461,7 @@ def run_fn(
|
|
1401 |
"only_eigvecs": only_eigvecs,
|
1402 |
"return_eigvec_and_rgb": return_eigvec_and_rgb,
|
1403 |
"normalize_eigvec_return": normalize_eigvec_return,
|
|
|
1404 |
}
|
1405 |
# print(kwargs)
|
1406 |
|
@@ -2196,7 +2257,8 @@ demo = gr.Blocks(
|
|
2196 |
css=custom_css,
|
2197 |
)
|
2198 |
with demo:
|
2199 |
-
|
|
|
2200 |
with gr.Tab('PlayGround'):
|
2201 |
eigvecs = gr.State(np.array([]))
|
2202 |
tsne3d_rgb = gr.State(np.array([]))
|
@@ -4247,6 +4309,45 @@ with demo:
|
|
4247 |
outputs=[mask_gallery, crop_gallery])
|
4248 |
|
4249 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4250 |
with gr.Tab('Sub-cluster (dev)', visible=False) as sub_cluster_tab:
|
4251 |
with gr.Row():
|
4252 |
image_cluster_plot = gr.Image(value=None, label="Image-level clustering", elem_id="image_cluster_plot", interactive=False)
|
|
|
409 |
return blended.astype(np.uint8)
|
410 |
|
411 |
|
412 |
+
def segment_fg_bg(images, hw=224, i_layer=6, batch_size=4, transform_images=True):
|
413 |
+
|
414 |
+
assert hw % 16 == 0, "The height and width of the image must be divisible by 16."
|
415 |
+
psz = hw // 16
|
416 |
+
center_xy = (psz-1) // 2
|
417 |
|
418 |
+
images = F.interpolate(images, (hw, hw), mode="bilinear")
|
419 |
|
420 |
# model = load_alignedthreemodel()
|
421 |
model = load_model("CLIP(ViT-B-16/openai)")
|
422 |
from ncut_pytorch.backbone import resample_position_embeddings
|
423 |
pos_embed = model.model.visual.positional_embedding
|
424 |
+
pos_embed = resample_position_embeddings(pos_embed, psz, psz)
|
425 |
model.model.visual.positional_embedding = torch.nn.Parameter(pos_embed)
|
426 |
|
|
|
427 |
chunk_idxs = torch.split(torch.arange(images.shape[0]), batch_size)
|
428 |
|
429 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
430 |
model.to(device)
|
431 |
+
|
432 |
+
if transform_images:
|
433 |
+
means = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
|
434 |
+
stds = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
|
435 |
|
436 |
fg_acts, bg_acts = [], []
|
437 |
for chunk_idx in chunk_idxs:
|
438 |
with torch.no_grad():
|
439 |
input_images = images[chunk_idx].to(device)
|
440 |
# transform the input images
|
441 |
+
if transform_images:
|
442 |
+
input_images = (input_images - means) / stds
|
443 |
# output = model(input_images)[:, 5]
|
444 |
+
output = model(input_images)['attn'][i_layer] # [B, H=14, W=14, C]
|
445 |
+
fg_act = output[:, center_xy, center_xy].mean(0)
|
446 |
bg_act = output[:, 0, 0].mean(0)
|
447 |
fg_acts.append(fg_act)
|
448 |
bg_acts.append(bg_act)
|
|
|
451 |
fg_act = F.normalize(fg_act, dim=-1)
|
452 |
bg_act = F.normalize(bg_act, dim=-1)
|
453 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
454 |
fg_act, bg_act = fg_act.to(device), bg_act.to(device)
|
455 |
chunk_idxs = torch.split(torch.arange(images.shape[0]), batch_size)
|
456 |
heatmap_fgs, heatmap_bgs = [], []
|
|
|
458 |
with torch.no_grad():
|
459 |
input_images = images[chunk_idx].to(device)
|
460 |
# transform the input images
|
461 |
+
if transform_images:
|
462 |
+
input_images = (input_images - means) / stds
|
463 |
# output = model(input_images)[:, 5]
|
464 |
+
output = model(input_images)['attn'][i_layer]
|
465 |
output = F.normalize(output, dim=-1)
|
466 |
heatmap_fg = output @ fg_act[:, None] # [B, H, W, 1]
|
467 |
heatmap_bg = output @ bg_act[:, None] # [B, H, W, 1]
|
|
|
860 |
return to_pil_images(rgb), logging_str
|
861 |
|
862 |
|
863 |
+
# fg-bg separated
|
864 |
+
separate_fg_bg = kwargs.get("separate_fg_bg", False)
|
865 |
+
if separate_fg_bg:
|
866 |
+
fg_threshold = kwargs.get("fg_threshold", 0.5)
|
867 |
+
feature_hw = features.shape[1]
|
868 |
+
progress(0.4, desc="Segmenting FG-BG")
|
869 |
+
heatmap_fg, heatmap_bg = segment_fg_bg(images, hw=448, transform_images=False, i_layer=4)
|
870 |
+
heatmap_fg = 1 - heatmap_fg
|
871 |
+
heatmap_bg = 1 - heatmap_bg
|
872 |
+
b, h, w, c = features.shape
|
873 |
+
heatmap_bg = rearrange(heatmap_bg, 'b h w c -> b c h w')
|
874 |
+
heatmap_fg = rearrange(heatmap_fg, 'b h w c -> b c h w')
|
875 |
+
is_cuda = torch.cuda.is_available()
|
876 |
+
heatmap_fg = F.interpolate(heatmap_fg, (h, w), mode="bicubic")
|
877 |
+
heatmap_bg = F.interpolate(heatmap_bg, (h, w), mode="bicubic")
|
878 |
+
heatmap_fg = heatmap_fg.flatten()
|
879 |
+
heatmap_bg = heatmap_bg.flatten()
|
880 |
+
fg_minus_bg = heatmap_fg - heatmap_bg
|
881 |
+
|
882 |
+
def _to_mask(heatmap, threshold, gamma=0.5):
|
883 |
+
heatmap = (heatmap - heatmap.mean()) / heatmap.std()
|
884 |
+
heatmap = heatmap.double()
|
885 |
+
heatmap = torch.exp(heatmap)
|
886 |
+
heatmap = 1 / heatmap ** gamma
|
887 |
+
if heatmap.shape[0] > 10000:
|
888 |
+
np.random.seed(0)
|
889 |
+
random_idx = np.random.choice(heatmap.shape[0], 10000, replace=False)
|
890 |
+
vmin, vmax = heatmap[random_idx].quantile(0.01), heatmap[random_idx].quantile(0.99)
|
891 |
+
else:
|
892 |
+
vmin, vmax = heatmap.quantile(0.01), heatmap.quantile(0.99)
|
893 |
+
heatmap = (heatmap - vmin) / (vmax - vmin)
|
894 |
+
heatmap = heatmap.reshape(b, h, w)
|
895 |
+
mask = heatmap > threshold
|
896 |
+
return mask
|
897 |
+
|
898 |
+
fg_mask = _to_mask(fg_minus_bg, fg_threshold)
|
899 |
+
features_fg = features.flatten(0, 2)[fg_mask.flatten()]
|
900 |
+
|
901 |
+
progress(0.4, desc="NCut FG")
|
902 |
+
rgb, _logging_str, eigvecs = compute_ncut(
|
903 |
+
features_fg,
|
904 |
+
num_eig=num_eig,
|
905 |
+
num_sample_ncut=num_sample_ncut,
|
906 |
+
affinity_focal_gamma=affinity_focal_gamma,
|
907 |
+
knn_ncut=knn_ncut,
|
908 |
+
knn_tsne=knn_tsne,
|
909 |
+
num_sample_tsne=num_sample_tsne,
|
910 |
+
embedding_method=embedding_method,
|
911 |
+
embedding_metric=embedding_metric,
|
912 |
+
perplexity=perplexity,
|
913 |
+
n_neighbors=n_neighbors,
|
914 |
+
min_dist=min_dist,
|
915 |
+
sampling_method=sampling_method,
|
916 |
+
indirect_connection=indirect_connection,
|
917 |
+
make_orthogonal=make_orthogonal,
|
918 |
+
metric=ncut_metric,
|
919 |
+
only_eigvecs=False,
|
920 |
+
)
|
921 |
+
|
922 |
+
rgb_all = torch.zeros(b, h, w, 3)
|
923 |
+
rgb_all_flat = rgb_all.flatten(0, 2)
|
924 |
+
rgb_all_flat[fg_mask.flatten()] = rgb
|
925 |
+
rgb_all = rgb_all_flat.reshape(b, h, w, 3)
|
926 |
+
|
927 |
+
return to_pil_images(rgb_all), logging_str
|
928 |
|
929 |
# ailgnedcut
|
930 |
if not directed:
|
|
|
1094 |
torch.cuda.empty_cache()
|
1095 |
return *(None for _ in range(n_ret)), "Error: " + str(e)
|
1096 |
|
1097 |
+
ret = ncut_run(*args, **kwargs)
|
1098 |
+
ret = list(ret)[:n_ret] + [ret[-1]]
|
1099 |
+
return ret
|
1100 |
|
1101 |
if USE_HUGGINGFACE_ZEROGPU:
|
1102 |
@spaces.GPU(duration=30)
|
|
|
1307 |
node_type2="k",
|
1308 |
head_index_text='all',
|
1309 |
make_symmetric=False,
|
1310 |
+
fg_threshold=0.5,
|
1311 |
n_ret=1,
|
1312 |
plot_clusters=False,
|
1313 |
alignedcut_eig_norm_plot=False,
|
|
|
1316 |
only_eigvecs=False,
|
1317 |
return_eigvec_and_rgb=False,
|
1318 |
normalize_eigvec_return=False,
|
1319 |
+
separate_fg_bg=False,
|
1320 |
):
|
1321 |
# print(node_type2, head_index_text, make_symmetric)
|
1322 |
progress=gr.Progress()
|
|
|
1449 |
"lisa_prompt2": lisa_prompt2,
|
1450 |
"lisa_prompt3": lisa_prompt3,
|
1451 |
"is_lisa": is_lisa,
|
1452 |
+
"fg_threshold": fg_threshold,
|
1453 |
"n_ret": n_ret,
|
1454 |
"plot_clusters": plot_clusters,
|
1455 |
"alignedcut_eig_norm_plot": alignedcut_eig_norm_plot,
|
|
|
1461 |
"only_eigvecs": only_eigvecs,
|
1462 |
"return_eigvec_and_rgb": return_eigvec_and_rgb,
|
1463 |
"normalize_eigvec_return": normalize_eigvec_return,
|
1464 |
+
"separate_fg_bg": separate_fg_bg,
|
1465 |
}
|
1466 |
# print(kwargs)
|
1467 |
|
|
|
2257 |
css=custom_css,
|
2258 |
)
|
2259 |
with demo:
|
2260 |
+
|
2261 |
+
|
2262 |
with gr.Tab('PlayGround'):
|
2263 |
eigvecs = gr.State(np.array([]))
|
2264 |
tsne3d_rgb = gr.State(np.array([]))
|
|
|
4309 |
outputs=[mask_gallery, crop_gallery])
|
4310 |
|
4311 |
|
4312 |
+
with gr.Tab('FG'):
|
4313 |
+
|
4314 |
+
with gr.Row():
|
4315 |
+
with gr.Column(scale=5, min_width=200):
|
4316 |
+
input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section()
|
4317 |
+
num_images_slider.value = 30
|
4318 |
+
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
|
4319 |
+
|
4320 |
+
with gr.Column(scale=5, min_width=200):
|
4321 |
+
output_gallery = make_output_images_section()
|
4322 |
+
fg_threshold_slider = gr.Slider(0.01, 1, step=0.01, label="Foreground threshold", value=0.5, elem_id="fg_threshold", info="increase for more foreground")
|
4323 |
+
# cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[2], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
4324 |
+
[
|
4325 |
+
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
4326 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
4327 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
4328 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
4329 |
+
sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
|
4330 |
+
] = make_parameters_section()
|
4331 |
+
|
4332 |
+
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
4333 |
+
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
4334 |
+
|
4335 |
+
submit_button.click(
|
4336 |
+
partial(run_fn, n_ret=1, plot_clusters=False, separate_fg_bg=True),
|
4337 |
+
inputs=[
|
4338 |
+
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
4339 |
+
positive_prompt, negative_prompt,
|
4340 |
+
false_placeholder, no_prompt, no_prompt, no_prompt,
|
4341 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
4342 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
4343 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown,
|
4344 |
+
*[false_placeholder]*12,
|
4345 |
+
fg_threshold_slider
|
4346 |
+
],
|
4347 |
+
outputs=[output_gallery, logging_text],
|
4348 |
+
)
|
4349 |
+
|
4350 |
+
|
4351 |
with gr.Tab('Sub-cluster (dev)', visible=False) as sub_cluster_tab:
|
4352 |
with gr.Row():
|
4353 |
image_cluster_plot = gr.Image(value=None, label="Image-level clustering", elem_id="image_cluster_plot", interactive=False)
|