File size: 9,562 Bytes
6321191
81e2598
36dd82f
81e2598
 
 
36dd82f
 
343407e
fa3925e
5dde850
81e2598
f69b877
81e2598
 
c9f1714
 
f69b877
c9f1714
81e2598
36dd82f
343407e
36dd82f
 
343407e
36dd82f
 
 
343407e
36dd82f
3f2cadc
69dc1f7
ac7c798
343407e
 
69dc1f7
343407e
5dde850
 
f05ca8c
7e7ddb7
c9f1714
 
 
5dde850
343407e
 
36dd82f
 
 
3f2cadc
6321191
f662a68
7d47fdc
c9f1714
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4f7d8f
08a3f43
 
 
 
 
 
e4f7d8f
08a3f43
2a7def2
e4f7d8f
 
 
 
08a3f43
 
 
 
e4f7d8f
 
 
 
 
 
 
 
 
 
 
 
 
 
08a3f43
 
e4f7d8f
08a3f43
c9f1714
 
 
 
 
 
 
 
 
3f2cadc
5dde850
 
 
 
36dd82f
5dde850
 
8714cd1
ec0d71c
5dde850
ec0d71c
 
343407e
8714cd1
 
08a3f43
8714cd1
2a7def2
8714cd1
 
 
343407e
 
fa3925e
8714cd1
343407e
 
fa3925e
36dd82f
05f5d03
36dd82f
 
 
4417183
05f5d03
 
4417183
05f5d03
 
4417183
36dd82f
 
 
 
 
 
 
5dde850
 
 
 
 
 
 
36dd82f
 
cc209a2
81e2598
36dd82f
a702d47
7f942f1
343407e
8714cd1
36dd82f
05f5d03
 
 
 
81e2598
 
 
5dde850
343407e
7f942f1
81e2598
5dde850
 
 
 
 
343407e
7f942f1
36dd82f
81e2598
5dde850
7f942f1
8714cd1
5dde850
8714cd1
36dd82f
 
81e2598
53eff3d
36dd82f
 
 
5dde850
36dd82f
3cb7297
8714cd1
 
3cb7297
8714cd1
3cb7297
 
8714cd1
36dd82f
 
 
 
08a3f43
5dde850
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
import spaces
import gradio as gr
import cv2
from PIL import Image, ImageDraw, ImageFont
import torch
from transformers import Owlv2Processor, Owlv2ForObjectDetection
import numpy as np
import os
import matplotlib.pyplot as plt
import tempfile
import shutil

device = "cuda"

processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16")

model = model.to(device)

def process_video(video_path, target, progress=gr.Progress()):
    if video_path is None:
        return None, None, "Error: No video uploaded"

    if not os.path.exists(video_path):
        return None, None, f"Error: Video file not found at {video_path}"

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        return None, None, f"Error: Unable to open video file at {video_path}"

    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    original_fps = int(cap.get(cv2.CAP_PROP_FPS))
    output_fps = 1
    frame_duration = 1 / output_fps
    video_duration = frame_count / original_fps

    frame_scores = []
    temp_dir = tempfile.mkdtemp()
    frame_paths = []

    batch_size = 1
    batch_frames = []
    batch_indices = []

    for i, time in enumerate(progress.tqdm(np.arange(0, video_duration, frame_duration))):
        frame_number = int(time * original_fps)
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
        ret, img = cap.read()
        if not ret:
            break

        # Convert to RGB without resizing
        pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        
        batch_frames.append(pil_img)
        batch_indices.append(i)

        if len(batch_frames) == batch_size or i == int(video_duration / frame_duration) - 1:
            # Process batch
            inputs = processor(text=[target] * len(batch_frames), images=batch_frames, return_tensors="pt", padding=True).to(device)
            
            with torch.no_grad():
                outputs = model(**inputs)

            target_sizes = torch.Tensor([pil_img.size[::-1] for _ in batch_frames]).to(device)
            results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)

            for idx, (pil_img, result) in enumerate(zip(batch_frames, results)):
                draw = ImageDraw.Draw(pil_img)
                max_score = 0

                boxes, scores, labels = result["boxes"], result["scores"], result["labels"]

                # Inside the loop where bounding boxes are drawn
                for box, score, label in zip(boxes, scores, labels):
                    if score.item() >= 0.5:
                        box = [round(i, 2) for i in box.tolist()]
                        object_label = target
                        confidence = round(score.item(), 3)
                        annotation = f"{object_label}: {confidence}"
                
                        # Increase line width for the bounding box
                        draw.rectangle(box, outline="red", width=3)
                
                        # Calculate font size based on image dimensions
                        img_width, img_height = pil_img.size
                        font_size = int(min(img_width, img_height) * 0.03)  # 3% of the smaller dimension
                        try:
                            font = ImageFont.truetype("arial.ttf", font_size)
                        except IOError:
                            font = ImageFont.load_default()
                
                        # Calculate text size
                        text_bbox = draw.textbbox((0, 0), annotation, font=font)
                        text_width = text_bbox[2] - text_bbox[0]
                        text_height = text_bbox[3] - text_bbox[1]
                
                        # Position text inside the top of the bounding box
                        text_position = (box[0], box[1])
                
                        # Draw semi-transparent background for text
                        draw.rectangle([text_position[0], text_position[1], 
                                        text_position[0] + text_width, text_position[1] + text_height], 
                                       fill=(0, 0, 0, 128))
                
                        # Draw text in red
                        draw.text(text_position, annotation, fill="red", font=font)
                
                        max_score = max(max_score, confidence)

                frame_path = os.path.join(temp_dir, f"frame_{batch_indices[idx]:04d}.png")
                pil_img.save(frame_path)
                frame_paths.append(frame_path)
                frame_scores.append(max_score)

            # Clear batch
            batch_frames = []
            batch_indices = []

        # Clear GPU cache every 10 frames
        if i % 10 == 0:
            torch.cuda.empty_cache()

    cap.release()
    return frame_paths, frame_scores, None

