ThomasSimonini HF staff commited on
Commit
26fc4a2
·
verified ·
1 Parent(s): ef24b41

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +305 -245
app.py CHANGED
@@ -19,10 +19,10 @@ from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
19
  import shutil
20
  from src.utils.train_util import instantiate_from_config
21
  from src.utils.camera_util import (
22
- FOV_to_intrinsics,
23
  get_zero123plus_input_cameras,
24
  get_circular_camera_poses,
25
- )
26
  from src.utils.mesh_util import save_obj, save_glb
27
  from src.utils.infer_util import remove_background, resize_foreground, images_to_video
28
 
@@ -44,7 +44,9 @@ from mpl_toolkits.mplot3d.art3d import Poly3DCollection
44
  ###############################################################################
45
  # Configuration for InstantMesh
46
  ###############################################################################
47
- def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
 
 
48
  """
49
  Get the rendering camera parameters.
50
  """
@@ -54,88 +56,110 @@ def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexi
54
  cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
55
  else:
56
  extrinsics = c2ws.flatten(-2)
57
- intrinsics = FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
 
 
58
  cameras = torch.cat([extrinsics, intrinsics], dim=-1)
59
  cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
60
- return cameras
61
 
62
 
63
- def images_to_video(images, output_path, fps=30):
64
  # images: (N, C, H, W)
65
  os.makedirs(os.path.dirname(output_path), exist_ok=True)
66
  frames = []
67
  for i in range(images.shape[0]):
68
- frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8).clip(0, 255)
69
- assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
70
- f"Frame shape mismatch: {frame.shape} vs {images.shape}"
71
- assert frame.min() >= 0 and frame.max() <= 255, \
72
- f"Frame value out of range: {frame.min()} ~ {frame.max()}"
 
 
 
 
 
 
73
  frames.append(frame)
74
- imageio.mimwrite(output_path, np.stack(frames), fps=fps, codec='h264')
 
75
 
76
- def find_cuda():
77
  # Check if CUDA_HOME or CUDA_PATH environment variables are set
78
- cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
79
 
80
  if cuda_home and os.path.exists(cuda_home):
81
  return cuda_home
82
 
83
  # Search for the nvcc executable in the system's PATH
84
- nvcc_path = shutil.which('nvcc')
85
 
86
  if nvcc_path:
87
  # Remove the 'bin/nvcc' part to get the CUDA installation path
88
  cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
89
  return cuda_path
90
 
91
- return None
92
 
93
- cuda_path = find_cuda()
94
 
95
- if cuda_path:
96
- print(f"CUDA installation found at: {cuda_path}")
97
- else:
98
- print("CUDA installation not found")
 
 
99
 
100
- config_path = 'configs/instant-mesh-large.yaml'
101
- config = OmegaConf.load(config_path)
102
- config_name = os.path.basename(config_path).replace('.yaml', '')
103
- model_config = config.model_config
104
- infer_config = config.infer_config
105
 
106
- IS_FLEXICUBES = True if config_name.startswith('instant-mesh') else False
107
 
108
- device = torch.device('cuda')
109
 
110
  # load diffusion model
111
- print('Loading diffusion model ...')
112
  pipeline = DiffusionPipeline.from_pretrained(
113
- "sudo-ai/zero123plus-v1.2",
114
  custom_pipeline="zero123plus",
115
  torch_dtype=torch.float16,
116
- )
117
  pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
118
- pipeline.scheduler.config, timestep_spacing='trailing'
119
- )
120
 
121
  # load custom white-background UNet
122
- unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
123
- state_dict = torch.load(unet_ckpt_path, map_location='cpu')
 
 
 
 
124
  pipeline.unet.load_state_dict(state_dict, strict=True)
125
 
126
  pipeline = pipeline.to(device)
127
 
128
  # load reconstruction model
129
- print('Loading reconstruction model ...')
130
- model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model")
 
 
 
 
131
  model = instantiate_from_config(model_config)
132
- state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
133
- state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
 
 
 
 
134
  model.load_state_dict(state_dict, strict=True)
135
 
136
  model = model.to(device)
137
 
138
- print('Loading Finished!')
139
 
140
 
141
  def check_input_image(input_image):
@@ -143,58 +167,63 @@ def check_input_image(input_image):
143
  raise gr.Error("No image uploaded!")
144
 
145
 
146
- def preprocess(input_image, do_remove_background):
 
 
147
 
148
- rembg_session = rembg.new_session() if do_remove_background else None
 
 
149
 
150
- if do_remove_background:
151
- input_image = remove_background(input_image, rembg_session)
152
- input_image = resize_foreground(input_image, 0.85)
153
 
154
- return input_image
155
 
 
 
156
 
157
- @spaces.GPU
158
- def generate_mvs(input_image, sample_steps, sample_seed):
159
 
