yfdeng commited on
Commit
45a206f
·
1 Parent(s): 744eb4e

finish zerogpu debug

Browse files
Files changed (39) hide show
  1. Anymate/.gitignore → .gitignore +3 -1
  2. Anymate/model.py +16 -0
  3. Anymate/models/conn.py +4 -0
  4. Anymate/models/joint.py +2 -1
  5. Anymate/models/skin.py +3 -0
  6. Anymate/utils/ui_utils.py +17 -5
  7. Anymate/utils/utils.py +20 -17
  8. ThirdParty/Rignet_utils/__pycache__/__init__.cpython-310.pyc +0 -0
  9. ThirdParty/Rignet_utils/__pycache__/binvox_rw.cpython-310.pyc +0 -0
  10. ThirdParty/__pycache__/__init__.cpython-310.pyc +0 -0
  11. ThirdParty/eg3d/__pycache__/__init__.cpython-310.pyc +0 -0
  12. ThirdParty/eg3d/dnnlib/__pycache__/__init__.cpython-310.pyc +0 -0
  13. ThirdParty/eg3d/dnnlib/__pycache__/util.cpython-310.pyc +0 -0
  14. ThirdParty/eg3d/torch_utils/__pycache__/__init__.cpython-310.pyc +0 -0
  15. ThirdParty/eg3d/torch_utils/__pycache__/custom_ops.cpython-310.pyc +0 -0
  16. ThirdParty/eg3d/torch_utils/__pycache__/misc.cpython-310.pyc +0 -0
  17. ThirdParty/eg3d/torch_utils/__pycache__/persistence.cpython-310.pyc +0 -0
  18. ThirdParty/eg3d/torch_utils/ops/__pycache__/__init__.cpython-310.pyc +0 -0
  19. ThirdParty/eg3d/torch_utils/ops/__pycache__/bias_act.cpython-310.pyc +0 -0
  20. ThirdParty/eg3d/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-310.pyc +0 -0
  21. ThirdParty/eg3d/torch_utils/ops/__pycache__/conv2d_resample.cpython-310.pyc +0 -0
  22. ThirdParty/eg3d/torch_utils/ops/__pycache__/fma.cpython-310.pyc +0 -0
  23. ThirdParty/eg3d/torch_utils/ops/__pycache__/upfirdn2d.cpython-310.pyc +0 -0
  24. ThirdParty/eg3d/training/__pycache__/__init__.cpython-310.pyc +0 -0
  25. ThirdParty/eg3d/training/__pycache__/networks_stylegan2.cpython-310.pyc +0 -0
  26. ThirdParty/michelangelo/__pycache__/__init__.cpython-310.pyc +0 -0
  27. ThirdParty/michelangelo/graphics/__pycache__/__init__.cpython-310.pyc +0 -0
  28. ThirdParty/michelangelo/graphics/primitives/__pycache__/__init__.cpython-310.pyc +0 -0
  29. ThirdParty/michelangelo/graphics/primitives/__pycache__/mesh.cpython-310.pyc +0 -0
  30. ThirdParty/michelangelo/graphics/primitives/__pycache__/volume.cpython-310.pyc +0 -0
  31. ThirdParty/michelangelo/models/__pycache__/__init__.cpython-310.pyc +0 -0
  32. ThirdParty/michelangelo/models/modules/__pycache__/__init__.cpython-310.pyc +0 -0
  33. ThirdParty/michelangelo/models/modules/__pycache__/checkpoint.cpython-310.pyc +0 -0
  34. ThirdParty/michelangelo/models/modules/__pycache__/distributions.cpython-310.pyc +0 -0
  35. ThirdParty/michelangelo/models/modules/__pycache__/embedder.cpython-310.pyc +0 -0
  36. ThirdParty/michelangelo/models/modules/__pycache__/transformer_blocks.cpython-310.pyc +0 -0
  37. ThirdParty/michelangelo/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  38. ThirdParty/michelangelo/utils/__pycache__/misc.cpython-310.pyc +0 -0
  39. app.py +101 -128
Anymate/.gitignore → .gitignore RENAMED
@@ -23,4 +23,6 @@ blender-*
23
  *.csv
24
  *.tga
25
  *.png
26
- *.jpg
 
 
 
23
  *.csv
24
  *.tga
25
  *.png
26
+ *.jpg
27
+
28
+ core*
Anymate/model.py CHANGED
@@ -352,6 +352,22 @@ class EncoderDecoder(nn.Module):
352
  return cond
353
 
