reab5555 commited on
Commit
4f98062
·
verified ·
1 Parent(s): 2934aa1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -40
app.py CHANGED
@@ -15,42 +15,6 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
  processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
16
  model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
17
 
18
- def detect_objects_in_frame(image, target):
19
- draw = ImageDraw.Draw(image)
20
- texts = [[target]]
21
- inputs = processor(text=texts, images=image, return_tensors="pt", padding=True).to(device)
22
- outputs = model(**inputs)
23
-
24
- target_sizes = torch.Tensor([image.size[::-1]])
25
- results = processor.post_process_object_detection(outputs=outputs, threshold=0.1, target_sizes=target_sizes)
26
-
27
- color_map = {target: "red"}
28
-
29
- try:
30
- font = ImageFont.truetype("arial.ttf", 30)
31
- except IOError:
32
- font = ImageFont.load_default()
33
-
34
- i = 0
35
- text = texts[i]
36
- boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
37
-
38
- max_score = 0
39
- for box, score, label in zip(boxes, scores, labels):
40
- if score.item() >= 0.25:
41
- box = [round(i, 2) for i in box.tolist()]
42
- object_label = text[label]
43
- confidence = round(score.item(), 3)
44
- annotation = f"{object_label}: {confidence}"
45
-
46
- draw.rectangle(box, outline=color_map.get(object_label, "red"), width=4)
47
- text_position = (box[0], box[1] - 30)
48
- draw.text(text_position, annotation, fill="white", font=font)
49
-
50
- max_score = max(max_score, confidence)
51
-
52
- return image, max_score
53
-
54
  def process_video(video_path, target, progress=gr.Progress()):
55
  if video_path is None:
56
  return None, None, "Error: No video uploaded"
@@ -70,6 +34,9 @@ def process_video(video_path, target, progress=gr.Progress()):
70
 
71
  processed_frames = []
72
  frame_scores = []
 
 
 
73
 
74
  for time in progress.tqdm(np.arange(0, video_duration, frame_duration)):
75
  frame_number = int(time * original_fps)
@@ -79,9 +46,47 @@ def process_video(video_path, target, progress=gr.Progress()):
79
  break
80
 
81
  pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
82
- annotated_img, max_score = detect_objects_in_frame(pil_img, target)
83
- processed_frames.append(np.array(annotated_img))
84
- frame_scores.append(max_score)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  cap.release()
87
  return processed_frames, frame_scores, None
@@ -115,7 +120,7 @@ def load_sample_frame(video_path):
115
 
116
  def gradio_app():
117
  with gr.Blocks() as app:
118
- gr.Markdown("# Video Object Detection with Owlv2 (3 FPS)")
119
 
120
  video_input = gr.Video(label="Upload Video")
121
  target_input = gr.Textbox(label="Target Object", value="Elephant")
 
15
  processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
16
  model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def process_video(video_path, target, progress=gr.Progress()):
19
  if video_path is None:
20
  return None, None, "Error: No video uploaded"
 
34
 
35
  processed_frames = []
36
  frame_scores = []
37
+ batch_size = 32
38
+ batch_frames = []
39
+ batch_times = []
40
 
41
  for time in progress.tqdm(np.arange(0, video_duration, frame_duration)):
42
  frame_number = int(time * original_fps)
 
46
  break
47
 
48
  pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
49
+ batch_frames.append(pil_img)
50
+ batch_times.append(time)
51
+
52
+ if len(batch_frames) == batch_size or time + frame_duration >= video_duration:
53
+ # Process the batch
54
+ texts = [[target]] * len(batch_frames)
55
+ inputs = processor(text=texts, images=batch_frames, return_tensors="pt", padding=True).to(device)
56
+ outputs = model(**inputs)
57
+
58
+ for i, (image, batch_time) in enumerate(zip(batch_frames, batch_times)):
59
+ target_sizes = torch.Tensor([image.size[::-1]])
60
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)
61
+
62
+ draw = ImageDraw.Draw(image)
63
+ max_score = 0
64
+
65
+ try:
66
+ font = ImageFont.truetype("arial.ttf", 30)
67
+ except IOError:
68
+ font = ImageFont.load_default()
69
+
70
+ boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
71
+
72
+ for box, score, label in zip(boxes, scores, labels):
73
+ if score.item() >= 0.25:
74
+ box = [round(i, 2) for i in box.tolist()]
75
+ object_label = target
76
+ confidence = round(score.item(), 3)
77
+ annotation = f"{object_label}: {confidence}"
78
+
79
+ draw.rectangle(box, outline="red", width=4)
80
+ text_position = (box[0], box[1] - 30)
81
+ draw.text(text_position, annotation, fill="white", font=font)
82
+
83
+ max_score = max(max_score, confidence)
84
+
85
+ processed_frames.append(np.array(image))
86
+ frame_scores.append(max_score)
87
+
88
+ batch_frames = []
89
+ batch_times = []
90
 
91
  cap.release()
92
  return processed_frames, frame_scores, None
 
120
 
121
  def gradio_app():
122
  with gr.Blocks() as app:
123
+ gr.Markdown("# Video Object Detection with Owlv2 (3 FPS, Batch Size 32)")
124
 
125
  video_input = gr.Video(label="Upload Video")
126
  target_input = gr.Textbox(label="Target Object", value="Elephant")