File size: 3,862 Bytes
4aadcb7
e3bd12a
4aadcb7
 
 
 
 
 
 
 
 
e3bd12a
 
 
 
 
4aadcb7
 
 
 
e3bd12a
 
 
 
 
 
 
4aadcb7
e3bd12a
4aadcb7
 
e3bd12a
4aadcb7
e3bd12a
4aadcb7
 
 
 
 
e3bd12a
4aadcb7
e3bd12a
 
4aadcb7
e3bd12a
 
4aadcb7
 
e3bd12a
4aadcb7
e3bd12a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aadcb7
e3bd12a
 
4aadcb7
e3bd12a
4aadcb7
 
 
 
 
 
 
 
 
 
 
e3bd12a
4aadcb7
 
e3bd12a
 
4aadcb7
 
 
 
 
 
e3bd12a
 
4aadcb7
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import cv2
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, HTMLResponse
from fastapi.templating import Jinja2Templates
from typing import Generator
from ultralytics import YOLO
import numpy as np

app = FastAPI()

# Load the YOLOv8 model
model = YOLO("yolov8n.pt")

# Open the video file
video_path = "demo.mp4"
cap = cv2.VideoCapture(video_path)

bird_count = 0
tracker_initialized = False

# Initialize trackers based on OpenCV version
try:
    if hasattr(cv2, 'legacy'):
        trackers = cv2.legacy.MultiTracker_create()
    else:
        trackers = cv2.TrackerCSRT_create()
except AttributeError:
    trackers = None
    tracker_initialized = False

def process_video() -> Generator[bytes, None, None]:
    global bird_count, tracker_initialized, trackers
    while cap.isOpened():
        # Read a frame from the video
        success, frame = cap.read()

        if success:
            frame_height, frame_width = frame.shape[:2]
            if not tracker_initialized:
                # Run YOLOv8 inference on the frame
                results = model(frame)

                # Extract the detected objects
                detections = results[0].boxes.data.cpu().numpy()

                # Filter results to include only the "bird" class (class id 14 in COCO)
                bird_results = [detection for detection in detections if int(detection[5]) == 14]

                # Initialize trackers for bird results
                try:
                    if hasattr(cv2, 'legacy'):
                        trackers = cv2.legacy.MultiTracker_create()
                    else:
                        trackers = cv2.MultiTracker_create()

                    for res in bird_results:
                        x1, y1, x2, y2, confidence, class_id = res
                        x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
                        if 0 <= x1 < frame_width and 0 <= y1 < frame_height and x2 <= frame_width and y2 <= frame_height:
                            bbox = (x1, y1, x2 - x1, y2 - y1)
                            tracker = cv2.legacy.TrackerCSRT_create() if hasattr(cv2, 'legacy') else cv2.TrackerCSRT_create()
                            trackers.add(tracker, frame, bbox)

                    bird_count = len(bird_results)
                    tracker_initialized = True
                except AttributeError:
                    trackers = None
                    tracker_initialized = False
            else:
                # Update trackers and get updated positions
                success, boxes = trackers.update(frame)

                if success:
                    bird_count = len(boxes)
                    for box in boxes:
                        x, y, w, h = [int(v) for v in box]
                        cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2)
                        cv2.putText(frame, 'bird', (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
                else:
                    tracker_initialized = False

            # Encode the frame in JPEG format
            ret, buffer = cv2.imencode('.jpg', frame)
            frame = buffer.tobytes()

            # Use generator to yield the frame
            yield (b'--frame\r\n'
                   b'Content-Type: image/jpeg\r\n\r\n' + frame + b'\r\n')
        else:
            break
    cap.release()

templates = Jinja2Templates(directory="templates")

@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
    return templates.TemplateResponse("index.html", {"request": request, "bird_count": bird_count})

@app.get("/video_feed")
async def video_feed():
    return StreamingResponse(process_video(), media_type='multipart/x-mixed-replace; boundary=frame')

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)