ktllc commited on
Commit
a012f3a
·
1 Parent(s): ac1aaed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -53
app.py CHANGED
@@ -5,75 +5,40 @@ import gradio as gr
5
 
6
  # Load the CLIP model
7
  model, preprocess = clip.load("ViT-B/32")
8
- device = "cuda" if torch.cuda.is_available() else "cpu" # Check for GPU availability
9
  model.to(device).eval()
10
 
11
  # Define the Business Listing variable
12
  Business_Listing = "Air Guide"
13
 
14
- def find_similar_images(text_input):
15
- # Directory where you want to load images
16
- image_dir = "/content/sample_data/Tourism"
17
 
18
- # Create an empty description dictionary
19
- description = f"{Business_Listing} Logo"
20
-
21
- # Set up the layout for displaying images
22
- num_rows = 4
23
- num_cols = 8
24
-
25
- original_images = []
26
- images = []
27
- texts = []
28
-
29
- # Load and preprocess images
30
- image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp', '.ico', '.svg', '.eps', '.pdf']
31
- for filename in [filename for filename in os.listdir(image_dir) if any(filename.endswith(ext) for ext in image_extensions)]:
32
- # Get the image name (without extension)
33
- image_name, _ = os.path.splitext(filename)
34
-
35
- # Load the image
36
- image = Image.open(os.path.join(image_dir, filename)).convert("RGB")
37
- original_images.append(image)
38
- images.append(preprocess(image))
39
- texts.append(description)
40
-
41
- # Prepare input text and images
42
- image_input = torch.tensor(np.stack(images)).to(device)
43
- text_tokens = clip.tokenize([f"This is {text_input}"])
44
- text_tokens = text_tokens.to(device)
45
-
46
- # Encode text and image features
47
  with torch.no_grad():
48
- image_features = model.encode_image(image_input).float()
49
  text_features = model.encode_text(text_tokens).float()
50
-
51
  # Normalize features and calculate similarity
52
  image_features /= image_features.norm(dim=-1, keepdim=True)
53
  text_features /= text_features.norm(dim=-1, keepdim=True)
54
- similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
55
-
56
- # Find the maximum similarity value
57
- max_similarity_value = similarity[0, :].max()
58
-
59
- # Find all indices with the maximum similarity value
60
- max_similarity_indices = np.where(similarity[0, :] == max_similarity_value)
61
-
62
- # Get the filenames with the highest similarity
63
- valid_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp', '.ico', '.svg', '.eps', '.pdf')
64
- image_files = [filename for filename in os.listdir(image_dir) if filename.endswith(valid_extensions)]
65
- filenames_with_highest_similarity = [image_files[i] for i in max_similarity_indices[0]]
66
-
67
- return filenames_with_highest_similarity, max_similarity_value
68
 
69
  # Define a Gradio interface
70
  iface = gr.Interface(
71
- fn=find_similar_images,
72
- inputs="text",
73
- outputs=["text", "number"],
74
  live=True,
75
  interpretation="default",
76
- title="CLIP Model Image Search",
 
77
  )
78
 
79
  iface.launch()
 
5
 
6
  # Load the CLIP model
7
  model, preprocess = clip.load("ViT-B/32")
8
+ device = "cuda" if torch.cuda.is available() else "cpu"
9
  model.to(device).eval()
10
 
11
  # Define the Business Listing variable
12
  Business_Listing = "Air Guide"
13
 
14
+ def find_similarity(image, text_input):
15
+ # Preprocess the uploaded image
16
+ image = preprocess(image).unsqueeze(0).to(device)
17
 
18
+ # Prepare input text
19
+ text_tokens = clip.tokenize([text_input]).to(device)
20
+
21
+ # Encode image and text features
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  with torch.no_grad():
23
+ image_features = model.encode_image(image).float()
24
  text_features = model.encode_text(text_tokens).float()
25
+
26
  # Normalize features and calculate similarity
27
  image_features /= image_features.norm(dim=-1, keepdim=True)
28
  text_features /= text_features.norm(dim=-1, keepdim=True)
29
+ similarity = (text_features @ image_features.T).cpu().numpy()
30
+
31
+ return similarity[0, 0]
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  # Define a Gradio interface
34
  iface = gr.Interface(
35
+ fn=find_similarity,
36
+ inputs=[gr.Image(type="pil"), "text"],
37
+ outputs="number",
38
  live=True,
39
  interpretation="default",
40
+ title="CLIP Model Image-Text Cosine Similarity",
41
+ description="Upload an image and enter text to find their cosine similarity.",
42
  )
43
 
44
  iface.launch()