Update app.py
Browse files
app.py
CHANGED
@@ -7,12 +7,12 @@ import numpy as np
|
|
7 |
import cv2
|
8 |
from PIL import Image, ImageDraw
|
9 |
|
10 |
-
#
|
11 |
rf = Roboflow(api_key="Otg64Ra6wNOgDyjuhMYU")
|
12 |
project = rf.workspace("alat-pelindung-diri").project("nescafe-4base")
|
13 |
model = project.version(16).model
|
14 |
|
15 |
-
#
|
16 |
def apply_nms(predictions, iou_threshold=0.5):
|
17 |
boxes = []
|
18 |
scores = []
|
@@ -31,6 +31,9 @@ def apply_nms(predictions, iou_threshold=0.5):
|
|
31 |
# Perform NMS using OpenCV
|
32 |
indices = cv2.dnn.NMSBoxes(boxes.tolist(), scores.tolist(), score_threshold=0.25, nms_threshold=iou_threshold)
|
33 |
|
|
|
|
|
|
|
34 |
# Check if indices is empty or invalid
|
35 |
if not indices or not isinstance(indices, tuple) or len(indices) == 0:
|
36 |
print("No valid indices returned from NMS.")
|
@@ -53,14 +56,14 @@ def apply_nms(predictions, iou_threshold=0.5):
|
|
53 |
|
54 |
return nms_predictions
|
55 |
|
56 |
-
#
|
57 |
def detect_objects(image):
|
58 |
-
#
|
59 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
|
60 |
image.save(temp_file, format="JPEG")
|
61 |
temp_file_path = temp_file.name
|
62 |
|
63 |
-
# Slice
|
64 |
slice_image_result = slice_image(
|
65 |
image=temp_file_path,
|
66 |
output_file_name="sliced_image",
|
@@ -82,10 +85,10 @@ def detect_objects(image):
|
|
82 |
print("Failed to access sliced_image_paths attribute.")
|
83 |
sliced_image_paths = []
|
84 |
|
85 |
-
#
|
86 |
all_predictions = []
|
87 |
|
88 |
-
#
|
89 |
for sliced_image_path in sliced_image_paths:
|
90 |
if isinstance(sliced_image_path, str):
|
91 |
predictions = model.predict(image_path=sliced_image_path).json()
|
@@ -93,10 +96,10 @@ def detect_objects(image):
|
|
93 |
else:
|
94 |
print(f"Skipping invalid image path: {sliced_image_path}")
|
95 |
|
96 |
-
#
|
97 |
postprocessed_predictions = apply_nms(all_predictions, iou_threshold=0.5)
|
98 |
|
99 |
-
# Annotate
|
100 |
img = cv2.imread(temp_file_path)
|
101 |
for prediction in postprocessed_predictions:
|
102 |
class_name = prediction['class']
|
@@ -118,11 +121,11 @@ def detect_objects(image):
|
|
118 |
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
119 |
annotated_image = Image.fromarray(img_rgb)
|
120 |
|
121 |
-
#
|
122 |
output_image_path = "/tmp/prediction.jpg"
|
123 |
annotated_image.save(output_image_path)
|
124 |
|
125 |
-
#
|
126 |
class_count = {}
|
127 |
for detection in postprocessed_predictions:
|
128 |
class_name = detection['class']
|
@@ -131,23 +134,23 @@ def detect_objects(image):
|
|
131 |
else:
|
132 |
class_count[class_name] = 1
|
133 |
|
134 |
-
#
|
135 |
result_text = "Jumlah objek per kelas:\n"
|
136 |
for class_name, count in class_count.items():
|
137 |
result_text += f"{class_name}: {count} objek\n"
|
138 |
|
139 |
-
#
|
140 |
os.remove(temp_file_path)
|
141 |
|
142 |
return output_image_path, result_text
|
143 |
|
144 |
-
#
|
145 |
iface = gr.Interface(
|
146 |
-
fn=detect_objects, #
|
147 |
-
inputs=gr.Image(type="pil"), # Input
|
148 |
-
outputs=[gr.Image(), gr.Textbox()], # Output
|
149 |
-
live=True #
|
150 |
)
|
151 |
|
152 |
-
#
|
153 |
iface.launch()
|
|
|
7 |
import cv2
|
8 |
from PIL import Image, ImageDraw
|
9 |
|
10 |
+
# Initialize Roboflow
|
11 |
rf = Roboflow(api_key="Otg64Ra6wNOgDyjuhMYU")
|
12 |
project = rf.workspace("alat-pelindung-diri").project("nescafe-4base")
|
13 |
model = project.version(16).model
|
14 |
|
15 |
+
# Apply NMS (Non-Maximum Suppression)
|
16 |
def apply_nms(predictions, iou_threshold=0.5):
|
17 |
boxes = []
|
18 |
scores = []
|
|
|
31 |
# Perform NMS using OpenCV
|
32 |
indices = cv2.dnn.NMSBoxes(boxes.tolist(), scores.tolist(), score_threshold=0.25, nms_threshold=iou_threshold)
|
33 |
|
34 |
+
print(f"Predictions before NMS: {predictions}")
|
35 |
+
print(f"Indices after NMS: {indices}")
|
36 |
+
|
37 |
# Check if indices is empty or invalid
|
38 |
if not indices or not isinstance(indices, tuple) or len(indices) == 0:
|
39 |
print("No valid indices returned from NMS.")
|
|
|
56 |
|
57 |
return nms_predictions
|
58 |
|
59 |
+
# Detect objects and annotate the image
|
60 |
def detect_objects(image):
|
61 |
+
# Save the image temporarily
|
62 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
|
63 |
image.save(temp_file, format="JPEG")
|
64 |
temp_file_path = temp_file.name
|
65 |
|
66 |
+
# Slice the image into smaller pieces
|
67 |
slice_image_result = slice_image(
|
68 |
image=temp_file_path,
|
69 |
output_file_name="sliced_image",
|
|
|
85 |
print("Failed to access sliced_image_paths attribute.")
|
86 |
sliced_image_paths = []
|
87 |
|
88 |
+
# Save all predictions for each sliced image
|
89 |
all_predictions = []
|
90 |
|
91 |
+
# Predict on each sliced image
|
92 |
for sliced_image_path in sliced_image_paths:
|
93 |
if isinstance(sliced_image_path, str):
|
94 |
predictions = model.predict(image_path=sliced_image_path).json()
|
|
|
96 |
else:
|
97 |
print(f"Skipping invalid image path: {sliced_image_path}")
|
98 |
|
99 |
+
# Apply NMS to remove duplicate detections
|
100 |
postprocessed_predictions = apply_nms(all_predictions, iou_threshold=0.5)
|
101 |
|
102 |
+
# Annotate the image with prediction results using OpenCV
|
103 |
img = cv2.imread(temp_file_path)
|
104 |
for prediction in postprocessed_predictions:
|
105 |
class_name = prediction['class']
|
|
|
121 |
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
122 |
annotated_image = Image.fromarray(img_rgb)
|
123 |
|
124 |
+
# Save the annotated image
|
125 |
output_image_path = "/tmp/prediction.jpg"
|
126 |
annotated_image.save(output_image_path)
|
127 |
|
128 |
+
# Count objects per class
|
129 |
class_count = {}
|
130 |
for detection in postprocessed_predictions:
|
131 |
class_name = detection['class']
|
|
|
134 |
else:
|
135 |
class_count[class_name] = 1
|
136 |
|
137 |
+
# Object count result
|
138 |
result_text = "Jumlah objek per kelas:\n"
|
139 |
for class_name, count in class_count.items():
|
140 |
result_text += f"{class_name}: {count} objek\n"
|
141 |
|
142 |
+
# Remove temporary file
|
143 |
os.remove(temp_file_path)
|
144 |
|
145 |
return output_image_path, result_text
|
146 |
|
147 |
+
# Gradio interface
|
148 |
iface = gr.Interface(
|
149 |
+
fn=detect_objects, # Function called when image is uploaded
|
150 |
+
inputs=gr.Image(type="pil"), # Input is an image
|
151 |
+
outputs=[gr.Image(), gr.Textbox()], # Output is an image and text
|
152 |
+
live=True # Display results live
|
153 |
)
|
154 |
|
155 |
+
# Run the interface
|
156 |
iface.launch()
|