import gradio as gr import numpy as np from PIL import Image, ImageDraw import torch from torchvision import transforms from transformers import AutoModelForImageClassification, AutoFeatureExtractor # Define all available models MODEL_LIST = { 'beit': "microsoft/beit-base-patch16-224-pt22k-ft22k", 'vit': "google/vit-base-patch16-224", 'convnext': "facebook/convnext-tiny-224", } # Global variables current_model = None current_preprocessor = None # Load model and preprocessor def load_model_and_preprocessor(model_name): """Load model and preprocessor for a given model name.""" global current_model, current_preprocessor print(f"Loading model and preprocessor for: {model_name}") current_model = AutoModelForImageClassification.from_pretrained(MODEL_LIST[model_name]).cuda().eval() current_preprocessor = AutoFeatureExtractor.from_pretrained(MODEL_LIST[model_name]) return f"Model {model_name} loaded successfully." # Predict function def predict(image, model, preprocessor): """Make a prediction on the given image patch using the loaded model.""" if model is None or preprocessor is None: raise ValueError("Model and preprocessor are not loaded.") inputs = preprocessor(images=image, return_tensors="pt").to("cuda") with torch.no_grad(): outputs = model(**inputs) predicted_class = torch.argmax(outputs.logits, dim=1).item() return model.config.id2label[predicted_class] # Function to draw a rectangle on the image def draw_rectangle(image, x, y, size=224): """Draw a rectangle on the image.""" image_pil = image.copy() # Create a copy to avoid modifying the original image draw = ImageDraw.Draw(image_pil) x1, y1 = x, y x2, y2 = x + size, y + size draw.rectangle([x1, y1, x2, y2], outline="red", width=5) return image_pil # Function to crop the image def crop_image(image, x, y, size=224): """Crop a region from the image.""" image_np = np.array(image) h, w, _ = image_np.shape x = min(max(x, 0), w - size) y = min(max(y, 0), h - size) cropped = image_np[y:y+size, x:x+size] return Image.fromarray(cropped) # Gradio Interface with gr.Blocks() as demo: gr.Markdown("## Test Public Models for Coral Classification") with gr.Row(): with gr.Column(): model_selector = gr.Dropdown(choices=list(MODEL_LIST.keys()), value='beit', label="Select Model") image_input = gr.Image(type="pil", label="Upload Image", interactive=True) x_slider = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="X Coordinate") y_slider = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="Y Coordinate") with gr.Column(): interactive_image = gr.Image(label="Interactive Image with Selection") cropped_image = gr.Image(label="Cropped Patch") label_output = gr.Textbox(label="Predicted Label") # Update the model and preprocessor def update_model(model_name): return load_model_and_preprocessor(model_name) # Update the rectangle and crop the patch def update_selection(image, x, y): overlay_image = draw_rectangle(image, x, y) cropped = crop_image(image, x, y) return overlay_image, cropped # Predict the label from the cropped patch def predict_from_cropped(cropped): print(f"Type of cropped_image before prediction: {type(cropped)}") return predict(cropped, current_model, current_preprocessor) # Buttons and interactions crop_button = gr.Button("Crop") crop_button.click(fn=update_selection, inputs=[image_input, x_slider, y_slider], outputs=[interactive_image, cropped_image]) predict_button = gr.Button("Predict") predict_button.click(fn=predict_from_cropped, inputs=cropped_image, outputs=label_output) model_selector.change(fn=update_model, inputs=model_selector, outputs=None) # Update sliders dynamically based on uploaded image size def update_sliders(image): if image is not None: width, height = image.size return gr.update(maximum=width - 224), gr.update(maximum=height - 224) return gr.update(), gr.update() image_input.change(fn=update_sliders, inputs=image_input, outputs=[x_slider, y_slider]) # Initialize model on app start demo.load(fn=lambda: load_model_and_preprocessor('beit'), inputs=None, outputs=None) demo.launch(server_name="0.0.0.0", server_port=7860)