Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	add directed ncut (test)
Browse files- app.py +382 -39
 - directed_ncut.py +287 -0
 - requirements.txt +1 -1
 
    	
        app.py
    CHANGED
    
    | 
         @@ -183,6 +183,84 @@ def compute_ncut( 
     | 
|
| 183 | 
         
             
                return rgb, logging_str, eigvecs
         
     | 
| 184 | 
         | 
| 185 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 186 | 
         
             
            def dont_use_too_much_green(image_rgb):
         
     | 
| 187 | 
         
             
                # make sure the foval 40% of the image is red leading
         
     | 
| 188 | 
         
             
                x1, x2 = int(image_rgb.shape[1] * 0.3), int(image_rgb.shape[1] * 0.7)
         
     | 
| 
         @@ -592,6 +670,8 @@ def ncut_run( 
     | 
|
| 592 | 
         
             
                **kwargs,
         
     | 
| 593 | 
         
             
            ):
         
     | 
| 594 | 
         
             
                advanced = kwargs.get("advanced", False)
         
     | 
| 
         | 
|
| 
         | 
|
| 595 | 
         
             
                progress = gr.Progress()
         
     | 
| 596 | 
         
             
                progress(0.2, desc="Feature Extraction")
         
     | 
| 597 | 
         | 
| 
         @@ -640,6 +720,11 @@ def ncut_run( 
     | 
|
| 640 | 
         
             
                    features = extract_features(
         
     | 
| 641 | 
         
             
                        images, model, node_type=node_type, layer=layer-1, batch_size=BATCH_SIZE
         
     | 
| 642 | 
         
             
                    )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 643 | 
         
             
                # print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
         
     | 
| 644 | 
         
             
                logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
         
     | 
| 645 | 
         
             
                del model
         
     | 
| 
         @@ -768,25 +853,59 @@ def ncut_run( 
     | 
|
| 768 | 
         | 
| 769 | 
         | 
| 770 | 
         
             
                # ailgnedcut
         
     | 
| 771 | 
         
            -
                
         
     | 
| 772 | 
         
            -
             
     | 
| 773 | 
         
            -
             
     | 
| 774 | 
         
            -
             
     | 
| 775 | 
         
            -
             
     | 
| 776 | 
         
            -
             
     | 
| 777 | 
         
            -
             
     | 
| 778 | 
         
            -
             
     | 
| 779 | 
         
            -
             
     | 
| 780 | 
         
            -
             
     | 
| 781 | 
         
            -
             
     | 
| 782 | 
         
            -
             
     | 
| 783 | 
         
            -
             
     | 
| 784 | 
         
            -
             
     | 
| 785 | 
         
            -
             
     | 
| 786 | 
         
            -
             
     | 
| 787 | 
         
            -
             
     | 
| 788 | 
         
            -
             
     | 
| 789 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 790 | 
         
             
                logging_str += _logging_str
         
     | 
| 791 | 
         | 
| 792 | 
         
             
                if "AlignedThreeModelAttnNodes" == model_name:
         
     | 
| 
         @@ -858,26 +977,26 @@ def ncut_run( 
     | 
|
| 858 | 
         | 
| 859 | 
         
             
            def _ncut_run(*args, **kwargs):
         
     | 
| 860 | 
         
             
                n_ret = kwargs.pop("n_ret", 1)
         
     | 
| 861 | 
         
            -
                try:
         
     | 
| 862 | 
         
            -
             
     | 
| 863 | 
         
            -
             
     | 
| 864 | 
         | 
| 865 | 
         
            -
             
     | 
| 866 | 
         | 
| 867 | 
         
            -
             
     | 
| 868 | 
         
            -
             
     | 
| 869 | 
         | 
| 870 | 
         
            -
             
     | 
| 871 | 
         
            -
             
     | 
| 872 | 
         
            -
                except Exception as e:
         
     | 
| 873 | 
         
            -
             
     | 
| 874 | 
         
            -
             
     | 
| 875 | 
         
            -
             
     | 
| 876 | 
         
            -
             
     | 
| 877 | 
         | 
| 878 | 
         
            -
                 
     | 
| 879 | 
         
            -
                 
     | 
| 880 | 
         
            -
                 
     | 
| 881 | 
         | 
| 882 | 
         
             
            if USE_HUGGINGFACE_ZEROGPU:
         
     | 
| 883 | 
         
             
                @spaces.GPU(duration=30)
         
     | 
| 
         @@ -1085,12 +1204,16 @@ def run_fn( 
     | 
|
| 1085 | 
         
             
                recursion_l1_gamma=0.5,
         
     | 
| 1086 | 
         
             
                recursion_l2_gamma=0.5,
         
     | 
| 1087 | 
         
             
                recursion_l3_gamma=0.5,
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 1088 | 
         
             
                n_ret=1,
         
     | 
| 1089 | 
         
             
                plot_clusters=False,
         
     | 
| 1090 | 
         
             
                alignedcut_eig_norm_plot=False,
         
     | 
| 1091 | 
         
             
                advanced=False,
         
     | 
| 
         | 
|
| 1092 | 
         
             
            ):
         
     | 
| 1093 | 
         
            -
                
         
     | 
| 1094 | 
         
             
                progress=gr.Progress()
         
     | 
| 1095 | 
         
             
                progress(0, desc="Starting")
         
     | 
| 1096 | 
         | 
| 
         @@ -1222,6 +1345,10 @@ def run_fn( 
     | 
|
| 1222 | 
         
             
                    "plot_clusters": plot_clusters,
         
     | 
| 1223 | 
         
             
                    "alignedcut_eig_norm_plot": alignedcut_eig_norm_plot,
         
     | 
| 1224 | 
         
             
                    "advanced": advanced,
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 1225 | 
         
             
                }
         
     | 
| 1226 | 
         
             
                # print(kwargs)
         
     | 
| 1227 | 
         | 
| 
         @@ -1379,7 +1506,7 @@ def fit_trans(rgb1, rgb2, num_layer=3, width=512, batch_size=256, lr=3e-4, fitti 
     | 
|
| 1379 | 
         
             
                # Train the model
         
     | 
| 1380 | 
         
             
                trainer.fit(mlp, dataloader)
         
     | 
| 1381 | 
         | 
| 1382 | 
         
            -
                
         
     | 
| 1383 | 
         
             
                results = trainer.predict(mlp, data_loader)
         
     | 
| 1384 | 
         
             
                A_transformed = torch.cat(results, dim=0)
         
     | 
| 1385 | 
         | 
| 
         @@ -2734,10 +2861,226 @@ with demo: 
     | 
|
| 2734 | 
         
             
                    buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
         
     | 
| 2735 | 
         
             
                    buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
         
     | 
| 2736 | 
         | 
| 
         | 
|
| 
         | 
|
| 2737 | 
         | 
| 2738 | 
         
            -
                     
     | 
| 2739 | 
         
            -
                    
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 2740 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 2741 | 
         | 
| 2742 | 
         
             
                with gr.Tab('📄About'):
         
     | 
| 2743 | 
         
             
                    with gr.Column():
         
     | 
| 
         | 
|
| 183 | 
         
             
                return rgb, logging_str, eigvecs
         
     | 
| 184 | 
         | 
| 185 | 
         | 
| 186 | 
         
            +
            def compute_ncut_directed(
         
     | 
| 187 | 
         
            +
                features_1,
         
     | 
| 188 | 
         
            +
                features_2,
         
     | 
| 189 | 
         
            +
                num_eig=100,
         
     | 
| 190 | 
         
            +
                num_sample_ncut=10000,
         
     | 
| 191 | 
         
            +
                affinity_focal_gamma=0.3,
         
     | 
| 192 | 
         
            +
                knn_ncut=10,
         
     | 
| 193 | 
         
            +
                knn_tsne=10,
         
     | 
| 194 | 
         
            +
                embedding_method="UMAP",
         
     | 
| 195 | 
         
            +
                embedding_metric='euclidean',
         
     | 
| 196 | 
         
            +
                num_sample_tsne=300,
         
     | 
| 197 | 
         
            +
                perplexity=150,
         
     | 
| 198 | 
         
            +
                n_neighbors=150,
         
     | 
| 199 | 
         
            +
                min_dist=0.1,
         
     | 
| 200 | 
         
            +
                sampling_method="QuickFPS",
         
     | 
| 201 | 
         
            +
                metric="cosine",
         
     | 
| 202 | 
         
            +
                indirect_connection=False,
         
     | 
| 203 | 
         
            +
                make_orthogonal=False,
         
     | 
| 204 | 
         
            +
                make_symmetric=False,
         
     | 
| 205 | 
         
            +
                progess_start=0.4,
         
     | 
| 206 | 
         
            +
            ):        
         
     | 
| 207 | 
         
            +
                print("Using directed_ncut")
         
     | 
| 208 | 
         
            +
                print("features_1.shape", features_1.shape)
         
     | 
| 209 | 
         
            +
                print("features_2.shape", features_2.shape)
         
     | 
| 210 | 
         
            +
                from directed_ncut import nystrom_ncut
         
     | 
| 211 | 
         
            +
                progress = gr.Progress()
         
     | 
| 212 | 
         
            +
                logging_str = ""
         
     | 
| 213 | 
         
            +
                
         
     | 
| 214 | 
         
            +
                num_nodes = np.prod(features_1.shape[:-2])
         
     | 
| 215 | 
         
            +
                if num_nodes / 2 < num_eig:
         
     | 
| 216 | 
         
            +
                    # raise gr.Error("Number of eigenvectors should be less than half the number of nodes.")
         
     | 
| 217 | 
         
            +
                    gr.Warning("Number of eigenvectors should be less than half the number of nodes.\n" f"Setting num_eig to {num_nodes // 2 - 1}.")
         
     | 
| 218 | 
         
            +
                    num_eig = num_nodes // 2 - 1
         
     | 
| 219 | 
         
            +
                    logging_str += f"Number of eigenvectors should be less than half the number of nodes.\n" f"Setting num_eig to {num_nodes // 2 - 1}.\n"
         
     | 
| 220 | 
         
            +
                
         
     | 
| 221 | 
         
            +
                start = time.time()
         
     | 
| 222 | 
         
            +
                progress(progess_start+0.0, desc="NCut")
         
     | 
| 223 | 
         
            +
                n_features = features_1.shape[-2]
         
     | 
| 224 | 
         
            +
                _features_1 = rearrange(features_1, "b h w d c -> (b h w) (d c)")
         
     | 
| 225 | 
         
            +
                _features_2 = rearrange(features_2, "b h w d c -> (b h w) (d c)")
         
     | 
| 226 | 
         
            +
                eigvecs, eigvals, _ = nystrom_ncut(
         
     | 
| 227 | 
         
            +
                    _features_1,
         
     | 
| 228 | 
         
            +
                    features_B=_features_2,
         
     | 
| 229 | 
         
            +
                    num_eig=num_eig,
         
     | 
| 230 | 
         
            +
                    num_sample=num_sample_ncut,
         
     | 
| 231 | 
         
            +
                    device="cuda" if torch.cuda.is_available() else "cpu",
         
     | 
| 232 | 
         
            +
                    affinity_focal_gamma=affinity_focal_gamma,
         
     | 
| 233 | 
         
            +
                    knn=knn_ncut,
         
     | 
| 234 | 
         
            +
                    sample_method=sampling_method,
         
     | 
| 235 | 
         
            +
                    distance=metric,
         
     | 
| 236 | 
         
            +
                    normalize_features=False,
         
     | 
| 237 | 
         
            +
                    indirect_connection=indirect_connection,
         
     | 
| 238 | 
         
            +
                    make_orthogonal=make_orthogonal,
         
     | 
| 239 | 
         
            +
                    make_symmetric=make_symmetric,
         
     | 
| 240 | 
         
            +
                    n_features=n_features,
         
     | 
| 241 | 
         
            +
                )
         
     | 
| 242 | 
         
            +
                # print(f"NCUT time: {time.time() - start:.2f}s")
         
     | 
| 243 | 
         
            +
                logging_str += f"NCUT time: {time.time() - start:.2f}s\n"
         
     | 
| 244 | 
         
            +
                
         
     | 
| 245 | 
         
            +
                start = time.time()
         
     | 
| 246 | 
         
            +
                progress(progess_start+0.01, desc="spectral-tSNE")
         
     | 
| 247 | 
         
            +
                _, rgb = eigenvector_to_rgb(
         
     | 
| 248 | 
         
            +
                    eigvecs,
         
     | 
| 249 | 
         
            +
                    method=embedding_method,
         
     | 
| 250 | 
         
            +
                    metric=embedding_metric,
         
     | 
| 251 | 
         
            +
                    num_sample=num_sample_tsne,
         
     | 
| 252 | 
         
            +
                    perplexity=perplexity,
         
     | 
| 253 | 
         
            +
                    n_neighbors=n_neighbors,
         
     | 
| 254 | 
         
            +
                    min_distance=min_dist,
         
     | 
| 255 | 
         
            +
                    knn=knn_tsne,
         
     | 
| 256 | 
         
            +
                    device="cuda" if torch.cuda.is_available() else "cpu",
         
     | 
| 257 | 
         
            +
                )
         
     | 
| 258 | 
         
            +
                logging_str += f"{embedding_method} time: {time.time() - start:.2f}s\n"
         
     | 
| 259 | 
         
            +
             
     | 
| 260 | 
         
            +
                rgb = rgb.reshape(features_1.shape[:3] + (3,))
         
     | 
| 261 | 
         
            +
                return rgb, logging_str, eigvecs
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
             
            def dont_use_too_much_green(image_rgb):
         
     | 
| 265 | 
         
             
                # make sure the foval 40% of the image is red leading
         
     | 
| 266 | 
         
             
                x1, x2 = int(image_rgb.shape[1] * 0.3), int(image_rgb.shape[1] * 0.7)
         
     | 
| 
         | 
|
| 670 | 
         
             
                **kwargs,
         
     | 
| 671 | 
         
             
            ):
         
     | 
| 672 | 
         
             
                advanced = kwargs.get("advanced", False)
         
     | 
| 673 | 
         
            +
                directed = kwargs.get("directed", False)
         
     | 
| 674 | 
         
            +
                
         
     | 
| 675 | 
         
             
                progress = gr.Progress()
         
     | 
| 676 | 
         
             
                progress(0.2, desc="Feature Extraction")
         
     | 
| 677 | 
         | 
| 
         | 
|
| 720 | 
         
             
                    features = extract_features(
         
     | 
| 721 | 
         
             
                        images, model, node_type=node_type, layer=layer-1, batch_size=BATCH_SIZE
         
     | 
| 722 | 
         
             
                    )
         
     | 
| 723 | 
         
            +
                    if directed:
         
     | 
| 724 | 
         
            +
                        node_type2 = kwargs.get("node_type2", None)
         
     | 
| 725 | 
         
            +
                        features_B = extract_features(
         
     | 
| 726 | 
         
            +
                            images, model, node_type=node_type2, layer=layer-1, batch_size=BATCH_SIZE
         
     | 
| 727 | 
         
            +
                        )
         
     | 
| 728 | 
         
             
                # print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
         
     | 
| 729 | 
         
             
                logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
         
     | 
| 730 | 
         
             
                del model
         
     | 
| 
         | 
|
| 853 | 
         | 
| 854 | 
         | 
| 855 | 
         
             
                # ailgnedcut
         
     | 
| 856 | 
         
            +
                if not directed:
         
     | 
| 857 | 
         
            +
                    rgb, _logging_str, eigvecs = compute_ncut(
         
     | 
| 858 | 
         
            +
                        features,
         
     | 
| 859 | 
         
            +
                        num_eig=num_eig,
         
     | 
| 860 | 
         
            +
                        num_sample_ncut=num_sample_ncut,
         
     | 
| 861 | 
         
            +
                        affinity_focal_gamma=affinity_focal_gamma,
         
     | 
| 862 | 
         
            +
                        knn_ncut=knn_ncut,
         
     | 
| 863 | 
         
            +
                        knn_tsne=knn_tsne,
         
     | 
| 864 | 
         
            +
                        num_sample_tsne=num_sample_tsne,
         
     | 
| 865 | 
         
            +
                        embedding_method=embedding_method,
         
     | 
| 866 | 
         
            +
                        embedding_metric=embedding_metric,
         
     | 
| 867 | 
         
            +
                        perplexity=perplexity,
         
     | 
| 868 | 
         
            +
                        n_neighbors=n_neighbors,
         
     | 
| 869 | 
         
            +
                        min_dist=min_dist,
         
     | 
| 870 | 
         
            +
                        sampling_method=sampling_method,
         
     | 
| 871 | 
         
            +
                        indirect_connection=indirect_connection,
         
     | 
| 872 | 
         
            +
                        make_orthogonal=make_orthogonal,
         
     | 
| 873 | 
         
            +
                        metric=ncut_metric,
         
     | 
| 874 | 
         
            +
                    )
         
     | 
| 875 | 
         
            +
                if directed:
         
     | 
| 876 | 
         
            +
                    head_index_text = kwargs.get("head_index_text", None)
         
     | 
| 877 | 
         
            +
                    n_heads = features.shape[-2]   # (batch, h, w, n_heads, d)
         
     | 
| 878 | 
         
            +
                    if head_index_text == 'all':
         
     | 
| 879 | 
         
            +
                        head_idx = torch.arange(n_heads)
         
     | 
| 880 | 
         
            +
                    else:
         
     | 
| 881 | 
         
            +
                        _idxs = head_index_text.split(",")
         
     | 
| 882 | 
         
            +
                        head_idx = torch.tensor([int(idx) for idx in _idxs])
         
     | 
| 883 | 
         
            +
                    features_A = features[:, :, :, head_idx, :]
         
     | 
| 884 | 
         
            +
                    features_B = features_B[:, :, :, head_idx, :]
         
     | 
| 885 | 
         
            +
                    
         
     | 
| 886 | 
         
            +
                    rgb, _logging_str, eigvecs = compute_ncut_directed(
         
     | 
| 887 | 
         
            +
                        features_A,
         
     | 
| 888 | 
         
            +
                        features_B,
         
     | 
| 889 | 
         
            +
                        num_eig=num_eig,
         
     | 
| 890 | 
         
            +
                        num_sample_ncut=num_sample_ncut,
         
     | 
| 891 | 
         
            +
                        affinity_focal_gamma=affinity_focal_gamma,
         
     | 
| 892 | 
         
            +
                        knn_ncut=knn_ncut,
         
     | 
| 893 | 
         
            +
                        knn_tsne=knn_tsne,
         
     | 
| 894 | 
         
            +
                        num_sample_tsne=num_sample_tsne,
         
     | 
| 895 | 
         
            +
                        embedding_method=embedding_method,
         
     | 
| 896 | 
         
            +
                        embedding_metric=embedding_metric,
         
     | 
| 897 | 
         
            +
                        perplexity=perplexity,
         
     | 
| 898 | 
         
            +
                        n_neighbors=n_neighbors,
         
     | 
| 899 | 
         
            +
                        min_dist=min_dist,
         
     | 
| 900 | 
         
            +
                        sampling_method=sampling_method,
         
     | 
| 901 | 
         
            +
                        indirect_connection=False,
         
     | 
| 902 | 
         
            +
                        make_orthogonal=make_orthogonal,
         
     | 
| 903 | 
         
            +
                        metric=ncut_metric,
         
     | 
| 904 | 
         
            +
                        make_symmetric=kwargs.get("make_symmetric", None),
         
     | 
| 905 | 
         
            +
                    )
         
     | 
| 906 | 
         
            +
                    
         
     | 
| 907 | 
         
            +
                    
         
     | 
| 908 | 
         
            +
                    
         
     | 
| 909 | 
         
             
                logging_str += _logging_str
         
     | 
| 910 | 
         | 
| 911 | 
         
             
                if "AlignedThreeModelAttnNodes" == model_name:
         
     | 
| 
         | 
|
| 977 | 
         | 
| 978 | 
         
             
            def _ncut_run(*args, **kwargs):
         
     | 
| 979 | 
         
             
                n_ret = kwargs.pop("n_ret", 1)
         
     | 
| 980 | 
         
            +
                # try:
         
     | 
| 981 | 
         
            +
                #     if torch.cuda.is_available():
         
     | 
| 982 | 
         
            +
                #         torch.cuda.empty_cache()
         
     | 
| 983 | 
         | 
| 984 | 
         
            +
                #     ret = ncut_run(*args, **kwargs)
         
     | 
| 985 | 
         | 
| 986 | 
         
            +
                #     if torch.cuda.is_available():
         
     | 
| 987 | 
         
            +
                #         torch.cuda.empty_cache()
         
     | 
| 988 | 
         | 
| 989 | 
         
            +
                #     ret = list(ret)[:n_ret] + [ret[-1]]
         
     | 
| 990 | 
         
            +
                #     return ret
         
     | 
| 991 | 
         
            +
                # except Exception as e:
         
     | 
| 992 | 
         
            +
                #     gr.Error(str(e))
         
     | 
| 993 | 
         
            +
                #     if torch.cuda.is_available():
         
     | 
| 994 | 
         
            +
                #         torch.cuda.empty_cache()
         
     | 
| 995 | 
         
            +
                #     return *(None for _ in range(n_ret)), "Error: " + str(e)
         
     | 
| 996 | 
         | 
| 997 | 
         
            +
                ret = ncut_run(*args, **kwargs)
         
     | 
| 998 | 
         
            +
                ret = list(ret)[:n_ret] + [ret[-1]]
         
     | 
| 999 | 
         
            +
                return ret
         
     | 
| 1000 | 
         | 
| 1001 | 
         
             
            if USE_HUGGINGFACE_ZEROGPU:
         
     | 
| 1002 | 
         
             
                @spaces.GPU(duration=30)
         
     | 
| 
         | 
|
| 1204 | 
         
             
                recursion_l1_gamma=0.5,
         
     | 
| 1205 | 
         
             
                recursion_l2_gamma=0.5,
         
     | 
| 1206 | 
         
             
                recursion_l3_gamma=0.5,
         
     | 
| 1207 | 
         
            +
                node_type2="k",
         
     | 
| 1208 | 
         
            +
                head_index_text='all',
         
     | 
| 1209 | 
         
            +
                make_symmetric=False,
         
     | 
| 1210 | 
         
             
                n_ret=1,
         
     | 
| 1211 | 
         
             
                plot_clusters=False,
         
     | 
| 1212 | 
         
             
                alignedcut_eig_norm_plot=False,
         
     | 
| 1213 | 
         
             
                advanced=False,
         
     | 
| 1214 | 
         
            +
                directed=False,
         
     | 
| 1215 | 
         
             
            ):
         
     | 
| 1216 | 
         
            +
                print(node_type2, head_index_text, make_symmetric)
         
     | 
| 1217 | 
         
             
                progress=gr.Progress()
         
     | 
| 1218 | 
         
             
                progress(0, desc="Starting")
         
     | 
| 1219 | 
         | 
| 
         | 
|
| 1345 | 
         
             
                    "plot_clusters": plot_clusters,
         
     | 
| 1346 | 
         
             
                    "alignedcut_eig_norm_plot": alignedcut_eig_norm_plot,
         
     | 
| 1347 | 
         
             
                    "advanced": advanced,
         
     | 
| 1348 | 
         
            +
                    "directed": directed,
         
     | 
| 1349 | 
         
            +
                    "node_type2": node_type2,
         
     | 
| 1350 | 
         
            +
                    "head_index_text": head_index_text,
         
     | 
| 1351 | 
         
            +
                    "make_symmetric": make_symmetric,
         
     | 
| 1352 | 
         
             
                }
         
     | 
| 1353 | 
         
             
                # print(kwargs)
         
     | 
| 1354 | 
         | 
| 
         | 
|
| 1506 | 
         
             
                # Train the model
         
     | 
| 1507 | 
         
             
                trainer.fit(mlp, dataloader)
         
     | 
| 1508 | 
         | 
| 1509 | 
         
            +
                mlp.progress(0.99, desc="Applying MLP")
         
     | 
| 1510 | 
         
             
                results = trainer.predict(mlp, data_loader)
         
     | 
| 1511 | 
         
             
                A_transformed = torch.cat(results, dim=0)
         
     | 
| 1512 | 
         | 
| 
         | 
|
| 2861 | 
         
             
                    buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
         
     | 
| 2862 | 
         
             
                    buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
         
     | 
| 2863 | 
         | 
| 2864 | 
         
            +
             
     | 
| 2865 | 
         
            +
                with gr.Tab('Directed (experimental)', visible=True) as tab_directed_ncut: 
         
     | 
| 2866 | 
         | 
| 2867 | 
         
            +
                    target_images = gr.State([])
         
     | 
| 2868 | 
         
            +
                    input_images = gr.State([])
         
     | 
| 2869 | 
         
            +
                    def add_mlp_fitting_buttons(output_gallery, mlp_gallery, target_images=target_images, input_images=input_images):
         
     | 
| 2870 | 
         
            +
                        with gr.Row():
         
     | 
| 2871 | 
         
            +
                            # mark_as_target_button = gr.Button("mark target", elem_id=f"mark_as_target_button_{output_gallery.elem_id}", variant='secondary')
         
     | 
| 2872 | 
         
            +
                            # mark_as_input_button = gr.Button("mark input", elem_id=f"mark_as_input_button_{output_gallery.elem_id}", variant='secondary')
         
     | 
| 2873 | 
         
            +
                            mark_as_target_button = gr.Button("🎯 Mark Target", elem_id=f"mark_as_target_button_{output_gallery.elem_id}", variant='secondary')
         
     | 
| 2874 | 
         
            +
                        fit_to_target_button = gr.Button("🔴 [MLP] Fit", elem_id=f"fit_to_target_button_{output_gallery.elem_id}", variant='primary')
         
     | 
| 2875 | 
         
            +
                        def mark_fn(images, text="target"):
         
     | 
| 2876 | 
         
            +
                            if images is None:
         
     | 
| 2877 | 
         
            +
                                raise gr.Error("No images selected")
         
     | 
| 2878 | 
         
            +
                            if len(images) == 0:
         
     | 
| 2879 | 
         
            +
                                raise gr.Error("No images selected")
         
     | 
| 2880 | 
         
            +
                            num_images = len(images)
         
     | 
| 2881 | 
         
            +
                            gr.Info(f"Marked {num_images} images as {text}")
         
     | 
| 2882 | 
         
            +
                            images = [(Image.open(tup[0]), []) for tup in images]
         
     | 
| 2883 | 
         
            +
                            return images
         
     | 
| 2884 | 
         
            +
                        mark_as_target_button.click(partial(mark_fn, text="target"), inputs=[output_gallery], outputs=[target_images])
         
     | 
| 2885 | 
         
            +
                        # mark_as_input_button.click(partial(mark_fn, text="input"), inputs=[output_gallery], outputs=[input_images])
         
     | 
| 2886 | 
         
            +
                        
         
     | 
| 2887 | 
         
            +
                        with gr.Accordion("➡️ MLP Parameters", open=False):
         
     | 
| 2888 | 
         
            +
                            num_layers_slider = gr.Slider(2, 10, step=1, label="Number of Layers", value=3, elem_id=f"num_layers_slider_{output_gallery.elem_id}")
         
     | 
| 2889 | 
         
            +
                            width_slider = gr.Slider(128, 4096, step=128, label="Width", value=512, elem_id=f"width_slider_{output_gallery.elem_id}")
         
     | 
| 2890 | 
         
            +
                            batch_size_slider = gr.Slider(32, 4096, step=32, label="Batch Size", value=128, elem_id=f"batch_size_slider_{output_gallery.elem_id}")
         
     | 
| 2891 | 
         
            +
                            lr_slider = gr.Slider(1e-6, 1, step=1e-6, label="Learning Rate", value=3e-4, elem_id=f"lr_slider_{output_gallery.elem_id}")
         
     | 
| 2892 | 
         
            +
                            fitting_steps_slider = gr.Slider(1000, 100000, step=1000, label="Fitting Steps", value=30000, elem_id=f"fitting_steps_slider_{output_gallery.elem_id}")
         
     | 
| 2893 | 
         
            +
                            fps_sample_slider = gr.Slider(128, 50000, step=128, label="FPS Sample", value=10240, elem_id=f"fps_sample_slider_{output_gallery.elem_id}")
         
     | 
| 2894 | 
         
            +
                            segmentation_loss_lambda_slider = gr.Slider(0, 100, step=0.01, label="Segmentation Preserving Loss Lambda", value=1, elem_id=f"segmentation_loss_lambda_slider_{output_gallery.elem_id}")
         
     | 
| 2895 | 
         
            +
                            
         
     | 
| 2896 | 
         
            +
                        fit_to_target_button.click(
         
     | 
| 2897 | 
         
            +
                            run_mlp_fit,
         
     | 
| 2898 | 
         
            +
                            inputs=[output_gallery, target_images, num_layers_slider, width_slider, batch_size_slider, lr_slider, fitting_steps_slider, fps_sample_slider, segmentation_loss_lambda_slider],
         
     | 
| 2899 | 
         
            +
                            outputs=[mlp_gallery],
         
     | 
| 2900 | 
         
            +
                        )
         
     | 
| 2901 | 
         
            +
             
     | 
| 2902 | 
         
            +
                    def make_parameters_section_2model(model_ratio=True):
         
     | 
| 2903 | 
         
            +
                        gr.Markdown("### Parameters <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Help</a>")
         
     | 
| 2904 | 
         
            +
                        from ncut_pytorch.backbone import list_models, get_demo_model_names
         
     | 
| 2905 | 
         
            +
                        model_names = list_models()
         
     | 
| 2906 | 
         
            +
                        model_names = sorted(model_names)
         
     | 
| 2907 | 
         
            +
                        # only CLIP DINO MAE is implemented for q k v
         
     | 
| 2908 | 
         
            +
                        ok_models = ["CLIP(ViT", "DiNO(", "MAE("]
         
     | 
| 2909 | 
         
            +
                        model_names = [m for m in model_names if any(ok in m for ok in ok_models)]
         
     | 
| 2910 | 
         
            +
                        
         
     | 
| 2911 | 
         
            +
                        def get_filtered_model_names(name):
         
     | 
| 2912 | 
         
            +
                            return [m for m in model_names if name.lower() in m.lower()]
         
     | 
| 2913 | 
         
            +
                        def get_default_model_name(name):
         
     | 
| 2914 | 
         
            +
                            lst = get_filtered_model_names(name)
         
     | 
| 2915 | 
         
            +
                            if len(lst) > 1:
         
     | 
| 2916 | 
         
            +
                                return lst[1]
         
     | 
| 2917 | 
         
            +
                            return lst[0]
         
     | 
| 2918 | 
         
            +
                        
         
     | 
| 2919 | 
         
            +
             
     | 
| 2920 | 
         
            +
                        model_radio = gr.Radio(["CLIP", "DiNO", "MAE"], label="Backbone", value="DiNO", elem_id="model_radio", show_label=True, visible=model_ratio)
         
     | 
| 2921 | 
         
            +
                        model_dropdown = gr.Dropdown(get_filtered_model_names("DiNO"), label="", value="DiNO(dino_vitb8_448)", elem_id="model_name", show_label=False)
         
     | 
| 2922 | 
         
            +
                        model_radio.change(fn=lambda x: gr.update(choices=get_filtered_model_names(x), value=get_default_model_name(x)), inputs=model_radio, outputs=[model_dropdown])
         
     | 
| 2923 | 
         
            +
                        layer_slider = gr.Slider(1, 12, step=1, label="Backbone: Layer index", value=10, elem_id="layer")
         
     | 
| 2924 | 
         
            +
                        positive_prompt = gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'")
         
     | 
| 2925 | 
         
            +
                        positive_prompt.visible = False
         
     | 
| 2926 | 
         
            +
                        negative_prompt = gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'")
         
     | 
| 2927 | 
         
            +
                        negative_prompt.visible = False
         
     | 
| 2928 | 
         
            +
                        node_type_dropdown = gr.Dropdown(['q', 'k', 'v'], 
         
     | 
| 2929 | 
         
            +
                                                        label="Left-side Node Type", value="q", elem_id="node_type", info="In directed case, left-side SVD eigenvector is taken")
         
     | 
| 2930 | 
         
            +
                        node_type_dropdown2 = gr.Dropdown(['q', 'k', 'v'], 
         
     | 
| 2931 | 
         
            +
                                                        label="Right-side Node Type", value="k", elem_id="node_type2")
         
     | 
| 2932 | 
         
            +
                        head_index_text = gr.Textbox(value='all', label="Head Index", elem_id="head_index", type="text", info="which attention heads to use, comma separated, e.g. 0,1,2")
         
     | 
| 2933 | 
         
            +
                        make_symmetric = gr.Checkbox(label="Make Symmetric", value=False, elem_id="make_symmetric", info="make the graph symmetric by A = (A + A.T) / 2")
         
     | 
| 2934 | 
         
            +
                        
         
     | 
| 2935 | 
         
            +
                        num_eig_slider = gr.Slider(1, 1000, step=1, label="NCUT: Number of eigenvectors", value=100, elem_id="num_eig", info='increase for smaller clusters')
         
     | 
| 2936 | 
         
            +
             
     | 
| 2937 | 
         
            +
                        def change_layer_slider(model_name):
         
     | 
| 2938 | 
         
            +
                            # SD2, UNET
         
     | 
| 2939 | 
         
            +
                            if "stable" in model_name.lower() and "diffusion" in model_name.lower():
         
     | 
| 2940 | 
         
            +
                                from ncut_pytorch.backbone import SD_KEY_DICT
         
     | 
| 2941 | 
         
            +
                                default_layer = 'up_2_resnets_1_block' if 'diffusion-3' not in model_name else 'block_23'
         
     | 
| 2942 | 
         
            +
                                return (gr.Slider(1, 49, step=1, label="Diffusion: Timestep (Noise)", value=5, elem_id="layer", visible=True, info="Noise level, 50 is max noise"),
         
     | 
| 2943 | 
         
            +
                                        gr.Dropdown(SD_KEY_DICT[model_name], label="Diffusion: Layer and Node", value=default_layer, elem_id="node_type", info="U-Net (v1, v2) or DiT (v3)"))
         
     | 
| 2944 | 
         
            +
                            
         
     | 
| 2945 | 
         
            +
                            if model_name == "LISSL(xinlai/LISSL-7B-v1)":
         
     | 
| 2946 | 
         
            +
                                layer_names = ["dec_0_input", "dec_0_attn", "dec_0_block", "dec_1_input", "dec_1_attn", "dec_1_block"]
         
     | 
| 2947 | 
         
            +
                                default_layer = "dec_1_block"
         
     | 
| 2948 | 
         
            +
                                return (gr.Slider(1, 6, step=1, label="LISA decoder: Layer index", value=6, elem_id="layer", visible=False, info=""),
         
     | 
| 2949 | 
         
            +
                                        gr.Dropdown(layer_names, label="LISA decoder: Layer and Node", value=default_layer, elem_id="node_type"))
         
     | 
| 2950 | 
         
            +
             
     | 
| 2951 | 
         
            +
                            layer_dict = LAYER_DICT
         
     | 
| 2952 | 
         
            +
                            if model_name in layer_dict:
         
     | 
| 2953 | 
         
            +
                                value = layer_dict[model_name]
         
     | 
| 2954 | 
         
            +
                                return gr.Slider(1, value, step=1, label="Backbone: Layer index", value=value, elem_id="layer", visible=True, info="")
         
     | 
| 2955 | 
         
            +
                            else:
         
     | 
| 2956 | 
         
            +
                                value = 12
         
     | 
| 2957 | 
         
            +
                                return gr.Slider(1, value, step=1, label="Backbone: Layer index", value=value, elem_id="layer", visible=True, info="")            
         
     | 
| 2958 | 
         
            +
                        model_dropdown.change(fn=change_layer_slider, inputs=model_dropdown, outputs=layer_slider)
         
     | 
| 2959 | 
         
            +
                        
         
     | 
| 2960 | 
         
            +
                        def change_prompt_text(model_name):
         
     | 
| 2961 | 
         
            +
                            if model_name in promptable_diffusion_models:
         
     | 
| 2962 | 
         
            +
                                return (gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'", visible=True),
         
     | 
| 2963 | 
         
            +
                                        gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=True))
         
     | 
| 2964 | 
         
            +
                            return (gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'", visible=False),
         
     | 
| 2965 | 
         
            +
                                    gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=False))
         
     | 
