import gradio as gr from gradio import spaces import cv2 from PIL import Image, ImageDraw, ImageFont import torch from transformers import Owlv2Processor, Owlv2ForObjectDetection import numpy as np import os import matplotlib.pyplot as plt import tempfile import shutil # Check if CUDA is available, otherwise use CPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16") model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16") # Try to move model to GPU and use half precision try: model = model.to(device).half() except RuntimeError: print("GPU out of memory, using CPU instead") device = torch.device("cpu") model = model.to(device) @spaces.GPU(duration=120) def process_video(video_path, target, progress=gr.Progress()): if video_path is None: return None, None, "Error: No video uploaded" if not os.path.exists(video_path): return None, None, f"Error: Video file not found at {video_path}" cap = cv2.VideoCapture(video_path) if not cap.isOpened(): return None, None, f"Error: Unable to open video file at {video_path}" frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) original_fps = int(cap.get(cv2.CAP_PROP_FPS)) output_fps = 1 frame_duration = 1 / output_fps video_duration = frame_count / original_fps frame_scores = [] temp_dir = tempfile.mkdtemp() frame_paths = [] batch_size = 4 # Process 4 frames at a time batch_frames = [] batch_indices = [] for i, time in enumerate(progress.tqdm(np.arange(0, video_duration, frame_duration))): frame_number = int(time * original_fps) cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) ret, img = cap.read() if not ret: break # Resize the frame #img_resized = cv2.resize(img, (640, 360)) pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) batch_frames.append(pil_img) batch_indices.append(i) if len(batch_frames) == batch_size or i == int(video_duration / frame_duration) - 1: # Process batch inputs = processor(text=[target] * len(batch_frames), images=batch_frames, return_tensors="pt", padding=True).to(device) with torch.no_grad(): outputs = model(**inputs) target_sizes = torch.Tensor([pil_img.size[::-1] for _ in batch_frames]).to(device) results = processor.post_process_object_detection(outputs, target_sizes=target_sizes) for idx, (pil_img, result) in enumerate(zip(batch_frames, results)): draw = ImageDraw.Draw(pil_img) max_score = 0 boxes, scores, labels = result["boxes"], result["scores"], result["labels"] for box, score, label in zip(boxes, scores, labels): if score.item() >= 0.5: box = [round(i, 2) for i in box.tolist()] object_label = target confidence = round(score.item(), 3) annotation = f"{object_label}: {confidence}" # Increase line width for the bounding box draw.rectangle(box, outline="red", width=4) # Increase font size and change color to red font_size = 50 try: font = ImageFont.truetype("arial.ttf", font_size) except IOError: font = ImageFont.load_default() text_position = (box[0], box[1] - font_size - 5) # Draw text in red draw.text(text_position, annotation, fill="red", font=font) max_score = max(max_score, confidence) frame_path = os.path.join(temp_dir, f"frame_{batch_indices[idx]:04d}.png") pil_img.save(frame_path) frame_paths.append(frame_path) frame_scores.append(max_score) # Clear batch batch_frames = [] batch_indices = [] # Clear GPU cache every 10 frames if i % 10 == 0: torch.cuda.empty_cache() cap.release() return frame_paths, frame_scores, None def create_heatmap(frame_scores, current_frame): plt.figure(figsize=(16, 4)) plt.imshow([frame_scores], cmap='hot_r', aspect='auto') plt.title('Object Detection Heatmap', fontsize=14) plt.xlabel('Frame', fontsize=12) plt.yticks([]) num_frames = len(frame_scores) step = max(1, num_frames // 20) frame_numbers = range(0, num_frames, step) plt.xticks(frame_numbers, [str(i) for i in frame_numbers], rotation=45, ha='right') plt.axvline(x=current_frame, color='blue', linestyle='--', linewidth=2) plt.tight_layout() with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp_file: plt.savefig(tmp_file.name, format='png', dpi=400, bbox_inches='tight') plt.close() return tmp_file.name def load_sample_frame(video_path): cap = cv2.VideoCapture(video_path) if not cap.isOpened(): return None ret, frame = cap.read() cap.release() if not ret: return None frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) return frame_rgb def update_frame_and_heatmap(frame_index, frame_paths, scores): if frame_paths and 0 <= frame_index < len(frame_paths): frame = Image.open(frame_paths[frame_index]) heatmap_path = create_heatmap(scores, frame_index) return np.array(frame), heatmap_path return None, None def gradio_app(): with gr.Blocks() as app: gr.Markdown("# Video Object Detection with Owlv2") video_input = gr.Video(label="Upload Video") target_input = gr.Textbox(label="Target Object", value="Elephant") frame_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Frame", value=0) heatmap_output = gr.Image(label="Detection Heatmap") output_image = gr.Image(label="Processed Frame") error_output = gr.Textbox(label="Error Messages", visible=False) sample_video_frame = gr.Image(value=load_sample_frame("Drone Video of African Wildlife Wild Botswan.mp4"), label="Drone Video of African Wildlife Wild Botswan by wildimagesonline.com - Sample Video Frame") use_sample_button = gr.Button("Use Sample Video") progress_bar = gr.Progress() frame_paths = gr.State([]) frame_scores = gr.State([]) def process_and_update(video, target): paths, scores, error = process_video(video, target, progress_bar) if paths is not None: heatmap_path = create_heatmap(scores, 0) first_frame = Image.open(paths[0]) return paths, scores, np.array(first_frame), heatmap_path, error, gr.Slider(maximum=len(paths) - 1, value=0) return None, None, None, None, error, gr.Slider(maximum=100, value=0) video_input.upload(process_and_update, inputs=[video_input, target_input], outputs=[frame_paths, frame_scores, output_image, heatmap_output, error_output, frame_slider]) frame_slider.change(update_frame_and_heatmap, inputs=[frame_slider, frame_paths, frame_scores], outputs=[output_image, heatmap_output]) def use_sample_video(): sample_video_path = "Drone Video of African Wildlife Wild Botswan.mp4" return process_and_update(sample_video_path, "Elephant") use_sample_button.click(use_sample_video, inputs=None, outputs=[frame_paths, frame_scores, output_image, heatmap_output, error_output, frame_slider]) # Layout with gr.Row(): with gr.Column(scale=2): output_image with gr.Column(scale=1): sample_video_frame use_sample_button return app if __name__ == "__main__": app = gradio_app() app.launch() # Cleanup temporary files def cleanup(): for path in frame_paths.value: if os.path.exists(path): os.remove(path) if os.path.exists(temp_dir): shutil.rmtree(temp_dir) # Make sure to call cleanup when the app is closed # This might require additional setup depending on how you're running the app