wilwork's picture
Update app.py
f63dbcd verified
raw
history blame
2.25 kB
import gradio as gr
from PIL import Image
from transformers import CLIPModel, AutoTokenizer, AutoProcessor
import torch
# Load Jina CLIP model with trust_remote_code=True
model_name = "jinaai/jina-clip-v1"
model = CLIPModel.from_pretrained(model_name, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
def compute_similarity(input1, input2, type1, type2):
inputs = []
# Process input1
if type1 == "Image":
image1 = Image.open(input1).convert("RGB")
inputs.append(processor(images=image1, return_tensors="pt")["pixel_values"])
else:
inputs.append(tokenizer(input1, return_tensors="pt")["input_ids"])
# Process input2
if type2 == "Image":
image2 = Image.open(input2).convert("RGB")
inputs.append(processor(images=image2, return_tensors="pt")["pixel_values"])
else:
inputs.append(tokenizer(input2, return_tensors="pt")["input_ids"])
# Compute embeddings
with torch.no_grad():
if type1 == "Image":
embedding1 = model.get_image_features(pixel_values=inputs[0])
else:
embedding1 = model.get_text_features(input_ids=inputs[0])
if type2 == "Image":
embedding2 = model.get_image_features(pixel_values=inputs[1])
else:
embedding2 = model.get_text_features(input_ids=inputs[1])
# Compute similarity
similarity = torch.nn.functional.cosine_similarity(embedding1, embedding2)
return similarity.item()
with gr.Blocks() as demo:
gr.Markdown("# CLIP-based Similarity Comparison")
with gr.Row():
type1 = gr.Radio(["Image", "Text"], label="Input 1 Type", value="Image")
type2 = gr.Radio(["Image", "Text"], label="Input 2 Type", value="Text")
with gr.Row():
input1 = gr.File(label="Upload Image 1 or Enter Text")
input2 = gr.File(label="Upload Image 2 or Enter Text")
compare_btn = gr.Button("Compare")
output = gr.Textbox(label="Similarity Score")
compare_btn.click(compute_similarity, inputs=[input1, input2, type1, type2], outputs=output)
demo.launch()