huzey commited on
Commit
cb75a81
·
1 Parent(s): 032a59d

update playground

Browse files
Files changed (1) hide show
  1. app.py +243 -243
app.py CHANGED
@@ -2455,7 +2455,7 @@ with demo:
2455
  return inspect_output_row, output_tree_image, heatmap_gallery, text_block
2456
 
2457
  gr.Markdown('---')
2458
- MAX_ROWS = 100
2459
  current_output_row = gr.State(MAX_ROWS-1)
2460
  inspect_output_rows, output_tree_images, heatmap_galleries, text_blocks = [], [], [], []
2461
  for i_row in range(MAX_ROWS, 0, -1):
@@ -4049,275 +4049,275 @@ with demo:
4049
  outputs=[mask_gallery, crop_gallery])
4050
 
4051
 
4052
- with gr.Tab('PlayGround (test)', visible=False) as test_playground_tab:
4053
- eigvecs = gr.State(np.array([]))
4054
- with gr.Row():
4055
- with gr.Column(scale=5, min_width=200):
4056
- gr.Markdown("### Step 1: Load Images and Run NCUT")
4057
- input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(n_example_images=100)
4058
- # submit_button.visible = False
4059
- num_images_slider.value = 30
4060
- [
4061
- model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
4062
- affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
4063
- embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
4064
- perplexity_slider, n_neighbors_slider, min_dist_slider,
4065
- sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
4066
- ] = make_parameters_section(ncut_parameter_dropdown=False)
4067
- num_eig_slider.value = 1000
4068
- num_eig_slider.visible = False
4069
- logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
4070
 
4071
- false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
4072
- no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
4073
 
4074
- submit_button.click(
4075
- partial(run_fn, n_ret=1, only_eigvecs=True),
4076
- inputs=[
4077
- input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
4078
- positive_prompt, negative_prompt,
4079
- false_placeholder, no_prompt, no_prompt, no_prompt,
4080
- affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
4081
- embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
4082
- perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown
4083
- ],
4084
- outputs=[eigvecs, logging_text],
4085
- )
4086
 
4087
- with gr.Column(scale=5, min_width=200):
4088
- gr.Markdown("### Step 2a: Pick an Image")
4089
- from gradio_image_prompter import ImagePrompter
4090
- with gr.Row():
4091
- image1_slider = gr.Slider(0, 100, step=1, label="Image#1 Index", value=0, elem_id="image1_slider", interactive=True)
4092
- load_one_image_button = gr.Button("🔴 Load", elem_id="load_one_image_button", variant='primary')
4093
- gr.Markdown("### Step 2b: Draw a Point")
4094
- gr.Markdown("""
4095
- <h5>
4096
- 🖱️ Left Click: Foreground </br>
4097
- </h5>
4098
- """)
4099
- prompt_image1 = ImagePrompter(show_label=False, elem_id="prompt_image1", interactive=False)
4100
- def update_prompt_image(original_images, index):
4101
- images = original_images
4102
- if images is None:
4103
- return
4104
- total_len = len(images)
4105
- if total_len == 0:
4106
- return
4107
- if index >= total_len:
4108
- index = total_len - 1
4109
 
4110
- return ImagePrompter(value={'image': images[index][0], 'points': []}, interactive=True)
4111
- # return gr.Image(value=images[index][0], elem_id=f"prompt_image{randint}", interactive=True)
4112
- load_one_image_button.click(update_prompt_image, inputs=[input_gallery, image1_slider], outputs=[prompt_image1])
4113
 
