huzey commited on
Commit
094608c
·
1 Parent(s): 5afcac2

add playground

Browse files
Files changed (1) hide show
  1. app.py +205 -18
app.py CHANGED
@@ -993,26 +993,28 @@ def ncut_run(
993
 
994
  def _ncut_run(*args, **kwargs):
995
  n_ret = kwargs.get("n_ret", 1)
996
- # try:
997
- # if torch.cuda.is_available():
998
- # torch.cuda.empty_cache()
 
999
 
1000
- # ret = ncut_run(*args, **kwargs)
1001
 
1002
- # if torch.cuda.is_available():
1003
- # torch.cuda.empty_cache()
1004
 
1005
- # ret = list(ret)[:n_ret] + [ret[-1]]
1006
- # return ret
1007
- # except Exception as e:
1008
- # gr.Error(str(e))
1009
- # if torch.cuda.is_available():
1010
- # torch.cuda.empty_cache()
1011
- # return *(None for _ in range(n_ret)), "Error: " + str(e)
 
1012
 
1013
- ret = ncut_run(*args, **kwargs)
1014
- ret = list(ret)[:n_ret] + [ret[-1]]
1015
- return ret
1016
 
1017
  if USE_HUGGINGFACE_ZEROGPU:
1018
  @spaces.GPU(duration=30)
@@ -3557,8 +3559,8 @@ with demo:
3557
  else:
3558
  right, (y, x) = xy_eigvec(prompt_points, prompt_image_idx, eigvecs)
3559
  right = right[:n_eig]
3560
- left = F.normalize(left, p=2, dim=1)
3561
- _right = F.normalize(right, p=2, dim=0)
3562
  heatmap = left @ _right.unsqueeze(-1)
3563
  heatmap = heatmap.squeeze(-1)
3564
  heatmap = 1 - heatmap
@@ -3707,7 +3709,192 @@ with demo:
3707
  inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, child_distance_slider, child_idx, overlay_image_checkbox],
3708
  outputs=[n_eig, current_idx, parent_plot, current_plot, *child_plots, child_idx],
3709
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3710
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3711
  with gr.Tab('📄About'):
3712
  with gr.Column():
3713
  gr.Markdown("**This demo is for Python package `ncut-pytorch`, please visit the [Documentation](https://ncut-pytorch.readthedocs.io/)**")
 
993
 
994
  def _ncut_run(*args, **kwargs):
995
  n_ret = kwargs.get("n_ret", 1)
996
+ try:
997
+ gr.Info("NCUT Run Started", 2)
998
+ if torch.cuda.is_available():
999
+ torch.cuda.empty_cache()
1000
 
1001
+ ret = ncut_run(*args, **kwargs)
1002
 
1003
+ if torch.cuda.is_available():
1004
+ torch.cuda.empty_cache()
1005
 
1006
+ ret = list(ret)[:n_ret] + [ret[-1]]
1007
+ gr.Info("NCUT Run Finished", 2)
1008
+ return ret
1009
+ except Exception as e:
1010
+ gr.Error(str(e))
1011
+ if torch.cuda.is_available():
1012
+ torch.cuda.empty_cache()
1013
+ return *(None for _ in range(n_ret)), "Error: " + str(e)
1014
 
1015
+ # ret = ncut_run(*args, **kwargs)
1016
+ # ret = list(ret)[:n_ret] + [ret[-1]]
1017
+ # return ret
1018
 
1019
  if USE_HUGGINGFACE_ZEROGPU:
1020
  @spaces.GPU(duration=30)
 
3559
  else:
3560
  right, (y, x) = xy_eigvec(prompt_points, prompt_image_idx, eigvecs)
3561
  right = right[:n_eig]
3562
+ left = F.normalize(left, p=2, dim=-1)
3563
+ _right = F.normalize(right, p=2, dim=-1)
3564
  heatmap = left @ _right.unsqueeze(-1)
3565
  heatmap = heatmap.squeeze(-1)
3566
  heatmap = 1 - heatmap
 
3709
  inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, child_distance_slider, child_idx, overlay_image_checkbox],
3710
  outputs=[n_eig, current_idx, parent_plot, current_plot, *child_plots, child_idx],
3711
  )
