Spaces:
Runtime error
Runtime error
FrozenBurning
commited on
Commit
·
9d573a0
1
Parent(s):
ae08466
update inference
Browse files- inference.py +15 -3
inference.py
CHANGED
|
@@ -85,6 +85,21 @@ def resize_foreground(
|
|
| 85 |
def extract_texmesh(args, model, output_path, device):
|
| 86 |
# Prepare directory
|
| 87 |
ins_dir = output_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
# Get SDFs
|
| 90 |
with torch.no_grad():
|
|
@@ -350,9 +365,6 @@ if __name__ == "__main__":
|
|
| 350 |
# manually enable tf32 to get speedup on A100 GPUs
|
| 351 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 352 |
torch.backends.cudnn.allow_tf32 = True
|
| 353 |
-
os.environ["CC"] = "/mnt/lustre/share/gcc/gcc-8.5.0/bin/gcc"
|
| 354 |
-
os.environ["CPP"] = "/mnt/lustre/share/gcc/gcc-8.5.0/bin/g++"
|
| 355 |
-
os.environ["CXX"] = "/mnt/lustre/share/gcc/gcc-8.5.0/bin/g++"
|
| 356 |
# set config
|
| 357 |
config = OmegaConf.load(str(sys.argv[1]))
|
| 358 |
config_cli = OmegaConf.from_cli(args_list=sys.argv[2:])
|
|
|
|
| 85 |
def extract_texmesh(args, model, output_path, device):
|
| 86 |
# Prepare directory
|
| 87 |
ins_dir = output_path
|
| 88 |
+
# Noise Filter
|
| 89 |
+
srt_param = model.srt_param.clone()
|
| 90 |
+
prim_position = srt_param[:, 1:4]
|
| 91 |
+
prim_scale = srt_param[:, 0:1]
|
| 92 |
+
dist = torch.sqrt(torch.sum((prim_position[:, None, :] - prim_position[None, :, :]) ** 2, dim=-1))
|
| 93 |
+
dist += torch.eye(prim_position.shape[0]).to(srt_param)
|
| 94 |
+
min_dist, min_indices = dist.min(1)
|
| 95 |
+
dst_prim_scale = prim_scale[min_indices, :]
|
| 96 |
+
min_scale_converage = prim_scale * 1.414 + dst_prim_scale * 1.414
|
| 97 |
+
prim_mask = min_dist < min_scale_converage[:, 0]
|
| 98 |
+
filtered_srt_param = srt_param[prim_mask, :]
|
| 99 |
+
filtered_feat_param = model.feat_param.clone()[prim_mask, ...]
|
| 100 |
+
model.srt_param.data = filtered_srt_param
|
| 101 |
+
model.feat_param.data = filtered_feat_param
|
| 102 |
+
print(f'[INFO] Mesh Extraction on PrimX: srt={model.srt_param.shape} feat={model.feat_param.shape}')
|
| 103 |
|
| 104 |
# Get SDFs
|
| 105 |
with torch.no_grad():
|
|
|
|
| 365 |
# manually enable tf32 to get speedup on A100 GPUs
|
| 366 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 367 |
torch.backends.cudnn.allow_tf32 = True
|
|
|
|
|
|
|
|
|
|
| 368 |
# set config
|
| 369 |
config = OmegaConf.load(str(sys.argv[1]))
|
| 370 |
config_cli = OmegaConf.from_cli(args_list=sys.argv[2:])
|