4114
- child_idx = gr.State([])
4115
- current_idx = gr.State(None)
4116
- n_eig = gr.State(64)
4117
- with gr.Column(scale=5, min_width=200):
4118
- gr.Markdown("### Step 3: Check groupping")
4119
- child_distance_slider = gr.Slider(0, 0.5, step=0.001, label="Child Distance", value=0.1, elem_id="child_distance_slider", interactive=True)
4120
- overlay_image_checkbox = gr.Checkbox(label="Overlay Image", value=True, elem_id="overlay_image_checkbox", interactive=True)
4121
- run_button = gr.Button("🔴 RUN", elem_id="run_groupping", variant='primary')
4122
- parent_plot = gr.Gallery(value=None, label="Parent", show_label=True, elem_id="parent_plot", interactive=False, rows=[1], columns=[2])
4123
- parent_button = gr.Button("Use Parent", elem_id="run_parent")
4124
- current_plot = gr.Gallery(value=None, label="Current", show_label=True, elem_id="current_plot", interactive=False, rows=[1], columns=[2])
4125
- with gr.Column(scale=5, min_width=200):
4126
- child_plots = []
4127
- child_buttons = []
4128
- for i in range(4):
4129
- child_plots.append(gr.Gallery(value=None, label=f"Child {i}", show_label=True, elem_id=f"child_plot_{i}", interactive=False, rows=[1], columns=[2]))
4130
- child_buttons.append(gr.Button(f"Use Child {i}", elem_id=f"run_child_{i}"))
4131
 
4132
- def relative_xy(prompts):
4133
- image = prompts['image']
4134
- points = np.asarray(prompts['points'])
4135
- if points.shape[0] == 0:
4136
- return [], []
4137
- is_point = points[:, 5] == 4.0
4138
- points = points[is_point]
4139
- is_positive = points[:, 2] == 1.0
4140
- is_negative = points[:, 2] == 0.0
4141
- xy = points[:, :2].tolist()
4142
- if isinstance(image, str):
4143
- image = Image.open(image)
4144
- image = np.array(image)
4145
- h, w = image.shape[:2]
4146
- new_xy = [(x/w, y/h) for x, y in xy]
4147
- # print(new_xy)
4148
- return new_xy, is_positive
4149
 
4150
- def xy_eigvec(prompts, image_idx, eigvecs):
4151
- eigvec = eigvecs[image_idx]
4152
- xy, is_positive = relative_xy(prompts)
4153
- for i, (x, y) in enumerate(xy):
4154
- if not is_positive[i]:
4155
- continue
4156
- x = int(x * eigvec.shape[1])
4157
- y = int(y * eigvec.shape[0])
4158
- return eigvec[y, x], (y, x)
4159
 
4160
- from ncut_pytorch.ncut_pytorch import _transform_heatmap
4161
- def _run_heatmap_fn(images, eigvecs, prompt_image_idx, prompt_points, n_eig, flat_idx=None, raw_heatmap=False, overlay_image=True):
4162
- left = eigvecs[..., :n_eig]
4163
- if flat_idx is not None:
4164
- right = eigvecs.reshape(-1, eigvecs.shape[-1])[flat_idx]
4165
- y, x = None, None
4166
- else:
4167
- right, (y, x) = xy_eigvec(prompt_points, prompt_image_idx, eigvecs)
4168
- right = right[:n_eig]
4169
- left = F.normalize(left, p=2, dim=-1)
4170
- _right = F.normalize(right, p=2, dim=-1)
4171
- heatmap = left @ _right.unsqueeze(-1)
4172
- heatmap = heatmap.squeeze(-1)
4173
- heatmap = 1 - heatmap
4174
- heatmap = _transform_heatmap(heatmap)
4175
- if raw_heatmap:
4176
- return heatmap
4177
- # apply hot colormap and covert to PIL image 256x256
4178
- heatmap = heatmap.cpu().numpy()
4179
- hot_map = matplotlib.colormaps['hot']
4180
- heatmap = hot_map(heatmap)
4181
- pil_images = to_pil_images(torch.tensor(heatmap), target_size=256, force_size=True)
4182
- if overlay_image:
4183
- overlaied_images = []
4184
- for i_image in range(len(images)):
4185
- rgb_image = images[i_image].resize((256, 256))
4186
- rgb_image = np.array(rgb_image)
4187
- heatmap_image = np.array(pil_images[i_image])[..., :3]
4188
- blend_image = 0.5 * rgb_image + 0.5 * heatmap_image
4189
- blend_image = Image.fromarray(blend_image.astype(np.uint8))
4190
- overlaied_images.append(blend_image)
4191
- pil_images = overlaied_images
4192
- return pil_images, (y, x)
4193
-
4194
- def _farthest_point_sampling(
4195
- features,
4196
- start_feature,
4197
- num_sample=300,
4198
- h=9,
4199
- ):
4200
- import fpsample
4201
-
4202
- h = min(h, int(np.log2(features.shape[0])))
4203
 
