File size: 6,990 Bytes
97afa8e
81e2598
36dd82f
81e2598
 
 
36dd82f
 
343407e
 
fa3925e
81e2598
ebc376e
 
81e2598
 
 
 
7d47fdc
81e2598
36dd82f
343407e
36dd82f
 
343407e
36dd82f
 
 
343407e
36dd82f
3f2cadc
69dc1f7
ac7c798
343407e
 
69dc1f7
7f942f1
343407e
f05ca8c
343407e
 
 
36dd82f
 
 
3f2cadc
81e2598
7d47fdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f2cadc
36dd82f
343407e
 
8714cd1
 
 
 
 
 
343407e
 
 
8714cd1
 
 
 
 
 
 
 
 
 
343407e
 
fa3925e
8714cd1
343407e
 
fa3925e
36dd82f
 
 
 
 
 
 
 
 
 
 
 
 
 
cc209a2
81e2598
36dd82f
a702d47
7f942f1
343407e
8714cd1
36dd82f
8714cd1
81e2598
 
 
7f942f1
343407e
7f942f1
81e2598
343407e
7f942f1
8714cd1
fa3925e
343407e
7f942f1
8714cd1
7f942f1
8714cd1
 
 
36dd82f
 
81e2598
343407e
7f942f1
8714cd1
 
 
36dd82f
 
81e2598
53eff3d
36dd82f
 
 
343407e
36dd82f
3cb7297
8714cd1
 
3cb7297
8714cd1
3cb7297
 
8714cd1
36dd82f
 
 
 
7f942f1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import spaces
import gradio as gr
import cv2
from PIL import Image, ImageDraw, ImageFont
import torch
from transformers import Owlv2Processor, Owlv2ForObjectDetection
import numpy as np
import os
import matplotlib.pyplot as plt
from io import BytesIO
import tempfile

# Check if CUDA is available, otherwise use CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'

processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)

@spaces.GPU(duration=120)
def process_video(video_path, target, progress=gr.Progress()):
    if video_path is None:
        return None, None, "Error: No video uploaded"

    if not os.path.exists(video_path):
        return None, None, f"Error: Video file not found at {video_path}"

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        return None, None, f"Error: Unable to open video file at {video_path}"

    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    original_fps = int(cap.get(cv2.CAP_PROP_FPS))
    output_fps = 1
    frame_duration = 1 / output_fps
    video_duration = frame_count / original_fps

    processed_frames = []
    frame_scores = []

    for time in progress.tqdm(np.arange(0, video_duration, frame_duration)):
        frame_number = int(time * original_fps)
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
        ret, img = cap.read()
        if not ret:
            break

        pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        
        # Process single image
        inputs = processor(text=[target], images=pil_img, return_tensors="pt", padding=True).to(device)
        outputs = model(**inputs)

        target_sizes = torch.Tensor([pil_img.size[::-1]])
        results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)
        
        draw = ImageDraw.Draw(pil_img)
        max_score = 0

        try:
            font = ImageFont.truetype("arial.ttf", 40)
        except IOError:
            font = ImageFont.load_default()

        boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]

        for box, score, label in zip(boxes, scores, labels):
            if score.item() >= 0.25:
                box = [round(i, 2) for i in box.tolist()]
                object_label = target
                confidence = round(score.item(), 3)
                annotation = f"{object_label}: {confidence}"

                draw.rectangle(box, outline="red", width=2)
                text_position = (box[0], box[1] - 30)
                draw.text(text_position, annotation, fill="white", font=font)

                max_score = max(max_score, confidence)

        processed_frames.append(np.array(pil_img))
        frame_scores.append(max_score)

    cap.release()
    return processed_frames, frame_scores, None

def create_heatmap(frame_scores, current_frame):
    plt.figure(figsize=(12, 3))
    plt.imshow([frame_scores], cmap='hot_r', aspect='auto')  # 'hot_r' for reversed hot colormap
    cbar = plt.colorbar(label='Confidence')
    cbar.ax.yaxis.set_ticks_position('left')
    cbar.ax.yaxis.set_label_position('left')
    plt.title('Object Detection Heatmap')
    plt.xlabel('Frame')
    plt.yticks([])

    # Add more frame numbers on x-axis
    num_frames = len(frame_scores)
    step = max(1, num_frames // 10)  # Show at most 10 frame numbers
    frame_numbers = range(0, num_frames, step)
    plt.xticks(frame_numbers, [str(i) for i in frame_numbers])

    # Add vertical line for current frame
    plt.axvline(x=current_frame, color='blue', linestyle='--', linewidth=2)

    plt.tight_layout()
    
    with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp_file:
        plt.savefig(tmp_file.name, format='png', dpi=400, bbox_inches='tight')
    plt.close()
    
    return tmp_file.name

def load_sample_frame(video_path):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        return None
    ret, frame = cap.read()
    cap.release()
    if not ret:
        return None
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    return frame_rgb

def gradio_app():
    with gr.Blocks() as app:
        gr.Markdown("# Video Object Detection with Owlv2")

        video_input = gr.Video(label="Upload Video")
        target_input = gr.Textbox(label="Target Object", value="Elephant")
        frame_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Frame", value=0)
        heatmap_output = gr.Image(label="Detection Heatmap")
        output_image = gr.Image(label="Processed Frame")
        error_output = gr.Textbox(label="Error Messages", visible=False)
        sample_video_frame = gr.Image(value=load_sample_frame("Drone Video of African Wildlife Wild Botswan.mp4"), label="Drone Video of African Wildlife Wild Botswan by wildimagesonline.com - Sample Video Frame")
        use_sample_button = gr.Button("Use Sample Video")
        progress_bar = gr.Progress()

        processed_frames = gr.State([])
        frame_scores = gr.State([])

        def process_and_update(video, target):
            frames, scores, error = process_video(video, target, progress_bar)
            if frames is not None:
                heatmap_path = create_heatmap(scores, 0)  # Initial heatmap with current frame at 0
                return frames, scores, frames[0], heatmap_path, error, gr.Slider(maximum=len(frames) - 1, value=0)
            return None, None, None, None, error, gr.Slider(maximum=100, value=0)

        def update_frame_and_heatmap(frame_index, frames, scores):
            if frames and 0 <= frame_index < len(frames):
                heatmap_path = create_heatmap(scores, frame_index)
                return frames[frame_index], heatmap_path
            return None, None

        video_input.upload(process_and_update, 
                           inputs=[video_input, target_input], 
                           outputs=[processed_frames, frame_scores, output_image, heatmap_output, error_output, frame_slider])

        frame_slider.change(update_frame_and_heatmap, 
                            inputs=[frame_slider, processed_frames, frame_scores], 
                            outputs=[output_image, heatmap_output])

        def use_sample_video():
            sample_video_path = "Drone Video of African Wildlife Wild Botswan.mp4"
            return process_and_update(sample_video_path, "Elephant")

        use_sample_button.click(use_sample_video, 
                                inputs=None, 
                                outputs=[processed_frames, frame_scores, output_image, heatmap_output, error_output, frame_slider])

        # Layout
        with gr.Row():
            with gr.Column(scale=2):
                output_image
            with gr.Column(scale=1):
                sample_video_frame
                use_sample_button

    return app

if __name__ == "__main__":
    app = gradio_app()
    app.launch(share=True)