reab5555 commited on
Commit
3f2cadc
·
verified ·
1 Parent(s): 5813a90

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +182 -84
app.py CHANGED
@@ -1,95 +1,171 @@
1
- import gradio as gr
2
  import cv2
3
- from PIL import Image, ImageDraw, ImageFont
4
- import torch
5
- from transformers import Owlv2Processor, Owlv2ForObjectDetection
6
  import numpy as np
 
 
 
7
  import os
8
-
9
- # Check if CUDA is available, otherwise use CPU
10
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
11
-
12
- processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
13
- model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
14
-
15
- def detect_objects_in_frame(image, target):
16
- draw = ImageDraw.Draw(image)
17
- texts = [[target]]
18
- inputs = processor(text=texts, images=image, return_tensors="pt", padding=True).to(device)
19
- outputs = model(**inputs)
20
-
21
- target_sizes = torch.Tensor([image.size[::-1]])
22
- results = processor.post_process_object_detection(outputs=outputs, threshold=0.1, target_sizes=target_sizes)
23
-
24
- color_map = {target: "red"}
25
-
26
- try:
27
- font = ImageFont.truetype("arial.ttf", 15)
28
- except IOError:
29
- font = ImageFont.load_default()
30
-
31
- i = 0
32
- text = texts[i]
33
- boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
34
-
35
- for box, score, label in zip(boxes, scores, labels):
36
- if score.item() >= 0.25:
37
- box = [round(i, 2) for i in box.tolist()]
38
- object_label = text[label]
39
- confidence = round(score.item(), 3)
40
- annotation = f"{object_label}: {confidence}"
41
-
42
- draw.rectangle(box, outline=color_map.get(object_label, "red"), width=2)
43
- text_position = (box[0], box[1] - 10)
44
- draw.text(text_position, annotation, fill="white", font=font)
45
-
46
  return image
47
 
48
- def process_video(video_path, target, progress=gr.Progress()):
 
 
 
 
 
 
 
 
 
 
49
  if video_path is None:
50
- return None, "Error: No video uploaded"
51
 
52
  if not os.path.exists(video_path):
53
- return None, f"Error: Video file not found at {video_path}"
54
 
55
  cap = cv2.VideoCapture(video_path)
56
  if not cap.isOpened():
57
- return None, f"Error: Unable to open video file at {video_path}"
58
 
59
- frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
60
  original_fps = int(cap.get(cv2.CAP_PROP_FPS))
 
61
  original_duration = frame_count / original_fps
62
- output_fps = 5
 
 
 
 
63
 
64
- output_path = "output_video.mp4"
65
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
66
- out = cv2.VideoWriter(output_path, fourcc, output_fps, (int(cap.get(3)), int(cap.get(4))))
67
-
68
- batch_size = 64
69
  frames = []
 
70
 
71
- for frame in progress.tqdm(range(frame_count)):
 
72
  ret, img = cap.read()
73
  if not ret:
74
  break
 
 
75
 
