ktllc commited on
Commit
7df14c5
·
1 Parent(s): de97c4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -25,7 +25,7 @@ def find_similarity(base64_image, text_input):
25
  image_bytes = base64.b64decode(base64_image)
26
 
27
  # Convert the bytes to a PIL image
28
- image = Image.open(BytesIO(image_bytes))
29
 
30
  # Preprocess the image
31
  image = preprocess(image).unsqueeze(0).to(device)
@@ -52,10 +52,10 @@ def segment_image(input_image, text_input):
52
  image_bytes = base64.b64decode(input_image)
53
  image = Image.open(BytesIO(image_bytes))
54
 
55
- image = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)
56
  mask_generator = SamAutomaticMaskGenerator(sam)
57
  masks = mask_generator.generate(image)
58
-
59
  segmented_regions = [] # List to store segmented regions with similarity scores
60
 
61
  for i, mask_dict in enumerate(masks):
@@ -66,7 +66,7 @@ def segment_image(input_image, text_input):
66
  cropped_region = segmented_region[y:y+h, x:x+w]
67
 
68
  # Convert to base64 image
69
- _, buffer = cv2.imencode(".png", cv2.cvtColor(cropped_region, cv2.COLOR_BGR2RGB))
70
  segmented_image_base64 = base64.b64encode(buffer).decode()
71
 
72
  # Calculate similarity for the segmented image
@@ -88,4 +88,3 @@ output_images = gr.outputs.JSON()
88
 
89
  # Create a Gradio interface
90
  gr.Interface(fn=segment_image, inputs=[input_image, text_input], outputs=output_images).launch()
91
-
 
25
  image_bytes = base64.b64decode(base64_image)
26
 
27
  # Convert the bytes to a PIL image
28
+ image = Image.open(BytesIO(image_bytes)
29
 
30
  # Preprocess the image
31
  image = preprocess(image).unsqueeze(0).to(device)
 
52
  image_bytes = base64.b64decode(input_image)
53
  image = Image.open(BytesIO(image_bytes))
54
 
55
+ image = np.array(image) # Remove the color mode conversion
56
  mask_generator = SamAutomaticMaskGenerator(sam)
57
  masks = mask_generator.generate(image)
58
+
59
  segmented_regions = [] # List to store segmented regions with similarity scores
60
 
61
  for i, mask_dict in enumerate(masks):
 
66
  cropped_region = segmented_region[y:y+h, x:x+w]
67
 
68
  # Convert to base64 image
69
+ _, buffer = cv2.imencode(".png", cropped_region)
70
  segmented_image_base64 = base64.b64encode(buffer).decode()
71
 
72
  # Calculate similarity for the segmented image
 
88
 
89
  # Create a Gradio interface
90
  gr.Interface(fn=segment_image, inputs=[input_image, text_input], outputs=output_images).launch()