File size: 3,392 Bytes
8ab94bb
 
 
 
d666f15
 
 
 
8ab94bb
 
 
 
 
 
d666f15
8eb4c63
d666f15
8eb4c63
 
c0b02e4
d666f15
608f6fc
 
 
d666f15
608f6fc
251edb4
d666f15
608f6fc
 
d666f15
608f6fc
 
d666f15
608f6fc
 
 
 
d666f15
608f6fc
 
 
 
d666f15
608f6fc
 
 
d666f15
 
 
8ea55d3
 
 
 
 
 
 
 
 
8ab94bb
8ea55d3
8ab94bb
8ea55d3
 
 
8ab94bb
8ea55d3
 
8ab94bb
8ea55d3
 
 
 
8ab94bb
8ea55d3
 
28d01e2
8ea55d3
 
8ab94bb
8ea55d3
 
 
 
 
 
 
 
8ab94bb
 
cf83b7c
fbf28e6
d666f15
8ab94bb
 
d666f15
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import gradio as gr
import numpy as np
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
import base64
from PIL import Image
from io import BytesIO
import torch
import clip

# Load the segmentation model
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)

# Load the CLIP model
model, preprocess = clip.load("ViT-B/32")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device).eval()


def find_similarity(base64_image, text_input):
    try:
        # Decode the base64 image to bytes
        image_bytes = base64.b64decode(base64_image)

        # Convert the bytes to a PIL image
        image = Image.open(BytesIO(image_bytes))

        # Preprocess the image
        image = preprocess(image).unsqueeze(0).to(device)

        # Prepare input text
        text_tokens = clip.tokenize([text_input]).to(device)

        # Encode image and text features
        with torch.no_grad():
            image_features = model.encode_image(image)
            text_features = model.encode_text(text_tokens)

        # Normalize features and calculate similarity
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        similarity = (text_features @ image_features.T).squeeze(0).cpu().numpy()

        return similarity
    except Exception as e:
        return str(e)

# Define a function for image segmentation
def segment_image(input_image, text_input):
    try:
        image_bytes = base64.b64decode(input_image)
        image = Image.open(BytesIO(image_bytes))

        # Convert the image to RGB color mode
        image = image.convert("RGB")

        mask_generator = SamAutomaticMaskGenerator(sam)
        masks = mask_generator.generate(image)

        segmented_regions = []  # List to store segmented regions with similarity scores

        for i, mask_dict in enumerate(masks):
            mask_data = (mask_dict['segmentation'] * 255).astype(np.uint8)
            segmented_region = Image.fromarray(mask_data)  # Convert mask to an image

            x, y, w, h = map(int, mask_dict['bbox'])
            cropped_region = image.crop((x, y, x + w, y + h))

            # Convert to base64 image
            buffered = BytesIO()
            cropped_region.save(buffered, format="PNG")
            segmented_image_base64 = base64.b64encode(buffered.getvalue()).decode()

            # Calculate similarity for the segmented image
            similarity = find_similarity(segmented_image_base64, text_input)

            # Append the segmented image and its similarity score
            segmented_regions.append({"image": segmented_image_base64, "similarity": similarity})

        # Sort the segmented images by similarity in descending order
        segmented_regions.sort(key=lambda x: x["similarity"], reverse=True)

        # Return the segmented images in descending order of similarity
        return segmented_regions

    except Exception as e:
        return str(e)

# Create Gradio components
input_image = gr.Textbox(label="Base64 Image", lines=8)
text_input = gr.Textbox(label="Text Input")  # Use Textbox with a label
output_images = gr.outputs.JSON()

# Create a Gradio interface
gr.Interface(fn=segment_image, inputs=[input_image, text_input], outputs=output_images).launch()