wilwork commited on
Commit
d28a2eb
·
verified ·
1 Parent(s): 5fadd6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -3,7 +3,7 @@ from transformers import CLIPModel, CLIPFeatureExtractor, BertTokenizer
3
  from PIL import Image
4
  import torch
5
 
6
- # Load model and appropriate processors separately
7
  model_name = "jinaai/jina-clip-v1"
8
  model = CLIPModel.from_pretrained(model_name)
9
  feature_extractor = CLIPFeatureExtractor.from_pretrained(model_name)
@@ -11,12 +11,13 @@ tokenizer = BertTokenizer.from_pretrained(model_name)
11
 
12
  def compute_similarity(image, text):
13
  image = Image.fromarray(image) # Convert NumPy array to PIL Image
14
-
15
  # Process image
16
  image_inputs = feature_extractor(images=image, return_tensors="pt")
17
 
18
- # Process text
19
  text_inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
 
20
 
21
  with torch.no_grad():
22
  outputs = model(**image_inputs, **text_inputs)
@@ -34,4 +35,4 @@ demo = gr.Interface(
34
  description="Upload an image and enter a text prompt to get the similarity score."
35
  )
36
 
37
- demo.launch()
 
3
  from PIL import Image
4
  import torch
5
 
6
+ # Load model and processors separately
7
  model_name = "jinaai/jina-clip-v1"
8
  model = CLIPModel.from_pretrained(model_name)
9
  feature_extractor = CLIPFeatureExtractor.from_pretrained(model_name)
 
11
 
12
  def compute_similarity(image, text):
13
  image = Image.fromarray(image) # Convert NumPy array to PIL Image
14
+
15
  # Process image
16
  image_inputs = feature_extractor(images=image, return_tensors="pt")
17
 
18
+ # Process text (Remove `token_type_ids`)
19
  text_inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
20
+ text_inputs.pop("token_type_ids", None) # Remove token_type_ids to avoid TypeError
21
 
22
  with torch.no_grad():
23
  outputs = model(**image_inputs, **text_inputs)
 
35
  description="Upload an image and enter a text prompt to get the similarity score."
36
  )
37
 
38
+ demo.launch()