CharlieAmalet commited on
Commit
212d34e
·
verified ·
1 Parent(s): b772fd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -86
app.py CHANGED
@@ -5,14 +5,9 @@ import spaces
5
 
6
  from zoedepth.utils.misc import colorize, save_raw_16bit
7
  from zoedepth.utils.geometry import depth_to_points, create_triangles
8
- from marigold_depth_estimation import MarigoldPipeline
9
 
10
  from PIL import Image
11
  import numpy as np
12
- import trimesh
13
- from functools import partial
14
- import tempfile
15
-
16
 
17
  css = """
18
  img {
@@ -21,97 +16,34 @@ img {
21
  }
22
  """
23
 
24
- DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
25
- model = torch.hub.load('isl-org/ZoeDepth', "ZoeD_N", pretrained=True).to(DEVICE).eval()
26
-
27
- CHECKPOINT = "prs-eth/marigold-v1-0"
28
- pipe = MarigoldPipeline.from_pretrained(CHECKPOINT)
29
 
30
  # ----------- Depth functions
31
- @spaces.GPU(enable_queue=True)
32
  def save_raw_16bit(depth, fpath="raw.png"):
33
  if isinstance(depth, torch.Tensor):
34
  depth = depth.squeeze().cpu().numpy()
35
 
36
- assert isinstance(depth, np.ndarray), "Depth must be a torch tensor or numpy array"
37
- assert depth.ndim == 2, "Depth must be 2D"
38
  depth = depth * 256 # scale for 16-bit png
39
  depth = depth.astype(np.uint16)
40
  return depth
41
 
42
  @spaces.GPU(enable_queue=True)
43
  def process_image(image: Image.Image):
44
- global model
45
  image = image.convert("RGB")
46
 
47
- # model.to(DEVICE)
48
- depth = model.infer_pil(image)
 
49
 
50
  processed_array = save_raw_16bit(colorize(depth)[:, :, 0])
51
  return Image.fromarray(processed_array)
52
 
53
- # model.to(device)
54
-
55
- # processed_array = pipe(image)["depth"]
56
- # return Image.fromarray(processed_array)
57
-
58
  # ----------- Depth functions
59
 
60
- # ----------- Mesh functions
61
- @spaces.GPU(enable_queue=True)
62
- def depth_edges_mask(depth):
63
- global model
64
- """Returns a mask of edges in the depth map.
65
- Args:
66
- depth: 2D numpy array of shape (H, W) with dtype float32.
67
- Returns:
68
- mask: 2D numpy array of shape (H, W) with dtype bool.
69
- """
70
- # Compute the x and y gradients of the depth map.
71
- depth_dx, depth_dy = np.gradient(depth)
72
- # Compute the gradient magnitude.
73
- depth_grad = np.sqrt(depth_dx ** 2 + depth_dy ** 2)
74
- # Compute the edge mask.
75
- mask = depth_grad > 0.05
76
- return mask
77
-
78
- @spaces.GPU(enable_queue=True)
79
- def predict_depth(image):
80
- global model
81
- device = "cuda" if torch.cuda.is_available() else "cpu"
82
- model.to(device)
83
- depth = model.infer_pil(image)
84
- return depth
85
-
86
- @spaces.GPU(enable_queue=True)
87
- def get_mesh(image: Image.Image, keep_edges=True):
88
- image.thumbnail((1024,1024)) # limit the size of the input image
89
-
90
- depth = predict_depth(image)
91
- pts3d = depth_to_points(depth[None])
92
- pts3d = pts3d.reshape(-1, 3)
93
-
94
- # Create a trimesh mesh from the points
95
- # Each pixel is connected to its 4 neighbors
96
- # colors are the RGB values of the image
97
-
98
- verts = pts3d.reshape(-1, 3)
99
- image = np.array(image)
100
- if keep_edges:
101
- triangles = create_triangles(image.shape[0], image.shape[1])
102
- else:
103
- triangles = create_triangles(image.shape[0], image.shape[1], mask=~depth_edges_mask(depth))
104
-
105
- colors = image.reshape(-1, 3)
106
- mesh = trimesh.Trimesh(vertices=verts, faces=triangles, vertex_colors=colors)
107
-
108
- # Save as glb
109
- glb_file = tempfile.NamedTemporaryFile(suffix='.glb', delete=False)
110
- glb_path = glb_file.name
111
- mesh.export(glb_path)
112
- return glb_path
113
-
114
- # ----------- Mesh functions
115
 
