reab5555 commited on
Commit
8714cd1
·
verified ·
1 Parent(s): fa3925e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -17
app.py CHANGED
@@ -34,7 +34,7 @@ def process_video(video_path, target, progress=gr.Progress()):
34
 
35
  processed_frames = []
36
  frame_scores = []
37
- batch_size = 1
38
  batch_frames = []
39
  batch_times = []
40
 
@@ -63,7 +63,7 @@ def process_video(video_path, target, progress=gr.Progress()):
63
  max_score = 0
64
 
65
  try:
66
- font = ImageFont.truetype("arial.ttf", 30)
67
  except IOError:
68
  font = ImageFont.load_default()
69
 
@@ -76,7 +76,7 @@ def process_video(video_path, target, progress=gr.Progress()):
76
  confidence = round(score.item(), 3)
77
  annotation = f"{object_label}: {confidence}"
78
 
79
- draw.rectangle(box, outline="red", width=4)
80
  text_position = (box[0], box[1] - 30)
81
  draw.text(text_position, annotation, fill="white", font=font)
82
 
@@ -91,17 +91,29 @@ def process_video(video_path, target, progress=gr.Progress()):
91
  cap.release()
92
  return processed_frames, frame_scores, None
93
 
94
- def create_heatmap(frame_scores):
95
- plt.figure(figsize=(10, 2))
96
- plt.imshow([frame_scores], cmap='hot', aspect='auto')
97
- plt.colorbar(label='Confidence')
 
 
98
  plt.title('Object Detection Heatmap')
99
  plt.xlabel('Frame')
100
  plt.yticks([])
 
 
 
 
 
 
 
 
 
 
101
  plt.tight_layout()
102
 
103
  with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp_file:
104
- plt.savefig(tmp_file.name, format='png')
105
  plt.close()
106
 
107
  return tmp_file.name
@@ -124,10 +136,10 @@ def gradio_app():
124
  video_input = gr.Video(label="Upload Video")
125
  target_input = gr.Textbox(label="Target Object", value="Elephant")
126
  frame_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Frame", value=0)
127
- output_image = gr.Image(label="Processed Frame")
128
  heatmap_output = gr.Image(label="Detection Heatmap")
 
129
  error_output = gr.Textbox(label="Error Messages", visible=False)
130
- sample_video_frame = gr.Image(value=load_sample_frame("Drone Video of African Wildlife Wild Botswan.mp4"), label="Sample Video Frame")
131
  use_sample_button = gr.Button("Use Sample Video")
132
  progress_bar = gr.Progress()
133
 
@@ -137,22 +149,23 @@ def gradio_app():
137
  def process_and_update(video, target):
138
  frames, scores, error = process_video(video, target, progress_bar)
139
  if frames is not None:
140
- heatmap_path = create_heatmap(scores)
141
  return frames, scores, frames[0], heatmap_path, error, gr.Slider(maximum=len(frames) - 1, value=0)
142
  return None, None, None, None, error, gr.Slider(maximum=100, value=0)
143
 
144
- def update_frame(frame_index, frames):
145
  if frames and 0 <= frame_index < len(frames):
146
- return frames[frame_index]
147
- return None
 
148
 
149
  video_input.upload(process_and_update,
150
  inputs=[video_input, target_input],
151
  outputs=[processed_frames, frame_scores, output_image, heatmap_output, error_output, frame_slider])
152
 
153
- frame_slider.change(update_frame,
154
- inputs=[frame_slider, processed_frames],
155
- outputs=[output_image])
156
 
157
  def use_sample_video():
158
  sample_video_path = "Drone Video of African Wildlife Wild Botswan.mp4"
@@ -162,6 +175,19 @@ def gradio_app():
162
  inputs=None,
163
  outputs=[processed_frames, frame_scores, output_image, heatmap_output, error_output, frame_slider])
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  return app
166
 
167
  if __name__ == "__main__":
 
34
 
35
  processed_frames = []
36
  frame_scores = []
37
+ batch_size = 2
38
  batch_frames = []
39
  batch_times = []
40
 
 
63
  max_score = 0
64
 
65
  try:
66
+ font = ImageFont.truetype("arial.ttf", 40)
67
  except IOError:
68
  font = ImageFont.load_default()