3712
+
3713
+ with gr.Tab('PlayGround', visible=True) as test_playground_tab2:
3714
+ eigvecs = gr.State(torch.tensor([]))
3715
+ with gr.Row():
3716
+ with gr.Column(scale=5, min_width=200):
3717
+ gr.Markdown("### Step 1: Load Images")
3718
+ input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(n_example_images=10)
3719
+ submit_button.visible = False
3720
+ num_images_slider.value = 30
3721
+ logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
3722
+
3723
+ false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
3724
+ no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
3725
+
3726
+
3727
+ with gr.Column(scale=5, min_width=200):
3728
+ gr.Markdown("### Step 2a: Run Backbone and NCUT")
3729
+ with gr.Accordion(label="Backbone Parameters", visible=True, open=False):
3730
+ [
3731
+ model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
3732
+ affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
3733
+ embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
3734
+ perplexity_slider, n_neighbors_slider, min_dist_slider,
3735
+ sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
3736
+ ] = make_parameters_section(parameter_dropdown=False)
3737
+ num_eig_slider.value = 1024
3738
+ num_eig_slider.visible = False
3739
+ submit_button = gr.Button("🔴 RUN NCUT", elem_id="run_ncut", variant='primary')
3740
+ submit_button.click(
3741
+ partial(run_fn, n_ret=1, only_eigvecs=True),
3742
+ inputs=[
3743
+ input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
3744
+ positive_prompt, negative_prompt,
3745
+ false_placeholder, no_prompt, no_prompt, no_prompt,
3746
+ affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
3747
+ embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
3748
+ perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown
3749
+ ],
3750
+ outputs=[eigvecs, logging_text],
3751
+ )
3752
+ gr.Markdown("### Step 2b: Pick an Image")
3753
+ from gradio_image_prompter import ImagePrompter
3754
+ with gr.Row():
3755
+ image1_slider = gr.Slider(0, 100, step=1, label="Image#1 Index", value=0, elem_id="image1_slider", interactive=True)
3756
+ load_one_image_button = gr.Button("🔴 Load Image", elem_id="load_one_image_button", variant='primary')
3757
+ gr.Markdown("### Step 2c: Draw a Point")
3758
+ gr.Markdown("""
3759
+ <h5>
3760
+ 🖱️ Left Click: Foreground </br>
3761
+ </h5>
3762
+ """)
3763
+ prompt_image1 = ImagePrompter(show_label=False, elem_id="prompt_image1", interactive=False)
3764
+ def update_prompt_image(original_images, index):
3765
+ images = original_images
3766
+ if images is None:
3767
+ return
3768
+ total_len = len(images)
3769
+ if total_len == 0:
3770
+ return
3771
+ if index >= total_len:
3772
+ index = total_len - 1
3773
+
3774
+ return ImagePrompter(value={'image': images[index][0], 'points': []}, interactive=True)
3775
+ # return gr.Image(value=images[index][0], elem_id=f"prompt_image{randint}", interactive=True)
3776
+ load_one_image_button.click(update_prompt_image, inputs=[input_gallery, image1_slider], outputs=[prompt_image1])
3777
+
3778
+ child_idx = gr.State([])
3779
+ current_idx = gr.State(None)
3780
+ n_eig = gr.State(64)
3781
+ with gr.Column(scale=5, min_width=200):
3782
+ gr.Markdown("### Step 3: Check groupping")
3783
+ child_distance_slider = gr.Slider(0, 0.5, step=0.001, label="Child Distance", value=0.1, elem_id="child_distance_slider", interactive=True)
3784
+ child_distance_slider.visible = False
3785
+ overlay_image_checkbox = gr.Checkbox(label="Overlay Image", value=True, elem_id="overlay_image_checkbox", interactive=True)
3786
+ n_eig_slider = gr.Slider(0, 1024, step=1, label="Number of Eigenvectors", value=256, elem_id="n_eig_slider", interactive=True)
3787
+ run_button = gr.Button("🔴 RUN", elem_id="run_groupping", variant='primary')
3788
+ current_plot = gr.Gallery(value=None, label="Current", show_label=True, elem_id="current_plot", interactive=False, rows=[1], columns=[2])
3789
+ with gr.Row():
3790
+ doublue_eigs_button = gr.Button("⬇️ +eigvecs", elem_id="doublue_eigs_button", variant='secondary')
3791
+ half_eigs_button = gr.Button("⬆️ -eigvecs", elem_id="half_eigs_button", variant='secondary')
3792
 
