KC / app.py
wilwork's picture
Update app.py
cf604df verified
raw
history blame
3.28 kB
import gradio as gr
from transformers import AutoModel
from PIL import Image
import torch
import numpy as np
# Load JinaAI CLIP model
model = AutoModel.from_pretrained('jinaai/jina-clip-v1', trust_remote_code=True)
def compute_similarity(input1, input2):
"""
Computes similarity between:
- Image and Text
- Image and Image
- Text and Text
"""
# Detect input types
input1_is_text = isinstance(input1, str) and input1.strip() != ""
input2_is_text = isinstance(input2, str) and input2.strip() != ""
input1_is_image = isinstance(input1, np.ndarray)
input2_is_image = isinstance(input2, np.ndarray)
# Ensure valid input
if not (input1_is_text or input1_is_image) or not (input2_is_text or input2_is_image):
return "Error: Both inputs must be valid (image or text)!"
try:
with torch.no_grad():
if input1_is_text and input2_is_text:
# Text-Text Similarity
emb1 = model.encode_text([input1])
emb2 = model.encode_text([input2])
elif input1_is_image and input2_is_image:
# Image-Image Similarity
image1 = Image.fromarray(input1)
image2 = Image.fromarray(input2)
emb1 = model.encode_image([image1])
emb2 = model.encode_image([image2])
else:
# Image-Text Similarity
if input1_is_image:
image = Image.fromarray(input1)
text = input2
emb1 = model.encode_image([image])
emb2 = model.encode_text([text])
else:
image = Image.fromarray(input2)
text = input1
emb1 = model.encode_text([text])
emb2 = model.encode_image([image])
# Compute cosine similarity
similarity_score = (emb1 @ emb2.T).item()
return similarity_score
except Exception as e:
return f"Error: {str(e)}"
# Gradio UI
demo = gr.Interface(
fn=compute_similarity,
inputs=[
gr.Radio(["Text", "Image"], label="Input 1 Type", value="Text"),
gr.Textbox(label="Text Input 1", visible=True),
gr.Image(type="numpy", label="Image Input 1", visible=False),
gr.Radio(["Text", "Image"], label="Input 2 Type", value="Text"),
gr.Textbox(label="Text Input 2", visible=True),
gr.Image(type="numpy", label="Image Input 2", visible=False),
],
outputs=gr.Textbox(label="Similarity Score / Error", interactive=False),
title="JinaAI CLIP Multimodal Similarity",
description="Compare similarity between two inputs (Text, Image, or both)."
)
# Update visibility dynamically
def update_visibility(input1_type, input2_type):
return (
input1_type == "Text", # Text input 1 visibility
input1_type == "Image", # Image input 1 visibility
input2_type == "Text", # Text input 2 visibility
input2_type == "Image" # Image input 2 visibility
)
# Add event handlers for input type change
demo.load(update_visibility, inputs=["Input 1 Type", "Input 2 Type"], outputs=["Text Input 1", "Image Input 1", "Text Input 2", "Image Input 2"])
demo.launch()