Spaces:
Running
on
Zero
Running
on
Zero
update sub-cluster
Browse files
app.py
CHANGED
@@ -148,7 +148,7 @@ def compute_ncut(
|
|
148 |
logging_str = ""
|
149 |
|
150 |
num_nodes = np.prod(features.shape[:-1])
|
151 |
-
if num_nodes / 2 < num_eig:
|
152 |
# raise gr.Error("Number of eigenvectors should be less than half the number of nodes.")
|
153 |
gr.Warning("Number of eigenvectors should be less than half the number of nodes.\n" f"Setting num_eig to {num_nodes // 2 - 1}.")
|
154 |
num_eig = num_nodes // 2 - 1
|
@@ -2196,7 +2196,7 @@ demo = gr.Blocks(
|
|
2196 |
css=custom_css,
|
2197 |
)
|
2198 |
with demo:
|
2199 |
-
|
2200 |
with gr.Tab('PlayGround'):
|
2201 |
eigvecs = gr.State(np.array([]))
|
2202 |
tsne3d_rgb = gr.State(np.array([]))
|
@@ -4246,6 +4246,219 @@ with demo:
|
|
4246 |
fg_contrast_slider, bg_contrast_slider],
|
4247 |
outputs=[mask_gallery, crop_gallery])
|
4248 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4249 |
|
4250 |
# with gr.Tab('PlayGround (test)', visible=False) as test_playground_tab:
|
4251 |
# eigvecs = gr.State(np.array([]))
|
@@ -4569,7 +4782,7 @@ with demo:
|
|
4569 |
|
4570 |
hidden_button.change(update_smile, [n_smiles], [n_smiles, hidden_button])
|
4571 |
hidden_tabs = [tab_alignedcut_advanced, tab_model_aligned_advanced, tab_recursivecut_advanced,
|
4572 |
-
tab_compare_models_advanced, tab_directed_ncut, tab_aligned, tab_lisa]
|
4573 |
hidden_button.change(partial(unlock_tabs, n_tab=len(hidden_tabs)), [n_smiles], hidden_tabs)
|
4574 |
|
4575 |
with gr.Row():
|
|
|
148 |
logging_str = ""
|
149 |
|
150 |
num_nodes = np.prod(features.shape[:-1])
|
151 |
+
if num_nodes / 2 - 1 < num_eig:
|
152 |
# raise gr.Error("Number of eigenvectors should be less than half the number of nodes.")
|
153 |
gr.Warning("Number of eigenvectors should be less than half the number of nodes.\n" f"Setting num_eig to {num_nodes // 2 - 1}.")
|
154 |
num_eig = num_nodes // 2 - 1
|
|
|
2196 |
css=custom_css,
|
2197 |
)
|
2198 |
with demo:
|
2199 |
+
|
2200 |
with gr.Tab('PlayGround'):
|
2201 |
eigvecs = gr.State(np.array([]))
|
2202 |
tsne3d_rgb = gr.State(np.array([]))
|
|
|
4246 |
fg_contrast_slider, bg_contrast_slider],
|
4247 |
outputs=[mask_gallery, crop_gallery])
|
4248 |
|
4249 |
+
|
4250 |
+
with gr.Tab('Sub-cluster (dev)') as sub_cluster_tab:
|
4251 |
+
with gr.Row():
|
4252 |
+
with gr.Column(scale=5, min_width=200):
|
4253 |
+
input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section()
|
4254 |
+
num_images_slider.value = 300
|
4255 |
+
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
|
4256 |
+
|
4257 |
+
with gr.Column(scale=5, min_width=200):
|
4258 |
+
output_gallery = make_output_images_section()
|
4259 |
+
add_download_button(output_gallery)
|
4260 |
+
with gr.Accordion("Sub-cluster", open=True):
|
4261 |
+
num_image_cluster_slider = gr.Slider(1, 20, step=1, label="Number of clusters (image-level)", value=6, elem_id="num_image_cluster", info="Image-level clustering before pixel-level clustering")
|
4262 |
+
show_class_label_checkbox = gr.Checkbox(label="Show cluster label", value=True, elem_id="show_class_label", info="Show image-level clustering label on the output")
|
4263 |
+
overlay_original_image_checkbox = gr.Checkbox(label="Overlay original image", value=False, elem_id="overlay_original_image", info="Overlay original image on the output")
|
4264 |
+
[
|
4265 |
+
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
4266 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
4267 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
4268 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
4269 |
+
sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
|
4270 |
+
] = make_parameters_section()
|
4271 |
+
|
4272 |
+
with gr.Row():
|
4273 |
+
image_cluster_plot = gr.Image(value=None, label="Image-level clustering", elem_id="image_cluster_plot", interactive=False)
|
4274 |
+
|
4275 |
+
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
4276 |
+
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
4277 |
+
|
4278 |
+
def cluster_image_level(image_gallery, n_clusters=6):
|
4279 |
+
images = [tup[0] for tup in image_gallery]
|
4280 |
+
if isinstance(images[0], str):
|
4281 |
+
images = [Image.open(image) for image in images]
|
4282 |
+
resized_images = [image.resize((224, 224), Image.LANCZOS) for image in images]
|
4283 |
+
images = [transform_image(image, resolution=(224, 224), stablediffusion=False) for image in resized_images]
|
4284 |
+
images = torch.stack(images)
|
4285 |
+
|
4286 |
+
is_cuda = torch.cuda.is_available()
|
4287 |
+
|
4288 |
+
dino = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16').eval()
|
4289 |
+
from torch.utils.data import DataLoader, TensorDataset
|
4290 |
+
dataset = TensorDataset(images)
|
4291 |
+
dataloader = DataLoader(dataset, batch_size=128, shuffle=False)
|
4292 |
+
features = []
|
4293 |
+
if is_cuda:
|
4294 |
+
dino = dino.cuda()
|
4295 |
+
with torch.no_grad():
|
4296 |
+
for _images in dataloader:
|
4297 |
+
_images = _images[0]
|
4298 |
+
_images = _images.cuda() if is_cuda else _images
|
4299 |
+
feature = dino.get_intermediate_layers(_images, 12)[11]
|
4300 |
+
feature = feature[:, 0]
|
4301 |
+
# normalize the feature
|
4302 |
+
feature = torch.nn.functional.normalize(feature, dim=-1)
|
4303 |
+
features.append(feature.cpu())
|
4304 |
+
class_features = torch.cat(features, dim=0)
|
4305 |
+
|
4306 |
+
from ncut_pytorch import NCUT
|
4307 |
+
inp = class_features
|
4308 |
+
b, c = inp.shape
|
4309 |
+
inp = inp.reshape(b, c)
|
4310 |
+
class_ncut = NCUT(num_eig=n_clusters, move_output_to_cpu=True, distance='cosine', knn=1, device='cuda' if is_cuda else 'cpu')
|
4311 |
+
class_eigvecs, _ = class_ncut.fit_transform(inp)
|
4312 |
+
from ncut_pytorch import kway_ncut
|
4313 |
+
kway_eigvecs = kway_ncut(class_eigvecs.cuda()) # (n, k), one-hot
|
4314 |
+
kway_indices = torch.argmax(kway_eigvecs, dim=1)
|
4315 |
+
cluster_labels = kway_indices.cpu().numpy()
|
4316 |
+
|
4317 |
+
from sklearn.manifold import TSNE
|
4318 |
+
perplexity = min(30, int(len(cluster_labels)/2 - 1))
|
4319 |
+
tsne_2d = TSNE(n_components=2, perplexity=perplexity, random_state=0).fit_transform(class_eigvecs.cpu())
|
4320 |
+
|
4321 |
+
tsne_image_plot = plot_2d_tsne_with_images(cluster_labels, tsne_2d, resized_images)
|
4322 |
+
|
4323 |
+
return cluster_labels, tsne_image_plot, resized_images
|
4324 |
+
|
4325 |
+
def plot_2d_tsne_with_images(cluster_labels, tsne_2d, images, max_num_images=300):
|
4326 |
+
# use dark background
|
4327 |
+
plt.style.use('dark_background')
|
4328 |
+
|
4329 |
+
def pad_image_with_border(image, border_color, border_width):
|
4330 |
+
new_image = np.ones((image.shape[0] + 2 * border_width, image.shape[1] + 2 * border_width, image.shape[2]), dtype=image.dtype)
|
4331 |
+
new_image[:, :] = border_color
|
4332 |
+
new_image[border_width:-border_width, border_width:-border_width] = image
|
4333 |
+
return new_image
|
4334 |
+
|
4335 |
+
padded_images = []
|
4336 |
+
for i in range(len(tsne_2d)):
|
4337 |
+
image = images[i]
|
4338 |
+
image = np.array(image)
|
4339 |
+
|
4340 |
+
border_color = np.array(plt.cm.tab20(cluster_labels[i] / 20))[:3] * 255
|
4341 |
+
border_width = 20
|
4342 |
+
padded_image = pad_image_with_border(image, border_color, border_width)
|
4343 |
+
padded_images.append(padded_image)
|
4344 |
+
|
4345 |
+
# Plot the t-SNE points
|
4346 |
+
fig, ax = plt.subplots(1, 1, figsize=(15, 15))
|
4347 |
+
ax.scatter(tsne_2d[:, 0], tsne_2d[:, 1], s=20, c=cluster_labels, cmap='tab20')
|
4348 |
+
ax.set_xticks([])
|
4349 |
+
ax.set_yticks([])
|
4350 |
+
ax.axis('off')
|
4351 |
+
ax.set_xlim(tsne_2d[:, 0].min()*1.1, tsne_2d[:, 0].max()*1.1)
|
4352 |
+
ax.set_ylim(tsne_2d[:, 1].min()*1.1, tsne_2d[:, 1].max()*1.1)
|
4353 |
+
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
|
4354 |
+
|
4355 |
+
if len(tsne_2d) > max_num_images:
|
4356 |
+
random_indices = np.random.choice(len(tsne_2d), max_num_images, replace=False)
|
4357 |
+
tsne_2d = tsne_2d[random_indices]
|
4358 |
+
padded_images = [padded_images[i] for i in random_indices]
|
4359 |
+
|
4360 |
+
|
4361 |
+
# Add the top1_image_blended to the scatter plot
|
4362 |
+
for i, (x, y) in enumerate(tsne_2d):
|
4363 |
+
img = padded_images[i]
|
4364 |
+
img = np.array(img)
|
4365 |
+
imgbox = OffsetImage(img, zoom=0.15)
|
4366 |
+
ab = AnnotationBbox(imgbox, (x, y), frameon=False)
|
4367 |
+
ax.add_artist(ab)
|
4368 |
+
|
4369 |
+
# Remove the white space around the plot
|
4370 |
+
fig.tight_layout(pad=0)
|
4371 |
+
|
4372 |
+
# Save the plot to an in-memory buffer
|
4373 |
+
buf = io.BytesIO()
|
4374 |
+
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
|
4375 |
+
buf.seek(0)
|
4376 |
+
|
4377 |
+
# Load the image into a NumPy array
|
4378 |
+
image = np.array(Image.open(buf))
|
4379 |
+
|
4380 |
+
# Close the buffer and plot
|
4381 |
+
buf.close()
|
4382 |
+
plt.close(fig)
|
4383 |
+
|
4384 |
+
pil_image = Image.fromarray(image)
|
4385 |
+
|
4386 |
+
# Reset the style
|
4387 |
+
plt.style.use('default')
|
4388 |
+
|
4389 |
+
return pil_image
|
4390 |
+
|
4391 |
+
def add_boarder_and_overlay_image(ncut_images, original_images, cluster_labels, is_overlay_image, is_add_border):
|
4392 |
+
def pad_image_with_left_border(image, border_color, border_width):
|
4393 |
+
new_image = np.ones((image.shape[0], image.shape[1] + border_width, image.shape[2]), dtype=image.dtype)
|
4394 |
+
new_image[:, :] = border_color
|
4395 |
+
new_image[:, border_width:] = image
|
4396 |
+
return new_image
|
4397 |
+
|
4398 |
+
if is_overlay_image:
|
4399 |
+
for i, ncut_image in enumerate(ncut_images):
|
4400 |
+
original_image = original_images[i]
|
4401 |
+
original_image = np.array(original_image)
|
4402 |
+
ncut_image = ncut_image.resize(original_image.shape[:2][::-1], Image.LANCZOS)
|
4403 |
+
ncut_image = np.array(ncut_image)
|
4404 |
+
blend_image = 0.5 * original_image + 0.5 * ncut_image
|
4405 |
+
blend_image = blend_image.astype(np.uint8)
|
4406 |
+
ncut_images[i] = Image.fromarray(blend_image)
|
4407 |
+
if is_add_border:
|
4408 |
+
for i, ncut_image in enumerate(ncut_images):
|
4409 |
+
border_color = np.array(plt.cm.tab20(cluster_labels[i] / 20))[:3] * 255
|
4410 |
+
border_width = 20
|
4411 |
+
ncut_image = ncut_image.resize((224, 224), Image.LANCZOS)
|
4412 |
+
ncut_image = np.array(ncut_image)
|
4413 |
+
ncut_image = pad_image_with_left_border(ncut_image, border_color, border_width)
|
4414 |
+
ncut_images[i] = Image.fromarray(ncut_image)
|
4415 |
+
|
4416 |
+
return ncut_images
|
4417 |
+
|
4418 |
+
def sub_cluster_run_fn(images, n_clusters, is_overlay_image, is_add_border, *args, **kwargs):
|
4419 |
+
cluster_labels, tsne_image_plot, resized_images = cluster_image_level(images, n_clusters)
|
4420 |
+
|
4421 |
+
output_images = [None] * len(images)
|
4422 |
+
for i in range(n_clusters):
|
4423 |
+
indices = np.where(cluster_labels == i)[0]
|
4424 |
+
if len(indices) == 0:
|
4425 |
+
continue
|
4426 |
+
input_images = [images[j] for j in indices]
|
4427 |
+
_i_output_images, logging_text = run_fn(input_images, *args, **kwargs)
|
4428 |
+
for _i, _idx in enumerate(indices):
|
4429 |
+
output_images[_idx] = _i_output_images[_i]
|
4430 |
+
|
4431 |
+
output_images = add_boarder_and_overlay_image(output_images, resized_images, cluster_labels, is_overlay_image, is_add_border)
|
4432 |
+
return output_images, tsne_image_plot
|
4433 |
+
|
4434 |
+
if USE_HUGGINGFACE_ZEROGPU:
|
4435 |
+
@spaces.GPU(duration=120)
|
4436 |
+
def sub_cluster_run_fn_wrap(*args, **kwargs):
|
4437 |
+
return sub_cluster_run_fn(*args, **kwargs)
|
4438 |
+
else:
|
4439 |
+
sub_cluster_run_fn_wrap = sub_cluster_run_fn
|
4440 |
+
|
4441 |
+
# image_cluster_indices = gr.State(np.array([]))
|
4442 |
+
|
4443 |
+
# submit_button.click(
|
4444 |
+
# cluster_image_level,
|
4445 |
+
# inputs=[input_gallery, num_image_cluster_slider],
|
4446 |
+
# outputs=[image_cluster_indices, image_cluster_plot],
|
4447 |
+
# )
|
4448 |
+
|
4449 |
+
submit_button.click(
|
4450 |
+
partial(sub_cluster_run_fn_wrap, n_ret=1, plot_clusters=False),
|
4451 |
+
inputs=[
|
4452 |
+
input_gallery, num_image_cluster_slider, overlay_original_image_checkbox, show_class_label_checkbox,
|
4453 |
+
model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
4454 |
+
positive_prompt, negative_prompt,
|
4455 |
+
false_placeholder, no_prompt, no_prompt, no_prompt,
|
4456 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
4457 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
4458 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown
|
4459 |
+
],
|
4460 |
+
outputs=[output_gallery, image_cluster_plot],
|
4461 |
+
)
|
4462 |
|
4463 |
# with gr.Tab('PlayGround (test)', visible=False) as test_playground_tab:
|
4464 |
# eigvecs = gr.State(np.array([]))
|
|
|
4782 |
|
4783 |
hidden_button.change(update_smile, [n_smiles], [n_smiles, hidden_button])
|
4784 |
hidden_tabs = [tab_alignedcut_advanced, tab_model_aligned_advanced, tab_recursivecut_advanced,
|
4785 |
+
tab_compare_models_advanced, tab_directed_ncut, tab_aligned, tab_lisa, sub_cluster_tab]
|
4786 |
hidden_button.change(partial(unlock_tabs, n_tab=len(hidden_tabs)), [n_smiles], hidden_tabs)
|
4787 |
|
4788 |
with gr.Row():
|