|
from fastapi import FastAPI, File, UploadFile, HTTPException |
|
from fastapi.responses import FileResponse, JSONResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
import uvicorn |
|
import cv2 |
|
import numpy as np |
|
from ultralytics import YOLO |
|
import os |
|
import shutil |
|
from typing import Optional |
|
import uuid |
|
import base64 |
|
from io import BytesIO |
|
from PIL import Image |
|
|
|
|
|
app = FastAPI( |
|
title="YOLO Intrusion Detection API", |
|
description="API for detecting intrusions using YOLOv8 model", |
|
version="1.0.0" |
|
) |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
model_name = 'yolov8n.pt' |
|
model = None |
|
|
|
|
|
trapezoid_pts = np.array([[250, 150], [400, 150], [450, 300], [200, 300]], np.int32) |
|
|
|
|
|
os.makedirs("temp", exist_ok=True) |
|
|
|
def is_inside_trapezoid(box, trapezoid_pts): |
|
"""Check if the center of a detected object is inside the trapezoidal area.""" |
|
x1, y1, x2, y2 = box |
|
cx, cy = int((x1 + x2) / 2), int((y1 + y2) / 2) |
|
|
|
|
|
return cv2.pointPolygonTest(trapezoid_pts, (cx, cy), False) >= 0 |
|
|
|
def process_image(frame): |
|
"""Process a single image and return the annotated image and intrusion data.""" |
|
global model |
|
|
|
|
|
results = model.predict(frame, conf=0.5) |
|
annotated_frame = results[0].plot() |
|
|
|
|
|
cv2.polylines(annotated_frame, [trapezoid_pts.reshape((-1, 1, 2))], isClosed=True, color=(0, 0, 255), thickness=2) |
|
|
|
isAlert = {'alert': [False, ""], 'personCount': 0} |
|
classInIntrusion = ['person', 'bicycle', 'car', 'motorcycle'] |
|
|
|
detections = [] |
|
|
|
|
|
for r in results: |
|
for box, cls, conf in zip(r.boxes.xyxy, r.boxes.cls, r.boxes.conf): |
|
class_id = int(cls.item()) |
|
confidence = float(conf.item()) |
|
x1, y1, x2, y2 = map(int, box.tolist()) |
|
|
|
class_name = classInIntrusion[class_id] if class_id < len(classInIntrusion) else f"class_{class_id}" |
|
|
|
|
|
detections.append({ |
|
"class": class_name, |
|
"confidence": confidence, |
|
"bbox": [x1, y1, x2, y2], |
|
"in_restricted_area": is_inside_trapezoid(box.tolist(), trapezoid_pts) |
|
}) |
|
|
|
if class_id == 0: |
|
isAlert['personCount'] += 1 |
|
|
|
if class_id in [0, 1, 2, 3]: |
|
if is_inside_trapezoid(box.tolist(), trapezoid_pts): |
|
isAlert['alert'] = [True, classInIntrusion[class_id]] |
|
|
|
cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), (0, 0, 255), 3) |
|
|
|
|
|
alert_text = f"Intrusion Alert: {isAlert['alert'][0]}, Object: {isAlert['alert'][1]}, Persons: {isAlert['personCount']}" |
|
cv2.putText(annotated_frame, alert_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2) |
|
|
|
|
|
response = { |
|
"intrusion_detected": isAlert['alert'][0], |
|
"intruding_object": isAlert['alert'][1], |
|
"person_count": isAlert['personCount'], |
|
"detections": detections |
|
} |
|
|
|
return annotated_frame, response |
|
|
|
def encode_image_to_base64(image): |
|
"""Convert an OpenCV image to base64 encoded string.""" |
|
_, buffer = cv2.imencode('.jpg', image) |
|
return base64.b64encode(buffer).decode('utf-8') |
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
"""Load the YOLO model when the app starts.""" |
|
global model |
|
model = YOLO(model_name) |
|
print(f"Model {model_name} loaded successfully") |
|
|
|
@app.get("/") |
|
async def root(): |
|
"""Root endpoint.""" |
|
return { |
|
"message": "YOLO Intrusion Detection API is running", |
|
"documentation": "/docs", |
|
"endpoints": { |
|
"process_image": "/process_image/", |
|
"process_video": "/process_video/", |
|
"health": "/health/" |
|
} |
|
} |
|
|
|
@app.get("/health/") |
|
async def health_check(): |
|
"""Health check endpoint.""" |
|
return {"status": "healthy", "model": model_name} |
|
|
|
@app.post("/process_image/") |
|
async def api_process_image(file: UploadFile = File(...), return_image: bool = True): |
|
""" |
|
Process an image file and detect intrusions. |
|
|
|
Args: |
|
file: The image file to process |
|
return_image: If True, returns the annotated image as base64 |
|
|
|
Returns: |
|
JSON with detection results and optionally the annotated image |
|
""" |
|
|
|
if not file.filename.lower().endswith(('.png', '.jpg', '.jpeg')): |
|
raise HTTPException(status_code=400, detail="Only PNG and JPG images are supported") |
|
|
|
|
|
contents = await file.read() |
|
nparr = np.frombuffer(contents, np.uint8) |
|
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
|
|
|
if img is None: |
|
raise HTTPException(status_code=400, detail="Could not decode image") |
|
|
|
|
|
annotated_img, results = process_image(img) |
|
|
|
|
|
if return_image: |
|
results["image"] = encode_image_to_base64(annotated_img) |
|
|
|
return results |
|
|
|
@app.post("/process_video/") |
|
async def api_process_video(file: UploadFile = File(...)): |
|
""" |
|
Process a video file and detect intrusions. |
|
|
|
Args: |
|
file: The video file to process |
|
|
|
Returns: |
|
JSON with detection results and path to processed video |
|
""" |
|
|
|
if not file.filename.lower().endswith(('.mp4', '.avi', '.mov')): |
|
raise HTTPException(status_code=400, detail="Only MP4, AVI, and MOV videos are supported") |
|
|
|
|
|
temp_input = f"temp/input_{uuid.uuid4()}.mp4" |
|
temp_output = f"temp/output_{uuid.uuid4()}.mp4" |
|
|
|
|
|
with open(temp_input, "wb") as buffer: |
|
shutil.copyfileobj(file.file, buffer) |
|
|
|
|
|
cap = cv2.VideoCapture(temp_input) |
|
if not cap.isOpened(): |
|
os.remove(temp_input) |
|
raise HTTPException(status_code=400, detail="Could not open video file") |
|
|
|
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
out = cv2.VideoWriter(temp_output, fourcc, fps, (width, height)) |
|
|
|
|
|
final_results = { |
|
"intrusion_detected": False, |
|
"intruding_objects": set(), |
|
"max_person_count": 0, |
|
"frames_processed": 0, |
|
"total_detections": 0 |
|
} |
|
|
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
|
|
annotated_frame, frame_results = process_image(frame) |
|
|
|
|
|
final_results["frames_processed"] += 1 |
|
final_results["total_detections"] += len(frame_results["detections"]) |
|
|
|
if frame_results["intrusion_detected"]: |
|
final_results["intrusion_detected"] = True |
|
if frame_results["intruding_object"]: |
|
final_results["intruding_objects"].add(frame_results["intruding_object"]) |
|
|
|
final_results["max_person_count"] = max( |
|
final_results["max_person_count"], |
|
frame_results["person_count"] |
|
) |
|
|
|
|
|
out.write(annotated_frame) |
|
|
|
|
|
cap.release() |
|
out.release() |
|
|
|
|
|
final_results["intruding_objects"] = list(final_results["intruding_objects"]) |
|
|
|
|
|
os.remove(temp_input) |
|
|
|
return { |
|
"results": final_results, |
|
"video_path": f"/download_video/{os.path.basename(temp_output)}" |
|
} |
|
|
|
@app.get("/download_video/{filename}") |
|
async def download_video(filename: str): |
|
""" |
|
Download the processed video file. |
|
|
|
Args: |
|
filename: The name of the processed video file |
|
|
|
Returns: |
|
The video file |
|
""" |
|
file_path = f"temp/{filename}" |
|
if not os.path.exists(file_path): |
|
raise HTTPException(status_code=404, detail="Video not found") |
|
|
|
return FileResponse(file_path, media_type="video/mp4", filename="processed_video.mp4") |
|
|
|
|
|
|
|
|