Spaces:
Running
on
Zero
Running
on
Zero
batch model inference
Browse files
app.py
CHANGED
@@ -43,24 +43,15 @@ def resize_image(image_buffer, max_size=256):
|
|
43 |
|
44 |
@spaces.GPU(duration=20)
|
45 |
def predict_depth(input_images):
|
46 |
-
# Preprocess the image
|
47 |
results = [depth_pro.load_rgb(image) for image in input_images]
|
48 |
-
|
49 |
-
# assume load_rgb returns a tuple of (image, f_px)
|
50 |
-
# stack the images and f_px values into tensors
|
51 |
-
images, f_px = zip(*results)
|
52 |
-
images = torch.stack(images)
|
53 |
-
f_px = torch.tensor(f_px)
|
54 |
-
|
55 |
-
images = transform(images)
|
56 |
-
|
57 |
images = images.to(device)
|
58 |
-
f_px = f_px.to(device)
|
59 |
|
60 |
# Run inference
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
64 |
|
65 |
# Convert depth to numpy array if it's a torch tensor
|
66 |
if isinstance(depth, torch.Tensor):
|
@@ -68,9 +59,9 @@ def predict_depth(input_images):
|
|
68 |
|
69 |
# Convert focal length to a float if it's a torch tensor
|
70 |
if isinstance(focallength_px, torch.Tensor):
|
71 |
-
focallength_px =
|
72 |
|
73 |
-
# Ensure depth is a
|
74 |
if depth.ndim != 2:
|
75 |
depth = depth.squeeze()
|
76 |
|
@@ -114,7 +105,13 @@ def run_rerun(path_to_video):
|
|
114 |
|
115 |
# limit the number of frames to 10 seconds of video
|
116 |
max_frames = min(10 * fps_video, num_frames)
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
# go through all the frames in the video, using the batch size
|
120 |
for i in range(0, int(max_frames), batch_size):
|
@@ -122,6 +119,7 @@ def run_rerun(path_to_video):
|
|
122 |
raise gr.Error("Reached the maximum number of frames to process")
|
123 |
|
124 |
frames = []
|
|
|
125 |
for _ in range(batch_size):
|
126 |
ret, frame = cap.read()
|
127 |
if not ret:
|
@@ -135,11 +133,14 @@ def run_rerun(path_to_video):
|
|
135 |
|
136 |
depths, focal_lengths = predict_depth(temp_files)
|
137 |
|
138 |
-
for depth, focal_length in zip(
|
|
|
|
|
139 |
# find x and y scale factors, which can be applied to image
|
140 |
x_scale = depth.shape[1] / frames[0].shape[1]
|
141 |
y_scale = depth.shape[0] / frames[0].shape[0]
|
142 |
|
|
|
143 |
rr.log(
|
144 |
"world/camera/depth",
|
145 |
rr.DepthImage(depth, meter=1),
|
@@ -149,7 +150,7 @@ def run_rerun(path_to_video):
|
|
149 |
"world/camera/frame",
|
150 |
rr.VideoFrameReference(
|
151 |
timestamp=rr.components.VideoTimestamp(
|
152 |
-
nanoseconds=frame_timestamps_ns[
|
153 |
),
|
154 |
video_reference="world/video",
|
155 |
),
|
|
|
43 |
|
44 |
@spaces.GPU(duration=20)
|
45 |
def predict_depth(input_images):
|
|
|
46 |
results = [depth_pro.load_rgb(image) for image in input_images]
|
47 |
+
images = torch.stack([transform(result[0]) for result in results])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
images = images.to(device)
|
|
|
49 |
|
50 |
# Run inference
|
51 |
+
with torch.no_grad():
|
52 |
+
prediction = model.infer(images)
|
53 |
+
depth = prediction["depth"] # Depth in [m]
|
54 |
+
focallength_px = prediction["focallength_px"] # Focal length in pixels
|
55 |
|
56 |
# Convert depth to numpy array if it's a torch tensor
|
57 |
if isinstance(depth, torch.Tensor):
|
|
|
59 |
|
60 |
# Convert focal length to a float if it's a torch tensor
|
61 |
if isinstance(focallength_px, torch.Tensor):
|
62 |
+
focallength_px = [focal_length.item() for focal_length in focallength_px]
|
63 |
|
64 |
+
# Ensure depth is a BxHxW tensor
|
65 |
if depth.ndim != 2:
|
66 |
depth = depth.squeeze()
|
67 |
|
|
|
105 |
|
106 |
# limit the number of frames to 10 seconds of video
|
107 |
max_frames = min(10 * fps_video, num_frames)
|
108 |
+
|
109 |
+
torch.cuda.empty_cache()
|
110 |
+
free_vram, _ = torch.cuda.mem_get_info(device)
|
111 |
+
free_vram = free_vram / 1024 / 1024 / 1024
|
112 |
+
|
113 |
+
# batch size is determined by the amount of free vram
|
114 |
+
batch_size = int(min(free_vram // 4, max_frames))
|
115 |
|
116 |
# go through all the frames in the video, using the batch size
|
117 |
for i in range(0, int(max_frames), batch_size):
|
|
|
119 |
raise gr.Error("Reached the maximum number of frames to process")
|
120 |
|
121 |
frames = []
|
122 |
+
frame_indices = list(range(i, min(i + batch_size, int(max_frames))))
|
123 |
for _ in range(batch_size):
|
124 |
ret, frame = cap.read()
|
125 |
if not ret:
|
|
|
133 |
|
134 |
depths, focal_lengths = predict_depth(temp_files)
|
135 |
|
136 |
+
for depth, focal_length, frame_idx in zip(
|
137 |
+
depths, focal_lengths, frame_indices
|
138 |
+
):
|
139 |
# find x and y scale factors, which can be applied to image
|
140 |
x_scale = depth.shape[1] / frames[0].shape[1]
|
141 |
y_scale = depth.shape[0] / frames[0].shape[0]
|
142 |
|
143 |
+
rr.set_time_nanos("video_time", frame_timestamps_ns[frame_idx])
|
144 |
rr.log(
|
145 |
"world/camera/depth",
|
146 |
rr.DepthImage(depth, meter=1),
|
|
|
150 |
"world/camera/frame",
|
151 |
rr.VideoFrameReference(
|
152 |
timestamp=rr.components.VideoTimestamp(
|
153 |
+
nanoseconds=frame_timestamps_ns[frame_idx]
|
154 |
),
|
155 |
video_reference="world/video",
|
156 |
),
|