| 2966 | 
         
            +
                        model_dropdown.change(fn=change_prompt_text, inputs=model_dropdown, outputs=[positive_prompt, negative_prompt])
         
     | 
| 2967 | 
         
            +
                        
         
     | 
| 2968 | 
         
            +
                        with gr.Accordion("Advanced Parameters: NCUT", open=False):
         
     | 
| 2969 | 
         
            +
                            gr.Markdown("<a href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Docs: How to Get Better Segmentation</a>")
         
     | 
| 2970 | 
         
            +
                            affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="NCUT: Affinity focal gamma", value=0.5, elem_id="affinity_focal_gamma", info="decrease for shaper segmentation")
         
     | 
| 2971 | 
         
            +
                            num_sample_ncut_slider = gr.Slider(100, 50000, step=100, label="NCUT: num_sample", value=10000, elem_id="num_sample_ncut", info="Nyström approximation")
         
     | 
| 2972 | 
         
            +
                            # sampling_method_dropdown = gr.Dropdown(["QuickFPS", "random"], label="NCUT: Sampling method", value="QuickFPS", elem_id="sampling_method", info="Nyström approximation")
         
     | 
| 2973 | 
         
            +
                            sampling_method_dropdown = gr.Radio(["QuickFPS", "random"], label="NCUT: Sampling method", value="QuickFPS", elem_id="sampling_method")
         
     | 
