Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import spaces
|
2 |
import gradio as gr
|
3 |
import cv2
|
4 |
from PIL import Image, ImageDraw, ImageFont
|
@@ -9,6 +8,7 @@ import os
|
|
9 |
import matplotlib.pyplot as plt
|
10 |
from io import BytesIO
|
11 |
import tempfile
|
|
|
12 |
|
13 |
# Check if CUDA is available, otherwise use CPU
|
14 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
@@ -33,18 +33,19 @@ def process_video(video_path, target, progress=gr.Progress()):
|
|
33 |
frame_duration = 1 / output_fps
|
34 |
video_duration = frame_count / original_fps
|
35 |
|
36 |
-
processed_frames = []
|
37 |
frame_scores = []
|
|
|
|
|
38 |
|
39 |
-
for time in progress.tqdm(np.arange(0, video_duration, frame_duration)):
|
40 |
frame_number = int(time * original_fps)
|
41 |
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
|
42 |
ret, img = cap.read()
|
43 |
if not ret:
|
44 |
break
|
45 |
|
46 |
-
# Resize the frame
|
47 |
-
|
48 |
pil_img = Image.fromarray(cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB))
|
49 |
|
50 |
# Process single image
|
@@ -58,7 +59,7 @@ def process_video(video_path, target, progress=gr.Progress()):
|
|
58 |
max_score = 0
|
59 |
|
60 |
try:
|
61 |
-
font = ImageFont.truetype("arial.ttf", 20)
|
62 |
except IOError:
|
63 |
font = ImageFont.load_default()
|
64 |
|
@@ -77,15 +78,22 @@ def process_video(video_path, target, progress=gr.Progress()):
|
|
77 |
|
78 |
max_score = max(max_score, confidence)
|
79 |
|
80 |
-
|
|
|
|
|
|
|
81 |
frame_scores.append(max_score)
|
82 |
|
|
|
|
|
|
|
|
|
83 |
cap.release()
|
84 |
-
return
|
85 |
-
|
86 |
def create_heatmap(frame_scores, current_frame):
|
87 |
plt.figure(figsize=(12, 3))
|
88 |
-
plt.imshow([frame_scores], cmap='hot_r', aspect='auto')
|
89 |
cbar = plt.colorbar(label='Confidence')
|
90 |
cbar.ax.yaxis.set_ticks_position('left')
|
91 |
cbar.ax.yaxis.set_label_position('left')
|
@@ -93,13 +101,11 @@ def create_heatmap(frame_scores, current_frame):
|
|
93 |
plt.xlabel('Frame')
|
94 |
plt.yticks([])
|
95 |
|
96 |
-
# Add more frame numbers on x-axis
|
97 |
num_frames = len(frame_scores)
|
98 |
-
step = max(1, num_frames // 10)
|
99 |
frame_numbers = range(0, num_frames, step)
|
100 |
plt.xticks(frame_numbers, [str(i) for i in frame_numbers])
|
101 |
|
102 |
-
# Add vertical line for current frame
|
103 |
plt.axvline(x=current_frame, color='blue', linestyle='--', linewidth=2)
|
104 |
|
105 |
plt.tight_layout()
|
@@ -121,6 +127,13 @@ def load_sample_frame(video_path):
|
|
121 |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
122 |
return frame_rgb
|
123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
def gradio_app():
|
125 |
with gr.Blocks() as app:
|
126 |
gr.Markdown("# Video Object Detection with Owlv2")
|
@@ -135,28 +148,23 @@ def gradio_app():
|
|
135 |
use_sample_button = gr.Button("Use Sample Video")
|
136 |
progress_bar = gr.Progress()
|
137 |
|
138 |
-
|
139 |
frame_scores = gr.State([])
|
140 |
|
141 |
def process_and_update(video, target):
|
142 |
-
|
143 |
-
if
|
144 |
-
heatmap_path = create_heatmap(scores, 0)
|
145 |
-
|
|
|
146 |
return None, None, None, None, error, gr.Slider(maximum=100, value=0)
|
147 |
|
148 |
-
def update_frame_and_heatmap(frame_index, frames, scores):
|
149 |
-
if frames and 0 <= frame_index < len(frames):
|
150 |
-
heatmap_path = create_heatmap(scores, frame_index)
|
151 |
-
return frames[frame_index], heatmap_path
|
152 |
-
return None, None
|
153 |
-
|
154 |
video_input.upload(process_and_update,
|
155 |
inputs=[video_input, target_input],
|
156 |
-
outputs=[
|
157 |
|
158 |
frame_slider.change(update_frame_and_heatmap,
|
159 |
-
inputs=[frame_slider,
|
160 |
outputs=[output_image, heatmap_output])
|
161 |
|
162 |
def use_sample_video():
|
@@ -165,7 +173,7 @@ def gradio_app():
|
|
165 |
|
166 |
use_sample_button.click(use_sample_video,
|
167 |
inputs=None,
|
168 |
-
outputs=[
|
169 |
|
170 |
# Layout
|
171 |
with gr.Row():
|
@@ -179,4 +187,15 @@ def gradio_app():
|
|
179 |
|
180 |
if __name__ == "__main__":
|
181 |
app = gradio_app()
|
182 |
-
app.launch(share=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import cv2
|
3 |
from PIL import Image, ImageDraw, ImageFont
|
|
|
8 |
import matplotlib.pyplot as plt
|
9 |
from io import BytesIO
|
10 |
import tempfile
|
11 |
+
import shutil
|
12 |
|
13 |
# Check if CUDA is available, otherwise use CPU
|
14 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
33 |
frame_duration = 1 / output_fps
|
34 |
video_duration = frame_count / original_fps
|
35 |
|
|
|
36 |
frame_scores = []
|
37 |
+
temp_dir = tempfile.mkdtemp()
|
38 |
+
frame_paths = []
|
39 |
|
40 |
+
for i, time in enumerate(progress.tqdm(np.arange(0, video_duration, frame_duration))):
|
41 |
frame_number = int(time * original_fps)
|
42 |
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
|
43 |
ret, img = cap.read()
|
44 |
if not ret:
|
45 |
break
|
46 |
|
47 |
+
# Resize the frame
|
48 |
+
img_resized = cv2.resize(img, (640, 360))
|
49 |
pil_img = Image.fromarray(cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB))
|
50 |
|
51 |
# Process single image
|
|
|
59 |
max_score = 0
|
60 |
|
61 |
try:
|
62 |
+
font = ImageFont.truetype("arial.ttf", 20)
|
63 |
except IOError:
|
64 |
font = ImageFont.load_default()
|
65 |
|
|
|
78 |
|
79 |
max_score = max(max_score, confidence)
|
80 |
|
81 |
+
# Save frame to disk
|
82 |
+
frame_path = os.path.join(temp_dir, f"frame_{i:04d}.png")
|
83 |
+
pil_img.save(frame_path)
|
84 |
+
frame_paths.append(frame_path)
|
85 |
frame_scores.append(max_score)
|
86 |
|
87 |
+
# Clear GPU cache every 10 frames
|
88 |
+
if i % 10 == 0:
|
89 |
+
torch.cuda.empty_cache()
|
90 |
+
|
91 |
cap.release()
|
92 |
+
return frame_paths, frame_scores, None
|
93 |
+
|
94 |
def create_heatmap(frame_scores, current_frame):
|
95 |
plt.figure(figsize=(12, 3))
|
96 |
+
plt.imshow([frame_scores], cmap='hot_r', aspect='auto')
|
97 |
cbar = plt.colorbar(label='Confidence')
|
98 |
cbar.ax.yaxis.set_ticks_position('left')
|
99 |
cbar.ax.yaxis.set_label_position('left')
|
|
|
101 |
plt.xlabel('Frame')
|
102 |
plt.yticks([])
|
103 |
|
|
|
104 |
num_frames = len(frame_scores)
|
105 |
+
step = max(1, num_frames // 10)
|
106 |
frame_numbers = range(0, num_frames, step)
|
107 |
plt.xticks(frame_numbers, [str(i) for i in frame_numbers])
|
108 |
|
|
|
109 |
plt.axvline(x=current_frame, color='blue', linestyle='--', linewidth=2)
|
110 |
|
111 |
plt.tight_layout()
|
|
|
127 |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
128 |
return frame_rgb
|
129 |
|
130 |
+
def update_frame_and_heatmap(frame_index, frame_paths, scores):
|
131 |
+
if frame_paths and 0 <= frame_index < len(frame_paths):
|
132 |
+
frame = Image.open(frame_paths[frame_index])
|
133 |
+
heatmap_path = create_heatmap(scores, frame_index)
|
134 |
+
return np.array(frame), heatmap_path
|
135 |
+
return None, None
|
136 |
+
|
137 |
def gradio_app():
|
138 |
with gr.Blocks() as app:
|
139 |
gr.Markdown("# Video Object Detection with Owlv2")
|
|
|
148 |
use_sample_button = gr.Button("Use Sample Video")
|
149 |
progress_bar = gr.Progress()
|
150 |
|
151 |
+
frame_paths = gr.State([])
|
152 |
frame_scores = gr.State([])
|
153 |
|
154 |
def process_and_update(video, target):
|
155 |
+
paths, scores, error = process_video(video, target, progress_bar)
|
156 |
+
if paths is not None:
|
157 |
+
heatmap_path = create_heatmap(scores, 0)
|
158 |
+
first_frame = Image.open(paths[0])
|
159 |
+
return paths, scores, np.array(first_frame), heatmap_path, error, gr.Slider(maximum=len(paths) - 1, value=0)
|
160 |
return None, None, None, None, error, gr.Slider(maximum=100, value=0)
|
161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
video_input.upload(process_and_update,
|
163 |
inputs=[video_input, target_input],
|
164 |
+
outputs=[frame_paths, frame_scores, output_image, heatmap_output, error_output, frame_slider])
|
165 |
|
166 |
frame_slider.change(update_frame_and_heatmap,
|
167 |
+
inputs=[frame_slider, frame_paths, frame_scores],
|
168 |
outputs=[output_image, heatmap_output])
|
169 |
|
170 |
def use_sample_video():
|
|
|
173 |
|
174 |
use_sample_button.click(use_sample_video,
|
175 |
inputs=None,
|
176 |
+
outputs=[frame_paths, frame_scores, output_image, heatmap_output, error_output, frame_slider])
|
177 |
|
178 |
# Layout
|
179 |
with gr.Row():
|
|
|
187 |
|
188 |
if __name__ == "__main__":
|
189 |
app = gradio_app()
|
190 |
+
app.launch(share=True)
|
191 |
+
|
192 |
+
# Cleanup temporary files
|
193 |
+
def cleanup():
|
194 |
+
for path in frame_paths.value:
|
195 |
+
if os.path.exists(path):
|
196 |
+
os.remove(path)
|
197 |
+
if os.path.exists(temp_dir):
|
198 |
+
shutil.rmtree(temp_dir)
|
199 |
+
|
200 |
+
# Make sure to call cleanup when the app is closed
|
201 |
+
# This might require additional setup depending on how you're running the app
|