354
  def forward(self, data, device='cuda', downsample=False, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
  latents = self.encode(data, device)
356
  # print('latents shape', latents.shape)
357
 
 
352
  return cond
353
 
354
  def forward(self, data, device='cuda', downsample=False, **kwargs):
355
+ # data['points_cloud'] = data['points_cloud'].to(self.dtype).to(device)
356
+ # if 'vertices' in data.keys():
357
+ # data['vertices'] = torch.tensor(data['vertices']).to(self.dtype).to(device).unsqueeze(0)
358
+ # if 'joints' in data.keys():
359
+ # data['joints'] = torch.tensor(data['joints'], dtype=torch.float32).to(device).unsqueeze(0)
360
+ # if 'bones' in data.keys():
361
+ # data['bones'] = torch.tensor(data['bones'], dtype=torch.float32).to(device).unsqueeze(0)
362
+ # if 'joints_mask' in data.keys():
363
+ # data['joints_mask'] = torch.tensor(data['joints_mask'], dtype=torch.float32).to(device).unsqueeze(0)
364
+ # if 'bones_mask' in data.keys():
365
+ # data['bones_mask'] = torch.tensor(data['bones_mask'], dtype=torch.float32).to(device).unsqueeze(0)
366
+
367
+ for key in data.keys():
368
+ if key in ['points_cloud', 'vertices', 'joints', 'bones', 'joints_mask', 'bones_mask']:
369
+ data[key] = data[key].to(torch.float32).to(device).unsqueeze(0)
370
+
371
  latents = self.encode(data, device)
372
  # print('latents shape', latents.shape)
373
 
Anymate/models/conn.py CHANGED
@@ -24,6 +24,7 @@ class AttendjointsDecoder_con_combine(nn.Module):
24
  self.use_checkpoint = use_checkpoint
25
  self.separate = separate
26
  self.use_mask = use_mask
 
27
  # self.num_latents = num_latents
28
 
29
  # self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02)
@@ -104,6 +105,9 @@ class AttendjointsDecoder_con_combine(nn.Module):
104
  mask = data['joints_mask'].to(device)
105
  logits = logits.masked_fill(mask.unsqueeze(1) == 0, -1e8)
106
 
 
 
 
107
  return logits
108
 
109
  class AttendjointsDecoder_con_token(nn.Module):
 
24
  self.use_checkpoint = use_checkpoint
25
  self.separate = separate
26
  self.use_mask = use_mask
27
+ self.inference = False
28
  # self.num_latents = num_latents
29
 
30
  # self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02)
 
105
  mask = data['joints_mask'].to(device)
106
  logits = logits.masked_fill(mask.unsqueeze(1) == 0, -1e8)
107
 
108
+ if self.inference:
109
+ logits = torch.argmax(logits, dim=-1).squeeze(0).to('cpu')
110
+
111
  return logits
112
 
113
  class AttendjointsDecoder_con_token(nn.Module):
Anymate/models/joint.py CHANGED
@@ -8,6 +8,7 @@ from Anymate.utils.vol_utils import get_co, sample_from_planes, generate_planes
8
  from einops import repeat
9
  from sklearn.cluster import DBSCAN
10
  from Anymate.utils.vol_utils import extract_keypoints
 
11
 
12
  class TransformerDecoder(nn.Module):
13
  def __init__(self,
@@ -66,7 +67,7 @@ class TransformerDecoder(nn.Module):
66
  cluster_centers = []
67
  for cluster in set(clustering.labels_):
68
  cluster_centers.append(joints[clustering.labels_ == cluster].mean(axis=0))
69
- return cluster_centers
70
  return logits
71
 
72
 
 
8
  from einops import repeat
9
  from sklearn.cluster import DBSCAN
10
  from Anymate.utils.vol_utils import extract_keypoints
11
+ import numpy as np
12
 
13
  class TransformerDecoder(nn.Module):
14
  def __init__(self,
 
67
  cluster_centers = []
68
  for cluster in set(clustering.labels_):
69
  cluster_centers.append(joints[clustering.labels_ == cluster].mean(axis=0))
70
+ return torch.tensor(np.array(cluster_centers)).to('cpu')
71
  return logits
72
 
73
 
Anymate/models/skin.py CHANGED
@@ -140,6 +140,9 @@ class AttendjointsDecoder_combine(nn.Module):
140
 
141
  if downsample and not self.inference:
142
  return logits, idx
 
 
 
143
 
144
  return logits
145
 
 
140
 
141
  if downsample and not self.inference:
142
  return logits, idx
143
+
144
+ if self.inference:
145
+ logits = logits.softmax(dim=-1).squeeze(0).to('cpu')
146
 
147
  return logits
148
 
Anymate/utils/ui_utils.py CHANGED
@@ -198,6 +198,17 @@ def vis_skinning(normalized_mesh_file, joints, conns, skins):
198
  vis_file = visualize_results(mesh_file=normalized_mesh_file, joints=joints, conns=conns, skins=skins)
199
  return vis_file, vis_file
200
 
 
 
 
 
 
 
 
 
 
 
 
201
  def prepare_blender_file(normalized_mesh_file):
202
  if normalized_mesh_file is None:
203
  return None
@@ -245,11 +256,11 @@ def process_input(mesh_file):
245
 
246
  vis_file = visualize_results(mesh_file=normalized_mesh_file)
247
  pc = process_mesh_to_pc(normalized_mesh_file)
248
- pc = torch.from_numpy(pc).to(anymate_args.device).to(torch.float32)
249
 
250
  # print(pc.shape, pc.max(dim=0), pc.min(dim=0))
251
 
252
- return normalized_mesh_file, vis_file, vis_file, None, pc, None, None, None
253
 
254
 
255
  def get_model(checkpoint):
@@ -265,9 +276,10 @@ def get_result_connectivity(mesh_file, model, pc, joints):
265
  def get_result_skinning(mesh_file, model, pc, joints, conns):
266
  # mesh = trimesh.load(mesh_file)
267
  mesh = obj2mesh(mesh_file)
268
- vertices = torch.from_numpy(mesh.vertices).to(anymate_args.device).to(torch.float32)
269
- vertex_normals = torch.from_numpy(mesh.vertex_normals).to(anymate_args.device).to(torch.float32)
270
- vertices = torch.cat([vertices, vertex_normals], dim=-1)
 
271
  return get_skinning(pc, joints, conns, model, vertices=vertices, device=anymate_args.device, save=mesh_file.replace('object.obj', 'skins.pt'))
272
 
273
  def get_all_models(checkpoint_joint, checkpoint_conn, checkpoint_skin):
 
198
  vis_file = visualize_results(mesh_file=normalized_mesh_file, joints=joints, conns=conns, skins=skins)
199
  return vis_file, vis_file
200
 
201
+ def vis_all(normalized_mesh_file):
202
+ if normalized_mesh_file is None:
203
+ return None, None
204
+ joints = torch.load(normalized_mesh_file.replace('object.obj', 'joints.pt'))
205
+ conns = torch.load(normalized_mesh_file.replace('object.obj', 'conns.pt'))
206
+ skins = torch.load(normalized_mesh_file.replace('object.obj', 'skins.pt'))
207
+ print("All results loaded")
208
+ vis_file = visualize_results(mesh_file=normalized_mesh_file, joints=joints, conns=conns, skins=skins)
209
+ print("Finish Visualization")
210
+ return vis_file, vis_file
211
+
212
  def prepare_blender_file(normalized_mesh_file):
213
  if normalized_mesh_file is None:
214
  return None
 
256
 
257
  vis_file = visualize_results(mesh_file=normalized_mesh_file)
258
  pc = process_mesh_to_pc(normalized_mesh_file)
259
+ pc = torch.from_numpy(pc).to(torch.float32)
260
 
261
  # print(pc.shape, pc.max(dim=0), pc.min(dim=0))
262
 
263
+ return normalized_mesh_file, vis_file, vis_file, None, pc
264
 
265
 
266
  def get_model(checkpoint):
 
276
  def get_result_skinning(mesh_file, model, pc, joints, conns):
277
  # mesh = trimesh.load(mesh_file)
278
  mesh = obj2mesh(mesh_file)
279
+ # vertices = torch.from_numpy(mesh.vertices, device='cpu')
280
+ # vertex_normals = torch.from_numpy(mesh.vertex_normals, device='cpu')
281
+ # vertices = torch.cat([vertices, vertex_normals], dim=-1)
282
+ vertices = torch.tensor(np.concatenate([mesh.vertices, mesh.vertex_normals], axis=-1))
283
  return get_skinning(pc, joints, conns, model, vertices=vertices, device=anymate_args.device, save=mesh_file.replace('object.obj', 'skins.pt'))
284
 
285
  def get_all_models(checkpoint_joint, checkpoint_conn, checkpoint_skin):
Anymate/utils/utils.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
  from Anymate.model import EncoderDecoder
3
  from sklearn.cluster import DBSCAN
4
 
@@ -25,26 +26,28 @@ def load_checkpoint(path, device, num_joints):
25
 
26
  def get_joint(pc, model, device='cuda', save=None, vox=None, eps=0.03, min_samples=1):
27
  model.eval()
28
- data = {'points_cloud': pc.unsqueeze(0)}
29
  if vox is not None:
30
  data['vox'] = vox.unsqueeze(0)
31
  with torch.no_grad():
32
  model.decoder.inference_mode(eps=eps, min_samples=min_samples)
33
  joints = model(data, device=device)
34
- joints = torch.tensor(joints, dtype=torch.float32).to(device)
35
 
36
  if save is not None:
37
  torch.save(joints, save)
38
 
39
  return joints
40
-
41
  def get_connectivity(pc, joints, model, device='cuda',return_prob=False, save=None):
42
  model.eval()
43
- data = {'points_cloud': pc.unsqueeze(0), 'joints': joints.unsqueeze(0), 'joints_num': torch.tensor([joints.shape[0]]),
44
- 'joints_mask': torch.ones(joints.shape[0], device=device).unsqueeze(0)}
45
  with torch.no_grad():
46
- conns = model(data, device=device).softmax(dim=-1)
47
- conns = conns.squeeze(0) if return_prob else torch.argmax(conns, dim=-1).squeeze(0)
 
 
48
 
49
  if save is not None:
50
  torch.save(conns, save)
@@ -56,20 +59,20 @@ def get_skinning(pc, joints, conns, model, vertices=None, bones=None, device='cu
56
 
57
  if bones is None:
58
  bones = []
59
- for i in range(joints.shape[0]):
60
  if conns[i] != i:
61
- bones.append(torch.cat((joints[conns[i]], joints[i]), dim=-1))
62
- bones = torch.stack(bones, dim=0)
63
 
64
- data = {'points_cloud': pc.unsqueeze(0), 'bones': bones.unsqueeze(0), 'bones_num': torch.tensor([bones.shape[0]]),
65
- 'bones_mask': torch.ones(bones.shape[0], device=device).unsqueeze(0)}
66
-
67
- if vertices is not None:
68
- data['vertices'] = vertices.unsqueeze(0)
69
- model.decoder.inference = True
70
 
71
  with torch.no_grad():
72
- skins = model(data, device=device).softmax(dim=-1).squeeze(0)
 
 
 
 
73
 
74
  if save is not None:
75
  torch.save(skins, save)
 
1
  import torch
2
+ import numpy as np
3
  from Anymate.model import EncoderDecoder
4
  from sklearn.cluster import DBSCAN
5
 
 
26
 
27
  def get_joint(pc, model, device='cuda', save=None, vox=None, eps=0.03, min_samples=1):
28
  model.eval()
29
+ data = {'points_cloud': pc}
30
  if vox is not None:
31
  data['vox'] = vox.unsqueeze(0)
32
  with torch.no_grad():
33
  model.decoder.inference_mode(eps=eps, min_samples=min_samples)
34
  joints = model(data, device=device)
35
+ # joints = torch.tensor(joints, dtype=torch.float32).to(device)
36
 
37
  if save is not None:
38
  torch.save(joints, save)
39
 
40
  return joints
41
+
42
  def get_connectivity(pc, joints, model, device='cuda',return_prob=False, save=None):
43
  model.eval()
44
+ data = {'points_cloud': pc, 'joints': joints, 'joints_num': torch.tensor([len(joints)]),
45
+ 'joints_mask': torch.ones(len(joints))}
46
  with torch.no_grad():
47
+ if not return_prob:
48
+ model.decoder.inference = True
49
+ conns = model(data, device=device)
50
+ # conns = conns.squeeze(0) if return_prob else torch.argmax(conns, dim=-1).squeeze(0)
51
 
52
  if save is not None:
53
  torch.save(conns, save)
 
59
 
60
  if bones is None:
61
  bones = []
62
+ for i in range(len(joints)):
63
  if conns[i] != i:
64
+ bones.append(np.concatenate((joints[conns[i]], joints[i]), axis=-1))
65
+ bones = torch.tensor(np.array(bones))
66
 
67
+ data = {'points_cloud': pc, 'bones': bones, 'bones_num': torch.tensor([len(bones)]),
68
+ 'bones_mask': torch.ones(len(bones))}
 
 
 
 
69
 
70
  with torch.no_grad():
71
+ if vertices is not None:
72
+ data['vertices'] = vertices
73
+ model.decoder.inference = True
74
+
75
+ skins = model(data, device=device)
76
 
77
  if save is not None:
78
  torch.save(skins, save)
ThirdParty/Rignet_utils/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (156 Bytes)
 
ThirdParty/Rignet_utils/__pycache__/binvox_rw.cpython-310.pyc DELETED
Binary file (6.49 kB)
 
ThirdParty/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (143 Bytes)
 
ThirdParty/eg3d/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (148 Bytes)
 
ThirdParty/eg3d/dnnlib/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (222 Bytes)
 
ThirdParty/eg3d/dnnlib/__pycache__/util.cpython-310.pyc DELETED
Binary file (14 kB)
 
ThirdParty/eg3d/torch_utils/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (160 Bytes)
 
ThirdParty/eg3d/torch_utils/__pycache__/custom_ops.cpython-310.pyc DELETED
Binary file (3.69 kB)
 
ThirdParty/eg3d/torch_utils/__pycache__/misc.cpython-310.pyc DELETED
Binary file (9.39 kB)
 
ThirdParty/eg3d/torch_utils/__pycache__/persistence.cpython-310.pyc DELETED
Binary file (8.71 kB)
 
ThirdParty/eg3d/torch_utils/ops/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (164 Bytes)
 
ThirdParty/eg3d/torch_utils/ops/__pycache__/bias_act.cpython-310.pyc DELETED
Binary file (8.29 kB)
 
ThirdParty/eg3d/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-310.pyc DELETED
Binary file (7.11 kB)
 
ThirdParty/eg3d/torch_utils/ops/__pycache__/conv2d_resample.cpython-310.pyc DELETED
Binary file (4.44 kB)
 
ThirdParty/eg3d/torch_utils/ops/__pycache__/fma.cpython-310.pyc DELETED
Binary file (1.71 kB)
 
ThirdParty/eg3d/torch_utils/ops/__pycache__/upfirdn2d.cpython-310.pyc DELETED
Binary file (14.1 kB)
 
ThirdParty/eg3d/training/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (157 Bytes)
 
ThirdParty/eg3d/training/__pycache__/networks_stylegan2.cpython-310.pyc DELETED
Binary file (22.3 kB)
 
ThirdParty/michelangelo/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (156 Bytes)
 
ThirdParty/michelangelo/graphics/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (165 Bytes)
 
ThirdParty/michelangelo/graphics/primitives/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (314 Bytes)
 
ThirdParty/michelangelo/graphics/primitives/__pycache__/mesh.cpython-310.pyc DELETED
Binary file (2.96 kB)
 
ThirdParty/michelangelo/graphics/primitives/__pycache__/volume.cpython-310.pyc DELETED
Binary file (1.27 kB)
 
ThirdParty/michelangelo/models/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (163 Bytes)
 
ThirdParty/michelangelo/models/modules/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (202 Bytes)
 
ThirdParty/michelangelo/models/modules/__pycache__/checkpoint.cpython-310.pyc DELETED
Binary file (2.65 kB)
 
ThirdParty/michelangelo/models/modules/__pycache__/distributions.cpython-310.pyc DELETED
Binary file (3.88 kB)
 
ThirdParty/michelangelo/models/modules/__pycache__/embedder.cpython-310.pyc DELETED
Binary file (9.04 kB)
 
ThirdParty/michelangelo/models/modules/__pycache__/transformer_blocks.cpython-310.pyc DELETED
Binary file (9.2 kB)
 
ThirdParty/michelangelo/utils/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (257 Bytes)
 
ThirdParty/michelangelo/utils/__pycache__/misc.cpython-310.pyc DELETED
Binary file (2.81 kB)
 
app.py CHANGED
@@ -1,8 +1,32 @@
 
1
  import gradio as gr
2
  import os
3
- from Anymate.args import ui_args
4
- from Anymate.utils.ui_utils import process_input, vis_joint, vis_connectivity, vis_skinning, prepare_blender_file
5
- from Anymate.utils.ui_utils import get_model, get_result_joint, get_result_connectivity, get_result_skinning, get_all_models, get_all_results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  with gr.Blocks() as demo:
8
  gr.Markdown("""
@@ -12,13 +36,13 @@ with gr.Blocks() as demo:
12
  pc = gr.State(value=None)
13
  normalized_mesh_file = gr.State(value=None)
14
 
15
- result_joint = gr.State(value=None)
16
- result_connectivity = gr.State(value=None)
17
- result_skinning = gr.State(value=None)
18
 
19
- model_joint = gr.State(value=None)
20
- model_connectivity = gr.State(value=None)
21
- model_skinning = gr.State(value=None)
22
 
23
  with gr.Row():
24
  with gr.Column():
@@ -36,7 +60,8 @@ with gr.Blocks() as demo:
36
  sample_dropdown = gr.Dropdown(
37
  label="Select Sample Object",
38
  choices=sample_objects,
39
- interactive=True
 
40
  )
41
 
42
  load_sample_btn = gr.Button("Load Sample")
@@ -54,51 +79,51 @@ with gr.Blocks() as demo:
54
  blender_file = gr.File(label="Output Blender File", scale=1)
55
 
56
  # Checkpoint paths
57
- joint_models_dir = 'Anymate/checkpoints/joint'
58
- joint_models = [os.path.join(joint_models_dir, f) for f in os.listdir(joint_models_dir)
59
- if os.path.isfile(os.path.join(joint_models_dir, f))]
60
- with gr.Row():
61
- joint_checkpoint = gr.Dropdown(
62
- label="Joint Checkpoint",
63
- choices=joint_models,
64
- value=ui_args.checkpoint_joint,
65
- interactive=True
66
- )
67
- joint_status = gr.Checkbox(label="Joint Model Status", value=False, interactive=False, scale=0.3)
68
  # with gr.Column():
69
  # with gr.Row():
70
  # load_joint_btn = gr.Button("Load", scale=0.3)
71
 
72
  # process_joint_btn = gr.Button("Process", scale=0.3)
73
 
74
- conn_models_dir = 'Anymate/checkpoints/conn'
75
- conn_models = [os.path.join(conn_models_dir, f) for f in os.listdir(conn_models_dir)
76
- if os.path.isfile(os.path.join(conn_models_dir, f))]
77
- with gr.Row():
78
- conn_checkpoint = gr.Dropdown(
79
- label="Connection Checkpoint",
80
- choices=conn_models,
81
- value=ui_args.checkpoint_conn,
82
- interactive=True
83
- )
84
- conn_status = gr.Checkbox(label="Connectivity Model Status", value=False, interactive=False, scale=0.3)
85
  # with gr.Column():
86
  # with gr.Row():
87
  # load_conn_btn = gr.Button("Load", scale=0.3)
88
 
89
  # process_conn_btn = gr.Button("Process", scale=0.3)
90
 
91
- skin_models_dir = 'Anymate/checkpoints/skin'
92
- skin_models = [os.path.join(skin_models_dir, f) for f in os.listdir(skin_models_dir)
93
- if os.path.isfile(os.path.join(skin_models_dir, f))]
94
- with gr.Row():
95
- skin_checkpoint = gr.Dropdown(
96
- label="Skin Checkpoint",
97
- choices=skin_models,
98
- value=ui_args.checkpoint_skin,
99
- interactive=True
100
- )
101
- skin_status = gr.Checkbox(label="Skinning Model Status", value=False, interactive=False, scale=0.3)
102
  # with gr.Column():
103
  # with gr.Row():
104
  # load_skin_btn = gr.Button("Load", scale=0.3)
@@ -106,7 +131,7 @@ with gr.Blocks() as demo:
106
  # process_skin_btn = gr.Button("Process", scale=0.3)
107
 
108
  with gr.Row():
109
- load_all_btn = gr.Button("Load all models", scale=1)
110
  process_all_btn = gr.Button("Run all models", scale=1)
111
  # download_btn = gr.DownloadButton("Blender File Not Ready", scale=0.3)
112
  # blender_file = gr.File(label="Blender File", scale=1)
@@ -117,7 +142,7 @@ with gr.Blocks() as demo:
117
  mesh_input.change(
118
  process_input,
119
  inputs=mesh_input,
120
- outputs=[normalized_mesh_file, mesh_output, mesh_output2, blender_file, pc, result_joint, result_connectivity, result_skinning]
121
  )
122
 
123
  load_sample_btn.click(
@@ -127,7 +152,7 @@ with gr.Blocks() as demo:
127
  ).then(
128
  process_input,
129
  inputs=mesh_input,
130
- outputs=[normalized_mesh_file, mesh_output, mesh_output2, blender_file, pc, result_joint, result_connectivity, result_skinning]
131
  )
132
 
133
  normalized_mesh_file.change(
@@ -136,95 +161,43 @@ with gr.Blocks() as demo:
136
  outputs=mesh_input
137
  )
138
 
139
- result_joint.change(
140
- vis_joint,
141
- inputs=[normalized_mesh_file, result_joint],
142
- outputs=[mesh_output, mesh_output2]
143
- )
144
-
145
- result_connectivity.change(
146
- vis_connectivity,
147
- inputs=[normalized_mesh_file, result_joint, result_connectivity],
148
- outputs=[mesh_output, mesh_output2]
149
- )
150
-
151
- result_skinning.change(
152
- vis_skinning,
153
- inputs=[normalized_mesh_file, result_joint, result_connectivity, result_skinning],
154
- outputs=[mesh_output, mesh_output2]
155
- )
156
-
157
- result_skinning.change(
158
- prepare_blender_file,
159
- inputs=[normalized_mesh_file],
160
- outputs=blender_file
161
- )
162
-
163
- joint_checkpoint.change(
164
- get_model,
165
- inputs=joint_checkpoint,
166
- outputs=[model_joint, joint_status]
167
- )
168
-
169
- conn_checkpoint.change(
170
- get_model,
171
- inputs=conn_checkpoint,
172
- outputs=[model_connectivity, conn_status]
173
- )
174
-
175
- skin_checkpoint.change(
176
- get_model,
177
- inputs=skin_checkpoint,
178
- outputs=[model_skinning, skin_status]
179
- )
180
-
181
- load_all_btn.click(
182
- get_all_models,
183
- inputs=[joint_checkpoint, conn_checkpoint, skin_checkpoint],
184
- outputs=[model_joint, model_connectivity, model_skinning, joint_status, conn_status, skin_status]
185
- )
186
-
187
- process_all_btn.click(
188
- get_all_results,
189
- inputs=[normalized_mesh_file, model_joint, model_connectivity, model_skinning, pc, eps, min_samples],
190
- outputs=[result_joint, result_connectivity, result_skinning]
191
- )
192
-
193
- # load_joint_btn.click(
194
- # fn=get_model,
195
- # inputs=joint_checkpoint,
196
- # outputs=[model_joint, joint_status]
197
  # )
198
 
199
- # load_conn_btn.click(
200
- # fn=get_model,
201
- # inputs=conn_checkpoint,
202
- # outputs=[model_connectivity, conn_status]
203
  # )
204
 
205
- # load_skin_btn.click(
206
- # fn=get_model,
207
- # inputs=skin_checkpoint,
208
- # outputs=[model_skinning, skin_status]
209
  # )
210
 
211
- # process_joint_btn.click(
212
- # fn=get_result_joint,
213
- # inputs=[normalized_mesh_file, model_joint, pc, eps, min_samples],
214
- # outputs=result_joint
215
- # )
216
-
217
- # process_conn_btn.click(
218
- # fn=get_result_connectivity,
219
- # inputs=[normalized_mesh_file, model_connectivity, pc, result_joint],
220
- # outputs=result_connectivity
221
- # )
222
-
223
- # process_skin_btn.click(
224
- # fn=get_result_skinning,
225
- # inputs=[normalized_mesh_file, model_skinning, pc, result_joint, result_connectivity],
226
- # outputs=result_skinning
227
  # )
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  if __name__ == "__main__":
230
  demo.launch()
 
1
+ import spaces
2
  import gradio as gr
3
  import os
4
+ from Anymate.args import ui_args, anymate_args
5
+ from Anymate.utils.ui_utils import process_input, vis_joint, vis_connectivity, vis_skinning, vis_all, prepare_blender_file
6
+ from Anymate.utils.ui_utils import get_result_joint, get_result_connectivity, get_result_skinning
7
+
8
+ from Anymate.utils.utils import load_checkpoint
9
+
10
+ # Check if checkpoints exist, if not download them
11
+ if not (os.path.exists(ui_args.checkpoint_joint) and
12
+ os.path.exists(ui_args.checkpoint_conn) and
13
+ os.path.exists(ui_args.checkpoint_skin)):
14
+ print("Missing checkpoints, downloading them...")
15
+ os.system("bash Anymate/get_checkpoints.sh")
16
+
17
+ model_joint = load_checkpoint(ui_args.checkpoint_joint, 'cpu', anymate_args.num_joints).to(anymate_args.device)
18
+ model_connectivity = load_checkpoint(ui_args.checkpoint_conn, 'cpu', anymate_args.num_joints).to(anymate_args.device)
19
+ model_skinning = load_checkpoint(ui_args.checkpoint_skin, 'cpu', anymate_args.num_joints).to(anymate_args.device)
20
+
21
+ @spaces.GPU
22
+ def get_all_results(mesh_file, pc, eps=0.03, min_samples=1):
23
+ # pc = pc.to(anymate_args.device)
24
+ joints = get_result_joint(mesh_file, model_joint, pc, eps=eps, min_samples=min_samples)
25
+ conns = get_result_connectivity(mesh_file, model_connectivity, pc, joints)
26
+ skins = get_result_skinning(mesh_file, model_skinning, pc, joints, conns)
27
+ print("Finish Inference")
28
+ return
29
+
30
 
31
  with gr.Blocks() as demo:
32
  gr.Markdown("""
 
36
  pc = gr.State(value=None)
37
  normalized_mesh_file = gr.State(value=None)
38
 
39
+ # result_joint = gr.State(value=None)
40
+ # result_connectivity = gr.State(value=None)
41
+ # result_skinning = gr.State(value=None)
42
 
43
+ # model_joint = gr.State(value=model_joint)
44
+ # model_connectivity = gr.State(value=model_connectivity)
45
+ # model_skinning = gr.State(value=model_skinning)
46
 
47
  with gr.Row():
48
  with gr.Column():
 
60
  sample_dropdown = gr.Dropdown(
61
  label="Select Sample Object",
62
  choices=sample_objects,
63
+ interactive=True,
64
+ value=sample_objects[0]
65
  )
66
 
67
  load_sample_btn = gr.Button("Load Sample")
 
79
  blender_file = gr.File(label="Output Blender File", scale=1)
80
 
81
  # Checkpoint paths
82
+ # joint_models_dir = 'Anymate/checkpoints/joint'
83
+ # joint_models = [os.path.join(joint_models_dir, f) for f in os.listdir(joint_models_dir)
84
+ # if os.path.isfile(os.path.join(joint_models_dir, f))]
85
+ # with gr.Row():
86
+ # joint_checkpoint = gr.Dropdown(
87
+ # label="Joint Checkpoint",
88
+ # choices=joint_models,
89
+ # value=ui_args.checkpoint_joint,
90
+ # interactive=True
91
+ # )
92
+ # joint_status = gr.Checkbox(label="Joint Model Status", value=False, interactive=False, scale=0.3)
93
  # with gr.Column():
94
  # with gr.Row():
95
  # load_joint_btn = gr.Button("Load", scale=0.3)
96
 
97
  # process_joint_btn = gr.Button("Process", scale=0.3)
98
 
99
+ # conn_models_dir = 'Anymate/checkpoints/conn'
100
+ # conn_models = [os.path.join(conn_models_dir, f) for f in os.listdir(conn_models_dir)
101
+ # if os.path.isfile(os.path.join(conn_models_dir, f))]
102
+ # with gr.Row():
103
+ # conn_checkpoint = gr.Dropdown(
104
+ # label="Connection Checkpoint",
105
+ # choices=conn_models,
106
+ # value=ui_args.checkpoint_conn,
107
+ # interactive=True
108
+ # )
109
+ # conn_status = gr.Checkbox(label="Connectivity Model Status", value=False, interactive=False, scale=0.3)
110
  # with gr.Column():
111
  # with gr.Row():
112
  # load_conn_btn = gr.Button("Load", scale=0.3)
113
 
114
  # process_conn_btn = gr.Button("Process", scale=0.3)
115
 
116
+ # skin_models_dir = 'Anymate/checkpoints/skin'
117
+ # skin_models = [os.path.join(skin_models_dir, f) for f in os.listdir(skin_models_dir)
118
+ # if os.path.isfile(os.path.join(skin_models_dir, f))]
119
+ # with gr.Row():
120
+ # skin_checkpoint = gr.Dropdown(
121
+ # label="Skin Checkpoint",
122
+ # choices=skin_models,
123
+ # value=ui_args.checkpoint_skin,
124
+ # interactive=True
125
+ # )
126
+ # skin_status = gr.Checkbox(label="Skinning Model Status", value=False, interactive=False, scale=0.3)
127
  # with gr.Column():
128
  # with gr.Row():
129
  # load_skin_btn = gr.Button("Load", scale=0.3)
 
131
  # process_skin_btn = gr.Button("Process", scale=0.3)
132
 
133
  with gr.Row():
134
+ # load_all_btn = gr.Button("Load all models", scale=1)
135
  process_all_btn = gr.Button("Run all models", scale=1)
136
  # download_btn = gr.DownloadButton("Blender File Not Ready", scale=0.3)
137
  # blender_file = gr.File(label="Blender File", scale=1)
 
142
  mesh_input.change(
143
  process_input,
144
  inputs=mesh_input,
145
+ outputs=[normalized_mesh_file, mesh_output, mesh_output2, blender_file, pc]
146
  )
147
 
148
  load_sample_btn.click(
 
152
  ).then(
153
  process_input,
154
  inputs=mesh_input,
155
+ outputs=[normalized_mesh_file, mesh_output, mesh_output2, blender_file, pc]
156
  )
157
 
158
  normalized_mesh_file.change(
 
161
  outputs=mesh_input
162
  )
163
 
164
+ # result_joint.change(
165
+ # vis_joint,
166
+ # inputs=[normalized_mesh_file, result_joint],
167
+ # outputs=[mesh_output, mesh_output2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  # )
169
 
170
+ # result_connectivity.change(
171
+ # vis_connectivity,
172
+ # inputs=[normalized_mesh_file, result_joint, result_connectivity],
173
+ # outputs=[mesh_output, mesh_output2]
174
  # )
175
 
176
+ # result_skinning.change(
177
+ # vis_skinning,
178
+ # inputs=[normalized_mesh_file, result_joint, result_connectivity, result_skinning],
179
+ # outputs=[mesh_output, mesh_output2]
180
  # )
181
 
182
+ # result_skinning.change(
183
+ # prepare_blender_file,
184
+ # inputs=[normalized_mesh_file],
185
+ # outputs=blender_file
 
 
 
 
 
 
 
 
 
 
 
 
186
  # )
187
 
188
+ process_all_btn.click(
189
+ get_all_results,
190
+ inputs=[normalized_mesh_file, pc, eps, min_samples],
191
+ outputs=[]
192
+ ).then(
193
+ vis_all,
194
+ inputs=[normalized_mesh_file],
195
+ outputs=[mesh_output, mesh_output2]
196
+ ).then(
197
+ prepare_blender_file,
198
+ inputs=[normalized_mesh_file],
199
+ outputs=blender_file
200
+ )
201
+
202
  if __name__ == "__main__":
203
  demo.launch()