reefnet_demo_1.0 / app _bk.py
yahiab
Track coral images with Git LFS
fe0c1e0
raw
history blame
5.53 kB
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw
import torch
import torchvision.transforms as transforms
import timm
# URL for the Hugging Face checkpoint
CHECKPOINT_URL = "https://huggingface.co/ReefNet/beit_global/resolve/main/checkpoint-60.pth"
# Class labels
all_classes = [
'Acanthastrea', 'Acropora', 'Agaricia', 'Alveopora', 'Astrea', 'Astreopora',
'Caulastraea', 'Coeloseris', 'Colpophyllia', 'Coscinaraea', 'Ctenactis',
'Cycloseris', 'Cyphastrea', 'Dendrogyra', 'Dichocoenia', 'Diploastrea',
'Diploria', 'Dipsastraea', 'Echinophyllia', 'Echinopora', 'Euphyllia',
'Eusmilia', 'Favia', 'Favites', 'Fungia', 'Galaxea', 'Gardineroseris',
'Goniastrea', 'Goniopora', 'Halomitra', 'Herpolitha', 'Hydnophora',
'Isophyllia', 'Isopora', 'Leptastrea', 'Leptoria', 'Leptoseris',
'Lithophyllon', 'Lobactis', 'Lobophyllia', 'Madracis', 'Meandrina', 'Merulina',
'Montastraea', 'Montipora', 'Mussa', 'Mussismilia', 'Mycedium', 'Orbicella',
'Oulastrea', 'Oulophyllia', 'Oxypora', 'Pachyseris', 'Pavona', 'Pectinia',
'Physogyra', 'Platygyra', 'Plerogyra', 'Plesiastrea', 'Pocillopora',
'Podabacia', 'Porites', 'Psammocora', 'Pseudodiploria', 'Sandalolitha',
'Scolymia', 'Seriatopora', 'Siderastrea', 'Stephanocoenia', 'Stylocoeniella',
'Stylophora', 'Tubastraea', 'Turbinaria'
]
# Function to load the BeIT model
def load_model(model_name):
print(f"Loading {model_name} model...")
if model_name == 'beit':
args = type('', (), {})()
args.model = 'beitv2_large_patch16_224.in1k_ft_in22k_in1k'
args.nb_classes = len(all_classes)
args.drop_path = 0.1
# Create model
model = timm.create_model(
args.model,
pretrained=False,
num_classes=args.nb_classes,
drop_path_rate=args.drop_path,
use_rel_pos_bias=True,
use_abs_pos_emb=True,
)
# Load checkpoint from Hugging Face
checkpoint = torch.hub.load_state_dict_from_url(CHECKPOINT_URL, map_location="cpu")
state_dict = checkpoint.get('model', checkpoint)
# Filter state dict
filtered_state_dict = {k: v for k, v in state_dict.items() if "relative_position_index" not in k}
model.load_state_dict(filtered_state_dict, strict=False)
else:
raise ValueError(f"Model {model_name} not implemented!")
# Move model to CUDA if available
model.eval()
if torch.cuda.is_available():
model.cuda()
return model
# Preprocessing transforms
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Initialize selected model
selected_model_name = 'beit'
model = load_model(selected_model_name)
def predict_label(image):
"""Predict the label for the given image."""
# Ensure the image is a PIL Image
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
elif not isinstance(image, Image.Image):
raise TypeError(f"Unexpected type {type(image)}, expected PIL.Image or numpy.ndarray.")
input_tensor = preprocess(image).unsqueeze(0)
if torch.cuda.is_available():
input_tensor = input_tensor.cuda()
with torch.no_grad():
outputs = model(input_tensor)
predicted_class = torch.argmax(outputs, dim=1).item()
return all_classes[predicted_class]
# Function to draw a rectangle on the image
def draw_rectangle(image, x, y, size=224):
image_pil = image.copy()
draw = ImageDraw.Draw(image_pil)
draw.rectangle([x, y, x + size, y + size], outline="red", width=3)
return image_pil
# Crop a region of interest
def crop_image(image, x, y, size=224):
image_np = np.array(image)
h, w, _ = image_np.shape
x = min(max(x, 0), w - size)
y = min(max(y, 0), h - size)
cropped = image_np[y:y+size, x:x+size]
return Image.fromarray(cropped)
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## Coral Classification with BeIT Model")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image", interactive=True)
x_slider = gr.Slider(0, 1000, step=1, value=0, label="X Coordinate")
y_slider = gr.Slider(0, 1000, step=1, value=0, label="Y Coordinate")
with gr.Column():
interactive_image = gr.Image(label="Interactive Image")
cropped_image = gr.Image(label="Cropped Patch")
label_output = gr.Textbox(label="Predicted Label")
# Interactions
def update_selection(image, x, y):
overlay_image = draw_rectangle(image, x, y)
cropped = crop_image(image, x, y)
return overlay_image, cropped
def predict_from_cropped(cropped):
return predict_label(cropped)
crop_button = gr.Button("Crop")
crop_button.click(fn=update_selection, inputs=[image_input, x_slider, y_slider], outputs=[interactive_image, cropped_image])
predict_button = gr.Button("Predict")
predict_button.click(fn=predict_from_cropped, inputs=cropped_image, outputs=label_output)
def update_sliders(image):
if image:
width, height = image.size
return gr.update(maximum=width - 224), gr.update(maximum=height - 224)
return gr.update(), gr.update()
image_input.change(fn=update_sliders, inputs=image_input, outputs=[x_slider, y_slider])
demo.launch(server_name="0.0.0.0", server_port=7860)