Spaces:
Sleeping
Sleeping
yahiab
commited on
Commit
·
f311e6e
1
Parent(s):
0d11696
fix
Browse files
app.py
CHANGED
@@ -12,18 +12,24 @@ MODEL_LIST = {
|
|
12 |
'convnext': "facebook/convnext-tiny-224",
|
13 |
}
|
14 |
|
15 |
-
#
|
16 |
-
|
17 |
-
|
18 |
-
return extractor
|
19 |
|
20 |
-
# Load
|
21 |
-
def
|
22 |
-
model
|
23 |
-
|
|
|
|
|
|
|
|
|
24 |
|
25 |
-
#
|
26 |
def predict(image, model, preprocessor):
|
|
|
|
|
|
|
27 |
inputs = preprocessor(images=image, return_tensors="pt").to("cuda")
|
28 |
with torch.no_grad():
|
29 |
outputs = model(**inputs)
|
@@ -32,6 +38,7 @@ def predict(image, model, preprocessor):
|
|
32 |
|
33 |
# Function to draw a rectangle on the image
|
34 |
def draw_rectangle(image, x, y, size=224):
|
|
|
35 |
image_pil = image.copy() # Create a copy to avoid modifying the original image
|
36 |
draw = ImageDraw.Draw(image_pil)
|
37 |
x1, y1 = x, y
|
@@ -41,6 +48,7 @@ def draw_rectangle(image, x, y, size=224):
|
|
41 |
|
42 |
# Function to crop the image
|
43 |
def crop_image(image, x, y, size=224):
|
|
|
44 |
image_np = np.array(image)
|
45 |
h, w, _ = image_np.shape
|
46 |
x = min(max(x, 0), w - size)
|
@@ -48,10 +56,6 @@ def crop_image(image, x, y, size=224):
|
|
48 |
cropped = image_np[y:y+size, x:x+size]
|
49 |
return Image.fromarray(cropped)
|
50 |
|
51 |
-
# Global variables
|
52 |
-
current_model = None
|
53 |
-
current_preprocessor = None
|
54 |
-
|
55 |
# Gradio Interface
|
56 |
with gr.Blocks() as demo:
|
57 |
gr.Markdown("## Test Public Models for Coral Classification")
|
@@ -67,12 +71,9 @@ with gr.Blocks() as demo:
|
|
67 |
cropped_image = gr.Image(label="Cropped Patch")
|
68 |
label_output = gr.Textbox(label="Predicted Label")
|
69 |
|
70 |
-
# Update the
|
71 |
def update_model(model_name):
|
72 |
-
|
73 |
-
current_model = load_model(model_name)
|
74 |
-
current_preprocessor = get_preprocessor(model_name)
|
75 |
-
return f"Model {model_name} loaded successfully."
|
76 |
|
77 |
# Update the rectangle and crop the patch
|
78 |
def update_selection(image, x, y):
|
@@ -82,6 +83,7 @@ with gr.Blocks() as demo:
|
|
82 |
|
83 |
# Predict the label from the cropped patch
|
84 |
def predict_from_cropped(cropped):
|
|
|
85 |
return predict(cropped, current_model, current_preprocessor)
|
86 |
|
87 |
# Buttons and interactions
|
@@ -102,4 +104,7 @@ with gr.Blocks() as demo:
|
|
102 |
|
103 |
image_input.change(fn=update_sliders, inputs=image_input, outputs=[x_slider, y_slider])
|
104 |
|
|
|
|
|
|
|
105 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
12 |
'convnext': "facebook/convnext-tiny-224",
|
13 |
}
|
14 |
|
15 |
+
# Global variables
|
16 |
+
current_model = None
|
17 |
+
current_preprocessor = None
|
|
|
18 |
|
19 |
+
# Load model and preprocessor
|
20 |
+
def load_model_and_preprocessor(model_name):
|
21 |
+
"""Load model and preprocessor for a given model name."""
|
22 |
+
global current_model, current_preprocessor
|
23 |
+
print(f"Loading model and preprocessor for: {model_name}")
|
24 |
+
current_model = AutoModelForImageClassification.from_pretrained(MODEL_LIST[model_name]).cuda().eval()
|
25 |
+
current_preprocessor = AutoFeatureExtractor.from_pretrained(MODEL_LIST[model_name])
|
26 |
+
return f"Model {model_name} loaded successfully."
|
27 |
|
28 |
+
# Predict function
|
29 |
def predict(image, model, preprocessor):
|
30 |
+
"""Make a prediction on the given image patch using the loaded model."""
|
31 |
+
if model is None or preprocessor is None:
|
32 |
+
raise ValueError("Model and preprocessor are not loaded.")
|
33 |
inputs = preprocessor(images=image, return_tensors="pt").to("cuda")
|
34 |
with torch.no_grad():
|
35 |
outputs = model(**inputs)
|
|
|
38 |
|
39 |
# Function to draw a rectangle on the image
|
40 |
def draw_rectangle(image, x, y, size=224):
|
41 |
+
"""Draw a rectangle on the image."""
|
42 |
image_pil = image.copy() # Create a copy to avoid modifying the original image
|
43 |
draw = ImageDraw.Draw(image_pil)
|
44 |
x1, y1 = x, y
|
|
|
48 |
|
49 |
# Function to crop the image
|
50 |
def crop_image(image, x, y, size=224):
|
51 |
+
"""Crop a region from the image."""
|
52 |
image_np = np.array(image)
|
53 |
h, w, _ = image_np.shape
|
54 |
x = min(max(x, 0), w - size)
|
|
|
56 |
cropped = image_np[y:y+size, x:x+size]
|
57 |
return Image.fromarray(cropped)
|
58 |
|
|
|
|
|
|
|
|
|
59 |
# Gradio Interface
|
60 |
with gr.Blocks() as demo:
|
61 |
gr.Markdown("## Test Public Models for Coral Classification")
|
|
|
71 |
cropped_image = gr.Image(label="Cropped Patch")
|
72 |
label_output = gr.Textbox(label="Predicted Label")
|
73 |
|
74 |
+
# Update the model and preprocessor
|
75 |
def update_model(model_name):
|
76 |
+
return load_model_and_preprocessor(model_name)
|
|
|
|
|
|
|
77 |
|
78 |
# Update the rectangle and crop the patch
|
79 |
def update_selection(image, x, y):
|
|
|
83 |
|
84 |
# Predict the label from the cropped patch
|
85 |
def predict_from_cropped(cropped):
|
86 |
+
print(f"Type of cropped_image before prediction: {type(cropped)}")
|
87 |
return predict(cropped, current_model, current_preprocessor)
|
88 |
|
89 |
# Buttons and interactions
|
|
|
104 |
|
105 |
image_input.change(fn=update_sliders, inputs=image_input, outputs=[x_slider, y_slider])
|
106 |
|
107 |
+
# Initialize model on app start
|
108 |
+
demo.load(fn=lambda: load_model_and_preprocessor('beit'), inputs=None, outputs=None)
|
109 |
+
|
110 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|