Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
from PIL import Image, ImageDraw | |
import torch | |
from torchvision import transforms | |
from transformers import AutoModelForImageClassification, AutoFeatureExtractor | |
# Define all available models | |
MODEL_LIST = { | |
'beit': "microsoft/beit-base-patch16-224-pt22k-ft22k", | |
'vit': "google/vit-base-patch16-224", | |
'convnext': "facebook/convnext-tiny-224", | |
} | |
# Preprocessing transforms | |
def get_preprocessor(model_name): | |
extractor = AutoFeatureExtractor.from_pretrained(MODEL_LIST[model_name]) | |
return extractor | |
# Load a model from Hugging Face | |
def load_model(model_name): | |
model = AutoModelForImageClassification.from_pretrained(MODEL_LIST[model_name]).cuda().eval() | |
return model | |
# Function to make predictions | |
def predict(image, model, preprocessor): | |
inputs = preprocessor(images=image, return_tensors="pt").to("cuda") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
predicted_class = torch.argmax(outputs.logits, dim=1).item() | |
return model.config.id2label[predicted_class] | |
# Function to draw a rectangle on the image | |
def draw_rectangle(image, x, y, size=224): | |
image_pil = image.copy() # Create a copy to avoid modifying the original image | |
draw = ImageDraw.Draw(image_pil) | |
x1, y1 = x, y | |
x2, y2 = x + size, y + size | |
draw.rectangle([x1, y1, x2, y2], outline="red", width=5) | |
return image_pil | |
# Function to crop the image | |
def crop_image(image, x, y, size=224): | |
image_np = np.array(image) | |
h, w, _ = image_np.shape | |
x = min(max(x, 0), w - size) | |
y = min(max(y, 0), h - size) | |
cropped = image_np[y:y+size, x:x+size] | |
return Image.fromarray(cropped) | |
# Global variables | |
current_model = None | |
current_preprocessor = None | |
# Gradio Interface | |
with gr.Blocks() as demo: | |
gr.Markdown("## Test Public Models for Coral Classification") | |
with gr.Row(): | |
with gr.Column(): | |
model_selector = gr.Dropdown(choices=list(MODEL_LIST.keys()), value='beit', label="Select Model") | |
image_input = gr.Image(type="pil", label="Upload Image", interactive=True) | |
x_slider = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="X Coordinate") | |
y_slider = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="Y Coordinate") | |
with gr.Column(): | |
interactive_image = gr.Image(label="Interactive Image with Selection") | |
cropped_image = gr.Image(label="Cropped Patch") | |
label_output = gr.Textbox(label="Predicted Label") | |
# Update the current model and preprocessor | |
def update_model(model_name): | |
global current_model, current_preprocessor | |
current_model = load_model(model_name) | |
current_preprocessor = get_preprocessor(model_name) | |
return f"Model {model_name} loaded successfully." | |
# Update the rectangle and crop the patch | |
def update_selection(image, x, y): | |
overlay_image = draw_rectangle(image, x, y) | |
cropped = crop_image(image, x, y) | |
return overlay_image, cropped | |
# Predict the label from the cropped patch | |
def predict_from_cropped(cropped): | |
return predict(cropped, current_model, current_preprocessor) | |
# Buttons and interactions | |
crop_button = gr.Button("Crop") | |
crop_button.click(fn=update_selection, inputs=[image_input, x_slider, y_slider], outputs=[interactive_image, cropped_image]) | |
predict_button = gr.Button("Predict") | |
predict_button.click(fn=predict_from_cropped, inputs=cropped_image, outputs=label_output) | |
model_selector.change(fn=update_model, inputs=model_selector, outputs=None) | |
# Update sliders dynamically based on uploaded image size | |
def update_sliders(image): | |
if image is not None: | |
width, height = image.size | |
return gr.update(maximum=width - 224), gr.update(maximum=height - 224) | |
return gr.update(), gr.update() | |
image_input.change(fn=update_sliders, inputs=image_input, outputs=[x_slider, y_slider]) | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |