File size: 3,696 Bytes
019f9fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from ultralytics import YOLO
from ultralytics.utils.plotting import Annotator, colors
import cv2
from typing import List
import numpy as np

class ObjectDetector():
    def __init__(self, pretrained_model: str = 'yolov8n.pt', debug: bool = False):
        self.model = YOLO(pretrained_model)
        self.debug = debug
        
        self.color_map = {
            0: (0, 255, 0),     # player: green
            1: (255, 128, 0),   # Storm Timer: light blue
            2: (0, 0, 255),     # Killfeed: red
            3: (255, 0, 0),     # Player Count: blue
            4: (0, 255, 255),   # Minimap: yellow
            5: (128, 0, 255),   # Storm Shrink Warning: dark red
            6: (255, 0, 255),   # Eliminations: magenta
            7: (0, 128, 255),   # Health: orange
            8: (255, 255, 0),   # Shield: cyan
            9: (128, 255, 0),   # Inventory: light green
            10: (0, 165, 255),  # Buildings: orange-yellow
            11: (139, 69, 19),  # Wood Material: brown
            12: (128, 128, 128),# Brick Material: gray
            13: (192, 192, 192),# Metal Material: light gray
            14: (255, 191, 0),  # Compass: deep sky blue
            15: (255, 0, 128),  # Equipped Item: purple
            16: (0, 255, 191),  # Waypoint: yellow-green
            17: (128, 128, 0),  # Sprint Meter: teal
            18: (0, 140, 255),  # Safe Zone: orange-red
            19: (0, 215, 255),  # playerIcon: gold
            20: (34, 139, 34),  # Tree: forest green
            21: (75, 75, 75),   # Stone: dark gray
            22: (0, 69, 255),   # Building: orange-red
            23: (122, 61, 0),   # Wood Building: dark brown
            24: (108, 108, 108),# Stone Building: medium gray
            25: (211, 211, 211), # Metal Building: silver
            26: (0, 43, 27), # Wall: dark green
            27: (22, 22, 22), # Ramp: dark gray
            28: (17, 211, 0), # Pyramid: bright green
            29: (121, 132, 9) # Floor: olive green
}
        
    def train_model(self, yaml_filepath):
        self.model.train(data=yaml_filepath, epochs=100, imgsz=640, batch=16, patience=50)
        
    def detect_object(self, frames: List[np.ndarray]):
        for frame in frames:
            results = self.model.track(frame, stream=True)
            
            for result in results:
                class_names = result.names
                annotated_frame = frame.copy()
                
                for box in result.boxes:
                    if box.conf[0] > 0.4:
                        [x1, y1, x2, y2] = box.xyxy[0] # coords
                        x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
                        
                        cls = int(box.cls[0]) # class
                        color = self.color_map.get(cls, (0,255,0))
                        
                        cv2.rectangle(annotated_frame, (x1,y1), (x2,y2), color=color, thickness=2)

                        text = f'{class_names[cls]} {box.conf[0]:.2f}'
                        (text_width, text_height), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 1, 2)
                        cv2.rectangle(annotated_frame, (x1, y1-text_height-5), (x1+text_width, y1), color, -1)

                        cv2.putText(annotated_frame, text, (x1, y1-5), 
                                  cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), thickness=2)
                while self.debug:
                    cv2.imshow('frame', annotated_frame)
                    if cv2.waitKey(1) & 0xFF == ord('q'):
                        break
        cv2.destroyAllWindows()
        
    def export_model(self, format: str = 'onnx'):
        self.model.export(format=format)