File size: 9,562 Bytes
6321191 81e2598 36dd82f 81e2598 36dd82f 343407e fa3925e 5dde850 81e2598 f69b877 81e2598 c9f1714 f69b877 c9f1714 81e2598 36dd82f 343407e 36dd82f 343407e 36dd82f 343407e 36dd82f 3f2cadc 69dc1f7 ac7c798 343407e 69dc1f7 343407e 5dde850 f05ca8c 7e7ddb7 c9f1714 5dde850 343407e 36dd82f 3f2cadc 6321191 f662a68 7d47fdc c9f1714 e4f7d8f 08a3f43 e4f7d8f 08a3f43 2a7def2 e4f7d8f 08a3f43 e4f7d8f 08a3f43 e4f7d8f 08a3f43 c9f1714 3f2cadc 5dde850 36dd82f 5dde850 8714cd1 ec0d71c 5dde850 ec0d71c 343407e 8714cd1 08a3f43 8714cd1 2a7def2 8714cd1 343407e fa3925e 8714cd1 343407e fa3925e 36dd82f 05f5d03 36dd82f 4417183 05f5d03 4417183 05f5d03 4417183 36dd82f 5dde850 36dd82f cc209a2 81e2598 36dd82f a702d47 7f942f1 343407e 8714cd1 36dd82f 05f5d03 81e2598 5dde850 343407e 7f942f1 81e2598 5dde850 343407e 7f942f1 36dd82f 81e2598 5dde850 7f942f1 8714cd1 5dde850 8714cd1 36dd82f 81e2598 53eff3d 36dd82f 5dde850 36dd82f 3cb7297 8714cd1 3cb7297 8714cd1 3cb7297 8714cd1 36dd82f 08a3f43 5dde850 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 |
import spaces
import gradio as gr
import cv2
from PIL import Image, ImageDraw, ImageFont
import torch
from transformers import Owlv2Processor, Owlv2ForObjectDetection
import numpy as np
import os
import matplotlib.pyplot as plt
import tempfile
import shutil
device = "cuda"
processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16")
model = model.to(device)
def process_video(video_path, target, progress=gr.Progress()):
if video_path is None:
return None, None, "Error: No video uploaded"
if not os.path.exists(video_path):
return None, None, f"Error: Video file not found at {video_path}"
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return None, None, f"Error: Unable to open video file at {video_path}"
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
original_fps = int(cap.get(cv2.CAP_PROP_FPS))
output_fps = 1
frame_duration = 1 / output_fps
video_duration = frame_count / original_fps
frame_scores = []
temp_dir = tempfile.mkdtemp()
frame_paths = []
batch_size = 1
batch_frames = []
batch_indices = []
for i, time in enumerate(progress.tqdm(np.arange(0, video_duration, frame_duration))):
frame_number = int(time * original_fps)
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
ret, img = cap.read()
if not ret:
break
# Convert to RGB without resizing
pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
batch_frames.append(pil_img)
batch_indices.append(i)
if len(batch_frames) == batch_size or i == int(video_duration / frame_duration) - 1:
# Process batch
inputs = processor(text=[target] * len(batch_frames), images=batch_frames, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
outputs = model(**inputs)
target_sizes = torch.Tensor([pil_img.size[::-1] for _ in batch_frames]).to(device)
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)
for idx, (pil_img, result) in enumerate(zip(batch_frames, results)):
draw = ImageDraw.Draw(pil_img)
max_score = 0
boxes, scores, labels = result["boxes"], result["scores"], result["labels"]
# Inside the loop where bounding boxes are drawn
for box, score, label in zip(boxes, scores, labels):
if score.item() >= 0.5:
box = [round(i, 2) for i in box.tolist()]
object_label = target
confidence = round(score.item(), 3)
annotation = f"{object_label}: {confidence}"
# Increase line width for the bounding box
draw.rectangle(box, outline="red", width=3)
# Calculate font size based on image dimensions
img_width, img_height = pil_img.size
font_size = int(min(img_width, img_height) * 0.03) # 3% of the smaller dimension
try:
font = ImageFont.truetype("arial.ttf", font_size)
except IOError:
font = ImageFont.load_default()
# Calculate text size
text_bbox = draw.textbbox((0, 0), annotation, font=font)
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
# Position text inside the top of the bounding box
text_position = (box[0], box[1])
# Draw semi-transparent background for text
draw.rectangle([text_position[0], text_position[1],
text_position[0] + text_width, text_position[1] + text_height],
fill=(0, 0, 0, 128))
# Draw text in red
draw.text(text_position, annotation, fill="red", font=font)
max_score = max(max_score, confidence)
frame_path = os.path.join(temp_dir, f"frame_{batch_indices[idx]:04d}.png")
pil_img.save(frame_path)
frame_paths.append(frame_path)
frame_scores.append(max_score)
# Clear batch
batch_frames = []
batch_indices = []
# Clear GPU cache every 10 frames
if i % 10 == 0:
torch.cuda.empty_cache()
cap.release()
return frame_paths, frame_scores, None
def create_heatmap(frame_scores, current_frame):
plt.figure(figsize=(16, 4))
plt.imshow([frame_scores], cmap='hot_r', aspect='auto')
plt.title('Object Detection Heatmap', fontsize=14)
plt.xlabel('Frame', fontsize=12)
plt.yticks([])
num_frames = len(frame_scores)
step = max(1, num_frames // 20)
frame_numbers = range(0, num_frames, step)
plt.xticks(frame_numbers, [str(i) for i in frame_numbers], rotation=90, ha='right')
plt.axvline(x=current_frame, color='blue', linestyle='--', linewidth=2)
plt.tight_layout()
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp_file:
plt.savefig(tmp_file.name, format='png', dpi=400, bbox_inches='tight')
plt.close()
return tmp_file.name
def load_sample_frame(video_path, target_frame=87, original_fps=30, processing_fps=1):
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return None
# Calculate the corresponding frame number in the original video
original_frame_number = int(target_frame * (original_fps / processing_fps))
# Set the frame position
cap.set(cv2.CAP_PROP_POS_FRAMES, original_frame_number)
ret, frame = cap.read()
cap.release()
if not ret:
return None
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return frame_rgb
def update_frame_and_heatmap(frame_index, frame_paths, scores):
if frame_paths and 0 <= frame_index < len(frame_paths):
frame = Image.open(frame_paths[frame_index])
heatmap_path = create_heatmap(scores, frame_index)
return np.array(frame), heatmap_path
return None, None
def gradio_app():
with gr.Blocks() as app:
gr.Markdown("# Video Object Detection with Owlv2")
video_input = gr.Video(label="Upload Video")
target_input = gr.Textbox(label="Target Object", value="Elephant")
frame_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Frame", value=0)
heatmap_output = gr.Image(label="Detection Heatmap")
output_image = gr.Image(label="Processed Frame")
error_output = gr.Textbox(label="Error Messages", visible=False)
sample_video_frame = gr.Image(
value=load_sample_frame("Drone Video of African Wildlife Wild Botswan.mp4", target_frame=87),
label="Drone Video of African Wildlife Wild Botswan by wildimagesonline.com - Sample Video Frame (Frame 87 at 1 FPS)"
)
use_sample_button = gr.Button("Use Sample Video")
progress_bar = gr.Progress()
frame_paths = gr.State([])
frame_scores = gr.State([])
def process_and_update(video, target):
paths, scores, error = process_video(video, target, progress_bar)
if paths is not None:
heatmap_path = create_heatmap(scores, 0)
first_frame = Image.open(paths[0])
return paths, scores, np.array(first_frame), heatmap_path, error, gr.Slider(maximum=len(paths) - 1, value=0)
return None, None, None, None, error, gr.Slider(maximum=100, value=0)
video_input.upload(process_and_update,
inputs=[video_input, target_input],
outputs=[frame_paths, frame_scores, output_image, heatmap_output, error_output, frame_slider])
frame_slider.change(update_frame_and_heatmap,
inputs=[frame_slider, frame_paths, frame_scores],
outputs=[output_image, heatmap_output])
def use_sample_video():
sample_video_path = "Drone Video of African Wildlife Wild Botswan.mp4"
return process_and_update(sample_video_path, "Elephant")
use_sample_button.click(use_sample_video,
inputs=None,
outputs=[frame_paths, frame_scores, output_image, heatmap_output, error_output, frame_slider])
# Layout
with gr.Row():
with gr.Column(scale=2):
output_image
with gr.Column(scale=1):
sample_video_frame
use_sample_button
return app
if __name__ == "__main__":
app = gradio_app()
app.launch()
# Cleanup temporary files
def cleanup():
for path in frame_paths.value:
if os.path.exists(path):
os.remove(path)
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
# Make sure to call cleanup when the app is closed
# This might require additional setup depending on how you're running the app |