Spaces:
Running
Running
File size: 7,772 Bytes
1999a98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
# Ultralytics ๐ AGPL-3.0 License - https://ultralytics.com/license
from collections import defaultdict
import cv2
from ultralytics import YOLO
from ultralytics.utils import ASSETS_URL, DEFAULT_CFG_DICT, DEFAULT_SOL_DICT, LOGGER
from ultralytics.utils.checks import check_imshow, check_requirements
class BaseSolution:
"""
A base class for managing Ultralytics Solutions.
This class provides core functionality for various Ultralytics Solutions, including model loading, object tracking,
and region initialization.
Attributes:
LineString (shapely.geometry.LineString): Class for creating line string geometries.
Polygon (shapely.geometry.Polygon): Class for creating polygon geometries.
Point (shapely.geometry.Point): Class for creating point geometries.
CFG (Dict): Configuration dictionary loaded from a YAML file and updated with kwargs.
region (List[Tuple[int, int]]): List of coordinate tuples defining a region of interest.
line_width (int): Width of lines used in visualizations.
model (ultralytics.YOLO): Loaded YOLO model instance.
names (Dict[int, str]): Dictionary mapping class indices to class names.
env_check (bool): Flag indicating whether the environment supports image display.
track_history (collections.defaultdict): Dictionary to store tracking history for each object.
Methods:
extract_tracks: Apply object tracking and extract tracks from an input image.
store_tracking_history: Store object tracking history for a given track ID and bounding box.
initialize_region: Initialize the counting region and line segment based on configuration.
display_output: Display the results of processing, including showing frames or saving results.
Examples:
>>> solution = BaseSolution(model="yolov8n.pt", region=[(0, 0), (100, 0), (100, 100), (0, 100)])
>>> solution.initialize_region()
>>> image = cv2.imread("image.jpg")
>>> solution.extract_tracks(image)
>>> solution.display_output(image)
"""
def __init__(self, IS_CLI=False, **kwargs):
"""
Initializes the `BaseSolution` class with configuration settings and the YOLO model for Ultralytics solutions.
IS_CLI (optional): Enables CLI mode if set.
"""
check_requirements("shapely>=2.0.0")
from shapely.geometry import LineString, Point, Polygon
from shapely.prepared import prep
self.LineString = LineString
self.Polygon = Polygon
self.Point = Point
self.prep = prep
self.annotator = None # Initialize annotator
self.tracks = None
self.track_data = None
self.boxes = []
self.clss = []
self.track_ids = []
self.track_line = None
self.r_s = None
# Load config and update with args
DEFAULT_SOL_DICT.update(kwargs)
DEFAULT_CFG_DICT.update(kwargs)
self.CFG = {**DEFAULT_SOL_DICT, **DEFAULT_CFG_DICT}
LOGGER.info(f"Ultralytics Solutions: โ
{DEFAULT_SOL_DICT}")
self.region = self.CFG["region"] # Store region data for other classes usage
self.line_width = (
self.CFG["line_width"] if self.CFG["line_width"] is not None else 2
) # Store line_width for usage
# Load Model and store classes names
if self.CFG["model"] is None:
self.CFG["model"] = "yolo11n.pt"
self.model = YOLO(self.CFG["model"])
self.names = self.model.names
self.track_add_args = { # Tracker additional arguments for advance configuration
k: self.CFG[k] for k in ["verbose", "iou", "conf", "device", "max_det", "half", "tracker"]
}
if IS_CLI and self.CFG["source"] is None:
d_s = "solutions_ci_demo.mp4" if "-pose" not in self.CFG["model"] else "solution_ci_pose_demo.mp4"
LOGGER.warning(f"โ ๏ธ WARNING: source not provided. using default source {ASSETS_URL}/{d_s}")
from ultralytics.utils.downloads import safe_download
safe_download(f"{ASSETS_URL}/{d_s}") # download source from ultralytics assets
self.CFG["source"] = d_s # set default source
# Initialize environment and region setup
self.env_check = check_imshow(warn=True)
self.track_history = defaultdict(list)
def extract_tracks(self, im0):
"""
Applies object tracking and extracts tracks from an input image or frame.
Args:
im0 (ndarray): The input image or frame.
Examples:
>>> solution = BaseSolution()
>>> frame = cv2.imread("path/to/image.jpg")
>>> solution.extract_tracks(frame)
"""
self.tracks = self.model.track(source=im0, persist=True, classes=self.CFG["classes"], **self.track_add_args)
# Extract tracks for OBB or object detection
self.track_data = self.tracks[0].obb or self.tracks[0].boxes
if self.track_data and self.track_data.id is not None:
self.boxes = self.track_data.xyxy.cpu()
self.clss = self.track_data.cls.cpu().tolist()
self.track_ids = self.track_data.id.int().cpu().tolist()
else:
LOGGER.warning("WARNING โ ๏ธ no tracks found!")
self.boxes, self.clss, self.track_ids = [], [], []
def store_tracking_history(self, track_id, box):
"""
Stores the tracking history of an object.
This method updates the tracking history for a given object by appending the center point of its
bounding box to the track line. It maintains a maximum of 30 points in the tracking history.
Args:
track_id (int): The unique identifier for the tracked object.
box (List[float]): The bounding box coordinates of the object in the format [x1, y1, x2, y2].
Examples:
>>> solution = BaseSolution()
>>> solution.store_tracking_history(1, [100, 200, 300, 400])
"""
# Store tracking history
self.track_line = self.track_history[track_id]
self.track_line.append(((box[0] + box[2]) / 2, (box[1] + box[3]) / 2))
if len(self.track_line) > 30:
self.track_line.pop(0)
def initialize_region(self):
"""Initialize the counting region and line segment based on configuration settings."""
if self.region is None:
self.region = [(20, 400), (1080, 400), (1080, 360), (20, 360)]
self.r_s = (
self.Polygon(self.region) if len(self.region) >= 3 else self.LineString(self.region)
) # region or line
def display_output(self, im0):
"""
Display the results of the processing, which could involve showing frames, printing counts, or saving results.
This method is responsible for visualizing the output of the object detection and tracking process. It displays
the processed frame with annotations, and allows for user interaction to close the display.
Args:
im0 (numpy.ndarray): The input image or frame that has been processed and annotated.
Examples:
>>> solution = BaseSolution()
>>> frame = cv2.imread("path/to/image.jpg")
>>> solution.display_output(frame)
Notes:
- This method will only display output if the 'show' configuration is set to True and the environment
supports image display.
- The display can be closed by pressing the 'q' key.
"""
if self.CFG.get("show") and self.env_check:
cv2.imshow("Ultralytics Solutions", im0)
if cv2.waitKey(1) & 0xFF == ord("q"):
return
|