Last commit not found
import gradio as gr | |
from dotenv import load_dotenv | |
from roboflow import Roboflow | |
import tempfile | |
import os | |
import requests | |
import numpy as np # Import numpy to handle image slices | |
from sahi.predict import get_sliced_prediction # SAHI slicing inference | |
import supervision as sv # For annotating images with results | |
# Muat variabel lingkungan dari file .env | |
load_dotenv() | |
api_key = os.getenv("ROBOFLOW_API_KEY") | |
workspace = os.getenv("ROBOFLOW_WORKSPACE") | |
project_name = os.getenv("ROBOFLOW_PROJECT") | |
model_version = int(os.getenv("ROBOFLOW_MODEL_VERSION")) | |
# Inisialisasi Roboflow menggunakan data yang diambil dari secrets | |
rf = Roboflow(api_key=api_key) | |
project = rf.workspace(workspace).project(project_name) | |
model = project.version(model_version).model | |
# Fungsi untuk menangani input dan output gambar | |
def detect_objects(image): | |
# Simpan gambar yang diupload sebagai file sementara | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file: | |
image.save(temp_file, format="JPEG") | |
temp_file_path = temp_file.name | |
try: | |
# Perform sliced inference with SAHI using InferenceSlicer | |
def callback(image_slice: np.ndarray) -> sv.Detections: | |
results = model.infer(image_slice)[0] # Perform inference on each slice | |
return sv.Detections.from_inference(results) | |
# Configure the SAHI Slicer with specific slice dimensions and overlap | |
slicer = sv.InferenceSlicer( | |
callback=callback, | |
slice_wh=(320, 320), # Adjust slice dimensions as needed | |
overlap_wh=(64, 64), # Adjust overlap in pixels (DO NOT use overlap_ratio_wh here) | |
overlap_filter=sv.OverlapFilter.NON_MAX_SUPPRESSION, # Filter overlapping detections | |
iou_threshold=0.5, # Intersection over Union threshold for NMS | |
) | |
# Run slicing-based inference | |
detections = slicer(image) | |
# Annotate the results on the image | |
box_annotator = sv.BoxAnnotator() | |
label_annotator = sv.LabelAnnotator() | |
annotated_image = box_annotator.annotate( | |
scene=image.copy(), detections=detections) | |
annotated_image = label_annotator.annotate( | |
scene=annotated_image, detections=detections) | |
# Save the annotated image | |
output_image_path = "/tmp/prediction_visual.png" | |
annotated_image.save(output_image_path) | |
# Count the number of detected objects per class | |
class_count = {} | |
total_count = 0 | |
for prediction in detections: | |
class_name = prediction.class_id # or prediction.class_name if available | |
class_count[class_name] = class_count.get(class_name, 0) + 1 | |
total_count += 1 # Increment the total object count | |
# Create a result text with object counts | |
result_text = "Detected Objects:\n\n" | |
for class_name, count in class_count.items(): | |
result_text += f"{class_name}: {count}\n" | |
result_text += f"\nTotal objects detected: {total_count}" | |
except requests.exceptions.HTTPError as http_err: | |
# Handle HTTP errors | |
result_text = f"HTTP error occurred: {http_err}" | |
output_image_path = temp_file_path # Return the original image in case of error | |
except Exception as err: | |
# Handle other errors | |
result_text = f"An error occurred: {err}" | |
output_image_path = temp_file_path # Return the original image in case of error | |
# Clean up temporary files | |
os.remove(temp_file_path) | |
return output_image_path, result_text | |
# Create the Gradio interface | |
with gr.Blocks() as iface: | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type="pil", label="Input Image") | |
with gr.Column(): | |
output_image = gr.Image(label="Detected Objects") | |
with gr.Column(): | |
output_text = gr.Textbox(label="Object Count") | |
# Button to trigger object detection | |
detect_button = gr.Button("Detect Objects") | |
# Link the button to the detect_objects function | |
detect_button.click( | |
fn=detect_objects, | |
inputs=input_image, | |
outputs=[output_image, output_text] | |
) | |
# Launch the interface | |
iface.launch() | |