initial commit
Browse files- app.py +236 -43
- requirements.txt +5 -2
app.py
CHANGED
@@ -1,82 +1,275 @@
|
|
1 |
-
import
|
|
|
|
|
|
|
2 |
import cv2
|
3 |
import numpy as np
|
4 |
from ultralytics import YOLO
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
# Load YOLO model
|
7 |
-
|
|
|
8 |
|
9 |
-
# Define trapezoidal restricted area
|
10 |
trapezoid_pts = np.array([[250, 150], [400, 150], [450, 300], [200, 300]], np.int32)
|
11 |
|
|
|
|
|
|
|
12 |
def is_inside_trapezoid(box, trapezoid_pts):
|
13 |
"""Check if the center of a detected object is inside the trapezoidal area."""
|
14 |
x1, y1, x2, y2 = box
|
15 |
-
cx, cy = int((x1 + x2) / 2), int((y1 + y2) / 2)
|
|
|
|
|
16 |
return cv2.pointPolygonTest(trapezoid_pts, (cx, cy), False) >= 0
|
17 |
|
18 |
-
def
|
19 |
-
|
20 |
-
|
21 |
|
|
|
22 |
results = model.predict(frame, conf=0.5)
|
23 |
-
annotated_frame = results[0].plot()
|
24 |
-
|
25 |
-
# Draw
|
26 |
cv2.polylines(annotated_frame, [trapezoid_pts.reshape((-1, 1, 2))], isClosed=True, color=(0, 0, 255), thickness=2)
|
|
|
|
|
|
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
for r in results:
|
32 |
-
for box, cls in zip(r.boxes.xyxy, r.boxes.cls):
|
33 |
-
class_id = int(cls.item())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
if class_id == 0: # Person
|
35 |
-
isAlert[
|
|
|
36 |
if class_id in [0, 1, 2, 3]: # Person, bicycle, car, motorcycle
|
37 |
if is_inside_trapezoid(box.tolist(), trapezoid_pts):
|
38 |
-
isAlert[
|
39 |
# Mark the intrusion with a red box
|
40 |
-
x1, y1, x2, y2 = map(int, box.tolist())
|
41 |
cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), (0, 0, 255), 3)
|
42 |
|
43 |
# Add alert text on the frame
|
44 |
alert_text = f"Intrusion Alert: {isAlert['alert'][0]}, Object: {isAlert['alert'][1]}, Persons: {isAlert['personCount']}"
|
45 |
cv2.putText(annotated_frame, alert_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
|
46 |
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
if not cap.isOpened():
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
while True:
|
57 |
ret, frame = cap.read()
|
58 |
if not ret:
|
59 |
break
|
60 |
|
61 |
-
# Process frame
|
62 |
-
|
63 |
|
64 |
-
#
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
-
|
82 |
-
|
|
|
|
1 |
+
from fastapi import FastAPI, File, UploadFile, HTTPException
|
2 |
+
from fastapi.responses import FileResponse, JSONResponse
|
3 |
+
from fastapi.middleware.cors import CORSMiddleware
|
4 |
+
import uvicorn
|
5 |
import cv2
|
6 |
import numpy as np
|
7 |
from ultralytics import YOLO
|
8 |
+
import os
|
9 |
+
import shutil
|
10 |
+
from typing import Optional
|
11 |
+
import uuid
|
12 |
+
import base64
|
13 |
+
from io import BytesIO
|
14 |
+
from PIL import Image
|
15 |
+
|
16 |
+
# Create FastAPI app
|
17 |
+
app = FastAPI(
|
18 |
+
title="YOLO Intrusion Detection API",
|
19 |
+
description="API for detecting intrusions using YOLOv8 model",
|
20 |
+
version="1.0.0"
|
21 |
+
)
|
22 |
+
|
23 |
+
# Add CORS middleware
|
24 |
+
app.add_middleware(
|
25 |
+
CORSMiddleware,
|
26 |
+
allow_origins=["*"],
|
27 |
+
allow_credentials=True,
|
28 |
+
allow_methods=["*"],
|
29 |
+
allow_headers=["*"],
|
30 |
+
)
|
31 |
|
32 |
# Load YOLO model
|
33 |
+
model_name = 'yolov8n.pt'
|
34 |
+
model = None # Will be loaded on startup
|
35 |
|
36 |
+
# Define trapezoidal restricted area
|
37 |
trapezoid_pts = np.array([[250, 150], [400, 150], [450, 300], [200, 300]], np.int32)
|
38 |
|
39 |
+
# Create temp directory for uploads if it doesn't exist
|
40 |
+
os.makedirs("temp", exist_ok=True)
|
41 |
+
|
42 |
def is_inside_trapezoid(box, trapezoid_pts):
|
43 |
"""Check if the center of a detected object is inside the trapezoidal area."""
|
44 |
x1, y1, x2, y2 = box
|
45 |
+
cx, cy = int((x1 + x2) / 2), int((y1 + y2) / 2) # Calculate center of detected object
|
46 |
+
|
47 |
+
# Use point-in-polygon check
|
48 |
return cv2.pointPolygonTest(trapezoid_pts, (cx, cy), False) >= 0
|
49 |
|
50 |
+
def process_image(frame):
|
51 |
+
"""Process a single image and return the annotated image and intrusion data."""
|
52 |
+
global model
|
53 |
|
54 |
+
# Perform object detection
|
55 |
results = model.predict(frame, conf=0.5)
|
56 |
+
annotated_frame = results[0].plot() # Draw bounding boxes
|
57 |
+
|
58 |
+
# Draw trapezoidal restricted area
|
59 |
cv2.polylines(annotated_frame, [trapezoid_pts.reshape((-1, 1, 2))], isClosed=True, color=(0, 0, 255), thickness=2)
|
60 |
+
|
61 |
+
isAlert = {'alert': [False, ""], 'personCount': 0}
|
62 |
+
classInIntrusion = ['person', 'bicycle', 'car', 'motorcycle']
|
63 |
|
64 |
+
detections = []
|
65 |
+
|
66 |
+
# Loop through detected objects
|
67 |
for r in results:
|
68 |
+
for box, cls, conf in zip(r.boxes.xyxy, r.boxes.cls, r.boxes.conf):
|
69 |
+
class_id = int(cls.item()) # Convert to integer
|
70 |
+
confidence = float(conf.item())
|
71 |
+
x1, y1, x2, y2 = map(int, box.tolist())
|
72 |
+
|
73 |
+
class_name = classInIntrusion[class_id] if class_id < len(classInIntrusion) else f"class_{class_id}"
|
74 |
+
|
75 |
+
# Add to detections list
|
76 |
+
detections.append({
|
77 |
+
"class": class_name,
|
78 |
+
"confidence": confidence,
|
79 |
+
"bbox": [x1, y1, x2, y2],
|
80 |
+
"in_restricted_area": is_inside_trapezoid(box.tolist(), trapezoid_pts)
|
81 |
+
})
|
82 |
+
|
83 |
if class_id == 0: # Person
|
84 |
+
isAlert['personCount'] += 1
|
85 |
+
|
86 |
if class_id in [0, 1, 2, 3]: # Person, bicycle, car, motorcycle
|
87 |
if is_inside_trapezoid(box.tolist(), trapezoid_pts):
|
88 |
+
isAlert['alert'] = [True, classInIntrusion[class_id]]
|
89 |
# Mark the intrusion with a red box
|
|
|
90 |
cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), (0, 0, 255), 3)
|
91 |
|
92 |
# Add alert text on the frame
|
93 |
alert_text = f"Intrusion Alert: {isAlert['alert'][0]}, Object: {isAlert['alert'][1]}, Persons: {isAlert['personCount']}"
|
94 |
cv2.putText(annotated_frame, alert_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
|
95 |
|
96 |
+
# Convert the response
|
97 |
+
response = {
|
98 |
+
"intrusion_detected": isAlert['alert'][0],
|
99 |
+
"intruding_object": isAlert['alert'][1],
|
100 |
+
"person_count": isAlert['personCount'],
|
101 |
+
"detections": detections
|
102 |
+
}
|
103 |
+
|
104 |
+
return annotated_frame, response
|
105 |
+
|
106 |
+
def encode_image_to_base64(image):
|
107 |
+
"""Convert an OpenCV image to base64 encoded string."""
|
108 |
+
_, buffer = cv2.imencode('.jpg', image)
|
109 |
+
return base64.b64encode(buffer).decode('utf-8')
|
110 |
+
|
111 |
+
@app.on_event("startup")
|
112 |
+
async def startup_event():
|
113 |
+
"""Load the YOLO model when the app starts."""
|
114 |
+
global model
|
115 |
+
model = YOLO(model_name)
|
116 |
+
print(f"Model {model_name} loaded successfully")
|
117 |
|
118 |
+
@app.get("/")
|
119 |
+
async def root():
|
120 |
+
"""Root endpoint."""
|
121 |
+
return {
|
122 |
+
"message": "YOLO Intrusion Detection API is running",
|
123 |
+
"documentation": "/docs",
|
124 |
+
"endpoints": {
|
125 |
+
"process_image": "/process_image/",
|
126 |
+
"process_video": "/process_video/",
|
127 |
+
"health": "/health/"
|
128 |
+
}
|
129 |
+
}
|
130 |
+
|
131 |
+
@app.get("/health/")
|
132 |
+
async def health_check():
|
133 |
+
"""Health check endpoint."""
|
134 |
+
return {"status": "healthy", "model": model_name}
|
135 |
+
|
136 |
+
@app.post("/process_image/")
|
137 |
+
async def api_process_image(file: UploadFile = File(...), return_image: bool = True):
|
138 |
+
"""
|
139 |
+
Process an image file and detect intrusions.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
file: The image file to process
|
143 |
+
return_image: If True, returns the annotated image as base64
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
JSON with detection results and optionally the annotated image
|
147 |
+
"""
|
148 |
+
# Check file extension
|
149 |
+
if not file.filename.lower().endswith(('.png', '.jpg', '.jpeg')):
|
150 |
+
raise HTTPException(status_code=400, detail="Only PNG and JPG images are supported")
|
151 |
+
|
152 |
+
# Read and process image
|
153 |
+
contents = await file.read()
|
154 |
+
nparr = np.frombuffer(contents, np.uint8)
|
155 |
+
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
156 |
+
|
157 |
+
if img is None:
|
158 |
+
raise HTTPException(status_code=400, detail="Could not decode image")
|
159 |
+
|
160 |
+
# Process the image
|
161 |
+
annotated_img, results = process_image(img)
|
162 |
+
|
163 |
+
# Optionally include the annotated image
|
164 |
+
if return_image:
|
165 |
+
results["image"] = encode_image_to_base64(annotated_img)
|
166 |
|
167 |
+
return results
|
168 |
+
|
169 |
+
@app.post("/process_video/")
|
170 |
+
async def api_process_video(file: UploadFile = File(...)):
|
171 |
+
"""
|
172 |
+
Process a video file and detect intrusions.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
file: The video file to process
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
JSON with detection results and path to processed video
|
179 |
+
"""
|
180 |
+
# Check file extension
|
181 |
+
if not file.filename.lower().endswith(('.mp4', '.avi', '.mov')):
|
182 |
+
raise HTTPException(status_code=400, detail="Only MP4, AVI, and MOV videos are supported")
|
183 |
+
|
184 |
+
# Create a unique temporary file name
|
185 |
+
temp_input = f"temp/input_{uuid.uuid4()}.mp4"
|
186 |
+
temp_output = f"temp/output_{uuid.uuid4()}.mp4"
|
187 |
+
|
188 |
+
# Save uploaded file
|
189 |
+
with open(temp_input, "wb") as buffer:
|
190 |
+
shutil.copyfileobj(file.file, buffer)
|
191 |
+
|
192 |
+
# Process the video
|
193 |
+
cap = cv2.VideoCapture(temp_input)
|
194 |
if not cap.isOpened():
|
195 |
+
os.remove(temp_input)
|
196 |
+
raise HTTPException(status_code=400, detail="Could not open video file")
|
197 |
+
|
198 |
+
# Get video properties
|
199 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
200 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
201 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
202 |
+
|
203 |
+
# Create output video file
|
204 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
205 |
+
out = cv2.VideoWriter(temp_output, fourcc, fps, (width, height))
|
206 |
+
|
207 |
+
# Process frames
|
208 |
+
final_results = {
|
209 |
+
"intrusion_detected": False,
|
210 |
+
"intruding_objects": set(),
|
211 |
+
"max_person_count": 0,
|
212 |
+
"frames_processed": 0,
|
213 |
+
"total_detections": 0
|
214 |
+
}
|
215 |
|
216 |
while True:
|
217 |
ret, frame = cap.read()
|
218 |
if not ret:
|
219 |
break
|
220 |
|
221 |
+
# Process the frame
|
222 |
+
annotated_frame, frame_results = process_image(frame)
|
223 |
|
224 |
+
# Update final results
|
225 |
+
final_results["frames_processed"] += 1
|
226 |
+
final_results["total_detections"] += len(frame_results["detections"])
|
227 |
+
|
228 |
+
if frame_results["intrusion_detected"]:
|
229 |
+
final_results["intrusion_detected"] = True
|
230 |
+
if frame_results["intruding_object"]:
|
231 |
+
final_results["intruding_objects"].add(frame_results["intruding_object"])
|
232 |
+
|
233 |
+
final_results["max_person_count"] = max(
|
234 |
+
final_results["max_person_count"],
|
235 |
+
frame_results["person_count"]
|
236 |
+
)
|
237 |
+
|
238 |
+
# Write the frame
|
239 |
+
out.write(annotated_frame)
|
240 |
+
|
241 |
+
# Release resources
|
242 |
+
cap.release()
|
243 |
+
out.release()
|
244 |
+
|
245 |
+
# Convert set to list for JSON serialization
|
246 |
+
final_results["intruding_objects"] = list(final_results["intruding_objects"])
|
247 |
+
|
248 |
+
# Clean up input file
|
249 |
+
os.remove(temp_input)
|
250 |
+
|
251 |
+
return {
|
252 |
+
"results": final_results,
|
253 |
+
"video_path": f"/download_video/{os.path.basename(temp_output)}"
|
254 |
+
}
|
255 |
+
|
256 |
+
@app.get("/download_video/{filename}")
|
257 |
+
async def download_video(filename: str):
|
258 |
+
"""
|
259 |
+
Download the processed video file.
|
260 |
+
|
261 |
+
Args:
|
262 |
+
filename: The name of the processed video file
|
263 |
+
|
264 |
+
Returns:
|
265 |
+
The video file
|
266 |
+
"""
|
267 |
+
file_path = f"temp/{filename}"
|
268 |
+
if not os.path.exists(file_path):
|
269 |
+
raise HTTPException(status_code=404, detail="Video not found")
|
270 |
+
|
271 |
+
return FileResponse(file_path, media_type="video/mp4", filename="processed_video.mp4")
|
272 |
|
273 |
+
# # For local development
|
274 |
+
# if __name__ == "__main__":
|
275 |
+
# uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
|
requirements.txt
CHANGED
@@ -1,4 +1,7 @@
|
|
1 |
ultralytics
|
2 |
-
opencv-python
|
3 |
numpy
|
4 |
-
|
|
|
|
|
|
|
|
|
|
1 |
ultralytics
|
|
|
2 |
numpy
|
3 |
+
fastapi
|
4 |
+
uvicorn
|
5 |
+
opencv-python-headless
|
6 |
+
pillow
|
7 |
+
python-multipart
|