File size: 5,039 Bytes
81e2598
36dd82f
81e2598
 
 
36dd82f
 
81e2598
ebc376e
 
81e2598
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9be460
81e2598
 
 
 
 
 
 
 
 
 
 
 
36dd82f
 
81e2598
36dd82f
81e2598
36dd82f
 
81e2598
36dd82f
 
 
81e2598
36dd82f
3f2cadc
69dc1f7
 
 
7f942f1
69dc1f7
f05ca8c
69dc1f7
 
36dd82f
 
 
3f2cadc
81e2598
7f942f1
 
3f2cadc
36dd82f
7f942f1
36dd82f
 
 
 
 
 
 
 
 
 
 
 
 
 
69dc1f7
81e2598
36dd82f
a702d47
7f942f1
 
36dd82f
81e2598
 
 
 
7f942f1
 
81e2598
7f942f1
 
53eff3d
 
7f942f1
 
 
 
 
36dd82f
 
81e2598
53eff3d
7f942f1
 
 
 
36dd82f
 
81e2598
53eff3d
36dd82f
 
 
53eff3d
36dd82f
 
 
 
 
7f942f1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import gradio as gr
import cv2
from PIL import Image, ImageDraw, ImageFont
import torch
from transformers import Owlv2Processor, Owlv2ForObjectDetection
import numpy as np
import os

# Check if CUDA is available, otherwise use CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'

processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)

def detect_objects_in_frame(image, target):
    draw = ImageDraw.Draw(image)
    texts = [[target]]
    inputs = processor(text=texts, images=image, return_tensors="pt", padding=True).to(device)
    outputs = model(**inputs)

    target_sizes = torch.Tensor([image.size[::-1]])
    results = processor.post_process_object_detection(outputs=outputs, threshold=0.1, target_sizes=target_sizes)

    color_map = {target: "red"}

    try:
        font = ImageFont.truetype("arial.ttf", 15)
    except IOError:
        font = ImageFont.load_default()

    i = 0
    text = texts[i]
    boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]

    for box, score, label in zip(boxes, scores, labels):
        if score.item() >= 0.25:
            box = [round(i, 2) for i in box.tolist()]
            object_label = text[label]
            confidence = round(score.item(), 3)
            annotation = f"{object_label}: {confidence}"

            draw.rectangle(box, outline=color_map.get(object_label, "red"), width=2)
            text_position = (box[0], box[1] - 10)
            draw.text(text_position, annotation, fill="white", font=font)

    return image

def process_video(video_path, target, progress=gr.Progress()):
    if video_path is None:
        return None, "Error: No video uploaded"

    if not os.path.exists(video_path):
        return None, f"Error: Video file not found at {video_path}"

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        return None, f"Error: Unable to open video file at {video_path}"

    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    original_fps = int(cap.get(cv2.CAP_PROP_FPS))
    output_fps = 3

    processed_frames = []
    frame_interval = max(1, round(original_fps / output_fps))

    for frame in progress.tqdm(range(0, frame_count, frame_interval)):
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame)
        ret, img = cap.read()
        if not ret:
            break

        pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        annotated_img = detect_objects_in_frame(pil_img, target)
        processed_frames.append(np.array(annotated_img))

    cap.release()
    return processed_frames, None

def load_sample_frame(video_path):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        return None
    ret, frame = cap.read()
    cap.release()
    if not ret:
        return None
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    return frame_rgb

def gradio_app():
    with gr.Blocks() as app:
        gr.Markdown("# Video Object Detection with Owlv2 (3 FPS)")

        video_input = gr.Video(label="Upload Video")
        target_input = gr.Textbox(label="Target Object", value="Elephant")
        frame_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Frame", value=0)
        output_image = gr.Image(label="Processed Frame")
        error_output = gr.Textbox(label="Error Messages", visible=False)
        sample_video_frame = gr.Image(value=load_sample_frame("Drone Video of African Wildlife Wild Botswan.mp4"), label="Sample Video Frame")
        use_sample_button = gr.Button("Use Sample Video")
        progress_bar = gr.Progress()

        processed_frames = gr.State([])

        def process_and_update(video, target):
            frames, error = process_video(video, target, progress_bar)
            if frames is not None:
                return frames, frames[0], error, gr.Slider(maximum=len(frames) - 1, value=0)
            return None, None, error, gr.Slider(maximum=100, value=0)

        def update_frame(frame_index, frames):
            if frames and 0 <= frame_index < len(frames):
                return frames[frame_index]
            return None

        video_input.upload(process_and_update, 
                           inputs=[video_input, target_input], 
                           outputs=[processed_frames, output_image, error_output, frame_slider])

        frame_slider.change(update_frame, 
                            inputs=[frame_slider, processed_frames], 
                            outputs=[output_image])

        def use_sample_video():
            sample_video_path = "Drone Video of African Wildlife Wild Botswan.mp4"
            return process_and_update(sample_video_path, "Elephant")

        use_sample_button.click(use_sample_video, 
                                inputs=None, 
                                outputs=[processed_frames, output_image, error_output, frame_slider])

    return app

if __name__ == "__main__":
    app = gradio_app()
    app.launch(share=True)