4204
- inp = features.cpu().numpy()
4205
- inp = np.concatenate([inp, start_feature[None, :]], axis=0)
4206
 
4207
- kdline_fps_samples_idx = fpsample.bucket_fps_kdline_sampling(
4208
- inp, num_sample, h, start_idx=inp.shape[0] - 1
4209
- ).astype(np.int64)
4210
- return kdline_fps_samples_idx
4211
 
4212
- @torch.no_grad()
4213
- def run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx=None, overlay_image=True):
4214
- gr.Info(f"current number of eigenvectors: {n_eig}")
4215
- eigvecs = torch.tensor(eigvecs)
4216
- image1_slider = min(image1_slider, len(images)-1)
4217
- images = [image[0] for image in images]
4218
- if isinstance(images[0], str):
4219
- images = [Image.open(image[0]).convert("RGB").resize((256, 256)) for image in images]
4220
 
4221
- current_heatmap, (y, x) = _run_heatmap_fn(images, eigvecs, image1_slider, prompt_image1, n_eig, flat_idx, overlay_image=overlay_image)
4222
- parent_heatmap, _ = _run_heatmap_fn(images, eigvecs, image1_slider, prompt_image1, int(n_eig/2), flat_idx, overlay_image=overlay_image)
4223
 
4224
- # find childs
4225
- # pca_eigvecs
4226
- _eigvecs = eigvecs.reshape(-1, eigvecs.shape[-1])
4227
- u, s, v = torch.pca_lowrank(_eigvecs, q=8)
4228
- _n = _eigvecs.shape[0]
4229
- s /= math.sqrt(_n)
4230
- _eigvecs = u @ torch.diag(s)
4231
 
4232
- if flat_idx is None:
4233
- _picked_eigvec = _eigvecs.reshape(*eigvecs.shape[:-1], 8)[image1_slider, y, x]
4234
- else:
4235
- _picked_eigvec = _eigvecs[flat_idx]
4236
- l2_distance = torch.norm(_eigvecs - _picked_eigvec, dim=-1)
4237
- average_distance = l2_distance.mean()
4238
- distance_threshold = distance_slider * average_distance
4239
- distance_mask = l2_distance < distance_threshold
4240
- masked_eigvecs = _eigvecs[distance_mask]
4241
- num_childs = min(4, masked_eigvecs.shape[0])
4242
- assert num_childs > 0
4243
 
4244
- child_idx = _farthest_point_sampling(masked_eigvecs, _picked_eigvec, num_sample=num_childs+1)
4245
- child_idx = np.sort(child_idx)[:-1]
4246
 
4247
- # convert child_idx to flat_idx
4248
- dummy_idx = torch.zeros(_eigvecs.shape[0], dtype=torch.bool)
4249
- dummy_idx2 = torch.zeros(int(distance_mask.sum().item()), dtype=torch.bool)
4250
- dummy_idx2[child_idx] = True
4251
- dummy_idx[distance_mask] = dummy_idx2
4252
- child_idx = torch.where(dummy_idx)[0]
4253
 
4254
 
4255
- # current_child heatmap, for contrast
4256
- current_child_heatmap = _run_heatmap_fn(images,eigvecs, image1_slider, prompt_image1, int(n_eig*2), flat_idx, raw_heatmap=True, overlay_image=overlay_image)
4257
 
