oxkitsune commited on
Commit
4be800e
·
1 Parent(s): 0dd7180

batch model inference

Browse files
Files changed (1) hide show
  1. app.py +65 -52
app.py CHANGED
@@ -42,16 +42,23 @@ def resize_image(image_buffer, max_size=256):
42
 
43
 
44
  @spaces.GPU(duration=20)
45
- def predict_depth(input_image):
46
  # Preprocess the image
47
- result = depth_pro.load_rgb(input_image)
48
- image = result[0]
49
- f_px = result[-1] # Assuming f_px is the last item in the returned tuple
50
- image = transform(image)
51
- image = image.to(device)
 
 
 
 
 
 
 
52
 
53
  # Run inference
54
- prediction = model.infer(image, f_px=f_px)
55
  depth = prediction["depth"] # Depth in [m]
56
  focallength_px = prediction["focallength_px"] # Focal length in pixels
57
 
@@ -107,62 +114,68 @@ def run_rerun(path_to_video):
107
 
108
  # limit the number of frames to 10 seconds of video
109
  max_frames = min(10 * fps_video, num_frames)
 
110
 
111
- for i in range(len(frame_timestamps_ns)):
 
112
  if i >= max_frames:
113
  raise gr.Error("Reached the maximum number of frames to process")
114
 
115
- ret, frame = cap.read()
116
- if not ret:
117
- break
 
 
 
118
 
119
- temp_file = None
120
  try:
121
- # Resize the image to make the inference faster
122
- temp_file = resize_image(frame, max_size=256)
123
-
124
- depth, focal_length = predict_depth(temp_file)
125
-
126
- # find x and y scale factors, which can be applied to image
127
- x_scale = depth.shape[1] / frame.shape[1]
128
- y_scale = depth.shape[0] / frame.shape[0]
129
-
130
- rr.set_time_nanos("video_time", frame_timestamps_ns[i])
131
- rr.log(
132
- "world/camera/depth",
133
- rr.DepthImage(depth, meter=1),
134
- )
135
-
136
- rr.log(
137
- "world/camera/frame",
138
- rr.VideoFrameReference(
139
- timestamp=rr.components.VideoTimestamp(
140
- nanoseconds=frame_timestamps_ns[i]
 
 
141
  ),
142
- video_reference="world/video",
143
- ),
144
- rr.Transform3D(scale=(x_scale, y_scale, 1)),
145
- )
146
-
147
- rr.log(
148
- "world/camera",
149
- rr.Pinhole(
150
- focal_length=focal_length,
151
- width=depth.shape[1],
152
- height=depth.shape[0],
153
- principal_point=(depth.shape[1] / 2, depth.shape[0] / 2),
154
- camera_xyz=rr.ViewCoordinates.FLU,
155
- image_plane_distance=depth.max(),
156
- ),
157
- )
158
 
159
- yield stream.read()
160
  except Exception as e:
161
  raise gr.Error(f"An error has occurred: {e}")
162
  finally:
163
- # Clean up the temporary file
164
- if temp_file and os.path.exists(temp_file):
165
- os.remove(temp_file)
 
166
 
167
  yield stream.read()
168
 
 
42
 
43
 
44
  @spaces.GPU(duration=20)
45
+ def predict_depth(input_images):
46
  # Preprocess the image
47
+ results = [depth_pro.load_rgb(image) for image in input_images]
48
+
49
+ # assume load_rgb returns a tuple of (image, f_px)
50
+ # stack the images and f_px values into tensors
51
+ images, f_px = zip(*results)
52
+ images = torch.stack(images)
53
+ f_px = torch.tensor(f_px)
54
+
55
+ images = transform(images)
56
+
57
+ images = images.to(device)
58
+ f_px = f_px.to(device)
59
 
60
  # Run inference
61
+ prediction = model.infer(images, f_px=f_px)
62
  depth = prediction["depth"] # Depth in [m]
63
  focallength_px = prediction["focallength_px"] # Focal length in pixels
64
 
 
114
 
115
  # limit the number of frames to 10 seconds of video
116
  max_frames = min(10 * fps_video, num_frames)
117
+ batch_size = min(16, max_frames)
118
 
119
+ # go through all the frames in the video, using the batch size
120
+ for i in range(0, int(max_frames), batch_size):
121
  if i >= max_frames:
122
  raise gr.Error("Reached the maximum number of frames to process")
123
 
124
+ frames = []
125
+ for _ in range(batch_size):
126
+ ret, frame = cap.read()
127
+ if not ret:
128
+ break
129
+ frames.append(frame)
130
 
131
+ temp_files = []
132
  try:
133
+ # Resize the images to make the inference faster
134
+ temp_files = [resize_image(frame, max_size=256) for frame in frames]
135
+
136
+ depths, focal_lengths = predict_depth(temp_files)
137
+
138
+ for depth, focal_length in zip(depths, focal_lengths):
139
+ # find x and y scale factors, which can be applied to image
140
+ x_scale = depth.shape[1] / frames[0].shape[1]
141
+ y_scale = depth.shape[0] / frames[0].shape[0]
142
+
143
+ rr.log(
144
+ "world/camera/depth",
145
+ rr.DepthImage(depth, meter=1),
146
+ )
147
+
148
+ rr.log(
149
+ "world/camera/frame",
150
+ rr.VideoFrameReference(
151
+ timestamp=rr.components.VideoTimestamp(
152
+ nanoseconds=frame_timestamps_ns[i]
153
+ ),
154
+ video_reference="world/video",
155
  ),
156
+ rr.Transform3D(scale=(x_scale, y_scale, 1)),
157
+ )
158
+
159
+ rr.log(
160
+ "world/camera",
161
+ rr.Pinhole(
162
+ focal_length=focal_length,
163
+ width=depth.shape[1],
164
+ height=depth.shape[0],
165
+ principal_point=(depth.shape[1] / 2, depth.shape[0] / 2),
166
+ camera_xyz=rr.ViewCoordinates.FLU,
167
+ image_plane_distance=depth.max(),
168
+ ),
169
+ )
 
 
170
 
171
+ yield stream.read()
172
  except Exception as e:
173
  raise gr.Error(f"An error has occurred: {e}")
174
  finally:
175
+ # Clean up the temporary files
176
+ for temp_file in temp_files:
177
+ if temp_file and os.path.exists(temp_file):
178
+ os.remove(temp_file)
179
 
180
  yield stream.read()
181