ktllc commited on
Commit
50cad22
·
1 Parent(s): 618ddfd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -3
app.py CHANGED
@@ -7,6 +7,7 @@ from PIL import Image
7
  from io import BytesIO
8
  import torch
9
  import clip
 
10
 
11
  # Load the segmentation model
12
  sam_checkpoint = "sam_vit_h_4b8939.pth"
@@ -47,7 +48,6 @@ def find_similarity(base64_image, text_input):
47
  except Exception as e:
48
  return str(e)
49
 
50
- # Define a function for image segmentation
51
  def segment_image(input_image, text_input):
52
  image_bytes = base64.b64decode(input_image)
53
  image = Image.open(BytesIO(image_bytes))
@@ -92,12 +92,25 @@ def segment_image(input_image, text_input):
92
  # Limit the output to the top 6 key-value pairs
93
  segmented_regions = segmented_regions[:6]
94
 
95
- # Print the top 6 results
 
96
  for result in segmented_regions:
 
 
 
 
 
 
 
 
 
 
 
 
97
  print(result)
98
 
99
  # Return the segmented images in descending order of similarity
100
- return segmented_regions
101
 
102
  # Create Gradio components
103
  input_image = gr.Textbox(label="Base64 Image", lines=8)
 
7
  from io import BytesIO
8
  import torch
9
  import clip
10
+ from gradio_client import Client
11
 
12
  # Load the segmentation model
13
  sam_checkpoint = "sam_vit_h_4b8939.pth"
 
48
  except Exception as e:
49
  return str(e)
50
 
 
51
  def segment_image(input_image, text_input):
52
  image_bytes = base64.b64decode(input_image)
53
  image = Image.open(BytesIO(image_bytes))
 
92
  # Limit the output to the top 6 key-value pairs
93
  segmented_regions = segmented_regions[:6]
94
 
95
+ # Run each segmented image through the inference API and store results
96
+ results_with_similarity = []
97
  for result in segmented_regions:
98
+ image_base64 = result["image"]
99
+ similarity = result["similarity"]
100
+
101
+ # Make an API call with image_base64, get the API result
102
+ client = Client("https://ktllc-clip-model-inputbase64.hf.space/--replicas/dv889/")
103
+ api_result = client.predict(image_base64, text_input, api_name="/predict")
104
+
105
+ # Append the API result and similarity to the list
106
+ results_with_similarity.append({"api_result": api_result, "similarity": similarity})
107
+
108
+ # Print the top 6 results
109
+ for result in results_with_similarity:
110
  print(result)
111
 
112
  # Return the segmented images in descending order of similarity
113
+ return results_with_similarity
114
 
115
  # Create Gradio components
116
  input_image = gr.Textbox(label="Base64 Image", lines=8)