CharlieAmalet commited on
Commit
5514789
·
verified ·
1 Parent(s): 956147e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -15
app.py CHANGED
@@ -1,12 +1,15 @@
1
  from zoedepth.utils.config import get_config
2
  from zoedepth.models.builder import build_model
3
-
4
- import gradio as gr
5
  import torch
6
- torch.jit.script = lambda f: f
7
-
8
- from depth import depth_interface
9
- from mesh import mesh_interface
 
 
 
10
 
11
 
12
  css = """
@@ -19,24 +22,107 @@ css = """
19
  #img-display-output {
20
  max-height: 40vh;
21
  }
22
-
23
  """
 
24
  # DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
25
  DEVICE = 'cuda'
26
  model = torch.hub.load('isl-org/ZoeDepth', "ZoeD_N", pretrained=True).to("cpu").eval()
27
 
28
- # title = "# ZoeDepth"
29
- # description = """Official demo for **ZoeDepth: Zero-shot Transfer by Combining Relative and Metric Depth**."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  with gr.Blocks(css=css) as API:
32
- # gr.Markdown(title)
33
- # gr.Markdown(description)
34
  with gr.Tab("Depth Prediction"):
35
- depth_interface(model, DEVICE)
 
 
 
 
 
36
  with gr.Tab("Image to 3D"):
37
- mesh_interface(model, DEVICE)
38
- # with gr.Tab("360 Panorama to 3D"):
39
- # create_pano_to_3d_demo(model)
 
 
40
 
41
  if __name__ == '__main__':
42
  API.launch()
 
1
  from zoedepth.utils.config import get_config
2
  from zoedepth.models.builder import build_model
3
+ from zoedepth.utils.misc import colorize, save_raw_16bit
4
+ from zoedepth.utils.geometry import depth_to_points, create_triangles
5
  import torch
6
+ import gradio as gr
7
+ # import spaces
8
+ from PIL import Image
9
+ import numpy as np
10
+ import trimesh
11
+ from functools import partial
12
+ import tempfile
13
 
14
 
15
  css = """
 
22
  #img-display-output {
23
  max-height: 40vh;
24
  }
 
25
  """
26
+
27
  # DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
28
  DEVICE = 'cuda'
29
  model = torch.hub.load('isl-org/ZoeDepth', "ZoeD_N", pretrained=True).to("cpu").eval()
30
 
31
+ # ----------- Depth functions
32
+ def save_raw_16bit(depth, fpath="raw.png"):
33
+ if isinstance(depth, torch.Tensor):
34
+ depth = depth.squeeze().cpu().numpy()
35
+
36
+ assert isinstance(depth, np.ndarray), "Depth must be a torch tensor or numpy array"
37
+ assert depth.ndim == 2, "Depth must be 2D"
38
+ depth = depth * 256 # scale for 16-bit png
39
+ depth = depth.astype(np.uint16)
40
+ return depth
41
+
42
+ @spaces.GPU(enable_queue=True)
43
+ def process_image(model, image: Image.Image):
44
+ image = image.convert("RGB")
45
+
46
+ model.to(DEVICE)
47
+ out = model.infer_pil(image)
48
+
49
+ processed_array = save_raw_16bit(colorize(out)[:, :, 0])
50
+ return Image.fromarray(processed_array)
51
+
52
+ # ----------- Depth functions
53
+
54
+ # ----------- Mesh functions
55
+
56
+ def depth_edges_mask(depth):
57
+ """Returns a mask of edges in the depth map.
58
+ Args:
59
+ depth: 2D numpy array of shape (H, W) with dtype float32.
60
+ Returns:
61
+ mask: 2D numpy array of shape (H, W) with dtype bool.
62
+ """
63
+ # Compute the x and y gradients of the depth map.
64
+ depth_dx, depth_dy = np.gradient(depth)
65
+ # Compute the gradient magnitude.
66
+ depth_grad = np.sqrt(depth_dx ** 2 + depth_dy ** 2)
67
+ # Compute the edge mask.
68
+ mask = depth_grad > 0.05
69
+ return mask
70
+
71
+ @spaces.GPU(enable_queue=True)
72
+ def predict_depth(model, image):
73
+ model.to(DEVICE)
74
+ depth = model.infer_pil(image)
75
+ return depth
76
+
77
+ @spaces.GPU(enable_queue=True)
78
+ def get_mesh(model, image: Image.Image, keep_edges=True):
79
+ image.thumbnail((1024,1024)) # limit the size of the input image
80
+
81
+ depth = predict_depth(model, image)
82
+ pts3d = depth_to_points(depth[None])
83
+ pts3d = pts3d.reshape(-1, 3)
84
+
85
+ # Create a trimesh mesh from the points
86
+ # Each pixel is connected to its 4 neighbors
87
+ # colors are the RGB values of the image
88
+
89
+ verts = pts3d.reshape(-1, 3)
90
+ image = np.array(image)
91
+ if keep_edges:
92
+ triangles = create_triangles(image.shape[0], image.shape[1])
93
+ else:
94
+ triangles = create_triangles(image.shape[0], image.shape[1], mask=~depth_edges_mask(depth))
95
+
96
+ colors = image.reshape(-1, 3)
97
+ mesh = trimesh.Trimesh(vertices=verts, faces=triangles, vertex_colors=colors)
98
+
99
+ # Save as glb
100
+ glb_file = tempfile.NamedTemporaryFile(suffix='.glb', delete=False)
101
+ glb_path = glb_file.name
102
+ mesh.export(glb_path)
103
+ return glb_path
104
+
105
+ # ----------- Mesh functions
106
+
107
+ title = "# ZoeDepth"
108
+ description = """Official demo for **ZoeDepth: Zero-shot Transfer by Combining Relative and Metric Depth**."""
109
 
110
  with gr.Blocks(css=css) as API:
111
+ gr.Markdown(title)
112
+ gr.Markdown(description)
113
  with gr.Tab("Depth Prediction"):
114
+ with gr.Row():
115
+ inputs=gr.Image(label="Input Image", type='pil') # Input is an image
116
+ outputs=gr.Image(label="Depth Map", type='pil') # Output is also an image
117
+ generate_btn = gr.Button(value="Generate")
118
+ generate_btn.click(partial(process_image, model), inputs=inputs, outputs=outputs, api_name="generate_depth")
119
+
120
  with gr.Tab("Image to 3D"):
121
+ with gr.Row():
122
+ inputs=[gr.Image(label="Input Image", type='pil'), gr.Checkbox(label="Keep occlusion edges", value=True)]
123
+ outputs=gr.Model3D(label="3D Mesh", clear_color=[1.0, 1.0, 1.0, 1.0])
124
+ generate_btn = gr.Button(value="Generate")
125
+ generate_btn.click(partial(get_mesh, model), inputs=inputs, outputs=outputs, api_name="generate_mesh")
126
 
127
  if __name__ == '__main__':
128
  API.launch()