DawnC commited on
Commit
1e960be
·
verified ·
1 Parent(s): 3684eb4

Delete detection_model.py

Browse files
Files changed (1) hide show
  1. detection_model.py +0 -164
detection_model.py DELETED
@@ -1,164 +0,0 @@
1
- from ultralytics import YOLO
2
- from typing import Any, List, Dict, Optional
3
- import torch
4
- import numpy as np
5
- import os
6
-
7
- class DetectionModel:
8
- """Core detection model class for object detection using YOLOv8"""
9
-
10
- # Model information dictionary
11
- MODEL_INFO = {
12
- "yolov8n.pt": {
13
- "name": "YOLOv8n (Nano)",
14
- "description": "Fastest model with smallest size (3.2M parameters). Best for speed-critical applications.",
15
- "size_mb": 6,
16
- "inference_speed": "Very Fast"
17
- },
18
- "yolov8m.pt": {
19
- "name": "YOLOv8m (Medium)",
20
- "description": "Balanced model with good accuracy-speed tradeoff (25.9M parameters). Recommended for general use.",
21
- "size_mb": 25,
22
- "inference_speed": "Medium"
23
- },
24
- "yolov8x.pt": {
25
- "name": "YOLOv8x (XLarge)",
26
- "description": "Most accurate but slower model (68.2M parameters). Best for accuracy-critical applications.",
27
- "size_mb": 68,
28
- "inference_speed": "Slower"
29
- }
30
- }
31
-
32
- def __init__(self, model_name: str = 'yolov8m.pt', confidence: float = 0.25, iou: float = 0.45):
33
- """
34
- Initialize the detection model
35
-
36
- Args:
37
- model_name: Model name or path, default is yolov8m.pt
38
- confidence: Confidence threshold, default is 0.25
39
- iou: IoU threshold for non-maximum suppression, default is 0.45
40
- """
41
- self.model_name = model_name
42
- self.confidence = confidence
43
- self.iou = iou
44
- self.model = None
45
- self.class_names = {}
46
- self.is_model_loaded = False
47
-
48
- # Load model on initialization
49
- self._load_model()
50
-
51
- def _load_model(self):
52
- """Load the YOLO model"""
53
- try:
54
- print(f"Loading model: {self.model_name}")
55
- self.model = YOLO(self.model_name)
56
- self.class_names = self.model.names
57
- self.is_model_loaded = True
58
- print(f"Successfully loaded model: {self.model_name}")
59
- print(f"Number of classes the model can recognize: {len(self.class_names)}")
60
- except Exception as e:
61
- print(f"Error occurred when loading the model: {e}")
62
- self.is_model_loaded = False
63
-
64
- def change_model(self, new_model_name: str) -> bool:
65
- """
66
- Change the currently loaded model
67
-
68
- Args:
69
- new_model_name: Name of the new model to load
70
-
71
- Returns:
72
- bool: True if model changed successfully, False otherwise
73
- """
74
- if self.model_name == new_model_name and self.is_model_loaded:
75
- print(f"Model {new_model_name} is already loaded")
76
- return True
77
-
78
- print(f"Changing model from {self.model_name} to {new_model_name}")
79
-
80
- # Unload current model to free memory
81
- if self.model is not None:
82
- del self.model
83
- self.model = None
84
-
85
- # Clean GPU memory if available
86
- if torch.cuda.is_available():
87
- torch.cuda.empty_cache()
88
-
89
- # Update model name and load new model
90
- self.model_name = new_model_name
91
- self._load_model()
92
-
93
- return self.is_model_loaded
94
-
95
- def reload_model(self):
96
- """Reload the model (useful for changing model or after error)"""
97
- if self.model is not None:
98
- del self.model
99
- self.model = None
100
-
101
- # Clean GPU memory if available
102
- if torch.cuda.is_available():
103
- torch.cuda.empty_cache()
104
-
105
- self._load_model()
106
-
107
- def detect(self, image_input: Any) -> Optional[Any]:
108
- """
109
- Perform object detection on a single image
110
-
111
- Args:
112
- image_input: Image path (str), PIL Image, or numpy array
113
-
114
- Returns:
115
- Detection result object or None if error occurred
116
- """
117
- if self.model is None or not self.is_model_loaded:
118
- print("Model not found or not loaded. Attempting to reload...")
119
- self._load_model()
120
- if self.model is None or not self.is_model_loaded:
121
- print("Failed to load model. Cannot perform detection.")
122
- return None
123
-
124
- try:
125
- results = self.model(image_input, conf=self.confidence, iou=self.iou)
126
- return results[0]
127
- except Exception as e:
128
- print(f"Error occurred during detection: {e}")
129
- return None
130
-
131
- def get_class_names(self, class_id: int) -> str:
132
- """Get class name for a given class ID"""
133
- return self.class_names.get(class_id, "Unknown Class")
134
-
135
- def get_supported_classes(self) -> Dict[int, str]:
136
- """Get all supported classes as a dictionary of {id: class_name}"""
137
- return self.class_names
138
-
139
- @classmethod
140
- def get_available_models(cls) -> List[Dict]:
141
- """
142
- Get list of available models with their information
143
-
144
- Returns:
145
- List of dictionaries containing model information
146
- """
147
- models = []
148
- for model_file, info in cls.MODEL_INFO.items():
149
- models.append({
150
- "model_file": model_file,
151
- "name": info["name"],
152
- "description": info["description"],
153
- "size_mb": info["size_mb"],
154
- "inference_speed": info["inference_speed"]
155
- })
156
- return models
157
-
158
- @classmethod
159
- def get_model_description(cls, model_name: str) -> str:
160
- """Get description for a specific model"""
161
- if model_name in cls.MODEL_INFO:
162
- info = cls.MODEL_INFO[model_name]
163
- return f"{info['name']}: {info['description']} (Size: ~{info['size_mb']}MB, Speed: {info['inference_speed']})"
164
- return "Model information not available"