wilwork commited on
Commit
7544622
·
verified ·
1 Parent(s): 95856bb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ from transformers import CLIPProcessor, CLIPModel
4
+ import torch
5
+
6
+ # Load Jina CLIP model
7
+ model_name = "jinaai/jina-clip-v1"
8
+ model = CLIPModel.from_pretrained(model_name)
9
+ processor = CLIPProcessor.from_pretrained(model_name)
10
+
11
+ def compute_similarity(input1, input2, type1, type2):
12
+ inputs = []
13
+
14
+ # Process input1
15
+ if type1 == "Image":
16
+ image1 = Image.open(input1).convert("RGB")
17
+ inputs.append(processor(images=image1, return_tensors="pt"))
18
+ else:
19
+ inputs.append(processor(text=[input1], return_tensors="pt"))
20
+
21
+ # Process input2
22
+ if type2 == "Image":
23
+ image2 = Image.open(input2).convert("RGB")
24
+ inputs.append(processor(images=image2, return_tensors="pt"))
25
+ else:
26
+ inputs.append(processor(text=[input2], return_tensors="pt"))
27
+
28
+ # Compute embeddings
29
+ with torch.no_grad():
30
+ if type1 == "Image":
31
+ embedding1 = model.get_image_features(**inputs[0])
32
+ else:
33
+ embedding1 = model.get_text_features(**inputs[0])
34
+
35
+ if type2 == "Image":
36
+ embedding2 = model.get_image_features(**inputs[1])
37
+ else:
38
+ embedding2 = model.get_text_features(**inputs[1])
39
+
40
+ # Compute similarity
41
+ similarity = torch.nn.functional.cosine_similarity(embedding1, embedding2)
42
+ return similarity.item()
43
+
44
+ with gr.Blocks() as demo:
45
+ gr.Markdown("# CLIP-based Similarity Comparison")
46
+
47
+ with gr.Row():
48
+ type1 = gr.Radio(["Image", "Text"], label="Input 1 Type", value="Image")
49
+ type2 = gr.Radio(["Image", "Text"], label="Input 2 Type", value="Text")
50
+
51
+ with gr.Row():
52
+ input1 = gr.File(label="Upload Image 1 or Enter Text")
53
+ input2 = gr.File(label="Upload Image 2 or Enter Text")
54
+
55
+ compare_btn = gr.Button("Compare")
56
+ output = gr.Textbox(label="Similarity Score")
57
+
58
+ compare_btn.click(compute_similarity, inputs=[input1, input2, type1, type2], outputs=output)
59
+
60
+ demo.launch()