160
- seed_everything(sample_seed)
161
-
162
  # sampling
163
- z123_image = pipeline(
164
- input_image,
165
- num_inference_steps=sample_steps
166
- ).images[0]
167
 
168
  show_image = np.asarray(z123_image, dtype=np.uint8)
169
- show_image = torch.from_numpy(show_image) # (960, 640, 3)
170
- show_image = rearrange(show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
171
- show_image = rearrange(show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
172
  show_image = Image.fromarray(show_image.numpy())
173
 
174
  return z123_image, show_image
175
 
176
 
177
- @spaces.GPU
178
- def make3d(images):
179
 
180
- global model
181
- if IS_FLEXICUBES:
182
- model.init_flexicubes_geometry(device, use_renderer=False)
183
- model = model.eval()
184
 
185
- images = np.asarray(images, dtype=np.float32) / 255.0
186
- images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
187
- images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
 
 
 
 
188
 
189
  input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device)
190
- render_cameras = get_render_cameras(batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device)
 
 
191
 
192
  images = images.unsqueeze(0).to(device)
193
- images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
 
 
194
 
195
  mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
196
  print(mesh_fpath)
197
- mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
198
  mesh_dirname = os.path.dirname(mesh_fpath)
199
  video_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.mp4")
200
  mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
@@ -202,41 +231,40 @@ def check_input_image(input_image):
202
  with torch.no_grad():
203
  # get triplane
204
  planes = model.forward_planes(images, input_cameras)
205
-
206
  # get mesh
207
  mesh_out = model.extract_mesh(
208
  planes,
209
  use_texture_map=False,
210
  **infer_config,
211
- )
212
 
213
  vertices, faces, vertex_colors = mesh_out
214
  vertices = vertices[:, [1, 2, 0]]
215
-
216
  save_glb(vertices, faces, vertex_colors, mesh_glb_fpath)
217
  save_obj(vertices, faces, vertex_colors, mesh_fpath)
218
-
219
  print(f"Mesh saved to {mesh_fpath}")
220
 
221
- return mesh_fpath, mesh_glb_fpath
222
 
223
 
224
  ###############################################################################
225
  # Configuration for MeshAnythingv2
226
  ###############################################################################
227
  model = load_v2()
228
- device = torch.device('cuda')
229
  accelerator = Accelerator(
230
  mixed_precision="fp16",
231
- )
232
  model = accelerator.prepare(model)
233
  model.eval()
234
  print("Model loaded to device")
235
 
 
236
  def wireframe_render(mesh):
237
- views = [
238
- (90, 20), (270, 20)
239
- ]
240
  mesh.vertices = mesh.vertices[:, [0, 2, 1]]
241
 
242
  bounding_box = mesh.bounds
@@ -247,7 +275,7 @@ def wireframe_render(mesh):
247
 
248
  # Function to render and return each view as an image
249
  def render_view(mesh, azimuth, elevation):
250
- ax = fig.add_subplot(111, projection='3d')
251
  ax.set_axis_off()
252
 
253
  # Extract vertices and faces for plotting
@@ -255,12 +283,14 @@ def wireframe_render(mesh):
255
  faces = mesh.faces
256
 
257
  # Plot faces
258
- ax.add_collection3d(Poly3DCollection(
259
- vertices[faces],
260
- facecolors=(0.8, 0.5, 0.2, 1.0), # Brownish yellow
261
- edgecolors='k',
262
- linewidths=0.5,
263
- ))
 
 
264
 
265
  # Set limits and center the view on the object
266
  ax.set_xlim(center[0] - scale / 2, center[0] + scale / 2)
@@ -272,7 +302,7 @@ def wireframe_render(mesh):
272
 
273
  # Save the figure to a buffer
274
  buf = io.BytesIO()
275
- plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=300)
276
  plt.clf()
277
  buf.seek(0)
278
 
@@ -286,7 +316,7 @@ def wireframe_render(mesh):
286
  total_width = sum(widths)
287
  max_height = max(heights)
288
 
289
- combined_image = Image.new('RGBA', (total_width, max_height))
290
 
291
  x_offset = 0
292
  for img in images:
@@ -300,14 +330,17 @@ def wireframe_render(mesh):
300
  plt.close(fig)
301
  return save_path
302
 
303
- @spaces.GPU(duration=360)
304
- def do_inference(input_3d, sample_seed=0, do_sampling=False, do_marching_cubes=False):
305
- set_seed(sample_seed)
306
- print("Seed value:", sample_seed)
307
 
308
- input_mesh = trimesh.load(input_3d)
309
- pc_list, mesh_list = process_mesh_to_pc([input_mesh], marching_cubes = do_marching_cubes)
310
- pc_normal = pc_list[0] # 4096, 6
 
 
 
 
 
 
 
311
  mesh = mesh_list[0]
312
  vertices = mesh.vertices
313
 
