File size: 1,829 Bytes
8ab94bb
 
 
 
 
28d01e2
8ab94bb
 
 
 
 
 
 
 
 
 
 
 
28d01e2
 
8ab94bb
 
 
 
 
 
 
 
 
 
 
 
28d01e2
 
 
8ab94bb
28d01e2
 
 
 
 
8ab94bb
28d01e2
8ab94bb
 
 
28d01e2
8ab94bb
 
28d01e2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import gradio as gr
import cv2
import numpy as np
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
import base64
from huggingface_hub import InferenceClient

# Load the segmentation model
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)

# Define a function for image segmentation
def segment_image(input_image):
    image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
    mask_generator = SamAutomaticMaskGenerator(sam)
    masks = mask_generator.generate(image)

    highest_cosine_value = -1
    highest_cosine_base64 = ""

    for i, mask_dict in enumerate(masks):
        mask_data = (mask_dict['segmentation'] * 255).astype(np.uint8)
        segmented_region = cv2.bitwise_and(input_image, input_image, mask=mask_data)

        x, y, w, h = map(int, mask_dict['bbox'])
        cropped_region = segmented_region[y:y+h, x:x+w]

        # Convert to base64 image
        _, buffer = cv2.imencode(".png", cv2.cvtColor(cropped_region, cv2.COLOR_BGR2RGB))
        segmented_image_base64 = base64.b64encode(buffer).decode()

        # Call the API to get the cosine similarity
        client = InferenceClient()
        result = client.post(json={"inputs": segmented_image_base64}, model="https://ktllc-clip-model-inputbase64.hf.space/--replicas/mmz7z/")

        cosine_similarity = result[0].get("score", 0.0)

        if cosine_similarity > highest_cosine_value:
            highest_cosine_value = cosine_similarity
            highest_cosine_base64 = segmented_image_base64

    return highest_cosine_base64

# Create Gradio components
input_image = gr.inputs.Image()
output_image = gr.outputs.Image(type="pil")

# Create a Gradio interface
gr.Interface(fn=segment_image, inputs=input_image, outputs=output_image).launch()