File size: 4,975 Bytes
81e2598 36dd82f 81e2598 36dd82f 81e2598 36dd82f 81e2598 36dd82f 81e2598 36dd82f 81e2598 36dd82f 81e2598 36dd82f 3f2cadc 81e2598 36dd82f 81e2598 36dd82f 81e2598 f05ca8c 81e2598 36dd82f 3f2cadc 81e2598 3f2cadc 81e2598 36dd82f 81e2598 3f2cadc 36dd82f 81e2598 36dd82f 81e2598 36dd82f 81e2598 36dd82f 81e2598 36dd82f 81e2598 36dd82f 81e2598 3f2cadc 81e2598 36dd82f 81e2598 36dd82f 81e2598 36dd82f 81e2598 36dd82f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import gradio as gr
import cv2
from PIL import Image, ImageDraw, ImageFont
import torch
from transformers import Owlv2Processor, Owlv2ForObjectDetection
import numpy as np
import os
# Check if CUDA is available, otherwise use CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
def detect_objects_in_frame(image, target):
draw = ImageDraw.Draw(image)
texts = [[target]]
inputs = processor(text=texts, images=image, return_tensors="pt", padding=True).to(device)
outputs = model(**inputs)
target_sizes = torch.Tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs=outputs, threshold=0.1, target_sizes=target_sizes)
color_map = {target: "red"}
try:
font = ImageFont.truetype("arial.ttf", 15)
except IOError:
font = ImageFont.load_default()
i = 0
text = texts[i]
boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
for box, score, label in zip(boxes, scores, labels):
if score.item() >= 0.25:
box = [round(i, 2) for i in box.tolist()]
object_label = text[label]
confidence = round(score.item(), 3)
annotation = f"{object_label}: {confidence}"
draw.rectangle(box, outline=color_map.get(object_label, "red"), width=2)
text_position = (box[0], box[1] - 10)
draw.text(text_position, annotation, fill="white", font=font)
return image
def process_video(video_path, target, progress=gr.Progress()):
if video_path is None:
return None, "Error: No video uploaded"
if not os.path.exists(video_path):
return None, f"Error: Video file not found at {video_path}"
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return None, f"Error: Unable to open video file at {video_path}"
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
original_fps = int(cap.get(cv2.CAP_PROP_FPS))
original_duration = frame_count / original_fps
output_fps = 5
output_path = "output_video.mp4"
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, output_fps, (int(cap.get(3)), int(cap.get(4))))
batch_size = 64
frames = []
for frame in progress.tqdm(range(frame_count)):
ret, img = cap.read()
if not ret:
break
if frame % (original_fps // output_fps) != 0:
continue
pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
frames.append(pil_img)
if len(frames) == batch_size or frame == frame_count - 1:
annotated_frames = [detect_objects_in_frame(frame, target) for frame in frames]
for annotated_img in annotated_frames:
annotated_frame = cv2.cvtColor(np.array(annotated_img), cv2.COLOR_RGB2BGR)
out.write(annotated_frame)
frames = []
cap.release()
out.release()
return output_path, None
def load_sample_frame(video_path):
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return None
ret, frame = cap.read()
cap.release()
if not ret:
return None
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return frame_rgb
def gradio_app():
with gr.Blocks() as app:
gr.Markdown("# Video Object Detection with Owlv2")
video_input = gr.Video(label="Upload Video")
target_input = gr.Textbox(label="Target Object")
output_video = gr.Video(label="Output Video")
error_output = gr.Textbox(label="Error Messages", visible=False)
sample_video_frame = gr.Image(value=load_sample_frame("Drone Video of African Wildlife Wild Botswan.mp4"), label="Sample Video Frame")
use_sample_button = gr.Button("Use Sample Video")
progress_bar = gr.Progress()
video_path = gr.State(None)
def process_and_update(video, target):
output_video_path, error = process_video(video, target, progress_bar)
if error:
error_output.visible = True
else:
error_output.visible = False
return output_video_path, error
video_input.upload(process_and_update,
inputs=[video_input, target_input],
outputs=[output_video, error_output])
def use_sample_video():
sample_video_path = "Drone Video of African Wildlife Wild Botswan.mp4"
return process_and_update(sample_video_path, "animal")
use_sample_button.click(use_sample_video,
inputs=None,
outputs=[output_video, error_output])
return app
if __name__ == "__main__":
app = gradio_app()
app.launch(share=True)
|