116
  title = "# ZoeDepth"
117
  description = """Unofficial demo for **ZoeDepth: Zero-shot Transfer by Combining Relative and Metric Depth**."""
@@ -124,17 +56,7 @@ with gr.Blocks(css=css) as API:
124
  inputs=gr.Image(label="Input Image", type='pil', height=500) # Input is an image
125
  outputs=gr.Image(label="Depth Map", type='pil', height=500) # Output is also an image
126
  generate_btn = gr.Button(value="Generate")
127
- # generate_btn.click(partial(process_image, model), inputs=inputs, outputs=outputs, api_name="generate_depth")
128
  generate_btn.click(process_image, inputs=inputs, outputs=outputs, api_name="generate_depth")
129
-
130
- with gr.Tab("Image to 3D"):
131
- with gr.Row():
132
- with gr.Column():
133
- inputs=[gr.Image(label="Input Image", type='pil', height=500), gr.Checkbox(label="Keep occlusion edges", value=True)]
134
- outputs=gr.Model3D(label="3D Mesh", clear_color=[1.0, 1.0, 1.0, 1.0], height=500)
135
- generate_btn = gr.Button(value="Generate")
136
- # generate_btn.click(partial(get_mesh, model), inputs=inputs, outputs=outputs, api_name="generate_mesh")
137
- generate_btn.click(get_mesh, inputs=inputs, outputs=outputs, api_name="generate_mesh")
138
 
139
  if __name__ == '__main__':
140
  API.launch()
 
5
 
6
  from zoedepth.utils.misc import colorize, save_raw_16bit
7
  from zoedepth.utils.geometry import depth_to_points, create_triangles
 
8
 
9
  from PIL import Image
10
  import numpy as np
 
 
 
 
11
 
12
  css = """
13
  img {
 
16
  }
17
  """
18
 
19
+ # DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
20
+ MODEL = torch.hub.load('isl-org/ZoeDepth', "ZoeD_N", pretrained=True)#.to(DEVICE).eval()
 
 
 
21
 
22
  # ----------- Depth functions
 
23
  def save_raw_16bit(depth, fpath="raw.png"):
24
  if isinstance(depth, torch.Tensor):
25
  depth = depth.squeeze().cpu().numpy()
26
 
27
+ # assert isinstance(depth, np.ndarray), "Depth must be a torch tensor or numpy array"
28
+ # assert depth.ndim == 2, "Depth must be 2D"
29
  depth = depth * 256 # scale for 16-bit png
30
  depth = depth.astype(np.uint16)
31
  return depth
32
 
33
  @spaces.GPU(enable_queue=True)
34
  def process_image(image: Image.Image):
35
+ global MODEL
36
  image = image.convert("RGB")
37
 
38
+ device = "cuda" if torch.cuda.is_available() else "cpu"
39
+ MODEL.to(device)
40
+ depth = MODEL.infer_pil(image)
41
 
42
  processed_array = save_raw_16bit(colorize(depth)[:, :, 0])
43
  return Image.fromarray(processed_array)
44
 
 
 
 
 
 
45
  # ----------- Depth functions
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  title = "# ZoeDepth"
49
  description = """Unofficial demo for **ZoeDepth: Zero-shot Transfer by Combining Relative and Metric Depth**."""
 
56
  inputs=gr.Image(label="Input Image", type='pil', height=500) # Input is an image
57
  outputs=gr.Image(label="Depth Map", type='pil', height=500) # Output is also an image
58
  generate_btn = gr.Button(value="Generate")
 
59
  generate_btn.click(process_image, inputs=inputs, outputs=outputs, api_name="generate_depth")
 
 
 
 
 
 
 
 
 
60
 
61
  if __name__ == '__main__':
62
  API.launch()