| 2974 | 
         
            +
                            # ncut_metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="NCUT: Distance metric", value="cosine", elem_id="ncut_metric")
         
     | 
| 2975 | 
         
            +
                            ncut_metric_dropdown = gr.Radio(["euclidean", "cosine"], label="NCUT: Distance metric", value="cosine", elem_id="ncut_metric")
         
     | 
| 2976 | 
         
            +
                            ncut_knn_slider = gr.Slider(1, 100, step=1, label="NCUT: KNN", value=10, elem_id="knn_ncut", info="Nyström approximation")
         
     | 
| 2977 | 
         
            +
                            ncut_indirect_connection = gr.Checkbox(label="indirect_connection", value=False, elem_id="ncut_indirect_connection", info="TODO: Indirect connection is not implemented for directed NCUT", interactive=False)
         
     | 
| 2978 | 
         
            +
                            ncut_make_orthogonal = gr.Checkbox(label="make_orthogonal", value=False, elem_id="ncut_make_orthogonal", info="Apply post-hoc eigenvectors orthogonalization")
         
     | 
| 2979 | 
         
            +
                        with gr.Accordion("Advanced Parameters: Visualization", open=False):
         
     | 
| 2980 | 
         
            +
                            # embedding_method_dropdown = gr.Dropdown(["tsne_3d", "umap_3d", "umap_sphere", "tsne_2d", "umap_2d"], label="Coloring method", value="tsne_3d", elem_id="embedding_method")
         
     | 
