reab5555 commited on
Commit
36dd82f
·
verified ·
1 Parent(s): f0b1428

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -0
app.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ from PIL import Image, ImageDraw, ImageFont
4
+ import torch
5
+ from transformers import Owlv2Processor, Owlv2ForObjectDetection
6
+ import numpy as np
7
+ import os
8
+
9
+ # Check if CUDA is available, otherwise use CPU
10
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
11
+
12
+ processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
13
+ model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
14
+
15
+ def detect_objects_in_frame(image, target):
16
+ draw = ImageDraw.Draw(image)
17
+ texts = [[target]]
18
+ inputs = processor(text=texts, images=image, return_tensors="pt").to(device)
19
+ outputs = model(**inputs)
20
+
21
+ target_sizes = torch.Tensor([image.size[::-1]])
22
+ results = processor.post_process_object_detection(outputs=outputs, threshold=0.1, target_sizes=target_sizes)
23
+
24
+ color_map = {target: "red"}
25
+
26
+ try:
27
+ font = ImageFont.truetype("arial.ttf", 15)
28
+ except IOError:
29
+ font = ImageFont.load_default()
30
+
31
+ i = 0
32
+ text = texts[i]
33
+ boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
34
+
35
+ for box, score, label in zip(boxes, scores, labels):
36
+ if score.item() >= 0.25:
37
+ box = [round(i, 2) for i in box.tolist()]
38
+ object_label = text[label]
39
+ confidence = round(score.item(), 3)
40
+ annotation = f"{object_label}: {confidence}"
41
+
42
+ draw.rectangle(box, outline=color_map.get(object_label, "red"), width=2)
43
+ text_position = (box[0], box[1] - 10)
44
+ draw.text(text_position, annotation, fill="white", font=font)
45
+
46
+ return image
47
+
48
+ def process_video(video_path, target, progress=gr.Progress()):
49
+ if video_path is None:
50
+ return None, "Error: No video uploaded"
51
+
52
+ if not os.path.exists(video_path):
53
+ return None, f"Error: Video file not found at {video_path}"
54
+
55
+ cap = cv2.VideoCapture(video_path)
56
+ if not cap.isOpened():
57
+ return None, f"Error: Unable to open video file at {video_path}"
58
+
59
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
60
+ original_fps = int(cap.get(cv2.CAP_PROP_FPS))
61
+ original_duration = frame_count / original_fps
62
+
63
+ output_path = "output_video.mp4"
64
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
65
+ out = cv2.VideoWriter(output_path, fourcc, original_fps, (int(cap.get(3)), int(cap.get(4))))
66
+
67
+ for frame in progress.tqdm(range(frame_count)):
68
+ ret, img = cap.read()
69
+ if not ret:
70
+ break
71
+
72
+ pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
73
+ annotated_img = detect_objects_in_frame(pil_img, target)
74
+ annotated_frame = cv2.cvtColor(np.array(annotated_img), cv2.COLOR_RGB2BGR)
75
+ out.write(annotated_frame)
76
+
77
+ cap.release()
78
+ out.release()
79
+
80
+ return output_path, None
81
+
82
+ def load_sample_frame(video_path):
83
+ cap = cv2.VideoCapture(video_path)
84
+ if not cap.isOpened():
85
+ return None
86
+ ret, frame = cap.read()
87
+ cap.release()
88
+ if not ret:
89
+ return None
90
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
91
+ return frame_rgb
92
+
93
+ def gradio_app():
94
+ with gr.Blocks() as app:
95
+ gr.Markdown("# Video Object Detection with Owlv2")
96
+
97
+ video_input = gr.Video(label="Upload Video")
98
+ target_input = gr.Textbox(label="Target Object")
99
+ output_video = gr.Video(label="Output Video")
100
+ error_output = gr.Textbox(label="Error Messages", visible=False)
101
+ sample_video_frame = gr.Image(value=load_sample_frame("IL_Dancing_Sample.mp4"), label="Sample Video Frame")
102
+ use_sample_button = gr.Button("Use Sample Video")
103
+
104
+ video_path = gr.State(None)
105
+ def process_and_update(video, target):
106
+ output_video_path, error = process_video(video, target)
107
+ if error:
108
+ error_output.visible = True
109
+ else:
110
+ error_output.visible = False
111
+ return output_video_path, error
112
+
113
+ video_input.upload(process_and_update,
114
+ inputs=[video_input, target_input],
115
+ outputs=[output_video, error_output])
116
+
117
+ def use_sample_video():
118
+ sample_video_path = "IL_Dancing_Sample.mp4"
119
+ return process_and_update(sample_video_path, "person")
120
+
121
+ use_sample_button.click(use_sample_video,
122
+ inputs=None,
123
+ outputs=[output_video, error_output])
124
+
125
+ return app
126
+
127
+ if __name__ == "__main__":
128
+ app = gradio_app()
129
+ app.launch(share=True)