Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,256 Bytes
37840e7 5212158 5514789 37840e7 5514789 5212158 c5c5a80 5212158 5514789 956147e 6af6ea2 5212158 5514789 c5c5a80 5212158 5514789 5212158 5514789 8f8d235 5514789 5212158 5514789 c5c5a80 8f8d235 5514789 5212158 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import torch
torch.jit.script = lambda f: f
from zoedepth.utils.config import get_config
from zoedepth.models.builder import build_model
from zoedepth.utils.misc import colorize, save_raw_16bit
from zoedepth.utils.geometry import depth_to_points, create_triangles
import gradio as gr
import spaces
from PIL import Image
import numpy as np
import trimesh
from functools import partial
import tempfile
css = """
#img-display-container {
max-height: 50vh;
}
#img-display-input {
max-height: 40vh;
}
#img-display-output {
max-height: 40vh;
}
"""
# DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE = 'cuda'
model = torch.hub.load('isl-org/ZoeDepth', "ZoeD_N", pretrained=True).to("cpu").eval()
# ----------- Depth functions
def save_raw_16bit(depth, fpath="raw.png"):
if isinstance(depth, torch.Tensor):
depth = depth.squeeze().cpu().numpy()
assert isinstance(depth, np.ndarray), "Depth must be a torch tensor or numpy array"
assert depth.ndim == 2, "Depth must be 2D"
depth = depth * 256 # scale for 16-bit png
depth = depth.astype(np.uint16)
return depth
@spaces.GPU(enable_queue=True)
def process_image(model, image: Image.Image):
image = image.convert("RGB")
model.to(DEVICE)
out = model.infer_pil(image)
processed_array = save_raw_16bit(colorize(out)[:, :, 0])
return Image.fromarray(processed_array)
# ----------- Depth functions
# ----------- Mesh functions
def depth_edges_mask(depth):
"""Returns a mask of edges in the depth map.
Args:
depth: 2D numpy array of shape (H, W) with dtype float32.
Returns:
mask: 2D numpy array of shape (H, W) with dtype bool.
"""
# Compute the x and y gradients of the depth map.
depth_dx, depth_dy = np.gradient(depth)
# Compute the gradient magnitude.
depth_grad = np.sqrt(depth_dx ** 2 + depth_dy ** 2)
# Compute the edge mask.
mask = depth_grad > 0.05
return mask
@spaces.GPU(enable_queue=True)
def predict_depth(model, image):
model.to(DEVICE)
depth = model.infer_pil(image)
return depth
@spaces.GPU(enable_queue=True)
def get_mesh(model, image: Image.Image, keep_edges=True):
image.thumbnail((1024,1024)) # limit the size of the input image
depth = predict_depth(model, image)
pts3d = depth_to_points(depth[None])
pts3d = pts3d.reshape(-1, 3)
# Create a trimesh mesh from the points
# Each pixel is connected to its 4 neighbors
# colors are the RGB values of the image
verts = pts3d.reshape(-1, 3)
image = np.array(image)
if keep_edges:
triangles = create_triangles(image.shape[0], image.shape[1])
else:
triangles = create_triangles(image.shape[0], image.shape[1], mask=~depth_edges_mask(depth))
colors = image.reshape(-1, 3)
mesh = trimesh.Trimesh(vertices=verts, faces=triangles, vertex_colors=colors)
# Save as glb
glb_file = tempfile.NamedTemporaryFile(suffix='.glb', delete=False)
glb_path = glb_file.name
mesh.export(glb_path)
return glb_path
# ----------- Mesh functions
title = "# ZoeDepth"
description = """Unofficial demo for **ZoeDepth: Zero-shot Transfer by Combining Relative and Metric Depth**."""
with gr.Blocks(css=css) as API:
gr.Markdown(title)
gr.Markdown(description)
with gr.Tab("Depth Prediction"):
with gr.Row():
inputs=gr.Image(label="Input Image", type='pil', height=500) # Input is an image
outputs=gr.Image(label="Depth Map", type='pil', height=500) # Output is also an image
generate_btn = gr.Button(value="Generate")
generate_btn.click(partial(process_image, model), inputs=inputs, outputs=outputs, api_name="generate_depth")
with gr.Tab("Image to 3D"):
with gr.Row():
with gr.Column():
inputs=[gr.Image(label="Input Image", type='pil', height=500), gr.Checkbox(label="Keep occlusion edges", value=True)]
outputs=gr.Model3D(label="3D Mesh", clear_color=[1.0, 1.0, 1.0, 1.0], height=500)
generate_btn = gr.Button(value="Generate")
generate_btn.click(partial(get_mesh, model), inputs=inputs, outputs=outputs, api_name="generate_mesh")
if __name__ == '__main__':
API.launch() |