wilwork's picture
Update app.py
629d862 verified
import gradio as gr
from PIL import Image
from transformers import CLIPModel, AutoTokenizer, AutoProcessor
import torch
# Ensure required dependencies are installed
try:
import timm
except ImportError:
import subprocess
subprocess.run(["pip", "install", "timm"], check=True)
# Load Jina CLIP model with trust_remote_code=True
model_name = "jinaai/jina-clip-v1"
model = CLIPModel.from_pretrained(model_name, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
def compute_similarity(input1, input2, text1, text2, type1, type2):
# Process input1
if type1 == "Image":
if not input1:
return "Error: No image provided for Input 1"
image1 = Image.open(input1).convert("RGB")
input1_tensor = processor(images=image1, return_tensors="pt")["pixel_values"]
elif type1 == "Text":
if not text1.strip():
return "Error: No text provided for Input 1"
input1_tensor = tokenizer(text1, return_tensors="pt")["input_ids"]
else:
return "Error: Invalid input type for Input 1"
# Process input2
if type2 == "Image":
if not input2:
return "Error: No image provided for Input 2"
image2 = Image.open(input2).convert("RGB")
input2_tensor = processor(images=image2, return_tensors="pt")["pixel_values"]
elif type2 == "Text":
if not text2.strip():
return "Error: No text provided for Input 2"
input2_tensor = tokenizer(text2, return_tensors="pt")["input_ids"]
else:
return "Error: Invalid input type for Input 2"
# Compute embeddings
with torch.no_grad():
if type1 == "Image":
embedding1 = model.get_image_features(pixel_values=input1_tensor)
else:
embedding1 = model.get_text_features(input_ids=input1_tensor)
if type2 == "Image":
embedding2 = model.get_image_features(pixel_values=input2_tensor)
else:
embedding2 = model.get_text_features(input_ids=input2_tensor)
# Normalize embeddings
embedding1 = embedding1 / embedding1.norm(dim=-1, keepdim=True)
embedding2 = embedding2 / embedding2.norm(dim=-1, keepdim=True)
# Compute cosine similarity
similarity = torch.nn.functional.cosine_similarity(embedding1, embedding2).item()
return f"Similarity Score: {similarity:.4f}"
with gr.Blocks() as demo:
gr.Markdown("# CLIP-based Similarity Comparison")
with gr.Row():
type1 = gr.Radio(["Image", "Text"], label="Input 1 Type", value="Image")
type2 = gr.Radio(["Image", "Text"], label="Input 2 Type", value="Text")
with gr.Row():
input1 = gr.Image(type="filepath", label="Upload Image 1")
input2 = gr.Image(type="filepath", label="Upload Image 2")
text1 = gr.Textbox(label="Enter Text 1")
text2 = gr.Textbox(label="Enter Text 2")
compare_btn = gr.Button("Compare")
output = gr.Textbox(label="Similarity Score")
compare_btn.click(
compute_similarity,
inputs=[
input1,
input2,
text1,
text2,
type1,
type2
],
outputs=output
)
demo.launch()