reab5555's picture
Update app.py
ec0d71c verified
raw
history blame
9.24 kB
import gradio as gr
import cv2
from PIL import Image, ImageDraw, ImageFont
import torch
from transformers import Owlv2Processor, Owlv2ForObjectDetection
import numpy as np
import os
import matplotlib.pyplot as plt
from io import BytesIO
import tempfile
import shutil
# Check if CUDA is available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16")
# Try to move model to GPU and use half precision
try:
model = model.to(device).half()
except RuntimeError:
print("GPU out of memory, using CPU instead")
device = torch.device("cpu")
model = model.to(device)
def process_video(video_path, target, progress=gr.Progress()):
if video_path is None:
return None, None, "Error: No video uploaded"
if not os.path.exists(video_path):
return None, None, f"Error: Video file not found at {video_path}"
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return None, None, f"Error: Unable to open video file at {video_path}"
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
original_fps = int(cap.get(cv2.CAP_PROP_FPS))
output_fps = 1
frame_duration = 1 / output_fps
video_duration = frame_count / original_fps
frame_scores = []
temp_dir = tempfile.mkdtemp()
frame_paths = []
# Try to use GPU with half precision, fall back to CPU if out of memory
try:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device).half() # Convert model to half precision
except RuntimeError:
print("GPU out of memory, falling back to CPU")
device = torch.device("cpu")
model.to(device)
batch_size = 1
batch_frames = []
batch_indices = []
for i, time in enumerate(progress.tqdm(np.arange(0, video_duration, frame_duration))):
frame_number = int(time * original_fps)
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
ret, img = cap.read()
if not ret:
break
# Resize the frame
#img_resized = cv2.resize(img, (1280, 720))
pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
batch_frames.append(pil_img)
batch_indices.append(i)
if len(batch_frames) == batch_size or i == int(video_duration / frame_duration) - 1:
# Process batch
inputs = processor(text=[target] * len(batch_frames), images=batch_frames, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
outputs = model(**inputs)
target_sizes = torch.Tensor([pil_img.size[::-1] for _ in batch_frames]).to(device)
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)
for idx, (pil_img, result) in enumerate(zip(batch_frames, results)):
draw = ImageDraw.Draw(pil_img)
max_score = 0
try:
font = ImageFont.truetype("arial.ttf", 20)
except IOError:
font = ImageFont.load_default()
boxes, scores, labels = result["boxes"], result["scores"], result["labels"]
for box, score, label in zip(boxes, scores, labels):
if score.item() >= 0.5:
box = [round(i, 2) for i in box.tolist()]
object_label = target
confidence = round(score.item(), 3)
annotation = f"{object_label}: {confidence}"
# Increase line width for the bounding box
draw.rectangle(box, outline="red", width=4)
# Increase font size and change color to red
font_size = 30 # Increased from 20
try:
font = ImageFont.truetype("arial.ttf", font_size)
except IOError:
font = ImageFont.load_default()
text_position = (box[0], box[1] - font_size - 5)
# Add a semi-transparent background for better text visibility
text_bbox = draw.textbbox(text_position, annotation, font=font)
draw.rectangle(text_bbox, fill=(0, 0, 0, 128))
# Draw text in red
draw.text(text_position, annotation, fill="red", font=font)
max_score = max(max_score, confidence)
# Save frame to disk
frame_path = os.path.join(temp_dir, f"frame_{batch_indices[idx]:04d}.png")
pil_img.save(frame_path)
frame_paths.append(frame_path)
frame_scores.append(max_score)
# Clear batch
batch_frames = []
batch_indices = []
# Clear GPU cache every 10 frames
if i % 10 == 0:
torch.cuda.empty_cache()
cap.release()
return frame_paths, frame_scores, None
def create_heatmap(frame_scores, current_frame):
plt.figure(figsize=(16, 4))
plt.imshow([frame_scores], cmap='hot_r', aspect='auto')
plt.title('Object Detection Heatmap', fontsize=14)
plt.xlabel('Frame', fontsize=12)
plt.yticks([])
# Add more frame numbers on x-axis
num_frames = len(frame_scores)
step = max(1, num_frames // 20) # Show at most 20 frame numbers
frame_numbers = range(0, num_frames, step)
plt.xticks(frame_numbers, [str(i) for i in frame_numbers], rotation=45, ha='right')
# Add vertical line for current frame
plt.axvline(x=current_frame, color='blue', linestyle='--', linewidth=2)
plt.tight_layout()
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp_file:
plt.savefig(tmp_file.name, format='png', dpi=400, bbox_inches='tight')
plt.close()
return tmp_file.name
def load_sample_frame(video_path):
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return None
ret, frame = cap.read()
cap.release()
if not ret:
return None
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return frame_rgb
def update_frame_and_heatmap(frame_index, frame_paths, scores):
if frame_paths and 0 <= frame_index < len(frame_paths):
frame = Image.open(frame_paths[frame_index])
heatmap_path = create_heatmap(scores, frame_index)
return np.array(frame), heatmap_path
return None, None
def gradio_app():
with gr.Blocks() as app:
gr.Markdown("# Video Object Detection with Owlv2")
video_input = gr.Video(label="Upload Video")
target_input = gr.Textbox(label="Target Object", value="Elephant")
frame_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Frame", value=0)
heatmap_output = gr.Image(label="Detection Heatmap")
output_image = gr.Image(label="Processed Frame")
error_output = gr.Textbox(label="Error Messages", visible=False)
sample_video_frame = gr.Image(value=load_sample_frame("Drone Video of African Wildlife Wild Botswan.mp4"), label="Drone Video of African Wildlife Wild Botswan by wildimagesonline.com - Sample Video Frame")
use_sample_button = gr.Button("Use Sample Video")
progress_bar = gr.Progress()
frame_paths = gr.State([])
frame_scores = gr.State([])
def process_and_update(video, target):
paths, scores, error = process_video(video, target, progress_bar)
if paths is not None:
heatmap_path = create_heatmap(scores, 0)
first_frame = Image.open(paths[0])
return paths, scores, np.array(first_frame), heatmap_path, error, gr.Slider(maximum=len(paths) - 1, value=0)
return None, None, None, None, error, gr.Slider(maximum=100, value=0)
video_input.upload(process_and_update,
inputs=[video_input, target_input],
outputs=[frame_paths, frame_scores, output_image, heatmap_output, error_output, frame_slider])
frame_slider.change(update_frame_and_heatmap,
inputs=[frame_slider, frame_paths, frame_scores],
outputs=[output_image, heatmap_output])
def use_sample_video():
sample_video_path = "Drone Video of African Wildlife Wild Botswan.mp4"
return process_and_update(sample_video_path, "Elephant")
use_sample_button.click(use_sample_video,
inputs=None,
outputs=[frame_paths, frame_scores, output_image, heatmap_output, error_output, frame_slider])
# Layout
with gr.Row():
with gr.Column(scale=2):
output_image
with gr.Column(scale=1):
sample_video_frame
use_sample_button
return app
if __name__ == "__main__":
app = gradio_app()
app.launch(share=True)
# Cleanup temporary files
def cleanup():
for path in frame_paths.value:
if os.path.exists(path):
os.remove(path)
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
# Make sure to call cleanup when the app is closed
# This might require additional setup depending on how you're running the app