DawnC commited on
Commit
3684eb4
·
verified ·
1 Parent(s): edda8e1

Delete evaluation_metrics.py

Browse files
Files changed (1) hide show
  1. evaluation_metrics.py +0 -323
evaluation_metrics.py DELETED
@@ -1,323 +0,0 @@
1
- import numpy as np
2
- import matplotlib.pyplot as plt
3
- from typing import Dict, List, Any, Optional, Tuple
4
-
5
- class EvaluationMetrics:
6
- """Class for computing detection metrics, generating statistics and visualization data"""
7
-
8
- @staticmethod
9
- def calculate_basic_stats(result: Any) -> Dict:
10
- """
11
- Calculate basic statistics for a single detection result
12
-
13
- Args:
14
- result: Detection result object
15
-
16
- Returns:
17
- Dictionary with basic statistics
18
- """
19
- if result is None:
20
- return {"error": "No detection result provided"}
21
-
22
- # Get classes and confidences
23
- classes = result.boxes.cls.cpu().numpy().astype(int)
24
- confidences = result.boxes.conf.cpu().numpy()
25
- names = result.names
26
-
27
- # Count by class
28
- class_counts = {}
29
- for cls, conf in zip(classes, confidences):
30
- cls_name = names[int(cls)]
31
- if cls_name not in class_counts:
32
- class_counts[cls_name] = {"count": 0, "total_confidence": 0, "confidences": []}
33
-
34
- class_counts[cls_name]["count"] += 1
35
- class_counts[cls_name]["total_confidence"] += float(conf)
36
- class_counts[cls_name]["confidences"].append(float(conf))
37
-
38
- # Calculate average confidence
39
- for cls_name, stats in class_counts.items():
40
- if stats["count"] > 0:
41
- stats["average_confidence"] = stats["total_confidence"] / stats["count"]
42
- stats["confidence_std"] = float(np.std(stats["confidences"])) if len(stats["confidences"]) > 1 else 0
43
- stats.pop("total_confidence") # Remove intermediate calculation
44
-
45
- # Prepare summary
46
- stats = {
47
- "total_objects": len(classes),
48
- "class_statistics": class_counts,
49
- "average_confidence": float(np.mean(confidences)) if len(confidences) > 0 else 0
50
- }
51
-
52
- return stats
53
-
54
- @staticmethod
55
- def generate_visualization_data(result: Any, class_colors: Dict = None) -> Dict:
56
- """
57
- Generate structured data suitable for visualization
58
-
59
- Args:
60
- result: Detection result object
61
- class_colors: Dictionary mapping class names to color codes (optional)
62
-
63
- Returns:
64
- Dictionary with visualization-ready data
65
- """
66
- if result is None:
67
- return {"error": "No detection result provided"}
68
-
69
- # Get basic stats first
70
- stats = EvaluationMetrics.calculate_basic_stats(result)
71
-
72
- # Create visualization-specific data structure
73
- viz_data = {
74
- "total_objects": stats["total_objects"],
75
- "average_confidence": stats["average_confidence"],
76
- "class_data": []
77
- }
78
-
79
- # Sort classes by count (descending)
80
- sorted_classes = sorted(
81
- stats["class_statistics"].items(),
82
- key=lambda x: x[1]["count"],
83
- reverse=True
84
- )
85
-
86
- # Create class-specific visualization data
87
- for cls_name, cls_stats in sorted_classes:
88
- class_id = -1
89
- # Find the class ID based on the name
90
- for idx, name in result.names.items():
91
- if name == cls_name:
92
- class_id = idx
93
- break
94
-
95
- cls_data = {
96
- "name": cls_name,
97
- "class_id": class_id,
98
- "count": cls_stats["count"],
99
- "average_confidence": cls_stats.get("average_confidence", 0),
100
- "confidence_std": cls_stats.get("confidence_std", 0),
101
- "color": class_colors.get(cls_name, "#CCCCCC") if class_colors else "#CCCCCC"
102
- }
103
-
104
- viz_data["class_data"].append(cls_data)
105
-
106
- return viz_data
107
-
108
- @staticmethod
109
- def create_stats_plot(viz_data: Dict, figsize: Tuple[int, int] = (10, 7),
110
- max_classes: int = 30) -> plt.Figure:
111
- """
112
- Create a horizontal bar chart showing detection statistics
113
-
114
- Args:
115
- viz_data: Visualization data generated by generate_visualization_data
116
- figsize: Figure size (width, height) in inches
117
- max_classes: Maximum number of classes to display
118
-
119
- Returns:
120
- Matplotlib figure object
121
- """
122
- if "error" in viz_data:
123
- # Create empty plot if error
124
- fig, ax = plt.subplots(figsize=figsize)
125
- ax.text(0.5, 0.5, viz_data["error"],
126
- ha='center', va='center', fontsize=12)
127
- ax.set_xlim(0, 1)
128
- ax.set_ylim(0, 1)
129
- ax.axis('off')
130
- return fig
131
-
132
- if "class_data" not in viz_data or not viz_data["class_data"]:
133
- # Create empty plot if no data
134
- fig, ax = plt.subplots(figsize=figsize)
135
- ax.text(0.5, 0.5, "No detection data available",
136
- ha='center', va='center', fontsize=12)
137
- ax.set_xlim(0, 1)
138
- ax.set_ylim(0, 1)
139
- ax.axis('off')
140
- return fig
141
-
142
- # Limit to max_classes
143
- class_data = viz_data["class_data"][:max_classes]
144
-
145
- # Extract data for plotting
146
- class_names = [item["name"] for item in class_data]
147
- counts = [item["count"] for item in class_data]
148
- colors = [item["color"] for item in class_data]
149
-
150
- # Create figure and horizontal bar chart
151
- fig, ax = plt.subplots(figsize=figsize)
152
- y_pos = np.arange(len(class_names))
153
-
154
- # Create horizontal bars with class-specific colors
155
- bars = ax.barh(y_pos, counts, color=colors, alpha=0.8)
156
-
157
- # Add count values at end of each bar
158
- for i, bar in enumerate(bars):
159
- width = bar.get_width()
160
- conf = class_data[i]["average_confidence"]
161
- ax.text(width + 0.3, bar.get_y() + bar.get_height()/2,
162
- f"{width:.0f} (conf: {conf:.2f})",
163
- va='center', fontsize=9)
164
-
165
- # Customize axis and labels
166
- ax.set_yticks(y_pos)
167
- ax.set_yticklabels(class_names)
168
- ax.invert_yaxis() # Labels read top-to-bottom
169
- ax.set_xlabel('Count')
170
- ax.set_title(f'Objects Detected: {viz_data["total_objects"]} Total')
171
-
172
- # Add grid for better readability
173
- ax.set_axisbelow(True)
174
- ax.grid(axis='x', linestyle='--', alpha=0.7)
175
-
176
- # Add detection summary as a text box
177
- summary_text = (
178
- f"Total Objects: {viz_data['total_objects']}\n"
179
- f"Average Confidence: {viz_data['average_confidence']:.2f}\n"
180
- f"Unique Classes: {len(viz_data['class_data'])}"
181
- )
182
- plt.figtext(0.02, 0.02, summary_text, fontsize=9,
183
- bbox=dict(facecolor='white', alpha=0.8, boxstyle='round'))
184
-
185
- plt.tight_layout()
186
- return fig
187
-
188
- @staticmethod
189
- def format_detection_summary(viz_data: Dict) -> str:
190
- """
191
- Format detection results as a readable text summary
192
- """
193
- if "error" in viz_data:
194
- return viz_data["error"]
195
-
196
- if "total_objects" not in viz_data:
197
- return "No detection data available."
198
-
199
- # 移除時間顯示
200
- total_objects = viz_data["total_objects"]
201
- avg_confidence = viz_data["average_confidence"]
202
-
203
- # 創建標題
204
- lines = [
205
- f"Detected {total_objects} objects.",
206
- f"Average confidence: {avg_confidence:.2f}",
207
- "",
208
- "Objects by class:",
209
- ]
210
-
211
- # 添加類別詳情
212
- if "class_data" in viz_data and viz_data["class_data"]:
213
- for item in viz_data["class_data"]:
214
- lines.append(
215
- f"• {item['name']}: {item['count']} (avg conf: {item['average_confidence']:.2f})"
216
- )
217
- else:
218
- lines.append("No class information available.")
219
-
220
- return "\n".join(lines)
221
-
222
- @staticmethod
223
- def calculate_distance_metrics(result: Any) -> Dict:
224
- """
225
- Calculate distance-related metrics for detected objects
226
-
227
- Args:
228
- result: Detection result object
229
-
230
- Returns:
231
- Dictionary with distance metrics
232
- """
233
- if result is None:
234
- return {"error": "No detection result provided"}
235
-
236
- boxes = result.boxes.xyxy.cpu().numpy()
237
- classes = result.boxes.cls.cpu().numpy().astype(int)
238
- names = result.names
239
-
240
- # Initialize metrics
241
- metrics = {
242
- "proximity": {}, # Classes that appear close to each other
243
- "spatial_distribution": {}, # Distribution across the image
244
- "size_distribution": {} # Size distribution of objects
245
- }
246
-
247
- # Calculate image dimensions (assuming normalized coordinates or extract from result)
248
- img_width, img_height = 1, 1
249
- if hasattr(result, "orig_shape"):
250
- img_height, img_width = result.orig_shape[:2]
251
-
252
- # Calculate bounding box areas and centers
253
- areas = []
254
- centers = []
255
- class_names = []
256
-
257
- for box, cls in zip(boxes, classes):
258
- x1, y1, x2, y2 = box
259
- width, height = x2 - x1, y2 - y1
260
- area = width * height
261
- center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
262
-
263
- areas.append(area)
264
- centers.append((center_x, center_y))
265
- class_names.append(names[int(cls)])
266
-
267
- # Calculate spatial distribution
268
- if centers:
269
- x_coords = [c[0] for c in centers]
270
- y_coords = [c[1] for c in centers]
271
-
272
- metrics["spatial_distribution"] = {
273
- "x_mean": float(np.mean(x_coords)) / img_width,
274
- "y_mean": float(np.mean(y_coords)) / img_height,
275
- "x_std": float(np.std(x_coords)) / img_width,
276
- "y_std": float(np.std(y_coords)) / img_height
277
- }
278
-
279
- # Calculate size distribution
280
- if areas:
281
- metrics["size_distribution"] = {
282
- "mean_area": float(np.mean(areas)) / (img_width * img_height),
283
- "std_area": float(np.std(areas)) / (img_width * img_height),
284
- "min_area": float(np.min(areas)) / (img_width * img_height),
285
- "max_area": float(np.max(areas)) / (img_width * img_height)
286
- }
287
-
288
- # Calculate proximity between different classes
289
- class_centers = {}
290
- for cls_name, center in zip(class_names, centers):
291
- if cls_name not in class_centers:
292
- class_centers[cls_name] = []
293
- class_centers[cls_name].append(center)
294
-
295
- # Find classes that appear close to each other
296
- proximity_pairs = []
297
- for i, cls1 in enumerate(class_centers.keys()):
298
- for j, cls2 in enumerate(class_centers.keys()):
299
- if i >= j: # Avoid duplicate pairs and self-comparison
300
- continue
301
-
302
- # Calculate minimum distance between any two objects of these classes
303
- min_distance = float('inf')
304
- for center1 in class_centers[cls1]:
305
- for center2 in class_centers[cls2]:
306
- dist = np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2)
307
- min_distance = min(min_distance, dist)
308
-
309
- # Normalize by image diagonal
310
- img_diagonal = np.sqrt(img_width**2 + img_height**2)
311
- norm_distance = min_distance / img_diagonal
312
-
313
- proximity_pairs.append({
314
- "class1": cls1,
315
- "class2": cls2,
316
- "distance": float(norm_distance)
317
- })
318
-
319
- # Sort by distance and keep the closest pairs
320
- proximity_pairs.sort(key=lambda x: x["distance"])
321
- metrics["proximity"] = proximity_pairs[:5] # Keep top 5 closest pairs
322
-
323
- return metrics