@@ -330,20 +363,26 @@ def wireframe_render(mesh):
330
  try:
331
  if mesh.visual.vertex_colors is not None:
332
  orange_color = np.array([255, 165, 0, 255], dtype=np.uint8)
333
-
334
- mesh.visual.vertex_colors = np.tile(orange_color, (mesh.vertices.shape[0], 1))
 
 
335
  else:
336
  orange_color = np.array([255, 165, 0, 255], dtype=np.uint8)
337
- mesh.visual.vertex_colors = np.tile(orange_color, (mesh.vertices.shape[0], 1))
338
- except Exception as e:
339
- print(e)
340
- input_save_name = f"processed_input_{int(time.time())}.obj"
341
- mesh.export(input_save_name)
342
- input_render_res = wireframe_render(mesh)
 
 
343
 
344
- pc_coor = pc_coor / np.abs(pc_coor).max() * 0.99 # input should be from -1 to 1
345
 
346
- assert (np.linalg.norm(normals, axis=-1) > 0.99).all(), "normals should be unit vectors, something wrong"
 
 
347
  normalized_pc_normal = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16)
348
 
349
  input = torch.tensor(normalized_pc_normal, dtype=torch.float16, device=device)[None]
@@ -352,17 +391,18 @@ def wireframe_render(mesh):
352
  # with accelerator.autocast():
353
  with accelerator.autocast():
354
  outputs = model(input, do_sampling)
355
- print("Model inference done")
356
- recon_mesh = outputs[0]
357
 
358
- valid_mask = torch.all(~torch.isnan(recon_mesh.reshape((-1, 9))), dim=1)
359
  recon_mesh = recon_mesh[valid_mask] # nvalid_face x 3 x 3
360
  vertices = recon_mesh.reshape(-1, 3).cpu()
361
  vertices_index = np.arange(len(vertices)) # 0, 1, ..., 3 x face
362
  triangles = vertices_index.reshape(-1, 3)
363
 
364
- artist_mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, force="mesh",
365
- merge_primitives=True)
 
366
 
367
  artist_mesh.merge_vertices()
368
  artist_mesh.update_faces(artist_mesh.nondegenerate_faces())
@@ -373,40 +413,45 @@ def wireframe_render(mesh):
373
  if artist_mesh.visual.vertex_colors is not None:
374
  orange_color = np.array([255, 165, 0, 255], dtype=np.uint8)
375
 
376
- artist_mesh.visual.vertex_colors = np.tile(orange_color, (artist_mesh.vertices.shape[0], 1))
 
 
377
  else:
378
  orange_color = np.array([255, 165, 0, 255], dtype=np.uint8)
379
- artist_mesh.visual.vertex_colors = np.tile(orange_color, (artist_mesh.vertices.shape[0], 1))
 
 
380
 
381
- num_faces = len(artist_mesh.faces)
382
 
383
- brown_color = np.array([165, 42, 42, 255], dtype=np.uint8)
384
- face_colors = np.tile(brown_color, (num_faces, 1))
385
 
386
- artist_mesh.visual.face_colors = face_colors
387
  # add time stamp to avoid cache
388
  save_name = f"output_{int(time.time())}.obj"
389
  artist_mesh.export(save_name)
390
  output_render = wireframe_render(artist_mesh)
391
  return input_save_name, input_render_res, save_name, output_render
392
 
 
393
  # Output gradio
394
  output_model_obj = gr.Model3D(
395
  label="Generated Mesh (OBJ Format)",
396
  display_mode="wireframe",
397
  clear_color=[1, 1, 1, 1],
398
- )
399
  preprocess_model_obj = gr.Model3D(
400
  label="Processed Input Mesh (OBJ Format)",
401
  display_mode="wireframe",
402
  clear_color=[1, 1, 1, 1],
403
- )
404
  input_image_render = gr.Image(
405
  label="Wireframe Render of Processed Input Mesh",
406
- )
407
  output_image_render = gr.Image(
408
  label="Wireframe Render of Generated Mesh",
409
- )
410
 
411
  ###############################################################################
412
  # Gradio
