CLIPScore / app.py
taesiri's picture
Update app.py
d6fa528 verified
import os
import torch
import torch.nn.functional as F
import gradio as gr
import spaces # ← keep this!
from transformers import (
CLIPProcessor,
CLIPModel,
SiglipProcessor, # transformers ≥ 4.40
SiglipModel,
)
# ---------------------------------------------------------------------
# 1. CONFIG
# ---------------------------------------------------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODELS = {
"CLIP ViT-B/32": ("openai/clip-vit-base-patch32", 224, "clip"),
"CLIP ViT-B/16": ("openai/clip-vit-base-patch16", 224, "clip"),
"CLIP ViT-L/14": ("openai/clip-vit-large-patch14", 224, "clip"),
"CLIP ViT-L/14@336": ("openai/clip-vit-large-patch14-336", 336, "clip"),
"SigLIP Large-256": ("google/siglip-large-patch16-256", 256, "siglip"),
"SigLIP Base-384": ("google/siglip-base-patch16-384", 384, "siglip"),
"SigLIP Large-384": ("google/siglip-large-patch16-384", 384, "siglip"),
}
# ---------------------------------------------------------------------
# 2. LAZY MODEL LOADING
# ---------------------------------------------------------------------
_models, _processors = {}, {}
def _load_model(name: str):
path, _, kind = MODELS[name]
kwargs = dict(
low_cpu_mem_usage=False, # avoid meta-device bug
torch_dtype=torch.float16, # faster & smaller
)
if kind == "clip":
model = CLIPModel.from_pretrained(path, **kwargs).to(DEVICE)
processor = CLIPProcessor.from_pretrained(path)
else:
model = SiglipModel.from_pretrained(path, **kwargs).to(DEVICE)
processor = SiglipProcessor.from_pretrained(path)
model.eval()
return model, processor
def get_model(name: str):
if name not in _models:
_models[name], _processors[name] = _load_model(name)
return _models[name], _processors[name]
# ---------------------------------------------------------------------
# 3. SCORING FUNCTION (runs on GPU in Spaces)
# ---------------------------------------------------------------------
@spaces.GPU
def calculate_score(image, text: str, model_name: str):
labels = [t.strip() for t in text.split(";") if t.strip()]
if not labels:
return {}
model, processor = get_model(model_name)
kind = MODELS[model_name][2]
inputs = processor(
text=labels,
images=image,
padding=True,
return_tensors="pt",
).to(DEVICE)
with torch.no_grad():
if kind == "clip":
out = model(**inputs)
img_emb = out.image_embeds
txt_emb = out.text_embeds
else:
img_emb = model.get_image_features(pixel_values=inputs["pixel_values"])
txt_emb = model.get_text_features(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
)
img_emb = F.normalize(img_emb, p=2, dim=-1)
txt_emb = F.normalize(txt_emb, p=2, dim=-1)
scores = (txt_emb @ img_emb.T).squeeze(1) # cosine
if kind == "siglip":
scores = torch.sigmoid(scores) # paper’s choice
return {lbl: float(score.clamp(0, 1)) for lbl, score in zip(labels, scores.cpu())}
# ---------------------------------------------------------------------
# 4. GRADIO UI
# ---------------------------------------------------------------------
with gr.Blocks(title="CLIP / SigLIP Image-Text Similarity") as demo:
gr.Markdown("## Compare an image with multiple text prompts")
with gr.Row():
image_in = gr.Image(type="pil", label="Image")
score_out = gr.Label(label="Similarity (0‒1)")
with gr.Row():
text_in = gr.Textbox(
label="Text prompts (use ‘;’ to separate)",
placeholder="a cat; a flying cat; a dog",
)
model_in = gr.Dropdown(
choices=list(MODELS.keys()),
value="CLIP ViT-B/16",
label="Model",
)
def infer(img, txt, mdl):
return calculate_score(img, txt, mdl) if img and txt.strip() else {}
for comp in (image_in, text_in, model_in):
comp.change(infer, [image_in, text_in, model_in], score_out)
gr.Examples(
examples=[
["cat.jpg",
"a cat stuck in a door; a cat jumping; a dog",
"CLIP ViT-B/16"],
],
inputs=[image_in, text_in, model_in],
outputs=score_out,
)
demo.launch()