File size: 2,252 Bytes
7544622
 
c4ab507
7544622
 
f63dbcd
7544622
f63dbcd
 
 
7544622
 
 
 
 
 
 
c4ab507
7544622
c4ab507
7544622
 
 
 
c4ab507
7544622
c4ab507
7544622
 
 
 
c4ab507
7544622
c4ab507
7544622
 
c4ab507
7544622
c4ab507
7544622
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import gradio as gr
from PIL import Image
from transformers import CLIPModel, AutoTokenizer, AutoProcessor
import torch

# Load Jina CLIP model with trust_remote_code=True
model_name = "jinaai/jina-clip-v1"
model = CLIPModel.from_pretrained(model_name, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)

def compute_similarity(input1, input2, type1, type2):
    inputs = []
    
    # Process input1
    if type1 == "Image":
        image1 = Image.open(input1).convert("RGB")
        inputs.append(processor(images=image1, return_tensors="pt")["pixel_values"])
    else:
        inputs.append(tokenizer(input1, return_tensors="pt")["input_ids"])
    
    # Process input2
    if type2 == "Image":
        image2 = Image.open(input2).convert("RGB")
        inputs.append(processor(images=image2, return_tensors="pt")["pixel_values"])
    else:
        inputs.append(tokenizer(input2, return_tensors="pt")["input_ids"])
    
    # Compute embeddings
    with torch.no_grad():
        if type1 == "Image":
            embedding1 = model.get_image_features(pixel_values=inputs[0])
        else:
            embedding1 = model.get_text_features(input_ids=inputs[0])
        
        if type2 == "Image":
            embedding2 = model.get_image_features(pixel_values=inputs[1])
        else:
            embedding2 = model.get_text_features(input_ids=inputs[1])
    
    # Compute similarity
    similarity = torch.nn.functional.cosine_similarity(embedding1, embedding2)
    return similarity.item()

with gr.Blocks() as demo:
    gr.Markdown("# CLIP-based Similarity Comparison")
    
    with gr.Row():
        type1 = gr.Radio(["Image", "Text"], label="Input 1 Type", value="Image")
        type2 = gr.Radio(["Image", "Text"], label="Input 2 Type", value="Text")
    
    with gr.Row():
        input1 = gr.File(label="Upload Image 1 or Enter Text")
        input2 = gr.File(label="Upload Image 2 or Enter Text")
    
    compare_btn = gr.Button("Compare")
    output = gr.Textbox(label="Similarity Score")
    
    compare_btn.click(compute_similarity, inputs=[input1, input2, type1, type2], outputs=output)

demo.launch()