wilwork commited on
Commit
30bfbf8
·
verified ·
1 Parent(s): d28a2eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -5
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  from transformers import CLIPModel, CLIPFeatureExtractor, BertTokenizer
3
  from PIL import Image
4
  import torch
 
5
 
6
  # Load model and processors separately
7
  model_name = "jinaai/jina-clip-v1"
@@ -17,13 +18,20 @@ def compute_similarity(image, text):
17
 
18
  # Process text (Remove `token_type_ids`)
19
  text_inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
20
- text_inputs.pop("token_type_ids", None) # Remove token_type_ids to avoid TypeError
21
 
22
  with torch.no_grad():
23
- outputs = model(**image_inputs, **text_inputs)
24
- logits_per_image = outputs.logits_per_image # Image-to-text similarity
25
- similarity_score = logits_per_image.item()
26
-
 
 
 
 
 
 
 
27
  return similarity_score
28
 
29
  # Gradio UI
 
2
  from transformers import CLIPModel, CLIPFeatureExtractor, BertTokenizer
3
  from PIL import Image
4
  import torch
5
+ import torch.nn.functional as F
6
 
7
  # Load model and processors separately
8
  model_name = "jinaai/jina-clip-v1"
 
18
 
19
  # Process text (Remove `token_type_ids`)
20
  text_inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
21
+ text_inputs.pop("token_type_ids", None)
22
 
23
  with torch.no_grad():
24
+ # Extract embeddings
25
+ image_embeds = model.get_image_features(**image_inputs)
26
+ text_embeds = model.get_text_features(**text_inputs)
27
+
28
+ # Normalize embeddings
29
+ image_embeds = F.normalize(image_embeds, p=2, dim=-1)
30
+ text_embeds = F.normalize(text_embeds, p=2, dim=-1)
31
+
32
+ # Compute cosine similarity
33
+ similarity_score = (image_embeds @ text_embeds.T).item()
34
+
35
  return similarity_score
36
 
37
  # Gradio UI