wilwork commited on
Commit
f4bfa5f
·
verified ·
1 Parent(s): b57f8d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -21
app.py CHANGED
@@ -1,49 +1,54 @@
1
  import gradio as gr
2
  from transformers import AutoModel
3
  from PIL import Image
 
4
  import torch
5
 
6
  # Load JinaAI CLIP model
7
  model = AutoModel.from_pretrained("jinaai/jina-clip-v1", trust_remote_code=True)
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  # Function to compute similarity
10
  def compute_similarity(input1, input2, input1_type, input2_type):
11
- # Ensure inputs are valid
12
- if input1_type == "Text" and (not input1 or input1.strip() == ""):
13
  return "Error: Input 1 is empty!"
14
- if input2_type == "Text" and (not input2 or input2.strip() == ""):
15
  return "Error: Input 2 is empty!"
16
  if input1_type == "Image" and input1 is None:
17
  return "Error: Image 1 is missing!"
18
  if input2_type == "Image" and input2 is None:
19
  return "Error: Image 2 is missing!"
20
 
21
- # Encode inputs
22
- inputs = []
23
-
24
- if input1_type == "Text":
25
- text1_embedding = model.encode_text([input1])
26
- inputs.append(text1_embedding)
27
- elif input1_type == "Image":
28
- image1_embedding = model.encode_image([Image.fromarray(input1)])
29
- inputs.append(image1_embedding)
30
 
31
- if input2_type == "Text":
32
- text2_embedding = model.encode_text([input2])
33
- inputs.append(text2_embedding)
34
- elif input2_type == "Image":
35
- image2_embedding = model.encode_image([Image.fromarray(input2)])
36
- inputs.append(image2_embedding)
37
 
38
  # Compute cosine similarity
39
- similarity_score = (inputs[0] @ inputs[1].T).item()
40
  return f"Similarity Score: {similarity_score:.4f}"
41
 
42
  # Function to toggle input fields dynamically
43
  def update_visibility(input1_type, input2_type):
44
  return (
45
- gr.update(visible=(input1_type == "Text")), # Show text input if Text is selected
46
- gr.update(visible=(input1_type == "Image")), # Show image input if Image is selected
47
  gr.update(visible=(input2_type == "Text")),
48
  gr.update(visible=(input2_type == "Image"))
49
  )
 
1
  import gradio as gr
2
  from transformers import AutoModel
3
  from PIL import Image
4
+ import numpy as np
5
  import torch
6
 
7
  # Load JinaAI CLIP model
8
  model = AutoModel.from_pretrained("jinaai/jina-clip-v1", trust_remote_code=True)
9
 
10
+ # Function to process input (convert to text or PIL image)
11
+ def process_input(input_data, input_type):
12
+ if input_type == "Text":
13
+ return model.encode_text([input_data]) if input_data.strip() else None
14
+ elif input_type == "Image":
15
+ if isinstance(input_data, str): # If it's a file path
16
+ image = Image.open(input_data).convert("RGB")
17
+ elif isinstance(input_data, np.ndarray): # If it's a NumPy array (Gradio default)
18
+ image = Image.fromarray(input_data)
19
+ else:
20
+ return None # Invalid input type
21
+ return model.encode_image([image])
22
+ return None
23
+
24
  # Function to compute similarity
25
  def compute_similarity(input1, input2, input1_type, input2_type):
26
+ # Validate inputs
27
+ if input1_type == "Text" and not input1.strip():
28
  return "Error: Input 1 is empty!"
29
+ if input2_type == "Text" and not input2.strip():
30
  return "Error: Input 2 is empty!"
31
  if input1_type == "Image" and input1 is None:
32
  return "Error: Image 1 is missing!"
33
  if input2_type == "Image" and input2 is None:
34
  return "Error: Image 2 is missing!"
35
 
36
+ # Process inputs
37
+ embedding1 = process_input(input1, input1_type)
38
+ embedding2 = process_input(input2, input2_type)
 
 
 
 
 
 
39
 
40
+ if embedding1 is None or embedding2 is None:
41
+ return "Error: Failed to process input!"
 
 
 
 
42
 
43
  # Compute cosine similarity
44
+ similarity_score = (embedding1 @ embedding2.T).item()
45
  return f"Similarity Score: {similarity_score:.4f}"
46
 
47
  # Function to toggle input fields dynamically
48
  def update_visibility(input1_type, input2_type):
49
  return (
50
+ gr.update(visible=(input1_type == "Text")),
51
+ gr.update(visible=(input1_type == "Image")),
52
  gr.update(visible=(input2_type == "Text")),
53
  gr.update(visible=(input2_type == "Image"))
54
  )