umarbalak commited on
Commit
7be1f70
·
1 Parent(s): 3854f63

initial commit

Browse files
Files changed (2) hide show
  1. app.py +236 -43
  2. requirements.txt +5 -2
app.py CHANGED
@@ -1,82 +1,275 @@
1
- import gradio as gr
 
 
 
2
  import cv2
3
  import numpy as np
4
  from ultralytics import YOLO
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  # Load YOLO model
7
- model = YOLO("yolov8n.pt") # Updated to correct model name
 
8
 
9
- # Define trapezoidal restricted area (top-left, top-right, bottom-right, bottom-left)
10
  trapezoid_pts = np.array([[250, 150], [400, 150], [450, 300], [200, 300]], np.int32)
11
 
 
 
 
12
  def is_inside_trapezoid(box, trapezoid_pts):
13
  """Check if the center of a detected object is inside the trapezoidal area."""
14
  x1, y1, x2, y2 = box
15
- cx, cy = int((x1 + x2) / 2), int((y1 + y2) / 2)
 
 
16
  return cv2.pointPolygonTest(trapezoid_pts, (cx, cy), False) >= 0
17
 
18
- def detect_objects(frame):
19
- if frame is None:
20
- return np.zeros((480, 640, 3), dtype=np.uint8), "No input frame"
21
 
 
22
  results = model.predict(frame, conf=0.5)
23
- annotated_frame = results[0].plot()
24
-
25
- # Draw the trapezoid area
26
  cv2.polylines(annotated_frame, [trapezoid_pts.reshape((-1, 1, 2))], isClosed=True, color=(0, 0, 255), thickness=2)
 
 
 
27
 
28
- isAlert = {"alert": [False, ""], "personCount": 0}
29
- classInIntrusion = ["person", "bicycle", "car", "motorcycle"]
30
-
31
  for r in results:
32
- for box, cls in zip(r.boxes.xyxy, r.boxes.cls):
33
- class_id = int(cls.item())
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  if class_id == 0: # Person
35
- isAlert["personCount"] += 1
 
36
  if class_id in [0, 1, 2, 3]: # Person, bicycle, car, motorcycle
37
  if is_inside_trapezoid(box.tolist(), trapezoid_pts):
38
- isAlert["alert"] = [True, classInIntrusion[class_id]]
39
  # Mark the intrusion with a red box
40
- x1, y1, x2, y2 = map(int, box.tolist())
41
  cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), (0, 0, 255), 3)
42
 
43
  # Add alert text on the frame
44
  alert_text = f"Intrusion Alert: {isAlert['alert'][0]}, Object: {isAlert['alert'][1]}, Persons: {isAlert['personCount']}"
