wilwork commited on
Commit
14e97b9
·
verified ·
1 Parent(s): f63dbcd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -16
app.py CHANGED
@@ -3,6 +3,13 @@ from PIL import Image
3
  from transformers import CLIPModel, AutoTokenizer, AutoProcessor
4
  import torch
5
 
 
 
 
 
 
 
 
6
  # Load Jina CLIP model with trust_remote_code=True
7
  model_name = "jinaai/jina-clip-v1"
8
  model = CLIPModel.from_pretrained(model_name, trust_remote_code=True)
@@ -10,37 +17,35 @@ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
10
  processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
11
 
12
  def compute_similarity(input1, input2, type1, type2):
13
- inputs = []
14
-
15
  # Process input1
16
  if type1 == "Image":
17
  image1 = Image.open(input1).convert("RGB")
18
- inputs.append(processor(images=image1, return_tensors="pt")["pixel_values"])
19
  else:
20
- inputs.append(tokenizer(input1, return_tensors="pt")["input_ids"])
21
 
22
  # Process input2
23
  if type2 == "Image":
24
  image2 = Image.open(input2).convert("RGB")
25
- inputs.append(processor(images=image2, return_tensors="pt")["pixel_values"])
26
  else:
27
- inputs.append(tokenizer(input2, return_tensors="pt")["input_ids"])
28
 
29
  # Compute embeddings
30
  with torch.no_grad():
31
  if type1 == "Image":
32
- embedding1 = model.get_image_features(pixel_values=inputs[0])
33
  else:
34
- embedding1 = model.get_text_features(input_ids=inputs[0])
35
 
36
  if type2 == "Image":
37
- embedding2 = model.get_image_features(pixel_values=inputs[1])
38
  else:
39
- embedding2 = model.get_text_features(input_ids=inputs[1])
40
 
41
- # Compute similarity
42
- similarity = torch.nn.functional.cosine_similarity(embedding1, embedding2)
43
- return similarity.item()
44
 
45
  with gr.Blocks() as demo:
46
  gr.Markdown("# CLIP-based Similarity Comparison")
@@ -50,12 +55,23 @@ with gr.Blocks() as demo:
50
  type2 = gr.Radio(["Image", "Text"], label="Input 2 Type", value="Text")
51
 
52
  with gr.Row():
53
- input1 = gr.File(label="Upload Image 1 or Enter Text")
54
- input2 = gr.File(label="Upload Image 2 or Enter Text")
 
 
55
 
56
  compare_btn = gr.Button("Compare")
57
  output = gr.Textbox(label="Similarity Score")
58
 
59
- compare_btn.click(compute_similarity, inputs=[input1, input2, type1, type2], outputs=output)
 
 
 
 
 
 
 
 
 
60
 
61
  demo.launch()
 
3
  from transformers import CLIPModel, AutoTokenizer, AutoProcessor
4
  import torch
5
 
6
+ # Ensure required dependencies are installed
7
+ try:
8
+ import timm
9
+ except ImportError:
10
+ import subprocess
11
+ subprocess.run(["pip", "install", "timm"], check=True)
12
+
13
  # Load Jina CLIP model with trust_remote_code=True
14
  model_name = "jinaai/jina-clip-v1"
15
  model = CLIPModel.from_pretrained(model_name, trust_remote_code=True)
 
17
  processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
18
 
19
  def compute_similarity(input1, input2, type1, type2):
 
 
20
  # Process input1
21
  if type1 == "Image":
22
  image1 = Image.open(input1).convert("RGB")
23
+ input1_tensor = processor(images=image1, return_tensors="pt")["pixel_values"]
24
  else:
25
+ input1_tensor = tokenizer(input1, return_tensors="pt")["input_ids"]
26
 
27
  # Process input2
28
  if type2 == "Image":
29
  image2 = Image.open(input2).convert("RGB")
30
+ input2_tensor = processor(images=image2, return_tensors="pt")["pixel_values"]
31
  else:
32
+ input2_tensor = tokenizer(input2, return_tensors="pt")["input_ids"]
33
 
34
  # Compute embeddings
35
  with torch.no_grad():
36
  if type1 == "Image":
37
+ embedding1 = model.get_image_features(pixel_values=input1_tensor)
38
  else:
39
+ embedding1 = model.get_text_features(input_ids=input1_tensor)
40
 
41
  if type2 == "Image":
42
+ embedding2 = model.get_image_features(pixel_values=input2_tensor)
43
  else:
44
+ embedding2 = model.get_text_features(input_ids=input2_tensor)
45
 
46
+ # Compute cosine similarity
47
+ similarity = torch.nn.functional.cosine_similarity(embedding1, embedding2).item()
48
+ return f"Similarity Score: {similarity:.4f}"
49
 
50
  with gr.Blocks() as demo:
51
  gr.Markdown("# CLIP-based Similarity Comparison")
 
55
  type2 = gr.Radio(["Image", "Text"], label="Input 2 Type", value="Text")
56
 
57
  with gr.Row():
58
+ input1 = gr.Image(type="filepath", label="Upload Image 1")
59
+ input2 = gr.Image(type="filepath", label="Upload Image 2")
60
+ text1 = gr.Textbox(label="Enter Text 1")
61
+ text2 = gr.Textbox(label="Enter Text 2")
62
 
63
  compare_btn = gr.Button("Compare")
64
  output = gr.Textbox(label="Similarity Score")
65
 
66
+ compare_btn.click(
67
+ compute_similarity,
68
+ inputs=[
69
+ gr.State(input1) if type1 == "Image" else gr.State(text1),
70
+ gr.State(input2) if type2 == "Image" else gr.State(text2),
71
+ type1,
72
+ type2
73
+ ],
74
+ outputs=output
75
+ )
76
 
77
  demo.launch()