4258
- # child_heatmaps, contrast mean of current clicked point
4259
- child_heatmaps = []
4260
- for idx in child_idx:
4261
- child_heatmap = _run_heatmap_fn(images,eigvecs, image1_slider, prompt_image1, int(n_eig*2), idx, raw_heatmap=True, overlay_image=overlay_image)
4262
- heatmap = child_heatmap - current_child_heatmap
4263
- # convert [-1, 1] to [0, 1]
4264
- heatmap = (heatmap + 1) / 2
4265
- heatmap = heatmap.cpu().numpy()
4266
- cm = matplotlib.colormaps['bwr']
4267
- heatmap = cm(heatmap)
4268
- # bwr with contrast
4269
- pil_images1 = to_pil_images(torch.tensor(heatmap), resize=256)
4270
- # no contrast
4271
- pil_images2, _ = _run_heatmap_fn(images, eigvecs, image1_slider, prompt_image1, int(n_eig*2), idx, overlay_image=overlay_image)
4272
 
4273
- # combine contrast and no contrast
4274
- pil_images = []
4275
- for i in range(len(pil_images1)):
4276
- pil_images.append(pil_images2[i])
4277
- pil_images.append(pil_images1[i])
4278
 
4279
 
4280
- child_heatmaps.append(pil_images)
4281
 
4282
- return parent_heatmap, current_heatmap, *child_heatmaps, child_idx.tolist()
4283
 
4284
- # def debug_fn(eigvecs):
4285
- # shape = eigvecs.shape
4286
- # gr.Info(f"eigvecs shape: {shape}")
4287
 
4288
- # run_button.click(
4289
- # debug_fn,
4290
- # inputs=[eigvecs],
4291
- # outputs=[],
4292
- # )
4293
- none_placeholder = gr.State(None)
4294
- run_button.click(
4295
- run_heatmap,
4296
- inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, child_distance_slider, none_placeholder, overlay_image_checkbox],
4297
- outputs=[parent_plot, current_plot, *child_plots, child_idx],
4298
- )
4299
 
4300
- def run_paraent(input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, current_idx=None, overlay_image=True):
4301
- n_eig = int(n_eig/2)
4302
- return n_eig, *run_heatmap(input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, current_idx, overlay_image)
4303
 
4304
- parent_button.click(
4305
- run_paraent,
4306
- inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, child_distance_slider, current_idx, overlay_image_checkbox],
4307
- outputs=[n_eig, parent_plot, current_plot, *child_plots, child_idx],
4308
- )
4309
 
4310
- def run_child(input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, child_idx=[], overlay_image=True, i_child=0):
4311
- n_eig = int(n_eig*2)
4312
- flat_idx = child_idx[i_child]
4313
- return n_eig, flat_idx, *run_heatmap(input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx, overlay_image)
4314
-
4315
- for i in range(4):
4316
- child_buttons[i].click(
4317
- partial(run_child, i_child=i),
4318
- inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, child_distance_slider, child_idx, overlay_image_checkbox],
4319
- outputs=[n_eig, current_idx, parent_plot, current_plot, *child_plots, child_idx],
4320
- )
4321
 
4322
 
4323
  with gr.Tab('📄About'):
@@ -4371,7 +4371,7 @@ with demo:
4371
 
4372
  hidden_button.change(update_smile, [n_smiles], [n_smiles, hidden_button])
4373
  hidden_tabs = [tab_alignedcut_advanced, tab_model_aligned_advanced, tab_recursivecut_advanced,
4374
- tab_compare_models_advanced, tab_directed_ncut, test_playground_tab, tab_aligned, tab_lisa]
4375
  hidden_button.change(partial(unlock_tabs, n_tab=len(hidden_tabs)), [n_smiles], hidden_tabs)
4376
 
4377
  with gr.Row():
 
2455
  return inspect_output_row, output_tree_image, heatmap_gallery, text_block
2456
 
2457
  gr.Markdown('---')
2458
+ MAX_ROWS = 10
2459
  current_output_row = gr.State(MAX_ROWS-1)
2460
  inspect_output_rows, output_tree_images, heatmap_galleries, text_blocks = [], [], [], []
2461
  for i_row in range(MAX_ROWS, 0, -1):
 
