reab5555 commited on
Commit
08a3f43
·
verified ·
1 Parent(s): ec0d71c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -46
app.py CHANGED
@@ -6,7 +6,6 @@ from transformers import Owlv2Processor, Owlv2ForObjectDetection
6
  import numpy as np
7
  import os
8
  import matplotlib.pyplot as plt
9
- from io import BytesIO
10
  import tempfile
11
  import shutil
12
 
@@ -24,7 +23,6 @@ except RuntimeError:
24
  device = torch.device("cpu")
25
  model = model.to(device)
26
 
27
-
28
  def process_video(video_path, target, progress=gr.Progress()):
29
  if video_path is None:
30
  return None, None, "Error: No video uploaded"
@@ -46,16 +44,7 @@ def process_video(video_path, target, progress=gr.Progress()):
46
  temp_dir = tempfile.mkdtemp()
47
  frame_paths = []
48
 
49
- # Try to use GPU with half precision, fall back to CPU if out of memory
50
- try:
51
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
- model.to(device).half() # Convert model to half precision
53
- except RuntimeError:
54
- print("GPU out of memory, falling back to CPU")
55
- device = torch.device("cpu")
56
- model.to(device)
57
-
58
- batch_size = 1
59
  batch_frames = []
60
  batch_indices = []
61
 
@@ -67,8 +56,8 @@ def process_video(video_path, target, progress=gr.Progress()):
67
  break
68
 
69
  # Resize the frame
70
- #img_resized = cv2.resize(img, (1280, 720))
71
- pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
72
 
73
  batch_frames.append(pil_img)
74
  batch_indices.append(i)
@@ -87,42 +76,36 @@ def process_video(video_path, target, progress=gr.Progress()):
87
  draw = ImageDraw.Draw(pil_img)
88
  max_score = 0
89
 
90
- try:
91
- font = ImageFont.truetype("arial.ttf", 20)
92
- except IOError:
93
- font = ImageFont.load_default()
94
-
95
  boxes, scores, labels = result["boxes"], result["scores"], result["labels"]
96
 
97
- for box, score, label in zip(boxes, scores, labels):
98
- if score.item() >= 0.5:
99
- box = [round(i, 2) for i in box.tolist()]
100
- object_label = target
101
- confidence = round(score.item(), 3)
102
- annotation = f"{object_label}: {confidence}"
103
 
104
- # Increase line width for the bounding box
105
- draw.rectangle(box, outline="red", width=4)
106
 
107
- # Increase font size and change color to red
108
- font_size = 30 # Increased from 20
109
- try:
110
- font = ImageFont.truetype("arial.ttf", font_size)
111
- except IOError:
112
- font = ImageFont.load_default()
113
 
114
- text_position = (box[0], box[1] - font_size - 5)
115
-
116
- # Add a semi-transparent background for better text visibility
117
- text_bbox = draw.textbbox(text_position, annotation, font=font)
118
- draw.rectangle(text_bbox, fill=(0, 0, 0, 128))
119
 
120
- # Draw text in red
121
- draw.text(text_position, annotation, fill="red", font=font)
122
 
123
- max_score = max(max_score, confidence)
124
 
125
- # Save frame to disk
126
  frame_path = os.path.join(temp_dir, f"frame_{batch_indices[idx]:04d}.png")
127
  pil_img.save(frame_path)
128
  frame_paths.append(frame_path)
@@ -146,13 +129,11 @@ def create_heatmap(frame_scores, current_frame):
146
  plt.xlabel('Frame', fontsize=12)
147
  plt.yticks([])
148
 
149
- # Add more frame numbers on x-axis
150
  num_frames = len(frame_scores)
151
- step = max(1, num_frames // 20) # Show at most 20 frame numbers
152
  frame_numbers = range(0, num_frames, step)
153
  plt.xticks(frame_numbers, [str(i) for i in frame_numbers], rotation=45, ha='right')
154
 
155
- # Add vertical line for current frame
156
  plt.axvline(x=current_frame, color='blue', linestyle='--', linewidth=2)
157
 
158
  plt.tight_layout()
@@ -234,7 +215,7 @@ def gradio_app():
234
 
235
  if __name__ == "__main__":
236
  app = gradio_app()
237
- app.launch(share=True)
238
 
239
  # Cleanup temporary files
240
  def cleanup():
 
6
  import numpy as np
7
  import os
8
  import matplotlib.pyplot as plt
 
9
  import tempfile
10
  import shutil
11
 
 
23
  device = torch.device("cpu")
24
  model = model.to(device)
25
 
 
26
  def process_video(video_path, target, progress=gr.Progress()):
27
  if video_path is None:
28
  return None, None, "Error: No video uploaded"
 
44
  temp_dir = tempfile.mkdtemp()
45
  frame_paths = []
46
 
47
+ batch_size = 4 # Process 4 frames at a time
 
 
 
 
 
 
 
 
 
48
  batch_frames = []
49
  batch_indices = []
50
 
 
56
  break
57
 
58
  # Resize the frame
59
+ img_resized = cv2.resize(img, (640, 360))
60
+ pil_img = Image.fromarray(cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB))
61
 
62
  batch_frames.append(pil_img)
63
  batch_indices.append(i)
 
76
  draw = ImageDraw.Draw(pil_img)
77
  max_score = 0
78
 
 
 
 
 
 
79
  boxes, scores, labels = result["boxes"], result["scores"], result["labels"]
80
 
81
+ for box, score, label in zip(boxes, scores, labels):
82
+ if score.item() >= 0.5:
83
+ box = [round(i, 2) for i in box.tolist()]
84
+ object_label = target
85
+ confidence = round(score.item(), 3)
86
+ annotation = f"{object_label}: {confidence}"
87
 
88
+ # Increase line width for the bounding box
89
+ draw.rectangle(box, outline="red", width=4)
90
 
91
+ # Increase font size and change color to red
92
+ font_size = 30
93
+ try:
94
+ font = ImageFont.truetype("arial.ttf", font_size)
95
+ except IOError:
96
+ font = ImageFont.load_default()
97
 
98
+ text_position = (box[0], box[1] - font_size - 5)
99
+
100
+ # Add a semi-transparent background for better text visibility
101
+ text_bbox = draw.textbbox(text_position, annotation, font=font)
102
+ draw.rectangle(text_bbox, fill=(0, 0, 0, 128))
103
 
104
+ # Draw text in red
105
+ draw.text(text_position, annotation, fill="red", font=font)
106
 
107
+ max_score = max(max_score, confidence)
108
 
 
109
  frame_path = os.path.join(temp_dir, f"frame_{batch_indices[idx]:04d}.png")
110
  pil_img.save(frame_path)
111
  frame_paths.append(frame_path)
 
129
  plt.xlabel('Frame', fontsize=12)
130
  plt.yticks([])
131
 
 
132
  num_frames = len(frame_scores)
133
+ step = max(1, num_frames // 20)
134
  frame_numbers = range(0, num_frames, step)
135
  plt.xticks(frame_numbers, [str(i) for i in frame_numbers], rotation=45, ha='right')
136
 
 
137
  plt.axvline(x=current_frame, color='blue', linestyle='--', linewidth=2)
138
 
139
  plt.tight_layout()
 
215
 
216
  if __name__ == "__main__":
217
  app = gradio_app()
218
+ app.launch()
219
 
220
  # Cleanup temporary files
221
  def cleanup():