DawnC commited on
Commit
a3d5402
·
verified ·
1 Parent(s): fcf2ce0

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -474
app.py DELETED
@@ -1,474 +0,0 @@
1
- import os
2
- import numpy as np
3
- import torch
4
- import cv2
5
- import matplotlib.pyplot as plt
6
- import gradio as gr
7
- import io
8
- from PIL import Image, ImageDraw, ImageFont
9
- import spaces
10
- from typing import Dict, List, Any, Optional, Tuple
11
- from ultralytics import YOLO
12
-
13
- from detection_model import DetectionModel
14
- from color_mapper import ColorMapper
15
- from visualization_helper import VisualizationHelper
16
- from evaluation_metrics import EvaluationMetrics
17
- from style import Style
18
-
19
-
20
- color_mapper = ColorMapper()
21
- model_instances = {}
22
-
23
- @spaces.GPU
24
- def process_image(image, model_instance, confidence_threshold, filter_classes=None):
25
- """
26
- Process an image for object detection
27
-
28
- Args:
29
- image: Input image (numpy array or PIL Image)
30
- model_instance: DetectionModel instance to use
31
- confidence_threshold: Confidence threshold for detection
32
- filter_classes: Optional list of classes to filter results
33
-
34
- Returns:
35
- Tuple of (result_image, result_text, stats_data)
36
- """
37
- # initialize key variables
38
- result = None
39
- stats = {}
40
- temp_path = None
41
-
42
- try:
43
- # update confidence threshold
44
- model_instance.confidence = confidence_threshold
45
-
46
- # processing input image
47
- if isinstance(image, np.ndarray):
48
- # Convert BGR to RGB if needed
49
- if image.shape[2] == 3:
50
- image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
51
- else:
52
- image_rgb = image
53
- pil_image = Image.fromarray(image_rgb)
54
- elif image is None:
55
- return None, "No image provided. Please upload an image.", {}
56
- else:
57
- pil_image = image
58
-
59
- # store temp files
60
- import uuid
61
- import tempfile
62
-
63
- temp_dir = tempfile.gettempdir() # use system temp directory
64
- temp_filename = f"temp_{uuid.uuid4().hex}.jpg"
65
- temp_path = os.path.join(temp_dir, temp_filename)
66
- pil_image.save(temp_path)
67
-
68
- # object detection
69
- result = model_instance.detect(temp_path)
70
-
71
- if result is None:
72
- return None, "Detection failed. Please try again with a different image.", {}
73
-
74
- # calculate stats
75
- stats = EvaluationMetrics.calculate_basic_stats(result)
76
-
77
- # add space calculation
78
- spatial_metrics = EvaluationMetrics.calculate_distance_metrics(result)
79
- stats["spatial_metrics"] = spatial_metrics
80
-
81
- if filter_classes and len(filter_classes) > 0:
82
- # get classes, boxes, confidence
83
- classes = result.boxes.cls.cpu().numpy().astype(int)
84
- confs = result.boxes.conf.cpu().numpy()
85
- boxes = result.boxes.xyxy.cpu().numpy()
86
-
87
- mask = np.zeros_like(classes, dtype=bool)
88
- for cls_id in filter_classes:
89
- mask = np.logical_or(mask, classes == cls_id)
90
-
91
- filtered_stats = {
92
- "total_objects": int(np.sum(mask)),
93
- "class_statistics": {},
94
- "average_confidence": float(np.mean(confs[mask])) if np.any(mask) else 0,
95
- "spatial_metrics": stats["spatial_metrics"]
96
- }
97
-
98
- # update stats
99
- names = result.names
100
- for cls, conf in zip(classes[mask], confs[mask]):
101
- cls_name = names[int(cls)]
102
- if cls_name not in filtered_stats["class_statistics"]:
103
- filtered_stats["class_statistics"][cls_name] = {
104
- "count": 0,
105
- "average_confidence": 0
106
- }
107
-
108
- filtered_stats["class_statistics"][cls_name]["count"] += 1
109
- filtered_stats["class_statistics"][cls_name]["average_confidence"] = conf
110
-
111
- stats = filtered_stats
112
-
113
- viz_data = EvaluationMetrics.generate_visualization_data(
114
- result,
115
- color_mapper.get_all_colors()
116
- )
117
-
118
- result_image = VisualizationHelper.visualize_detection(
119
- temp_path, result, color_mapper=color_mapper, figsize=(12, 12), return_pil=True
120
- )
121
-
122
- result_text = EvaluationMetrics.format_detection_summary(viz_data)
123
-
124
- return result_image, result_text, stats
125
-
126
- except Exception as e:
127
- error_message = f"Error Occurs: {str(e)}"
128
- import traceback
129
- traceback.print_exc()
130
- print(error_message)
131
- return None, error_message, {}
132
-
133
- finally:
134
- if temp_path and os.path.exists(temp_path):
135
- try:
136
- os.remove(temp_path)
137
- except Exception as e:
138
- print(f"Cannot delete temp files {temp_path}: {str(e)}")
139
-
140
- def format_result_text(stats):
141
- """Format detection statistics into readable text"""
142
- if not stats or "total_objects" not in stats:
143
- return "No objects detected."
144
-
145
- lines = [
146
- f"Detected {stats['total_objects']} objects.",
147
- f"Average confidence: {stats.get('average_confidence', 0):.2f}",
148
- "",
149
- "Objects by class:",
150
- ]
151
-
152
- if "class_statistics" in stats and stats["class_statistics"]:
153
- # Sort classes by count
154
- sorted_classes = sorted(
155
- stats["class_statistics"].items(),
156
- key=lambda x: x[1]["count"],
157
- reverse=True
158
- )
159
-
160
- for cls_name, cls_stats in sorted_classes:
161
- lines.append(f"• {cls_name}: {cls_stats['count']} (avg conf: {cls_stats.get('average_confidence', 0):.2f})")
162
- else:
163
- lines.append("No class information available.")
164
-
165
- return "\n".join(lines)
166
-
167
- def get_all_classes():
168
- """Get all available COCO classes"""
169
- try:
170
- class_names = model.class_names
171
- return [(idx, name) for idx, name in class_names.items()]
172
- except:
173
- # Fallback to standard COCO classes
174
- return [
175
- (0, 'person'), (1, 'bicycle'), (2, 'car'), (3, 'motorcycle'), (4, 'airplane'),
176
- (5, 'bus'), (6, 'train'), (7, 'truck'), (8, 'boat'), (9, 'traffic light'),
177
- (10, 'fire hydrant'), (11, 'stop sign'), (12, 'parking meter'), (13, 'bench'),
178
- (14, 'bird'), (15, 'cat'), (16, 'dog'), (17, 'horse'), (18, 'sheep'), (19, 'cow'),
179
- (20, 'elephant'), (21, 'bear'), (22, 'zebra'), (23, 'giraffe'), (24, 'backpack'),
180
- (25, 'umbrella'), (26, 'handbag'), (27, 'tie'), (28, 'suitcase'), (29, 'frisbee'),
181
- (30, 'skis'), (31, 'snowboard'), (32, 'sports ball'), (33, 'kite'), (34, 'baseball bat'),
182
- (35, 'baseball glove'), (36, 'skateboard'), (37, 'surfboard'), (38, 'tennis racket'),
183
- (39, 'bottle'), (40, 'wine glass'), (41, 'cup'), (42, 'fork'), (43, 'knife'),
184
- (44, 'spoon'), (45, 'bowl'), (46, 'banana'), (47, 'apple'), (48, 'sandwich'),
185
- (49, 'orange'), (50, 'broccoli'), (51, 'carrot'), (52, 'hot dog'), (53, 'pizza'),
186
- (54, 'donut'), (55, 'cake'), (56, 'chair'), (57, 'couch'), (58, 'potted plant'),
187
- (59, 'bed'), (60, 'dining table'), (61, 'toilet'), (62, 'tv'), (63, 'laptop'),
188
- (64, 'mouse'), (65, 'remote'), (66, 'keyboard'), (67, 'cell phone'), (68, 'microwave'),
189
- (69, 'oven'), (70, 'toaster'), (71, 'sink'), (72, 'refrigerator'), (73, 'book'),
190
- (74, 'clock'), (75, 'vase'), (76, 'scissors'), (77, 'teddy bear'), (78, 'hair drier'),
191
- (79, 'toothbrush')
192
- ]
193
-
194
- def create_interface():
195
- """創建 Gradio 界面"""
196
-
197
- css = Style.get_css()
198
-
199
- # get model info
200
- available_models = DetectionModel.get_available_models()
201
- model_choices = [model["model_file"] for model in available_models]
202
- model_labels = [f"{model['name']} - {model['inference_speed']}" for model in available_models]
203
-
204
- # classes option
205
- available_classes = get_all_classes()
206
- class_choices = [f"{id}: {name}" for id, name in available_classes]
207
-
208
- # create blocks area
209
- with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="teal", secondary_hue="blue")) as demo:
210
- # Header
211
- with gr.Group(elem_classes="app-header"):
212
- gr.HTML("""
213
- <div style="text-align: center; width: 100%;">
214
- <h1 class="app-title">VisionScout</h1>
215
- <h2 class="app-subtitle">Detect and identify objects in your images</h2>
216
- <div class="app-divider"></div>
217
- </div>
218
- """)
219
-
220
- current_model = gr.State("yolov8m.pt") # use medium size as default
221
-
222
- # 主要內容區 - 輸入和輸出面板
223
- with gr.Row(equal_height=True):
224
- # 左側 - 輸入控制區
225
- with gr.Column(scale=4, elem_classes="input-panel"):
226
- with gr.Group():
227
- gr.Markdown("<div style='text-align: center;'>### Upload Image</div>")
228
- image_input = gr.Image(type="pil", label="Upload an image", elem_classes="upload-box")
229
-
230
- with gr.Accordion("Advanced Settings", open=False):
231
- with gr.Row():
232
- model_dropdown = gr.Dropdown(
233
- choices=model_choices,
234
- value="yolov8m.pt",
235
- label="Select Model",
236
- info="Choose different models based on your needs for speed vs. accuracy"
237
- )
238
-
239
- # 顯示模型資訊
240
- model_info = gr.Markdown(DetectionModel.get_model_description("yolov8m.pt"))
241
-
242
- confidence = gr.Slider(
243
- minimum=0.1,
244
- maximum=0.9,
245
- value=0.25,
246
- step=0.05,
247
- label="Confidence Threshold",
248
- info="Higher values show fewer but more confident detections"
249
- )
250
-
251
- with gr.Accordion("Filter Classes", open=False):
252
- # 常見物件類別快速選擇按鈕
253
- gr.Markdown("<div style='text-align: center;'>Common Categories</div>")
254
- with gr.Row():
255
- people_btn = gr.Button("People", size="sm")
256
- vehicles_btn = gr.Button("Vehicles", size="sm")
257
- animals_btn = gr.Button("Animals", size="sm")
258
- objects_btn = gr.Button("Common Objects", size="sm")
259
-
260
- # 類別選擇下拉框
261
- class_filter = gr.Dropdown(
262
- choices=class_choices,
263
- multiselect=True,
264
- label="Select Classes to Display",
265
- info="Leave empty to show all detected objects"
266
- )
267
-
268
- # 偵測按鈕
269
- detect_btn = gr.Button("Detect Objects", variant="primary", elem_classes="detect-btn")
270
-
271
- # 使用說明區
272
- with gr.Group(elem_classes="how-to-use"):
273
- gr.Markdown("<div style='text-align: center;'>### How to Use</div>")
274
- gr.Markdown("""
275
- 1. Upload an image or use the camera
276
- 2. Adjust confidence threshold if needed
277
- 3. Optionally filter to specific object classes
278
- 4. Click "Detect Objects" button
279
-
280
- The model will identify objects in your image and display them with bounding boxes.
281
-
282
- **Note:** Detection quality depends on image clarity and object visibility. The model can detect up to 80 different types of common objects.
283
- """)
284
-
285
- # 右側 - 結果顯示區
286
- with gr.Column(scale=6, elem_classes="output-panel"):
287
- with gr.Tabs(elem_classes="tabs"):
288
- with gr.Tab("Detection Result"):
289
- result_image = gr.Image(type="pil", label="Detection Result")
290
- result_text = gr.Textbox(label="Detection Details", lines=10)
291
-
292
- with gr.Tab("Statistics"):
293
- with gr.Row():
294
- with gr.Column(scale=1):
295
- stats_json = gr.JSON(label="Full Statistics")
296
-
297
- with gr.Column(scale=1):
298
- gr.Markdown("<div style='text-align: center;'>### Object Distribution</div>")
299
- plot_output = gr.Plot(label="Object Distribution")
300
-
301
- detect_btn.click(
302
- fn=lambda img, model, conf, classes: process_and_plot(img, model, conf, classes),
303
- inputs=[image_input, current_model, confidence, class_filter],
304
- outputs=[result_image, result_text, stats_json, plot_output]
305
- )
306
-
307
- model_dropdown.change(
308
- fn=lambda model: (model, DetectionModel.get_model_description(model)),
309
- inputs=[model_dropdown],
310
- outputs=[current_model, model_info]
311
- )
312
-
313
- # 快速類別過濾按鈕
314
- people_classes = [0] # people
315
- vehicles_classes = [1, 2, 3, 4, 5, 6, 7, 8] # cars
316
- animals_classes = list(range(14, 24)) # COCO dataset animal
317
- common_objects = [41, 42, 43, 44, 45, 67, 73, 74, 76] # common things
318
-
319
- people_btn.click(
320
- lambda: [f"{id}: {name}" for id, name in available_classes if id in people_classes],
321
- outputs=class_filter
322
- )
323
-
324
- vehicles_btn.click(
325
- lambda: [f"{id}: {name}" for id, name in available_classes if id in vehicles_classes],
326
- outputs=class_filter
327
- )
328
-
329
- animals_btn.click(
330
- lambda: [f"{id}: {name}" for id, name in available_classes if id in animals_classes],
331
- outputs=class_filter
332
- )
333
-
334
- objects_btn.click(
335
- lambda: [f"{id}: {name}" for id, name in available_classes if id in common_objects],
336
- outputs=class_filter
337
- )
338
-
339
- example_images = [
340
- "room_01.jpg",
341
- "street_01.jpg",
342
- "street_02.jpg",
343
- "street_03.jpg"
344
- ]
345
-
346
- # add expample images
347
- gr.Examples(
348
- examples=example_images,
349
- inputs=image_input,
350
- outputs=None,
351
- fn=None,
352
- cache_examples=False,
353
- )
354
-
355
- # footer
356
- gr.HTML("""
357
- <div class="footer">
358
- <p>Powered by YOLOv8 and Ultralytics • Created with Gradio</p>
359
- <p>Model can detect 80 different classes of objects</p>
360
- </div>
361
- """)
362
-
363
- return demo
364
-
365
- @spaces.GPU
366
- def process_and_plot(image, model_name, confidence_threshold, filter_classes=None):
367
- """
368
- Process image and create plots for statistics
369
-
370
- Args:
371
- image: Input image
372
- model_name: Name of the model to use
373
- confidence_threshold: Confidence threshold for detection
374
- filter_classes: Optional list of classes to filter results
375
-
376
- Returns:
377
- Tuple of (result_image, result_text, stats_json, plot_figure)
378
- """
379
- global model_instances
380
-
381
- if model_name not in model_instances:
382
- print(f"Creating new model instance for {model_name}")
383
- model_instances[model_name] = DetectionModel(model_name=model_name, confidence=confidence_threshold, iou=0.45)
384
- else:
385
- print(f"Using existing model instance for {model_name}")
386
- model_instances[model_name].confidence = confidence_threshold
387
-
388
- class_ids = None
389
- if filter_classes:
390
- class_ids = []
391
- for class_str in filter_classes:
392
- try:
393
- # Extract ID from format "id: name"
394
- class_id = int(class_str.split(":")[0].strip())
395
- class_ids.append(class_id)
396
- except:
397
- continue
398
-
399
- # execute detection
400
- result_image, result_text, stats = process_image(
401
- image,
402
- model_instances[model_name],
403
- confidence_threshold,
404
- class_ids
405
- )
406
-
407
- # create stats table
408
- plot_figure = create_stats_plot(stats)
409
-
410
- return result_image, result_text, stats, plot_figure
411
-
412
- def create_stats_plot(stats):
413
- """
414
- Create a visualization of statistics data
415
-
416
- Args:
417
- stats: Dictionary containing detection statistics
418
-
419
- Returns:
420
- Matplotlib figure with visualization
421
- """
422
- if not stats or "class_statistics" not in stats or not stats["class_statistics"]:
423
- # Create empty plot if no data
424
- fig, ax = plt.subplots(figsize=(8, 6))
425
- ax.text(0.5, 0.5, "No detection data available",
426
- ha='center', va='center', fontsize=12)
427
- ax.set_xlim(0, 1)
428
- ax.set_ylim(0, 1)
429
- ax.axis('off')
430
- return fig
431
-
432
- # preparing visualization data
433
- viz_data = {
434
- "total_objects": stats.get("total_objects", 0),
435
- "average_confidence": stats.get("average_confidence", 0),
436
- "class_data": []
437
- }
438
-
439
- # get current model classes
440
- # This uses the get_all_classes function which should retrieve from the current model
441
- available_classes = dict(get_all_classes())
442
-
443
- # process class data
444
- for cls_name, cls_stats in stats.get("class_statistics", {}).items():
445
- # search for class ID
446
- class_id = -1
447
-
448
- # Try to find the class ID from class names
449
- for id, name in available_classes.items():
450
- if name == cls_name:
451
- class_id = id
452
- break
453
-
454
- cls_data = {
455
- "name": cls_name,
456
- "class_id": class_id,
457
- "count": cls_stats.get("count", 0),
458
- "average_confidence": cls_stats.get("average_confidence", 0),
459
- "color": color_mapper.get_color(class_id if class_id >= 0 else cls_name)
460
- }
461
-
462
- viz_data["class_data"].append(cls_data)
463
-
464
- # Sort by count in descending order
465
- viz_data["class_data"].sort(key=lambda x: x["count"], reverse=True)
466
-
467
- return EvaluationMetrics.create_stats_plot(viz_data)
468
-
469
-
470
- if __name__ == "__main__":
471
- import time
472
-
473
- demo = create_interface()
474
- demo.launch()