@@ -456,135 +501,150 @@ STEP4_HEADER = """
456
  with gr.Blocks() as demo:
457
  gr.Markdown(HEADER)
458
  gr.Markdown(STEP1_HEADER)
459
- with gr.Row(variant = "panel"):
460
  with gr.Column():
461
  with gr.Row():
462
  input_image = gr.Image(
463
- label = "Input Image",
464
- image_mode = "RGBA",
465
- sources = "upload",
466
  type="pil",
467
- elem_id="content_image"
468
- )
469
- processed_image = gr.Image(label="Processed Image",
 
470
  image_mode="RGBA",
471
  type="pil",
472
- interactive=False
 
 
 
 
 
 
 
 
 
473
  )
474
- with gr.Row():
475
- with gr.Group():
476
- do_remove_background = gr.Checkbox(
477
- label="Remove Background",
478
- value=True)
479
- sample_seed = gr.Number(
480
- value=42,
481
- label="Seed Value",
482
- precision=0
483
- )
484
- sample_steps = gr.Slider(
485
- label="Sample Steps",
486
- minimum=30,
487
- maximum=75,
488
- value=75,
489
- step=5
490
- )
491
- with gr.Row():
492
- step1_submit = gr.Button("Generate", elem_id="generate", variant="primary")
493
- with gr.Column():
494
- with gr.Row():
495
- with gr.Column():
496
- mv_show_images = gr.Image(
497
- label="Generated Multi-views",
498
- type="pil",
499
- width=379,
500
- interactive=False
501
- )
502
- with gr.Column():
503
- with gr.Tab("OBJ"):
504
- output_model_obj = gr.Model3D(
505
- label = "Output Model (OBJ Format)",
506
- interactive = False,
507
- )
508
- gr.Markdown("Note: Downloaded object will be flipped in case of .obj export. Export .glb instead or manually flip it before usage.")
509
- with gr.Tab("GLB"):
510
- output_model_glb = gr.Model3D(
511
- label="Output Model (GLB Format)",
512
- interactive=False,
513
- )
514
- gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
515
- gr.Markdown('''Try a different <b>seed value</b> if the result is unsatisfying (Default: 42).''')
516
- with gr.Row():
517
- gr.Markdown(STEP2_HEADER)
518
- with gr.Row(variant="panel"):
519
- with gr.Column():
520
- with gr.Row():
521
- input_3d = gr.Model3D(
522
- label="Input Mesh",
523
- display_mode="wireframe",
524
- clear_color=[1,1,1,1],
525
- )
526
-
527
- with gr.Row():
528
- with gr.Group():
529
- do_marching_cubes = gr.Checkbox(label="Preprocess with Marching Cubes", value=False)
530
- do_sampling = gr.Checkbox(label="Random Sampling", value=False)
531
- sample_seed = gr.Number(value=0, label="Seed Value", precision=0)
532
-
533
- with gr.Row():
534
- step2_submit = gr.Button("Generate", elem_id="generate", variant="primary")
535
-
536
- with gr.Row(variant="panel"):
537
- mesh_examples = gr.Examples(
538
- examples=[
539
- os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
540
- ],
541
- inputs=input_3d,
542
- outputs=[preprocess_model_obj, input_image_render, output_model_obj, output_image_render],
543
- fn=do_inference,
544
- cache_examples = False,
545
- examples_per_page=10
546
- )
547
-
548
- with gr.Column():
549
- with gr.Row():
550
- input_image_render.render()
551
- with gr.Row():
552
- with gr.Tab("OBJ"):
553
- preprocess_model_obj.render()
554
- with gr.Row():
555
- output_image_render.render()
556
- with gr.Row():
557
- with gr.Tab("OBJ"):
558
- output_model_obj.render()
559
- with gr.Row():
560
- gr.Markdown('''Try click random sampling and different <b>Seed Value</b> if the result is unsatisfying''')
561
-
562
- gr.Markdown(STEP3_HEADER)
563
- gr.Markdown(STEP4_HEADER)
564
-
565
- mv_images = gr.State()
566
-
567
- step1_submit.click(fn=check_input_image, inputs=[input_image]).success(
568
- fn=preprocess,
569
- inputs=[input_image, do_remove_background],
570
- outputs=[processed_image],
571
- ).success(
572
- fn=generate_mvs,
573
- inputs=[processed_image, sample_steps, sample_seed],
574
- outputs=[mv_images, mv_show_images],
575
- ).success(
576
- fn=make3d,
577
- inputs=[mv_images],
578
- outputs=[output_model_obj, output_model_glb]
579
- )
580
-
581
- step2_submit.click(
582
- fn=do_inference,
583
- inputs=[input_3d, sample_seed, do_sampling, do_marching_cubes],
584
- outputs=[preprocess_model_obj, input_image_render, output_model_obj, output_image_render],
585
- )
586
-
587
-
588
-
589
- demo.queue(max_size=10)
590
- demo.launch()
 
 
 
 
 
 
19
  import shutil
20
  from src.utils.train_util import instantiate_from_config
21
  from src.utils.camera_util import (
22
+ FOV_to_intrinsics,
23
  get_zero123plus_input_cameras,
24
  get_circular_camera_poses,
25
+ )
26
  from src.utils.mesh_util import save_obj, save_glb
27
  from src.utils.infer_util import remove_background, resize_foreground, images_to_video
28
 
 
44
  ###############################################################################
45
  # Configuration for InstantMesh
46
  ###############################################################################
47
+ def get_render_cameras(
48
+ batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False
49
+ ):
50
  """
51
  Get the rendering camera parameters.
