ktllc commited on
Commit
28d01e2
·
1 Parent(s): a3c777d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -34
app.py CHANGED
@@ -3,7 +3,7 @@ import cv2
3
  import numpy as np
4
  from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
5
  import base64
6
- from gradio_client import Client
7
 
8
  # Load the segmentation model
9
  sam_checkpoint = "sam_vit_h_4b8939.pth"
@@ -16,7 +16,8 @@ def segment_image(input_image):
16
  mask_generator = SamAutomaticMaskGenerator(sam)
17
  masks = mask_generator.generate(image)
18
 
19
- segmented_regions = [] # List to store segmented regions
 
20
 
21
  for i, mask_dict in enumerate(masks):
22
  mask_data = (mask_dict['segmentation'] * 255).astype(np.uint8)
@@ -28,43 +29,22 @@ def segment_image(input_image):
28
  # Convert to base64 image
29
  _, buffer = cv2.imencode(".png", cv2.cvtColor(cropped_region, cv2.COLOR_BGR2RGB))
30
  segmented_image_base64 = base64.b64encode(buffer).decode()
31
- segmented_regions.append(segmented_image_base64) # Add to the list
32
 
33
- return segmented_regions
 
 
34
 
35
- # Function to call the API and calculate cosine similarity
36
- def calculate_cosine_similarity(segmented_images):
37
- highest_cosine = -1
38
- highest_cosine_base64 = ""
39
-
40
- client = Client("https://ktllc-clip-model-inputbase64.hf.space/--replicas/mmz7z/")
41
-
42
- for base64_image in segmented_images:
43
- # Call the API here using the base64 image
44
- result = client.predict(base64_image, base64_image, api_name="/predict")
45
-
46
- cosine_value = result['similarity']
47
- print(f"Base64 Image: {base64_image}, Cosine Similarity: {cosine_value}")
48
-
49
- if cosine_value > highest_cosine:
50
- highest_cosine = cosine_value
51
- highest_cosine_base64 = base64_image
52
 
53
- print(f"Highest Cosine Similarity: {highest_cosine} (Base64 Image: {highest_cosine_base64})")
54
 
55
  # Create Gradio components
56
  input_image = gr.inputs.Image()
57
- output_images = gr.outputs.JSON()
58
 
59
  # Create a Gradio interface
60
- segmentation_interface = gr.Interface(fn=segment_image, inputs=input_image, outputs=output_images)
61
-
62
- # Launch the segmentation interface
63
- segmentation_interface.launch()
64
-
65
- # Get the segmented images from the segmentation interface
66
- segmented_images = segmentation_interface.run()
67
- segmentation_interface.close()
68
-
69
- # Call the API for each segmented image and calculate cosine similarity
70
- calculate_cosine_similarity(segmented_images)
 
3
  import numpy as np
4
  from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
5
  import base64
6
+ from huggingface_hub import InferenceClient
7
 
8
  # Load the segmentation model
9
  sam_checkpoint = "sam_vit_h_4b8939.pth"
 
16
  mask_generator = SamAutomaticMaskGenerator(sam)
17
  masks = mask_generator.generate(image)
18
 
19
+ highest_cosine_value = -1
20
+ highest_cosine_base64 = ""
21
 
22
  for i, mask_dict in enumerate(masks):
23
  mask_data = (mask_dict['segmentation'] * 255).astype(np.uint8)
 
29
  # Convert to base64 image
30
  _, buffer = cv2.imencode(".png", cv2.cvtColor(cropped_region, cv2.COLOR_BGR2RGB))
31
  segmented_image_base64 = base64.b64encode(buffer).decode()
 
32
 
33
+ # Call the API to get the cosine similarity
34
+ client = InferenceClient()
35
+ result = client.post(json={"inputs": segmented_image_base64}, model="https://ktllc-clip-model-inputbase64.hf.space/--replicas/mmz7z/")
36
 
37
+ cosine_similarity = result[0].get("score", 0.0)
38
+
39
+ if cosine_similarity > highest_cosine_value:
40
+ highest_cosine_value = cosine_similarity
41
+ highest_cosine_base64 = segmented_image_base64
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ return highest_cosine_base64
44
 
45
  # Create Gradio components
46
  input_image = gr.inputs.Image()
47
+ output_image = gr.outputs.Image(type="pil")
48
 
49
  # Create a Gradio interface
50
+ gr.Interface(fn=segment_image, inputs=input_image, outputs=output_image).launch()