76
- if frame % (original_fps // output_fps) != 0:
77
- continue
 
78
 
79
- pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
80
- frames.append(pil_img)
 
81
 
82
- if len(frames) == batch_size or frame == frame_count - 1:
83
- annotated_frames = [detect_objects_in_frame(frame, target) for frame in frames]
84
- for annotated_img in annotated_frames:
85
- annotated_frame = cv2.cvtColor(np.array(annotated_img), cv2.COLOR_RGB2BGR)
86
- out.write(annotated_frame)
87
- frames = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
 
 
 
 
 
 
 
 
89
  cap.release()
90
- out.release()
91
 
92
- return output_path, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  def load_sample_frame(video_path):
95
  cap = cv2.VideoCapture(video_path)
@@ -104,35 +180,57 @@ def load_sample_frame(video_path):
104
 
105
  def gradio_app():
106
  with gr.Blocks() as app:
107
- gr.Markdown("# Video Object Detection with Owlv2")
108
-
109
  video_input = gr.Video(label="Upload Video")
110
- target_input = gr.Textbox(label="Target Object")
111
- output_video = gr.Video(label="Output Video")
 
 
 
 
 
 
112
  error_output = gr.Textbox(label="Error Messages", visible=False)
113
- sample_video_frame = gr.Image(value=load_sample_frame("Drone Video of African Wildlife Wild Botswan.mp4"), label="Sample Video Frame")
114
- use_sample_button = gr.Button("Use Sample Video")
115
-
116
  video_path = gr.State(None)
117
- def process_and_update(video, target):
118
- output_video_path, error = process_video(video, target)
119
- if error:
120
- error_output.visible = True
121
- else:
 
 
 
 
 
122
  error_output.visible = False
123
- return output_video_path, error
 
 
124
 
125
  video_input.upload(process_and_update,
126
- inputs=[video_input, target_input],
127
- outputs=[output_video, error_output])
 
 
 
 
 
 
 
 
 
 
128
 
129
  def use_sample_video():
130
- sample_video_path = "Drone Video of African Wildlife Wild Botswan.mp4"
131
- return process_and_update(sample_video_path, "animal")
132
 
133
  use_sample_button.click(use_sample_video,
134
  inputs=None,
135
- outputs=[output_video, error_output])
 
136
 
137
  return app
138
 
 
 
1
  import cv2
 
 
 
2
  import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import gradio as gr
5
+ from moviepy.editor import *
6
  import os
7
+ import torch
8
+ import openpifpaf
9
+
10
+ # Ensure NumPy is available
11
+ try:
12
+ import numpy as np
13
+ except ImportError:
14
+ os.system('pip install numpy')
15
+ import numpy as np
16
+
17
+ # OpenPifPaf configuration
18
+ predictor = openpifpaf.Predictor(checkpoint='shufflenetv2k16')
19
+
20
+ def preprocess(image):
21
+ input_size = (192, 256)
22
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
23
+ image = cv2.resize(image, input_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  return image
25
 
26
+ def total_body_movement(current_poses, prev_poses):
27
+ if not current_poses or not prev_poses:
28
+ return 0
29
+ total_movement = 0
30
+ for current_pose in current_poses:
31
+ for prev_pose in prev_poses:
32
+ movement = np.sum(np.sqrt(np.sum((current_pose - prev_pose)**2, axis=1)))
33
+ total_movement += movement
34
+ return total_movement / (len(current_poses) * len(prev_poses))
35
+
36
+ def process_video(video_path, progress=gr.Progress(), batch_size=64):
37
  if video_path is None:
38
+ return None, None, None, None, None, None, "Error: No video uploaded"
39
 
40
  if not os.path.exists(video_path):
41
+ return None, None, None, None, None, None, f"Error: Video file not found at {video_path}"
42
 
43
  cap = cv2.VideoCapture(video_path)
44
  if not cap.isOpened():
45
+ return None, None, None, None, None, None, f"Error: Unable to open video file at {video_path}"
46
 
 
47
  original_fps = int(cap.get(cv2.CAP_PROP_FPS))
48
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
49
  original_duration = frame_count / original_fps
50
+
51
+ frame_interval = max(1, round(original_fps / 10)) # Process 10 frames per second
52
+
53
+ body_movements = []
54
+ time_points = []
55
 
56
+ prev_poses = None
 
 
 
 
57
  frames = []
58
+ frame_indices = []
59
 
60
+ for frame in progress.tqdm(range(0, frame_count, frame_interval)):
61
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame)
62
  ret, img = cap.read()
63
  if not ret:
64
  break
65
+ frames.append(img)
66
+ frame_indices.append(frame)
67
 
68
+ if len(frames) == batch_size:
69
+ process_batch(frames, frame_indices, prev_poses, body_movements, time_points, original_fps)
70
+ frames = []
71
 
72
+ # Process any remaining frames
73
+ if frames:
74
+ process_batch(frames, frame_indices, prev_poses, body_movements, time_points, original_fps)
75
 
76
+ cap.release()
77
+
78
+ fig, ax = plt.subplots(figsize=(10, 6), dpi=500)
79
+ ax.plot(time_points, body_movements, "-", linewidth=0.5)
80
+ ax.set_xlim(0, original_duration)
81
+ ax.set_xlabel("Time")
82
+ ax.set_ylabel("Body Movement")
83
+ ax.set_title("Body Movement Analysis")
84
+
85
+ num_labels = 50
86
+ label_positions = np.linspace(0, original_duration, num_labels)
87
+ label_texts = [f"{int(t//60):02d}:{int(t%60):02d}" for t in label_positions]
88
+ ax.set_xticks(label_positions)
89
+ ax.set_xticklabels(label_texts, rotation=90, ha='right')
90
+ plt.tight_layout()
91
+
92
+ return fig, ax, time_points, body_movements, video_path, original_duration, None
93
+
94
+ def process_batch(frames, frame_indices, prev_poses, body_movements, time_points, original_fps):
95
+ batch_preds = predictor.numpy_images(frames)
96
+
97
+ for i, (predictions, frame_index) in enumerate(zip(batch_preds, frame_indices)):
98
+ pose_coords = [pred.data for pred in predictions]
99
+
100
+ if prev_poses is not None:
101
+ movement = total_body_movement(pose_coords, prev_poses)
102
+ body_movements.append(movement)
103
+ else:
104
+ body_movements.append(0)
105
+
106
+ prev_poses = pose_coords
107
+ time_points.append(frame_index / original_fps)
108
+
109
+ def update_video(video_path, time):
110
+ if video_path is None:
111
+ return None
112
+
113
+ if not os.path.exists(video_path):
114
+ return None
115
 
116
+ cap = cv2.VideoCapture(video_path)
117
+ if not cap.isOpened():
118
+ return None
119
+
120
+ original_fps = int(cap.get(cv2.CAP_PROP_FPS))
121
+ frame_number = int(time * original_fps)
122
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
123
+ ret, img = cap.read()
124
  cap.release()
 
125
 
126
+ if not ret:
127
+ return None
128
+
129
+ predictions, _, _ = predictor.numpy_image(img)
130
+ pose_coords = [pred.data for pred in predictions]
131
+
132
+ for coords in pose_coords:
133
+ for i in range(len(coords)):
134
+ x, y = coords[i]
135
+ if x > 0 and y > 0:
136
+ cv2.circle(img, (int(x), int(y)), 3, (0, 255, 0), -1)
137
+
138
+ for pred in predictions:
139
+ skeleton = pred.data[:, :2]
140
+ for i, j in pred.skeleton:
141
+ if skeleton[i, 0] > 0 and skeleton[i, 1] > 0 and skeleton[j, 0] > 0 and skeleton[j, 1] > 0:
142
+ cv2.line(img, (int(skeleton[i, 0]), int(skeleton[i, 1])), (int(skeleton[j, 0]), int(skeleton[j, 1])), (255, 0, 0), 2)
143
+
144
+ return img
145
+
146
+ def update_graph(fig, ax, time_points, body_movements, current_time, video_duration):
147
+ ax.clear()
148
+ ax.plot(time_points, body_movements, "-", linewidth=0.5)
149
+ ax.axvline(x=current_time, color='r', linestyle='--')
150
+
151
+ minutes, seconds = divmod(int(current_time), 60)
152
+ timecode = f"{minutes:02d}:{seconds:02d}"
153
+ ax.text(current_time, ax.get_ylim()[1], timecode,
154
+ verticalalignment='top', horizontalalignment='right',
155
+ color='r', fontweight='bold', bbox=dict(facecolor='white', edgecolor='none', alpha=0.7))
156
+
157
+ ax.set_xlabel("Time")
158
+ ax.set_ylabel("Body Movement")
159
+ ax.set_title("Body Movement Analysis")
160
+
161
+ num_labels = 80
162
+ label_positions = np.linspace(0, video_duration, num_labels)
163
+ label_texts = [f"{int(t//60):02d}:{int(t%60):02d}" for t in label_positions]
164
+ ax.set_xticks(label_positions)
165
+ ax.set_xticklabels(label_texts, rotation=90, ha='right')
166
+ ax.set_xlim(0, video_duration)
167
+ plt.tight_layout()
168
+ return fig
169
 
170
  def load_sample_frame(video_path):
171
  cap = cv2.VideoCapture(video_path)
 
180
 
181
  def gradio_app():
182
  with gr.Blocks() as app:
183
+ gr.Markdown("# Multi-Person Body Movement Analysis")
184
+
185
  video_input = gr.Video(label="Upload Video")
186
+ graph_output = gr.Plot()
187
+ time_slider = gr.Slider(label="Time (seconds)", minimum=0, maximum=100, step=0.1)
188
+ video_output = gr.Image(label="Body Posture")
189
+
190
+ with gr.Row():
191
+ sample_video_frame = gr.Image(value=load_sample_frame("IL_Dancing_Sample.mp4"), label="Sample Video Frame")
192
+ use_sample_button = gr.Button("Use Sample Video")
193
+
194
  error_output = gr.Textbox(label="Error Messages", visible=False)
195
+
 
 
196
  video_path = gr.State(None)
197
+ fig_state = gr.State(None)
198
+ ax_state = gr.State(None)
199
+ time_points_state = gr.State(None)
200
+ body_movements_state = gr.State(None)
201
+ video_duration_state = gr.State(None)
202
+
203
+ def process_and_update(video):
204
+ fig, ax, time_points, body_movements, video_path_value, video_duration, error = process_video(video)
205
+ if fig is not None:
206
+ time_slider.maximum = video_duration
207
  error_output.visible = False
208
+ else:
209
+ error_output.visible = True
210
+ return fig, video, error, video_path_value, fig, ax, time_points, body_movements, video_duration
211
 
212
  video_input.upload(process_and_update,
213
+ inputs=video_input,
214
+ outputs=[graph_output, video_output, error_output, video_path,
215
+ fig_state, ax_state, time_points_state, body_movements_state, video_duration_state])
216
+
217
+ def update_video_and_graph(video_path_value, current_time, fig, ax, time_points, body_movements, video_duration):
218
+ updated_frame = update_video(video_path_value, current_time)
219
+ updated_fig = update_graph(fig, ax, time_points, body_movements, current_time, video_duration)
220
+ return updated_frame, updated_fig
221
+
222
+ time_slider.change(update_video_and_graph,
223
+ inputs=[video_path, time_slider, fig_state, ax_state, time_points_state, body_movements_state, video_duration_state],
224
+ outputs=[video_output, graph_output])
225
 
226
  def use_sample_video():
227
+ sample_video_path = "IL_Dancing_Sample.mp4"
228
+ return process_and_update(sample_video_path)
229
 
230
  use_sample_button.click(use_sample_video,
231
  inputs=None,
232
+ outputs=[graph_output, video_output, error_output, video_path,
233
+ fig_state, ax_state, time_points_state, body_movements_state, video_duration_state])
234
 
235
  return app
236