reab5555 commited on
Commit
7d47fdc
·
verified ·
1 Parent(s): 3cb7297

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -45
app.py CHANGED
@@ -16,7 +16,7 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
16
  processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
17
  model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
18
 
19
- @spaces.GPU(duration=200)
20
  def process_video(video_path, target, progress=gr.Progress()):
21
  if video_path is None:
22
  return None, None, "Error: No video uploaded"
@@ -36,9 +36,6 @@ def process_video(video_path, target, progress=gr.Progress()):
36
 
37
  processed_frames = []
38
  frame_scores = []
39
- batch_size = 2
40
- batch_frames = []
41
- batch_times = []
42
 
43
  for time in progress.tqdm(np.arange(0, video_duration, frame_duration)):
44
  frame_number = int(time * original_fps)
@@ -48,47 +45,39 @@ def process_video(video_path, target, progress=gr.Progress()):
48
  break
49
 
50
  pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
51
- batch_frames.append(pil_img)
52
- batch_times.append(time)
53
-
54
- if len(batch_frames) == batch_size or time + frame_duration >= video_duration:
55
- # Process the batch
56
- texts = [[target]] * len(batch_frames)
57
- inputs = processor(text=texts, images=batch_frames, return_tensors="pt", padding=True).to(device)
58
- outputs = model(**inputs)
59
-
60
- for i, (image, batch_time) in enumerate(zip(batch_frames, batch_times)):
61
- target_sizes = torch.Tensor([image.size[::-1]])
62
- results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)
63
-
64
- draw = ImageDraw.Draw(image)
65
- max_score = 0
66
-
67
- try:
68
- font = ImageFont.truetype("arial.ttf", 40)
69
- except IOError:
70
- font = ImageFont.load_default()
71
-
72
- boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
73
-
74
- for box, score, label in zip(boxes, scores, labels):
75
- if score.item() >= 0.25:
76
- box = [round(i, 2) for i in box.tolist()]
77
- object_label = target
78
- confidence = round(score.item(), 3)
79
- annotation = f"{object_label}: {confidence}"
80
-
81
- draw.rectangle(box, outline="red", width=2)
82
- text_position = (box[0], box[1] - 30)
83
- draw.text(text_position, annotation, fill="white", font=font)
84
-
85
- max_score = max(max_score, confidence)
86
-
87
- processed_frames.append(np.array(image))
88
- frame_scores.append(max_score)
89
-
90
- batch_frames = []
91
- batch_times = []
92
 
93
  cap.release()
94
  return processed_frames, frame_scores, None
 
16
  processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
17
  model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
18
 
19
+ @spaces.GPU(duration=120)
20
  def process_video(video_path, target, progress=gr.Progress()):
21
  if video_path is None:
22
  return None, None, "Error: No video uploaded"
 
36
 
37
  processed_frames = []
38
  frame_scores = []
 
 
 
39
 
40
  for time in progress.tqdm(np.arange(0, video_duration, frame_duration)):
41
  frame_number = int(time * original_fps)
 
45
  break
46
 
47
  pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
48
+
49
+ # Process single image
50
+ inputs = processor(text=[target], images=pil_img, return_tensors="pt", padding=True).to(device)
51
+ outputs = model(**inputs)
52
+
53
+ target_sizes = torch.Tensor([pil_img.size[::-1]])
54
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)
55
+
56
+ draw = ImageDraw.Draw(pil_img)
57
+ max_score = 0
58
+
59
+ try:
60
+ font = ImageFont.truetype("arial.ttf", 40)
61
+ except IOError:
62
+ font = ImageFont.load_default()
63
+
64
+ boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
65
+
66
+ for box, score, label in zip(boxes, scores, labels):
67
+ if score.item() >= 0.25:
68
+ box = [round(i, 2) for i in box.tolist()]
69
+ object_label = target
70
+ confidence = round(score.item(), 3)
71
+ annotation = f"{object_label}: {confidence}"
72
+
73
+ draw.rectangle(box, outline="red", width=2)
74
+ text_position = (box[0], box[1] - 30)
75
+ draw.text(text_position, annotation, fill="white", font=font)
76
+
77
+ max_score = max(max_score, confidence)
78
+
79
+ processed_frames.append(np.array(pil_img))
80
+ frame_scores.append(max_score)
 
 
 
 
 
 
 
 
81
 
82
  cap.release()
83
  return processed_frames, frame_scores, None