def create_heatmap(frame_scores, current_frame):
    plt.figure(figsize=(16, 4))
    plt.imshow([frame_scores], cmap='hot_r', aspect='auto')
    plt.title('Object Detection Heatmap', fontsize=14)
    plt.xlabel('Frame', fontsize=12)
    plt.yticks([])

    num_frames = len(frame_scores)
    step = max(1, num_frames // 20)
    frame_numbers = range(0, num_frames, step)
    plt.xticks(frame_numbers, [str(i) for i in frame_numbers], rotation=90, ha='right')

    plt.axvline(x=current_frame, color='blue', linestyle='--', linewidth=2)

    plt.tight_layout()
    
    with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp_file:
        plt.savefig(tmp_file.name, format='png', dpi=400, bbox_inches='tight')
    plt.close()
    
    return tmp_file.name

def load_sample_frame(video_path, target_frame=87, original_fps=30, processing_fps=1):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        return None
    
    # Calculate the corresponding frame number in the original video
    original_frame_number = int(target_frame * (original_fps / processing_fps))
    
    # Set the frame position
    cap.set(cv2.CAP_PROP_POS_FRAMES, original_frame_number)
    
    ret, frame = cap.read()
    cap.release()
    if not ret:
        return None
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    return frame_rgb

def update_frame_and_heatmap(frame_index, frame_paths, scores):
    if frame_paths and 0 <= frame_index < len(frame_paths):
        frame = Image.open(frame_paths[frame_index])
        heatmap_path = create_heatmap(scores, frame_index)
        return np.array(frame), heatmap_path
    return None, None

def gradio_app():
    with gr.Blocks() as app:
        gr.Markdown("# Video Object Detection with Owlv2")

        video_input = gr.Video(label="Upload Video")
        target_input = gr.Textbox(label="Target Object", value="Elephant")
        frame_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Frame", value=0)
        heatmap_output = gr.Image(label="Detection Heatmap")
        output_image = gr.Image(label="Processed Frame")
        error_output = gr.Textbox(label="Error Messages", visible=False)
        sample_video_frame = gr.Image(
            value=load_sample_frame("Drone Video of African Wildlife Wild Botswan.mp4", target_frame=87),
            label="Drone Video of African Wildlife Wild Botswan by wildimagesonline.com - Sample Video Frame (Frame 87 at 1 FPS)"
        )
        use_sample_button = gr.Button("Use Sample Video")
        progress_bar = gr.Progress()

        frame_paths = gr.State([])
        frame_scores = gr.State([])

        def process_and_update(video, target):
            paths, scores, error = process_video(video, target, progress_bar)
            if paths is not None:
                heatmap_path = create_heatmap(scores, 0)
                first_frame = Image.open(paths[0])
                return paths, scores, np.array(first_frame), heatmap_path, error, gr.Slider(maximum=len(paths) - 1, value=0)
            return None, None, None, None, error, gr.Slider(maximum=100, value=0)

        video_input.upload(process_and_update, 
                           inputs=[video_input, target_input], 
                           outputs=[frame_paths, frame_scores, output_image, heatmap_output, error_output, frame_slider])

        frame_slider.change(update_frame_and_heatmap, 
                            inputs=[frame_slider, frame_paths, frame_scores], 
                            outputs=[output_image, heatmap_output])

        def use_sample_video():
            sample_video_path = "Drone Video of African Wildlife Wild Botswan.mp4"
            return process_and_update(sample_video_path, "Elephant")

        use_sample_button.click(use_sample_video, 
                                inputs=None, 
                                outputs=[frame_paths, frame_scores, output_image, heatmap_output, error_output, frame_slider])

        # Layout
        with gr.Row():
            with gr.Column(scale=2):
                output_image
            with gr.Column(scale=1):
                sample_video_frame
                use_sample_button

    return app

if __name__ == "__main__":
    app = gradio_app()
    app.launch()

    # Cleanup temporary files
    def cleanup():
        for path in frame_paths.value:
            if os.path.exists(path):
                os.remove(path)
        if os.path.exists(temp_dir):
            shutil.rmtree(temp_dir)

    # Make sure to call cleanup when the app is closed
    # This might require additional setup depending on how you're running the app