yahiab commited on
Commit
0d11696
·
1 Parent(s): 882271f

Add public model testing app

Browse files
Files changed (2) hide show
  1. app.py +105 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image, ImageDraw
4
+ import torch
5
+ from torchvision import transforms
6
+ from transformers import AutoModelForImageClassification, AutoFeatureExtractor
7
+
8
+ # Define all available models
9
+ MODEL_LIST = {
10
+ 'beit': "microsoft/beit-base-patch16-224-pt22k-ft22k",
11
+ 'vit': "google/vit-base-patch16-224",
12
+ 'convnext': "facebook/convnext-tiny-224",
13
+ }
14
+
15
+ # Preprocessing transforms
16
+ def get_preprocessor(model_name):
17
+ extractor = AutoFeatureExtractor.from_pretrained(MODEL_LIST[model_name])
18
+ return extractor
19
+
20
+ # Load a model from Hugging Face
21
+ def load_model(model_name):
22
+ model = AutoModelForImageClassification.from_pretrained(MODEL_LIST[model_name]).cuda().eval()
23
+ return model
24
+
25
+ # Function to make predictions
26
+ def predict(image, model, preprocessor):
27
+ inputs = preprocessor(images=image, return_tensors="pt").to("cuda")
28
+ with torch.no_grad():
29
+ outputs = model(**inputs)
30
+ predicted_class = torch.argmax(outputs.logits, dim=1).item()
31
+ return model.config.id2label[predicted_class]
32
+
33
+ # Function to draw a rectangle on the image
34
+ def draw_rectangle(image, x, y, size=224):
35
+ image_pil = image.copy() # Create a copy to avoid modifying the original image
36
+ draw = ImageDraw.Draw(image_pil)
37
+ x1, y1 = x, y
38
+ x2, y2 = x + size, y + size
39
+ draw.rectangle([x1, y1, x2, y2], outline="red", width=5)
40
+ return image_pil
41
+
42
+ # Function to crop the image
43
+ def crop_image(image, x, y, size=224):
44
+ image_np = np.array(image)
45
+ h, w, _ = image_np.shape
46
+ x = min(max(x, 0), w - size)
47
+ y = min(max(y, 0), h - size)
48
+ cropped = image_np[y:y+size, x:x+size]
49
+ return Image.fromarray(cropped)
50
+
51
+ # Global variables
52
+ current_model = None
53
+ current_preprocessor = None
54
+
55
+ # Gradio Interface
56
+ with gr.Blocks() as demo:
57
+ gr.Markdown("## Test Public Models for Coral Classification")
58
+
59
+ with gr.Row():
60
+ with gr.Column():
61
+ model_selector = gr.Dropdown(choices=list(MODEL_LIST.keys()), value='beit', label="Select Model")
62
+ image_input = gr.Image(type="pil", label="Upload Image", interactive=True)
63
+ x_slider = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="X Coordinate")
64
+ y_slider = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="Y Coordinate")
65
+ with gr.Column():
66
+ interactive_image = gr.Image(label="Interactive Image with Selection")
67
+ cropped_image = gr.Image(label="Cropped Patch")
68
+ label_output = gr.Textbox(label="Predicted Label")
69
+
70
+ # Update the current model and preprocessor
71
+ def update_model(model_name):
72
+ global current_model, current_preprocessor
73
+ current_model = load_model(model_name)
74
+ current_preprocessor = get_preprocessor(model_name)
75
+ return f"Model {model_name} loaded successfully."
76
+
77
+ # Update the rectangle and crop the patch
78
+ def update_selection(image, x, y):
79
+ overlay_image = draw_rectangle(image, x, y)
80
+ cropped = crop_image(image, x, y)
81
+ return overlay_image, cropped
82
+
83
+ # Predict the label from the cropped patch
84
+ def predict_from_cropped(cropped):
85
+ return predict(cropped, current_model, current_preprocessor)
86
+
87
+ # Buttons and interactions
88
+ crop_button = gr.Button("Crop")
89
+ crop_button.click(fn=update_selection, inputs=[image_input, x_slider, y_slider], outputs=[interactive_image, cropped_image])
90
+
91
+ predict_button = gr.Button("Predict")
92
+ predict_button.click(fn=predict_from_cropped, inputs=cropped_image, outputs=label_output)
93
+
94
+ model_selector.change(fn=update_model, inputs=model_selector, outputs=None)
95
+
96
+ # Update sliders dynamically based on uploaded image size
97
+ def update_sliders(image):
98
+ if image is not None:
99
+ width, height = image.size
100
+ return gr.update(maximum=width - 224), gr.update(maximum=height - 224)
101
+ return gr.update(), gr.update()
102
+
103
+ image_input.change(fn=update_sliders, inputs=image_input, outputs=[x_slider, y_slider])
104
+
105
+ demo.launch(server_name="0.0.0.0", server_port=7860)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ transformers