|
import os |
|
import torch |
|
import torch.nn.functional as F |
|
import gradio as gr |
|
import spaces |
|
from transformers import ( |
|
CLIPProcessor, |
|
CLIPModel, |
|
SiglipProcessor, |
|
SiglipModel, |
|
) |
|
|
|
|
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
MODELS = { |
|
"CLIP ViT-B/32": ("openai/clip-vit-base-patch32", 224, "clip"), |
|
"CLIP ViT-B/16": ("openai/clip-vit-base-patch16", 224, "clip"), |
|
"CLIP ViT-L/14": ("openai/clip-vit-large-patch14", 224, "clip"), |
|
"CLIP ViT-L/14@336": ("openai/clip-vit-large-patch14-336", 336, "clip"), |
|
"SigLIP Large-256": ("google/siglip-large-patch16-256", 256, "siglip"), |
|
"SigLIP Base-384": ("google/siglip-base-patch16-384", 384, "siglip"), |
|
"SigLIP Large-384": ("google/siglip-large-patch16-384", 384, "siglip"), |
|
} |
|
|
|
|
|
|
|
|
|
_models, _processors = {}, {} |
|
|
|
def _load_model(name: str): |
|
path, _, kind = MODELS[name] |
|
|
|
kwargs = dict( |
|
low_cpu_mem_usage=False, |
|
torch_dtype=torch.float16, |
|
) |
|
|
|
if kind == "clip": |
|
model = CLIPModel.from_pretrained(path, **kwargs).to(DEVICE) |
|
processor = CLIPProcessor.from_pretrained(path) |
|
else: |
|
model = SiglipModel.from_pretrained(path, **kwargs).to(DEVICE) |
|
processor = SiglipProcessor.from_pretrained(path) |
|
|
|
model.eval() |
|
return model, processor |
|
|
|
def get_model(name: str): |
|
if name not in _models: |
|
_models[name], _processors[name] = _load_model(name) |
|
return _models[name], _processors[name] |
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
def calculate_score(image, text: str, model_name: str): |
|
labels = [t.strip() for t in text.split(";") if t.strip()] |
|
if not labels: |
|
return {} |
|
|
|
model, processor = get_model(model_name) |
|
kind = MODELS[model_name][2] |
|
|
|
inputs = processor( |
|
text=labels, |
|
images=image, |
|
padding=True, |
|
return_tensors="pt", |
|
).to(DEVICE) |
|
|
|
with torch.no_grad(): |
|
if kind == "clip": |
|
out = model(**inputs) |
|
img_emb = out.image_embeds |
|
txt_emb = out.text_embeds |
|
else: |
|
img_emb = model.get_image_features(pixel_values=inputs["pixel_values"]) |
|
txt_emb = model.get_text_features( |
|
input_ids=inputs["input_ids"], |
|
attention_mask=inputs["attention_mask"], |
|
) |
|
|
|
img_emb = F.normalize(img_emb, p=2, dim=-1) |
|
txt_emb = F.normalize(txt_emb, p=2, dim=-1) |
|
|
|
scores = (txt_emb @ img_emb.T).squeeze(1) |
|
if kind == "siglip": |
|
scores = torch.sigmoid(scores) |
|
|
|
return {lbl: float(score.clamp(0, 1)) for lbl, score in zip(labels, scores.cpu())} |
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="CLIP / SigLIP Image-Text Similarity") as demo: |
|
gr.Markdown("## Compare an image with multiple text prompts") |
|
|
|
with gr.Row(): |
|
image_in = gr.Image(type="pil", label="Image") |
|
score_out = gr.Label(label="Similarity (0‒1)") |
|
|
|
with gr.Row(): |
|
text_in = gr.Textbox( |
|
label="Text prompts (use ‘;’ to separate)", |
|
placeholder="a cat; a flying cat; a dog", |
|
) |
|
model_in = gr.Dropdown( |
|
choices=list(MODELS.keys()), |
|
value="CLIP ViT-B/16", |
|
label="Model", |
|
) |
|
|
|
def infer(img, txt, mdl): |
|
return calculate_score(img, txt, mdl) if img and txt.strip() else {} |
|
|
|
for comp in (image_in, text_in, model_in): |
|
comp.change(infer, [image_in, text_in, model_in], score_out) |
|
|
|
gr.Examples( |
|
examples=[ |
|
["cat.jpg", |
|
"a cat stuck in a door; a cat jumping; a dog", |
|
"CLIP ViT-B/16"], |
|
], |
|
inputs=[image_in, text_in, model_in], |
|
outputs=score_out, |
|
) |
|
|
|
demo.launch() |