ktllc commited on
Commit
8ea55d3
·
1 Parent(s): 251edb4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -25
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import gradio as gr
2
- import cv2
3
  import numpy as np
4
  from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
5
  import base64
@@ -49,37 +48,44 @@ def find_similarity(base64_image, text_input):
49
 
50
  # Define a function for image segmentation
51
  def segment_image(input_image, text_input):
52
- image_bytes = base64.b64decode(input_image)
53
- image = Image.open(BytesIO(image_bytes))
54
-
55
- image = np.array(image) # Remove the color mode conversion
56
- mask_generator = SamAutomaticMaskGenerator(sam)
57
- masks = mask_generator.generate(image)
58
-
59
- segmented_regions = [] # List to store segmented regions with similarity scores
 
60
 
61
- for i, mask_dict in enumerate(masks):
62
- mask_data = (mask_dict['segmentation'] * 255).astype(np.uint8)
63
- segmented_region = cv2.bitwise_and(image, image, mask=mask_data)
64
 
65
- x, y, w, h = map(int, mask_dict['bbox'])
66
- cropped_region = segmented_region[y:y+h, x:x+w]
 
67
 
68
- # Convert to base64 image
69
- _, buffer = cv2.imencode(".png", cropped_region)
70
- segmented_image_base64 = base64.b64encode(buffer).decode()
71
 
72
- # Calculate similarity for the segmented image
73
- similarity = find_similarity(segmented_image_base64, text_input)
 
 
74
 
75
- # Append the segmented image and its similarity score
76
- segmented_regions.append({"image": segmented_image_base64, "similarity": similarity})
77
 
78
- # Sort the segmented images by similarity in descending order
79
- segmented_regions.sort(key=lambda x: x["similarity"], reverse=True)
80
 
81
- # Return the segmented images in descending order of similarity
82
- return segmented_regions
 
 
 
 
 
 
83
 
84
  # Create Gradio components
85
  input_image = gr.Textbox(label="Base64 Image", lines=8)
 
1
  import gradio as gr
 
2
  import numpy as np
3
  from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
4
  import base64
 
48
 
49
  # Define a function for image segmentation
50
  def segment_image(input_image, text_input):
51
+ try:
52
+ image_bytes = base64.b64decode(input_image)
53
+ image = Image.open(BytesIO(image_bytes))
54
+
55
+ # Convert the image to RGB color mode
56
+ image = image.convert("RGB")
57
+
58
+ mask_generator = SamAutomaticMaskGenerator(sam)
59
+ masks = mask_generator.generate(image)
60
 
61
+ segmented_regions = [] # List to store segmented regions with similarity scores
 
 
62
 
63
+ for i, mask_dict in enumerate(masks):
64
+ mask_data = (mask_dict['segmentation'] * 255).astype(np.uint8)
65
+ segmented_region = Image.fromarray(mask_data) # Convert mask to an image
66
 
67
+ x, y, w, h = map(int, mask_dict['bbox'])
68
+ cropped_region = image.crop((x, y, x + w, y + h))
 
69
 
70
+ # Convert to base64 image
71
+ buffered = BytesIO()
72
+ cropped_region.save(buffered, format="PNG")
73
+ segmented_image_base64 = base64.b64encode(buffered.getvalue()).decode()
74
 
75
+ # Calculate similarity for the segmented image
76
+ similarity = find_similarity(segmented_image_base64, text_input)
77
 
78
+ # Append the segmented image and its similarity score
79
+ segmented_regions.append({"image": segmented_image_base64, "similarity": similarity})
80
 
81
+ # Sort the segmented images by similarity in descending order
82
+ segmented_regions.sort(key=lambda x: x["similarity"], reverse=True)
83
+
84
+ # Return the segmented images in descending order of similarity
85
+ return segmented_regions
86
+
87
+ except Exception as e:
88
+ return str(e)
89
 
90
  # Create Gradio components
91
  input_image = gr.Textbox(label="Base64 Image", lines=8)