4049
  outputs=[mask_gallery, crop_gallery])
4050
 
4051
 
4052
+ # with gr.Tab('PlayGround (test)', visible=False) as test_playground_tab:
4053
+ # eigvecs = gr.State(np.array([]))
4054
+ # with gr.Row():
4055
+ # with gr.Column(scale=5, min_width=200):
4056
+ # gr.Markdown("### Step 1: Load Images and Run NCUT")
4057
+ # input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(n_example_images=100)
4058
+ # # submit_button.visible = False
4059
+ # num_images_slider.value = 30
4060
+ # [
4061
+ # model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
4062
+ # affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
4063
+ # embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
4064
+ # perplexity_slider, n_neighbors_slider, min_dist_slider,
4065
+ # sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
4066
+ # ] = make_parameters_section(ncut_parameter_dropdown=False)
4067
+ # num_eig_slider.value = 1000
4068
+ # num_eig_slider.visible = False
4069
+ # logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
4070
 
4071
+ # false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
4072
+ # no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
4073
 
4074
+ # submit_button.click(
4075
+ # partial(run_fn, n_ret=1, only_eigvecs=True),
4076
+ # inputs=[
4077
+ # input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
4078
+ # positive_prompt, negative_prompt,
4079
+ # false_placeholder, no_prompt, no_prompt, no_prompt,
4080
+ # affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
4081
+ # embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
4082
+ # perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown
4083
+ # ],
4084
+ # outputs=[eigvecs, logging_text],
4085
+ # )
4086
 
4087
+ # with gr.Column(scale=5, min_width=200):
4088
+ # gr.Markdown("### Step 2a: Pick an Image")
4089
+ # from gradio_image_prompter import ImagePrompter
4090
+ # with gr.Row():
4091
+ # image1_slider = gr.Slider(0, 100, step=1, label="Image#1 Index", value=0, elem_id="image1_slider", interactive=True)
4092
+ # load_one_image_button = gr.Button("🔴 Load", elem_id="load_one_image_button", variant='primary')
4093
+ # gr.Markdown("### Step 2b: Draw a Point")
4094
+ # gr.Markdown("""
4095
+ # <h5>
4096
+ # 🖱️ Left Click: Foreground </br>
4097
+ # </h5>
4098
+ # """)
4099
+ # prompt_image1 = ImagePrompter(show_label=False, elem_id="prompt_image1", interactive=False)
4100
+ # def update_prompt_image(original_images, index):
4101
+ # images = original_images
4102
+ # if images is None:
4103
+ # return
4104
+ # total_len = len(images)
4105
+ # if total_len == 0:
4106
+ # return
4107
+ # if index >= total_len:
4108
+ # index = total_len - 1
4109
 
4110
+ # return ImagePrompter(value={'image': images[index][0], 'points': []}, interactive=True)
4111
+ # # return gr.Image(value=images[index][0], elem_id=f"prompt_image{randint}", interactive=True)
4112
+ # load_one_image_button.click(update_prompt_image, inputs=[input_gallery, image1_slider], outputs=[prompt_image1])
4113
 
4114
+ # child_idx = gr.State([])
4115
+ # current_idx = gr.State(None)
4116
+ # n_eig = gr.State(64)
4117
+ # with gr.Column(scale=5, min_width=200):
4118
+ # gr.Markdown("### Step 3: Check groupping")
4119
+ # child_distance_slider = gr.Slider(0, 0.5, step=0.001, label="Child Distance", value=0.1, elem_id="child_distance_slider", interactive=True)
4120
+ # overlay_image_checkbox = gr.Checkbox(label="Overlay Image", value=True, elem_id="overlay_image_checkbox", interactive=True)
4121
+ # run_button = gr.Button("🔴 RUN", elem_id="run_groupping", variant='primary')
4122
+ # parent_plot = gr.Gallery(value=None, label="Parent", show_label=True, elem_id="parent_plot", interactive=False, rows=[1], columns=[2])
4123
+ # parent_button = gr.Button("Use Parent", elem_id="run_parent")
4124
+ # current_plot = gr.Gallery(value=None, label="Current", show_label=True, elem_id="current_plot", interactive=False, rows=[1], columns=[2])
4125
+ # with gr.Column(scale=5, min_width=200):
4126
+ # child_plots = []
4127
+ # child_buttons = []
4128
+ # for i in range(4):
4129
+ # child_plots.append(gr.Gallery(value=None, label=f"Child {i}", show_label=True, elem_id=f"child_plot_{i}", interactive=False, rows=[1], columns=[2]))
4130
+ # child_buttons.append(gr.Button(f"Use Child {i}", elem_id=f"run_child_{i}"))
4131
 
