wilwork commited on
Commit
7151f63
·
verified ·
1 Parent(s): e6e728f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -22
app.py CHANGED
@@ -2,33 +2,72 @@ import gradio as gr
2
  from transformers import AutoModel
3
  from PIL import Image
4
  import torch
5
- import torch.nn.functional as F
6
- import requests
7
- from io import BytesIO
8
 
9
- # Load model with remote code support
10
- model = AutoModel.from_pretrained('jinaai/jina-clip-v1', trust_remote_code=True)
11
 
12
- def compute_similarity(image, text):
13
- image = Image.fromarray(image) # Convert NumPy array to PIL Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- with torch.no_grad():
16
- # Encode text and image using JinaAI CLIP model
17
- text_embeds = model.encode_text([text]) # Expecting list input
18
- image_embeds = model.encode_image([image]) # Expecting list input
19
 
20
- # Compute cosine similarity
21
- similarity_score = (text_embeds @ image_embeds.T).item()
 
22
 
23
- return similarity_score
 
 
 
 
 
 
24
 
25
- # Gradio UI
26
- demo = gr.Interface(
27
- fn=compute_similarity,
28
- inputs=[gr.Image(type="numpy"), gr.Textbox(label="Enter text")],
29
- outputs=gr.Number(label="Similarity Score"),
30
- title="JinaAI CLIP Image-Text Similarity",
31
- description="Upload an image and enter a text prompt to get the similarity score."
32
- )
 
 
 
 
 
 
 
 
33
 
34
  demo.launch()
 
2
  from transformers import AutoModel
3
  from PIL import Image
4
  import torch
 
 
 
5
 
6
+ # Load JinaAI CLIP model
7
+ model = AutoModel.from_pretrained("jinaai/jina-clip-v1", trust_remote_code=True)
8
 
9
+ # Function to compute similarity
10
+ def compute_similarity(input1, input2, input1_type, input2_type):
11
+ # Check if inputs are empty
12
+ if (input1_type == "Text" and not input1.strip()) or (input1_type == "Image" and input1 is None):
13
+ return "Error: Input 1 is empty!"
14
+ if (input2_type == "Text" and not input2.strip()) or (input2_type == "Image" and input2 is None):
15
+ return "Error: Input 2 is empty!"
16
+
17
+ inputs = []
18
+
19
+ # Process first input
20
+ if input1_type == "Text":
21
+ text1_embedding = model.encode_text([input1])
22
+ inputs.append(text1_embedding)
23
+ elif input1_type == "Image":
24
+ image1_embedding = model.encode_image([Image.fromarray(input1)])
25
+ inputs.append(image1_embedding)
26
+
27
+ # Process second input
28
+ if input2_type == "Text":
29
+ text2_embedding = model.encode_text([input2])
30
+ inputs.append(text2_embedding)
31
+ elif input2_type == "Image":
32
+ image2_embedding = model.encode_image([Image.fromarray(input2)])
33
+ inputs.append(image2_embedding)
34
+
35
+ # Compute cosine similarity
36
+ similarity_score = (inputs[0] @ inputs[1].T).item()
37
+
38
+ return similarity_score
39
 
40
+ # Gradio UI
41
+ with gr.Blocks() as demo:
42
+ gr.Markdown("## Multimodal Similarity: Text-Text, Text-Image, Image-Image")
 
43
 
44
+ with gr.Row():
45
+ input1_type = gr.Radio(["Text", "Image"], label="Input 1 Type", value="Text")
46
+ input2_type = gr.Radio(["Text", "Image"], label="Input 2 Type", value="Image")
47
 
48
+ with gr.Row():
49
+ input1 = gr.Textbox(label="Text Input 1", visible=True)
50
+ image1 = gr.Image(type="numpy", label="Image Input 1", visible=False)
51
+
52
+ with gr.Row():
53
+ input2 = gr.Textbox(label="Text Input 2", visible=False)
54
+ image2 = gr.Image(type="numpy", label="Image Input 2", visible=True)
55
 
56
+ output = gr.Textbox(label="Similarity Score / Error", interactive=False)
57
+
58
+ # Function to toggle visibility based on selected types
59
+ def update_visibility(input1_type, input2_type):
60
+ return (
61
+ input1_type == "Text",
62
+ input1_type == "Image",
63
+ input2_type == "Text",
64
+ input2_type == "Image"
65
+ )
66
+
67
+ input1_type.change(update_visibility, inputs=[input1_type, input2_type], outputs=[input1, image1, input2, image2])
68
+ input2_type.change(update_visibility, inputs=[input1_type, input2_type], outputs=[input1, image1, input2, image2])
69
+
70
+ btn = gr.Button("Compute Similarity")
71
+ btn.click(compute_similarity, inputs=[input1, input2, input1_type, input2_type], outputs=output)
72
 
73
  demo.launch()