File size: 3,365 Bytes
7544622
 
c4ab507
7544622
 
14e97b9
 
 
 
 
 
 
f63dbcd
7544622
f63dbcd
 
 
7544622
cbcffb4
7544622
 
cbcffb4
 
7544622
14e97b9
cbcffb4
 
 
 
08f9b31
cbcffb4
7544622
 
 
cbcffb4
 
7544622
14e97b9
cbcffb4
 
 
 
08f9b31
cbcffb4
7544622
 
 
 
14e97b9
7544622
14e97b9
7544622
 
14e97b9
7544622
14e97b9
7544622
629d862
 
 
 
14e97b9
 
 
7544622
 
 
 
 
 
 
 
 
14e97b9
 
 
 
7544622
 
 
 
14e97b9
 
 
2241cd5
 
cbcffb4
 
14e97b9
 
 
 
 
7544622
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
from PIL import Image
from transformers import CLIPModel, AutoTokenizer, AutoProcessor
import torch

# Ensure required dependencies are installed
try:
    import timm
except ImportError:
    import subprocess
    subprocess.run(["pip", "install", "timm"], check=True)

# Load Jina CLIP model with trust_remote_code=True
model_name = "jinaai/jina-clip-v1"
model = CLIPModel.from_pretrained(model_name, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)

def compute_similarity(input1, input2, text1, text2, type1, type2):
    # Process input1
    if type1 == "Image":
        if not input1:
            return "Error: No image provided for Input 1"
        image1 = Image.open(input1).convert("RGB")
        input1_tensor = processor(images=image1, return_tensors="pt")["pixel_values"]
    elif type1 == "Text":
        if not text1.strip():
            return "Error: No text provided for Input 1"
        input1_tensor = tokenizer(text1, return_tensors="pt")["input_ids"]
    else:
        return "Error: Invalid input type for Input 1"
    
    # Process input2
    if type2 == "Image":
        if not input2:
            return "Error: No image provided for Input 2"
        image2 = Image.open(input2).convert("RGB")
        input2_tensor = processor(images=image2, return_tensors="pt")["pixel_values"]
    elif type2 == "Text":
        if not text2.strip():
            return "Error: No text provided for Input 2"
        input2_tensor = tokenizer(text2, return_tensors="pt")["input_ids"]
    else:
        return "Error: Invalid input type for Input 2"
    
    # Compute embeddings
    with torch.no_grad():
        if type1 == "Image":
            embedding1 = model.get_image_features(pixel_values=input1_tensor)
        else:
            embedding1 = model.get_text_features(input_ids=input1_tensor)
        
        if type2 == "Image":
            embedding2 = model.get_image_features(pixel_values=input2_tensor)
        else:
            embedding2 = model.get_text_features(input_ids=input2_tensor)
    
    # Normalize embeddings
    embedding1 = embedding1 / embedding1.norm(dim=-1, keepdim=True)
    embedding2 = embedding2 / embedding2.norm(dim=-1, keepdim=True)
    
    # Compute cosine similarity
    similarity = torch.nn.functional.cosine_similarity(embedding1, embedding2).item()
    return f"Similarity Score: {similarity:.4f}"

with gr.Blocks() as demo:
    gr.Markdown("# CLIP-based Similarity Comparison")
    
    with gr.Row():
        type1 = gr.Radio(["Image", "Text"], label="Input 1 Type", value="Image")
        type2 = gr.Radio(["Image", "Text"], label="Input 2 Type", value="Text")
    
    with gr.Row():
        input1 = gr.Image(type="filepath", label="Upload Image 1")
        input2 = gr.Image(type="filepath", label="Upload Image 2")
        text1 = gr.Textbox(label="Enter Text 1")
        text2 = gr.Textbox(label="Enter Text 2")
    
    compare_btn = gr.Button("Compare")
    output = gr.Textbox(label="Similarity Score")
    
    compare_btn.click(
        compute_similarity, 
        inputs=[
            input1,
            input2,
            text1,
            text2,
            type1,
            type2
        ], 
        outputs=output
    )

demo.launch()