|
import gradio as gr |
|
import cv2 |
|
from PIL import Image, ImageDraw, ImageFont |
|
import torch |
|
from transformers import Owlv2Processor, Owlv2ForObjectDetection |
|
import numpy as np |
|
import os |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16") |
|
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device) |
|
|
|
def detect_objects_in_frame(image, target): |
|
draw = ImageDraw.Draw(image) |
|
texts = [[target]] |
|
inputs = processor(text=texts, images=image, return_tensors="pt", padding=True).to(device) |
|
outputs = model(**inputs) |
|
|
|
target_sizes = torch.Tensor([image.size[::-1]]) |
|
results = processor.post_process_object_detection(outputs=outputs, threshold=0.1, target_sizes=target_sizes) |
|
|
|
color_map = {target: "red"} |
|
|
|
try: |
|
font = ImageFont.truetype("arial.ttf", 15) |
|
except IOError: |
|
font = ImageFont.load_default() |
|
|
|
i = 0 |
|
text = texts[i] |
|
boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"] |
|
|
|
for box, score, label in zip(boxes, scores, labels): |
|
if score.item() >= 0.25: |
|
box = [round(i, 2) for i in box.tolist()] |
|
object_label = text[label] |
|
confidence = round(score.item(), 3) |
|
annotation = f"{object_label}: {confidence}" |
|
|
|
draw.rectangle(box, outline=color_map.get(object_label, "red"), width=2) |
|
text_position = (box[0], box[1] - 10) |
|
draw.text(text_position, annotation, fill="white", font=font) |
|
|
|
return image |
|
|
|
def process_video(video_path, target, progress=gr.Progress()): |
|
if video_path is None: |
|
return None, "Error: No video uploaded" |
|
|
|
if not os.path.exists(video_path): |
|
return None, f"Error: Video file not found at {video_path}" |
|
|
|
cap = cv2.VideoCapture(video_path) |
|
if not cap.isOpened(): |
|
return None, f"Error: Unable to open video file at {video_path}" |
|
|
|
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
original_fps = int(cap.get(cv2.CAP_PROP_FPS)) |
|
original_duration = frame_count / original_fps |
|
output_fps = 5 |
|
|
|
output_path = "output_video.mp4" |
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
out = cv2.VideoWriter(output_path, fourcc, output_fps, (int(cap.get(3)), int(cap.get(4)))) |
|
|
|
batch_size = 64 |
|
frames = [] |
|
|
|
for frame in progress.tqdm(range(frame_count)): |
|
ret, img = cap.read() |
|
if not ret: |
|
break |
|
|
|
if frame % (original_fps // output_fps) != 0: |
|
continue |
|
|
|
pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) |
|
frames.append(pil_img) |
|
|
|
if len(frames) == batch_size or frame == frame_count - 1: |
|
annotated_frames = [detect_objects_in_frame(frame, target) for frame in frames] |
|
for annotated_img in annotated_frames: |
|
annotated_frame = cv2.cvtColor(np.array(annotated_img), cv2.COLOR_RGB2BGR) |
|
out.write(annotated_frame) |
|
frames = [] |
|
|
|
cap.release() |
|
out.release() |
|
|
|
return output_path, None |
|
|
|
def load_sample_frame(video_path): |
|
cap = cv2.VideoCapture(video_path) |
|
if not cap.isOpened(): |
|
return None |
|
ret, frame = cap.read() |
|
cap.release() |
|
if not ret: |
|
return None |
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
return frame_rgb |
|
|
|
def gradio_app(): |
|
with gr.Blocks() as app: |
|
gr.Markdown("# Video Object Detection with Owlv2") |
|
|
|
video_input = gr.Video(label="Upload Video") |
|
target_input = gr.Textbox(label="Target Object") |
|
output_video = gr.Video(label="Output Video") |
|
error_output = gr.Textbox(label="Error Messages", visible=False) |
|
sample_video_frame = gr.Image(value=load_sample_frame("Drone Video of African Wildlife Wild Botswan.mp4"), label="Sample Video Frame") |
|
use_sample_button = gr.Button("Use Sample Video") |
|
progress_bar = gr.Progress() |
|
|
|
video_path = gr.State(None) |
|
def process_and_update(video, target): |
|
output_video_path, error = process_video(video, target, progress_bar) |
|
if error: |
|
error_output.visible = True |
|
else: |
|
error_output.visible = False |
|
return output_video_path, error |
|
|
|
video_input.upload(process_and_update, |
|
inputs=[video_input, target_input], |
|
outputs=[output_video, error_output]) |
|
|
|
def use_sample_video(): |
|
sample_video_path = "Drone Video of African Wildlife Wild Botswan.mp4" |
|
return process_and_update(sample_video_path, "animal") |
|
|
|
use_sample_button.click(use_sample_video, |
|
inputs=None, |
|
outputs=[output_video, error_output]) |
|
|
|
return app |
|
|
|
if __name__ == "__main__": |
|
app = gradio_app() |
|
app.launch(share=True) |
|
|