birdcount / app.py
pyresearch's picture
Upload 3 files
e3bd12a verified
raw
history blame
3.86 kB
import cv2
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, HTMLResponse
from fastapi.templating import Jinja2Templates
from typing import Generator
from ultralytics import YOLO
import numpy as np
app = FastAPI()
# Load the YOLOv8 model
model = YOLO("yolov8n.pt")
# Open the video file
video_path = "demo.mp4"
cap = cv2.VideoCapture(video_path)
bird_count = 0
tracker_initialized = False
# Initialize trackers based on OpenCV version
try:
if hasattr(cv2, 'legacy'):
trackers = cv2.legacy.MultiTracker_create()
else:
trackers = cv2.TrackerCSRT_create()
except AttributeError:
trackers = None
tracker_initialized = False
def process_video() -> Generator[bytes, None, None]:
global bird_count, tracker_initialized, trackers
while cap.isOpened():
# Read a frame from the video
success, frame = cap.read()
if success:
frame_height, frame_width = frame.shape[:2]
if not tracker_initialized:
# Run YOLOv8 inference on the frame
results = model(frame)
# Extract the detected objects
detections = results[0].boxes.data.cpu().numpy()
# Filter results to include only the "bird" class (class id 14 in COCO)
bird_results = [detection for detection in detections if int(detection[5]) == 14]
# Initialize trackers for bird results
try:
if hasattr(cv2, 'legacy'):
trackers = cv2.legacy.MultiTracker_create()
else:
trackers = cv2.MultiTracker_create()
for res in bird_results:
x1, y1, x2, y2, confidence, class_id = res
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
if 0 <= x1 < frame_width and 0 <= y1 < frame_height and x2 <= frame_width and y2 <= frame_height:
bbox = (x1, y1, x2 - x1, y2 - y1)
tracker = cv2.legacy.TrackerCSRT_create() if hasattr(cv2, 'legacy') else cv2.TrackerCSRT_create()
trackers.add(tracker, frame, bbox)
bird_count = len(bird_results)
tracker_initialized = True
except AttributeError:
trackers = None
tracker_initialized = False
else:
# Update trackers and get updated positions
success, boxes = trackers.update(frame)
if success:
bird_count = len(boxes)
for box in boxes:
x, y, w, h = [int(v) for v in box]
cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2)
cv2.putText(frame, 'bird', (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
else:
tracker_initialized = False
# Encode the frame in JPEG format
ret, buffer = cv2.imencode('.jpg', frame)
frame = buffer.tobytes()
# Use generator to yield the frame
yield (b'--frame\r\n'
b'Content-Type: image/jpeg\r\n\r\n' + frame + b'\r\n')
else:
break
cap.release()
templates = Jinja2Templates(directory="templates")
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
return templates.TemplateResponse("index.html", {"request": request, "bird_count": bird_count})
@app.get("/video_feed")
async def video_feed():
return StreamingResponse(process_video(), media_type='multipart/x-mixed-replace; boundary=frame')
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)