Spaces:
Build error
Build error
import gradio as gr | |
from roboflow import Roboflow | |
import tempfile | |
import os | |
from sahi.slicing import slice_image | |
import numpy as np | |
import cv2 | |
from PIL import Image, ImageDraw | |
# Initialize Roboflow | |
rf = Roboflow(api_key="Otg64Ra6wNOgDyjuhMYU") | |
project = rf.workspace("alat-pelindung-diri").project("nescafe-4base") | |
model = project.version(16).model | |
def apply_nms(predictions, iou_threshold=0.5): | |
boxes = [] | |
scores = [] | |
classes = [] | |
# Extract boxes, scores, and class info | |
for prediction in predictions: | |
# Construct the bounding box from x, y, width, height | |
x = prediction['x'] | |
y = prediction['y'] | |
width = prediction['width'] | |
height = prediction['height'] | |
box = [x, y, width, height] | |
boxes.append(box) | |
scores.append(prediction['confidence']) | |
classes.append(prediction['class']) | |
boxes = np.array(boxes) | |
scores = np.array(scores) | |
classes = np.array(classes) | |
# Perform NMS using OpenCV | |
indices = cv2.dnn.NMSBoxes(boxes.tolist(), scores.tolist(), score_threshold=0.25, nms_threshold=iou_threshold) | |
print(f"Predictions before NMS: {predictions}") | |
print(f"Indices after NMS: {indices}") | |
# Check if indices is empty or invalid | |
if indices is None or len(indices) == 0: | |
print("No valid indices returned from NMS.") | |
return [] # Return an empty list if no valid indices are found | |
# Flatten indices array (if returned as a tuple) | |
indices = indices.flatten() | |
nms_predictions = [] | |
for i in indices: | |
nms_predictions.append({ | |
'class': classes[i], | |
'bbox': boxes[i], # Now using the constructed box | |
'confidence': scores[i] | |
}) | |
return nms_predictions | |
# Detect objects and annotate the image | |
def detect_objects(image): | |
# Save the image temporarily | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file: | |
image.save(temp_file, format="JPEG") | |
temp_file_path = temp_file.name | |
# Slice the image into smaller pieces | |
slice_image_result = slice_image( | |
image=temp_file_path, | |
output_file_name="sliced_image", | |
output_dir="/tmp/sliced/", | |
slice_height=256, | |
slice_width=256, | |
overlap_height_ratio=0.1, | |
overlap_width_ratio=0.1 | |
) | |
# Print to check the available attributes of the slice_image_result object | |
print(f"Slice result: {slice_image_result}") | |
# Try accessing the sliced image paths from the result object | |
try: | |
sliced_image_paths = slice_image_result.sliced_image_paths # Assuming this is the correct attribute | |
print(f"Sliced image paths: {sliced_image_paths}") | |
except AttributeError: | |
print("Failed to access sliced_image_paths attribute.") | |
sliced_image_paths = [] | |
# Check predictions for the whole image first | |
print("Predicting on the whole image (without slicing)...") | |
whole_image_predictions = model.predict(image_path=temp_file_path).json() | |
print(f"Whole image predictions: {whole_image_predictions}") | |
# If there are predictions, return them | |
if whole_image_predictions['predictions']: | |
print("Using predictions from the whole image.") | |
all_predictions = whole_image_predictions['predictions'] | |
else: | |
print("No predictions found for the whole image. Predicting on slices...") | |
# If no predictions for the whole image, predict on slices | |
all_predictions = [] | |
for sliced_image_path in sliced_image_paths: | |
if isinstance(sliced_image_path, str): | |
predictions = model.predict(image_path=sliced_image_path).json() | |
all_predictions.extend(predictions['predictions']) | |
else: | |
print(f"Skipping invalid image path: {sliced_image_path}") | |
# Apply NMS to remove duplicate detections | |
postprocessed_predictions = apply_nms(all_predictions, iou_threshold=0.5) | |
# Annotate the image with prediction results using OpenCV | |
img = cv2.imread(temp_file_path) | |
for prediction in postprocessed_predictions: | |
class_name = prediction['class'] | |
bbox = prediction['bbox'] | |
confidence = prediction['confidence'] | |
# Unpack the bounding box coordinates | |
x, y, w, h = map(int, bbox) | |
# Draw the bounding box and label on the image | |
color = (0, 255, 0) # Green color for the box | |
thickness = 2 | |
cv2.rectangle(img, (x, y), (x + w, y + h), color, thickness) | |
label = f"{class_name}: {confidence:.2f}" | |
cv2.putText(img, label, (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, thickness) | |
# Convert the image from BGR to RGB for PIL compatibility | |
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
annotated_image = Image.fromarray(img_rgb) | |
# Save the annotated image | |
output_image_path = "/tmp/prediction.jpg" | |
annotated_image.save(output_image_path) | |
# Count objects per class | |
class_count = {} | |
for detection in postprocessed_predictions: | |
class_name = detection['class'] | |
if class_name in class_count: | |
class_count[class_name] += 1 | |
else: | |
class_count[class_name] = 1 | |
# Object count result | |
result_text = "Jumlah objek per kelas:\n" | |
for class_name, count in class_count.items(): | |
result_text += f"{class_name}: {count} objek\n" | |
# Remove temporary file | |
os.remove(temp_file_path) | |
return output_image_path, result_text | |
# Gradio interface | |
iface = gr.Interface( | |
fn=detect_objects, # Function called when image is uploaded | |
inputs=gr.Image(type="pil"), # Input is an image | |
outputs=[gr.Image(), gr.Textbox()], # Output is an image and text | |
live=True # Display results live | |
) | |
# Run the interface | |
iface.launch() | |