4132
+ # def relative_xy(prompts):
4133
+ # image = prompts['image']
4134
+ # points = np.asarray(prompts['points'])
4135
+ # if points.shape[0] == 0:
4136
+ # return [], []
4137
+ # is_point = points[:, 5] == 4.0
4138
+ # points = points[is_point]
4139
+ # is_positive = points[:, 2] == 1.0
4140
+ # is_negative = points[:, 2] == 0.0
4141
+ # xy = points[:, :2].tolist()
4142
+ # if isinstance(image, str):
4143
+ # image = Image.open(image)
4144
+ # image = np.array(image)
4145
+ # h, w = image.shape[:2]
4146
+ # new_xy = [(x/w, y/h) for x, y in xy]
4147
+ # # print(new_xy)
4148
+ # return new_xy, is_positive
4149
 
4150
+ # def xy_eigvec(prompts, image_idx, eigvecs):
4151
+ # eigvec = eigvecs[image_idx]
4152
+ # xy, is_positive = relative_xy(prompts)
4153
+ # for i, (x, y) in enumerate(xy):
4154
+ # if not is_positive[i]:
4155
+ # continue
4156
+ # x = int(x * eigvec.shape[1])
4157
+ # y = int(y * eigvec.shape[0])
4158
+ # return eigvec[y, x], (y, x)
4159
 
4160
+ # from ncut_pytorch.ncut_pytorch import _transform_heatmap
4161
+ # def _run_heatmap_fn(images, eigvecs, prompt_image_idx, prompt_points, n_eig, flat_idx=None, raw_heatmap=False, overlay_image=True):
4162
+ # left = eigvecs[..., :n_eig]
4163
+ # if flat_idx is not None:
4164
+ # right = eigvecs.reshape(-1, eigvecs.shape[-1])[flat_idx]
4165
+ # y, x = None, None
4166
+ # else:
4167
+ # right, (y, x) = xy_eigvec(prompt_points, prompt_image_idx, eigvecs)
4168
+ # right = right[:n_eig]
4169
+ # left = F.normalize(left, p=2, dim=-1)
4170
+ # _right = F.normalize(right, p=2, dim=-1)
4171
+ # heatmap = left @ _right.unsqueeze(-1)
4172
+ # heatmap = heatmap.squeeze(-1)
4173
+ # heatmap = 1 - heatmap
4174
+ # heatmap = _transform_heatmap(heatmap)
4175
+ # if raw_heatmap:
4176
+ # return heatmap
4177
+ # # apply hot colormap and covert to PIL image 256x256
4178
+ # heatmap = heatmap.cpu().numpy()
4179
+ # hot_map = matplotlib.colormaps['hot']
4180
+ # heatmap = hot_map(heatmap)
4181
+ # pil_images = to_pil_images(torch.tensor(heatmap), target_size=256, force_size=True)
4182
+ # if overlay_image:
4183
+ # overlaied_images = []
4184
+ # for i_image in range(len(images)):
4185
+ # rgb_image = images[i_image].resize((256, 256))
4186
+ # rgb_image = np.array(rgb_image)
4187
+ # heatmap_image = np.array(pil_images[i_image])[..., :3]
4188
+ # blend_image = 0.5 * rgb_image + 0.5 * heatmap_image
4189
+ # blend_image = Image.fromarray(blend_image.astype(np.uint8))
4190
+ # overlaied_images.append(blend_image)
4191
+ # pil_images = overlaied_images
4192
+ # return pil_images, (y, x)
4193
+
4194
+ # def _farthest_point_sampling(
4195
+ # features,
4196
+ # start_feature,
4197
+ # num_sample=300,
4198
+ # h=9,
4199
+ # ):
4200
+ # import fpsample
4201
+
4202
+ # h = min(h, int(np.log2(features.shape[0])))
4203
 
