oxkitsune commited on
Commit
ad0f5c7
·
1 Parent(s): 4be800e

batch model inference

Browse files
Files changed (1) hide show
  1. app.py +20 -19
app.py CHANGED
@@ -43,24 +43,15 @@ def resize_image(image_buffer, max_size=256):
43
 
44
  @spaces.GPU(duration=20)
45
  def predict_depth(input_images):
46
- # Preprocess the image
47
  results = [depth_pro.load_rgb(image) for image in input_images]
48
-
49
- # assume load_rgb returns a tuple of (image, f_px)
50
- # stack the images and f_px values into tensors
51
- images, f_px = zip(*results)
52
- images = torch.stack(images)
53
- f_px = torch.tensor(f_px)
54
-
55
- images = transform(images)
56
-
57
  images = images.to(device)
58
- f_px = f_px.to(device)
59
 
60
  # Run inference
61
- prediction = model.infer(images, f_px=f_px)
62
- depth = prediction["depth"] # Depth in [m]
63
- focallength_px = prediction["focallength_px"] # Focal length in pixels
 
64
 
65
  # Convert depth to numpy array if it's a torch tensor
66
  if isinstance(depth, torch.Tensor):
@@ -68,9 +59,9 @@ def predict_depth(input_images):
68
 
69
  # Convert focal length to a float if it's a torch tensor
70
  if isinstance(focallength_px, torch.Tensor):
71
- focallength_px = focallength_px.item()
72
 
73
- # Ensure depth is a 2D numpy array
74
  if depth.ndim != 2:
75
  depth = depth.squeeze()
76
 
@@ -114,7 +105,13 @@ def run_rerun(path_to_video):
114
 
115
  # limit the number of frames to 10 seconds of video
116
  max_frames = min(10 * fps_video, num_frames)
117
- batch_size = min(16, max_frames)
 
 
 
 
 
 
118
 
119
  # go through all the frames in the video, using the batch size
120
  for i in range(0, int(max_frames), batch_size):
@@ -122,6 +119,7 @@ def run_rerun(path_to_video):
122
  raise gr.Error("Reached the maximum number of frames to process")
123
 
124
  frames = []
 
125
  for _ in range(batch_size):
126
  ret, frame = cap.read()
127
  if not ret:
@@ -135,11 +133,14 @@ def run_rerun(path_to_video):
135
 
136
  depths, focal_lengths = predict_depth(temp_files)
137
 
138
- for depth, focal_length in zip(depths, focal_lengths):
 
 
139
  # find x and y scale factors, which can be applied to image
140
  x_scale = depth.shape[1] / frames[0].shape[1]
141
  y_scale = depth.shape[0] / frames[0].shape[0]
142
 
 
143
  rr.log(
144
  "world/camera/depth",
145
  rr.DepthImage(depth, meter=1),
@@ -149,7 +150,7 @@ def run_rerun(path_to_video):
149
  "world/camera/frame",
150
  rr.VideoFrameReference(
151
  timestamp=rr.components.VideoTimestamp(
152
- nanoseconds=frame_timestamps_ns[i]
153
  ),
154
  video_reference="world/video",
155
  ),
 
43
 
44
  @spaces.GPU(duration=20)
45
  def predict_depth(input_images):
 
46
  results = [depth_pro.load_rgb(image) for image in input_images]
47
+ images = torch.stack([transform(result[0]) for result in results])
 
 
 
 
 
 
 
 
48
  images = images.to(device)
 
49
 
50
  # Run inference
51
+ with torch.no_grad():
52
+ prediction = model.infer(images)
53
+ depth = prediction["depth"] # Depth in [m]
54
+ focallength_px = prediction["focallength_px"] # Focal length in pixels
55
 
56
  # Convert depth to numpy array if it's a torch tensor
57
  if isinstance(depth, torch.Tensor):
 
59
 
60
  # Convert focal length to a float if it's a torch tensor
61
  if isinstance(focallength_px, torch.Tensor):
62
+ focallength_px = [focal_length.item() for focal_length in focallength_px]
63
 
64
+ # Ensure depth is a BxHxW tensor
65
  if depth.ndim != 2:
66
  depth = depth.squeeze()
67
 
 
105
 
106
  # limit the number of frames to 10 seconds of video
107
  max_frames = min(10 * fps_video, num_frames)
108
+
109
+ torch.cuda.empty_cache()
110
+ free_vram, _ = torch.cuda.mem_get_info(device)
111
+ free_vram = free_vram / 1024 / 1024 / 1024
112
+
113
+ # batch size is determined by the amount of free vram
114
+ batch_size = int(min(free_vram // 4, max_frames))
115
 
116
  # go through all the frames in the video, using the batch size
117
  for i in range(0, int(max_frames), batch_size):
 
119
  raise gr.Error("Reached the maximum number of frames to process")
120
 
121
  frames = []
122
+ frame_indices = list(range(i, min(i + batch_size, int(max_frames))))
123
  for _ in range(batch_size):
124
  ret, frame = cap.read()
125
  if not ret:
 
133
 
134
  depths, focal_lengths = predict_depth(temp_files)
135
 
136
+ for depth, focal_length, frame_idx in zip(
137
+ depths, focal_lengths, frame_indices
138
+ ):
139
  # find x and y scale factors, which can be applied to image
140
  x_scale = depth.shape[1] / frames[0].shape[1]
141
  y_scale = depth.shape[0] / frames[0].shape[0]
142
 
143
+ rr.set_time_nanos("video_time", frame_timestamps_ns[frame_idx])
144
  rr.log(
145
  "world/camera/depth",
146
  rr.DepthImage(depth, meter=1),
 
150
  "world/camera/frame",
151
  rr.VideoFrameReference(
152
  timestamp=rr.components.VideoTimestamp(
153
+ nanoseconds=frame_timestamps_ns[frame_idx]
154
  ),
155
  video_reference="world/video",
156
  ),