reab5555 commited on
Commit
343407e
·
verified ·
1 Parent(s): 53eff3d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -17
app.py CHANGED
@@ -5,6 +5,8 @@ import torch
5
  from transformers import Owlv2Processor, Owlv2ForObjectDetection
6
  import numpy as np
7
  import os
 
 
8
 
9
  # Check if CUDA is available, otherwise use CPU
10
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -24,7 +26,7 @@ def detect_objects_in_frame(image, target):
24
  color_map = {target: "red"}
25
 
26
  try:
27
- font = ImageFont.truetype("arial.ttf", 15)
28
  except IOError:
29
  font = ImageFont.load_default()
30
 
@@ -32,6 +34,7 @@ def detect_objects_in_frame(image, target):
32
  text = texts[i]
33
  boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
34
 
 
35
  for box, score, label in zip(boxes, scores, labels):
36
  if score.item() >= 0.25:
37
  box = [round(i, 2) for i in box.tolist()]
@@ -39,42 +42,64 @@ def detect_objects_in_frame(image, target):
39
  confidence = round(score.item(), 3)
40
  annotation = f"{object_label}: {confidence}"
41
 
42
- draw.rectangle(box, outline=color_map.get(object_label, "red"), width=2)
43
- text_position = (box[0], box[1] - 10)
44
  draw.text(text_position, annotation, fill="white", font=font)
45
 
46
- return image
 
 
47
 
48
  def process_video(video_path, target, progress=gr.Progress()):
49
  if video_path is None:
50
- return None, "Error: No video uploaded"
51
 
52
  if not os.path.exists(video_path):
53
- return None, f"Error: Video file not found at {video_path}"
54
 
55
  cap = cv2.VideoCapture(video_path)
56
  if not cap.isOpened():
57
- return None, f"Error: Unable to open video file at {video_path}"
58
 
59
  frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
60
  original_fps = int(cap.get(cv2.CAP_PROP_FPS))
61
  output_fps = 3
 
 
62
 
63
  processed_frames = []
64
- frame_interval = max(1, round(original_fps / output_fps))
65
 
66
- for frame in progress.tqdm(range(0, frame_count, frame_interval)):
67
- cap.set(cv2.CAP_PROP_POS_FRAMES, frame)
 
68
  ret, img = cap.read()
69
  if not ret:
70
  break
71
 
72
  pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
73
- annotated_img = detect_objects_in_frame(pil_img, target)
74
  processed_frames.append(np.array(annotated_img))
 
75
 
76
  cap.release()
77
- return processed_frames, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  def load_sample_frame(video_path):
80
  cap = cv2.VideoCapture(video_path)
@@ -95,18 +120,21 @@ def gradio_app():
95
  target_input = gr.Textbox(label="Target Object", value="Elephant")
96
  frame_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Frame", value=0)
97
  output_image = gr.Image(label="Processed Frame")
 
98
  error_output = gr.Textbox(label="Error Messages", visible=False)
99
  sample_video_frame = gr.Image(value=load_sample_frame("Drone Video of African Wildlife Wild Botswan.mp4"), label="Sample Video Frame")
100
  use_sample_button = gr.Button("Use Sample Video")
101
  progress_bar = gr.Progress()
102
 
103
  processed_frames = gr.State([])
 
104
 
105
  def process_and_update(video, target):
106
- frames, error = process_video(video, target, progress_bar)
107
  if frames is not None:
108
- return frames, frames[0], error, gr.Slider(maximum=len(frames) - 1, value=0)
109
- return None, None, error, gr.Slider(maximum=100, value=0)
 
110
 
111
  def update_frame(frame_index, frames):
112
  if frames and 0 <= frame_index < len(frames):
@@ -115,7 +143,7 @@ def gradio_app():
115
 
116
  video_input.upload(process_and_update,
117
  inputs=[video_input, target_input],
118
- outputs=[processed_frames, output_image, error_output, frame_slider])
119
 
