pyresearch commited on
Commit
e3bd12a
·
verified ·
1 Parent(s): 4aadcb7

Upload 3 files

Browse files
Files changed (1) hide show
  1. app.py +45 -37
app.py CHANGED
@@ -1,6 +1,5 @@
1
- import os
2
  import cv2
3
- from fastapi import FastAPI, Request, UploadFile, File
4
  from fastapi.responses import StreamingResponse, HTMLResponse
5
  from fastapi.templating import Jinja2Templates
6
  from typing import Generator
@@ -8,63 +7,67 @@ from ultralytics import YOLO
8
  import numpy as np
9
 
10
  app = FastAPI()
11
- templates = Jinja2Templates(directory="templates")
12
 
13
  # Load the YOLOv8 model
14
- model = YOLO("yolov8l.pt")
 
 
 
 
15
 
16
- video_path = None
17
- cap = None
18
  bird_count = 0
19
  tracker_initialized = False
20
- trackers = None
21
-
22
- @app.post("/upload_video/")
23
- async def upload_video(file: UploadFile = File(...)):
24
- global cap, tracker_initialized, trackers
25
-
26
- # Save uploaded video file
27
- file_location = f"uploads/{file.filename}"
28
- with open(file_location, "wb") as f:
29
- f.write(file.file.read())
30
 
31
- # Open the uploaded video file
32
- cap = cv2.VideoCapture(file_location)
33
-
34
- # Reset tracker and counter
35
- tracker_initialized = False
 
 
36
  trackers = None
37
-
38
- return {"info": f"file '{file.filename}' saved at '{file_location}'"}
39
 
40
  def process_video() -> Generator[bytes, None, None]:
41
- global bird_count, tracker_initialized, trackers, cap
42
  while cap.isOpened():
 
43
  success, frame = cap.read()
44
 
45
  if success:
46
  frame_height, frame_width = frame.shape[:2]
47
  if not tracker_initialized:
 
48
  results = model(frame)
 
 
49
  detections = results[0].boxes.data.cpu().numpy()
 
 
50
  bird_results = [detection for detection in detections if int(detection[5]) == 14]
51
 
 
52
  try:
53
- trackers = cv2.legacy.MultiTracker_create()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  except AttributeError:
55
- trackers = cv2.MultiTracker_create()
56
-
57
- for res in bird_results:
58
- x1, y1, x2, y2, confidence, class_id = res
59
- x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
60
- if 0 <= x1 < frame_width and 0 <= y1 < frame_height and x2 <= frame_width and y2 <= frame_height:
61
- bbox = (x1, y1, x2 - x1, y2 - y1)
62
- tracker = cv2.legacy.TrackerCSRT_create() if hasattr(cv2, 'legacy') else cv2.TrackerCSRT_create()
63
- trackers.add(tracker, frame, bbox)
64
-
65
- bird_count = len(bird_results)
66
- tracker_initialized = True
67
  else:
 
68
  success, boxes = trackers.update(frame)
69
 
70
  if success:
@@ -76,14 +79,19 @@ def process_video() -> Generator[bytes, None, None]:
76
  else:
77
  tracker_initialized = False
78
 
 
79
  ret, buffer = cv2.imencode('.jpg', frame)
80
  frame = buffer.tobytes()
 
 
81
  yield (b'--frame\r\n'
82
  b'Content-Type: image/jpeg\r\n\r\n' + frame + b'\r\n')
83
  else:
84
  break
85
  cap.release()
86
 
 
 
87
  @app.get("/", response_class=HTMLResponse)
88
  async def index(request: Request):
89
  return templates.TemplateResponse("index.html", {"request": request, "bird_count": bird_count})
 
 
1
  import cv2
2
+ from fastapi import FastAPI, Request
3
  from fastapi.responses import StreamingResponse, HTMLResponse
4
  from fastapi.templating import Jinja2Templates
5
  from typing import Generator
 
7
  import numpy as np
8
 
9
  app = FastAPI()
 
10
 
11
  # Load the YOLOv8 model
12
+ model = YOLO("yolov8n.pt")
13
+
14
+ # Open the video file
15
+ video_path = "demo.mp4"
16
+ cap = cv2.VideoCapture(video_path)
17
 
 
 
18
  bird_count = 0
19
  tracker_initialized = False
 
 
 
 
 
 
 
 
 
 
20
 
21
+ # Initialize trackers based on OpenCV version
22
+ try:
23
+ if hasattr(cv2, 'legacy'):
24
+ trackers = cv2.legacy.MultiTracker_create()
25
+ else:
26
+ trackers = cv2.TrackerCSRT_create()
27
+ except AttributeError:
28
  trackers = None
29
+ tracker_initialized = False
 
30
 
31
  def process_video() -> Generator[bytes, None, None]:
32
+ global bird_count, tracker_initialized, trackers
33
  while cap.isOpened():
34
+ # Read a frame from the video
35
  success, frame = cap.read()
36
 
37
  if success:
38
  frame_height, frame_width = frame.shape[:2]
39
  if not tracker_initialized:
40
+ # Run YOLOv8 inference on the frame
41
  results = model(frame)
42
+
43
+ # Extract the detected objects
44
  detections = results[0].boxes.data.cpu().numpy()
45
+
46
+ # Filter results to include only the "bird" class (class id 14 in COCO)
47
  bird_results = [detection for detection in detections if int(detection[5]) == 14]
48
 
49
+ # Initialize trackers for bird results
50
  try:
51
+ if hasattr(cv2, 'legacy'):
52
+ trackers = cv2.legacy.MultiTracker_create()
53
+ else:
54
+ trackers = cv2.MultiTracker_create()
55
+
56
+ for res in bird_results:
57
+ x1, y1, x2, y2, confidence, class_id = res
58
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
59
+ if 0 <= x1 < frame_width and 0 <= y1 < frame_height and x2 <= frame_width and y2 <= frame_height:
60
+ bbox = (x1, y1, x2 - x1, y2 - y1)
61
+ tracker = cv2.legacy.TrackerCSRT_create() if hasattr(cv2, 'legacy') else cv2.TrackerCSRT_create()
62
+ trackers.add(tracker, frame, bbox)
63
+
64
+ bird_count = len(bird_results)
65
+ tracker_initialized = True
66
  except AttributeError:
67
+ trackers = None
68
+ tracker_initialized = False
 
 
 
 
 
 
 
 
 
 
69
  else:
70
+ # Update trackers and get updated positions
71
  success, boxes = trackers.update(frame)
72
 
73
  if success:
 
79
  else:
80
  tracker_initialized = False
81
 
82
+ # Encode the frame in JPEG format
83
  ret, buffer = cv2.imencode('.jpg', frame)
84
  frame = buffer.tobytes()
85
+
86
+ # Use generator to yield the frame
87
  yield (b'--frame\r\n'
88
  b'Content-Type: image/jpeg\r\n\r\n' + frame + b'\r\n')
89
  else:
90
  break
91
  cap.release()
92
 
93
+ templates = Jinja2Templates(directory="templates")
94
+
95
  @app.get("/", response_class=HTMLResponse)
96
  async def index(request: Request):
97
  return templates.TemplateResponse("index.html", {"request": request, "bird_count": bird_count})