reefnet_demo_1.0 / app _bk.py
yahiab
fix
471d95f
raw
history blame
4.6 kB
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw
import torch
from torchvision import transforms
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
# Define all available models
MODEL_LIST = {
'beit': "microsoft/beit-base-patch16-224-pt22k-ft22k",
'vit': "google/vit-base-patch16-224",
'convnext': "facebook/convnext-tiny-224",
}
# Global variables
current_model = None
current_preprocessor = None
device = "cuda" if torch.cuda.is_available() else "cpu" # Dynamically set device
# Load model and preprocessor
def load_model_and_preprocessor(model_name):
"""Load model and preprocessor for a given model name."""
global current_model, current_preprocessor
print(f"Loading model and preprocessor for: {model_name} on {device}")
current_model = AutoModelForImageClassification.from_pretrained(MODEL_LIST[model_name]).to(device).eval()
current_preprocessor = AutoFeatureExtractor.from_pretrained(MODEL_LIST[model_name])
return f"Model {model_name} loaded successfully on {device}."
# Predict function
def predict(image, model, preprocessor):
"""Make a prediction on the given image patch using the loaded model."""
if model is None or preprocessor is None:
raise ValueError("Model and preprocessor are not loaded.")
inputs = preprocessor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
predicted_class = torch.argmax(outputs.logits, dim=1).item()
return model.config.id2label[predicted_class]
# Function to draw a rectangle on the image
def draw_rectangle(image, x, y, size=224):
"""Draw a rectangle on the image."""
image_pil = image.copy() # Create a copy to avoid modifying the original image
draw = ImageDraw.Draw(image_pil)
x1, y1 = x, y
x2, y2 = x + size, y + size
draw.rectangle([x1, y1, x2, y2], outline="red", width=5)
return image_pil
# Function to crop the image
def crop_image(image, x, y, size=224):
"""Crop a region from the image."""
image_np = np.array(image)
h, w, _ = image_np.shape
x = min(max(x, 0), w - size)
y = min(max(y, 0), h - size)
cropped = image_np[y:y+size, x:x+size]
return Image.fromarray(cropped)
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("## Test Public Models for Coral Classification")
with gr.Row():
with gr.Column():
model_selector = gr.Dropdown(choices=list(MODEL_LIST.keys()), value='beit', label="Select Model")
image_input = gr.Image(type="pil", label="Upload Image", interactive=True)
x_slider = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="X Coordinate")
y_slider = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="Y Coordinate")
with gr.Column():
interactive_image = gr.Image(label="Interactive Image with Selection")
cropped_image = gr.Image(label="Cropped Patch")
label_output = gr.Textbox(label="Predicted Label")
# Update the model and preprocessor
def update_model(model_name):
return load_model_and_preprocessor(model_name)
# Update the rectangle and crop the patch
def update_selection(image, x, y):
overlay_image = draw_rectangle(image, x, y)
cropped = crop_image(image, x, y)
return overlay_image, cropped
# Predict the label from the cropped patch
def predict_from_cropped(cropped):
print(f"Type of cropped_image before prediction: {type(cropped)}")
return predict(cropped, current_model, current_preprocessor)
# Buttons and interactions
crop_button = gr.Button("Crop")
crop_button.click(fn=update_selection, inputs=[image_input, x_slider, y_slider], outputs=[interactive_image, cropped_image])
predict_button = gr.Button("Predict")
predict_button.click(fn=predict_from_cropped, inputs=cropped_image, outputs=label_output)
model_selector.change(fn=update_model, inputs=model_selector, outputs=None)
# Update sliders dynamically based on uploaded image size
def update_sliders(image):
if image is not None:
width, height = image.size
return gr.update(maximum=width - 224), gr.update(maximum=height - 224)
return gr.update(), gr.update()
image_input.change(fn=update_sliders, inputs=image_input, outputs=[x_slider, y_slider])
# Initialize model on app start
demo.load(fn=lambda: load_model_and_preprocessor('beit'), inputs=None, outputs=None)
demo.launch(server_name="0.0.0.0", server_port=7860)