120
  frame_slider.change(update_frame,
121
  inputs=[frame_slider, processed_frames],
@@ -127,7 +155,7 @@ def gradio_app():
127
 
128
  use_sample_button.click(use_sample_video,
129
  inputs=None,
130
- outputs=[processed_frames, output_image, error_output, frame_slider])
131
 
132
  return app
133
 
 
5
  from transformers import Owlv2Processor, Owlv2ForObjectDetection
6
  import numpy as np
7
  import os
8
+ import matplotlib.pyplot as plt
9
+ from io import BytesIO
10
 
11
  # Check if CUDA is available, otherwise use CPU
12
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
26
  color_map = {target: "red"}
27
 
28
  try:
29
+ font = ImageFont.truetype("arial.ttf", 30)
30
  except IOError:
31
  font = ImageFont.load_default()
32
 
 
34
  text = texts[i]
35
  boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
36
 
37
+ max_score = 0
38
  for box, score, label in zip(boxes, scores, labels):
39
  if score.item() >= 0.25:
40
  box = [round(i, 2) for i in box.tolist()]
 
42
  confidence = round(score.item(), 3)
43
  annotation = f"{object_label}: {confidence}"
44
 
45
+ draw.rectangle(box, outline=color_map.get(object_label, "red"), width=4)
46
+ text_position = (box[0], box[1] - 30)
47
  draw.text(text_position, annotation, fill="white", font=font)
48
 
49
+ max_score = max(max_score, confidence)
50
+
51
+ return image, max_score
52
 
53
  def process_video(video_path, target, progress=gr.Progress()):
54
  if video_path is None:
55
+ return None, None, "Error: No video uploaded"
56
 
57
  if not os.path.exists(video_path):
58
+ return None, None, f"Error: Video file not found at {video_path}"
59
 
60
  cap = cv2.VideoCapture(video_path)
61
  if not cap.isOpened():
62
+ return None, None, f"Error: Unable to open video file at {video_path}"
63
 
64
  frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
65
  original_fps = int(cap.get(cv2.CAP_PROP_FPS))
66
  output_fps = 3
67
+ frame_duration = 1 / output_fps
68
+ video_duration = frame_count / original_fps
69
 
70
  processed_frames = []
71
+ frame_scores = []
72
 
73
+ for time in progress.tqdm(np.arange(0, video_duration, frame_duration)):
74
+ frame_number = int(time * original_fps)
75
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
76
  ret, img = cap.read()
77
  if not ret:
78
  break
79
 
80
  pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
81
+ annotated_img, max_score = detect_objects_in_frame(pil_img, target)
82
  processed_frames.append(np.array(annotated_img))
83
+ frame_scores.append(max_score)
84
 
85
  cap.release()
86
+ return processed_frames, frame_scores, None
87
+
88
+ def create_heatmap(frame_scores):
89
+ plt.figure(figsize=(10, 2))
90
+ plt.imshow([frame_scores], cmap='hot', aspect='auto')
91
+ plt.colorbar(label='Confidence')
92
+ plt.title('Object Detection Heatmap')
93
+ plt.xlabel('Frame')
94
+ plt.yticks([])
95
+ plt.tight_layout()
96
+
97
+ buf = BytesIO()
98
+ plt.savefig(buf, format='png')
99
+ buf.seek(0)
100
+ plt.close()
101
+
102
+ return buf
103
 
104
  def load_sample_frame(video_path):
105
  cap = cv2.VideoCapture(video_path)
 
120
  target_input = gr.Textbox(label="Target Object", value="Elephant")
121
  frame_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Frame", value=0)
122
  output_image = gr.Image(label="Processed Frame")
123
+ heatmap_output = gr.Image(label="Detection Heatmap")
124
  error_output = gr.Textbox(label="Error Messages", visible=False)
125
  sample_video_frame = gr.Image(value=load_sample_frame("Drone Video of African Wildlife Wild Botswan.mp4"), label="Sample Video Frame")
126
  use_sample_button = gr.Button("Use Sample Video")
127
  progress_bar = gr.Progress()
128
 
129
  processed_frames = gr.State([])
130
+ frame_scores = gr.State([])
131
 
132
  def process_and_update(video, target):
133
+ frames, scores, error = process_video(video, target, progress_bar)
134
  if frames is not None:
135
+ heatmap = create_heatmap(scores)
136
+ return frames, scores, frames[0], heatmap, error, gr.Slider(maximum=len(frames) - 1, value=0)
137
+ return None, None, None, None, error, gr.Slider(maximum=100, value=0)
138
 
139
  def update_frame(frame_index, frames):
140
  if frames and 0 <= frame_index < len(frames):
 
143
 
144
  video_input.upload(process_and_update,
145
  inputs=[video_input, target_input],
146
+ outputs=[processed_frames, frame_scores, output_image, heatmap_output, error_output, frame_slider])
147
 
148
  frame_slider.change(update_frame,
149
  inputs=[frame_slider, processed_frames],
 
155
 
156
  use_sample_button.click(use_sample_video,
157
  inputs=None,
158
+ outputs=[processed_frames, frame_scores, output_image, heatmap_output, error_output, frame_slider])
159
 
160
  return app
161