reab5555's picture
Update app.py
7f942f1 verified
raw
history blame
5.04 kB
import gradio as gr
import cv2
from PIL import Image, ImageDraw, ImageFont
import torch
from transformers import Owlv2Processor, Owlv2ForObjectDetection
import numpy as np
import os
# Check if CUDA is available, otherwise use CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
def detect_objects_in_frame(image, target):
draw = ImageDraw.Draw(image)
texts = [[target]]
inputs = processor(text=texts, images=image, return_tensors="pt", padding=True).to(device)
outputs = model(**inputs)
target_sizes = torch.Tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs=outputs, threshold=0.1, target_sizes=target_sizes)
color_map = {target: "red"}
try:
font = ImageFont.truetype("arial.ttf", 15)
except IOError:
font = ImageFont.load_default()
i = 0
text = texts[i]
boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
for box, score, label in zip(boxes, scores, labels):
if score.item() >= 0.25:
box = [round(i, 2) for i in box.tolist()]
object_label = text[label]
confidence = round(score.item(), 3)
annotation = f"{object_label}: {confidence}"
draw.rectangle(box, outline=color_map.get(object_label, "red"), width=2)
text_position = (box[0], box[1] - 10)
draw.text(text_position, annotation, fill="white", font=font)
return image
def process_video(video_path, target, progress=gr.Progress()):
if video_path is None:
return None, "Error: No video uploaded"
if not os.path.exists(video_path):
return None, f"Error: Video file not found at {video_path}"
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return None, f"Error: Unable to open video file at {video_path}"
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
processed_frames = []
for frame in progress.tqdm(range(frame_count)):
ret, img = cap.read()
if not ret:
break
pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
annotated_img = detect_objects_in_frame(pil_img, target)
processed_frames.append(np.array(annotated_img))
cap.release()
return processed_frames, None
def load_sample_frame(video_path):
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return None
ret, frame = cap.read()
cap.release()
if not ret:
return None
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return frame_rgb
def gradio_app():
with gr.Blocks() as app:
gr.Markdown("# Video Object Detection with Owlv2")
video_input = gr.Video(label="Upload Video")
target_input = gr.Textbox(label="Target Object", value="Elephant")
frame_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Frame", value=0)
output_image = gr.Image(label="Processed Frame")
error_output = gr.Textbox(label="Error Messages", visible=False)
sample_video_frame = gr.Image(value=load_sample_frame("Drone Video of African Wildlife Wild Botswan.mp4"), label="Sample Video Frame")
use_sample_button = gr.Button("Use Sample Video")
progress_bar = gr.Progress()
processed_frames = gr.State([])
def process_and_update(video, target):
frames, error = process_video(video, target, progress_bar)
if frames is not None:
frame_slider.maximum = len(frames) - 1
frame_slider.value = 0
return frames, frames[0], error, gr.Slider.update(maximum=len(frames) - 1, value=0)
return None, None, error, gr.Slider.update(maximum=100, value=0)
def update_frame(frame_index, frames):
if frames and 0 <= frame_index < len(frames):
return frames[frame_index]
return None
video_input.upload(process_and_update,
inputs=[video_input, target_input],
outputs=[processed_frames, output_image, error_output, frame_slider])
frame_slider.change(update_frame,
inputs=[frame_slider, processed_frames],
outputs=[output_image])
def use_sample_video():
sample_video_path = "Drone Video of African Wildlife Wild Botswan.mp4"
frames, output_image, error, slider_update = process_and_update(sample_video_path, "Elephant")
return frames, output_image, error, slider_update
use_sample_button.click(use_sample_video,
inputs=None,
outputs=[processed_frames, output_image, error_output, frame_slider])
return app
if __name__ == "__main__":
app = gradio_app()
app.launch(share=True)