File size: 4,576 Bytes
d6fa528
d60d34b
3ccdd83
acda6c7
d6fa528
 
 
 
 
 
 
 
 
 
 
 
acda6c7
9270f3d
d6fa528
 
 
 
 
 
 
61b7eee
 
d6fa528
 
 
 
61b7eee
d6fa528
 
9270f3d
d6fa528
 
 
 
 
 
 
 
 
 
 
 
 
 
acda6c7
d6fa528
 
 
 
 
 
 
 
73f9f45
d6fa528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61b7eee
d6fa528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9270f3d
d60d34b
d6fa528
 
9270f3d
61b7eee
d6fa528
 
 
 
 
 
 
 
9270f3d
 
d6fa528
 
9270f3d
d6fa528
 
9270f3d
d60d34b
 
d6fa528
 
 
d60d34b
d6fa528
 
3601eff
 
d6fa528
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import torch
import torch.nn.functional as F
import gradio as gr
import spaces                               # ← keep this!
from transformers import (
    CLIPProcessor,
    CLIPModel,
    SiglipProcessor,                        # transformers ≥ 4.40
    SiglipModel,
)

# ---------------------------------------------------------------------
# 1.  CONFIG
# ---------------------------------------------------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

MODELS = {
    "CLIP ViT-B/32":      ("openai/clip-vit-base-patch32",        224, "clip"),
    "CLIP ViT-B/16":      ("openai/clip-vit-base-patch16",        224, "clip"),
    "CLIP ViT-L/14":      ("openai/clip-vit-large-patch14",       224, "clip"),
    "CLIP ViT-L/14@336":  ("openai/clip-vit-large-patch14-336",   336, "clip"),
    "SigLIP Large-256":   ("google/siglip-large-patch16-256",     256, "siglip"),
    "SigLIP Base-384":    ("google/siglip-base-patch16-384",      384, "siglip"),
    "SigLIP Large-384":   ("google/siglip-large-patch16-384",     384, "siglip"),
}

# ---------------------------------------------------------------------
# 2.  LAZY MODEL LOADING
# ---------------------------------------------------------------------
_models, _processors = {}, {}

def _load_model(name: str):
    path, _, kind = MODELS[name]

    kwargs = dict(
        low_cpu_mem_usage=False,     # avoid meta-device bug
        torch_dtype=torch.float16,   # faster & smaller
    )

    if kind == "clip":
        model     = CLIPModel.from_pretrained(path, **kwargs).to(DEVICE)
        processor = CLIPProcessor.from_pretrained(path)
    else:
        model     = SiglipModel.from_pretrained(path, **kwargs).to(DEVICE)
        processor = SiglipProcessor.from_pretrained(path)

    model.eval()
    return model, processor

def get_model(name: str):
    if name not in _models:
        _models[name], _processors[name] = _load_model(name)
    return _models[name], _processors[name]

# ---------------------------------------------------------------------
# 3.  SCORING FUNCTION (runs on GPU in Spaces)
# ---------------------------------------------------------------------
@spaces.GPU
def calculate_score(image, text: str, model_name: str):
    labels = [t.strip() for t in text.split(";") if t.strip()]
    if not labels:
        return {}

    model, processor = get_model(model_name)
    kind = MODELS[model_name][2]

    inputs = processor(
        text=labels,
        images=image,
        padding=True,
        return_tensors="pt",
    ).to(DEVICE)

    with torch.no_grad():
        if kind == "clip":
            out       = model(**inputs)
            img_emb   = out.image_embeds
            txt_emb   = out.text_embeds
        else:
            img_emb = model.get_image_features(pixel_values=inputs["pixel_values"])
            txt_emb = model.get_text_features(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
            )

    img_emb = F.normalize(img_emb, p=2, dim=-1)
    txt_emb = F.normalize(txt_emb, p=2, dim=-1)

    scores = (txt_emb @ img_emb.T).squeeze(1)          # cosine
    if kind == "siglip":
        scores = torch.sigmoid(scores)                 # paper’s choice

    return {lbl: float(score.clamp(0, 1)) for lbl, score in zip(labels, scores.cpu())}

# ---------------------------------------------------------------------
# 4.  GRADIO UI
# ---------------------------------------------------------------------
with gr.Blocks(title="CLIP / SigLIP Image-Text Similarity") as demo:
    gr.Markdown("## Compare an image with multiple text prompts")

    with gr.Row():
        image_in  = gr.Image(type="pil", label="Image")
        score_out = gr.Label(label="Similarity (0‒1)")

    with gr.Row():
        text_in = gr.Textbox(
            label="Text prompts (use ‘;’ to separate)",
            placeholder="a cat; a flying cat; a dog",
        )
        model_in = gr.Dropdown(
            choices=list(MODELS.keys()),
            value="CLIP ViT-B/16",
            label="Model",
        )

    def infer(img, txt, mdl):
        return calculate_score(img, txt, mdl) if img and txt.strip() else {}

    for comp in (image_in, text_in, model_in):
        comp.change(infer, [image_in, text_in, model_in], score_out)

    gr.Examples(
        examples=[
            ["cat.jpg",
             "a cat stuck in a door; a cat jumping; a dog",
             "CLIP ViT-B/16"],
        ],
        inputs=[image_in, text_in, model_in],
        outputs=score_out,
    )

demo.launch()