4204
+ # inp = features.cpu().numpy()
4205
+ # inp = np.concatenate([inp, start_feature[None, :]], axis=0)
4206
 
4207
+ # kdline_fps_samples_idx = fpsample.bucket_fps_kdline_sampling(
4208
+ # inp, num_sample, h, start_idx=inp.shape[0] - 1
4209
+ # ).astype(np.int64)
4210
+ # return kdline_fps_samples_idx
4211
 
4212
+ # @torch.no_grad()
4213
+ # def run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx=None, overlay_image=True):
4214
+ # gr.Info(f"current number of eigenvectors: {n_eig}")
4215
+ # eigvecs = torch.tensor(eigvecs)
4216
+ # image1_slider = min(image1_slider, len(images)-1)
4217
+ # images = [image[0] for image in images]
4218
+ # if isinstance(images[0], str):
4219
+ # images = [Image.open(image[0]).convert("RGB").resize((256, 256)) for image in images]
4220
 
4221
+ # current_heatmap, (y, x) = _run_heatmap_fn(images, eigvecs, image1_slider, prompt_image1, n_eig, flat_idx, overlay_image=overlay_image)
4222
+ # parent_heatmap, _ = _run_heatmap_fn(images, eigvecs, image1_slider, prompt_image1, int(n_eig/2), flat_idx, overlay_image=overlay_image)
4223
 
4224
+ # # find childs
4225
+ # # pca_eigvecs
4226
+ # _eigvecs = eigvecs.reshape(-1, eigvecs.shape[-1])
4227
+ # u, s, v = torch.pca_lowrank(_eigvecs, q=8)
4228
+ # _n = _eigvecs.shape[0]
4229
+ # s /= math.sqrt(_n)
4230
+ # _eigvecs = u @ torch.diag(s)
4231
 
4232
+ # if flat_idx is None:
4233
+ # _picked_eigvec = _eigvecs.reshape(*eigvecs.shape[:-1], 8)[image1_slider, y, x]
4234
+ # else:
4235
+ # _picked_eigvec = _eigvecs[flat_idx]
4236
+ # l2_distance = torch.norm(_eigvecs - _picked_eigvec, dim=-1)
4237
+ # average_distance = l2_distance.mean()
4238
+ # distance_threshold = distance_slider * average_distance
4239
+ # distance_mask = l2_distance < distance_threshold
4240
+ # masked_eigvecs = _eigvecs[distance_mask]
4241
+ # num_childs = min(4, masked_eigvecs.shape[0])
4242
+ # assert num_childs > 0
4243
 
4244
+ # child_idx = _farthest_point_sampling(masked_eigvecs, _picked_eigvec, num_sample=num_childs+1)
4245
+ # child_idx = np.sort(child_idx)[:-1]
4246
 
4247
+ # # convert child_idx to flat_idx
4248
+ # dummy_idx = torch.zeros(_eigvecs.shape[0], dtype=torch.bool)
4249
+ # dummy_idx2 = torch.zeros(int(distance_mask.sum().item()), dtype=torch.bool)
4250
+ # dummy_idx2[child_idx] = True
4251
+ # dummy_idx[distance_mask] = dummy_idx2
4252
+ # child_idx = torch.where(dummy_idx)[0]
4253
 
4254
 
4255
+ # # current_child heatmap, for contrast
4256
+ # current_child_heatmap = _run_heatmap_fn(images,eigvecs, image1_slider, prompt_image1, int(n_eig*2), flat_idx, raw_heatmap=True, overlay_image=overlay_image)
4257
 
