ktllc commited on
Commit
d666f15
·
1 Parent(s): c0b02e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -25
app.py CHANGED
@@ -1,31 +1,62 @@
1
  import gradio as gr
 
2
  import numpy as np
3
  from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
4
  import base64
5
- from gradio_client import Client
 
 
 
6
 
7
  # Load the segmentation model
8
  sam_checkpoint = "sam_vit_h_4b8939.pth"
9
  model_type = "vit_h"
10
  sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
11
 
12
- # Define a function for image segmentation
13
- def segment_image(input_image):
14
- # Convert Gradio input image to a NumPy array
15
- image = input_image.astype(np.uint8)
16
 
17
- # Initialize the mask generator
18
- mask_generator = SamAutomaticMaskGenerator(sam)
19
 
20
- # Generate masks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  masks = mask_generator.generate(image)
22
 
23
- highest_cosine_value = -1
24
- highest_cosine_base64 = ""
25
 
26
  for i, mask_dict in enumerate(masks):
27
  mask_data = (mask_dict['segmentation'] * 255).astype(np.uint8)
28
- segmented_region = cv2.bitwise_and(image, image, mask=mask_data)
29
 
30
  x, y, w, h = map(int, mask_dict['bbox'])
31
  cropped_region = segmented_region[y:y+h, x:x+w]
@@ -34,25 +65,22 @@ def segment_image(input_image):
34
  _, buffer = cv2.imencode(".png", cv2.cvtColor(cropped_region, cv2.COLOR_BGR2RGB))
35
  segmented_image_base64 = base64.b64encode(buffer).decode()
36
 
37
- # Call the API to get the cosine similarity
38
- client = Client("https://ktllc-clip-model-inputbase64.hf.space/--replicas/mmz7z/")
39
- result = client.predict(
40
- segmented_image_base64, # Base64 Image
41
- "Text input", # Text input
42
- api_name="/predict"
43
- )
44
 
45
- cosine_similarity = result[0].get("score", 0.0)
 
46
 
47
- if cosine_similarity > highest_cosine_value:
48
- highest_cosine_value = cosine_similarity
49
- highest_cosine_base64 = segmented_image_base64
50
 
51
- return highest_cosine_base64
 
52
 
53
  # Create Gradio components
54
  input_image = gr.inputs.Image()
55
- output_image = gr.outputs.Image(type="pil")
 
56
 
57
  # Create a Gradio interface
58
- gr.Interface(fn=segment_image, inputs=input_image, outputs=output_image).launch()
 
1
  import gradio as gr
2
+ import cv2
3
  import numpy as np
4
  from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
5
  import base64
6
+ from PIL import Image
7
+ from io import BytesIO
8
+ import torch
9
+ import clip
10
 
11
  # Load the segmentation model
12
  sam_checkpoint = "sam_vit_h_4b8939.pth"
13
  model_type = "vit_h"
14
  sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
15
 
16
+ # Load the CLIP model
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ model, preprocess = clip.load("ViT-B/32", device=device)
 
19
 
 
 
20
 
21
+ def find_similarity(base64_image, text_input):
22
+ # Decode the base64 image to bytes
23
+ image_bytes = base64.b64decode(base64_image)
24
+
25
+ # Convert the bytes to a PIL image
26
+ image = Image.open(BytesIO(image_bytes))
27
+
28
+ # Preprocess the image
29
+ image = preprocess(image).unsqueeze(0).to(device)
30
+
31
+ # Prepare input text
32
+ text_tokens = clip.tokenize([text_input]).to(device)
33
+
34
+ # Encode image and text features
35
+
36
+
37
+ with torch.no_grad():
38
+ image_features = model.encode_image(image)
39
+ text_features = model.encode_text(text_tokens)
40
+
41
+ # Normalize features and calculate similarity
42
+ image_features /= image_features.norm(dim=-1, keepdim=True)
43
+ text_features /= text_features.norm(dim=-1, keepdim=True)
44
+ similarity = (text_features @ image_features.T).squeeze(0).cpu().numpy()
45
+
46
+ return similarity
47
+
48
+
49
+ # Define a function for image segmentation
50
+ def segment_image(input_image, text_input):
51
+ image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
52
+ mask_generator = SamAutomaticMaskGenerator(sam)
53
  masks = mask_generator.generate(image)
54
 
55
+ segmented_regions = [] # List to store segmented regions with similarity scores
 
56
 
57
  for i, mask_dict in enumerate(masks):
58
  mask_data = (mask_dict['segmentation'] * 255).astype(np.uint8)
59
+ segmented_region = cv2.bitwise_and(input_image, input_image, mask=mask_data)
60
 
61
  x, y, w, h = map(int, mask_dict['bbox'])
62
  cropped_region = segmented_region[y:y+h, x:x+w]
 
65
  _, buffer = cv2.imencode(".png", cv2.cvtColor(cropped_region, cv2.COLOR_BGR2RGB))
66
  segmented_image_base64 = base64.b64encode(buffer).decode()
67
 
68
+ # Calculate similarity for the segmented image
69
+ similarity = find_similarity(segmented_image_base64, text_input)
 
 
 
 
 
70
 
71
+ # Append the segmented image and its similarity score
72
+ segmented_regions.append({"image": segmented_image_base64, "similarity": similarity})
73
 
74
+ # Sort the segmented images by similarity in descending order
75
+ segmented_regions.sort(key=lambda x: x["similarity"], reverse=True)
 
76
 
77
+ # Return the segmented images in descending order of similarity
78
+ return segmented_regions
79
 
80
  # Create Gradio components
81
  input_image = gr.inputs.Image()
82
+ text_input = gr.inputs.Text()
83
+ output_images = gr.outputs.JSON()
84
 
85
  # Create a Gradio interface
86
+ gr.Interface(fn=segment_image, inputs=[input_image, text_input], outputs=output_images).launch()