ktllc commited on
Commit
f687418
·
1 Parent(s): 8ea55d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -6,6 +6,7 @@ from PIL import Image
6
  from io import BytesIO
7
  import torch
8
  import clip
 
9
 
10
  # Load the segmentation model
11
  sam_checkpoint = "sam_vit_h_4b8939.pth"
@@ -17,7 +18,6 @@ model, preprocess = clip.load("ViT-B/32")
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
  model.to(device).eval()
19
 
20
-
21
  def find_similarity(base64_image, text_input):
22
  try:
23
  # Decode the base64 image to bytes
@@ -44,9 +44,8 @@ def find_similarity(base64_image, text_input):
44
 
45
  return similarity
46
  except Exception as e:
47
- return str(e)
48
 
49
- # Define a function for image segmentation
50
  def segment_image(input_image, text_input):
51
  try:
52
  image_bytes = base64.b64decode(input_image)
@@ -81,11 +80,11 @@ def segment_image(input_image, text_input):
81
  # Sort the segmented images by similarity in descending order
82
  segmented_regions.sort(key=lambda x: x["similarity"], reverse=True)
83
 
84
- # Return the segmented images in descending order of similarity
85
- return segmented_regions
86
 
87
  except Exception as e:
88
- return str(e)
89
 
90
  # Create Gradio components
91
  input_image = gr.Textbox(label="Base64 Image", lines=8)
 
6
  from io import BytesIO
7
  import torch
8
  import clip
9
+ import json
10
 
11
  # Load the segmentation model
12
  sam_checkpoint = "sam_vit_h_4b8939.pth"
 
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
  model.to(device).eval()
20
 
 
21
  def find_similarity(base64_image, text_input):
22
  try:
23
  # Decode the base64 image to bytes
 
44
 
45
  return similarity
46
  except Exception as e:
47
+ return json.dumps({"error": str(e)})
48
 
 
49
  def segment_image(input_image, text_input):
50
  try:
51
  image_bytes = base64.b64decode(input_image)
 
80
  # Sort the segmented images by similarity in descending order
81
  segmented_regions.sort(key=lambda x: x["similarity"], reverse=True)
82
 
83
+ # Return the segmented images in a JSON format
84
+ return json.dumps(segmented_regions)
85
 
86
  except Exception as e:
87
+ return json.dumps({"error": str(e})
88
 
89
  # Create Gradio components
90
  input_image = gr.Textbox(label="Base64 Image", lines=8)