yahiab
Add public model testing app
0d11696
raw
history blame
4.04 kB
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw
import torch
from torchvision import transforms
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
# Define all available models
MODEL_LIST = {
'beit': "microsoft/beit-base-patch16-224-pt22k-ft22k",
'vit': "google/vit-base-patch16-224",
'convnext': "facebook/convnext-tiny-224",
}
# Preprocessing transforms
def get_preprocessor(model_name):
extractor = AutoFeatureExtractor.from_pretrained(MODEL_LIST[model_name])
return extractor
# Load a model from Hugging Face
def load_model(model_name):
model = AutoModelForImageClassification.from_pretrained(MODEL_LIST[model_name]).cuda().eval()
return model
# Function to make predictions
def predict(image, model, preprocessor):
inputs = preprocessor(images=image, return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model(**inputs)
predicted_class = torch.argmax(outputs.logits, dim=1).item()
return model.config.id2label[predicted_class]
# Function to draw a rectangle on the image
def draw_rectangle(image, x, y, size=224):
image_pil = image.copy() # Create a copy to avoid modifying the original image
draw = ImageDraw.Draw(image_pil)
x1, y1 = x, y
x2, y2 = x + size, y + size
draw.rectangle([x1, y1, x2, y2], outline="red", width=5)
return image_pil
# Function to crop the image
def crop_image(image, x, y, size=224):
image_np = np.array(image)
h, w, _ = image_np.shape
x = min(max(x, 0), w - size)
y = min(max(y, 0), h - size)
cropped = image_np[y:y+size, x:x+size]
return Image.fromarray(cropped)
# Global variables
current_model = None
current_preprocessor = None
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("## Test Public Models for Coral Classification")
with gr.Row():
with gr.Column():
model_selector = gr.Dropdown(choices=list(MODEL_LIST.keys()), value='beit', label="Select Model")
image_input = gr.Image(type="pil", label="Upload Image", interactive=True)
x_slider = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="X Coordinate")
y_slider = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="Y Coordinate")
with gr.Column():
interactive_image = gr.Image(label="Interactive Image with Selection")
cropped_image = gr.Image(label="Cropped Patch")
label_output = gr.Textbox(label="Predicted Label")
# Update the current model and preprocessor
def update_model(model_name):
global current_model, current_preprocessor
current_model = load_model(model_name)
current_preprocessor = get_preprocessor(model_name)
return f"Model {model_name} loaded successfully."
# Update the rectangle and crop the patch
def update_selection(image, x, y):
overlay_image = draw_rectangle(image, x, y)
cropped = crop_image(image, x, y)
return overlay_image, cropped
# Predict the label from the cropped patch
def predict_from_cropped(cropped):
return predict(cropped, current_model, current_preprocessor)
# Buttons and interactions
crop_button = gr.Button("Crop")
crop_button.click(fn=update_selection, inputs=[image_input, x_slider, y_slider], outputs=[interactive_image, cropped_image])
predict_button = gr.Button("Predict")
predict_button.click(fn=predict_from_cropped, inputs=cropped_image, outputs=label_output)
model_selector.change(fn=update_model, inputs=model_selector, outputs=None)
# Update sliders dynamically based on uploaded image size
def update_sliders(image):
if image is not None:
width, height = image.size
return gr.update(maximum=width - 224), gr.update(maximum=height - 224)
return gr.update(), gr.update()
image_input.change(fn=update_sliders, inputs=image_input, outputs=[x_slider, y_slider])
demo.launch(server_name="0.0.0.0", server_port=7860)