ktllc commited on
Commit
9223407
·
1 Parent(s): 914eb81

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -10
app.py CHANGED
@@ -52,8 +52,8 @@ def segment_image(input_image, text_input):
52
  image_bytes = base64.b64decode(input_image)
53
  image = Image.open(BytesIO(image_bytes))
54
 
55
- # Convert the image to a numpy array
56
- image = np.array(image)
57
 
58
  mask_generator = SamAutomaticMaskGenerator(sam)
59
  masks = mask_generator.generate(image)
@@ -62,19 +62,13 @@ def segment_image(input_image, text_input):
62
 
63
  for i, mask_dict in enumerate(masks):
64
  mask_data = (mask_dict['segmentation'] * 255).astype(np.uint8)
65
-
66
- # Create a mask with the same shape as the original image
67
- mask = np.zeros_like(image)
68
- mask[:, :] = mask_data[:, :, np.newaxis]
69
-
70
- # Apply the mask to the original image
71
- segmented_region = cv2.bitwise_and(image, mask)
72
 
73
  x, y, w, h = map(int, mask_dict['bbox'])
74
  cropped_region = segmented_region[y:y+h, x:x+w]
75
 
76
  # Convert to base64 image
77
- _, buffer = cv2.imencode(".png", cropped_region)
78
  segmented_image_base64 = base64.b64encode(buffer).decode()
79
 
80
  # Calculate similarity for the segmented image
 
52
  image_bytes = base64.b64decode(input_image)
53
  image = Image.open(BytesIO(image_bytes))
54
 
55
+ # Convert the image to a NumPy array
56
+ image = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB) # Convert to RGB color space
57
 
58
  mask_generator = SamAutomaticMaskGenerator(sam)
59
  masks = mask_generator.generate(image)
 
62
 
63
  for i, mask_dict in enumerate(masks):
64
  mask_data = (mask_dict['segmentation'] * 255).astype(np.uint8)
65
+ segmented_region = cv2.bitwise_and(image, image, mask=mask_data)
 
 
 
 
 
 
66
 
67
  x, y, w, h = map(int, mask_dict['bbox'])
68
  cropped_region = segmented_region[y:y+h, x:x+w]
69
 
70
  # Convert to base64 image
71
+ _, buffer = cv2.imencode(".png", cv2.cvtColor(cropped_region, cv2.COLOR_RGB2BGR)) # Convert back to BGR for encoding
72
  segmented_image_base64 = base64.b64encode(buffer).decode()
73
 
74
  # Calculate similarity for the segmented image