|
from ultralytics import YOLO |
|
from ultralytics.utils.plotting import Annotator, colors |
|
import cv2 |
|
from typing import List |
|
import numpy as np |
|
|
|
class ObjectDetector(): |
|
def __init__(self, pretrained_model: str = 'yolov8n.pt', debug: bool = False): |
|
self.model = YOLO(pretrained_model) |
|
self.debug = debug |
|
|
|
self.color_map = { |
|
0: (0, 255, 0), |
|
1: (255, 128, 0), |
|
2: (0, 0, 255), |
|
3: (255, 0, 0), |
|
4: (0, 255, 255), |
|
5: (128, 0, 255), |
|
6: (255, 0, 255), |
|
7: (0, 128, 255), |
|
8: (255, 255, 0), |
|
9: (128, 255, 0), |
|
10: (0, 165, 255), |
|
11: (139, 69, 19), |
|
12: (128, 128, 128), |
|
13: (192, 192, 192), |
|
14: (255, 191, 0), |
|
15: (255, 0, 128), |
|
16: (0, 255, 191), |
|
17: (128, 128, 0), |
|
18: (0, 140, 255), |
|
19: (0, 215, 255), |
|
20: (34, 139, 34), |
|
21: (75, 75, 75), |
|
22: (0, 69, 255), |
|
23: (122, 61, 0), |
|
24: (108, 108, 108), |
|
25: (211, 211, 211), |
|
26: (0, 43, 27), |
|
27: (22, 22, 22), |
|
28: (17, 211, 0), |
|
29: (121, 132, 9) |
|
} |
|
|
|
def train_model(self, yaml_filepath): |
|
self.model.train(data=yaml_filepath, epochs=100, imgsz=640, batch=16, patience=50) |
|
|
|
def detect_object(self, frames: List[np.ndarray]): |
|
for frame in frames: |
|
results = self.model.track(frame, stream=True) |
|
|
|
for result in results: |
|
class_names = result.names |
|
annotated_frame = frame.copy() |
|
|
|
for box in result.boxes: |
|
if box.conf[0] > 0.4: |
|
[x1, y1, x2, y2] = box.xyxy[0] |
|
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) |
|
|
|
cls = int(box.cls[0]) |
|
color = self.color_map.get(cls, (0,255,0)) |
|
|
|
cv2.rectangle(annotated_frame, (x1,y1), (x2,y2), color=color, thickness=2) |
|
|
|
text = f'{class_names[cls]} {box.conf[0]:.2f}' |
|
(text_width, text_height), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 1, 2) |
|
cv2.rectangle(annotated_frame, (x1, y1-text_height-5), (x1+text_width, y1), color, -1) |
|
|
|
cv2.putText(annotated_frame, text, (x1, y1-5), |
|
cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), thickness=2) |
|
while self.debug: |
|
cv2.imshow('frame', annotated_frame) |
|
if cv2.waitKey(1) & 0xFF == ord('q'): |
|
break |
|
cv2.destroyAllWindows() |
|
|
|
def export_model(self, format: str = 'onnx'): |
|
self.model.export(format=format) |