45
  cv2.putText(annotated_frame, alert_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
46
 
47
- return annotated_frame, alert_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- def webcam_feed():
50
- cap = cv2.VideoCapture(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- # Check if the webcam opened successfully
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  if not cap.isOpened():
54
- return np.zeros((480, 640, 3), dtype=np.uint8), "Failed to open webcam"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  while True:
57
  ret, frame = cap.read()
58
  if not ret:
59
  break
60
 
61
- # Process frame
62
- result_frame, alert_message = detect_objects(frame)
63
 
64
- # Return the processed frame and alert message
65
- yield result_frame, alert_message
66
-
67
- # Create the Gradio interface with webcam
68
- demo = gr.Interface(
69
- fn=webcam_feed,
70
- inputs=[],
71
- outputs=[
72
- gr.Image(label="Detection Output"),
73
- gr.Textbox(label="Alert Status")
74
- ],
75
- live=True,
76
- title="YOLO Intrusion Detection",
77
- description="Real-time detection of persons and vehicles inside a restricted trapezoidal area.",
78
- allow_flagging="never"
79
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- if __name__ == "__main__":
82
- demo.queue(max_size=1).launch()
 
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from fastapi.responses import FileResponse, JSONResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ import uvicorn
5
  import cv2
6
  import numpy as np
7
  from ultralytics import YOLO
8
+ import os
9
+ import shutil
10
+ from typing import Optional
11
+ import uuid
12
+ import base64
13
+ from io import BytesIO
14
+ from PIL import Image
15
+
16
+ # Create FastAPI app
17
+ app = FastAPI(
18
+ title="YOLO Intrusion Detection API",
19
+ description="API for detecting intrusions using YOLOv8 model",
20
+ version="1.0.0"
21
+ )
22
+
23
+ # Add CORS middleware
24
+ app.add_middleware(
25
+ CORSMiddleware,
26
+ allow_origins=["*"],
27
+ allow_credentials=True,
28
+ allow_methods=["*"],
29
+ allow_headers=["*"],
30
+ )
31
 
32
  # Load YOLO model
33
+ model_name = 'yolov8n.pt'
34
+ model = None # Will be loaded on startup
35
 
36
+ # Define trapezoidal restricted area
37
  trapezoid_pts = np.array([[250, 150], [400, 150], [450, 300], [200, 300]], np.int32)
38
 
39
+ # Create temp directory for uploads if it doesn't exist
40
+ os.makedirs("temp", exist_ok=True)
41
+
42
  def is_inside_trapezoid(box, trapezoid_pts):
43
  """Check if the center of a detected object is inside the trapezoidal area."""
44
  x1, y1, x2, y2 = box
45
+ cx, cy = int((x1 + x2) / 2), int((y1 + y2) / 2) # Calculate center of detected object
46
+
47
+ # Use point-in-polygon check
48
  return cv2.pointPolygonTest(trapezoid_pts, (cx, cy), False) >= 0
49
 
50
+ def process_image(frame):
51
+ """Process a single image and return the annotated image and intrusion data."""
52
+ global model
53
 
54
+ # Perform object detection
55
  results = model.predict(frame, conf=0.5)
56
+ annotated_frame = results[0].plot() # Draw bounding boxes
57
+
58
+ # Draw trapezoidal restricted area
59
  cv2.polylines(annotated_frame, [trapezoid_pts.reshape((-1, 1, 2))], isClosed=True, color=(0, 0, 255), thickness=2)
60
+
61
+ isAlert = {'alert': [False, ""], 'personCount': 0}
62
+ classInIntrusion = ['person', 'bicycle', 'car', 'motorcycle']
63
 
64
+ detections = []
65
+
66
+ # Loop through detected objects
67
  for r in results:
68
+ for box, cls, conf in zip(r.boxes.xyxy, r.boxes.cls, r.boxes.conf):
69
+ class_id = int(cls.item()) # Convert to integer
70
+ confidence = float(conf.item())
71
+ x1, y1, x2, y2 = map(int, box.tolist())
72
+
73
+ class_name = classInIntrusion[class_id] if class_id < len(classInIntrusion) else f"class_{class_id}"
74
+
75
+ # Add to detections list
76
+ detections.append({
77
+ "class": class_name,
78
+ "confidence": confidence,
79
+ "bbox": [x1, y1, x2, y2],
80
+ "in_restricted_area": is_inside_trapezoid(box.tolist(), trapezoid_pts)
81
+ })
82
+
83
  if class_id == 0: # Person
84
+ isAlert['personCount'] += 1
85
+
86
  if class_id in [0, 1, 2, 3]: # Person, bicycle, car, motorcycle
87
  if is_inside_trapezoid(box.tolist(), trapezoid_pts):
88
+ isAlert['alert'] = [True, classInIntrusion[class_id]]
89
  # Mark the intrusion with a red box
 
90
  cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), (0, 0, 255), 3)
91
 
92
  # Add alert text on the frame
93
  alert_text = f"Intrusion Alert: {isAlert['alert'][0]}, Object: {isAlert['alert'][1]}, Persons: {isAlert['personCount']}"
94
  cv2.putText(annotated_frame, alert_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
95
 
96
+ # Convert the response
97
+ response = {
98
+ "intrusion_detected": isAlert['alert'][0],
99
+ "intruding_object": isAlert['alert'][1],
100
+ "person_count": isAlert['personCount'],
101
+ "detections": detections
102
+ }
103
+
104
+ return annotated_frame, response
105
+
106
+ def encode_image_to_base64(image):
107
+ """Convert an OpenCV image to base64 encoded string."""
108
+ _, buffer = cv2.imencode('.jpg', image)
109
+ return base64.b64encode(buffer).decode('utf-8')
110
+
111
+ @app.on_event("startup")
112
+ async def startup_event():
113
+ """Load the YOLO model when the app starts."""
114
+ global model
115
+ model = YOLO(model_name)
116
+ print(f"Model {model_name} loaded successfully")
117
 
118
+ @app.get("/")
119
+ async def root():
120
+ """Root endpoint."""
121
+ return {
122
+ "message": "YOLO Intrusion Detection API is running",
123
+ "documentation": "/docs",
124
+ "endpoints": {
125
+ "process_image": "/process_image/",
126
+ "process_video": "/process_video/",
127
+ "health": "/health/"
128
+ }
129
+ }
130
+
131
+ @app.get("/health/")
132
+ async def health_check():
133
+ """Health check endpoint."""
134
+ return {"status": "healthy", "model": model_name}
135
+
136
+ @app.post("/process_image/")
137
+ async def api_process_image(file: UploadFile = File(...), return_image: bool = True):
138
+ """
139
+ Process an image file and detect intrusions.
140
+
141
+ Args:
142
+ file: The image file to process
143
+ return_image: If True, returns the annotated image as base64
144
+
145
+ Returns:
146
+ JSON with detection results and optionally the annotated image
147
+ """
148
+ # Check file extension
149
+ if not file.filename.lower().endswith(('.png', '.jpg', '.jpeg')):
150
+ raise HTTPException(status_code=400, detail="Only PNG and JPG images are supported")
151
+
152
+ # Read and process image
153
+ contents = await file.read()
154
+ nparr = np.frombuffer(contents, np.uint8)
155
+ img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
156
+
157
+ if img is None:
158
+ raise HTTPException(status_code=400, detail="Could not decode image")
159
+
160
+ # Process the image
161
+ annotated_img, results = process_image(img)
162
+
163
+ # Optionally include the annotated image
164
+ if return_image:
165
+ results["image"] = encode_image_to_base64(annotated_img)
166
 
167
+ return results
168
+
169
+ @app.post("/process_video/")
170
+ async def api_process_video(file: UploadFile = File(...)):
171
+ """
172
+ Process a video file and detect intrusions.
173
+
174
+ Args:
175
+ file: The video file to process
176
+
177
+ Returns:
178
+ JSON with detection results and path to processed video
179
+ """
180
+ # Check file extension
181
+ if not file.filename.lower().endswith(('.mp4', '.avi', '.mov')):
182
+ raise HTTPException(status_code=400, detail="Only MP4, AVI, and MOV videos are supported")
183
+
184
+ # Create a unique temporary file name
185
+ temp_input = f"temp/input_{uuid.uuid4()}.mp4"
186
+ temp_output = f"temp/output_{uuid.uuid4()}.mp4"
187
+
188
+ # Save uploaded file
189
+ with open(temp_input, "wb") as buffer:
190
+ shutil.copyfileobj(file.file, buffer)
191
+
192
+ # Process the video
193
+ cap = cv2.VideoCapture(temp_input)
194
  if not cap.isOpened():
195
+ os.remove(temp_input)
196
+ raise HTTPException(status_code=400, detail="Could not open video file")
197
+
198
+ # Get video properties
199
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
200
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
201
+ fps = cap.get(cv2.CAP_PROP_FPS)
202
+
203
+ # Create output video file
204
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
205
+ out = cv2.VideoWriter(temp_output, fourcc, fps, (width, height))
206
+
207
+ # Process frames
208
+ final_results = {
209
+ "intrusion_detected": False,
210
+ "intruding_objects": set(),
211
+ "max_person_count": 0,
212
+ "frames_processed": 0,
213
+ "total_detections": 0
214
+ }
215
 
216
  while True:
217
  ret, frame = cap.read()
218
  if not ret:
219
  break
220
 
221
+ # Process the frame
222
+ annotated_frame, frame_results = process_image(frame)
223
 
224
+ # Update final results
225
+ final_results["frames_processed"] += 1
226
+ final_results["total_detections"] += len(frame_results["detections"])
227
+
228
+ if frame_results["intrusion_detected"]:
229
+ final_results["intrusion_detected"] = True
230
+ if frame_results["intruding_object"]:
231
+ final_results["intruding_objects"].add(frame_results["intruding_object"])
232
+
233
+ final_results["max_person_count"] = max(
234
+ final_results["max_person_count"],
235
+ frame_results["person_count"]
236
+ )
237
+
238
+ # Write the frame
239
+ out.write(annotated_frame)
240
+
241
+ # Release resources
242
+ cap.release()
243
+ out.release()
244
+
245
+ # Convert set to list for JSON serialization
246
+ final_results["intruding_objects"] = list(final_results["intruding_objects"])
247
+
248
+ # Clean up input file
249
+ os.remove(temp_input)
250
+
251
+ return {
252
+ "results": final_results,
253
+ "video_path": f"/download_video/{os.path.basename(temp_output)}"
254
+ }
255
+
256
+ @app.get("/download_video/{filename}")
257
+ async def download_video(filename: str):
258
+ """
259
+ Download the processed video file.
260
+
261
+ Args:
262
+ filename: The name of the processed video file
263
+
264
+ Returns:
265
+ The video file
266
+ """
267
+ file_path = f"temp/{filename}"
268
+ if not os.path.exists(file_path):
269
+ raise HTTPException(status_code=404, detail="Video not found")
270
+
271
+ return FileResponse(file_path, media_type="video/mp4", filename="processed_video.mp4")
272
 
273
+ # # For local development
274
+ # if __name__ == "__main__":
275
+ # uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
requirements.txt CHANGED
@@ -1,4 +1,7 @@
1
  ultralytics
2
- opencv-python
3
  numpy
4
- gradio
 
 
 
 
 
1
  ultralytics
 
2
  numpy
3
+ fastapi
4
+ uvicorn
5
+ opencv-python-headless
6
+ pillow
7
+ python-multipart