reab5555 commited on
Commit
81e2598
·
verified ·
1 Parent(s): 3f2cadc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -182
app.py CHANGED
@@ -1,171 +1,95 @@
 
1
  import cv2
 
 
 
2
  import numpy as np
3
- import matplotlib.pyplot as plt
4
- import gradio as gr
5
- from moviepy.editor import *
6
  import os
7
- import torch
8
- import openpifpaf
9
-
10
- # Ensure NumPy is available
11
- try:
12
- import numpy as np
13
- except ImportError:
14
- os.system('pip install numpy')
15
- import numpy as np
16
-
17
- # OpenPifPaf configuration
18
- predictor = openpifpaf.Predictor(checkpoint='shufflenetv2k16')
19
-
20
- def preprocess(image):
21
- input_size = (192, 256)
22
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
23
- image = cv2.resize(image, input_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  return image
25
 
26
- def total_body_movement(current_poses, prev_poses):
27
- if not current_poses or not prev_poses:
28
- return 0
29
- total_movement = 0
30
- for current_pose in current_poses:
31
- for prev_pose in prev_poses:
32
- movement = np.sum(np.sqrt(np.sum((current_pose - prev_pose)**2, axis=1)))
33
- total_movement += movement
34
- return total_movement / (len(current_poses) * len(prev_poses))
35
-
36
- def process_video(video_path, progress=gr.Progress(), batch_size=64):
37
  if video_path is None:
38
- return None, None, None, None, None, None, "Error: No video uploaded"
39
 
40
  if not os.path.exists(video_path):
41
- return None, None, None, None, None, None, f"Error: Video file not found at {video_path}"
42
 
43
  cap = cv2.VideoCapture(video_path)
44
  if not cap.isOpened():
45
- return None, None, None, None, None, None, f"Error: Unable to open video file at {video_path}"
46
 
47
- original_fps = int(cap.get(cv2.CAP_PROP_FPS))
48
  frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
 
49
  original_duration = frame_count / original_fps
50
-
51
- frame_interval = max(1, round(original_fps / 10)) # Process 10 frames per second
52
-
53
- body_movements = []
54
- time_points = []
55
 
56
- prev_poses = None
 
 
 
 
57
  frames = []
58
- frame_indices = []
59
 
60
- for frame in progress.tqdm(range(0, frame_count, frame_interval)):
61
- cap.set(cv2.CAP_PROP_POS_FRAMES, frame)
62
  ret, img = cap.read()
63
  if not ret:
64
  break
65
- frames.append(img)
66
- frame_indices.append(frame)
67
-
68
- if len(frames) == batch_size:
69
- process_batch(frames, frame_indices, prev_poses, body_movements, time_points, original_fps)
70
- frames = []
71
-
72
- # Process any remaining frames
73
- if frames:
74
- process_batch(frames, frame_indices, prev_poses, body_movements, time_points, original_fps)
75
-
76
- cap.release()
77
 
78
- fig, ax = plt.subplots(figsize=(10, 6), dpi=500)
79
- ax.plot(time_points, body_movements, "-", linewidth=0.5)
80
- ax.set_xlim(0, original_duration)
81
- ax.set_xlabel("Time")
82
- ax.set_ylabel("Body Movement")
83
- ax.set_title("Body Movement Analysis")
84
-
85
- num_labels = 50
86
- label_positions = np.linspace(0, original_duration, num_labels)
87
- label_texts = [f"{int(t//60):02d}:{int(t%60):02d}" for t in label_positions]
88
- ax.set_xticks(label_positions)
89
- ax.set_xticklabels(label_texts, rotation=90, ha='right')
90
- plt.tight_layout()
91
-
92
- return fig, ax, time_points, body_movements, video_path, original_duration, None
93
-
94
- def process_batch(frames, frame_indices, prev_poses, body_movements, time_points, original_fps):
95
- batch_preds = predictor.numpy_images(frames)
96
-
97
- for i, (predictions, frame_index) in enumerate(zip(batch_preds, frame_indices)):
98
- pose_coords = [pred.data for pred in predictions]
99
-
100
- if prev_poses is not None:
101
- movement = total_body_movement(pose_coords, prev_poses)
102
- body_movements.append(movement)
103
- else:
104
- body_movements.append(0)
105
-
106
- prev_poses = pose_coords
107
- time_points.append(frame_index / original_fps)
108
-
109
- def update_video(video_path, time):
110
- if video_path is None:
111
- return None
112
 
113
- if not os.path.exists(video_path):
114
- return None
115
 
116
- cap = cv2.VideoCapture(video_path)
117
- if not cap.isOpened():
118
- return None
 
 
 
119
 
120
- original_fps = int(cap.get(cv2.CAP_PROP_FPS))
121
- frame_number = int(time * original_fps)
122
- cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
123
- ret, img = cap.read()
124
  cap.release()
 
125
 
126
- if not ret:
127
- return None
128
-
129
- predictions, _, _ = predictor.numpy_image(img)
130
- pose_coords = [pred.data for pred in predictions]
131
-
132
- for coords in pose_coords:
133
- for i in range(len(coords)):
134
- x, y = coords[i]
135
- if x > 0 and y > 0:
136
- cv2.circle(img, (int(x), int(y)), 3, (0, 255, 0), -1)
137
-
138
- for pred in predictions:
139
- skeleton = pred.data[:, :2]
140
- for i, j in pred.skeleton:
141
- if skeleton[i, 0] > 0 and skeleton[i, 1] > 0 and skeleton[j, 0] > 0 and skeleton[j, 1] > 0:
142
- cv2.line(img, (int(skeleton[i, 0]), int(skeleton[i, 1])), (int(skeleton[j, 0]), int(skeleton[j, 1])), (255, 0, 0), 2)
143
-
144
- return img
145
-
146
- def update_graph(fig, ax, time_points, body_movements, current_time, video_duration):
147
- ax.clear()
148
- ax.plot(time_points, body_movements, "-", linewidth=0.5)
149
- ax.axvline(x=current_time, color='r', linestyle='--')
150
-
151
- minutes, seconds = divmod(int(current_time), 60)
152
- timecode = f"{minutes:02d}:{seconds:02d}"
153
- ax.text(current_time, ax.get_ylim()[1], timecode,
154
- verticalalignment='top', horizontalalignment='right',
155
- color='r', fontweight='bold', bbox=dict(facecolor='white', edgecolor='none', alpha=0.7))
156
-
157
- ax.set_xlabel("Time")
158
- ax.set_ylabel("Body Movement")
159
- ax.set_title("Body Movement Analysis")
160
-
161
- num_labels = 80
162
- label_positions = np.linspace(0, video_duration, num_labels)
163
- label_texts = [f"{int(t//60):02d}:{int(t%60):02d}" for t in label_positions]
164
- ax.set_xticks(label_positions)
165
- ax.set_xticklabels(label_texts, rotation=90, ha='right')
166
- ax.set_xlim(0, video_duration)
167
- plt.tight_layout()
168
- return fig
169
 
170
  def load_sample_frame(video_path):
171
  cap = cv2.VideoCapture(video_path)
@@ -180,57 +104,36 @@ def load_sample_frame(video_path):
180
 
181
  def gradio_app():
182
  with gr.Blocks() as app:
183
- gr.Markdown("# Multi-Person Body Movement Analysis")
184
-
185
  video_input = gr.Video(label="Upload Video")
186
- graph_output = gr.Plot()
187
- time_slider = gr.Slider(label="Time (seconds)", minimum=0, maximum=100, step=0.1)
188
- video_output = gr.Image(label="Body Posture")
189
-
190
- with gr.Row():
191
- sample_video_frame = gr.Image(value=load_sample_frame("IL_Dancing_Sample.mp4"), label="Sample Video Frame")
192
- use_sample_button = gr.Button("Use Sample Video")
193
-
194
  error_output = gr.Textbox(label="Error Messages", visible=False)
195
-
 
 
 
196
  video_path = gr.State(None)
197
- fig_state = gr.State(None)
198
- ax_state = gr.State(None)
199
- time_points_state = gr.State(None)
200
- body_movements_state = gr.State(None)
201
- video_duration_state = gr.State(None)
202
-
203
- def process_and_update(video):
204
- fig, ax, time_points, body_movements, video_path_value, video_duration, error = process_video(video)
205
- if fig is not None:
206
- time_slider.maximum = video_duration
207
- error_output.visible = False
208
- else:
209
  error_output.visible = True
210
- return fig, video, error, video_path_value, fig, ax, time_points, body_movements, video_duration
 
 
211
 
212
  video_input.upload(process_and_update,
213
- inputs=video_input,
214
- outputs=[graph_output, video_output, error_output, video_path,
215
- fig_state, ax_state, time_points_state, body_movements_state, video_duration_state])
216
-
217
- def update_video_and_graph(video_path_value, current_time, fig, ax, time_points, body_movements, video_duration):
218
- updated_frame = update_video(video_path_value, current_time)
219
- updated_fig = update_graph(fig, ax, time_points, body_movements, current_time, video_duration)
220
- return updated_frame, updated_fig
221
-
222
- time_slider.change(update_video_and_graph,
223
- inputs=[video_path, time_slider, fig_state, ax_state, time_points_state, body_movements_state, video_duration_state],
224
- outputs=[video_output, graph_output])
225
 
226
  def use_sample_video():
227
- sample_video_path = "IL_Dancing_Sample.mp4"
228
- return process_and_update(sample_video_path)
229
 
230
  use_sample_button.click(use_sample_video,
231
  inputs=None,
232
- outputs=[graph_output, video_output, error_output, video_path,
233
- fig_state, ax_state, time_points_state, body_movements_state, video_duration_state])
234
 
235
  return app
236
 
 
1
+ import gradio as gr
2
  import cv2
3
+ from PIL import Image, ImageDraw, ImageFont
4
+ import torch
5
+ from transformers import Owlv2Processor, Owlv2ForObjectDetection
6
  import numpy as np
 
 
 
7
  import os
8
+
9
+ # Check if CUDA is available, otherwise use CPU
10
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
11
+
12
+ processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
13
+ model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
14
+
15
+ def detect_objects_in_frame(image, target):
16
+ draw = ImageDraw.Draw(image)
17
+ texts = [[target]]
18
+ inputs = processor(text=texts, images=image, return_tensors="pt", padding=True).to(device)
19
+ outputs = model(**inputs)
20
+
21
+ target_sizes = torch.Tensor([image.size[::-1]])
22
+ results = processor.post_process_object_detection(outputs=outputs, threshold=0.1, target_sizes=target_sizes)
23
+
24
+ color_map = {target: "red"}
25
+
26
+ try:
27
+ font = ImageFont.truetype("arial.ttf", 15)
28
+ except IOError:
29
+ font = ImageFont.load_default()
30
+
31
+ i = 0
32
+ text = texts[i]
33
+ boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
34
+
35
+ for box, score, label in zip(boxes, scores, labels):
36
+ if score.item() >= 0.25:
37
+ box = [round(i, 2) for i in box.tolist()]
38
+ object_label = text[label]
39
+ confidence = round(score.item(), 3)
40
+ annotation = f"{object_label}: {confidence}"
41
+
42
+ draw.rectangle(box, outline=color_map.get(object_label, "red"), width=2)
43
+ text_position = (box[0], box[1] - 10)
44
+ draw.text(text_position, annotation, fill="white", font=font)
45
+
46
  return image
47
 
48
+ def process_video(video_path, target, progress=gr.Progress()):
 
 
 
 
 
 
 
 
 
 
49
  if video_path is None:
50
+ return None, "Error: No video uploaded"
51
 
52
  if not os.path.exists(video_path):
53
+ return None, f"Error: Video file not found at {video_path}"
54
 
55
  cap = cv2.VideoCapture(video_path)
56
  if not cap.isOpened():
57
+ return None, f"Error: Unable to open video file at {video_path}"
58
 
 
59
  frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
60
+ original_fps = int(cap.get(cv2.CAP_PROP_FPS))
61
  original_duration = frame_count / original_fps
62
+ output_fps = 5
 
 
 
 
63
 
64
+ output_path = "output_video.mp4"
65
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
66
+ out = cv2.VideoWriter(output_path, fourcc, output_fps, (int(cap.get(3)), int(cap.get(4))))
67
+
68
+ batch_size = 64
69
  frames = []
 
70
 
71
+ for frame in progress.tqdm(range(frame_count)):
 
72
  ret, img = cap.read()
73
  if not ret:
74
  break
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ if frame % (original_fps // output_fps) != 0:
77
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
80
+ frames.append(pil_img)
81
 
82
+ if len(frames) == batch_size or frame == frame_count - 1:
83
+ annotated_frames = [detect_objects_in_frame(frame, target) for frame in frames]
84
+ for annotated_img in annotated_frames:
85
+ annotated_frame = cv2.cvtColor(np.array(annotated_img), cv2.COLOR_RGB2BGR)
86
+ out.write(annotated_frame)
87
+ frames = []
88
 
 
 
 
 
89
  cap.release()
90
+ out.release()
91
 
92
+ return output_path, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  def load_sample_frame(video_path):
95
  cap = cv2.VideoCapture(video_path)
 
104
 
105
  def gradio_app():
106
  with gr.Blocks() as app:
107
+ gr.Markdown("# Video Object Detection with Owlv2")
108
+
109
  video_input = gr.Video(label="Upload Video")
110
+ target_input = gr.Textbox(label="Target Object")
111
+ output_video = gr.Video(label="Output Video")
 
 
 
 
 
 
112
  error_output = gr.Textbox(label="Error Messages", visible=False)
113
+ sample_video_frame = gr.Image(value=load_sample_frame("Drone Video of African Wildlife Wild Botswan.mp4"), label="Sample Video Frame")
114
+ use_sample_button = gr.Button("Use Sample Video")
115
+ progress_bar = gr.Progress()
116
+
117
  video_path = gr.State(None)
118
+ def process_and_update(video, target):
119
+ output_video_path, error = process_video(video, target, progress_bar)
120
+ if error:
 
 
 
 
 
 
 
 
 
121
  error_output.visible = True
122
+ else:
123
+ error_output.visible = False
124
+ return output_video_path, error
125
 
126
  video_input.upload(process_and_update,
127
+ inputs=[video_input, target_input],
128
+ outputs=[output_video, error_output])
 
 
 
 
 
 
 
 
 
 
129
 
130
  def use_sample_video():
131
+ sample_video_path = "Drone Video of African Wildlife Wild Botswan.mp4"
132
+ return process_and_update(sample_video_path, "animal")
133
 
134
  use_sample_button.click(use_sample_video,
135
  inputs=None,
136
+ outputs=[output_video, error_output])
 
137
 
138
  return app
139