Spaces:
Running
on
Zero
Running
on
Zero
batch model inference
Browse files
app.py
CHANGED
@@ -42,16 +42,23 @@ def resize_image(image_buffer, max_size=256):
|
|
42 |
|
43 |
|
44 |
@spaces.GPU(duration=20)
|
45 |
-
def predict_depth(
|
46 |
# Preprocess the image
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
# Run inference
|
54 |
-
prediction = model.infer(
|
55 |
depth = prediction["depth"] # Depth in [m]
|
56 |
focallength_px = prediction["focallength_px"] # Focal length in pixels
|
57 |
|
@@ -107,62 +114,68 @@ def run_rerun(path_to_video):
|
|
107 |
|
108 |
# limit the number of frames to 10 seconds of video
|
109 |
max_frames = min(10 * fps_video, num_frames)
|
|
|
110 |
|
111 |
-
|
|
|
112 |
if i >= max_frames:
|
113 |
raise gr.Error("Reached the maximum number of frames to process")
|
114 |
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
118 |
|
119 |
-
|
120 |
try:
|
121 |
-
# Resize the
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
|
|
|
|
141 |
),
|
142 |
-
|
143 |
-
)
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
),
|
157 |
-
)
|
158 |
|
159 |
-
|
160 |
except Exception as e:
|
161 |
raise gr.Error(f"An error has occurred: {e}")
|
162 |
finally:
|
163 |
-
# Clean up the temporary
|
164 |
-
|
165 |
-
os.
|
|
|
166 |
|
167 |
yield stream.read()
|
168 |
|
|
|
42 |
|
43 |
|
44 |
@spaces.GPU(duration=20)
|
45 |
+
def predict_depth(input_images):
|
46 |
# Preprocess the image
|
47 |
+
results = [depth_pro.load_rgb(image) for image in input_images]
|
48 |
+
|
49 |
+
# assume load_rgb returns a tuple of (image, f_px)
|
50 |
+
# stack the images and f_px values into tensors
|
51 |
+
images, f_px = zip(*results)
|
52 |
+
images = torch.stack(images)
|
53 |
+
f_px = torch.tensor(f_px)
|
54 |
+
|
55 |
+
images = transform(images)
|
56 |
+
|
57 |
+
images = images.to(device)
|
58 |
+
f_px = f_px.to(device)
|
59 |
|
60 |
# Run inference
|
61 |
+
prediction = model.infer(images, f_px=f_px)
|
62 |
depth = prediction["depth"] # Depth in [m]
|
63 |
focallength_px = prediction["focallength_px"] # Focal length in pixels
|
64 |
|
|
|
114 |
|
115 |
# limit the number of frames to 10 seconds of video
|
116 |
max_frames = min(10 * fps_video, num_frames)
|
117 |
+
batch_size = min(16, max_frames)
|
118 |
|
119 |
+
# go through all the frames in the video, using the batch size
|
120 |
+
for i in range(0, int(max_frames), batch_size):
|
121 |
if i >= max_frames:
|
122 |
raise gr.Error("Reached the maximum number of frames to process")
|
123 |
|
124 |
+
frames = []
|
125 |
+
for _ in range(batch_size):
|
126 |
+
ret, frame = cap.read()
|
127 |
+
if not ret:
|
128 |
+
break
|
129 |
+
frames.append(frame)
|
130 |
|
131 |
+
temp_files = []
|
132 |
try:
|
133 |
+
# Resize the images to make the inference faster
|
134 |
+
temp_files = [resize_image(frame, max_size=256) for frame in frames]
|
135 |
+
|
136 |
+
depths, focal_lengths = predict_depth(temp_files)
|
137 |
+
|
138 |
+
for depth, focal_length in zip(depths, focal_lengths):
|
139 |
+
# find x and y scale factors, which can be applied to image
|
140 |
+
x_scale = depth.shape[1] / frames[0].shape[1]
|
141 |
+
y_scale = depth.shape[0] / frames[0].shape[0]
|
142 |
+
|
143 |
+
rr.log(
|
144 |
+
"world/camera/depth",
|
145 |
+
rr.DepthImage(depth, meter=1),
|
146 |
+
)
|
147 |
+
|
148 |
+
rr.log(
|
149 |
+
"world/camera/frame",
|
150 |
+
rr.VideoFrameReference(
|
151 |
+
timestamp=rr.components.VideoTimestamp(
|
152 |
+
nanoseconds=frame_timestamps_ns[i]
|
153 |
+
),
|
154 |
+
video_reference="world/video",
|
155 |
),
|
156 |
+
rr.Transform3D(scale=(x_scale, y_scale, 1)),
|
157 |
+
)
|
158 |
+
|
159 |
+
rr.log(
|
160 |
+
"world/camera",
|
161 |
+
rr.Pinhole(
|
162 |
+
focal_length=focal_length,
|
163 |
+
width=depth.shape[1],
|
164 |
+
height=depth.shape[0],
|
165 |
+
principal_point=(depth.shape[1] / 2, depth.shape[0] / 2),
|
166 |
+
camera_xyz=rr.ViewCoordinates.FLU,
|
167 |
+
image_plane_distance=depth.max(),
|
168 |
+
),
|
169 |
+
)
|
|
|
|
|
170 |
|
171 |
+
yield stream.read()
|
172 |
except Exception as e:
|
173 |
raise gr.Error(f"An error has occurred: {e}")
|
174 |
finally:
|
175 |
+
# Clean up the temporary files
|
176 |
+
for temp_file in temp_files:
|
177 |
+
if temp_file and os.path.exists(temp_file):
|
178 |
+
os.remove(temp_file)
|
179 |
|
180 |
yield stream.read()
|
181 |
|