| 2981 | 
         
            +
                            embedding_method_dropdown = gr.Radio(["tsne_3d", "umap_3d", "umap_sphere", "tsne_2d", "umap_2d"], label="Coloring method", value="tsne_3d", elem_id="embedding_method")
         
     | 
| 2982 | 
         
            +
                            # embedding_metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="t-SNE/UMAP metric", value="euclidean", elem_id="embedding_metric")
         
     | 
| 2983 | 
         
            +
                            embedding_metric_dropdown = gr.Radio(["euclidean", "cosine"], label="t-SNE/UMAP: metric", value="euclidean", elem_id="embedding_metric")
         
     | 
| 2984 | 
         
            +
                            num_sample_tsne_slider = gr.Slider(100, 10000, step=100, label="t-SNE/UMAP: num_sample", value=300, elem_id="num_sample_tsne", info="Nyström approximation")
         
     | 
| 2985 | 
         
            +
                            knn_tsne_slider = gr.Slider(1, 100, step=1, label="t-SNE/UMAP: KNN", value=10, elem_id="knn_tsne", info="Nyström approximation")
         
     | 
| 2986 | 
         
            +
                            perplexity_slider = gr.Slider(10, 1000, step=10, label="t-SNE: perplexity", value=150, elem_id="perplexity")
         
     | 
