File size: 3,854 Bytes
3d5d03a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de35370
3d5d03a
 
de35370
 
 
 
 
 
 
 
 
 
 
 
 
 
3d5d03a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f892f53
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
import numpy as np
import gradio as gr
import plotly.express as px
import plotly.graph_objects as go
from sklearn.decomposition import PCA
from torchvision import transforms as T
from sklearn.preprocessing import MinMaxScaler


device = "cuda" if torch.cuda.is_available() else "cpu"

dino = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
dino.eval()
dino.to(device)

pca = PCA(n_components=3)
scaler = MinMaxScaler(clip=True)

def plot_img(img_array: np.array) -> go.Figure:
    fig = px.imshow(img_array)
    fig.update_layout(
        xaxis=dict(showticklabels=False),
        yaxis=dict(showticklabels=False)
    )

    return fig


def app_fn(
        img: np.ndarray, 
        threshold: float, 
        object_larger_than_bg: bool
    ) -> go.Figure:
    IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
    IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)

    patch_h = 40
    patch_w = 40

    transform = T.Compose([
        T.Resize((14 * patch_h, 14 * patch_w)),
        T.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
    ])

    img = torch.from_numpy(img).type(torch.float).permute(2, 0, 1) / 255
    img_tensor = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        out = dino.forward_features(img_tensor)

    features = out["x_prenorm"][:, 1:, :]
    features = features.squeeze(0)
    features = features.cpu().numpy()

    pca_features = pca.fit_transform(features)
    pca_features = scaler.fit_transform(pca_features)

    if object_larger_than_bg:
        pca_features_bg = pca_features[:, 0] > threshold
    else: 
        pca_features_bg = pca_features[:, 0] < threshold
        
    pca_features_fg = ~pca_features_bg

    pca_features_fg_seg = pca.fit_transform(features[pca_features_fg])

    pca_features_fg_seg = scaler.fit_transform(pca_features_fg_seg)

    pca_features_rgb = np.zeros((patch_h * patch_w, 3))
    pca_features_rgb[pca_features_bg] = 0
    pca_features_rgb[pca_features_fg] = pca_features_fg_seg
    pca_features_rgb = pca_features_rgb.reshape(patch_h, patch_w, 3)

    
    fig_pca = plot_img(pca_features_rgb)

    return fig_pca

if __name__=="__main__":
    title = "πŸ¦– DINOv2 Features Visualization πŸ¦–"
    with gr.Blocks(title=title) as demo:
        gr.Markdown(f"# {title}")
        gr.Markdown(
            """
            ### This app visualizes the features extracted by [DINOv2](https://arxiv.org/pdf/2304.07193.pdf) model. \
            To create the visualizations we use a 2-step PCA. \
            The first step we reduce the features to 3 dimensions and then threshold the first component \
            to segment the background and foreground. Then, we run the second PCA on the foreground features \
            so we can visualize foreground objects as RGB.

            [Paper](https://arxiv.org/pdf/2304.07193.pdf)
            [Github](https://github.com/facebookresearch/dinov2)

            Created by: [Eduardo Pacheco](https://github.com/EduardoPach)
            """
        )
        with gr.Row():
            threshold = gr.Slider(minimum=0, maximum=1, value=0.6, step=0.05, label="Threshold")
            object_larger_than_bg = gr.Checkbox(label="Object Larger than Background", value=False)
        btn = gr.Button(label="Visualize")
        with gr.Row():
            img = gr.Image()
            fig_pca = gr.Plot(label="PCA Features")
        
        btn.click(fn=app_fn, inputs=[img, threshold, object_larger_than_bg], outputs=[fig_pca])
        examples = gr.Examples(
            examples=[
                ["assets/neca-the-cat.jpeg", 0.6, True],
                ["assets/dog.png", 0.7, False]
            ],
            inputs=[img, threshold, object_larger_than_bg],
            outputs=[fig_pca],
            fn=app_fn,
            cache_examples=True
        )

    demo.queue(max_size=5).launch()