Update app.py
Browse files
app.py
CHANGED
@@ -94,18 +94,33 @@ def process_video(video_path, target, progress=gr.Progress()):
|
|
94 |
|
95 |
boxes, scores, labels = result["boxes"], result["scores"], result["labels"]
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
-
|
105 |
-
|
106 |
-
draw.text(text_position, annotation, fill="white", font=font)
|
107 |
|
108 |
-
|
109 |
|
110 |
# Save frame to disk
|
111 |
frame_path = os.path.join(temp_dir, f"frame_{batch_indices[idx]:04d}.png")
|
@@ -125,17 +140,17 @@ def process_video(video_path, target, progress=gr.Progress()):
|
|
125 |
return frame_paths, frame_scores, None
|
126 |
|
127 |
def create_heatmap(frame_scores, current_frame):
|
128 |
-
plt.figure(figsize=(
|
129 |
plt.imshow([frame_scores], cmap='hot_r', aspect='auto')
|
130 |
-
plt.title('Object Detection Heatmap')
|
131 |
-
plt.xlabel('Frame')
|
132 |
plt.yticks([])
|
133 |
|
134 |
# Add more frame numbers on x-axis
|
135 |
num_frames = len(frame_scores)
|
136 |
-
step = max(1, num_frames //
|
137 |
frame_numbers = range(0, num_frames, step)
|
138 |
-
plt.xticks(frame_numbers, [str(i) for i in frame_numbers])
|
139 |
|
140 |
# Add vertical line for current frame
|
141 |
plt.axvline(x=current_frame, color='blue', linestyle='--', linewidth=2)
|
|
|
94 |
|
95 |
boxes, scores, labels = result["boxes"], result["scores"], result["labels"]
|
96 |
|
97 |
+
for box, score, label in zip(boxes, scores, labels):
|
98 |
+
if score.item() >= 0.5:
|
99 |
+
box = [round(i, 2) for i in box.tolist()]
|
100 |
+
object_label = target
|
101 |
+
confidence = round(score.item(), 3)
|
102 |
+
annotation = f"{object_label}: {confidence}"
|
103 |
+
|
104 |
+
# Increase line width for the bounding box
|
105 |
+
draw.rectangle(box, outline="red", width=4)
|
106 |
+
|
107 |
+
# Increase font size and change color to red
|
108 |
+
font_size = 30 # Increased from 20
|
109 |
+
try:
|
110 |
+
font = ImageFont.truetype("arial.ttf", font_size)
|
111 |
+
except IOError:
|
112 |
+
font = ImageFont.load_default()
|
113 |
+
|
114 |
+
text_position = (box[0], box[1] - font_size - 5)
|
115 |
+
|
116 |
+
# Add a semi-transparent background for better text visibility
|
117 |
+
text_bbox = draw.textbbox(text_position, annotation, font=font)
|
118 |
+
draw.rectangle(text_bbox, fill=(0, 0, 0, 128))
|
119 |
|
120 |
+
# Draw text in red
|
121 |
+
draw.text(text_position, annotation, fill="red", font=font)
|
|
|
122 |
|
123 |
+
max_score = max(max_score, confidence)
|
124 |
|
125 |
# Save frame to disk
|
126 |
frame_path = os.path.join(temp_dir, f"frame_{batch_indices[idx]:04d}.png")
|
|
|
140 |
return frame_paths, frame_scores, None
|
141 |
|
142 |
def create_heatmap(frame_scores, current_frame):
|
143 |
+
plt.figure(figsize=(16, 4))
|
144 |
plt.imshow([frame_scores], cmap='hot_r', aspect='auto')
|
145 |
+
plt.title('Object Detection Heatmap', fontsize=14)
|
146 |
+
plt.xlabel('Frame', fontsize=12)
|
147 |
plt.yticks([])
|
148 |
|
149 |
# Add more frame numbers on x-axis
|
150 |
num_frames = len(frame_scores)
|
151 |
+
step = max(1, num_frames // 20) # Show at most 20 frame numbers
|
152 |
frame_numbers = range(0, num_frames, step)
|
153 |
+
plt.xticks(frame_numbers, [str(i) for i in frame_numbers], rotation=45, ha='right')
|
154 |
|
155 |
# Add vertical line for current frame
|
156 |
plt.axvline(x=current_frame, color='blue', linestyle='--', linewidth=2)
|