File size: 12,916 Bytes
e17f35c
d73e700
4175ab9
 
 
 
 
 
 
d73e700
 
 
 
 
 
 
 
 
 
 
 
4175ab9
 
 
 
d73e700
 
 
 
 
4175ab9
 
 
d73e700
4175ab9
 
 
 
d73e700
4175ab9
d73e700
 
 
 
 
 
 
4175ab9
d73e700
 
 
 
 
4175ab9
d73e700
4175ab9
d73e700
 
 
4175ab9
 
 
d73e700
 
 
 
 
 
 
 
 
 
 
4175ab9
 
e17f35c
4175ab9
 
d73e700
 
 
 
4175ab9
 
 
 
 
 
 
 
 
 
 
d73e700
4175ab9
 
 
d73e700
4175ab9
 
d73e700
 
 
 
4175ab9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d73e700
4175ab9
 
e17f35c
d73e700
 
 
4175ab9
 
 
 
d73e700
 
4175ab9
d73e700
 
ff51b2a
4175ab9
d73e700
 
4175ab9
d73e700
4175ab9
 
 
 
d73e700
4175ab9
 
 
 
 
 
 
 
 
 
 
 
d73e700
4175ab9
 
 
 
d73e700
4175ab9
 
 
 
d73e700
4175ab9
 
 
d73e700
4175ab9
 
 
 
 
 
 
 
d73e700
4175ab9
 
 
 
 
 
d73e700
4175ab9
 
d73e700
4175ab9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d73e700
4175ab9
 
 
 
d55a3e3
4175ab9
d73e700
 
d55a3e3
d73e700
4175ab9
d73e700
4175ab9
d73e700
4175ab9
d55a3e3
4175ab9
 
 
d73e700
 
 
 
 
 
 
4175ab9
 
d73e700
4175ab9
d55a3e3
d73e700
 
 
 
 
 
d55a3e3
d73e700
 
 
 
 
 
 
 
 
d55a3e3
d73e700
 
 
 
 
 
 
 
 
 
 
 
d55a3e3
 
4175ab9
 
d73e700
 
e17f35c
d73e700
 
4175ab9
d55a3e3
4175ab9
 
d73e700
 
4175ab9
 
 
 
 
d55a3e3
 
 
 
 
 
 
 
 
 
 
 
 
 
ff51b2a
d55a3e3
 
 
 
 
 
ff51b2a
d55a3e3
4175ab9
 
 
 
 
 
 
d73e700
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
# app.py

import os
import torch
import torch.nn.functional as F
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw
import torchvision.transforms.functional as TF

# --- Robust colormap import (Matplotlib ≥3.5 and older versions) ---
try:
    from matplotlib import colormaps as _mpl_colormaps
    def _get_cmap(name: str):
        return _mpl_colormaps[name]
except Exception:
    import matplotlib.cm as _cm
    def _get_cmap(name: str):
        return _cm.get_cmap(name)

from transformers import AutoModel  # uses trust_remote_code for DINOv3

# ----------------------------
# Configuration
# ----------------------------
# Default to smaller/faster ViT-S/16+; offer ViT-H/16+ as alternative.
DEFAULT_MODEL_ID = "facebook/dinov3-vits16plus-pretrain-lvd1689m"
ALT_MODEL_ID = "facebook/dinov3-vith16plus-pretrain-lvd1689m"
AVAILABLE_MODELS = [DEFAULT_MODEL_ID, ALT_MODEL_ID]

PATCH_SIZE = 16
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Normalization constants (standard for ImageNet)
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

# ----------------------------
# Model Loading (Hugging Face Hub) with caching
# ----------------------------
_model_cache = {}
_current_model_id = None
model = None  # global reference used by extract_image_features()

