Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -3,6 +3,13 @@ from PIL import Image
|
|
3 |
from transformers import CLIPModel, AutoTokenizer, AutoProcessor
|
4 |
import torch
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
# Load Jina CLIP model with trust_remote_code=True
|
7 |
model_name = "jinaai/jina-clip-v1"
|
8 |
model = CLIPModel.from_pretrained(model_name, trust_remote_code=True)
|
@@ -10,37 +17,35 @@ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
|
10 |
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
11 |
|
12 |
def compute_similarity(input1, input2, type1, type2):
|
13 |
-
inputs = []
|
14 |
-
|
15 |
# Process input1
|
16 |
if type1 == "Image":
|
17 |
image1 = Image.open(input1).convert("RGB")
|
18 |
-
|
19 |
else:
|
20 |
-
|
21 |
|
22 |
# Process input2
|
23 |
if type2 == "Image":
|
24 |
image2 = Image.open(input2).convert("RGB")
|
25 |
-
|
26 |
else:
|
27 |
-
|
28 |
|
29 |
# Compute embeddings
|
30 |
with torch.no_grad():
|
31 |
if type1 == "Image":
|
32 |
-
embedding1 = model.get_image_features(pixel_values=
|
33 |
else:
|
34 |
-
embedding1 = model.get_text_features(input_ids=
|
35 |
|
36 |
if type2 == "Image":
|
37 |
-
embedding2 = model.get_image_features(pixel_values=
|
38 |
else:
|
39 |
-
embedding2 = model.get_text_features(input_ids=
|
40 |
|
41 |
-
# Compute similarity
|
42 |
-
similarity = torch.nn.functional.cosine_similarity(embedding1, embedding2)
|
43 |
-
return similarity
|
44 |
|
45 |
with gr.Blocks() as demo:
|
46 |
gr.Markdown("# CLIP-based Similarity Comparison")
|
@@ -50,12 +55,23 @@ with gr.Blocks() as demo:
|
|
50 |
type2 = gr.Radio(["Image", "Text"], label="Input 2 Type", value="Text")
|
51 |
|
52 |
with gr.Row():
|
53 |
-
input1 = gr.
|
54 |
-
input2 = gr.
|
|
|
|
|
55 |
|
56 |
compare_btn = gr.Button("Compare")
|
57 |
output = gr.Textbox(label="Similarity Score")
|
58 |
|
59 |
-
compare_btn.click(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
demo.launch()
|
|
|
3 |
from transformers import CLIPModel, AutoTokenizer, AutoProcessor
|
4 |
import torch
|
5 |
|
6 |
+
# Ensure required dependencies are installed
|
7 |
+
try:
|
8 |
+
import timm
|
9 |
+
except ImportError:
|
10 |
+
import subprocess
|
11 |
+
subprocess.run(["pip", "install", "timm"], check=True)
|
12 |
+
|
13 |
# Load Jina CLIP model with trust_remote_code=True
|
14 |
model_name = "jinaai/jina-clip-v1"
|
15 |
model = CLIPModel.from_pretrained(model_name, trust_remote_code=True)
|
|
|
17 |
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
18 |
|
19 |
def compute_similarity(input1, input2, type1, type2):
|
|
|
|
|
20 |
# Process input1
|
21 |
if type1 == "Image":
|
22 |
image1 = Image.open(input1).convert("RGB")
|
23 |
+
input1_tensor = processor(images=image1, return_tensors="pt")["pixel_values"]
|
24 |
else:
|
25 |
+
input1_tensor = tokenizer(input1, return_tensors="pt")["input_ids"]
|
26 |
|
27 |
# Process input2
|
28 |
if type2 == "Image":
|
29 |
image2 = Image.open(input2).convert("RGB")
|
30 |
+
input2_tensor = processor(images=image2, return_tensors="pt")["pixel_values"]
|
31 |
else:
|
32 |
+
input2_tensor = tokenizer(input2, return_tensors="pt")["input_ids"]
|
33 |
|
34 |
# Compute embeddings
|
35 |
with torch.no_grad():
|
36 |
if type1 == "Image":
|
37 |
+
embedding1 = model.get_image_features(pixel_values=input1_tensor)
|
38 |
else:
|
39 |
+
embedding1 = model.get_text_features(input_ids=input1_tensor)
|
40 |
|
41 |
if type2 == "Image":
|
42 |
+
embedding2 = model.get_image_features(pixel_values=input2_tensor)
|
43 |
else:
|
44 |
+
embedding2 = model.get_text_features(input_ids=input2_tensor)
|
45 |
|
46 |
+
# Compute cosine similarity
|
47 |
+
similarity = torch.nn.functional.cosine_similarity(embedding1, embedding2).item()
|
48 |
+
return f"Similarity Score: {similarity:.4f}"
|
49 |
|
50 |
with gr.Blocks() as demo:
|
51 |
gr.Markdown("# CLIP-based Similarity Comparison")
|
|
|
55 |
type2 = gr.Radio(["Image", "Text"], label="Input 2 Type", value="Text")
|
56 |
|
57 |
with gr.Row():
|
58 |
+
input1 = gr.Image(type="filepath", label="Upload Image 1")
|
59 |
+
input2 = gr.Image(type="filepath", label="Upload Image 2")
|
60 |
+
text1 = gr.Textbox(label="Enter Text 1")
|
61 |
+
text2 = gr.Textbox(label="Enter Text 2")
|
62 |
|
63 |
compare_btn = gr.Button("Compare")
|
64 |
output = gr.Textbox(label="Similarity Score")
|
65 |
|
66 |
+
compare_btn.click(
|
67 |
+
compute_similarity,
|
68 |
+
inputs=[
|
69 |
+
gr.State(input1) if type1 == "Image" else gr.State(text1),
|
70 |
+
gr.State(input2) if type2 == "Image" else gr.State(text2),
|
71 |
+
type1,
|
72 |
+
type2
|
73 |
+
],
|
74 |
+
outputs=output
|
75 |
+
)
|
76 |
|
77 |
demo.launch()
|