ktllc commited on
Commit
f21cc03
·
1 Parent(s): 431f965

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -31
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import gradio as gr
 
2
  import numpy as np
3
  from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
4
  import base64
@@ -6,7 +7,6 @@ from PIL import Image
6
  from io import BytesIO
7
  import torch
8
  import clip
9
- import json
10
 
11
  # Load the segmentation model
12
  sam_checkpoint = "sam_vit_h_4b8939.pth"
@@ -18,10 +18,16 @@ model, preprocess = clip.load("ViT-B/32")
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
  model.to(device).eval()
20
 
21
- def find_similarity(image, text_input):
 
22
  try:
 
 
 
 
 
 
23
  # Preprocess the image
24
- image = Image.fromarray(image)
25
  image = preprocess(image).unsqueeze(0).to(device)
26
 
27
  # Prepare input text
@@ -37,45 +43,43 @@ def find_similarity(image, text_input):
37
  text_features /= text_features.norm(dim=-1, keepdim=True)
38
  similarity = (text_features @ image_features.T).squeeze(0).cpu().numpy()
39
 
40
- return similarity.tolist() # Convert to a list
41
-
42
  except Exception as e:
43
- return json.dumps({"error": str(e)})
44
 
 
45
  def segment_image(input_image, text_input):
46
- try:
47
- image_bytes = base64.b64decode(input_image)
48
- image = Image.open(BytesIO(image_bytes))
49
-
50
- # Convert the image to RGB color mode
51
- image = image.convert("RGB")
52
 
53
- mask_generator = SamAutomaticMaskGenerator(sam)
54
- masks = mask_generator.generate(np.array(image))
55
 
56
- segmented_regions = [] # List to store segmented regions with similarity scores
 
 
57
 
58
- for i, mask_dict in enumerate(masks):
59
- mask_data = (mask_dict['segmentation'] * 255).astype(np.uint8)
60
- segmented_region = Image.fromarray(mask_data) # Convert mask to an image
61
 
62
- x, y, w, h = map(int, mask_dict['bbox'])
63
- cropped_region = image.crop((x, y, x + w, y + h))
 
64
 
65
- # Calculate similarity for the segmented image
66
- similarity = find_similarity(np.array(cropped_region), text_input)
67
 
68
- # Append the segmented image and its similarity score
69
- segmented_regions.append({"image": input_image, "similarity": similarity})
70
 
71
- # Sort the segmented images by similarity in descending order
72
- segmented_regions.sort(key=lambda x: x["similarity"], reverse=True)
73
 
74
- # Return the segmented images in a JSON format
75
- return json.dumps(segmented_regions)
76
-
77
- except Exception as e:
78
- return json.dumps({"error": str(e)})
79
 
80
  # Create Gradio components
81
  input_image = gr.Textbox(label="Base64 Image", lines=8)
 
1
  import gradio as gr
2
+ import cv2
3
  import numpy as np
4
  from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
5
  import base64
 
7
  from io import BytesIO
8
  import torch
9
  import clip
 
10
 
11
  # Load the segmentation model
12
  sam_checkpoint = "sam_vit_h_4b8939.pth"
 
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
  model.to(device).eval()
20
 
21
+
22
+ def find_similarity(base64_image, text_input):
23
  try:
24
+ # Decode the base64 image to bytes
25
+ image_bytes = base64.b64decode(base64_image)
26
+
27
+ # Convert the bytes to a PIL image
28
+ image = Image.open(BytesIO(image_bytes))
29
+
30
  # Preprocess the image
 
31
  image = preprocess(image).unsqueeze(0).to(device)
32
 
33
  # Prepare input text
 
43
  text_features /= text_features.norm(dim=-1, keepdim=True)
44
  similarity = (text_features @ image_features.T).squeeze(0).cpu().numpy()
45
 
46
+ return similarity
 
47
  except Exception as e:
48
+ return str(e)
49
 
50
+ # Define a function for image segmentation
51
  def segment_image(input_image, text_input):
52
+ image_bytes = base64.b64decode(input_image)
53
+ image = Image.open(BytesIO(image_bytes))
54
+
55
+ image = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)
56
+ mask_generator = SamAutomaticMaskGenerator(sam)
57
+ masks = mask_generator.generate(image)
58
 
59
+ segmented_regions = [] # List to store segmented regions with similarity scores
 
60
 
61
+ for i, mask_dict in enumerate(masks):
62
+ mask_data = (mask_dict['segmentation'] * 255).astype(np.uint8)
63
+ segmented_region = cv2.bitwise_and(image, image, mask=mask_data)
64
 
65
+ x, y, w, h = map(int, mask_dict['bbox'])
66
+ cropped_region = segmented_region[y:y+h, x:x+w]
 
67
 
68
+ # Convert to base64 image
69
+ _, buffer = cv2.imencode(".png", cv2.cvtColor(cropped_region, cv2.COLOR_BGR2RGB))
70
+ segmented_image_base64 = base64.b64encode(buffer).decode()
71
 
72
+ # Calculate similarity for the segmented image
73
+ similarity = find_similarity(segmented_image_base64, text_input)
74
 
75
+ # Append the segmented image and its similarity score
76
+ segmented_regions.append({"image": segmented_image_base64, "similarity": similarity})
77
 
78
+ # Sort the segmented images by similarity in descending order
79
+ segmented_regions.sort(key=lambda x: x["similarity"], reverse=True)
80
 
81
+ # Return the segmented images in descending order of similarity
82
+ return segmented_regions
 
 
 
83
 
84
  # Create Gradio components
85
  input_image = gr.Textbox(label="Base64 Image", lines=8)