| 2987 | 
         
            +
                            n_neighbors_slider = gr.Slider(10, 1000, step=10, label="UMAP: n_neighbors", value=150, elem_id="n_neighbors")
         
     | 
| 2988 | 
         
            +
                            min_dist_slider = gr.Slider(0.1, 1, step=0.1, label="UMAP: min_dist", value=0.1, elem_id="min_dist")
         
     | 
| 2989 | 
         
            +
                        return [model_dropdown, layer_slider, node_type_dropdown, node_type_dropdown2, head_index_text, make_symmetric, num_eig_slider,
         
     | 
| 2990 | 
         
            +
                                affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
         
     | 
| 2991 | 
         
            +
                                embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, 
         
     | 
| 2992 | 
         
            +
                                perplexity_slider, n_neighbors_slider, min_dist_slider, 
         
     | 
| 2993 | 
         
            +
                                sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt]
         
     | 
| 2994 | 
         
            +
                        
         
     | 
| 2995 | 
         
            +
                    def add_one_model(i_model=1):
         
     | 
| 2996 | 
         
            +
                        with gr.Column(scale=5, min_width=200) as col:
         
     | 
| 2997 | 
         
            +
                            gr.Markdown(f'### Output Images')
         
     | 
| 2998 | 
         
            +
                            output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
         
     | 
| 2999 | 
         
            +
                            submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
         
     | 
| 3000 | 
         
            +
                            add_rotate_flip_buttons(output_gallery)
         
     | 
| 3001 | 
         
            +
                            add_download_button(output_gallery, f"ncut_embed")
         
     | 
