birdcount / app.py
pyresearch's picture
Upload 6 files
4aadcb7 verified
raw
history blame
3.48 kB
import os
import cv2
from fastapi import FastAPI, Request, UploadFile, File
from fastapi.responses import StreamingResponse, HTMLResponse
from fastapi.templating import Jinja2Templates
from typing import Generator
from ultralytics import YOLO
import numpy as np
app = FastAPI()
templates = Jinja2Templates(directory="templates")
# Load the YOLOv8 model
model = YOLO("yolov8l.pt")
video_path = None
cap = None
bird_count = 0
tracker_initialized = False
trackers = None
@app.post("/upload_video/")
async def upload_video(file: UploadFile = File(...)):
global cap, tracker_initialized, trackers
# Save uploaded video file
file_location = f"uploads/{file.filename}"
with open(file_location, "wb") as f:
f.write(file.file.read())
# Open the uploaded video file
cap = cv2.VideoCapture(file_location)
# Reset tracker and counter
tracker_initialized = False
trackers = None
return {"info": f"file '{file.filename}' saved at '{file_location}'"}
def process_video() -> Generator[bytes, None, None]:
global bird_count, tracker_initialized, trackers, cap
while cap.isOpened():
success, frame = cap.read()
if success:
frame_height, frame_width = frame.shape[:2]
if not tracker_initialized:
results = model(frame)
detections = results[0].boxes.data.cpu().numpy()
bird_results = [detection for detection in detections if int(detection[5]) == 14]
try:
trackers = cv2.legacy.MultiTracker_create()
except AttributeError:
trackers = cv2.MultiTracker_create()
for res in bird_results:
x1, y1, x2, y2, confidence, class_id = res
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
if 0 <= x1 < frame_width and 0 <= y1 < frame_height and x2 <= frame_width and y2 <= frame_height:
bbox = (x1, y1, x2 - x1, y2 - y1)
tracker = cv2.legacy.TrackerCSRT_create() if hasattr(cv2, 'legacy') else cv2.TrackerCSRT_create()
trackers.add(tracker, frame, bbox)
bird_count = len(bird_results)
tracker_initialized = True
else:
success, boxes = trackers.update(frame)
if success:
bird_count = len(boxes)
for box in boxes:
x, y, w, h = [int(v) for v in box]
cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2)
cv2.putText(frame, 'bird', (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
else:
tracker_initialized = False
ret, buffer = cv2.imencode('.jpg', frame)
frame = buffer.tobytes()
yield (b'--frame\r\n'
b'Content-Type: image/jpeg\r\n\r\n' + frame + b'\r\n')
else:
break
cap.release()
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
return templates.TemplateResponse("index.html", {"request": request, "bird_count": bird_count})
@app.get("/video_feed")
async def video_feed():
return StreamingResponse(process_video(), media_type='multipart/x-mixed-replace; boundary=frame')
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)