|
import gradio as gr |
|
import numpy as np |
|
import clip |
|
import torch |
|
from PIL import Image |
|
import base64 |
|
from io import BytesIO |
|
from decimal import Decimal |
|
|
|
|
|
model, preprocess = clip.load("ViT-L/14@336px") |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model.to(device).eval() |
|
|
|
|
|
def find_similarity(base64_image, text_input): |
|
|
|
image_bytes = base64.b64decode(base64_image) |
|
|
|
|
|
image = Image.open(BytesIO(image_bytes)) |
|
|
|
|
|
image = preprocess(image).unsqueeze(0).to(device) |
|
|
|
|
|
text_tokens = clip.tokenize([text_input]).to(device) |
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
image_features = model.encode_image(image) |
|
text_features = model.encode_text(text_tokens) |
|
|
|
|
|
image_features /= image_features.norm(dim=-1, keepdim=True) |
|
text_features /= text_features.norm(dim=-1, keepdim=True) |
|
similarity = (text_features @ image_features.T).squeeze(0).cpu().numpy() |
|
|
|
return similarity |
|
|
|
|
|
iface = gr.Interface( |
|
fn=find_similarity, |
|
inputs=[ |
|
gr.Textbox(label="Base64 Image", lines=8), |
|
gr.Textbox(label="Text Input") |
|
], |
|
outputs="number", |
|
live=True, |
|
title="CLIP Model Image-Text Cosine Similarity", |
|
description="Upload a base64 image and enter text to find their cosine similarity.", |
|
) |
|
|
|
iface.launch() |
|
|