reab5555 commited on
Commit
5dde850
·
verified ·
1 Parent(s): 1b0b39f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -28
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import spaces
2
  import gradio as gr
3
  import cv2
4
  from PIL import Image, ImageDraw, ImageFont
@@ -9,6 +8,7 @@ import os
9
  import matplotlib.pyplot as plt
10
  from io import BytesIO
11
  import tempfile
 
12
 
13
  # Check if CUDA is available, otherwise use CPU
14
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -33,18 +33,19 @@ def process_video(video_path, target, progress=gr.Progress()):
33
  frame_duration = 1 / output_fps
34
  video_duration = frame_count / original_fps
35
 
36
- processed_frames = []
37
  frame_scores = []
 
 
38
 
39
- for time in progress.tqdm(np.arange(0, video_duration, frame_duration)):
40
  frame_number = int(time * original_fps)
41
  cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
42
  ret, img = cap.read()
43
  if not ret:
44
  break
45
 
46
- # Resize the frame to 640x480
47
- #img_resized = cv2.resize(img, (640, 360))
48
  pil_img = Image.fromarray(cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB))
49
 
50
  # Process single image
@@ -58,7 +59,7 @@ def process_video(video_path, target, progress=gr.Progress()):
58
  max_score = 0
59
 
60
  try:
61
- font = ImageFont.truetype("arial.ttf", 20) # Reduced font size for smaller image
62
  except IOError:
63
  font = ImageFont.load_default()
64
 
@@ -77,15 +78,22 @@ def process_video(video_path, target, progress=gr.Progress()):
77
 
78
  max_score = max(max_score, confidence)
79
 
80
- processed_frames.append(np.array(pil_img))
 
 
 
81
  frame_scores.append(max_score)
82
 
 
 
 
 
83
  cap.release()
84
- return processed_frames, frame_scores, None
85
-
86
  def create_heatmap(frame_scores, current_frame):
87
  plt.figure(figsize=(12, 3))
88
- plt.imshow([frame_scores], cmap='hot_r', aspect='auto') # 'hot_r' for reversed hot colormap
89
  cbar = plt.colorbar(label='Confidence')
90
  cbar.ax.yaxis.set_ticks_position('left')
91
  cbar.ax.yaxis.set_label_position('left')
@@ -93,13 +101,11 @@ def create_heatmap(frame_scores, current_frame):
93
  plt.xlabel('Frame')
94
  plt.yticks([])
95
 
96
- # Add more frame numbers on x-axis
97
  num_frames = len(frame_scores)
98
- step = max(1, num_frames // 10) # Show at most 10 frame numbers
99
  frame_numbers = range(0, num_frames, step)
100
  plt.xticks(frame_numbers, [str(i) for i in frame_numbers])
101
 
102
- # Add vertical line for current frame
103
  plt.axvline(x=current_frame, color='blue', linestyle='--', linewidth=2)
104
 
105
  plt.tight_layout()
@@ -121,6 +127,13 @@ def load_sample_frame(video_path):
121
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
122
  return frame_rgb
123
 
 
 
 
 
 
 
 
124
  def gradio_app():
125
  with gr.Blocks() as app:
126
  gr.Markdown("# Video Object Detection with Owlv2")
@@ -135,28 +148,23 @@ def gradio_app():
135
  use_sample_button = gr.Button("Use Sample Video")
136
  progress_bar = gr.Progress()
137
 
138
- processed_frames = gr.State([])
139
  frame_scores = gr.State([])
140
 
141
  def process_and_update(video, target):
142
- frames, scores, error = process_video(video, target, progress_bar)
143
- if frames is not None:
144
- heatmap_path = create_heatmap(scores, 0) # Initial heatmap with current frame at 0
145
- return frames, scores, frames[0], heatmap_path, error, gr.Slider(maximum=len(frames) - 1, value=0)
 
146
  return None, None, None, None, error, gr.Slider(maximum=100, value=0)
147
 
148
- def update_frame_and_heatmap(frame_index, frames, scores):
149
- if frames and 0 <= frame_index < len(frames):
150
- heatmap_path = create_heatmap(scores, frame_index)
151
- return frames[frame_index], heatmap_path
152
- return None, None
153
-
154
  video_input.upload(process_and_update,
155
  inputs=[video_input, target_input],
156
- outputs=[processed_frames, frame_scores, output_image, heatmap_output, error_output, frame_slider])
157
 
158
  frame_slider.change(update_frame_and_heatmap,
159
- inputs=[frame_slider, processed_frames, frame_scores],
160
  outputs=[output_image, heatmap_output])
161
 
162
  def use_sample_video():
@@ -165,7 +173,7 @@ def gradio_app():
165
 
166
  use_sample_button.click(use_sample_video,
167
  inputs=None,
168
- outputs=[processed_frames, frame_scores, output_image, heatmap_output, error_output, frame_slider])
169
 
170
  # Layout
171
  with gr.Row():
@@ -179,4 +187,15 @@ def gradio_app():
179
 
180
  if __name__ == "__main__":
181
  app = gradio_app()
182
- app.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import cv2
3
  from PIL import Image, ImageDraw, ImageFont
 
8
  import matplotlib.pyplot as plt
9
  from io import BytesIO
10
  import tempfile
11
+ import shutil
12
 
13
  # Check if CUDA is available, otherwise use CPU
14
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
33
  frame_duration = 1 / output_fps
34
  video_duration = frame_count / original_fps
35
 
 
36
  frame_scores = []
37
+ temp_dir = tempfile.mkdtemp()
38
+ frame_paths = []
39
 
40
+ for i, time in enumerate(progress.tqdm(np.arange(0, video_duration, frame_duration))):
41
  frame_number = int(time * original_fps)
42
  cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
43
  ret, img = cap.read()
44
  if not ret:
45
  break
46
 
47
+ # Resize the frame
48
+ img_resized = cv2.resize(img, (640, 360))
49
  pil_img = Image.fromarray(cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB))