52
  """
 
56
  cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
57
  else:
58
  extrinsics = c2ws.flatten(-2)
59
+ intrinsics = (
60
+ FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
61
+ )
62
  cameras = torch.cat([extrinsics, intrinsics], dim=-1)
63
  cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
64
+ return cameras
65
 
66
 
67
+ def images_to_video(images, output_path, fps=30):
68
  # images: (N, C, H, W)
69
  os.makedirs(os.path.dirname(output_path), exist_ok=True)
70
  frames = []
71
  for i in range(images.shape[0]):
72
+ frame = (
73
+ (images[i].permute(1, 2, 0).cpu().numpy() * 255)
74
+ .astype(np.uint8)
75
+ .clip(0, 255)
76
+ )
77
+ assert (
78
+ frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3]
79
+ ), f"Frame shape mismatch: {frame.shape} vs {images.shape}"
80
+ assert (
81
+ frame.min() >= 0 and frame.max() <= 255
82
+ ), f"Frame value out of range: {frame.min()} ~ {frame.max()}"
83
  frames.append(frame)
84
+ imageio.mimwrite(output_path, np.stack(frames), fps=fps, codec="h264")
85
+
86
 
87
+ def find_cuda():
88
  # Check if CUDA_HOME or CUDA_PATH environment variables are set
89
+ cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
90
 
91
  if cuda_home and os.path.exists(cuda_home):
92
  return cuda_home
93
 
94
  # Search for the nvcc executable in the system's PATH
95
+ nvcc_path = shutil.which("nvcc")
96
 
97
  if nvcc_path:
98
  # Remove the 'bin/nvcc' part to get the CUDA installation path
99
  cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
100
  return cuda_path
101
 
102
+ return None
103
 
 
104
 
105
+ cuda_path = find_cuda()
106
+
107
+ if cuda_path:
108
+ print(f"CUDA installation found at: {cuda_path}")
109
+ else:
110
+ print("CUDA installation not found")
111
 
112
+ config_path = "configs/instant-mesh-large.yaml"
113
+ config = OmegaConf.load(config_path)
114
+ config_name = os.path.basename(config_path).replace(".yaml", "")
115
+ model_config = config.model_config
116
+ infer_config = config.infer_config
117
 
118
+ IS_FLEXICUBES = True if config_name.startswith("instant-mesh") else False
119
 
120
+ device = torch.device("cuda")
121
 
122
  # load diffusion model
123
+ print("Loading diffusion model ...")
124
  pipeline = DiffusionPipeline.from_pretrained(
125
+ "sudo-ai/zero123plus-v1.2",
126
  custom_pipeline="zero123plus",
127
  torch_dtype=torch.float16,
128
+ )
129
  pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
130
+ pipeline.scheduler.config, timestep_spacing="trailing"
131
+ )
132
 
133
  # load custom white-background UNet
134
+ unet_ckpt_path = hf_hub_download(
135
+ repo_id="TencentARC/InstantMesh",
136
+ filename="diffusion_pytorch_model.bin",
137
+ repo_type="model",
138
+ )
139
+ state_dict = torch.load(unet_ckpt_path, map_location="cpu")
140
  pipeline.unet.load_state_dict(state_dict, strict=True)
141
 
142
  pipeline = pipeline.to(device)
143
 
144
  # load reconstruction model
145
+ print("Loading reconstruction model ...")
146
+ model_ckpt_path = hf_hub_download(
147
+ repo_id="TencentARC/InstantMesh",
148
+ filename="instant_mesh_large.ckpt",
149
+ repo_type="model",
150
+ )
151
  model = instantiate_from_config(model_config)
152
+ state_dict = torch.load(model_ckpt_path, map_location="cpu")["state_dict"]
153
+ state_dict = {
154
+ k[14:]: v
155
+ for k, v in state_dict.items()
156
+ if k.startswith("lrm_generator.") and "source_camera" not in k
157
+ }
158
  model.load_state_dict(state_dict, strict=True)
159
 
160
  model = model.to(device)
161
 
162
+ print("Loading Finished!")
163
 
164
 
165
  def check_input_image(input_image):
 
167
  raise gr.Error("No image uploaded!")
168
 
169
 
170
+ def preprocess(input_image, do_remove_background):
171
+
172
+ rembg_session = rembg.new_session() if do_remove_background else None
173
 
174
+ if do_remove_background:
175
+ input_image = remove_background(input_image, rembg_session)
176
+ input_image = resize_foreground(input_image, 0.85)
177
 
178
+ return input_image
 
 
179
 
 
180
 
181
+ @spaces.GPU
182
+ def generate_mvs(input_image, sample_steps, sample_seed):
183
 
184
+ seed_everything(sample_seed)
 
185
 
 
 
186
  # sampling
187
+ z123_image = pipeline(input_image, num_inference_steps=sample_steps).images[0]
 
 
 
188
 
189
  show_image = np.asarray(z123_image, dtype=np.uint8)
190
+ show_image = torch.from_numpy(show_image) # (960, 640, 3)
191
+ show_image = rearrange(show_image, "(n h) (m w) c -> (n m) h w c", n=3, m=2)
192
+ show_image = rearrange(show_image, "(n m) h w c -> (n h) (m w) c", n=2, m=3)
193
  show_image = Image.fromarray(show_image.numpy())
194
 
195
  return z123_image, show_image
196
 
197
 
198
+ @spaces.GPU
199
+ def make3d(images):
200
 
201
+ global model
202
+ if IS_FLEXICUBES:
203
+ model.init_flexicubes_geometry(device, use_renderer=False)
204
+ model = model.eval()
205
 
206
+ images = np.asarray(images, dtype=np.float32) / 255.0
207
+ images = (
208
+ torch.from_numpy(images).permute(2, 0, 1).contiguous().float()
209
+ ) # (3, 960, 640)
210
+ images = rearrange(
211
+ images, "c (n h) (m w) -> (n m) c h w", n=3, m=2
212
+ ) # (6, 3, 320, 320)
213
 
214
  input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device)
215
+ render_cameras = get_render_cameras(
216
+ batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES
217
+ ).to(device)
218
 
219
  images = images.unsqueeze(0).to(device)
220
+ images = v2.functional.resize(
221
+ images, (320, 320), interpolation=3, antialias=True
222
+ ).clamp(0, 1)
223
 
224
  mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
225
  print(mesh_fpath)
226
+ mesh_basename = os.path.basename(mesh_fpath).split(".")[0]
227
  mesh_dirname = os.path.dirname(mesh_fpath)
228
  video_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.mp4")
229
  mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
 
231
  with torch.no_grad():
232
  # get triplane
233
  planes = model.forward_planes(images, input_cameras)
234
+
235
  # get mesh
236
  mesh_out = model.extract_mesh(
237
  planes,
238
  use_texture_map=False,
239
  **infer_config,
240
+ )
241
 
242
  vertices, faces, vertex_colors = mesh_out
243
  vertices = vertices[:, [1, 2, 0]]
244
+
245
  save_glb(vertices, faces, vertex_colors, mesh_glb_fpath)
246
  save_obj(vertices, faces, vertex_colors, mesh_fpath)
247
+
248
  print(f"Mesh saved to {mesh_fpath}")
249
 
250
+ return mesh_fpath, mesh_glb_fpath
251
 
252
 
253
  ###############################################################################
254
  # Configuration for MeshAnythingv2
255
  ###############################################################################
256
  model = load_v2()
257
+ device = torch.device("cuda")
258
  accelerator = Accelerator(
259
  mixed_precision="fp16",
260
+ )
261
  model = accelerator.prepare(model)
262
  model.eval()
263
  print("Model loaded to device")
264
 
265
+
266
  def wireframe_render(mesh):
267
+ views = [(90, 20), (270, 20)]
 
 
268
  mesh.vertices = mesh.vertices[:, [0, 2, 1]]
269
 
270
  bounding_box = mesh.bounds
 
275
 
276
  # Function to render and return each view as an image
277
  def render_view(mesh, azimuth, elevation):
278
+ ax = fig.add_subplot(111, projection="3d")
279
  ax.set_axis_off()
280
 
281
  # Extract vertices and faces for plotting
 
283
  faces = mesh.faces
284
 
285
  # Plot faces
286
+ ax.add_collection3d(
287
+ Poly3DCollection(
288
+ vertices[faces],
289
+ facecolors=(0.8, 0.5, 0.2, 1.0), # Brownish yellow
290
+ edgecolors="k",
291
+ linewidths=0.5,
292
+ )
293
+ )
294
 
295
  # Set limits and center the view on the object
296
  ax.set_xlim(center[0] - scale / 2, center[0] + scale / 2)
 
302
 
303
  # Save the figure to a buffer
304
  buf = io.BytesIO()
305
+ plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0, dpi=300)
306
  plt.clf()
307
  buf.seek(0)
308
 
 
316
  total_width = sum(widths)
317
  max_height = max(heights)
318
 
319
+ combined_image = Image.new("RGBA", (total_width, max_height))
320
 
321
  x_offset = 0
322
  for img in images:
 
330
  plt.close(fig)
331
  return save_path
332
 
 
 
 
 
333
 
334
+ @spaces.GPU(duration=360)
335
+ def do_inference(input_3d, sample_seed=0, do_sampling=False, do_marching_cubes=False):
336
+ set_seed(sample_seed)
337
+ print("Seed value:", sample_seed)
338
+
339
+ input_mesh = trimesh.load(input_3d)
340
+ pc_list, mesh_list = process_mesh_to_pc(
341
+ [input_mesh], marching_cubes=do_marching_cubes
342
+ )
343
+ pc_normal = pc_list[0] # 4096, 6
344
  mesh = mesh_list[0]
345
  vertices = mesh.vertices
346
 
 
363
  try:
364
  if mesh.visual.vertex_colors is not None:
365
  orange_color = np.array([255, 165, 0, 255], dtype=np.uint8)
366
+
367
+ mesh.visual.vertex_colors = np.tile(
368
+ orange_color, (mesh.vertices.shape[0], 1)
369
+ )
370
  else:
371
  orange_color = np.array([255, 165, 0, 255], dtype=np.uint8)
372
+ mesh.visual.vertex_colors = np.tile(
373
+ orange_color, (mesh.vertices.shape[0], 1)
374
+ )
375
+ except Exception as e:
376
+ print(e)
377
+ input_save_name = f"processed_input_{int(time.time())}.obj"
378
+ mesh.export(input_save_name)
379
+ input_render_res = wireframe_render(mesh)
380
 
381
+ pc_coor = pc_coor / np.abs(pc_coor).max() * 0.99 # input should be from -1 to 1
382
 
383
+ assert (
384
+ np.linalg.norm(normals, axis=-1) > 0.99
385
+ ).all(), "normals should be unit vectors, something wrong"
386
  normalized_pc_normal = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16)
387
 
388
  input = torch.tensor(normalized_pc_normal, dtype=torch.float16, device=device)[None]
 
391
  # with accelerator.autocast():
392
  with accelerator.autocast():
393
  outputs = model(input, do_sampling)
394
+ print("Model inference done")
395
+ recon_mesh = outputs[0]
396
 
397
+ valid_mask = torch.all(~torch.isnan(recon_mesh.reshape((-1, 9))), dim=1)
398
  recon_mesh = recon_mesh[valid_mask] # nvalid_face x 3 x 3
399
  vertices = recon_mesh.reshape(-1, 3).cpu()
400
  vertices_index = np.arange(len(vertices)) # 0, 1, ..., 3 x face
401
  triangles = vertices_index.reshape(-1, 3)
402
 
403
+ artist_mesh = trimesh.Trimesh(
404
+ vertices=vertices, faces=triangles, force="mesh", merge_primitives=True
405
+ )
406
 
407
  artist_mesh.merge_vertices()
408
  artist_mesh.update_faces(artist_mesh.nondegenerate_faces())
 
413
  if artist_mesh.visual.vertex_colors is not None:
414
  orange_color = np.array([255, 165, 0, 255], dtype=np.uint8)
415
 
416
+ artist_mesh.visual.vertex_colors = np.tile(
417
+ orange_color, (artist_mesh.vertices.shape[0], 1)
418
+ )
419
  else:
420
  orange_color = np.array([255, 165, 0, 255], dtype=np.uint8)
421
+ artist_mesh.visual.vertex_colors = np.tile(
422
+ orange_color, (artist_mesh.vertices.shape[0], 1)
423
+ )
424
 
425
+ num_faces = len(artist_mesh.faces)
426
 
427
+ brown_color = np.array([165, 42, 42, 255], dtype=np.uint8)
428
+ face_colors = np.tile(brown_color, (num_faces, 1))
429
 
430
+ artist_mesh.visual.face_colors = face_colors
431
  # add time stamp to avoid cache
432
  save_name = f"output_{int(time.time())}.obj"
433
  artist_mesh.export(save_name)
434
  output_render = wireframe_render(artist_mesh)
435
  return input_save_name, input_render_res, save_name, output_render
436
 
437
+
438
  # Output gradio
439
  output_model_obj = gr.Model3D(
440
  label="Generated Mesh (OBJ Format)",
441
  display_mode="wireframe",
442
  clear_color=[1, 1, 1, 1],
443
+ )
444
  preprocess_model_obj = gr.Model3D(
445
  label="Processed Input Mesh (OBJ Format)",
446
  display_mode="wireframe",
447
  clear_color=[1, 1, 1, 1],
448
+ )
449
  input_image_render = gr.Image(
450
  label="Wireframe Render of Processed Input Mesh",
451
+ )
452
  output_image_render = gr.Image(
453
  label="Wireframe Render of Generated Mesh",
454
+ )
455
 
456
  ###############################################################################
457
  # Gradio
 
501
  with gr.Blocks() as demo:
502
  gr.Markdown(HEADER)
503
  gr.Markdown(STEP1_HEADER)
504
+ with gr.Row(variant="panel"):
505
  with gr.Column():
506
  with gr.Row():
507
  input_image = gr.Image(
508
+ label="Input Image",
509
+ image_mode="RGBA",
510
+ sources="upload",
511
  type="pil",
512
+ elem_id="content_image",
513
+ )
514
+ processed_image = gr.Image(
515
+ label="Processed Image",
516
  image_mode="RGBA",
517
  type="pil",
518
+ interactive=False,
519
+ )
520
+ with gr.Row():
521
+ with gr.Group():
522
+ do_remove_background = gr.Checkbox(
523
+ label="Remove Background", value=True
524
+ )
525
+ sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
526
+ sample_steps = gr.Slider(
527
+ label="Sample Steps", minimum=30, maximum=75, value=75, step=5
528
  )
529
+ with gr.Row():
530
+ step1_submit = gr.Button(
531
+ "Generate", elem_id="generate", variant="primary"
532
+ )
533
+ with gr.Column():
534
+ with gr.Row():
535
+ with gr.Column():
536
+ mv_show_images = gr.Image(
537
+ label="Generated Multi-views",
538
+ type="pil",
539
+ width=379,
540
+ interactive=False,
541
+ )
542
+ with gr.Column():
543
+ with gr.Tab("OBJ"):
544
+ output_model_obj = gr.Model3D(
545
+ label="Output Model (OBJ Format)",
546
+ interactive=False,
547
+ )
548
+ gr.Markdown(
549
+ "Note: Downloaded object will be flipped in case of .obj export. Export .glb instead or manually flip it before usage."
550
+ )
551
+ with gr.Tab("GLB"):
552
+ output_model_glb = gr.Model3D(
553
+ label="Output Model (GLB Format)",
554
+ interactive=False,
555
+ )
556
+ gr.Markdown(
557
+ "Note: The model shown here has a darker appearance. Download to get correct results."
558
+ )
559
+ gr.Markdown(
560
+ """Try a different <b>seed value</b> if the result is unsatisfying (Default: 42)."""
561
+ )
562
+
563
+ gr.Markdown(STEP2_HEADER)
564
+ with gr.Row(variant="panel"):
565
+ with gr.Column():
566
+ with gr.Row():
567
+ input_3d = gr.Model3D(
568
+ label="Input Mesh",
569
+ display_mode="wireframe",
570
+ clear_color=[1, 1, 1, 1],
571
+ )
572
+
573
+ with gr.Row():
574
+ with gr.Group():
575
+ do_marching_cubes = gr.Checkbox(
576
+ label="Preprocess with Marching Cubes", value=False
577
+ )
578
+ do_sampling = gr.Checkbox(label="Random Sampling", value=False)
579
+ sample_seed = gr.Number(value=0, label="Seed Value", precision=0)
580
+
581
+ with gr.Row():
582
+ step2_submit = gr.Button(
583
+ "Generate", elem_id="generate", variant="primary"
584
+ )
585
+
586
+ with gr.Row(variant="panel"):
587
+ mesh_examples = gr.Examples(
588
+ examples=[
589
+ os.path.join("examples", img_name)
590
+ for img_name in sorted(os.listdir("examples"))
591
+ ],
592
+ inputs=input_3d,
593
+ outputs=[
594
+ preprocess_model_obj,
595
+ input_image_render,
596
+ output_model_obj,
597
+ output_image_render,
598
+ ],
599
+ fn=do_inference,
600
+ cache_examples=False,
601
+ examples_per_page=10,
602
+ )
603
+
604
+ with gr.Column():
605
+ with gr.Row():
606
+ input_image_render.render()
607
+ with gr.Row():
608
+ with gr.Tab("OBJ"):
609
+ preprocess_model_obj.render()
610
+ with gr.Row():
611
+ output_image_render.render()
612
+ with gr.Row():
613
+ with gr.Tab("OBJ"):
614
+ output_model_obj.render()
615
+ with gr.Row():
616
+ gr.Markdown(
617
+ """Try click random sampling and different <b>Seed Value</b> if the result is unsatisfying"""
618
+ )
619
+
620
+ gr.Markdown(STEP3_HEADER)
621
+ gr.Markdown(STEP4_HEADER)
622
+
623
+ mv_images = gr.State()
624
+
625
+ step1_submit.click(fn=check_input_image, inputs=[input_image]).success(
626
+ fn=preprocess,
627
+ inputs=[input_image, do_remove_background],
628
+ outputs=[processed_image],
629
+ ).success(
630
+ fn=generate_mvs,
631
+ inputs=[processed_image, sample_steps, sample_seed],
632
+ outputs=[mv_images, mv_show_images],
633
+ ).success(
634
+ fn=make3d, inputs=[mv_images], outputs=[output_model_obj, output_model_glb]
635
+ )
636
+
637
+ step2_submit.click(
638
+ fn=do_inference,
639
+ inputs=[input_3d, sample_seed, do_sampling, do_marching_cubes],
640
+ outputs=[
641
+ preprocess_model_obj,
642
+ input_image_render,
643
+ output_model_obj,
644
+ output_image_render,
645
+ ],
646
+ )
647
+
648
+
649
+ demo.queue(max_size=10)
650
+ demo.launch()