ktllc commited on
Commit
8ab94bb
·
1 Parent(s): 88947a0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import numpy as np
4
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
5
+ import base64
6
+ import requests
7
+ from gradio_client import Client
8
+
9
+ # Load the segmentation model
10
+ sam_checkpoint = "sam_vit_h_4b8939.pth"
11
+ model_type = "vit_h"
12
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
13
+
14
+ # Define a function for image segmentation
15
+ def segment_image(input_image):
16
+ image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
17
+ mask_generator = SamAutomaticMaskGenerator(sam)
18
+ masks = mask_generator.generate(image)
19
+
20
+ segmented_regions = [] # List to store segmented regions
21
+
22
+ for i, mask_dict in enumerate(masks):
23
+ mask_data = (mask_dict['segmentation'] * 255).astype(np.uint8)
24
+ segmented_region = cv2.bitwise_and(input_image, input_image, mask=mask_data)
25
+
26
+ x, y, w, h = map(int, mask_dict['bbox'])
27
+ cropped_region = segmented_region[y:y+h, x:x+w]
28
+
29
+ # Convert to base64 image
30
+ _, buffer = cv2.imencode(".png", cv2.cvtColor(cropped_region, cv2.COLOR_BGR2RGB))
31
+ segmented_image_base64 = base64.b64encode(buffer).decode()
32
+ segmented_regions.append(segmented_image_base64) # Add to the list
33
+
34
+ return segmented_regions
35
+
36
+ # Function to call the API and calculate cosine similarity
37
+ def calculate_cosine_similarity(segmented_images):
38
+ highest_cosine = -1
39
+ highest_cosine_base64 = ""
40
+
41
+ client = Client("https://ktllc-clip-model-inputbase64.hf.space/--replicas/mmz7z/")
42
+
43
+ for base64_image in segmented_images:
44
+ # Call the API here using the base64 image
45
+ result = client.predict(base64_image, base64_image, api_name="/predict")
46
+
47
+ cosine_value = result['similarity']
48
+ print(f"Base64 Image: {base64_image}, Cosine Similarity: {cosine_value}")
49
+
50
+ if cosine_value > highest_cosine:
51
+ highest_cosine = cosine_value
52
+ highest_cosine_base64 = base64_image
53
+
54
+ print(f"Highest Cosine Similarity: {highest_cosine} (Base64 Image: {highest_cosine_base64})")
55
+
56
+ # Create Gradio components
57
+ input_image = gr.inputs.Image()
58
+ output_images = gr.outputs.JSON()
59
+
60
+ # Create a Gradio interface
61
+ segmentation_interface = gr.Interface(fn=segment_image, inputs=input_image, outputs=output_images)
62
+
63
+ # Launch the segmentation interface
64
+ segmentation_interface.launch()
65
+
66
+ # Get the segmented images from the segmentation interface
67
+ segmented_images = segmentation_interface.run()
68
+ segmentation_interface.close()
69
+
70
+ # Call the API for each segmented image and calculate cosine similarity
71
+ calculate_cosine_similarity(segmented_images)