def load_model_from_hub(model_id: str):
    """Loads a DINOv3 model from the Hugging Face Hub."""
    print(f"Loading model '{model_id}' from Hugging Face Hub...")
    try:
        token = os.environ.get("HF_TOKEN")  # optional, for gated models
        mdl = AutoModel.from_pretrained(model_id, token=token, trust_remote_code=True)
        mdl.to(DEVICE).eval()
        print(f"✅ Model '{model_id}' loaded successfully on device: {DEVICE}")
        return mdl
    except Exception as e:
        print(f"❌ Failed to load model '{model_id}': {e}")
        raise gr.Error(
            f"Could not load model '{model_id}'. "
            "If the model is gated, please accept the terms on its Hugging Face page "
            "and set HF_TOKEN in your environment. "
            f"Original error: {e}"
        )

def get_model(model_id: str):
    """Return a cached model if available, otherwise load and cache it."""
    if model_id in _model_cache:
        return _model_cache[model_id]
    mdl = load_model_from_hub(model_id)
    _model_cache[model_id] = mdl
    return mdl

# Load default model at startup
model = get_model(DEFAULT_MODEL_ID)
_current_model_id = DEFAULT_MODEL_ID

# ----------------------------
# Helper Functions (resize, viz)
# ----------------------------
def resize_to_grid(img: Image.Image, long_side: int, patch: int) -> torch.Tensor:
    """
    Resizes so max(h,w)=long_side (keeping aspect), then rounds each side UP to a multiple of 'patch'.
    Returns CHW float tensor in [0,1].
    """
    w, h = img.size
    scale = long_side / max(h, w)
    new_h = max(patch, int(round(h * scale)))
    new_w = max(patch, int(round(w * scale)))
    new_h = ((new_h + patch - 1) // patch) * patch
    new_w = ((new_w + patch - 1) // patch) * patch
    return TF.to_tensor(TF.resize(img.convert("RGB"), (new_h, new_w)))

def colorize(sim_map_up: np.ndarray, cmap_name: str = "viridis") -> Image.Image:
    x = sim_map_up.astype(np.float32)
    x = (x - x.min()) / (x.max() - x.min() + 1e-6)
    rgb = (_get_cmap(cmap_name)(x)[..., :3] * 255).astype(np.uint8)
    return Image.fromarray(rgb)

def blend(base: Image.Image, heat: Image.Image, alpha: float = 0.55) -> Image.Image:
    # Put alpha on heatmap and composite for a crisp overlay
    base = base.convert("RGBA")
    heat = heat.convert("RGBA")
    a = Image.new("L", heat.size, int(255 * alpha))
    heat.putalpha(a)
    out = Image.alpha_composite(base, heat)
    return out.convert("RGB")

def draw_crosshair(img: Image.Image, x: int, y: int, radius: int = None) -> Image.Image:
    r = radius if radius is not None else max(2, PATCH_SIZE // 2)
    out = img.copy()
    draw = ImageDraw.Draw(out)
    draw.line([(x - r, y), (x + r, y)], fill="red", width=3)
    draw.line([(x, y - r), (x, y + r)], fill="red", width=3)
    return out

def draw_boxes(img: Image.Image, boxes, outline="yellow", width=3, labels=True):
    out = img.copy()
    draw = ImageDraw.Draw(out)
    for i, (x0, y0, x1, y1) in enumerate(boxes, start=1):
        draw.rectangle([x0, y0, x1, y1], outline=outline, width=width)
        if labels:
            tx, ty = x0 + 2, y0 + 2
            draw.text((tx, ty), str(i), fill=outline)
    return out

def patch_neighborhood_box(r: int, c: int, Hp: int, Wp: int, rad: int, patch: int = PATCH_SIZE):
    r0 = max(0, r - rad)
    r1 = min(Hp - 1, r + rad)
    c0 = max(0, c - rad)
    c1 = min(Wp - 1, c + rad)
    x0 = int(c0 * patch)
    y0 = int(r0 * patch)
    x1 = int((c1 + 1) * patch) - 1
    y1 = int((r1 + 1) * patch) - 1
    return (x0, y0, x1, y1)

# ----------------------------
# Feature Extraction (using transformers)
# ----------------------------
@torch.inference_mode()
def extract_image_features(image_pil: Image.Image, target_long_side: int):
    """
    Extracts patch features from an image using the loaded Hugging Face model.
    """
    t = resize_to_grid(image_pil, target_long_side, PATCH_SIZE)
    t_norm = TF.normalize(t, IMAGENET_MEAN, IMAGENET_STD).unsqueeze(0).to(DEVICE)
    _, _, H, W = t_norm.shape
    Hp, Wp = H // PATCH_SIZE, W // PATCH_SIZE

    # Models output: [CLS] + 4 register tokens + patches
    outputs = model(t_norm)

    # Skip the 5 special tokens to get only patch embeddings
    n_special_tokens = 5
    patch_embeddings = outputs.last_hidden_state.squeeze(0)[n_special_tokens:, :]

    # L2-normalize features for cosine similarity
    X = F.normalize(patch_embeddings, p=2, dim=-1)

    img_resized = TF.to_pil_image(t)
    return {"X": X, "Hp": Hp, "Wp": Wp, "img": img_resized}

# ----------------------------
# Similarity inside the same image
# ----------------------------
def click_to_similarity_in_same_image(
    state: dict,
    click_xy: tuple[int, int],
    exclude_radius_patches: int = 1,
    topk: int = 10,
    alpha: float = 0.55,
    cmap_name: str = "viridis",
    box_radius_patches: int = 4,
):
    if not state:
        return None, None, None, None

    X = state["X"]
    Hp, Wp = state["Hp"], state["Wp"]
    base_img = state["img"]
    img_w, img_h = base_img.size

    x_pix, y_pix = click_xy
    col = int(np.clip(x_pix // PATCH_SIZE, 0, Wp - 1))
    row = int(np.clip(y_pix // PATCH_SIZE, 0, Hp - 1))
    idx = row * Wp + col

    q = X[idx]
    sims = torch.matmul(X, q)
    sim_map = sims.view(Hp, Wp)

    if exclude_radius_patches > 0:
        rr, cc = torch.meshgrid(
            torch.arange(Hp, device=sims.device),
            torch.arange(Wp, device=sims.device),
            indexing="ij",
        )
        mask = (torch.abs(rr - row) <= exclude_radius_patches) & (torch.abs(cc - col) <= exclude_radius_patches)
        sim_map = sim_map.masked_fill(mask, float("-inf"))

    sim_up = F.interpolate(
        sim_map.unsqueeze(0).unsqueeze(0),
        size=(img_h, img_w),
        mode="bicubic",
        align_corners=False,
    ).squeeze().detach().cpu().numpy()

    heatmap_pil = colorize(sim_up, cmap_name)
    overlay_pil = blend(base_img, heatmap_pil, alpha=alpha)

    overlay_boxes_pil = overlay_pil
    if topk and topk > 0:
        flat = sim_map.view(-1)
        valid = torch.isfinite(flat)
        if valid.any():
            vals = flat.clone()
            vals[~valid] = -1e9
            k = min(topk, int(valid.sum().item()))
            _, top_idx = torch.topk(vals, k=k, largest=True, sorted=True)
            boxes = [
                patch_neighborhood_box(
                    r, c, Hp, Wp, rad=int(box_radius_patches), patch=PATCH_SIZE
                )
                for r, c in [divmod(j.item(), Wp) for j in top_idx]
            ]
            overlay_boxes_pil = draw_boxes(overlay_pil, boxes, outline="yellow", width=3, labels=True)

    marked_ref = draw_crosshair(base_img, x_pix, y_pix, radius=PATCH_SIZE // 2)
    return marked_ref, heatmap_pil, overlay_pil, overlay_boxes_pil

# ----------------------------
# Gradio UI (Manual-only processing)
# ----------------------------
with gr.Blocks(theme=gr.themes.Soft(), title="DINOv3 Single-Image Patch Similarity") as demo:
    gr.Markdown("# 🦖 DINOv3 Single-Image Patch Similarity")
    gr.Markdown("Upload one image, adjust settings, then press **▶️ Start processing**. Click on the processed image to find similar regions.")

    app_state = gr.State()

    with gr.Row():
        with gr.Column(scale=1):
            input_image = gr.Image(
                label="Image (click anywhere after processing)",
                type="pil",
                value="https://images.squarespace-cdn.com/content/v1/607f89e638219e13eee71b1e/1684821560422-SD5V37BAG28BURTLIXUQ/michael-sum-LEpfefQf4rU-unsplash.jpg"
            )
            target_long_side = gr.Slider(
                minimum=224, maximum=1024, value=768, step=16,
                label="Processing Resolution",
                info="Higher values = more detail but slower processing",
            )
            with gr.Row():
                alpha = gr.Slider(0.0, 1.0, value=0.55, step=0.05, label="Overlay opacity")
                cmap = gr.Dropdown(
                    ["viridis", "magma", "plasma", "inferno", "turbo", "cividis"],
                    value="viridis", label="Colormap",
                )
            # Backbone selector (default = smaller/faster ViT-S/16+)
            model_choice = gr.Dropdown(
                choices=AVAILABLE_MODELS,
                value=DEFAULT_MODEL_ID,
                label="Backbone (DINOv3)",
                info="ViT-S/16+ is smaller & faster; ViT-H/16+ is larger.",
            )
            # Start processing button (manual trigger)
            with gr.Row():
                start_btn = gr.Button("▶️ Start processing", variant="primary")

        with gr.Column(scale=1):
            exclude_r = gr.Slider(0, 10, value=0, step=1, label="Exclude radius (patches)")
            topk = gr.Slider(0, 200, value=20, step=1, label="Top-K boxes")
            box_radius = gr.Slider(0, 10, value=1, step=1, label="Box radius (patches)")

    with gr.Row():
        marked_image = gr.Image(label="Click marker / Preview", interactive=False)
        heatmap_output = gr.Image(label="Similarity heatmap", interactive=False)
    with gr.Row():
        overlay_output = gr.Image(label="Overlay (image ⊕ heatmap)", interactive=False)
        overlay_boxes_output = gr.Image(label="Overlay + top-K similar patch boxes", interactive=False)

    def _ensure_model(model_id: str):
        """Ensure the global 'model' matches the dropdown selection."""
        global model, _current_model_id
        if model_id != _current_model_id:
            model = get_model(model_id)
            _current_model_id = model_id

    # Manual feature extraction (only runs on Start button)
    def _run_extraction(img: Image.Image, long_side: int, model_id: str, progress=gr.Progress(track_tqdm=True)):
        if img is None:
            return None, None
        _ensure_model(model_id)
        progress(0, desc="Extracting features...")
        st = extract_image_features(img, int(long_side))
        progress(1, desc="Done!")
        return st["img"], st

    # Clicking on processed image to compute similarities
    def _on_click(st, a: float, m: str, excl: int, k: int, box_rad: int, evt: gr.SelectData):
        if not st or evt is None:
            return None, None, None, None
        return click_to_similarity_in_same_image(
            st, click_xy=evt.index, exclude_radius_patches=int(excl),
            topk=int(k), alpha=float(a), cmap_name=m,
            box_radius_patches=int(box_rad),
        )

    # On image change: just preview and clear outputs/state (NO extraction)
    def _on_image_changed(img: Image.Image):
        if img is None:
            return None, None, None, None, None
        return img, None, None, None, None

    # ---------- Wiring (Manual mode) ----------
    # Do NOT auto-run on upload/slider/model change or on app load.
    # Only the Start button triggers extraction.
    start_btn.click(
        _run_extraction,
        inputs=[input_image, target_long_side, model_choice],
        outputs=[marked_image, app_state],
    )

    # When a new image is picked, show it as preview and clear old results.
    input_image.change(
        _on_image_changed,
        inputs=[input_image],
        outputs=[marked_image, app_state, heatmap_output, overlay_output, overlay_boxes_output],
    )

    # Keep click handler the same.
    marked_image.select(
        _on_click,
        inputs=[app_state, alpha, cmap, exclude_r, topk, box_radius],
        outputs=[marked_image, heatmap_output, overlay_output, overlay_boxes_output],
    )

if __name__ == "__main__":
    demo.launch()