File size: 9,242 Bytes
81e2598
36dd82f
81e2598
 
 
36dd82f
 
343407e
 
fa3925e
5dde850
81e2598
ebc376e
c9f1714
81e2598
 
c9f1714
 
 
 
 
 
 
 
 
 
81e2598
 
36dd82f
343407e
36dd82f
 
343407e
36dd82f
 
 
343407e
36dd82f
3f2cadc
69dc1f7
ac7c798
343407e
 
69dc1f7
343407e
5dde850
 
f05ca8c
c9f1714
 
 
 
 
 
 
 
 
d6151e4
c9f1714
 
 
5dde850
343407e
 
36dd82f
 
 
3f2cadc
5dde850
696bc54
 
7d47fdc
c9f1714
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec0d71c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9f1714
ec0d71c
 
c9f1714
ec0d71c
c9f1714
 
 
 
 
 
 
 
 
 
3f2cadc
5dde850
 
 
 
36dd82f
5dde850
 
8714cd1
ec0d71c
5dde850
ec0d71c
 
343407e
8714cd1
696bc54
8714cd1
ec0d71c
8714cd1
ec0d71c
8714cd1
696bc54
8714cd1
 
343407e
 
fa3925e
8714cd1
343407e
 
fa3925e
36dd82f
 
 
 
 
 
 
 
 
 
 
 
5dde850
 
 
 
 
 
 
36dd82f
 
cc209a2
81e2598
36dd82f
a702d47
7f942f1
343407e
8714cd1
36dd82f
8714cd1
81e2598
 
 
5dde850
343407e
7f942f1
81e2598
5dde850
 
 
 
 
343407e
7f942f1
36dd82f
81e2598
5dde850
7f942f1
8714cd1
5dde850
8714cd1
36dd82f
 
81e2598
53eff3d
36dd82f
 
 
5dde850
36dd82f
3cb7297
8714cd1
 
3cb7297
8714cd1
3cb7297
 
8714cd1
36dd82f
 
 
 
5dde850
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
import gradio as gr
import cv2
from PIL import Image, ImageDraw, ImageFont
import torch
from transformers import Owlv2Processor, Owlv2ForObjectDetection
import numpy as np
import os
import matplotlib.pyplot as plt
from io import BytesIO
import tempfile
import shutil

# Check if CUDA is available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16")

# Try to move model to GPU and use half precision
try:
    model = model.to(device).half()
except RuntimeError:
    print("GPU out of memory, using CPU instead")
    device = torch.device("cpu")
    model = model.to(device)


def process_video(video_path, target, progress=gr.Progress()):
    if video_path is None:
        return None, None, "Error: No video uploaded"

    if not os.path.exists(video_path):
        return None, None, f"Error: Video file not found at {video_path}"

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        return None, None, f"Error: Unable to open video file at {video_path}"

    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    original_fps = int(cap.get(cv2.CAP_PROP_FPS))
    output_fps = 1
    frame_duration = 1 / output_fps
    video_duration = frame_count / original_fps

    frame_scores = []
    temp_dir = tempfile.mkdtemp()
    frame_paths = []

    # Try to use GPU with half precision, fall back to CPU if out of memory
    try:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(device).half()  # Convert model to half precision
    except RuntimeError:
        print("GPU out of memory, falling back to CPU")
        device = torch.device("cpu")
        model.to(device)

    batch_size = 1 
    batch_frames = []
    batch_indices = []

    for i, time in enumerate(progress.tqdm(np.arange(0, video_duration, frame_duration))):
        frame_number = int(time * original_fps)
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
        ret, img = cap.read()
        if not ret:
            break

        # Resize the frame
        #img_resized = cv2.resize(img, (1280, 720))
        pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        
        batch_frames.append(pil_img)
        batch_indices.append(i)

        if len(batch_frames) == batch_size or i == int(video_duration / frame_duration) - 1:
            # Process batch
            inputs = processor(text=[target] * len(batch_frames), images=batch_frames, return_tensors="pt", padding=True).to(device)
            
            with torch.no_grad():
                outputs = model(**inputs)

            target_sizes = torch.Tensor([pil_img.size[::-1] for _ in batch_frames]).to(device)
            results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)

            for idx, (pil_img, result) in enumerate(zip(batch_frames, results)):
                draw = ImageDraw.Draw(pil_img)
                max_score = 0

                try:
                    font = ImageFont.truetype("arial.ttf", 20)
                except IOError:
                    font = ImageFont.load_default()

                boxes, scores, labels = result["boxes"], result["scores"], result["labels"]