50
 
51
  # Process single image
 
59
  max_score = 0
60
 
61
  try:
62
+ font = ImageFont.truetype("arial.ttf", 20)
63
  except IOError:
64
  font = ImageFont.load_default()
65
 
 
78
 
79
  max_score = max(max_score, confidence)
80
 
81
+ # Save frame to disk
82
+ frame_path = os.path.join(temp_dir, f"frame_{i:04d}.png")
83
+ pil_img.save(frame_path)
84
+ frame_paths.append(frame_path)
85
  frame_scores.append(max_score)
86
 
87
+ # Clear GPU cache every 10 frames
88
+ if i % 10 == 0:
89
+ torch.cuda.empty_cache()
90
+
91
  cap.release()
92
+ return frame_paths, frame_scores, None
93
+
94
  def create_heatmap(frame_scores, current_frame):
95
  plt.figure(figsize=(12, 3))
96
+ plt.imshow([frame_scores], cmap='hot_r', aspect='auto')
97
  cbar = plt.colorbar(label='Confidence')
98
  cbar.ax.yaxis.set_ticks_position('left')
99
  cbar.ax.yaxis.set_label_position('left')
 
101
  plt.xlabel('Frame')
102
  plt.yticks([])
103
 
 
104
  num_frames = len(frame_scores)
105
+ step = max(1, num_frames // 10)
106
  frame_numbers = range(0, num_frames, step)
107
  plt.xticks(frame_numbers, [str(i) for i in frame_numbers])
108
 
 
109
  plt.axvline(x=current_frame, color='blue', linestyle='--', linewidth=2)
110
 
111
  plt.tight_layout()
 
127
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
128
  return frame_rgb
129
 
130
+ def update_frame_and_heatmap(frame_index, frame_paths, scores):
131
+ if frame_paths and 0 <= frame_index < len(frame_paths):
132
+ frame = Image.open(frame_paths[frame_index])
133
+ heatmap_path = create_heatmap(scores, frame_index)
134
+ return np.array(frame), heatmap_path
135
+ return None, None
136
+
137
  def gradio_app():
138
  with gr.Blocks() as app:
139
  gr.Markdown("# Video Object Detection with Owlv2")
 
148
  use_sample_button = gr.Button("Use Sample Video")
149
  progress_bar = gr.Progress()
150
 
151
+ frame_paths = gr.State([])
152
  frame_scores = gr.State([])
153
 
154
  def process_and_update(video, target):
155
+ paths, scores, error = process_video(video, target, progress_bar)
156
+ if paths is not None:
157
+ heatmap_path = create_heatmap(scores, 0)
158
+ first_frame = Image.open(paths[0])
159
+ return paths, scores, np.array(first_frame), heatmap_path, error, gr.Slider(maximum=len(paths) - 1, value=0)
160
  return None, None, None, None, error, gr.Slider(maximum=100, value=0)
161
 
 
 
 
 
 
 
162
  video_input.upload(process_and_update,
163
  inputs=[video_input, target_input],
164
+ outputs=[frame_paths, frame_scores, output_image, heatmap_output, error_output, frame_slider])
165
 
166
  frame_slider.change(update_frame_and_heatmap,
167
+ inputs=[frame_slider, frame_paths, frame_scores],
168
  outputs=[output_image, heatmap_output])
169
 
170
  def use_sample_video():
 
173
 
174
  use_sample_button.click(use_sample_video,
175
  inputs=None,
176
+ outputs=[frame_paths, frame_scores, output_image, heatmap_output, error_output, frame_slider])
177
 
178
  # Layout
179
  with gr.Row():
 
187
 
188
  if __name__ == "__main__":
189
  app = gradio_app()
190
+ app.launch(share=True)
191
+
192
+ # Cleanup temporary files
193
+ def cleanup():
194
+ for path in frame_paths.value:
195
+ if os.path.exists(path):
196
+ os.remove(path)
197
+ if os.path.exists(temp_dir):
198
+ shutil.rmtree(temp_dir)
199
+
200
+ # Make sure to call cleanup when the app is closed
201
+ # This might require additional setup depending on how you're running the app