3793
+ def relative_xy(prompts):
3794
+ image = prompts['image']
3795
+ points = np.asarray(prompts['points'])
3796
+ if points.shape[0] == 0:
3797
+ return [], []
3798
+ is_point = points[:, 5] == 4.0
3799
+ points = points[is_point]
3800
+ is_positive = points[:, 2] == 1.0
3801
+ is_negative = points[:, 2] == 0.0
3802
+ xy = points[:, :2].tolist()
3803
+ if isinstance(image, str):
3804
+ image = Image.open(image)
3805
+ image = np.array(image)
3806
+ h, w = image.shape[:2]
3807
+ new_xy = [(x/w, y/h) for x, y in xy]
3808
+ # print(new_xy)
3809
+ return new_xy, is_positive
3810
+
3811
+ def xy_eigvec(prompts, image_idx, eigvecs):
3812
+ eigvec = eigvecs[image_idx]
3813
+ xy, is_positive = relative_xy(prompts)
3814
+ for i, (x, y) in enumerate(xy):
3815
+ if not is_positive[i]:
3816
+ continue
3817
+ x = int(x * eigvec.shape[1])
3818
+ y = int(y * eigvec.shape[0])
3819
+ return eigvec[y, x], (y, x)
3820
+
3821
+ from ncut_pytorch.ncut_pytorch import _transform_heatmap
3822
+ def _run_heatmap_fn(images, eigvecs, prompt_image_idx, prompt_points, n_eig, flat_idx=None, raw_heatmap=False, overlay_image=True):
3823
+ left = eigvecs[..., :n_eig]
3824
+ if flat_idx is not None:
3825
+ right = eigvecs.reshape(-1, eigvecs.shape[-1])[flat_idx]
3826
+ y, x = None, None
3827
+ else:
3828
+ right, (y, x) = xy_eigvec(prompt_points, prompt_image_idx, eigvecs)
3829
+ right = right[:n_eig]
3830
+ left = F.normalize(left, p=2, dim=-1)
3831
+ _right = F.normalize(right, p=2, dim=-1)
3832
+ heatmap = left @ _right.unsqueeze(-1)
3833
+ heatmap = heatmap.squeeze(-1)
3834
+ # heatmap = 1 - heatmap
3835
+ # heatmap = _transform_heatmap(heatmap)
3836
+ if raw_heatmap:
3837
+ return heatmap
3838
+ # apply hot colormap and covert to PIL image 256x256
3839
+ # gr.Info(f"heatmap vmin: {heatmap.min()}, vmax: {heatmap.max()}, mean: {heatmap.mean()}")
3840
+ heatmap = heatmap.cpu().numpy()
3841
+ hot_map = matplotlib.cm.get_cmap('hot')
3842
+ heatmap = hot_map(heatmap)
3843
+ pil_images = to_pil_images(torch.tensor(heatmap), target_size=256, force_size=True)
3844
+ if overlay_image:
3845
+ overlaied_images = []
3846
+ for i_image in range(len(images)):
3847
+ rgb_image = images[i_image].resize((256, 256))
3848
+ rgb_image = np.array(rgb_image)
3849
+ heatmap_image = np.array(pil_images[i_image])[..., :3]
3850
+ blend_image = 0.5 * rgb_image + 0.5 * heatmap_image
3851
+ blend_image = Image.fromarray(blend_image.astype(np.uint8))
3852
+ overlaied_images.append(blend_image)
3853
+ pil_images = overlaied_images
3854
+ return pil_images, (y, x)
3855
+
3856
+ @torch.no_grad()
3857
+ def run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx=None, overlay_image=True):
3858
+ gr.Info(f"current number of eigenvectors: {n_eig}", 2)
3859
+ images = [image[0] for image in images]
3860
+ if isinstance(images[0], str):
3861
+ images = [Image.open(image[0]).convert("RGB").resize((256, 256)) for image in images]
3862
+
3863
+ current_heatmap, (y, x) = _run_heatmap_fn(images, eigvecs, image1_slider, prompt_image1, n_eig, flat_idx, overlay_image=overlay_image)
3864
+
3865
+ return current_heatmap
3866
+
3867
+ def doublue_eigs_wrapper(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx=None, overlay_image=True):
3868
+ n_eig = int(n_eig*2)
3869
+ n_eig = min(n_eig, eigvecs.shape[-1])
3870
+ n_eig = max(n_eig, 1)
3871
+ return gr.update(value=n_eig), run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx, overlay_image=overlay_image)
3872
+
3873
+ def half_eigs_wrapper(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, current_idx=None, overlay_image=True):
3874
+ n_eig = int(n_eig/2)
3875
+ n_eig = min(n_eig, eigvecs.shape[-1])
3876
+ n_eig = max(n_eig, 1)
3877
+ return gr.update(value=n_eig), run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, current_idx, overlay_image=overlay_image)
3878
+
3879
+ none_placeholder = gr.State(None)
3880
+ run_button.click(
3881
+ run_heatmap,
3882
+ inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig_slider, child_distance_slider, none_placeholder, overlay_image_checkbox],
3883
+ outputs=[current_plot],
3884
+ )
3885
+
3886
+ doublue_eigs_button.click(
3887
+ doublue_eigs_wrapper,
3888
+ inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig_slider, child_distance_slider, none_placeholder, overlay_image_checkbox],
3889
+ outputs=[n_eig_slider, current_plot],
3890
+ )
3891
+
3892
+ half_eigs_button.click(
3893
+ half_eigs_wrapper,
3894
+ inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig_slider, child_distance_slider, current_idx, overlay_image_checkbox],
3895
+ outputs=[n_eig_slider, current_plot],
3896
+ )
3897
+
3898
  with gr.Tab('📄About'):
3899
  with gr.Column():
3900
  gr.Markdown("**This demo is for Python package `ncut-pytorch`, please visit the [Documentation](https://ncut-pytorch.readthedocs.io/)**")