File size: 4,403 Bytes
8ab94bb
f21cc03
8ab94bb
 
 
d666f15
 
 
 
50cad22
8ab94bb
 
 
 
 
 
d666f15
8eb4c63
d666f15
8eb4c63
 
f21cc03
 
608f6fc
f21cc03
 
 
 
 
 
608f6fc
 
d666f15
608f6fc
 
d666f15
608f6fc
 
 
 
d666f15
608f6fc
 
 
 
d666f15
f21cc03
608f6fc
f21cc03
d666f15
 
f21cc03
 
cb3fbdc
74e5368
 
 
 
 
cb3fbdc
f21cc03
 
8ea55d3
f21cc03
8ab94bb
f21cc03
 
74e5368
 
 
 
 
 
 
8ab94bb
f21cc03
 
8ab94bb
914eb81
d3f5c13
f21cc03
8ab94bb
f21cc03
 
28d01e2
f21cc03
 
8ab94bb
f21cc03
 
8ea55d3
257f0a7
 
e1a9402
50cad22
 
e1a9402
50cad22
 
 
 
 
 
 
 
8b4232b
50cad22
8b4232b
 
50cad22
8b4232b
 
257f0a7
f21cc03
50cad22
8ab94bb
 
cf83b7c
fbf28e6
8b4232b
8ab94bb
 
8b4232b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
import cv2
import numpy as np
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
import base64
from PIL import Image
from io import BytesIO
import torch
import clip
from gradio_client import Client

# Load the segmentation model
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)

# Load the CLIP model
model, preprocess = clip.load("ViT-B/32")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device).eval()


def find_similarity(base64_image, text_input):
    try:
        # Decode the base64 image to bytes
        image_bytes = base64.b64decode(base64_image)

        # Convert the bytes to a PIL image
        image = Image.open(BytesIO(image_bytes))

        # Preprocess the image
        image = preprocess(image).unsqueeze(0).to(device)

        # Prepare input text
        text_tokens = clip.tokenize([text_input]).to(device)

        # Encode image and text features
        with torch.no_grad():
            image_features = model.encode_image(image)
            text_features = model.encode_text(text_tokens)

        # Normalize features and calculate similarity
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        similarity = (text_features @ image_features.T).squeeze(0).cpu().numpy()

        return similarity
    except Exception as e:
        return str(e)

def segment_image(input_image, text_input):
    image_bytes = base64.b64decode(input_image)
    image = Image.open(BytesIO(image_bytes))

    # Convert the image to RGB color mode
    image = image.convert("RGB")

    # Convert the image to a numpy array
    image = np.array(image)

    mask_generator = SamAutomaticMaskGenerator(sam)
    masks = mask_generator.generate(image)

    segmented_regions = []  # List to store segmented regions with similarity scores

    for i, mask_dict in enumerate(masks):
        mask_data = (mask_dict['segmentation'] * 255).astype(np.uint8)

        # Create a mask with the same shape as the original image
        mask = np.zeros_like(image)
        mask[:, :] = mask_data[:, :, np.newaxis]

        # Apply the mask to the original image
        segmented_region = cv2.bitwise_and(image, mask)

        x, y, w, h = map(int, mask_dict['bbox'])
        cropped_region = segmented_region[y:y+h, x:x+w]

        # Convert to base64 image
        _, buffer = cv2.imencode(".png", cv2.cvtColor(cropped_region, cv2.COLOR_BGR2RGB))
        segmented_image_base64 = base64.b64encode(buffer).decode()

        # Calculate similarity for the segmented image
        similarity = find_similarity(segmented_image_base64, text_input)

        # Append the segmented image and its similarity score
        segmented_regions.append({"image": segmented_image_base64, "similarity": similarity})

    # Sort the segmented images by similarity in descending order
    segmented_regions.sort(key=lambda x: x["similarity"], reverse=True)

    # Limit the output to the top 6 key-value pairs
    segmented_regions = segmented_regions[:6]

    # Run each segmented image through the inference API and store results
    results_with_similarity = []
    for result in segmented_regions:
        image_base64 = result["image"]
        similarity = result["similarity"]
        
        # Make an API call with image_base64, get the API result
        client = Client("https://ktllc-clip-model-inputbase64.hf.space/--replicas/dv889/")
        api_result = client.predict(image_base64, text_input, api_name="/predict")

        # Append the API result and similarity to the list
        results_with_similarity.append({"api_result": api_result, "image": image_base64, "similarity": similarity})

    results_with_similarity = results_with_similarity[:1]
        
    # Print the top 6 results
    # for result in results_with_similarity:
    #     print(result)
    
    # Return the segmented images in descending order of similarity
    return results_with_similarity

# Create Gradio components
input_image = gr.Textbox(label="Base64 Image", lines=8)
text_input = gr.Textbox(label="Text Input")  # Use Textbox with a label
output_images = gr.outputs.JSON()

# Create a Gradio interface
gr.Interface(fn=segment_image, inputs=[input_image, text_input], outputs=output_images).launch()