Spaces:
Running
Running
import gradio as gr | |
from PIL import Image | |
from transformers import CLIPModel, AutoTokenizer, AutoProcessor | |
import torch | |
# Load Jina CLIP model with trust_remote_code=True | |
model_name = "jinaai/jina-clip-v1" | |
model = CLIPModel.from_pretrained(model_name, trust_remote_code=True) | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) | |
def compute_similarity(input1, input2, type1, type2): | |
inputs = [] | |
# Process input1 | |
if type1 == "Image": | |
image1 = Image.open(input1).convert("RGB") | |
inputs.append(processor(images=image1, return_tensors="pt")["pixel_values"]) | |
else: | |
inputs.append(tokenizer(input1, return_tensors="pt")["input_ids"]) | |
# Process input2 | |
if type2 == "Image": | |
image2 = Image.open(input2).convert("RGB") | |
inputs.append(processor(images=image2, return_tensors="pt")["pixel_values"]) | |
else: | |
inputs.append(tokenizer(input2, return_tensors="pt")["input_ids"]) | |
# Compute embeddings | |
with torch.no_grad(): | |
if type1 == "Image": | |
embedding1 = model.get_image_features(pixel_values=inputs[0]) | |
else: | |
embedding1 = model.get_text_features(input_ids=inputs[0]) | |
if type2 == "Image": | |
embedding2 = model.get_image_features(pixel_values=inputs[1]) | |
else: | |
embedding2 = model.get_text_features(input_ids=inputs[1]) | |
# Compute similarity | |
similarity = torch.nn.functional.cosine_similarity(embedding1, embedding2) | |
return similarity.item() | |
with gr.Blocks() as demo: | |
gr.Markdown("# CLIP-based Similarity Comparison") | |
with gr.Row(): | |
type1 = gr.Radio(["Image", "Text"], label="Input 1 Type", value="Image") | |
type2 = gr.Radio(["Image", "Text"], label="Input 2 Type", value="Text") | |
with gr.Row(): | |
input1 = gr.File(label="Upload Image 1 or Enter Text") | |
input2 = gr.File(label="Upload Image 2 or Enter Text") | |
compare_btn = gr.Button("Compare") | |
output = gr.Textbox(label="Similarity Score") | |
compare_btn.click(compute_similarity, inputs=[input1, input2, type1, type2], outputs=output) | |
demo.launch() | |