for box, score, label in zip(boxes, scores, labels):
    if score.item() >= 0.5:
        box = [round(i, 2) for i in box.tolist()]
        object_label = target
        confidence = round(score.item(), 3)
        annotation = f"{object_label}: {confidence}"

        # Increase line width for the bounding box
        draw.rectangle(box, outline="red", width=4)

        # Increase font size and change color to red
        font_size = 30  # Increased from 20
        try:
            font = ImageFont.truetype("arial.ttf", font_size)
        except IOError:
            font = ImageFont.load_default()

        text_position = (box[0], box[1] - font_size - 5)
        
        # Add a semi-transparent background for better text visibility
        text_bbox = draw.textbbox(text_position, annotation, font=font)
        draw.rectangle(text_bbox, fill=(0, 0, 0, 128))

        # Draw text in red
        draw.text(text_position, annotation, fill="red", font=font)

        max_score = max(max_score, confidence)

                # Save frame to disk
                frame_path = os.path.join(temp_dir, f"frame_{batch_indices[idx]:04d}.png")
                pil_img.save(frame_path)
                frame_paths.append(frame_path)
                frame_scores.append(max_score)

            # Clear batch
            batch_frames = []
            batch_indices = []

        # Clear GPU cache every 10 frames
        if i % 10 == 0:
            torch.cuda.empty_cache()

    cap.release()
    return frame_paths, frame_scores, None

def create_heatmap(frame_scores, current_frame):
    plt.figure(figsize=(16, 4))
    plt.imshow([frame_scores], cmap='hot_r', aspect='auto')
    plt.title('Object Detection Heatmap', fontsize=14)
    plt.xlabel('Frame', fontsize=12)
    plt.yticks([])

    # Add more frame numbers on x-axis
    num_frames = len(frame_scores)
    step = max(1, num_frames // 20)  # Show at most 20 frame numbers
    frame_numbers = range(0, num_frames, step)
    plt.xticks(frame_numbers, [str(i) for i in frame_numbers], rotation=45, ha='right')

    # Add vertical line for current frame
    plt.axvline(x=current_frame, color='blue', linestyle='--', linewidth=2)

    plt.tight_layout()
    
    with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp_file:
        plt.savefig(tmp_file.name, format='png', dpi=400, bbox_inches='tight')
    plt.close()
    
    return tmp_file.name

def load_sample_frame(video_path):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        return None
    ret, frame = cap.read()
    cap.release()
    if not ret:
        return None
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    return frame_rgb

def update_frame_and_heatmap(frame_index, frame_paths, scores):
    if frame_paths and 0 <= frame_index < len(frame_paths):
        frame = Image.open(frame_paths[frame_index])
        heatmap_path = create_heatmap(scores, frame_index)
        return np.array(frame), heatmap_path
    return None, None

def gradio_app():
    with gr.Blocks() as app:
        gr.Markdown("# Video Object Detection with Owlv2")

        video_input = gr.Video(label="Upload Video")
        target_input = gr.Textbox(label="Target Object", value="Elephant")
        frame_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Frame", value=0)
        heatmap_output = gr.Image(label="Detection Heatmap")
        output_image = gr.Image(label="Processed Frame")
        error_output = gr.Textbox(label="Error Messages", visible=False)
        sample_video_frame = gr.Image(value=load_sample_frame("Drone Video of African Wildlife Wild Botswan.mp4"), label="Drone Video of African Wildlife Wild Botswan by wildimagesonline.com - Sample Video Frame")
        use_sample_button = gr.Button("Use Sample Video")
        progress_bar = gr.Progress()

        frame_paths = gr.State([])
        frame_scores = gr.State([])

        def process_and_update(video, target):
            paths, scores, error = process_video(video, target, progress_bar)
            if paths is not None:
                heatmap_path = create_heatmap(scores, 0)
                first_frame = Image.open(paths[0])
                return paths, scores, np.array(first_frame), heatmap_path, error, gr.Slider(maximum=len(paths) - 1, value=0)
            return None, None, None, None, error, gr.Slider(maximum=100, value=0)

        video_input.upload(process_and_update, 
                           inputs=[video_input, target_input], 
                           outputs=[frame_paths, frame_scores, output_image, heatmap_output, error_output, frame_slider])

        frame_slider.change(update_frame_and_heatmap, 
                            inputs=[frame_slider, frame_paths, frame_scores], 
                            outputs=[output_image, heatmap_output])

        def use_sample_video():
            sample_video_path = "Drone Video of African Wildlife Wild Botswan.mp4"
            return process_and_update(sample_video_path, "Elephant")

        use_sample_button.click(use_sample_video, 
                                inputs=None, 
                                outputs=[frame_paths, frame_scores, output_image, heatmap_output, error_output, frame_slider])

        # Layout
        with gr.Row():
            with gr.Column(scale=2):
                output_image
            with gr.Column(scale=1):
                sample_video_frame
                use_sample_button

    return app

if __name__ == "__main__":
    app = gradio_app()
    app.launch(share=True)

    # Cleanup temporary files
    def cleanup():
        for path in frame_paths.value:
            if os.path.exists(path):
                os.remove(path)
        if os.path.exists(temp_dir):
            shutil.rmtree(temp_dir)

    # Make sure to call cleanup when the app is closed
    # This might require additional setup depending on how you're running the app