ktllc commited on
Commit
1f09605
·
1 Parent(s): 8be433f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -24
app.py CHANGED
@@ -7,7 +7,7 @@ from PIL import Image
7
  from io import BytesIO
8
  import torch
9
  import clip
10
- from gradio_client import Client
11
 
12
  # Load the segmentation model
13
  sam_checkpoint = "sam_vit_h_4b8939.pth"
@@ -92,34 +92,13 @@ def segment_image(input_image, text_input):
92
  # Limit the output to the top 6 key-value pairs
93
  segmented_regions = segmented_regions[:6]
94
 
95
- # # Run each segmented image through the inference API and store results
96
- # results_with_similarity = []
97
- # for result in segmented_regions:
98
- # image_base64 = result["image"]
99
- # similarity = result["similarity"]
100
-
101
- # # Make an API call with image_base64, get the API result
102
- # client = Client("https://ktllc-clip-model-inputbase64.hf.space/--replicas/dv889/")
103
- # api_result = client.predict(image_base64, text_input, api_name="/predict")
104
-
105
- # # Append the API result and similarity to the list
106
- # results_with_similarity.append({"api_result": api_result, "image": image_base64, "similarity": similarity})
107
-
108
- # results_with_similarity.sort(key=lambda x: x["similarity"], reverse=True)
109
-
110
- # results_with_similarity = results_with_similarity[:1]
111
-
112
- # Print the top 6 results
113
- # for result in results_with_similarity:
114
- # print(result)
115
-
116
  # Return the segmented images in descending order of similarity
117
  return segmented_regions
118
 
119
  # Create Gradio components
120
  input_image = gr.Textbox(label="Base64 Image", lines=8)
121
  text_input = gr.Textbox(label="Text Input") # Use Textbox with a label
122
- output_images = gr.outputs.JSON()
123
 
124
  # Create a Gradio interface
125
- gr.Interface(fn=segment_image, inputs=[input_image, text_input], outputs=output_images).launch()
 
7
  from io import BytesIO
8
  import torch
9
  import clip
10
+
11
 
12
  # Load the segmentation model
13
  sam_checkpoint = "sam_vit_h_4b8939.pth"
 
92
  # Limit the output to the top 6 key-value pairs
93
  segmented_regions = segmented_regions[:6]
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  # Return the segmented images in descending order of similarity
96
  return segmented_regions
97
 
98
  # Create Gradio components
99
  input_image = gr.Textbox(label="Base64 Image", lines=8)
100
  text_input = gr.Textbox(label="Text Input") # Use Textbox with a label
101
+ #output_images = gr.outputs.JSON()
102
 
103
  # Create a Gradio interface
104
+ gr.Interface(fn=segment_image, inputs=[input_image, text_input], outputs="text").launch()