File size: 4,601 Bytes
471d95f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw
import torch
from torchvision import transforms
from transformers import AutoModelForImageClassification, AutoFeatureExtractor

# Define all available models
MODEL_LIST = {
    'beit': "microsoft/beit-base-patch16-224-pt22k-ft22k",
    'vit': "google/vit-base-patch16-224",
    'convnext': "facebook/convnext-tiny-224",
}

# Global variables
current_model = None
current_preprocessor = None
device = "cuda" if torch.cuda.is_available() else "cpu"  # Dynamically set device

# Load model and preprocessor
def load_model_and_preprocessor(model_name):
    """Load model and preprocessor for a given model name."""
    global current_model, current_preprocessor
    print(f"Loading model and preprocessor for: {model_name} on {device}")
    current_model = AutoModelForImageClassification.from_pretrained(MODEL_LIST[model_name]).to(device).eval()
    current_preprocessor = AutoFeatureExtractor.from_pretrained(MODEL_LIST[model_name])
    return f"Model {model_name} loaded successfully on {device}."

# Predict function
def predict(image, model, preprocessor):
    """Make a prediction on the given image patch using the loaded model."""
    if model is None or preprocessor is None:
        raise ValueError("Model and preprocessor are not loaded.")
    inputs = preprocessor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
        predicted_class = torch.argmax(outputs.logits, dim=1).item()
    return model.config.id2label[predicted_class]

# Function to draw a rectangle on the image
def draw_rectangle(image, x, y, size=224):
    """Draw a rectangle on the image."""
    image_pil = image.copy()  # Create a copy to avoid modifying the original image
    draw = ImageDraw.Draw(image_pil)
    x1, y1 = x, y
    x2, y2 = x + size, y + size
    draw.rectangle([x1, y1, x2, y2], outline="red", width=5)
    return image_pil

# Function to crop the image
def crop_image(image, x, y, size=224):
    """Crop a region from the image."""
    image_np = np.array(image)
    h, w, _ = image_np.shape
    x = min(max(x, 0), w - size)
    y = min(max(y, 0), h - size)
    cropped = image_np[y:y+size, x:x+size]
    return Image.fromarray(cropped)

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("## Test Public Models for Coral Classification")

    with gr.Row():
        with gr.Column():
            model_selector = gr.Dropdown(choices=list(MODEL_LIST.keys()), value='beit', label="Select Model")
            image_input = gr.Image(type="pil", label="Upload Image", interactive=True)
            x_slider = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="X Coordinate")
            y_slider = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="Y Coordinate")
        with gr.Column():
            interactive_image = gr.Image(label="Interactive Image with Selection")
            cropped_image = gr.Image(label="Cropped Patch")
            label_output = gr.Textbox(label="Predicted Label")

    # Update the model and preprocessor
    def update_model(model_name):
        return load_model_and_preprocessor(model_name)

    # Update the rectangle and crop the patch
    def update_selection(image, x, y):
        overlay_image = draw_rectangle(image, x, y)
        cropped = crop_image(image, x, y)
        return overlay_image, cropped

    # Predict the label from the cropped patch
    def predict_from_cropped(cropped):
        print(f"Type of cropped_image before prediction: {type(cropped)}")
        return predict(cropped, current_model, current_preprocessor)

    # Buttons and interactions
    crop_button = gr.Button("Crop")
    crop_button.click(fn=update_selection, inputs=[image_input, x_slider, y_slider], outputs=[interactive_image, cropped_image])

    predict_button = gr.Button("Predict")
    predict_button.click(fn=predict_from_cropped, inputs=cropped_image, outputs=label_output)

    model_selector.change(fn=update_model, inputs=model_selector, outputs=None)

    # Update sliders dynamically based on uploaded image size
    def update_sliders(image):
        if image is not None:
            width, height = image.size
            return gr.update(maximum=width - 224), gr.update(maximum=height - 224)
        return gr.update(), gr.update()

    image_input.change(fn=update_sliders, inputs=image_input, outputs=[x_slider, y_slider])

    # Initialize model on app start
    demo.load(fn=lambda: load_model_and_preprocessor('beit'), inputs=None, outputs=None)

demo.launch(server_name="0.0.0.0", server_port=7860)