YulianSa commited on
Commit
e82e682
·
1 Parent(s): ccc749d
Files changed (2) hide show
  1. infer_api.py +3 -7
  2. refine/mesh_refine.py +5 -0
infer_api.py CHANGED
@@ -619,30 +619,26 @@ def infer_refine(meshes, imgs):
619
 
620
  # my mesh flow weight by nearest vertexs
621
  if fixed_v is not None and fixed_f is not None and level == 1:
622
- t = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f)
623
-
624
  fixed_v_cpu = fixed_v.cpu().numpy()
625
  kdtree_anchor = KDTree(fixed_v_cpu)
626
  kdtree_mesh_v = KDTree(mesh_v)
627
  _, idx_anchor = kdtree_anchor.query(mesh_v, k=1)
628
  _, idx_mesh_v = kdtree_mesh_v.query(mesh_v, k=25)
629
  idx_anchor = idx_anchor.squeeze()
630
- neighbors = torch.tensor(mesh_v).cuda()[idx_mesh_v] # V, 25, 3
631
  # calculate the distances neighbors [V, 25, 3]; mesh_v [V, 3] -> [V, 25]
632
- neighbor_dists = torch.norm(neighbors - torch.tensor(mesh_v).cuda()[:, None], dim=-1)
633
  neighbor_dists[neighbor_dists > 0.06] = 114514.
634
  neighbor_weights = torch.exp(-neighbor_dists * 1.)
635
  neighbor_weights = neighbor_weights / neighbor_weights.sum(dim=1, keepdim=True)
636
  anchors = fixed_v[idx_anchor] # V, 3
637
  anchor_normals = calc_vertex_normals(fixed_v, fixed_f)[idx_anchor] # V, 3
638
- dis_anchor = torch.clamp(((anchors - torch.tensor(mesh_v).cuda()) * anchor_normals).sum(-1), min=0) + 0.01
639
  vec_anchor = dis_anchor[:, None] * anchor_normals # V, 3
640
  vec_anchor = vec_anchor[idx_mesh_v] # V, 25, 3
641
  weighted_vec_anchor = (vec_anchor * neighbor_weights[:, :, None]).sum(1) # V, 3
642
  mesh_v += weighted_vec_anchor.cpu().numpy()
643
 
644
- t = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f)
645
-
646
  mesh_v = torch.tensor(mesh_v, dtype=torch.float32)
647
  mesh_f = torch.tensor(mesh_f)
648
 
 
619
 
620
  # my mesh flow weight by nearest vertexs
621
  if fixed_v is not None and fixed_f is not None and level == 1:
 
 
622
  fixed_v_cpu = fixed_v.cpu().numpy()
623
  kdtree_anchor = KDTree(fixed_v_cpu)
624
  kdtree_mesh_v = KDTree(mesh_v)
625
  _, idx_anchor = kdtree_anchor.query(mesh_v, k=1)
626
  _, idx_mesh_v = kdtree_mesh_v.query(mesh_v, k=25)
627
  idx_anchor = idx_anchor.squeeze()
628
+ neighbors = torch.tensor(mesh_v)[idx_mesh_v] # V, 25, 3
629
  # calculate the distances neighbors [V, 25, 3]; mesh_v [V, 3] -> [V, 25]
630
+ neighbor_dists = torch.norm(neighbors - torch.tensor(mesh_v)[:, None], dim=-1)
631
  neighbor_dists[neighbor_dists > 0.06] = 114514.
632
  neighbor_weights = torch.exp(-neighbor_dists * 1.)
633
  neighbor_weights = neighbor_weights / neighbor_weights.sum(dim=1, keepdim=True)
634
  anchors = fixed_v[idx_anchor] # V, 3
635
  anchor_normals = calc_vertex_normals(fixed_v, fixed_f)[idx_anchor] # V, 3
636
+ dis_anchor = torch.clamp(((anchors - torch.tensor(mesh_v)) * anchor_normals).sum(-1), min=0) + 0.01
637
  vec_anchor = dis_anchor[:, None] * anchor_normals # V, 3
638
  vec_anchor = vec_anchor[idx_mesh_v] # V, 25, 3
639
  weighted_vec_anchor = (vec_anchor * neighbor_weights[:, :, None]).sum(1) # V, 3
640
  mesh_v += weighted_vec_anchor.cpu().numpy()
641
 
 
 
642
  mesh_v = torch.tensor(mesh_v, dtype=torch.float32)
643
  mesh_f = torch.tensor(mesh_f)
644
 
refine/mesh_refine.py CHANGED
@@ -268,6 +268,11 @@ def run_mesh_refine(vertices, faces, pils: List[Image.Image], fixed_v=None, fixe
268
 
269
  def geo_refine(mesh_v, mesh_f, rgb_ls, normal_ls, expansion_weight=0.1, fixed_v=None, fixed_f=None,
270
  distract_mask=None, distract_bbox=None, thres=3e-6, no_decompose=False):
 
 
 
 
 
271
  vertices, faces = geo_refine_1(mesh_v, mesh_f, rgb_ls, normal_ls, expansion_weight=expansion_weight, fixed_v=fixed_v, fixed_f=fixed_f,
272
  distract_mask=distract_mask, distract_bbox=distract_bbox, thres=thres, no_decompose=no_decompose)
273
  vertices, faces = geo_refine_2(vertices, faces, fixed_v=fixed_v)
 
268
 
269
  def geo_refine(mesh_v, mesh_f, rgb_ls, normal_ls, expansion_weight=0.1, fixed_v=None, fixed_f=None,
270
  distract_mask=None, distract_bbox=None, thres=3e-6, no_decompose=False):
271
+ print(mesh_v.device, mesh_f.device)
272
+ if fixed_v is not None:
273
+ print('fixed_v', fixed_v.shape, fixed_v.device)
274
+ if fixed_f is not None:
275
+ print('fixed_f', fixed_f.shape, fixed_f.device)
276
  vertices, faces = geo_refine_1(mesh_v, mesh_f, rgb_ls, normal_ls, expansion_weight=expansion_weight, fixed_v=fixed_v, fixed_f=fixed_f,
277
  distract_mask=distract_mask, distract_bbox=distract_bbox, thres=thres, no_decompose=no_decompose)
278
  vertices, faces = geo_refine_2(vertices, faces, fixed_v=fixed_v)