File size: 4,576 Bytes
d6fa528 d60d34b 3ccdd83 acda6c7 d6fa528 acda6c7 9270f3d d6fa528 61b7eee d6fa528 61b7eee d6fa528 9270f3d d6fa528 acda6c7 d6fa528 73f9f45 d6fa528 61b7eee d6fa528 9270f3d d60d34b d6fa528 9270f3d 61b7eee d6fa528 9270f3d d6fa528 9270f3d d6fa528 9270f3d d60d34b d6fa528 d60d34b d6fa528 3601eff d6fa528 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import os
import torch
import torch.nn.functional as F
import gradio as gr
import spaces # ← keep this!
from transformers import (
CLIPProcessor,
CLIPModel,
SiglipProcessor, # transformers ≥ 4.40
SiglipModel,
)
# ---------------------------------------------------------------------
# 1. CONFIG
# ---------------------------------------------------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODELS = {
"CLIP ViT-B/32": ("openai/clip-vit-base-patch32", 224, "clip"),
"CLIP ViT-B/16": ("openai/clip-vit-base-patch16", 224, "clip"),
"CLIP ViT-L/14": ("openai/clip-vit-large-patch14", 224, "clip"),
"CLIP ViT-L/14@336": ("openai/clip-vit-large-patch14-336", 336, "clip"),
"SigLIP Large-256": ("google/siglip-large-patch16-256", 256, "siglip"),
"SigLIP Base-384": ("google/siglip-base-patch16-384", 384, "siglip"),
"SigLIP Large-384": ("google/siglip-large-patch16-384", 384, "siglip"),
}
# ---------------------------------------------------------------------
# 2. LAZY MODEL LOADING
# ---------------------------------------------------------------------
_models, _processors = {}, {}
def _load_model(name: str):
path, _, kind = MODELS[name]
kwargs = dict(
low_cpu_mem_usage=False, # avoid meta-device bug
torch_dtype=torch.float16, # faster & smaller
)
if kind == "clip":
model = CLIPModel.from_pretrained(path, **kwargs).to(DEVICE)
processor = CLIPProcessor.from_pretrained(path)
else:
model = SiglipModel.from_pretrained(path, **kwargs).to(DEVICE)
processor = SiglipProcessor.from_pretrained(path)
model.eval()
return model, processor
def get_model(name: str):
if name not in _models:
_models[name], _processors[name] = _load_model(name)
return _models[name], _processors[name]
# ---------------------------------------------------------------------
# 3. SCORING FUNCTION (runs on GPU in Spaces)
# ---------------------------------------------------------------------
@spaces.GPU
def calculate_score(image, text: str, model_name: str):
labels = [t.strip() for t in text.split(";") if t.strip()]
if not labels:
return {}
model, processor = get_model(model_name)
kind = MODELS[model_name][2]
inputs = processor(
text=labels,
images=image,
padding=True,
return_tensors="pt",
).to(DEVICE)
with torch.no_grad():
if kind == "clip":
out = model(**inputs)
img_emb = out.image_embeds
txt_emb = out.text_embeds
else:
img_emb = model.get_image_features(pixel_values=inputs["pixel_values"])
txt_emb = model.get_text_features(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
)
img_emb = F.normalize(img_emb, p=2, dim=-1)
txt_emb = F.normalize(txt_emb, p=2, dim=-1)
scores = (txt_emb @ img_emb.T).squeeze(1) # cosine
if kind == "siglip":
scores = torch.sigmoid(scores) # paper’s choice
return {lbl: float(score.clamp(0, 1)) for lbl, score in zip(labels, scores.cpu())}
# ---------------------------------------------------------------------
# 4. GRADIO UI
# ---------------------------------------------------------------------
with gr.Blocks(title="CLIP / SigLIP Image-Text Similarity") as demo:
gr.Markdown("## Compare an image with multiple text prompts")
with gr.Row():
image_in = gr.Image(type="pil", label="Image")
score_out = gr.Label(label="Similarity (0‒1)")
with gr.Row():
text_in = gr.Textbox(
label="Text prompts (use ‘;’ to separate)",
placeholder="a cat; a flying cat; a dog",
)
model_in = gr.Dropdown(
choices=list(MODELS.keys()),
value="CLIP ViT-B/16",
label="Model",
)
def infer(img, txt, mdl):
return calculate_score(img, txt, mdl) if img and txt.strip() else {}
for comp in (image_in, text_in, model_in):
comp.change(infer, [image_in, text_in, model_in], score_out)
gr.Examples(
examples=[
["cat.jpg",
"a cat stuck in a door; a cat jumping; a dog",
"CLIP ViT-B/16"],
],
inputs=[image_in, text_in, model_in],
outputs=score_out,
)
demo.launch() |