File size: 4,041 Bytes
0d11696
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw
import torch
from torchvision import transforms
from transformers import AutoModelForImageClassification, AutoFeatureExtractor

# Define all available models
MODEL_LIST = {
    'beit': "microsoft/beit-base-patch16-224-pt22k-ft22k",
    'vit': "google/vit-base-patch16-224",
    'convnext': "facebook/convnext-tiny-224",
}

# Preprocessing transforms
def get_preprocessor(model_name):
    extractor = AutoFeatureExtractor.from_pretrained(MODEL_LIST[model_name])
    return extractor

# Load a model from Hugging Face
def load_model(model_name):
    model = AutoModelForImageClassification.from_pretrained(MODEL_LIST[model_name]).cuda().eval()
    return model

# Function to make predictions
def predict(image, model, preprocessor):
    inputs = preprocessor(images=image, return_tensors="pt").to("cuda")
    with torch.no_grad():
        outputs = model(**inputs)
        predicted_class = torch.argmax(outputs.logits, dim=1).item()
    return model.config.id2label[predicted_class]

# Function to draw a rectangle on the image
def draw_rectangle(image, x, y, size=224):
    image_pil = image.copy()  # Create a copy to avoid modifying the original image
    draw = ImageDraw.Draw(image_pil)
    x1, y1 = x, y
    x2, y2 = x + size, y + size
    draw.rectangle([x1, y1, x2, y2], outline="red", width=5)
    return image_pil

# Function to crop the image
def crop_image(image, x, y, size=224):
    image_np = np.array(image)
    h, w, _ = image_np.shape
    x = min(max(x, 0), w - size)
    y = min(max(y, 0), h - size)
    cropped = image_np[y:y+size, x:x+size]
    return Image.fromarray(cropped)

# Global variables
current_model = None
current_preprocessor = None

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("## Test Public Models for Coral Classification")

    with gr.Row():
        with gr.Column():
            model_selector = gr.Dropdown(choices=list(MODEL_LIST.keys()), value='beit', label="Select Model")
            image_input = gr.Image(type="pil", label="Upload Image", interactive=True)
            x_slider = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="X Coordinate")
            y_slider = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="Y Coordinate")
        with gr.Column():
            interactive_image = gr.Image(label="Interactive Image with Selection")
            cropped_image = gr.Image(label="Cropped Patch")
            label_output = gr.Textbox(label="Predicted Label")

    # Update the current model and preprocessor
    def update_model(model_name):
        global current_model, current_preprocessor
        current_model = load_model(model_name)
        current_preprocessor = get_preprocessor(model_name)
        return f"Model {model_name} loaded successfully."

    # Update the rectangle and crop the patch
    def update_selection(image, x, y):
        overlay_image = draw_rectangle(image, x, y)
        cropped = crop_image(image, x, y)
        return overlay_image, cropped

    # Predict the label from the cropped patch
    def predict_from_cropped(cropped):
        return predict(cropped, current_model, current_preprocessor)

    # Buttons and interactions
    crop_button = gr.Button("Crop")
    crop_button.click(fn=update_selection, inputs=[image_input, x_slider, y_slider], outputs=[interactive_image, cropped_image])

    predict_button = gr.Button("Predict")
    predict_button.click(fn=predict_from_cropped, inputs=cropped_image, outputs=label_output)

    model_selector.change(fn=update_model, inputs=model_selector, outputs=None)

    # Update sliders dynamically based on uploaded image size
    def update_sliders(image):
        if image is not None:
            width, height = image.size
            return gr.update(maximum=width - 224), gr.update(maximum=height - 224)
        return gr.update(), gr.update()

    image_input.change(fn=update_sliders, inputs=image_input, outputs=[x_slider, y_slider])

demo.launch(server_name="0.0.0.0", server_port=7860)