File size: 4,074 Bytes
ac831c4
af6e415
52b0cb8
0fc2ac3
 
 
 
2d17df1
0fc2ac3
52b0cb8
0fc2ac3
 
 
 
 
 
52b0cb8
0fc2ac3
 
 
 
93307f9
af6e415
0fc2ac3
 
 
 
2d17df1
 
 
 
0fc2ac3
2af14ab
2d17df1
2af14ab
2d17df1
2af14ab
 
2d17df1
 
2af14ab
2d17df1
0fc2ac3
 
 
2d17df1
 
 
 
 
0fc2ac3
 
 
 
 
 
 
 
 
 
 
 
 
 
af6e415
0fc2ac3
 
 
 
2d17df1
 
0fc2ac3
af6e415
0fc2ac3
 
 
af6e415
0fc2ac3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64c6f6f
0fc2ac3
af6e415
0fc2ac3
 
 
 
 
 
 
 
 
 
 
af6e415
0fc2ac3
 
 
 
 
ac831c4
af6e415
5a61493
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gradio as gr
import numpy as np
import cv2
import supervision as sv
from roboflow import Roboflow
import tempfile
import os
from sahi.predict import predict
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()
api_key = os.getenv("ROBOFLOW_API_KEY")
workspace = os.getenv("ROBOFLOW_WORKSPACE")
project_name = os.getenv("ROBOFLOW_PROJECT")
model_version = int(os.getenv("ROBOFLOW_MODEL_VERSION"))

# Initialize Roboflow with the API key
rf = Roboflow(api_key=api_key)
project = rf.workspace(workspace).project(project_name)
model = project.version(model_version).model

def detect_objects(image):
    # Save the uploaded image to a temporary file
    with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
        image.save(temp_file, format="JPEG")
        temp_file_path = temp_file.name

    # Read the image using OpenCV
    original_image = cv2.imread(temp_file_path)

    try:
        # Use SAHI to slice the image (optional for large images)
        predictions = predict(
            detection_model=model,  # Use Roboflow model for prediction
            image=original_image,
            slice_height=800,  # Height of each slice
            slice_width=800,   # Width of each slice
            overlap_height_ratio=0.2,
            overlap_width_ratio=0.2,
            return_slice_result=False,  # We don't need slice results, just detections
        )

        # Initialize Supervision annotations
        detections = []
        for prediction in predictions:
            bbox = prediction.bbox
            class_name = prediction.category
            confidence = prediction.score

            # Add detection to Supervision Detections list
            detections.append(
                sv.Detection(
                    x1=bbox[0],
                    y1=bbox[1],
                    x2=bbox[2],
                    y2=bbox[3],
                    confidence=confidence,
                    class_name=class_name
                )
            )

        # Convert detections to a Detections object for Supervision
        detections = sv.Detections(detections)

        # Annotate the image with bounding boxes and labels
        label_annotator = sv.LabelAnnotator()
        box_annotator = sv.BoxAnnotator()
        
        # Annotate and create the final result
        annotated_image = box_annotator.annotate(scene=original_image.copy(), detections=detections)
        annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections)

        # Count detected objects per class
        class_count = {}
        total_count = 0

        for detection in detections:
            class_name = detection.class_name
            class_count[class_name] = class_count.get(class_name, 0) + 1
            total_count += 1

        # Prepare result text
        result_text = "Detected Objects:\n\n"
        for class_name, count in class_count.items():
            result_text += f"{class_name}: {count}\n"
        result_text += f"\nTotal objects detected: {total_count}"

        # Save the annotated image as output
        output_image_path = "/tmp/prediction.jpg"
        cv2.imwrite(output_image_path, annotated_image)

    except Exception as err:
        result_text = f"An error occurred: {err}"
        output_image_path = temp_file_path  # Return original image on error

    # Clean up by removing the temporary file
    os.remove(temp_file_path)
    
    return output_image_path, result_text

# Gradio interface
with gr.Blocks() as iface:
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil", label="Input Image")
        with gr.Column():
            output_image = gr.Image(label="Detected Image")
        with gr.Column():
            output_text = gr.Textbox(label="Object Count Results")
    
    detect_button = gr.Button("Detect")
    
    detect_button.click(
        fn=detect_objects,
        inputs=input_image,
        outputs=[output_image, output_text]
    )

# Launch the Gradio interface
iface.launch()