muhammadsalmanalfaridzi commited on
Commit
62b1b2e
·
verified ·
1 Parent(s): 2af14ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -102
app.py CHANGED
@@ -1,121 +1,63 @@
1
  import gradio as gr
 
2
  import numpy as np
3
  import cv2
4
- import supervision as sv
5
- from roboflow import Roboflow
6
- import tempfile
7
- import os
8
- from sahi.predict import predict
9
- from dotenv import load_dotenv
10
-
11
- # Load environment variables from .env file
12
- load_dotenv()
13
- api_key = os.getenv("ROBOFLOW_API_KEY")
14
- workspace = os.getenv("ROBOFLOW_WORKSPACE")
15
- project_name = os.getenv("ROBOFLOW_PROJECT")
16
- model_version = int(os.getenv("ROBOFLOW_MODEL_VERSION"))
17
-
18
- # Initialize Roboflow with the API key
19
- rf = Roboflow(api_key=api_key)
20
- project = rf.workspace(workspace).project(project_name)
21
- model = project.version(model_version).model
22
-
23
- def detect_objects(image):
24
- # Save the uploaded image to a temporary file
25
- with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
26
- image.save(temp_file, format="JPEG")
27
- temp_file_path = temp_file.name
28
 
29
- # Read the image using OpenCV
30
- original_image = cv2.imread(temp_file_path)
 
31
 
32
- try:
33
- # Use SAHI to slice the image (optional for large images)
34
- predictions = predict(
35
- detection_model=model, # Use Roboflow model for prediction
36
- image=original_image,
37
- slice_height=800, # Height of each slice
38
- slice_width=800, # Width of each slice
39
- overlap_height_ratio=0.2,
40
- overlap_width_ratio=0.2,
41
- return_slice_result=False, # We don't need slice results, just detections
42
- )
43
 
44
- # Initialize Supervision annotations
45
- detections = []
46
- for prediction in predictions:
47
- bbox = prediction.bbox
48
- class_name = prediction.category
49
- confidence = prediction.score
50
 
51
- # Add detection to Supervision Detections list
52
- detections.append(
53
- sv.Detection(
54
- x1=bbox[0],
55
- y1=bbox[1],
56
- x2=bbox[2],
57
- y2=bbox[3],
58
- confidence=confidence,
59
- class_name=class_name
60
- )
61
- )
62
 
63
- # Convert detections to a Detections object for Supervision
64
- detections = sv.Detections(detections)
 
 
 
65
 
66
- # Annotate the image with bounding boxes and labels
67
- label_annotator = sv.LabelAnnotator()
68
- box_annotator = sv.BoxAnnotator()
69
-
70
- # Annotate and create the final result
71
- annotated_image = box_annotator.annotate(scene=original_image.copy(), detections=detections)
72
- annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections)
73
 
74
- # Count detected objects per class
75
- class_count = {}
76
- total_count = 0
77
-
78
- for detection in detections:
79
- class_name = detection.class_name
80
- class_count[class_name] = class_count.get(class_name, 0) + 1
81
- total_count += 1
82
 
83
- # Prepare result text
84
- result_text = "Detected Objects:\n\n"
85
- for class_name, count in class_count.items():
86
- result_text += f"{class_name}: {count}\n"
87
- result_text += f"\nTotal objects detected: {total_count}"
88
 
89
- # Save the annotated image as output
90
- output_image_path = "/tmp/prediction.jpg"
91
- cv2.imwrite(output_image_path, annotated_image)
 
 
92
 
93
- except Exception as err:
94
- result_text = f"An error occurred: {err}"
95
- output_image_path = temp_file_path # Return original image on error
96
 
97
- # Clean up by removing the temporary file
98
- os.remove(temp_file_path)
99
-
100
- return output_image_path, result_text
101
 
102
  # Gradio interface
103
- with gr.Blocks() as iface:
104
- with gr.Row():
105
- with gr.Column():
106
- input_image = gr.Image(type="pil", label="Input Image")
107
- with gr.Column():
108
- output_image = gr.Image(label="Detected Image")
109
- with gr.Column():
110
- output_text = gr.Textbox(label="Object Count Results")
111
-
112
- detect_button = gr.Button("Detect")
113
-
114
- detect_button.click(
115
- fn=detect_objects,
116
- inputs=input_image,
117
- outputs=[output_image, output_text]
118
- )
119
 
120
  # Launch the Gradio interface
121
  iface.launch()
 
1
  import gradio as gr
2
+ import supervision as sv
3
  import numpy as np
4
  import cv2
5
+ from inference import get_roboflow_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ # Replace with your actual Roboflow model ID and API key
8
+ model_id = "your-model-id" # Replace with your Roboflow model ID
9
+ api_key = "your-api-key" # Replace with your Roboflow API key
10
 
11
+ # Load the Roboflow model using the get_roboflow_model function
12
+ model = get_roboflow_model(model_id=model_id, api_key=api_key)
 
 
 
 
 
 
 
 
 
13
 
14
+ # Define the callback function for the SAHI slicer
15
+ def callback(image_slice: np.ndarray) -> sv.Detections:
16
+ # Run inference on the image slice
17
+ results = model.infer(image_slice)[0]
18
+ return sv.Detections.from_inference(results)
 
19
 
20
+ # Initialize the SAHI Inference Slicer
21
+ slicer = sv.InferenceSlicer(callback=callback)
 
 
 
 
 
 
 
 
 
22
 
23
+ # Function to handle image processing, inference, and annotation
24
+ def process_image(image):
25
+ # Convert the PIL image to OpenCV format (BGR)
26
+ image = np.array(image)
27
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
28
 
29
+ # Run inference using SAHI (splitting the image into slices)
30
+ sliced_detections = slicer(image=image)
 
 
 
 
 
31
 
32
+ # Annotate the detections with bounding boxes and labels
33
+ label_annotator = sv.LabelAnnotator()
34
+ box_annotator = sv.BoxAnnotator()
35
+
36
+ annotated_image = box_annotator.annotate(scene=image.copy(), detections=sliced_detections)
37
+ annotated_image = label_annotator.annotate(scene=annotated_image, detections=sliced_detections)
 
 
38
 
39
+ # Convert the annotated image back to RGB for display in Gradio
40
+ result_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
 
 
 
41
 
42
+ # Count the number of objects detected
43
+ class_count = {}
44
+ for detection in sliced_detections:
45
+ class_name = detection.class_name
46
+ class_count[class_name] = class_count.get(class_name, 0) + 1
47
 
48
+ total_count = sum(class_count.values())
 
 
49
 
50
+ return result_image, class_count, total_count
 
 
 
51
 
52
  # Gradio interface
53
+ iface = gr.Interface(
54
+ fn=process_image,
55
+ inputs=gr.Image(type="pil", label="Upload Image"),
56
+ outputs=[gr.Image(type="pil", label="Annotated Image"),
57
+ gr.JSON(label="Object Count"),
58
+ gr.Number(label="Total Objects Detected")],
59
+ live=True
60
+ )
 
 
 
 
 
 
 
 
61
 
62
  # Launch the Gradio interface
63
  iface.launch()