4258
+ # # child_heatmaps, contrast mean of current clicked point
4259
+ # child_heatmaps = []
4260
+ # for idx in child_idx:
4261
+ # child_heatmap = _run_heatmap_fn(images,eigvecs, image1_slider, prompt_image1, int(n_eig*2), idx, raw_heatmap=True, overlay_image=overlay_image)
4262
+ # heatmap = child_heatmap - current_child_heatmap
4263
+ # # convert [-1, 1] to [0, 1]
4264
+ # heatmap = (heatmap + 1) / 2
4265
+ # heatmap = heatmap.cpu().numpy()
4266
+ # cm = matplotlib.colormaps['bwr']
4267
+ # heatmap = cm(heatmap)
4268
+ # # bwr with contrast
4269
+ # pil_images1 = to_pil_images(torch.tensor(heatmap), resize=256)
4270
+ # # no contrast
4271
+ # pil_images2, _ = _run_heatmap_fn(images, eigvecs, image1_slider, prompt_image1, int(n_eig*2), idx, overlay_image=overlay_image)
4272
 
4273
+ # # combine contrast and no contrast
4274
+ # pil_images = []
4275
+ # for i in range(len(pil_images1)):
4276
+ # pil_images.append(pil_images2[i])
4277
+ # pil_images.append(pil_images1[i])
4278
 
4279
 
4280
+ # child_heatmaps.append(pil_images)
4281
 
4282
+ # return parent_heatmap, current_heatmap, *child_heatmaps, child_idx.tolist()
4283
 
4284
+ # # def debug_fn(eigvecs):
4285
+ # # shape = eigvecs.shape
4286
+ # # gr.Info(f"eigvecs shape: {shape}")
4287
 
4288
+ # # run_button.click(
4289
+ # # debug_fn,
4290
+ # # inputs=[eigvecs],
4291
+ # # outputs=[],
4292
+ # # )
4293
+ # none_placeholder = gr.State(None)
4294
+ # run_button.click(
4295
+ # run_heatmap,
4296
+ # inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, child_distance_slider, none_placeholder, overlay_image_checkbox],
4297
+ # outputs=[parent_plot, current_plot, *child_plots, child_idx],
4298
+ # )
4299
 
4300
+ # def run_paraent(input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, current_idx=None, overlay_image=True):
4301
+ # n_eig = int(n_eig/2)
4302
+ # return n_eig, *run_heatmap(input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, current_idx, overlay_image)
4303
 
4304
+ # parent_button.click(
4305
+ # run_paraent,
4306
+ # inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, child_distance_slider, current_idx, overlay_image_checkbox],
4307
+ # outputs=[n_eig, parent_plot, current_plot, *child_plots, child_idx],
4308
+ # )
4309
 
4310
+ # def run_child(input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, child_idx=[], overlay_image=True, i_child=0):
4311
+ # n_eig = int(n_eig*2)
4312
+ # flat_idx = child_idx[i_child]
4313
+ # return n_eig, flat_idx, *run_heatmap(input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx, overlay_image)
4314
+
4315
+ # for i in range(4):
4316
+ # child_buttons[i].click(
4317
+ # partial(run_child, i_child=i),
4318
+ # inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, child_distance_slider, child_idx, overlay_image_checkbox],
4319
+ # outputs=[n_eig, current_idx, parent_plot, current_plot, *child_plots, child_idx],
4320
+ # )
4321
 
4322
 
4323
  with gr.Tab('📄About'):
 
4371
 
4372
  hidden_button.change(update_smile, [n_smiles], [n_smiles, hidden_button])
4373
  hidden_tabs = [tab_alignedcut_advanced, tab_model_aligned_advanced, tab_recursivecut_advanced,
4374
+ tab_compare_models_advanced, tab_directed_ncut, tab_aligned, tab_lisa]
4375
  hidden_button.change(partial(unlock_tabs, n_tab=len(hidden_tabs)), [n_smiles], hidden_tabs)
4376
 
4377
  with gr.Row():