File size: 5,532 Bytes
471d95f
 
 
 
fe0c1e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471d95f
fe0c1e0
 
 
 
 
471d95f
 
 
fe0c1e0
471d95f
fe0c1e0
471d95f
 
fe0c1e0
471d95f
 
 
 
 
 
 
 
fe0c1e0
471d95f
fe0c1e0
471d95f
 
 
fe0c1e0
 
471d95f
fe0c1e0
471d95f
 
fe0c1e0
 
471d95f
 
 
 
 
 
fe0c1e0
471d95f
 
 
 
 
 
 
 
fe0c1e0
471d95f
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw
import torch
import torchvision.transforms as transforms
import timm

# URL for the Hugging Face checkpoint
CHECKPOINT_URL = "https://huggingface.co/ReefNet/beit_global/resolve/main/checkpoint-60.pth"

# Class labels
all_classes = [
    'Acanthastrea', 'Acropora', 'Agaricia', 'Alveopora', 'Astrea', 'Astreopora',
    'Caulastraea', 'Coeloseris', 'Colpophyllia', 'Coscinaraea', 'Ctenactis',
    'Cycloseris', 'Cyphastrea', 'Dendrogyra', 'Dichocoenia', 'Diploastrea',
    'Diploria', 'Dipsastraea', 'Echinophyllia', 'Echinopora', 'Euphyllia',
    'Eusmilia', 'Favia', 'Favites', 'Fungia', 'Galaxea', 'Gardineroseris',
    'Goniastrea', 'Goniopora', 'Halomitra', 'Herpolitha', 'Hydnophora',
    'Isophyllia', 'Isopora', 'Leptastrea', 'Leptoria', 'Leptoseris',
    'Lithophyllon', 'Lobactis', 'Lobophyllia', 'Madracis', 'Meandrina', 'Merulina',
    'Montastraea', 'Montipora', 'Mussa', 'Mussismilia', 'Mycedium', 'Orbicella',
    'Oulastrea', 'Oulophyllia', 'Oxypora', 'Pachyseris', 'Pavona', 'Pectinia',
    'Physogyra', 'Platygyra', 'Plerogyra', 'Plesiastrea', 'Pocillopora',
    'Podabacia', 'Porites', 'Psammocora', 'Pseudodiploria', 'Sandalolitha',
    'Scolymia', 'Seriatopora', 'Siderastrea', 'Stephanocoenia', 'Stylocoeniella',
    'Stylophora', 'Tubastraea', 'Turbinaria'
]

# Function to load the BeIT model
def load_model(model_name):
    print(f"Loading {model_name} model...")
    if model_name == 'beit':
        args = type('', (), {})()
        args.model = 'beitv2_large_patch16_224.in1k_ft_in22k_in1k'
        args.nb_classes = len(all_classes)
        args.drop_path = 0.1

        # Create model
        model = timm.create_model(
            args.model,
            pretrained=False,
            num_classes=args.nb_classes,
            drop_path_rate=args.drop_path,
            use_rel_pos_bias=True,
            use_abs_pos_emb=True,
        )

        # Load checkpoint from Hugging Face
        checkpoint = torch.hub.load_state_dict_from_url(CHECKPOINT_URL, map_location="cpu")
        state_dict = checkpoint.get('model', checkpoint)

        # Filter state dict
        filtered_state_dict = {k: v for k, v in state_dict.items() if "relative_position_index" not in k}
        model.load_state_dict(filtered_state_dict, strict=False)
    else:
        raise ValueError(f"Model {model_name} not implemented!")

    # Move model to CUDA if available
    model.eval()
    if torch.cuda.is_available():
        model.cuda()
    return model

# Preprocessing transforms
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Initialize selected model
selected_model_name = 'beit'
model = load_model(selected_model_name)

def predict_label(image):
    """Predict the label for the given image."""
    # Ensure the image is a PIL Image
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    elif not isinstance(image, Image.Image):
        raise TypeError(f"Unexpected type {type(image)}, expected PIL.Image or numpy.ndarray.")

    input_tensor = preprocess(image).unsqueeze(0)
    if torch.cuda.is_available():
        input_tensor = input_tensor.cuda()

    with torch.no_grad():
        outputs = model(input_tensor)
        predicted_class = torch.argmax(outputs, dim=1).item()

    return all_classes[predicted_class]


# Function to draw a rectangle on the image
def draw_rectangle(image, x, y, size=224):
    image_pil = image.copy()
    draw = ImageDraw.Draw(image_pil)
    draw.rectangle([x, y, x + size, y + size], outline="red", width=3)
    return image_pil

# Crop a region of interest
def crop_image(image, x, y, size=224):
    image_np = np.array(image)
    h, w, _ = image_np.shape
    x = min(max(x, 0), w - size)
    y = min(max(y, 0), h - size)
    cropped = image_np[y:y+size, x:x+size]
    return Image.fromarray(cropped)

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## Coral Classification with BeIT Model")
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil", label="Upload Image", interactive=True)
            x_slider = gr.Slider(0, 1000, step=1, value=0, label="X Coordinate")
            y_slider = gr.Slider(0, 1000, step=1, value=0, label="Y Coordinate")
        with gr.Column():
            interactive_image = gr.Image(label="Interactive Image")
            cropped_image = gr.Image(label="Cropped Patch")
            label_output = gr.Textbox(label="Predicted Label")
    
    # Interactions
    def update_selection(image, x, y):
        overlay_image = draw_rectangle(image, x, y)
        cropped = crop_image(image, x, y)
        return overlay_image, cropped

    def predict_from_cropped(cropped):
        return predict_label(cropped)

    crop_button = gr.Button("Crop")
    crop_button.click(fn=update_selection, inputs=[image_input, x_slider, y_slider], outputs=[interactive_image, cropped_image])

    predict_button = gr.Button("Predict")
    predict_button.click(fn=predict_from_cropped, inputs=cropped_image, outputs=label_output)

    def update_sliders(image):
        if image:
            width, height = image.size
            return gr.update(maximum=width - 224), gr.update(maximum=height - 224)
        return gr.update(), gr.update()

    image_input.change(fn=update_sliders, inputs=image_input, outputs=[x_slider, y_slider])

demo.launch(server_name="0.0.0.0", server_port=7860)