| 3002 | 
         
            +
                            mlp_gallery = gr.Gallery(format='png', value=[], label="MLP color align", show_label=True, elem_id=f"mlp{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
         
     | 
| 3003 | 
         
            +
                            add_mlp_fitting_buttons(output_gallery, mlp_gallery)
         
     | 
| 3004 | 
         
            +
                            add_download_button(mlp_gallery, f"mlp_color_align")
         
     | 
| 3005 | 
         
            +
                            norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id=f"eig_norm{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
         
     | 
| 3006 | 
         
            +
                            add_download_button(norm_gallery, f"eig_norm")
         
     | 
| 3007 | 
         
            +
                            cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id=f"clusters{i_model}", columns=[2], rows=[4], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
         
     | 
| 3008 | 
         
            +
                            add_download_button(cluster_gallery, f"clusters")
         
     | 
| 3009 | 
         
            +
                            [
         
     | 
| 3010 | 
         
            +
                                model_dropdown, layer_slider, node_type_dropdown, node_type_dropdown2, head_index_text, make_symmetric, num_eig_slider, 
         
     | 
| 3011 | 
         
            +
                                affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, 
         
     | 
| 3012 | 
         
            +
                                embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, 
         
     | 
| 3013 | 
         
            +
                                perplexity_slider, n_neighbors_slider, min_dist_slider, 
         
     | 
| 3014 | 
         
            +
                                sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
         
     | 
| 3015 | 
         
            +
                            ] = make_parameters_section_2model()
         
     | 
| 3016 | 
         
            +
                            # logging text box
         
     | 
| 3017 | 
         
            +
                            logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
         
     | 
| 3018 | 
         
            +
                            false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
         
     | 
| 3019 | 
         
            +
                            no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
         
     | 
| 3020 | 
         
            +
                            
         
     | 
| 3021 | 
         
            +
                            false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
         
     | 
| 3022 | 
         
            +
                            
         
     | 
| 3023 | 
         
            +
                            submit_button.click(
         
     | 
| 3024 | 
         
            +
                                partial(run_fn, n_ret=3, plot_clusters=True, alignedcut_eig_norm_plot=True, advanced=True, directed=True),
         
     | 
| 3025 | 
         
            +
                                inputs=[
         
     | 
| 3026 | 
         
            +
                                    input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown, 
         
     | 
| 3027 | 
         
            +
                                    positive_prompt, negative_prompt,
         
     | 
| 3028 | 
         
            +
                                    false_placeholder, no_prompt, no_prompt, no_prompt,
         
     | 
| 3029 | 
         
            +
                                    affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, 
         
     | 
| 3030 | 
         
            +
                                    embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, 
         
     | 
| 3031 | 
         
            +
                                    perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown,
         
     | 
| 3032 | 
         
            +
                                    *[false_placeholder for _ in range(9)],
         
     | 
| 3033 | 
         
            +
                                    node_type_dropdown2, head_index_text, make_symmetric
         
     | 
| 3034 | 
         
            +
                                ],
         
     | 
| 3035 | 
         
            +
                                outputs=[output_gallery, cluster_gallery, norm_gallery, logging_text]
         
     | 
| 3036 | 
         
            +
                            )
         
     | 
| 3037 | 
         
            +
                            
         
     | 
| 3038 | 
         
            +
                            output_gallery.change(lambda x: gr.update(value=x), inputs=[output_gallery], outputs=[mlp_gallery])
         
     | 
| 3039 | 
         
            +
                            
         
     | 
| 3040 | 
         
            +
                            return output_gallery
         
     | 
| 3041 | 
         
            +
             
     | 
| 3042 | 
         
            +
                    galleries = []
         
     | 
| 3043 | 
         | 
| 3044 | 
         
            +
                    with gr.Row():
         
     | 
| 3045 | 
         
            +
                        with gr.Column(scale=5, min_width=200):
         
     | 
| 3046 | 
         
            +
                            input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(allow_download=True)
         
     | 
| 3047 | 
         
            +
                            submit_button.visible = False
         
     | 
| 3048 | 
         
            +
             
     | 
| 3049 | 
         
            +
                            
         
     | 
| 3050 | 
         
            +
                        for i in range(3):
         
     | 
| 3051 | 
         
            +
                            g = add_one_model()
         
     | 
| 3052 | 
         
            +
                            galleries.append(g)
         
     | 
| 3053 | 
         
            +
                            
         
     | 
| 3054 | 
         
            +
                    # Create rows and buttons in a loop
         
     | 
| 3055 | 
         
            +
                    rows = []
         
     | 
| 3056 | 
         
            +
                    buttons = []
         
     | 
| 3057 | 
         
            +
             
     | 
| 3058 | 
         
            +
                    for i in range(4):
         
     | 
| 3059 | 
         
            +
                        row = gr.Row(visible=False)
         
     | 
| 3060 | 
         
            +
                        rows.append(row)
         
     | 
| 3061 | 
         
            +
                        
         
     | 
| 3062 | 
         
            +
                        with row:
         
     | 
| 3063 | 
         
            +
                            for j in range(4):
         
     | 
| 3064 | 
         
            +
                                with gr.Column(scale=5, min_width=200):
         
     | 
| 3065 | 
         
            +
                                    g = add_one_model()
         
     | 
| 3066 | 
         
            +
                                    galleries.append(g)
         
     | 
| 3067 | 
         
            +
             
     | 
| 3068 | 
         
            +
                        button = gr.Button("➕ Add Compare", elem_id=f"add_button_{i}", visible=False if i > 0 else True, scale=3)
         
     | 
| 3069 | 
         
            +
                        buttons.append(button)
         
     | 
| 3070 | 
         
            +
                        
         
     | 
| 3071 | 
         
            +
                        if i > 0:
         
     | 
| 3072 | 
         
            +
                            # Reveal the current row and next button
         
     | 
| 3073 | 
         
            +
                            buttons[i - 1].click(fn=lambda x: gr.update(visible=True), outputs=row)
         
     | 
| 3074 | 
         
            +
                            buttons[i - 1].click(fn=lambda x: gr.update(visible=True), outputs=button)
         
     | 
| 3075 | 
         
            +
                            
         
     | 
| 3076 | 
         
            +
                            # Hide the current button
         
     | 
| 3077 | 
         
            +
                            buttons[i - 1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[i - 1])
         
     | 
| 3078 | 
         
            +
             
     | 
| 3079 | 
         
            +
                    # Last button only reveals the last row and hides itself
         
     | 
| 3080 | 
         
            +
                    buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
         
     | 
| 3081 | 
         
            +
                    buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
         
     | 
| 3082 | 
         
            +
             
     | 
| 3083 | 
         
            +
             
     | 
| 3084 | 
         | 
| 3085 | 
         
             
                with gr.Tab('📄About'):
         
     | 
| 3086 | 
         
             
                    with gr.Column():
         
     | 
    	
        directed_ncut.py
    ADDED
    
    | 
         @@ -0,0 +1,287 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # %%
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            def affinity_from_features(
         
     | 
| 6 | 
         
            +
                features,
         
     | 
| 7 | 
         
            +
                features_B=None,
         
     | 
| 8 | 
         
            +
                affinity_focal_gamma=1.0,
         
     | 
| 9 | 
         
            +
                distance="cosine",
         
     | 
| 10 | 
         
            +
                normalize_features=False,
         
     | 
| 11 | 
         
            +
                fill_diagonal=False,
         
     | 
| 12 | 
         
            +
                n_features=1,
         
     | 
| 13 | 
         
            +
            ):
         
     | 
| 14 | 
         
            +
                """Compute affinity matrix from input features.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                Args:
         
     | 
| 17 | 
         
            +
                    features (torch.Tensor): input features, shape (n_samples, n_features)
         
     | 
| 18 | 
         
            +
                    feature_B (torch.Tensor, optional): optional, if not None, compute affinity between two features
         
     | 
| 19 | 
         
            +
                    affinity_focal_gamma (float): affinity matrix parameter, lower t reduce the edge weights
         
     | 
| 20 | 
         
            +
                        on weak connections, default 1.0
         
     | 
| 21 | 
         
            +
                    distance (str): distance metric, 'cosine' (default) or 'euclidean'.
         
     | 
| 22 | 
         
            +
                    apply_normalize (bool): normalize input features before computing affinity matrix,
         
     | 
| 23 | 
         
            +
                        default True
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                Returns:
         
     | 
| 26 | 
         
            +
                    (torch.Tensor): affinity matrix, shape (n_samples, n_samples)
         
     | 
| 27 | 
         
            +
                """
         
     | 
| 28 | 
         
            +
                # compute affinity matrix from input features
         
     | 
| 29 | 
         
            +
                features = features.clone()
         
     | 
| 30 | 
         
            +
                if features_B is not None:
         
     | 
| 31 | 
         
            +
                    features_B = features_B.clone()
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                # if feature_B is not provided, compute affinity matrix on features x features
         
     | 
| 34 | 
         
            +
                # if feature_B is provided, compute affinity matrix on features x feature_B
         
     | 
| 35 | 
         
            +
                if features_B is not None:
         
     | 
| 36 | 
         
            +
                    assert not fill_diagonal, "fill_diagonal should be False when feature_B is None"
         
     | 
| 37 | 
         
            +
                features_B = features if features_B is None else features_B
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                if normalize_features:
         
     | 
| 40 | 
         
            +
                    features = F.normalize(features, dim=-1)
         
     | 
| 41 | 
         
            +
                    features_B = F.normalize(features_B, dim=-1)
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                if distance == "cosine":
         
     | 
| 44 | 
         
            +
                    # if not check_if_normalized(features):
         
     | 
| 45 | 
         
            +
                    
         
     | 
| 46 | 
         
            +
                    # TODO: make sure features are normalized within each head
         
     | 
| 47 | 
         
            +
                    
         
     | 
| 48 | 
         
            +
                    features = F.normalize(features, dim=-1)
         
     | 
| 49 | 
         
            +
                    # if not check_if_normalized(features_B):
         
     | 
| 50 | 
         
            +
                    features_B = F.normalize(features_B, dim=-1)
         
     | 
| 51 | 
         
            +
                    A = 1 - (features @ features_B.T) / n_features
         
     | 
| 52 | 
         
            +
                elif distance == "euclidean":
         
     | 
| 53 | 
         
            +
                    A = torch.cdist(features, features_B, p=2) / n_features
         
     | 
| 54 | 
         
            +
                else:
         
     | 
| 55 | 
         
            +
                    raise ValueError("distance should be 'cosine' or 'euclidean'")
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                if fill_diagonal:
         
     | 
| 58 | 
         
            +
                    A[torch.arange(A.shape[0]), torch.arange(A.shape[0])] = 0
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                # torch.exp make affinity matrix positive definite,
         
     | 
| 61 | 
         
            +
                # lower affinity_focal_gamma reduce the weak edge weights
         
     | 
| 62 | 
         
            +
                A = torch.exp(-((A / affinity_focal_gamma)))
         
     | 
| 63 | 
         
            +
                return A
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
            from ncut_pytorch.ncut_pytorch import run_subgraph_sampling, propagate_knn, gram_schmidt
         
     | 
| 66 | 
         
            +
            import logging
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
            import torch
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
            def ncut(
         
     | 
| 71 | 
         
            +
                A,
         
     | 
| 72 | 
         
            +
                num_eig=20,
         
     | 
| 73 | 
         
            +
                eig_solver="svd_lowrank",
         
     | 
| 74 | 
         
            +
                make_symmetric=True,
         
     | 
| 75 | 
         
            +
            ):
         
     | 
| 76 | 
         
            +
                """PyTorch implementation of Normalized cut without Nystrom-like approximation.
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                Args:
         
     | 
| 79 | 
         
            +
                    A (torch.Tensor): affinity matrix, shape (n_samples, n_samples)
         
     | 
| 80 | 
         
            +
                    num_eig (int): number of eigenvectors to return
         
     | 
| 81 | 
         
            +
                    eig_solver (str): eigen decompose solver, ['svd_lowrank', 'lobpcg', 'svd', 'eigh']
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                Returns:
         
     | 
| 84 | 
         
            +
                    (torch.Tensor): eigenvectors corresponding to the eigenvalues, shape (n_samples, num_eig)
         
     | 
| 85 | 
         
            +
                    (torch.Tensor): eigenvalues of the eigenvectors, sorted in descending order
         
     | 
| 86 | 
         
            +
                """
         
     | 
| 87 | 
         
            +
                if make_symmetric:
         
     | 
| 88 | 
         
            +
                    # make sure A is symmetric
         
     | 
| 89 | 
         
            +
                    A = (A + A.T) / 2
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                # symmetrical normalization; A = D^(-1/2) A D^(-1/2)
         
     | 
| 92 | 
         
            +
                D_r = A.sum(dim=0).detach().clone()
         
     | 
| 93 | 
         
            +
                D_c = A.sum(dim=1).detach().clone()
         
     | 
| 94 | 
         
            +
                A /= torch.sqrt(D_r)[:, None]
         
     | 
| 95 | 
         
            +
                A /= torch.sqrt(D_c)[None, :]
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                # compute eigenvectors
         
     | 
| 98 | 
         
            +
                if eig_solver == "svd_lowrank":  # default
         
     | 
| 99 | 
         
            +
                    # only top q eigenvectors, fastest
         
     | 
| 100 | 
         
            +
                    eigen_vector, eigen_value, _ = torch.svd_lowrank(A, q=num_eig)
         
     | 
| 101 | 
         
            +
                elif eig_solver == "lobpcg":
         
     | 
| 102 | 
         
            +
                    # only top k eigenvectors, fast
         
     | 
| 103 | 
         
            +
                    eigen_value, eigen_vector = torch.lobpcg(A, k=num_eig)
         
     | 
| 104 | 
         
            +
                elif eig_solver == "svd":
         
     | 
| 105 | 
         
            +
                    # all eigenvectors, slow
         
     | 
| 106 | 
         
            +
                    eigen_vector, eigen_value, _ = torch.svd(A)
         
     | 
| 107 | 
         
            +
                elif eig_solver == "eigh":
         
     | 
| 108 | 
         
            +
                    # all eigenvectors, slow
         
     | 
| 109 | 
         
            +
                    eigen_value, eigen_vector = torch.linalg.eigh(A)
         
     | 
| 110 | 
         
            +
                else:
         
     | 
| 111 | 
         
            +
                    raise ValueError(
         
     | 
| 112 | 
         
            +
                        "eigen_solver should be 'lobpcg', 'svd_lowrank', 'svd' or 'eigh'"
         
     | 
| 113 | 
         
            +
                    )
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                # sort eigenvectors by eigenvalues, take top (descending order)
         
     | 
| 116 | 
         
            +
                eigen_value = eigen_value.real
         
     | 
| 117 | 
         
            +
                eigen_vector = eigen_vector.real
         
     | 
| 118 | 
         
            +
                
         
     | 
| 119 | 
         
            +
                sort_order = torch.argsort(eigen_value, descending=True)[:num_eig]
         
     | 
| 120 | 
         
            +
                eigen_value = eigen_value[sort_order]
         
     | 
| 121 | 
         
            +
                eigen_vector = eigen_vector[:, sort_order]
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                if eigen_value.min() < 0:
         
     | 
| 124 | 
         
            +
                    logging.warning(
         
     | 
| 125 | 
         
            +
                        "negative eigenvalues detected, please make sure the affinity matrix is positive definite"
         
     | 
| 126 | 
         
            +
                    )
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                return eigen_vector, eigen_value
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
            def nystrom_ncut(
         
     | 
| 131 | 
         
            +
                features,
         
     | 
| 132 | 
         
            +
                features_B=None,
         
     | 
| 133 | 
         
            +
                num_eig=100,
         
     | 
| 134 | 
         
            +
                num_sample=10000,
         
     | 
| 135 | 
         
            +
                knn=10,
         
     | 
| 136 | 
         
            +
                sample_method="farthest",
         
     | 
| 137 | 
         
            +
                distance="cosine",
         
     | 
| 138 | 
         
            +
                affinity_focal_gamma=1.0,
         
     | 
| 139 | 
         
            +
                indirect_connection=False,
         
     | 
| 140 | 
         
            +
                indirect_pca_dim=100,
         
     | 
| 141 | 
         
            +
                device=None,
         
     | 
| 142 | 
         
            +
                eig_solver="svd_lowrank",
         
     | 
| 143 | 
         
            +
                normalize_features=False,
         
     | 
| 144 | 
         
            +
                matmul_chunk_size=8096,
         
     | 
| 145 | 
         
            +
                make_orthogonal=False,
         
     | 
| 146 | 
         
            +
                verbose=False,
         
     | 
| 147 | 
         
            +
                no_propagation=False,
         
     | 
| 148 | 
         
            +
                make_symmetric=False,
         
     | 
| 149 | 
         
            +
                n_features=1,
         
     | 
| 150 | 
         
            +
            ):
         
     | 
| 151 | 
         
            +
                """PyTorch implementation of Faster Nystrom Normalized cut.
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                Args:
         
     | 
| 154 | 
         
            +
                    features (torch.Tensor): feature matrix, shape (n_samples, n_features)
         
     | 
| 155 | 
         
            +
                    features_2 (torch.Tensor): feature matrix 2, for asymmetric affinity matrix, shape (n_samples2, n_features)
         
     | 
| 156 | 
         
            +
                    num_eig (int): default 20, number of top eigenvectors to return
         
     | 
| 157 | 
         
            +
                    num_sample (int): default 30000, number of samples for Nystrom-like approximation
         
     | 
| 158 | 
         
            +
                    knn (int): default 3, number of KNN for propagating eigenvectors from subgraph to full graph,
         
     | 
| 159 | 
         
            +
                        smaller knn will result in more sharp eigenvectors,
         
     | 
| 160 | 
         
            +
                    sample_method (str): sample method, 'farthest' (default) or 'random'
         
     | 
| 161 | 
         
            +
                        'farthest' is recommended for better approximation
         
     | 
| 162 | 
         
            +
                    distance (str): distance metric, 'cosine' (default) or 'euclidean'
         
     | 
| 163 | 
         
            +
                    affinity_focal_gamma (float): affinity matrix parameter, lower t reduce the weak edge weights,
         
     | 
| 164 | 
         
            +
                        resulting in more sharp eigenvectors, default 1.0
         
     | 
| 165 | 
         
            +
                    indirect_connection (bool): include indirect connection in the subgraph, default True
         
     | 
| 166 | 
         
            +
                    indirect_pca_dim (int): default 100, PCA dimension to reduce the node dimension, only applied to
         
     | 
| 167 | 
         
            +
                        the not sampled nodes, not applied to the sampled nodes
         
     | 
| 168 | 
         
            +
                    device (str): device to use for computation, if None, will not change device
         
     | 
| 169 | 
         
            +
                        a good practice is to pass features by CPU since it's usually large,
         
     | 
| 170 | 
         
            +
                        and move subgraph affinity to GPU to speed up eigenvector computation
         
     | 
| 171 | 
         
            +
                    eig_solver (str): eigen decompose solver, 'svd_lowrank' (default), 'lobpcg', 'svd', 'eigh'
         
     | 
| 172 | 
         
            +
                        'svd_lowrank' is recommended for large scale graph, it's the fastest
         
     | 
| 173 | 
         
            +
                        they correspond to torch.svd_lowrank, torch.lobpcg, torch.svd, torch.linalg.eigh
         
     | 
| 174 | 
         
            +
                    normalize_features (bool): normalize input features before computing affinity matrix,
         
     | 
| 175 | 
         
            +
                        default True
         
     | 
| 176 | 
         
            +
                    matmul_chunk_size (int): chunk size for matrix multiplication
         
     | 
| 177 | 
         
            +
                        large matrix multiplication is chunked to reduce memory usage,
         
     | 
| 178 | 
         
            +
                        smaller chunk size will reduce memory usage but slower computation, default 8096
         
     | 
| 179 | 
         
            +
                    make_orthogonal (bool): make eigenvectors orthogonal after propagation, default True
         
     | 
| 180 | 
         
            +
                    verbose (bool): show progress bar when propagating eigenvectors from subgraph to full graph
         
     | 
| 181 | 
         
            +
                    no_propagation (bool): if True, skip the eigenvector propagation step, only return the subgraph eigenvectors
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                Returns:
         
     | 
| 184 | 
         
            +
                    (torch.Tensor): eigenvectors, shape (n_samples, num_eig)
         
     | 
| 185 | 
         
            +
                    (torch.Tensor): eigenvalues, sorted in descending order, shape (num_eig,)
         
     | 
| 186 | 
         
            +
                    (torch.Tensor): sampled_indices used by Nystrom-like approximation subgraph, shape (num_sample,)
         
     | 
| 187 | 
         
            +
                """
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                # check if features dimension greater than num_eig
         
     | 
| 190 | 
         
            +
                if eig_solver in ["svd_lowrank", "lobpcg"]:
         
     | 
| 191 | 
         
            +
                    assert features.shape[0] > (
         
     | 
| 192 | 
         
            +
                        num_eig * 2
         
     | 
| 193 | 
         
            +
                    ), "number of nodes should be greater than 2*num_eig"
         
     | 
| 194 | 
         
            +
                if eig_solver in ["svd", "eigh"]:
         
     | 
| 195 | 
         
            +
                    assert (
         
     | 
| 196 | 
         
            +
                        features.shape[0] > num_eig
         
     | 
| 197 | 
         
            +
                    ), "number of nodes should be greater than num_eig"
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
                features = features.clone()
         
     | 
| 200 | 
         
            +
                if normalize_features:
         
     | 
| 201 | 
         
            +
                    # features need to be normalized for affinity matrix computation (cosine distance)
         
     | 
| 202 | 
         
            +
                    features = torch.nn.functional.normalize(features, dim=-1)
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                sampled_indices = run_subgraph_sampling(
         
     | 
| 205 | 
         
            +
                    features,
         
     | 
| 206 | 
         
            +
                    num_sample=num_sample,
         
     | 
| 207 | 
         
            +
                    sample_method=sample_method,
         
     | 
| 208 | 
         
            +
                )
         
     | 
| 209 | 
         
            +
                
         
     | 
| 210 | 
         
            +
                sampled_indices_B = run_subgraph_sampling(
         
     | 
| 211 | 
         
            +
                    features_B,
         
     | 
| 212 | 
         
            +
                    num_sample=num_sample,
         
     | 
| 213 | 
         
            +
                    sample_method=sample_method,
         
     | 
| 214 | 
         
            +
                )
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                sampled_features = features[sampled_indices]
         
     | 
| 217 | 
         
            +
                sampled_features_B = features_B[sampled_indices_B]
         
     | 
| 218 | 
         
            +
                # move subgraph gpu to speed up
         
     | 
| 219 | 
         
            +
                original_device = sampled_features.device
         
     | 
| 220 | 
         
            +
                device = original_device if device is None else device
         
     | 
| 221 | 
         
            +
                sampled_features = sampled_features.to(device)
         
     | 
| 222 | 
         
            +
                sampled_features_B = sampled_features_B.to(device)
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
                # compute affinity matrix on subgraph
         
     | 
| 225 | 
         
            +
                A = affinity_from_features(
         
     | 
| 226 | 
         
            +
                    sampled_features, features_B=sampled_features_B,
         
     | 
| 227 | 
         
            +
                    affinity_focal_gamma=affinity_focal_gamma, distance=distance,
         
     | 
| 228 | 
         
            +
                    n_features=n_features,
         
     | 
| 229 | 
         
            +
                )
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
                not_sampled = torch.tensor(
         
     | 
| 232 | 
         
            +
                    list(set(range(features.shape[0])) - set(sampled_indices))
         
     | 
| 233 | 
         
            +
                )
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
                if len(not_sampled) == 0:
         
     | 
| 236 | 
         
            +
                    # if sampled all nodes, no need for nyström approximation
         
     | 
| 237 | 
         
            +
                    eigen_vector, eigen_value = ncut(A, num_eig, eig_solver=eig_solver)
         
     | 
| 238 | 
         
            +
                    return eigen_vector, eigen_value, sampled_indices
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                # 1) PCA to reduce the node dimension for the not sampled nodes
         
     | 
| 241 | 
         
            +
                # 2) compute indirect connection on the PC nodes
         
     | 
| 242 | 
         
            +
                if len(not_sampled) > 0 and indirect_connection:
         
     | 
| 243 | 
         
            +
                    raise NotImplementedError("indirect_connection is not implemented yet")
         
     | 
| 244 | 
         
            +
                    indirect_pca_dim = min(indirect_pca_dim, min(*features.shape))
         
     | 
| 245 | 
         
            +
                    U, S, V = torch.pca_lowrank(features[not_sampled].T, q=indirect_pca_dim)
         
     | 
| 246 | 
         
            +
                    feature_B = (features[not_sampled].T @ V).T  # project to PCA space
         
     | 
| 247 | 
         
            +
                    feature_B = feature_B.to(device)
         
     | 
| 248 | 
         
            +
                    B = affinity_from_features(
         
     | 
| 249 | 
         
            +
                        sampled_features,
         
     | 
| 250 | 
         
            +
                        feature_B,
         
     | 
| 251 | 
         
            +
                        affinity_focal_gamma=affinity_focal_gamma,
         
     | 
| 252 | 
         
            +
                        distance=distance,
         
     | 
| 253 | 
         
            +
                        fill_diagonal=False,
         
     | 
| 254 | 
         
            +
                    )
         
     | 
| 255 | 
         
            +
                    # P is 1-hop random walk matrix
         
     | 
| 256 | 
         
            +
                    B_row = B / B.sum(axis=1, keepdim=True)
         
     | 
| 257 | 
         
            +
                    B_col = B / B.sum(axis=0, keepdim=True)
         
     | 
| 258 | 
         
            +
                    P = B_row @ B_col.T
         
     | 
| 259 | 
         
            +
                    P = (P + P.T) / 2
         
     | 
| 260 | 
         
            +
                    # fill diagonal with 0
         
     | 
| 261 | 
         
            +
                    P[torch.arange(P.shape[0]), torch.arange(P.shape[0])] = 0
         
     | 
| 262 | 
         
            +
                    A = A + P
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
                # compute normalized cut on the subgraph
         
     | 
| 265 | 
         
            +
                eigen_vector, eigen_value = ncut(A, num_eig, eig_solver=eig_solver, make_symmetric=make_symmetric)
         
     | 
| 266 | 
         
            +
                eigen_vector = eigen_vector.to(dtype=features.dtype, device=original_device)
         
     | 
| 267 | 
         
            +
                eigen_value = eigen_value.to(dtype=features.dtype, device=original_device)
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
                if no_propagation:
         
     | 
| 270 | 
         
            +
                    return eigen_vector, eigen_value, sampled_indices
         
     | 
| 271 | 
         
            +
             
     | 
| 272 | 
         
            +
                # propagate eigenvectors from subgraph to full graph
         
     | 
| 273 | 
         
            +
                eigen_vector = propagate_knn(
         
     | 
| 274 | 
         
            +
                    eigen_vector,
         
     | 
| 275 | 
         
            +
                    features,
         
     | 
| 276 | 
         
            +
                    sampled_features,
         
     | 
| 277 | 
         
            +
                    knn,
         
     | 
| 278 | 
         
            +
                    chunk_size=matmul_chunk_size,
         
     | 
| 279 | 
         
            +
                    device=device,
         
     | 
| 280 | 
         
            +
                    use_tqdm=verbose,
         
     | 
| 281 | 
         
            +
                )
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
                # post-hoc orthogonalization
         
     | 
| 284 | 
         
            +
                if make_orthogonal:
         
     | 
| 285 | 
         
            +
                    eigen_vector = gram_schmidt(eigen_vector)
         
     | 
| 286 | 
         
            +
             
     | 
| 287 | 
         
            +
                return eigen_vector, eigen_value, sampled_indices
         
     | 
    	
        requirements.txt
    CHANGED
    
    | 
         @@ -20,4 +20,4 @@ lisa @ git+https://github.com/huzeyann/LISA.git@7211e99 
     | 
|
| 20 | 
         
             
            timm==0.9.2
         
     | 
| 21 | 
         
             
            open-clip-torch==2.20.0
         
     | 
| 22 | 
         
             
            pytorch_lightning==1.9.4
         
     | 
| 23 | 
         
            -
            ncut-pytorch>=1. 
     | 
| 
         | 
|
| 20 | 
         
             
            timm==0.9.2
         
     | 
| 21 | 
         
             
            open-clip-torch==2.20.0
         
     | 
| 22 | 
         
             
            pytorch_lightning==1.9.4
         
     | 
| 23 | 
         
            +
            ncut-pytorch>=1.4.1
         
     |