File size: 6,990 Bytes
97afa8e 81e2598 36dd82f 81e2598 36dd82f 343407e fa3925e 81e2598 ebc376e 81e2598 7d47fdc 81e2598 36dd82f 343407e 36dd82f 343407e 36dd82f 343407e 36dd82f 3f2cadc 69dc1f7 ac7c798 343407e 69dc1f7 7f942f1 343407e f05ca8c 343407e 36dd82f 3f2cadc 81e2598 7d47fdc 3f2cadc 36dd82f 343407e 8714cd1 343407e 8714cd1 343407e fa3925e 8714cd1 343407e fa3925e 36dd82f cc209a2 81e2598 36dd82f a702d47 7f942f1 343407e 8714cd1 36dd82f 8714cd1 81e2598 7f942f1 343407e 7f942f1 81e2598 343407e 7f942f1 8714cd1 fa3925e 343407e 7f942f1 8714cd1 7f942f1 8714cd1 36dd82f 81e2598 343407e 7f942f1 8714cd1 36dd82f 81e2598 53eff3d 36dd82f 343407e 36dd82f 3cb7297 8714cd1 3cb7297 8714cd1 3cb7297 8714cd1 36dd82f 7f942f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import spaces
import gradio as gr
import cv2
from PIL import Image, ImageDraw, ImageFont
import torch
from transformers import Owlv2Processor, Owlv2ForObjectDetection
import numpy as np
import os
import matplotlib.pyplot as plt
from io import BytesIO
import tempfile
# Check if CUDA is available, otherwise use CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
@spaces.GPU(duration=120)
def process_video(video_path, target, progress=gr.Progress()):
if video_path is None:
return None, None, "Error: No video uploaded"
if not os.path.exists(video_path):
return None, None, f"Error: Video file not found at {video_path}"
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return None, None, f"Error: Unable to open video file at {video_path}"
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
original_fps = int(cap.get(cv2.CAP_PROP_FPS))
output_fps = 1
frame_duration = 1 / output_fps
video_duration = frame_count / original_fps
processed_frames = []
frame_scores = []
for time in progress.tqdm(np.arange(0, video_duration, frame_duration)):
frame_number = int(time * original_fps)
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
ret, img = cap.read()
if not ret:
break
pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
# Process single image
inputs = processor(text=[target], images=pil_img, return_tensors="pt", padding=True).to(device)
outputs = model(**inputs)
target_sizes = torch.Tensor([pil_img.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)
draw = ImageDraw.Draw(pil_img)
max_score = 0
try:
font = ImageFont.truetype("arial.ttf", 40)
except IOError:
font = ImageFont.load_default()
boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
for box, score, label in zip(boxes, scores, labels):
if score.item() >= 0.25:
box = [round(i, 2) for i in box.tolist()]
object_label = target
confidence = round(score.item(), 3)
annotation = f"{object_label}: {confidence}"
draw.rectangle(box, outline="red", width=2)
text_position = (box[0], box[1] - 30)
draw.text(text_position, annotation, fill="white", font=font)
max_score = max(max_score, confidence)
processed_frames.append(np.array(pil_img))
frame_scores.append(max_score)
cap.release()
return processed_frames, frame_scores, None
def create_heatmap(frame_scores, current_frame):
plt.figure(figsize=(12, 3))
plt.imshow([frame_scores], cmap='hot_r', aspect='auto') # 'hot_r' for reversed hot colormap
cbar = plt.colorbar(label='Confidence')
cbar.ax.yaxis.set_ticks_position('left')
cbar.ax.yaxis.set_label_position('left')
plt.title('Object Detection Heatmap')
plt.xlabel('Frame')
plt.yticks([])
# Add more frame numbers on x-axis
num_frames = len(frame_scores)
step = max(1, num_frames // 10) # Show at most 10 frame numbers
frame_numbers = range(0, num_frames, step)
plt.xticks(frame_numbers, [str(i) for i in frame_numbers])
# Add vertical line for current frame
plt.axvline(x=current_frame, color='blue', linestyle='--', linewidth=2)
plt.tight_layout()
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp_file:
plt.savefig(tmp_file.name, format='png', dpi=400, bbox_inches='tight')
plt.close()
return tmp_file.name
def load_sample_frame(video_path):
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return None
ret, frame = cap.read()
cap.release()
if not ret:
return None
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return frame_rgb
def gradio_app():
with gr.Blocks() as app:
gr.Markdown("# Video Object Detection with Owlv2")
video_input = gr.Video(label="Upload Video")
target_input = gr.Textbox(label="Target Object", value="Elephant")
frame_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Frame", value=0)
heatmap_output = gr.Image(label="Detection Heatmap")
output_image = gr.Image(label="Processed Frame")
error_output = gr.Textbox(label="Error Messages", visible=False)
sample_video_frame = gr.Image(value=load_sample_frame("Drone Video of African Wildlife Wild Botswan.mp4"), label="Drone Video of African Wildlife Wild Botswan by wildimagesonline.com - Sample Video Frame")
use_sample_button = gr.Button("Use Sample Video")
progress_bar = gr.Progress()
processed_frames = gr.State([])
frame_scores = gr.State([])
def process_and_update(video, target):
frames, scores, error = process_video(video, target, progress_bar)
if frames is not None:
heatmap_path = create_heatmap(scores, 0) # Initial heatmap with current frame at 0
return frames, scores, frames[0], heatmap_path, error, gr.Slider(maximum=len(frames) - 1, value=0)
return None, None, None, None, error, gr.Slider(maximum=100, value=0)
def update_frame_and_heatmap(frame_index, frames, scores):
if frames and 0 <= frame_index < len(frames):
heatmap_path = create_heatmap(scores, frame_index)
return frames[frame_index], heatmap_path
return None, None
video_input.upload(process_and_update,
inputs=[video_input, target_input],
outputs=[processed_frames, frame_scores, output_image, heatmap_output, error_output, frame_slider])
frame_slider.change(update_frame_and_heatmap,
inputs=[frame_slider, processed_frames, frame_scores],
outputs=[output_image, heatmap_output])
def use_sample_video():
sample_video_path = "Drone Video of African Wildlife Wild Botswan.mp4"
return process_and_update(sample_video_path, "Elephant")
use_sample_button.click(use_sample_video,
inputs=None,
outputs=[processed_frames, frame_scores, output_image, heatmap_output, error_output, frame_slider])
# Layout
with gr.Row():
with gr.Column(scale=2):
output_image
with gr.Column(scale=1):
sample_video_frame
use_sample_button
return app
if __name__ == "__main__":
app = gradio_app()
app.launch(share=True) |