69
 
 
76
  confidence = round(score.item(), 3)
77
  annotation = f"{object_label}: {confidence}"
78
 
79
+ draw.rectangle(box, outline="red", width=2)
80
  text_position = (box[0], box[1] - 30)
81
  draw.text(text_position, annotation, fill="white", font=font)
82
 
 
91
  cap.release()
92
  return processed_frames, frame_scores, None
93
 
94
+ def create_heatmap(frame_scores, current_frame):
95
+ plt.figure(figsize=(12, 3))
96
+ plt.imshow([frame_scores], cmap='hot_r', aspect='auto') # 'hot_r' for reversed hot colormap
97
+ cbar = plt.colorbar(label='Confidence')
98
+ cbar.ax.yaxis.set_ticks_position('left')
99
+ cbar.ax.yaxis.set_label_position('left')
100
  plt.title('Object Detection Heatmap')
101
  plt.xlabel('Frame')
102
  plt.yticks([])
103
+
104
+ # Add more frame numbers on x-axis
105
+ num_frames = len(frame_scores)
106
+ step = max(1, num_frames // 10) # Show at most 10 frame numbers
107
+ frame_numbers = range(0, num_frames, step)
108
+ plt.xticks(frame_numbers, [str(i) for i in frame_numbers])
109
+
110
+ # Add vertical line for current frame
111
+ plt.axvline(x=current_frame, color='blue', linestyle='--', linewidth=2)
112
+
113
  plt.tight_layout()
114
 
115
  with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp_file:
116
+ plt.savefig(tmp_file.name, format='png', dpi=400, bbox_inches='tight')
117
  plt.close()
118
 
119
  return tmp_file.name
 
136
  video_input = gr.Video(label="Upload Video")
137
  target_input = gr.Textbox(label="Target Object", value="Elephant")
138
  frame_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Frame", value=0)
 
139
  heatmap_output = gr.Image(label="Detection Heatmap")
140
+ output_image = gr.Image(label="Processed Frame")
141
  error_output = gr.Textbox(label="Error Messages", visible=False)
142
+ sample_video_frame = gr.Image(value=load_sample_frame("Drone Video of African Wildlife Wild Botswan.mp4"), label="Drone Video of African Wildlife Wild Botswan by wildimagesonline.com - Sample Video Frame")
143
  use_sample_button = gr.Button("Use Sample Video")
144
  progress_bar = gr.Progress()
145
 
 
149
  def process_and_update(video, target):
150
  frames, scores, error = process_video(video, target, progress_bar)
151
  if frames is not None:
152
+ heatmap_path = create_heatmap(scores, 0) # Initial heatmap with current frame at 0
153
  return frames, scores, frames[0], heatmap_path, error, gr.Slider(maximum=len(frames) - 1, value=0)
154
  return None, None, None, None, error, gr.Slider(maximum=100, value=0)
155
 
156
+ def update_frame_and_heatmap(frame_index, frames, scores):
157
  if frames and 0 <= frame_index < len(frames):
158
+ heatmap_path = create_heatmap(scores, frame_index)
159
+ return frames[frame_index], heatmap_path
160
+ return None, None
161
 
162
  video_input.upload(process_and_update,
163
  inputs=[video_input, target_input],
164
  outputs=[processed_frames, frame_scores, output_image, heatmap_output, error_output, frame_slider])
165
 
166
+ frame_slider.change(update_frame_and_heatmap,
167
+ inputs=[frame_slider, processed_frames, frame_scores],
168
+ outputs=[output_image, heatmap_output])
169
 
170
  def use_sample_video():
171
  sample_video_path = "Drone Video of African Wildlife Wild Botswan.mp4"
 
175
  inputs=None,
176
  outputs=[processed_frames, frame_scores, output_image, heatmap_output, error_output, frame_slider])
177
 
178
+ # Rearrange the layout
179
+ video_input.render()
180
+ target_input.render()
181
+ with gr.Row():
182
+ with gr.Column(scale=2):
183
+ output_image.render()
184
+ with gr.Column(scale=1):
185
+ sample_video_frame.render()
186
+ use_sample_button.render()
187
+ frame_slider.render()
188
+ heatmap_output.render()
189
+ error_output.render()
190
+
191
  return app
192
 
193
  if __name__ == "__main__":