Commit
·
c614b0f
1
Parent(s):
7d7a0a6
add wheels
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- LHM/__init__.py +15 -0
- LHM/__pycache__/__init__.cpython-310.pyc +0 -0
- LHM/__pycache__/launch.cpython-310.pyc +0 -0
- LHM/datasets/__init__.py +16 -0
- LHM/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
- LHM/datasets/__pycache__/cam_utils.cpython-310.pyc +0 -0
- LHM/datasets/__pycache__/mixer.cpython-310.pyc +0 -0
- LHM/datasets/base.py +70 -0
- LHM/datasets/bedlam.py +493 -0
- LHM/datasets/bedlam_util.py +306 -0
- LHM/datasets/cam_utils.py +205 -0
- LHM/datasets/mixer.py +120 -0
- LHM/launch.py +35 -0
- LHM/losses/__init__.py +20 -0
- LHM/losses/ball_loss.py +54 -0
- LHM/losses/offset_loss.py +52 -0
- LHM/losses/perceptual.py +70 -0
- LHM/losses/pixelwise.py +58 -0
- LHM/losses/tvloss.py +55 -0
- LHM/models/ESRGANer_utils.py +482 -0
- LHM/models/__init__.py +30 -0
- LHM/models/__pycache__/ESRGANer_utils.cpython-310.pyc +0 -0
- LHM/models/__pycache__/__init__.cpython-310.pyc +0 -0
- LHM/models/__pycache__/arcface_utils.cpython-310.pyc +0 -0
- LHM/models/__pycache__/embedder.cpython-310.pyc +0 -0
- LHM/models/__pycache__/modeling_human_lrm.cpython-310.pyc +0 -0
- LHM/models/__pycache__/transformer.cpython-310.pyc +0 -0
- LHM/models/__pycache__/transformer_dit.cpython-310.pyc +0 -0
- LHM/models/__pycache__/utils.cpython-310.pyc +0 -0
- LHM/models/arcface_utils.py +360 -0
- LHM/models/block.py +124 -0
- LHM/models/discriminator.py +120 -0
- LHM/models/embedder.py +37 -0
- LHM/models/encoders/__init__.py +15 -0
- LHM/models/encoders/__pycache__/__init__.cpython-310.pyc +0 -0
- LHM/models/encoders/__pycache__/dinov2_fusion_wrapper.cpython-310.pyc +0 -0
- LHM/models/encoders/__pycache__/sapiens_warpper.cpython-310.pyc +0 -0
- LHM/models/encoders/dino_wrapper.py +68 -0
- LHM/models/encoders/dinov2/__init__.py +15 -0
- LHM/models/encoders/dinov2/__pycache__/__init__.cpython-310.pyc +0 -0
- LHM/models/encoders/dinov2/hub/__init__.py +4 -0
- LHM/models/encoders/dinov2/hub/__pycache__/__init__.cpython-310.pyc +0 -0
- LHM/models/encoders/dinov2/hub/__pycache__/backbones.cpython-310.pyc +0 -0
- LHM/models/encoders/dinov2/hub/__pycache__/utils.cpython-310.pyc +0 -0
- LHM/models/encoders/dinov2/hub/backbones.py +166 -0
- LHM/models/encoders/dinov2/hub/classifiers.py +268 -0
- LHM/models/encoders/dinov2/hub/depth/__init__.py +7 -0
- LHM/models/encoders/dinov2/hub/depth/decode_heads.py +747 -0
- LHM/models/encoders/dinov2/hub/depth/encoder_decoder.py +351 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.whl filter=lfs diff=lfs merge=lfs -text
|
LHM/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024, Zexin He
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# Empty
|
LHM/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (153 Bytes). View file
|
|
LHM/__pycache__/launch.cpython-310.pyc
ADDED
Binary file (743 Bytes). View file
|
|
LHM/datasets/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024, Zexin He
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from .mixer import MixerDataset
|
LHM/datasets/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (204 Bytes). View file
|
|
LHM/datasets/__pycache__/cam_utils.cpython-310.pyc
ADDED
Binary file (5.43 kB). View file
|
|
LHM/datasets/__pycache__/mixer.cpython-310.pyc
ADDED
Binary file (2.91 kB). View file
|
|
LHM/datasets/base.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Organization : Alibaba XR-Lab
|
3 |
+
# @Author : Peihao Li & Lingteng Qiu & Xiaodong Gu & Qi Zuo
|
4 |
+
# @Email : [email protected]
|
5 |
+
# @Time : 2025-03-10 18:47:56
|
6 |
+
# @Function : dataset base
|
7 |
+
|
8 |
+
import json
|
9 |
+
import pdb
|
10 |
+
import traceback
|
11 |
+
from abc import ABC, abstractmethod
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
from megfile import smart_exists, smart_open, smart_path_join
|
16 |
+
from PIL import Image
|
17 |
+
|
18 |
+
|
19 |
+
class BaseDataset(torch.utils.data.Dataset, ABC):
|
20 |
+
def __init__(self, root_dirs: str, meta_path: str):
|
21 |
+
super().__init__()
|
22 |
+
self.root_dirs = root_dirs
|
23 |
+
self.uids = self._load_uids(meta_path)
|
24 |
+
|
25 |
+
def __len__(self):
|
26 |
+
return len(self.uids)
|
27 |
+
|
28 |
+
@abstractmethod
|
29 |
+
def inner_get_item(self, idx):
|
30 |
+
pass
|
31 |
+
|
32 |
+
def __getitem__(self, idx):
|
33 |
+
try:
|
34 |
+
return self.inner_get_item(idx)
|
35 |
+
except Exception as e:
|
36 |
+
traceback.print_exc()
|
37 |
+
print(f"[DEBUG-DATASET] Error when loading {self.uids[idx]}")
|
38 |
+
# raise e
|
39 |
+
return self.__getitem__((idx + 1) % self.__len__())
|
40 |
+
|
41 |
+
@staticmethod
|
42 |
+
def _load_uids(meta_path: str):
|
43 |
+
# meta_path is a json file
|
44 |
+
if meta_path == None:
|
45 |
+
uids = []
|
46 |
+
else:
|
47 |
+
with open(meta_path, "r") as f:
|
48 |
+
uids = json.load(f)
|
49 |
+
|
50 |
+
return uids
|
51 |
+
|
52 |
+
@staticmethod
|
53 |
+
def _load_rgba_image(file_path, bg_color: float = 1.0):
|
54 |
+
"""Load and blend RGBA image to RGB with certain background, 0-1 scaled"""
|
55 |
+
rgba = np.array(Image.open(smart_open(file_path, "rb")))
|
56 |
+
rgba = torch.from_numpy(rgba).float() / 255.0
|
57 |
+
rgba = rgba.permute(2, 0, 1).unsqueeze(0)
|
58 |
+
rgb = rgba[:, :3, :, :] * rgba[:, 3:4, :, :] + bg_color * (
|
59 |
+
1 - rgba[:, 3:, :, :]
|
60 |
+
)
|
61 |
+
# rgba[:, :3, ...] * rgba[:, 3:, ...] + (1 - rgba[:, 3:, ...])
|
62 |
+
return rgb
|
63 |
+
|
64 |
+
@staticmethod
|
65 |
+
def _locate_datadir(root_dirs, uid, locator: str):
|
66 |
+
for root_dir in root_dirs:
|
67 |
+
datadir = smart_path_join(root_dir, uid, locator)
|
68 |
+
if smart_exists(datadir):
|
69 |
+
return root_dir
|
70 |
+
raise FileNotFoundError(f"Cannot find valid data directory for uid {uid}")
|
LHM/datasets/bedlam.py
ADDED
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024, Zexin He
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import glob
|
17 |
+
|
18 |
+
# from megfile import smart_path_join, smart_open
|
19 |
+
import json
|
20 |
+
import os
|
21 |
+
import random
|
22 |
+
from collections import defaultdict
|
23 |
+
from typing import Union
|
24 |
+
|
25 |
+
import cv2
|
26 |
+
import numpy as np
|
27 |
+
import torch
|
28 |
+
from PIL import Image
|
29 |
+
|
30 |
+
from LHM.datasets.base import BaseDataset
|
31 |
+
from LHM.datasets.cam_utils import (
|
32 |
+
build_camera_principle,
|
33 |
+
build_camera_standard,
|
34 |
+
camera_normalization_objaverse,
|
35 |
+
)
|
36 |
+
from LHM.utils.proxy import no_proxy
|
37 |
+
|
38 |
+
__all__ = ["BedlamDataset"]
|
39 |
+
|
40 |
+
|
41 |
+
class BedlamDataset(BaseDataset):
|
42 |
+
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
root_dirs: str,
|
46 |
+
meta_path: str,
|
47 |
+
sample_side_views: int,
|
48 |
+
render_image_res_low: int,
|
49 |
+
render_image_res_high: int,
|
50 |
+
render_region_size: int,
|
51 |
+
source_image_res: int,
|
52 |
+
repeat_num=1,
|
53 |
+
crop_range_ratio_hw=[1.0, 1.0],
|
54 |
+
valid_area_ratio=0.4,
|
55 |
+
debug=False,
|
56 |
+
**kwargs,
|
57 |
+
):
|
58 |
+
super().__init__(root_dirs, meta_path)
|
59 |
+
self.sample_side_views = sample_side_views
|
60 |
+
self.render_image_res_low = render_image_res_low
|
61 |
+
self.render_image_res_high = render_image_res_high
|
62 |
+
if not (
|
63 |
+
isinstance(render_region_size, list)
|
64 |
+
or isinstance(render_region_size, tuple)
|
65 |
+
):
|
66 |
+
render_region_size = render_region_size, render_region_size # [H, W]
|
67 |
+
self.render_region_size = render_region_size
|
68 |
+
self.source_image_res = source_image_res
|
69 |
+
|
70 |
+
self.uids = self.uids * repeat_num
|
71 |
+
self.crop_range_ratio_hw = crop_range_ratio_hw
|
72 |
+
self.debug = debug
|
73 |
+
self.valid_area_ratio = valid_area_ratio
|
74 |
+
print(
|
75 |
+
f"BedlamDataset, data_len:{len(self.uids)}, repeat_num:{repeat_num}, debug:{debug}"
|
76 |
+
)
|
77 |
+
self.multiply = kwargs.get("multiply", 14)
|
78 |
+
|
79 |
+
@staticmethod
|
80 |
+
def _load_pose(pose):
|
81 |
+
intrinsic = torch.eye(4)
|
82 |
+
intrinsic[0, 0] = pose["focal"][0]
|
83 |
+
intrinsic[1, 1] = pose["focal"][1]
|
84 |
+
intrinsic[0, 2] = pose["princpt"][0]
|
85 |
+
intrinsic[1, 2] = pose["princpt"][1]
|
86 |
+
intrinsic = intrinsic.float()
|
87 |
+
|
88 |
+
c2w = torch.eye(4)
|
89 |
+
# c2w[:3, :3] = torch.tensor(pose["R"])
|
90 |
+
# c2w[3, :3] = torch.tensor(pose["t"])
|
91 |
+
c2w = c2w.float()
|
92 |
+
|
93 |
+
return c2w, intrinsic
|
94 |
+
|
95 |
+
def load_rgb_image_with_aug_bg(self, rgb_path, mask_path, bg_color):
|
96 |
+
rgb = np.array(Image.open(rgb_path))
|
97 |
+
rgb = torch.from_numpy(rgb).float() / 255.0
|
98 |
+
rgb = rgb.permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
|
99 |
+
mask = None
|
100 |
+
|
101 |
+
if mask_path is not None:
|
102 |
+
mask = np.array(Image.open(mask_path))
|
103 |
+
mask = torch.from_numpy(mask).float() / 255.0
|
104 |
+
mask = (mask > 0.5).float()
|
105 |
+
if len(mask.shape) == 3:
|
106 |
+
mask = mask[:, :, 0:1]
|
107 |
+
if len(mask.shape) == 2:
|
108 |
+
mask = mask.unsqueeze(-1)
|
109 |
+
mask = mask.permute(2, 0, 1).unsqueeze(0) # [1, 1, H, W]
|
110 |
+
rgb = torch.cat([rgb, mask], dim=1) # [1, 4, H, W]
|
111 |
+
else:
|
112 |
+
mask = rgb[:, 3:4, :, :]
|
113 |
+
|
114 |
+
# erode mask
|
115 |
+
mask_np = (mask[0, 0].numpy() * 255).astype(np.uint8)
|
116 |
+
kernel_size, iterations = 3, 1
|
117 |
+
kernel = np.ones((kernel_size, kernel_size), np.uint8)
|
118 |
+
mask_np = cv2.erode(mask_np, kernel, iterations=iterations)
|
119 |
+
mask = torch.from_numpy(mask_np).unsqueeze(0).unsqueeze(0) / 255.0
|
120 |
+
mask = (mask > 0.5).float()
|
121 |
+
rgb = torch.cat([rgb[:, :3], mask], dim=1) # [1, 4, H, W]
|
122 |
+
|
123 |
+
if rgb.shape[1] == 4:
|
124 |
+
rgb = rgb[:, :3, :, :] * rgb[:, 3:4, :, :] + bg_color * (
|
125 |
+
1 - rgb[:, 3:, :, :]
|
126 |
+
)
|
127 |
+
|
128 |
+
return rgb, mask
|
129 |
+
|
130 |
+
def scale_intrs(self, intrs, ratio_x, ratio_y):
|
131 |
+
intrs[:, 0] = intrs[:, 0] * ratio_x
|
132 |
+
intrs[:, 1] = intrs[:, 1] * ratio_y
|
133 |
+
return intrs
|
134 |
+
|
135 |
+
def uniform_sample_in_chunk(self, sample_num, sample_data):
|
136 |
+
chunks = np.array_split(sample_data, sample_num)
|
137 |
+
select_list = []
|
138 |
+
for chunk in chunks:
|
139 |
+
select_list.append(np.random.choice(chunk))
|
140 |
+
return select_list
|
141 |
+
|
142 |
+
@no_proxy
|
143 |
+
def inner_get_item(self, idx):
|
144 |
+
"""
|
145 |
+
Loaded contents:
|
146 |
+
rgbs: [M, 3, H, W]
|
147 |
+
poses: [M, 3, 4], [R|t]
|
148 |
+
intrinsics: [3, 2], [[fx, fy], [cx, cy], [weight, height]]
|
149 |
+
"""
|
150 |
+
uid = self.uids[idx]
|
151 |
+
seq_id = uid["seq_id"]
|
152 |
+
all_frame_info = uid["all_frame_info"]
|
153 |
+
uid = os.path.join(self.root_dirs, seq_id)
|
154 |
+
valid_imgs = [
|
155 |
+
e["frame_name"]
|
156 |
+
for e in all_frame_info
|
157 |
+
if e["valid_area_ratio"] > self.valid_area_ratio
|
158 |
+
]
|
159 |
+
assert len(valid_imgs) >= 1
|
160 |
+
|
161 |
+
if self.sample_side_views + 1 <= len(valid_imgs):
|
162 |
+
cam_id_list = np.random.choice(
|
163 |
+
valid_imgs, self.sample_side_views + 1, replace=False
|
164 |
+
)
|
165 |
+
else:
|
166 |
+
cam_id_list = np.random.choice(
|
167 |
+
valid_imgs, self.sample_side_views + 1, replace=True
|
168 |
+
)
|
169 |
+
|
170 |
+
assert self.sample_side_views + 1 == len(cam_id_list)
|
171 |
+
crop_ratio_h, crop_ratio_w = self.crop_range_ratio_hw
|
172 |
+
|
173 |
+
frame_id_list = cam_id_list
|
174 |
+
|
175 |
+
# source images
|
176 |
+
c2ws, intrs, rgbs, bg_colors, masks = [], [], [], [], []
|
177 |
+
source_c2ws, source_intrs, source_rgbs = [], [], []
|
178 |
+
smplx_params = []
|
179 |
+
shape_param = None
|
180 |
+
for cam_id, frame_id in zip(cam_id_list, frame_id_list):
|
181 |
+
frame_path = os.path.join(uid, cam_id + ".png")
|
182 |
+
frame_name = os.path.splitext(os.path.basename(frame_path))[0]
|
183 |
+
smplx_path = os.path.join(
|
184 |
+
uid.replace("/png_post/", "/smplx/"), f"{frame_name}.json"
|
185 |
+
)
|
186 |
+
|
187 |
+
with open(smplx_path) as f:
|
188 |
+
smplx_param = {
|
189 |
+
k: torch.FloatTensor(v)
|
190 |
+
for k, v in json.load(f).items()
|
191 |
+
if "valid_area_ratio" not in k
|
192 |
+
}
|
193 |
+
|
194 |
+
# if cam_id == 0:
|
195 |
+
shape_param = smplx_param["betas"]
|
196 |
+
|
197 |
+
c2w, intrinsic = self._load_pose(smplx_param)
|
198 |
+
|
199 |
+
bg_color = random.choice([0.0, 0.5, 1.0])
|
200 |
+
rgb, mask = self.load_rgb_image_with_aug_bg(
|
201 |
+
frame_path, mask_path=None, bg_color=bg_color
|
202 |
+
)
|
203 |
+
|
204 |
+
# crop image to enlarge human area.
|
205 |
+
if (crop_ratio_h < 1.0) or (crop_ratio_w < 1.0):
|
206 |
+
img_size_hw = rgb.shape[2], rgb.shape[3]
|
207 |
+
h_crop, w_crop = round(img_size_hw[0] * crop_ratio_h), round(
|
208 |
+
img_size_hw[1] * crop_ratio_w
|
209 |
+
)
|
210 |
+
h_crop_offset, w_crop_offset = round(
|
211 |
+
(img_size_hw[0] - h_crop) / 2
|
212 |
+
), round((img_size_hw[1] - w_crop) / 2)
|
213 |
+
rgb = rgb[
|
214 |
+
:,
|
215 |
+
:,
|
216 |
+
h_crop_offset : h_crop_offset + h_crop,
|
217 |
+
w_crop_offset : w_crop_offset + w_crop,
|
218 |
+
]
|
219 |
+
mask = mask[
|
220 |
+
:,
|
221 |
+
:,
|
222 |
+
h_crop_offset : h_crop_offset + h_crop,
|
223 |
+
w_crop_offset : w_crop_offset + w_crop,
|
224 |
+
]
|
225 |
+
intrinsic[0, 2] -= w_crop_offset
|
226 |
+
intrinsic[1, 2] -= h_crop_offset
|
227 |
+
|
228 |
+
assert (
|
229 |
+
abs(intrinsic[0, 2] * 2 - rgb.shape[-1]) <= 1
|
230 |
+
), f"{intrinsic[0, 2] * 2}, {rgb.shape[-1]}"
|
231 |
+
|
232 |
+
c2ws.append(c2w)
|
233 |
+
rgbs.append(rgb)
|
234 |
+
bg_colors.append(bg_color)
|
235 |
+
intrs.append(intrinsic)
|
236 |
+
smplx_params.append(smplx_param)
|
237 |
+
masks.append(mask)
|
238 |
+
|
239 |
+
c2ws = torch.stack(c2ws, dim=0) # [N, 4, 4]
|
240 |
+
intrs = torch.stack(intrs, dim=0) # [N, 4, 4]
|
241 |
+
rgbs = torch.cat(rgbs, dim=0) # [N, 3, H, W]
|
242 |
+
bg_colors = (
|
243 |
+
torch.tensor(bg_colors, dtype=torch.float32).unsqueeze(-1).repeat(1, 3)
|
244 |
+
) # [N, 3]
|
245 |
+
masks = torch.cat(masks, dim=0) # [N, 1, H, W]
|
246 |
+
|
247 |
+
smplx_params_tmp = defaultdict(list)
|
248 |
+
for smplx in smplx_params:
|
249 |
+
for k, v in smplx.items():
|
250 |
+
smplx_params_tmp[k].append(v)
|
251 |
+
for k, v in smplx_params_tmp.items():
|
252 |
+
smplx_params_tmp[k] = torch.stack(v)
|
253 |
+
smplx_params = smplx_params_tmp
|
254 |
+
# TODO check different betas for same person
|
255 |
+
smplx_params["betas"] = shape_param
|
256 |
+
|
257 |
+
# reference images
|
258 |
+
# TODO check prob
|
259 |
+
ref_idx = np.random.choice(self.sample_side_views + 1)
|
260 |
+
|
261 |
+
cam_id_source_list = cam_id_list[ref_idx : ref_idx + 1]
|
262 |
+
frame_id_source_list = frame_id_list[ref_idx : ref_idx + 1]
|
263 |
+
|
264 |
+
for cam_id, frame_id in zip(cam_id_source_list, frame_id_source_list):
|
265 |
+
frame_path = os.path.join(uid, cam_id + ".png")
|
266 |
+
frame_name = os.path.splitext(os.path.basename(frame_path))[0]
|
267 |
+
smplx_path = os.path.join(
|
268 |
+
uid.replace("/png_post/", "/smplx/"), f"{frame_name}.json"
|
269 |
+
)
|
270 |
+
|
271 |
+
with open(smplx_path) as f:
|
272 |
+
smplx_param = {
|
273 |
+
k: torch.FloatTensor(v)
|
274 |
+
for k, v in json.load(f).items()
|
275 |
+
if "valid_area_ratio" not in k
|
276 |
+
}
|
277 |
+
|
278 |
+
c2w, intrinsic = self._load_pose(smplx_param)
|
279 |
+
|
280 |
+
bg_color = 1.0
|
281 |
+
rgb, mask = self.load_rgb_image_with_aug_bg(
|
282 |
+
frame_path, mask_path=None, bg_color=bg_color
|
283 |
+
)
|
284 |
+
|
285 |
+
# crop image to enlarge human area.
|
286 |
+
if (crop_ratio_h < 1.0) or (crop_ratio_w < 1.0):
|
287 |
+
img_size_hw = rgb.shape[2], rgb.shape[3]
|
288 |
+
h_crop, w_crop = round(img_size_hw[0] * crop_ratio_h), round(
|
289 |
+
img_size_hw[1] * crop_ratio_w
|
290 |
+
)
|
291 |
+
h_crop_offset, w_crop_offset = round(
|
292 |
+
(img_size_hw[0] - h_crop) / 2
|
293 |
+
), round((img_size_hw[1] - w_crop) / 2)
|
294 |
+
rgb = rgb[
|
295 |
+
:,
|
296 |
+
:,
|
297 |
+
h_crop_offset : h_crop_offset + h_crop,
|
298 |
+
w_crop_offset : w_crop_offset + w_crop,
|
299 |
+
]
|
300 |
+
mask = mask[
|
301 |
+
:,
|
302 |
+
:,
|
303 |
+
h_crop_offset : h_crop_offset + h_crop,
|
304 |
+
w_crop_offset : w_crop_offset + w_crop,
|
305 |
+
]
|
306 |
+
intrinsic[0, 2] -= w_crop_offset
|
307 |
+
intrinsic[1, 2] -= h_crop_offset
|
308 |
+
|
309 |
+
assert (
|
310 |
+
abs(intrinsic[0, 2] * 2 - rgb.shape[-1]) <= 1
|
311 |
+
), f"{intrinsic[0, 2] * 2}, {rgb.shape[-1]}"
|
312 |
+
|
313 |
+
source_c2ws.append(c2w)
|
314 |
+
source_intrs.append(intrinsic)
|
315 |
+
source_rgbs.append(rgb)
|
316 |
+
|
317 |
+
source_c2ws = torch.stack(source_c2ws, dim=0)
|
318 |
+
source_intrs = torch.stack(source_intrs, dim=0)
|
319 |
+
source_rgbs = torch.cat(source_rgbs, dim=0)
|
320 |
+
|
321 |
+
# adjust source image resolution
|
322 |
+
# TODO check 224x224 need to padding?
|
323 |
+
# ratio_x, ratio_y = self.source_image_res / source_rgbs.shape[3], self.source_image_res / source_rgbs.shape[2]
|
324 |
+
ratio = self.source_image_res / min(source_rgbs.shape[2:])
|
325 |
+
tgt_size = int(ratio * source_rgbs.shape[2]), int(ratio * source_rgbs.shape[3])
|
326 |
+
multiply = self.multiply
|
327 |
+
tgt_size = (
|
328 |
+
int(tgt_size[0] / multiply) * multiply,
|
329 |
+
int(tgt_size[1] / multiply) * multiply,
|
330 |
+
)
|
331 |
+
ratio_y, ratio_x = (
|
332 |
+
tgt_size[0] / source_rgbs.shape[2],
|
333 |
+
tgt_size[1] / source_rgbs.shape[3],
|
334 |
+
)
|
335 |
+
source_rgbs = torch.nn.functional.interpolate(
|
336 |
+
source_rgbs, size=tgt_size, mode="bicubic", align_corners=True
|
337 |
+
)
|
338 |
+
source_rgbs = torch.clamp(source_rgbs, 0, 1)
|
339 |
+
source_intrs = self.scale_intrs(source_intrs, ratio_x=ratio_x, ratio_y=ratio_y)
|
340 |
+
|
341 |
+
# adjust render image resolution and sample intended rendering region
|
342 |
+
render_image_res = np.random.randint(
|
343 |
+
self.render_image_res_low, self.render_image_res_high + 1
|
344 |
+
)
|
345 |
+
ratio = render_image_res / min(rgbs.shape[2:])
|
346 |
+
tgt_size = int(ratio * rgbs.shape[2]), int(ratio * rgbs.shape[3])
|
347 |
+
# multiply = 14
|
348 |
+
# tgt_size = int(tgt_size[0] / multiply) * multiply, int(tgt_size[1] / multiply) * multiply
|
349 |
+
# ratio_y, ratio_x = tgt_size[0] / rgbs.shape[2], tgt_size[1] / rgbs.shape[3]
|
350 |
+
render_image = torch.nn.functional.interpolate(
|
351 |
+
rgbs, size=tgt_size, mode="bicubic", align_corners=True
|
352 |
+
)
|
353 |
+
render_image = torch.clamp(render_image, 0, 1)
|
354 |
+
intrs = self.scale_intrs(intrs, ratio_x=ratio, ratio_y=ratio)
|
355 |
+
|
356 |
+
render_mask = torch.nn.functional.interpolate(
|
357 |
+
masks, size=tgt_size, mode="bicubic", align_corners=True
|
358 |
+
)
|
359 |
+
render_mask = torch.clamp(render_mask, 0, 1)
|
360 |
+
|
361 |
+
assert (
|
362 |
+
abs(intrs[0, 0, 2] * 2 - render_image.shape[3]) <= 1.1
|
363 |
+
), f"{intrs[0, 0, 2] * 2}, {render_image.shape}"
|
364 |
+
assert (
|
365 |
+
abs(intrs[0, 1, 2] * 2 - render_image.shape[2]) <= 1.1
|
366 |
+
), f"{intrs[0, 1, 2] * 2}, {render_image.shape}"
|
367 |
+
|
368 |
+
# anchors = torch.randint(
|
369 |
+
# 0, render_image_res - min(self.render_region_size) + 1, size=(self.sample_side_views + 1, 2))
|
370 |
+
# crop_indices_h = torch.arange(0, self.render_region_size[0], device=render_image.device)
|
371 |
+
# crop_indices_w = torch.arange(0, self.render_region_size[1], device=render_image.device)
|
372 |
+
# index_h = (anchors[:, 0].unsqueeze(1) + crop_indices_h).view(-1, self.render_region_size[0], 1)
|
373 |
+
# index_w = (anchors[:, 1].unsqueeze(1) + crop_indices_w).view(-1, 1, self.render_region_size[1])
|
374 |
+
# batch_indices = torch.arange(self.sample_side_views + 1, device=render_image.device).view(-1, 1, 1)
|
375 |
+
# cropped_render_image = render_image[batch_indices, :, index_h, index_w].permute(0, 3, 1, 2)
|
376 |
+
|
377 |
+
ret = {
|
378 |
+
"uid": uid,
|
379 |
+
"source_c2ws": source_c2ws, # [N1, 4, 4]
|
380 |
+
"source_intrs": source_intrs, # [N1, 4, 4]
|
381 |
+
"source_rgbs": source_rgbs, # [N1, 3, H, W]
|
382 |
+
"render_image": render_image, # [N, 3, H, W]
|
383 |
+
"render_mask": render_mask, # [ N, 1, H, W]
|
384 |
+
"c2ws": c2ws, # [N, 4, 4]
|
385 |
+
"intrs": intrs, # [N, 4, 4]
|
386 |
+
# 'render_anchors': anchors, # [N, 2]
|
387 |
+
"render_full_resolutions": torch.tensor(
|
388 |
+
[tgt_size], dtype=torch.float32
|
389 |
+
).repeat(
|
390 |
+
self.sample_side_views + 1, 1
|
391 |
+
), # [N, 2]
|
392 |
+
"render_bg_colors": bg_colors, # [N, 3]
|
393 |
+
}
|
394 |
+
|
395 |
+
# ['root_pose', 'body_pose', 'jaw_pose', 'leye_pose', 'reye_pose', 'lhand_pose', 'rhand_pose', 'expr', 'trans', 'betas']
|
396 |
+
# 'smplx_params': smplx_params, # dict: body_pose:[N, 21, 3],
|
397 |
+
ret.update(smplx_params)
|
398 |
+
|
399 |
+
return ret
|
400 |
+
|
401 |
+
|
402 |
+
if __name__ == "__main__":
|
403 |
+
import cv2
|
404 |
+
|
405 |
+
root_dir = "./train_data/bedlam/data/"
|
406 |
+
meta_path = "./train_data/bedlam/data/annots/valid_list.json"
|
407 |
+
dataset = BedlamDataset(
|
408 |
+
root_dirs=root_dir,
|
409 |
+
meta_path=meta_path,
|
410 |
+
sample_side_views=3,
|
411 |
+
render_image_res_low=384,
|
412 |
+
render_image_res_high=384,
|
413 |
+
render_region_size=(682, 384),
|
414 |
+
source_image_res=384,
|
415 |
+
valid_area_ratio=0.1,
|
416 |
+
debug=False,
|
417 |
+
)
|
418 |
+
|
419 |
+
for data in dataset:
|
420 |
+
print(
|
421 |
+
"source_c2ws.shape",
|
422 |
+
data["source_c2ws"].shape,
|
423 |
+
)
|
424 |
+
print(
|
425 |
+
"source_intrs.shape",
|
426 |
+
data["source_intrs"].shape,
|
427 |
+
)
|
428 |
+
print(
|
429 |
+
"source_rgbs.shape",
|
430 |
+
data["source_rgbs"].shape,
|
431 |
+
)
|
432 |
+
print(
|
433 |
+
"render_image.shape",
|
434 |
+
data["render_image"].shape,
|
435 |
+
)
|
436 |
+
print(
|
437 |
+
"c2ws.shape",
|
438 |
+
data["c2ws"].shape,
|
439 |
+
)
|
440 |
+
print(
|
441 |
+
"intrs.shape",
|
442 |
+
data["intrs"].shape,
|
443 |
+
)
|
444 |
+
# print("render_anchors.shape", data["render_anchors"].shape, )
|
445 |
+
print(
|
446 |
+
"render_full_resolutions.shape",
|
447 |
+
data["render_full_resolutions"].shape,
|
448 |
+
)
|
449 |
+
print(
|
450 |
+
"render_bg_colors.shape",
|
451 |
+
data["render_bg_colors"].shape,
|
452 |
+
)
|
453 |
+
# print("smplx_params", data["smplx_params"].keys())
|
454 |
+
print("smplx_params.body_pose.shape", data["body_pose"].shape)
|
455 |
+
print("smplx_params.expr.shape", data["expr"].shape)
|
456 |
+
print("smplx_params.betas.shape", data["betas"].shape)
|
457 |
+
os.makedirs("debug_vis/dataloader", exist_ok=True)
|
458 |
+
for i in range(data["source_rgbs"].shape[0]):
|
459 |
+
cv2.imwrite(
|
460 |
+
f"debug_vis/dataloader/source_rgbs_{i}.jpg",
|
461 |
+
(
|
462 |
+
(
|
463 |
+
data["source_rgbs"][i].permute(1, 2, 0).numpy()[:, :, (2, 1, 0)]
|
464 |
+
* 255
|
465 |
+
).astype(np.uint8)
|
466 |
+
),
|
467 |
+
)
|
468 |
+
print(
|
469 |
+
"source_rgbs",
|
470 |
+
data["source_rgbs"].shape,
|
471 |
+
)
|
472 |
+
print("source_intrs", data["source_intrs"][i])
|
473 |
+
|
474 |
+
for i in range(data["render_image"].shape[0]):
|
475 |
+
cv2.imwrite(
|
476 |
+
f"debug_vis/dataloader/rgbs{i}.jpg",
|
477 |
+
(
|
478 |
+
(
|
479 |
+
data["render_image"][i]
|
480 |
+
.permute(1, 2, 0)
|
481 |
+
.numpy()[:, :, (2, 1, 0)]
|
482 |
+
* 255
|
483 |
+
).astype(np.uint8)
|
484 |
+
),
|
485 |
+
)
|
486 |
+
print(
|
487 |
+
"render_image",
|
488 |
+
data["render_image"].shape,
|
489 |
+
)
|
490 |
+
print("render_full_resolutions", data["render_full_resolutions"][i])
|
491 |
+
# print("render_anchors", data["render_anchors"][i])
|
492 |
+
print("intrs", data["intrs"][i])
|
493 |
+
xx
|
LHM/datasets/bedlam_util.py
ADDED
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Multi-HMR
|
2 |
+
# Copyright (c) 2024-present NAVER Corp.
|
3 |
+
# CC BY-NC-SA 4.0 license
|
4 |
+
|
5 |
+
import os
|
6 |
+
# os.environ["PYOPENGL_PLATFORM"] = "egl"
|
7 |
+
# os.environ['EGL_DEVICE_ID'] = '0'
|
8 |
+
|
9 |
+
import warnings
|
10 |
+
import pickle
|
11 |
+
import torch
|
12 |
+
import smplx
|
13 |
+
from tqdm import tqdm
|
14 |
+
import sys
|
15 |
+
import numpy as np
|
16 |
+
from PIL import Image, ImageOps, ImageFile
|
17 |
+
import random
|
18 |
+
import json
|
19 |
+
import tqdm
|
20 |
+
import cv2
|
21 |
+
import traceback
|
22 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True # to avoid "OSError: image file is truncated"
|
23 |
+
from torch.utils.data import Dataset
|
24 |
+
|
25 |
+
BEDLAM_DIR = "./train_data/bedlam/data"
|
26 |
+
SMPLX_DIR = "./pretrained_models/human_model_files"
|
27 |
+
ANNOT_DIR = "./train_data/bedlam/data/annots"
|
28 |
+
|
29 |
+
class BEDLAMSeg(Dataset):
|
30 |
+
def __init__(self,
|
31 |
+
split='training',
|
32 |
+
training=False,
|
33 |
+
img_size=512,
|
34 |
+
root_dir=BEDLAM_DIR,
|
35 |
+
force_build_dataset=0,
|
36 |
+
n_iter=None,
|
37 |
+
subsample=1,
|
38 |
+
extension='png',
|
39 |
+
crops=[0],
|
40 |
+
flip=1,
|
41 |
+
res=None,
|
42 |
+
n=-1,
|
43 |
+
):
|
44 |
+
super().__init__()
|
45 |
+
|
46 |
+
self.name = 'bedlam'
|
47 |
+
self.annotations_dir = ANNOT_DIR
|
48 |
+
self.training = training
|
49 |
+
self.img_size = img_size
|
50 |
+
self.n_iter = n_iter
|
51 |
+
self.subsample = subsample
|
52 |
+
self.crops = crops # 0 is the default
|
53 |
+
self.flip = flip # 1 by default
|
54 |
+
|
55 |
+
assert split in ['training', 'validation']
|
56 |
+
|
57 |
+
self.root_dir = root_dir
|
58 |
+
self.split = split
|
59 |
+
self.image_dir = os.path.join(self.root_dir, f"{self.split}")
|
60 |
+
self.mask_dir = os.path.join(self.root_dir, "masks")
|
61 |
+
|
62 |
+
self.annot_file = os.path.join(self.annotations_dir, f"{self.name}_{split}.pkl")
|
63 |
+
# self.force_build_dataset = force_build_dataset
|
64 |
+
|
65 |
+
self.annots = None
|
66 |
+
# if self.force_build_dataset or not os.path.isfile(self.annot_file):
|
67 |
+
# self.annots = self.build_dataset()
|
68 |
+
if self.annots is None:
|
69 |
+
with open(self.annot_file, 'rb') as f:
|
70 |
+
self.annots = pickle.load(f)
|
71 |
+
|
72 |
+
self.imagenames = list(self.annots.keys())
|
73 |
+
self.imagenames.sort()
|
74 |
+
|
75 |
+
|
76 |
+
def __len__(self):
|
77 |
+
return len(self.imagenames)
|
78 |
+
|
79 |
+
def __repr__(self):
|
80 |
+
return f"{self.name}: split={self.split} - N={len(self.imagenames)}"
|
81 |
+
|
82 |
+
def save_smplx_params_to_json(self, person, focal, princpt, valid_area_ratio, save_path):
|
83 |
+
smplx_params = {}
|
84 |
+
smplx_params["betas"] = person['smplx_shape'].reshape(11).tolist()
|
85 |
+
smplx_params["root_pose"] = person["smplx_root_pose"].reshape(3).tolist()
|
86 |
+
smplx_params['body_pose'] = person["smplx_body_pose"].tolist()
|
87 |
+
smplx_params['jaw_pose'] = person["smplx_jaw_pose"].reshape(3).tolist()
|
88 |
+
smplx_params['leye_pose'] = person["smplx_leye_pose"].reshape(3).tolist()
|
89 |
+
smplx_params['reye_pose'] = person["smplx_reye_pose"].reshape(3).tolist()
|
90 |
+
smplx_params['lhand_pose'] = person["smplx_left_hand_pose"].tolist()
|
91 |
+
smplx_params['rhand_pose'] = person["smplx_right_hand_pose"].tolist()
|
92 |
+
smplx_params['trans'] = person["smplx_transl"].reshape(3).tolist()
|
93 |
+
smplx_params['expr'] = np.zeros(10).tolist()
|
94 |
+
|
95 |
+
smplx_params['focal'] = focal
|
96 |
+
smplx_params['princpt'] = princpt
|
97 |
+
smplx_params['valid_area_ratio'] = valid_area_ratio
|
98 |
+
|
99 |
+
# for k, v in smplx_params.items():
|
100 |
+
# print(k, np.array(v).shape)
|
101 |
+
|
102 |
+
with open(save_path, 'w') as fp:
|
103 |
+
json.dump(smplx_params, fp)
|
104 |
+
|
105 |
+
return smplx_params
|
106 |
+
|
107 |
+
|
108 |
+
def center_crop_and_resize(self, img, mask, princpt_x, princpt_y, fx, fy, area_ratio):
|
109 |
+
|
110 |
+
ys, xs = np.where(mask > 0)
|
111 |
+
|
112 |
+
if len(xs) == 0 or len(ys) == 0:
|
113 |
+
print(f"unvalid: no body")
|
114 |
+
return None
|
115 |
+
|
116 |
+
x_min = np.min(xs)
|
117 |
+
x_max = np.max(xs)
|
118 |
+
y_min = np.min(ys)
|
119 |
+
y_max = np.max(ys)
|
120 |
+
|
121 |
+
center_x, center_y = img.shape[1]//2, img.shape[0]//2
|
122 |
+
|
123 |
+
half_w = max(abs(center_x - x_min), abs(center_x - x_max))
|
124 |
+
half_h = max(abs(center_y - y_min), abs(center_y - y_max))
|
125 |
+
ratio = half_h / half_w
|
126 |
+
ratio_standard= 1280 / 720
|
127 |
+
if ratio >= 1:
|
128 |
+
if ratio >= ratio_standard:
|
129 |
+
half_w = round(half_h / ratio_standard)
|
130 |
+
else:
|
131 |
+
half_h = round(half_w * ratio_standard)
|
132 |
+
else:
|
133 |
+
print(f"unvalid: h/w ratio:{ratio}")
|
134 |
+
return None
|
135 |
+
|
136 |
+
assert abs(half_h / half_w - ratio_standard) < 0.1
|
137 |
+
offset_x = center_x - half_w
|
138 |
+
offset_y = center_y - half_h
|
139 |
+
|
140 |
+
new_img = img[offset_y: offset_y + 2*half_h, offset_x: offset_x + 2*half_w]
|
141 |
+
new_mask = mask[offset_y: offset_y + 2*half_h, offset_x: offset_x + 2*half_w]
|
142 |
+
|
143 |
+
princpt_x -= offset_x
|
144 |
+
princpt_y -= offset_y
|
145 |
+
|
146 |
+
new_img = cv2.resize(new_img, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC)
|
147 |
+
new_mask = cv2.resize(new_mask, (mask.shape[1], mask.shape[0]), interpolation=cv2.INTER_NEAREST)
|
148 |
+
|
149 |
+
valid_area_ratio = np.sum(new_mask > 0) / new_mask.shape[0] / new_mask.shape[1]
|
150 |
+
if valid_area_ratio < area_ratio:
|
151 |
+
print(f"unvalid: area ratio:{valid_area_ratio}")
|
152 |
+
return None
|
153 |
+
|
154 |
+
scale = img.shape[0] / 2. / half_h
|
155 |
+
|
156 |
+
fx *= scale
|
157 |
+
princpt_x *= scale
|
158 |
+
|
159 |
+
fy *= scale
|
160 |
+
princpt_y *= scale
|
161 |
+
|
162 |
+
new_img = np.concatenate([new_img, new_mask[:, :, None]], axis=2)
|
163 |
+
|
164 |
+
return new_img, princpt_x, princpt_y, fx, fy, valid_area_ratio
|
165 |
+
|
166 |
+
|
167 |
+
def __getitem__(self, idx):
|
168 |
+
imagename = self.imagenames[idx]
|
169 |
+
annot = self.annots[imagename].copy()
|
170 |
+
annot['imagename'] = imagename
|
171 |
+
|
172 |
+
# find appropriate image_dir
|
173 |
+
img_path = os.path.join(self.image_dir, imagename)
|
174 |
+
|
175 |
+
mask_path = os.path.join(self.mask_dir, imagename.replace("_6fps/png/", "/masks/").replace(".png", "_env.png"))
|
176 |
+
assert os.path.exists(mask_path), f"mask_path:{mask_path}"
|
177 |
+
|
178 |
+
# Original size
|
179 |
+
real_width, real_height = annot['size']
|
180 |
+
|
181 |
+
# preprocessing the image
|
182 |
+
img_pil = Image.open(img_path)
|
183 |
+
if img_pil.mode != 'RGB':
|
184 |
+
img_pil = img_pil.convert('RGB')
|
185 |
+
|
186 |
+
# BEDLAM specifc to correct the rotation issue
|
187 |
+
# https://github.com/pixelite1201/BEDLAM/blob/ebf8bb14a43de46cc74dca4c00c13e571b325726/visualize_ground_truth.py#L183
|
188 |
+
if self.name == 'bedlam' and 'closeup' in imagename and self.split != 'test':
|
189 |
+
img_pil = img_pil.rotate(-90, expand=True)
|
190 |
+
|
191 |
+
|
192 |
+
# preprocessing the image
|
193 |
+
mask_pil = Image.open(mask_path)
|
194 |
+
# if mask_pil.mode != 'RGB':
|
195 |
+
# img_pil = img_pil.convert('RGB')
|
196 |
+
|
197 |
+
# BEDLAM specifc to correct the rotation issue
|
198 |
+
# https://github.com/pixelite1201/BEDLAM/blob/ebf8bb14a43de46cc74dca4c00c13e571b325726/visualize_ground_truth.py#L183
|
199 |
+
if self.name == 'bedlam' and 'closeup' in imagename and self.split != 'test':
|
200 |
+
mask_pil = mask_pil.rotate(-90, expand=True)
|
201 |
+
|
202 |
+
img = np.asarray(img_pil)
|
203 |
+
mask = np.asarray(mask_pil)
|
204 |
+
mask = 255 * (mask < 1).astype(np.uint8)
|
205 |
+
|
206 |
+
princpt, focal = annot['princpt'], annot['focal']
|
207 |
+
|
208 |
+
ret = self.center_crop_and_resize(img, mask, princpt[0], princpt[1], focal[0], focal[1], area_ratio=0.05)
|
209 |
+
if ret is None:
|
210 |
+
print(f"unvalid, img_path:{img_path}")
|
211 |
+
return
|
212 |
+
|
213 |
+
new_img, princpt_x, princpt_y, fx, fy, valid_area_ratio = ret
|
214 |
+
# print(new_img.shape, princpt_x, princpt_y, fx, fy, "ori", princpt, focal)
|
215 |
+
|
216 |
+
princpt = princpt_x, princpt_y
|
217 |
+
focal = fx, fy
|
218 |
+
|
219 |
+
|
220 |
+
save_path = img_path.replace("/png/", "/png_post/")
|
221 |
+
save_vis_path = img_path.replace("/png/", "/png_post_vis/")
|
222 |
+
save_smplx_path = img_path.replace("/png/", "/smplx/").replace(".png", ".json")
|
223 |
+
|
224 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
225 |
+
os.makedirs(os.path.dirname(save_vis_path), exist_ok=True)
|
226 |
+
os.makedirs(os.path.dirname(save_smplx_path), exist_ok=True)
|
227 |
+
|
228 |
+
cv2.imwrite(save_path, new_img[:, :, (2, 1, 0, 3)])
|
229 |
+
cv2.imwrite(save_vis_path, np.hstack([np.concatenate([img, 255 * np.ones_like(mask[:,:, None])], axis=2), new_img])[:, :, (2, 1, 0, 3)])
|
230 |
+
|
231 |
+
# Humans
|
232 |
+
_humans = annot['humans'].copy()
|
233 |
+
# annot.pop('humans')
|
234 |
+
# if self.training:
|
235 |
+
humans = [hum for hum in _humans if hum['smplx_transl'][-1] > 0.01] # the person should be in front of the camera
|
236 |
+
# else:
|
237 |
+
# humans = [hum for hum in _humans]
|
238 |
+
|
239 |
+
assert len(humans) == 1
|
240 |
+
|
241 |
+
self.save_smplx_params_to_json(humans[0], focal, princpt, valid_area_ratio, save_smplx_path)
|
242 |
+
|
243 |
+
# return img_array, annot
|
244 |
+
|
245 |
+
def create_annots(splits=['validation', 'training']):
|
246 |
+
for split in splits:
|
247 |
+
dataset = BEDLAM(split=split, force_build_dataset=1)
|
248 |
+
|
249 |
+
|
250 |
+
def visualize(split='validation', i=1500, res=None, extension='png', training=0, img_size=800):
|
251 |
+
# training - 52287 for a closeup
|
252 |
+
from utils import render_meshes, demo_color
|
253 |
+
model_neutral = smplx.create(SMPLX_DIR, 'smplx', gender='neutral', num_betas=11, use_pca=False, flat_hand_mean=True)
|
254 |
+
|
255 |
+
dataset = BEDLAM(split=split, force_build_dataset=0,
|
256 |
+
res=res, extension=extension,
|
257 |
+
training=training,
|
258 |
+
img_size=img_size,
|
259 |
+
)
|
260 |
+
print(dataset)
|
261 |
+
|
262 |
+
img_array, annot = dataset.__getitem__(i)
|
263 |
+
|
264 |
+
img_array = denormalize_rgb(img_array, imagenet_normalization=1)
|
265 |
+
verts_list = []
|
266 |
+
for person in annot['humans']:
|
267 |
+
with torch.no_grad():
|
268 |
+
verts = model_neutral(
|
269 |
+
global_orient=torch.from_numpy(person['smplx_root_pose']).reshape(1,-1),
|
270 |
+
body_pose=torch.from_numpy(person['smplx_body_pose']).reshape(1,-1),
|
271 |
+
jaw_pose=torch.from_numpy(person['smplx_jaw_pose']).reshape(1,-1),
|
272 |
+
leye_pose=torch.from_numpy(person['smplx_leye_pose']).reshape(1,-1),
|
273 |
+
reye_pose=torch.from_numpy(person['smplx_reye_pose']).reshape(1,-1),
|
274 |
+
left_hand_pose=torch.from_numpy(person['smplx_left_hand_pose']).reshape(1,-1),
|
275 |
+
right_hand_pose=torch.from_numpy(person['smplx_right_hand_pose']).reshape(1,-1),
|
276 |
+
betas=torch.from_numpy(person['smplx_shape']).reshape(1,-1),
|
277 |
+
transl=torch.from_numpy(person['smplx_transl']).reshape(1,-1),
|
278 |
+
).vertices.cpu().numpy().reshape(-1,3)
|
279 |
+
verts_list.append(verts)
|
280 |
+
faces_list = [model_neutral.faces for _ in annot['humans']]
|
281 |
+
_color = [demo_color[0] for _ in annot['humans']]
|
282 |
+
pred_rend_array = render_meshes(img_array.copy(),
|
283 |
+
verts_list,
|
284 |
+
faces_list,
|
285 |
+
{'focal': annot['K'][[0,1],[0,1]],
|
286 |
+
'princpt': annot['K'][[0,1],[-1,-1]]},
|
287 |
+
alpha=0.7,
|
288 |
+
color=_color)
|
289 |
+
img_array = np.concatenate([img_array, np.asarray(pred_rend_array)], 1)
|
290 |
+
|
291 |
+
fn = f"{dataset.name}_{split}_{i}.jpg"
|
292 |
+
Image.fromarray(img_array).save(fn)
|
293 |
+
print(f"open {fn}")
|
294 |
+
return 1
|
295 |
+
|
296 |
+
|
297 |
+
|
298 |
+
if __name__ == "__main__":
|
299 |
+
# exec(sys.argv[1])
|
300 |
+
dataset = BEDLAMSeg(split="validation")
|
301 |
+
for i in tqdm.tqdm(range(len(dataset))):
|
302 |
+
try:
|
303 |
+
dataset.__getitem__(i)
|
304 |
+
except:
|
305 |
+
traceback.print_exc()
|
306 |
+
continue
|
LHM/datasets/cam_utils.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024, Zexin He
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import math
|
17 |
+
import torch
|
18 |
+
|
19 |
+
"""
|
20 |
+
R: (N, 3, 3)
|
21 |
+
T: (N, 3)
|
22 |
+
E: (N, 4, 4)
|
23 |
+
vector: (N, 3)
|
24 |
+
"""
|
25 |
+
|
26 |
+
|
27 |
+
def compose_extrinsic_R_T(R: torch.Tensor, T: torch.Tensor):
|
28 |
+
"""
|
29 |
+
Compose the standard form extrinsic matrix from R and T.
|
30 |
+
Batched I/O.
|
31 |
+
"""
|
32 |
+
RT = torch.cat((R, T.unsqueeze(-1)), dim=-1)
|
33 |
+
return compose_extrinsic_RT(RT)
|
34 |
+
|
35 |
+
|
36 |
+
def compose_extrinsic_RT(RT: torch.Tensor):
|
37 |
+
"""
|
38 |
+
Compose the standard form extrinsic matrix from RT.
|
39 |
+
Batched I/O.
|
40 |
+
"""
|
41 |
+
return torch.cat([
|
42 |
+
RT,
|
43 |
+
torch.tensor([[[0, 0, 0, 1]]], dtype=RT.dtype, device=RT.device).repeat(RT.shape[0], 1, 1)
|
44 |
+
], dim=1)
|
45 |
+
|
46 |
+
|
47 |
+
def decompose_extrinsic_R_T(E: torch.Tensor):
|
48 |
+
"""
|
49 |
+
Decompose the standard extrinsic matrix into R and T.
|
50 |
+
Batched I/O.
|
51 |
+
"""
|
52 |
+
RT = decompose_extrinsic_RT(E)
|
53 |
+
return RT[:, :, :3], RT[:, :, 3]
|
54 |
+
|
55 |
+
|
56 |
+
def decompose_extrinsic_RT(E: torch.Tensor):
|
57 |
+
"""
|
58 |
+
Decompose the standard extrinsic matrix into RT.
|
59 |
+
Batched I/O.
|
60 |
+
"""
|
61 |
+
return E[:, :3, :]
|
62 |
+
|
63 |
+
|
64 |
+
def camera_normalization_objaverse(normed_dist_to_center, poses: torch.Tensor, ret_transform: bool = False):
|
65 |
+
assert normed_dist_to_center is not None
|
66 |
+
pivotal_pose = compose_extrinsic_RT(poses[:1])
|
67 |
+
dist_to_center = pivotal_pose[:, :3, 3].norm(dim=-1, keepdim=True).item() \
|
68 |
+
if normed_dist_to_center == 'auto' else normed_dist_to_center
|
69 |
+
|
70 |
+
# compute camera norm (new version)
|
71 |
+
canonical_camera_extrinsics = torch.tensor([[
|
72 |
+
[1, 0, 0, 0],
|
73 |
+
[0, 0, -1, -dist_to_center],
|
74 |
+
[0, 1, 0, 0],
|
75 |
+
[0, 0, 0, 1],
|
76 |
+
]], dtype=torch.float32)
|
77 |
+
pivotal_pose_inv = torch.inverse(pivotal_pose)
|
78 |
+
camera_norm_matrix = torch.bmm(canonical_camera_extrinsics, pivotal_pose_inv)
|
79 |
+
|
80 |
+
# normalize all views
|
81 |
+
poses = compose_extrinsic_RT(poses)
|
82 |
+
poses = torch.bmm(camera_norm_matrix.repeat(poses.shape[0], 1, 1), poses)
|
83 |
+
poses = decompose_extrinsic_RT(poses)
|
84 |
+
|
85 |
+
if ret_transform:
|
86 |
+
return poses, camera_norm_matrix.squeeze(dim=0)
|
87 |
+
return poses
|
88 |
+
|
89 |
+
|
90 |
+
def get_normalized_camera_intrinsics(intrinsics: torch.Tensor):
|
91 |
+
"""
|
92 |
+
intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
|
93 |
+
Return batched fx, fy, cx, cy
|
94 |
+
"""
|
95 |
+
fx, fy = intrinsics[:, 0, 0], intrinsics[:, 0, 1]
|
96 |
+
cx, cy = intrinsics[:, 1, 0], intrinsics[:, 1, 1]
|
97 |
+
width, height = intrinsics[:, 2, 0], intrinsics[:, 2, 1]
|
98 |
+
fx, fy = fx / width, fy / height
|
99 |
+
cx, cy = cx / width, cy / height
|
100 |
+
return fx, fy, cx, cy
|
101 |
+
|
102 |
+
|
103 |
+
def build_camera_principle(RT: torch.Tensor, intrinsics: torch.Tensor):
|
104 |
+
"""
|
105 |
+
RT: (N, 3, 4)
|
106 |
+
intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
|
107 |
+
"""
|
108 |
+
fx, fy, cx, cy = get_normalized_camera_intrinsics(intrinsics)
|
109 |
+
return torch.cat([
|
110 |
+
RT.reshape(-1, 12),
|
111 |
+
fx.unsqueeze(-1), fy.unsqueeze(-1), cx.unsqueeze(-1), cy.unsqueeze(-1),
|
112 |
+
], dim=-1)
|
113 |
+
|
114 |
+
|
115 |
+
def build_camera_standard(RT: torch.Tensor, intrinsics: torch.Tensor):
|
116 |
+
"""
|
117 |
+
RT: (N, 3, 4)
|
118 |
+
intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
|
119 |
+
"""
|
120 |
+
E = compose_extrinsic_RT(RT)
|
121 |
+
fx, fy, cx, cy = get_normalized_camera_intrinsics(intrinsics)
|
122 |
+
I = torch.stack([
|
123 |
+
torch.stack([fx, torch.zeros_like(fx), cx], dim=-1),
|
124 |
+
torch.stack([torch.zeros_like(fy), fy, cy], dim=-1),
|
125 |
+
torch.tensor([[0, 0, 1]], dtype=torch.float32, device=RT.device).repeat(RT.shape[0], 1),
|
126 |
+
], dim=1)
|
127 |
+
return torch.cat([
|
128 |
+
E.reshape(-1, 16),
|
129 |
+
I.reshape(-1, 9),
|
130 |
+
], dim=-1)
|
131 |
+
|
132 |
+
|
133 |
+
def center_looking_at_camera_pose(
|
134 |
+
camera_position: torch.Tensor, look_at: torch.Tensor = None, up_world: torch.Tensor = None,
|
135 |
+
device: torch.device = torch.device('cpu'),
|
136 |
+
):
|
137 |
+
"""
|
138 |
+
camera_position: (M, 3)
|
139 |
+
look_at: (3)
|
140 |
+
up_world: (3)
|
141 |
+
return: (M, 3, 4)
|
142 |
+
"""
|
143 |
+
# by default, looking at the origin and world up is pos-z
|
144 |
+
if look_at is None:
|
145 |
+
look_at = torch.tensor([0, 0, 0], dtype=torch.float32, device=device)
|
146 |
+
if up_world is None:
|
147 |
+
up_world = torch.tensor([0, 0, 1], dtype=torch.float32, device=device)
|
148 |
+
look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1)
|
149 |
+
up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1)
|
150 |
+
|
151 |
+
z_axis = camera_position - look_at
|
152 |
+
z_axis = z_axis / z_axis.norm(dim=-1, keepdim=True)
|
153 |
+
x_axis = torch.cross(up_world, z_axis)
|
154 |
+
x_axis = x_axis / x_axis.norm(dim=-1, keepdim=True)
|
155 |
+
y_axis = torch.cross(z_axis, x_axis)
|
156 |
+
y_axis = y_axis / y_axis.norm(dim=-1, keepdim=True)
|
157 |
+
extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1)
|
158 |
+
return extrinsics
|
159 |
+
|
160 |
+
|
161 |
+
def surrounding_views_linspace(n_views: int, radius: float = 2.0, height: float = 0.8, device: torch.device = torch.device('cpu')):
|
162 |
+
"""
|
163 |
+
n_views: number of surrounding views
|
164 |
+
radius: camera dist to center
|
165 |
+
height: height of the camera
|
166 |
+
return: (M, 3, 4)
|
167 |
+
"""
|
168 |
+
assert n_views > 0
|
169 |
+
assert radius > 0
|
170 |
+
|
171 |
+
theta = torch.linspace(-torch.pi / 2, 3 * torch.pi / 2, n_views, device=device)
|
172 |
+
projected_radius = math.sqrt(radius ** 2 - height ** 2)
|
173 |
+
x = torch.cos(theta) * projected_radius
|
174 |
+
y = torch.sin(theta) * projected_radius
|
175 |
+
z = torch.full((n_views,), height, device=device)
|
176 |
+
|
177 |
+
camera_positions = torch.stack([x, y, z], dim=1)
|
178 |
+
extrinsics = center_looking_at_camera_pose(camera_positions, device=device)
|
179 |
+
|
180 |
+
return extrinsics
|
181 |
+
|
182 |
+
|
183 |
+
def create_intrinsics(
|
184 |
+
f: float,
|
185 |
+
c: float = None, cx: float = None, cy: float = None,
|
186 |
+
w: float = 1., h: float = 1.,
|
187 |
+
dtype: torch.dtype = torch.float32,
|
188 |
+
device: torch.device = torch.device('cpu'),
|
189 |
+
):
|
190 |
+
"""
|
191 |
+
return: (3, 2)
|
192 |
+
"""
|
193 |
+
fx = fy = f
|
194 |
+
if c is not None:
|
195 |
+
assert cx is None and cy is None, "c and cx/cy cannot be used together"
|
196 |
+
cx = cy = c
|
197 |
+
else:
|
198 |
+
assert cx is not None and cy is not None, "cx/cy must be provided when c is not provided"
|
199 |
+
fx, fy, cx, cy, w, h = fx/w, fy/h, cx/w, cy/h, 1., 1.
|
200 |
+
intrinsics = torch.tensor([
|
201 |
+
[fx, fy],
|
202 |
+
[cx, cy],
|
203 |
+
[w, h],
|
204 |
+
], dtype=dtype, device=device)
|
205 |
+
return intrinsics
|
LHM/datasets/mixer.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024, Zexin He
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");:
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import math
|
17 |
+
import pdb
|
18 |
+
from functools import partial
|
19 |
+
|
20 |
+
import torch
|
21 |
+
|
22 |
+
__all__ = ["MixerDataset"]
|
23 |
+
|
24 |
+
|
25 |
+
class MixerDataset(torch.utils.data.Dataset):
|
26 |
+
"""Reference"""
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
split: str,
|
31 |
+
subsets: dict,
|
32 |
+
**dataset_kwargs,
|
33 |
+
):
|
34 |
+
|
35 |
+
self.subsets = [
|
36 |
+
self._dataset_fn(subset, split)(
|
37 |
+
use_flame=subset["use_flame"],
|
38 |
+
src_head_size=subset.get("src_head_size", 448),
|
39 |
+
**dataset_kwargs,
|
40 |
+
)
|
41 |
+
for subset in subsets
|
42 |
+
]
|
43 |
+
self.virtual_lens = [
|
44 |
+
math.ceil(subset_config["sample_rate"] * len(subset_obj))
|
45 |
+
for subset_config, subset_obj in zip(subsets, self.subsets)
|
46 |
+
]
|
47 |
+
|
48 |
+
@staticmethod
|
49 |
+
def _dataset_fn(subset_config: dict, split: str):
|
50 |
+
name = subset_config["name"]
|
51 |
+
|
52 |
+
dataset_cls = None
|
53 |
+
if name == "exavatar":
|
54 |
+
from .exavatar import ExAvatarDataset
|
55 |
+
|
56 |
+
dataset_cls = ExAvatarDataset
|
57 |
+
elif name == "humman":
|
58 |
+
from .humman import HuMManDataset
|
59 |
+
|
60 |
+
dataset_cls = HuMManDataset
|
61 |
+
elif name == "static_human":
|
62 |
+
from .static_human import StaticHumanDataset
|
63 |
+
|
64 |
+
dataset_cls = StaticHumanDataset
|
65 |
+
elif name == "singleview_human":
|
66 |
+
from .singleview_human import SingleViewHumanDataset
|
67 |
+
|
68 |
+
dataset_cls = SingleViewHumanDataset
|
69 |
+
elif name == "singleview_square_human":
|
70 |
+
from .singleview_square_human import SingleViewSquareHumanDataset
|
71 |
+
|
72 |
+
dataset_cls = SingleViewSquareHumanDataset
|
73 |
+
elif name == "bedlam":
|
74 |
+
from .bedlam import BedlamDataset
|
75 |
+
|
76 |
+
dataset_cls = BedlamDataset
|
77 |
+
elif name == "dna_human":
|
78 |
+
from .dna import DNAHumanDataset
|
79 |
+
|
80 |
+
dataset_cls = DNAHumanDataset
|
81 |
+
elif name == "video_human":
|
82 |
+
from .video_human import VideoHumanDataset
|
83 |
+
|
84 |
+
dataset_cls = VideoHumanDataset
|
85 |
+
elif name == "video_human_flame":
|
86 |
+
from .video_human_flame import VideoHumanFlameDataset
|
87 |
+
|
88 |
+
dataset_cls = VideoHumanFlameDataset
|
89 |
+
elif name == "video_human_flame_dp":
|
90 |
+
from .video_human_flame_df import VideoHumanFlameDFDataset
|
91 |
+
|
92 |
+
# add deepfashon random sample in video_human_flame
|
93 |
+
dataset_cls = VideoHumanFlameDFDataset
|
94 |
+
elif name == "objaverse":
|
95 |
+
from .objaverse import ObjaverseDataset
|
96 |
+
|
97 |
+
dataset_cls = ObjaverseDataset
|
98 |
+
# elif name == 'mvimgnet':
|
99 |
+
# from .mvimgnet import MVImgNetDataset
|
100 |
+
# dataset_cls = MVImgNetDataset
|
101 |
+
else:
|
102 |
+
raise NotImplementedError(f"Dataset {name} not implemented")
|
103 |
+
|
104 |
+
return partial(
|
105 |
+
dataset_cls,
|
106 |
+
root_dirs=subset_config["root_dirs"],
|
107 |
+
meta_path=subset_config["meta_path"][split],
|
108 |
+
)
|
109 |
+
|
110 |
+
def __len__(self):
|
111 |
+
return sum(self.virtual_lens)
|
112 |
+
|
113 |
+
def __getitem__(self, idx):
|
114 |
+
subset_idx = 0
|
115 |
+
virtual_idx = idx
|
116 |
+
while virtual_idx >= self.virtual_lens[subset_idx]:
|
117 |
+
virtual_idx -= self.virtual_lens[subset_idx]
|
118 |
+
subset_idx += 1
|
119 |
+
real_idx = virtual_idx % len(self.subsets[subset_idx])
|
120 |
+
return self.subsets[subset_idx][real_idx]
|
LHM/launch.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024, Zexin He
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import argparse
|
15 |
+
import pdb
|
16 |
+
|
17 |
+
from LHM.runners import REGISTRY_RUNNERS
|
18 |
+
|
19 |
+
|
20 |
+
def main():
|
21 |
+
|
22 |
+
parser = argparse.ArgumentParser(description="OpenLRM launcher")
|
23 |
+
parser.add_argument("runner", type=str, help="Runner to launch")
|
24 |
+
args, unknown = parser.parse_known_args()
|
25 |
+
|
26 |
+
if args.runner not in REGISTRY_RUNNERS:
|
27 |
+
raise ValueError("Runner {} not found".format(args.runner))
|
28 |
+
|
29 |
+
RunnerClass = REGISTRY_RUNNERS[args.runner]
|
30 |
+
with RunnerClass() as runner:
|
31 |
+
runner.run()
|
32 |
+
|
33 |
+
|
34 |
+
if __name__ == "__main__":
|
35 |
+
main()
|
LHM/losses/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024, Zexin He
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from .ball_loss import *
|
17 |
+
from .offset_loss import *
|
18 |
+
from .perceptual import *
|
19 |
+
from .pixelwise import *
|
20 |
+
from .tvloss import *
|
LHM/losses/ball_loss.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Organization : Alibaba XR-Lab
|
3 |
+
# @Author : Lingteng Qiu
|
4 |
+
# @Email : [email protected]
|
5 |
+
# @Time : 2025-03-10 19:08:35
|
6 |
+
# @Function : ASAP loss
|
7 |
+
import pdb
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
__all__ = ["ASAP_Loss", "Heuristic_ASAP_Loss"]
|
13 |
+
|
14 |
+
|
15 |
+
class ASAP_Loss(nn.Module):
|
16 |
+
|
17 |
+
def forward(self, scaling, r=1, **params):
|
18 |
+
"""where r is the radius of the ball between max-axis and min-axis."""
|
19 |
+
raise NotImplementedError(
|
20 |
+
"ASAP_Loss is not implemented yet in Inference version"
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
class Heuristic_ASAP_Loss(nn.Module):
|
25 |
+
def __init__(self, group_dict, group_body_mapping):
|
26 |
+
super(Heuristic_ASAP_Loss, self).__init__()
|
27 |
+
|
28 |
+
self.group_dict = group_dict # register weights fro different body parts
|
29 |
+
self.group_body_mapping = group_body_mapping # mapping of body parts to group
|
30 |
+
|
31 |
+
def _heurisitic_loss(self, _ball_loss):
|
32 |
+
|
33 |
+
_loss = 0.0
|
34 |
+
for key in self.group_dict.keys():
|
35 |
+
key_weights = self.group_dict[key]
|
36 |
+
group_mapping_idx = self.group_body_mapping[key]
|
37 |
+
_loss += key_weights * _ball_loss[:, group_mapping_idx].mean()
|
38 |
+
|
39 |
+
return _loss
|
40 |
+
|
41 |
+
def forward(self, scaling, r=5, **params):
|
42 |
+
"""where r is the radius of the ball between max-axis and min-axis."""
|
43 |
+
"human motion or rotation is very different in each body parts, for example, the head is more stable than the leg and hand, so we use heuristic_ball_loss"
|
44 |
+
|
45 |
+
_scale = scaling
|
46 |
+
|
47 |
+
_scale_min = torch.min(_scale, dim=-1)[0]
|
48 |
+
_scale_max = torch.max(_scale, dim=-1)[0]
|
49 |
+
|
50 |
+
scale_ratio = _scale_max / (_scale_min + 1e-6)
|
51 |
+
|
52 |
+
_ball_loss = torch.clamp(scale_ratio, min=r) - r
|
53 |
+
|
54 |
+
return self._heurisitic_loss(_ball_loss)
|
LHM/losses/offset_loss.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Organization : Alibaba XR-Lab
|
3 |
+
# @Author : Lingteng Qiu
|
4 |
+
# @Email : [email protected]
|
5 |
+
# @Time : 2025-03-10 19:08:56
|
6 |
+
# @Function : ACAP Loss
|
7 |
+
import pdb
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
__all__ = ["ACAP_Loss", "Heuristic_ACAP_Loss"]
|
14 |
+
|
15 |
+
|
16 |
+
class ACAP_Loss(nn.Module):
|
17 |
+
"""As close as possibel loss"""
|
18 |
+
|
19 |
+
def forward(self, offset, d=0.05625, **params):
|
20 |
+
"""Empirically, where d is the thresold of distance points leave from 1.8/32 = 0.0562."""
|
21 |
+
|
22 |
+
offset_loss = torch.clamp(offset.norm(p=2, dim=-1), min=d) - d
|
23 |
+
|
24 |
+
return offset_loss.mean()
|
25 |
+
|
26 |
+
|
27 |
+
class Heuristic_ACAP_Loss(nn.Module):
|
28 |
+
"""As close as possibel loss"""
|
29 |
+
|
30 |
+
def __init__(self, group_dict, group_body_mapping):
|
31 |
+
super(Heuristic_ACAP_Loss, self).__init__()
|
32 |
+
|
33 |
+
self.group_dict = group_dict # register weights fro different body parts
|
34 |
+
self.group_body_mapping = group_body_mapping # mapping of body parts to group
|
35 |
+
|
36 |
+
def _heurisitic_loss(self, _offset_loss):
|
37 |
+
|
38 |
+
_loss = 0.0
|
39 |
+
for key in self.group_dict.keys():
|
40 |
+
key_weights = self.group_dict[key]
|
41 |
+
group_mapping_idx = self.group_body_mapping[key]
|
42 |
+
_loss += key_weights * _offset_loss[:, group_mapping_idx].mean()
|
43 |
+
|
44 |
+
return _loss
|
45 |
+
|
46 |
+
def forward(self, offset, d=0.05625, **params):
|
47 |
+
"""Empirically, where d is the thresold of distance points leave from human prior model, 1.8/32 = 0.0562."""
|
48 |
+
"human motion or rotation is very different in each body parts, for example, the head is more stable than the leg and hand, so we use heuristic_ball_loss"
|
49 |
+
|
50 |
+
_offset_loss = torch.clamp(offset.norm(p=2, dim=-1), min=d) - d
|
51 |
+
|
52 |
+
return self._heurisitic_loss(_offset_loss)
|
LHM/losses/perceptual.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024, Zexin He
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
|
19 |
+
__all__ = ['LPIPSLoss']
|
20 |
+
|
21 |
+
|
22 |
+
class LPIPSLoss(nn.Module):
|
23 |
+
"""
|
24 |
+
Compute LPIPS loss between two images.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, device, prefech: bool = False):
|
28 |
+
super().__init__()
|
29 |
+
self.device = device
|
30 |
+
self.cached_models = {}
|
31 |
+
if prefech:
|
32 |
+
self.prefetch_models()
|
33 |
+
|
34 |
+
def _get_model(self, model_name: str):
|
35 |
+
if model_name not in self.cached_models:
|
36 |
+
import warnings
|
37 |
+
with warnings.catch_warnings():
|
38 |
+
warnings.filterwarnings('ignore', category=UserWarning)
|
39 |
+
import lpips
|
40 |
+
_model = lpips.LPIPS(net=model_name, eval_mode=True, verbose=False).to(self.device)
|
41 |
+
_model = torch.compile(_model)
|
42 |
+
self.cached_models[model_name] = _model
|
43 |
+
return self.cached_models[model_name]
|
44 |
+
|
45 |
+
def prefetch_models(self):
|
46 |
+
_model_names = ['alex', 'vgg']
|
47 |
+
for model_name in _model_names:
|
48 |
+
self._get_model(model_name)
|
49 |
+
|
50 |
+
def forward(self, x, y, is_training: bool = True):
|
51 |
+
"""
|
52 |
+
Assume images are 0-1 scaled and channel first.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
x: [N, M, C, H, W]
|
56 |
+
y: [N, M, C, H, W]
|
57 |
+
is_training: whether to use VGG or AlexNet.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
Mean-reduced LPIPS loss across batch.
|
61 |
+
"""
|
62 |
+
model_name = 'vgg' if is_training else 'alex'
|
63 |
+
loss_fn = self._get_model(model_name)
|
64 |
+
N, M, C, H, W = x.shape
|
65 |
+
x = x.reshape(N*M, C, H, W)
|
66 |
+
y = y.reshape(N*M, C, H, W)
|
67 |
+
image_loss = loss_fn(x, y, normalize=True).mean(dim=[1, 2, 3])
|
68 |
+
batch_loss = image_loss.reshape(N, M).mean(dim=1)
|
69 |
+
all_loss = batch_loss.mean()
|
70 |
+
return all_loss
|
LHM/losses/pixelwise.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024, Zexin He
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
|
19 |
+
__all__ = ['PixelLoss']
|
20 |
+
|
21 |
+
|
22 |
+
class PixelLoss(nn.Module):
|
23 |
+
"""
|
24 |
+
Pixel-wise loss between two images.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, option: str = 'mse'):
|
28 |
+
super().__init__()
|
29 |
+
self.loss_fn = self._build_from_option(option)
|
30 |
+
|
31 |
+
@staticmethod
|
32 |
+
def _build_from_option(option: str, reduction: str = 'none'):
|
33 |
+
if option == 'mse':
|
34 |
+
return nn.MSELoss(reduction=reduction)
|
35 |
+
elif option == 'l1':
|
36 |
+
return nn.L1Loss(reduction=reduction)
|
37 |
+
else:
|
38 |
+
raise NotImplementedError(f'Unknown pixel loss option: {option}')
|
39 |
+
|
40 |
+
@torch.compile
|
41 |
+
def forward(self, x, y):
|
42 |
+
"""
|
43 |
+
Assume images are channel first.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
x: [N, M, C, H, W]
|
47 |
+
y: [N, M, C, H, W]
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
Mean-reduced pixel loss across batch.
|
51 |
+
"""
|
52 |
+
N, M, C, H, W = x.shape
|
53 |
+
x = x.reshape(N*M, C, H, W)
|
54 |
+
y = y.reshape(N*M, C, H, W)
|
55 |
+
image_loss = self.loss_fn(x, y).mean(dim=[1, 2, 3])
|
56 |
+
batch_loss = image_loss.reshape(N, M).mean(dim=1)
|
57 |
+
all_loss = batch_loss.mean()
|
58 |
+
return all_loss
|
LHM/losses/tvloss.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024, Zexin He
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
|
19 |
+
__all__ = ['TVLoss']
|
20 |
+
|
21 |
+
|
22 |
+
class TVLoss(nn.Module):
|
23 |
+
"""
|
24 |
+
Total variance loss.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
def numel_excluding_first_dim(self, x):
|
31 |
+
return x.numel() // x.shape[0]
|
32 |
+
|
33 |
+
@torch.compile
|
34 |
+
def forward(self, x):
|
35 |
+
"""
|
36 |
+
Assume batched and channel first with inner sizes.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
x: [N, M, C, H, W]
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
Mean-reduced TV loss with element-level scaling.
|
43 |
+
"""
|
44 |
+
N, M, C, H, W = x.shape
|
45 |
+
x = x.reshape(N*M, C, H, W)
|
46 |
+
diff_i = x[..., 1:, :] - x[..., :-1, :]
|
47 |
+
diff_j = x[..., :, 1:] - x[..., :, :-1]
|
48 |
+
div_i = self.numel_excluding_first_dim(diff_i)
|
49 |
+
div_j = self.numel_excluding_first_dim(diff_j)
|
50 |
+
tv_i = diff_i.pow(2).sum(dim=[1,2,3]) / div_i
|
51 |
+
tv_j = diff_j.pow(2).sum(dim=[1,2,3]) / div_j
|
52 |
+
tv = tv_i + tv_j
|
53 |
+
batch_tv = tv.reshape(N, M).mean(dim=1)
|
54 |
+
all_tv = batch_tv.mean()
|
55 |
+
return all_tv
|
LHM/models/ESRGANer_utils.py
ADDED
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Organization : Alibaba XR-Lab
|
3 |
+
# @Author : Lingteng Qiu
|
4 |
+
# @Email : [email protected]
|
5 |
+
# @Time : 2025-03-1 17:39:52
|
6 |
+
# @Function : Function to improve face quality when training.
|
7 |
+
|
8 |
+
import math
|
9 |
+
import os
|
10 |
+
import queue
|
11 |
+
import sys
|
12 |
+
|
13 |
+
sys.path.append("./")
|
14 |
+
import threading
|
15 |
+
|
16 |
+
import cv2
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
from basicsr.utils.download_util import load_file_from_url
|
20 |
+
from torch.nn import functional as F
|
21 |
+
|
22 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
23 |
+
import pdb
|
24 |
+
|
25 |
+
import torch
|
26 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
27 |
+
|
28 |
+
|
29 |
+
def avaliable_device():
|
30 |
+
if torch.cuda.is_available():
|
31 |
+
current_device_id = torch.cuda.current_device()
|
32 |
+
device = f"cuda:{current_device_id}"
|
33 |
+
else:
|
34 |
+
device = "cpu"
|
35 |
+
|
36 |
+
return device
|
37 |
+
|
38 |
+
|
39 |
+
class RealESRGANer:
|
40 |
+
"""A helper class for upsampling images with RealESRGAN.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
|
44 |
+
model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
|
45 |
+
model (nn.Module): The defined network. Default: None.
|
46 |
+
tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
|
47 |
+
input images into tiles, and then process each of them. Finally, they will be merged into one image.
|
48 |
+
0 denotes for do not use tile. Default: 0.
|
49 |
+
tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
|
50 |
+
pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
|
51 |
+
half (float): Whether to use half precision during inference. Default: False.
|
52 |
+
"""
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
scale,
|
57 |
+
model_path,
|
58 |
+
dni_weight=None,
|
59 |
+
model=None,
|
60 |
+
tile=0,
|
61 |
+
tile_pad=10,
|
62 |
+
pre_pad=10,
|
63 |
+
half=False,
|
64 |
+
device=None,
|
65 |
+
gpu_id=None,
|
66 |
+
):
|
67 |
+
self.scale = scale
|
68 |
+
self.tile_size = tile
|
69 |
+
self.tile_pad = tile_pad
|
70 |
+
self.pre_pad = pre_pad
|
71 |
+
self.mod_scale = None
|
72 |
+
self.half = half
|
73 |
+
|
74 |
+
# initialize model
|
75 |
+
if gpu_id:
|
76 |
+
self.device = (
|
77 |
+
torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
|
78 |
+
if device is None
|
79 |
+
else device
|
80 |
+
)
|
81 |
+
else:
|
82 |
+
self.device = (
|
83 |
+
torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
84 |
+
if device is None
|
85 |
+
else device
|
86 |
+
)
|
87 |
+
|
88 |
+
if isinstance(model_path, list):
|
89 |
+
# dni
|
90 |
+
assert len(model_path) == len(
|
91 |
+
dni_weight
|
92 |
+
), "model_path and dni_weight should have the save length."
|
93 |
+
loadnet = self.dni(model_path[0], model_path[1], dni_weight)
|
94 |
+
else:
|
95 |
+
# if the model_path starts with https, it will first download models to the folder: weights
|
96 |
+
if model_path.startswith("https://"):
|
97 |
+
model_path = load_file_from_url(
|
98 |
+
url=model_path,
|
99 |
+
model_dir=os.path.join(ROOT_DIR, "weights"),
|
100 |
+
progress=True,
|
101 |
+
file_name=None,
|
102 |
+
)
|
103 |
+
loadnet = torch.load(model_path, map_location=torch.device("cpu"))
|
104 |
+
|
105 |
+
# prefer to use params_ema
|
106 |
+
if "params_ema" in loadnet:
|
107 |
+
keyname = "params_ema"
|
108 |
+
else:
|
109 |
+
keyname = "params"
|
110 |
+
model.load_state_dict(loadnet[keyname], strict=True)
|
111 |
+
|
112 |
+
model.eval()
|
113 |
+
self.model = model.to(self.device)
|
114 |
+
if self.half:
|
115 |
+
self.model = self.model.half()
|
116 |
+
|
117 |
+
def dni(self, net_a, net_b, dni_weight, key="params", loc="cpu"):
|
118 |
+
"""Deep network interpolation.
|
119 |
+
|
120 |
+
``Paper: Deep Network Interpolation for Continuous Imagery Effect Transition``
|
121 |
+
"""
|
122 |
+
net_a = torch.load(net_a, map_location=torch.device(loc))
|
123 |
+
net_b = torch.load(net_b, map_location=torch.device(loc))
|
124 |
+
for k, v_a in net_a[key].items():
|
125 |
+
net_a[key][k] = dni_weight[0] * v_a + dni_weight[1] * net_b[key][k]
|
126 |
+
return net_a
|
127 |
+
|
128 |
+
def pre_process(self, img):
|
129 |
+
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible"""
|
130 |
+
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
|
131 |
+
self.img = img.unsqueeze(0).to(self.device)
|
132 |
+
if self.half:
|
133 |
+
self.img = self.img.half()
|
134 |
+
|
135 |
+
# pre_pad
|
136 |
+
if self.pre_pad != 0:
|
137 |
+
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), "reflect")
|
138 |
+
# mod pad for divisible borders
|
139 |
+
if self.scale == 2:
|
140 |
+
self.mod_scale = 2
|
141 |
+
elif self.scale == 1:
|
142 |
+
self.mod_scale = 4
|
143 |
+
if self.mod_scale is not None:
|
144 |
+
self.mod_pad_h, self.mod_pad_w = 0, 0
|
145 |
+
_, _, h, w = self.img.size()
|
146 |
+
if h % self.mod_scale != 0:
|
147 |
+
self.mod_pad_h = self.mod_scale - h % self.mod_scale
|
148 |
+
if w % self.mod_scale != 0:
|
149 |
+
self.mod_pad_w = self.mod_scale - w % self.mod_scale
|
150 |
+
self.img = F.pad(
|
151 |
+
self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), "reflect"
|
152 |
+
)
|
153 |
+
|
154 |
+
def process(self):
|
155 |
+
# model inference
|
156 |
+
self.output = self.model(self.img)
|
157 |
+
|
158 |
+
def tile_process(self):
|
159 |
+
"""It will first crop input images to tiles, and then process each tile.
|
160 |
+
Finally, all the processed tiles are merged into one images.
|
161 |
+
|
162 |
+
Modified from: https://github.com/ata4/esrgan-launcher
|
163 |
+
"""
|
164 |
+
batch, channel, height, width = self.img.shape
|
165 |
+
output_height = height * self.scale
|
166 |
+
output_width = width * self.scale
|
167 |
+
output_shape = (batch, channel, output_height, output_width)
|
168 |
+
|
169 |
+
# start with black image
|
170 |
+
self.output = self.img.new_zeros(output_shape)
|
171 |
+
tiles_x = math.ceil(width / self.tile_size)
|
172 |
+
tiles_y = math.ceil(height / self.tile_size)
|
173 |
+
|
174 |
+
# loop over all tiles
|
175 |
+
for y in range(tiles_y):
|
176 |
+
for x in range(tiles_x):
|
177 |
+
# extract tile from input image
|
178 |
+
ofs_x = x * self.tile_size
|
179 |
+
ofs_y = y * self.tile_size
|
180 |
+
# input tile area on total image
|
181 |
+
input_start_x = ofs_x
|
182 |
+
input_end_x = min(ofs_x + self.tile_size, width)
|
183 |
+
input_start_y = ofs_y
|
184 |
+
input_end_y = min(ofs_y + self.tile_size, height)
|
185 |
+
|
186 |
+
# input tile area on total image with padding
|
187 |
+
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
|
188 |
+
input_end_x_pad = min(input_end_x + self.tile_pad, width)
|
189 |
+
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
|
190 |
+
input_end_y_pad = min(input_end_y + self.tile_pad, height)
|
191 |
+
|
192 |
+
# input tile dimensions
|
193 |
+
input_tile_width = input_end_x - input_start_x
|
194 |
+
input_tile_height = input_end_y - input_start_y
|
195 |
+
tile_idx = y * tiles_x + x + 1
|
196 |
+
input_tile = self.img[
|
197 |
+
:,
|
198 |
+
:,
|
199 |
+
input_start_y_pad:input_end_y_pad,
|
200 |
+
input_start_x_pad:input_end_x_pad,
|
201 |
+
]
|
202 |
+
|
203 |
+
# upscale tile
|
204 |
+
try:
|
205 |
+
with torch.no_grad():
|
206 |
+
output_tile = self.model(input_tile)
|
207 |
+
except RuntimeError as error:
|
208 |
+
print("Error", error)
|
209 |
+
print(f"\tTile {tile_idx}/{tiles_x * tiles_y}")
|
210 |
+
|
211 |
+
# output tile area on total image
|
212 |
+
output_start_x = input_start_x * self.scale
|
213 |
+
output_end_x = input_end_x * self.scale
|
214 |
+
output_start_y = input_start_y * self.scale
|
215 |
+
output_end_y = input_end_y * self.scale
|
216 |
+
|
217 |
+
# output tile area without padding
|
218 |
+
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
|
219 |
+
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
|
220 |
+
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
|
221 |
+
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
|
222 |
+
|
223 |
+
# put tile into output image
|
224 |
+
self.output[
|
225 |
+
:, :, output_start_y:output_end_y, output_start_x:output_end_x
|
226 |
+
] = output_tile[
|
227 |
+
:,
|
228 |
+
:,
|
229 |
+
output_start_y_tile:output_end_y_tile,
|
230 |
+
output_start_x_tile:output_end_x_tile,
|
231 |
+
]
|
232 |
+
|
233 |
+
def post_process(self):
|
234 |
+
# remove extra pad
|
235 |
+
if self.mod_scale is not None:
|
236 |
+
_, _, h, w = self.output.size()
|
237 |
+
self.output = self.output[
|
238 |
+
:,
|
239 |
+
:,
|
240 |
+
0 : h - self.mod_pad_h * self.scale,
|
241 |
+
0 : w - self.mod_pad_w * self.scale,
|
242 |
+
]
|
243 |
+
# remove prepad
|
244 |
+
if self.pre_pad != 0:
|
245 |
+
_, _, h, w = self.output.size()
|
246 |
+
self.output = self.output[
|
247 |
+
:,
|
248 |
+
:,
|
249 |
+
0 : h - self.pre_pad * self.scale,
|
250 |
+
0 : w - self.pre_pad * self.scale,
|
251 |
+
]
|
252 |
+
return self.output
|
253 |
+
|
254 |
+
@torch.no_grad()
|
255 |
+
def enhance(self, img, outscale=None, alpha_upsampler="realesrgan"):
|
256 |
+
h_input, w_input = img.shape[0:2]
|
257 |
+
# img: numpy
|
258 |
+
img = img.astype(np.float32)
|
259 |
+
if np.max(img) > 256: # 16-bit image
|
260 |
+
max_range = 65535
|
261 |
+
print("\tInput is a 16-bit image")
|
262 |
+
else:
|
263 |
+
max_range = 255
|
264 |
+
img = img / max_range
|
265 |
+
if len(img.shape) == 2: # gray image
|
266 |
+
img_mode = "L"
|
267 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
268 |
+
elif img.shape[2] == 4: # RGBA image with alpha channel
|
269 |
+
img_mode = "RGBA"
|
270 |
+
alpha = img[:, :, 3]
|
271 |
+
img = img[:, :, 0:3]
|
272 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
273 |
+
if alpha_upsampler == "realesrgan":
|
274 |
+
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
|
275 |
+
else:
|
276 |
+
img_mode = "RGB"
|
277 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
278 |
+
|
279 |
+
# ------------------- process image (without the alpha channel) ------------------- #
|
280 |
+
self.pre_process(img)
|
281 |
+
if self.tile_size > 0:
|
282 |
+
self.tile_process()
|
283 |
+
else:
|
284 |
+
self.process()
|
285 |
+
output_img = self.post_process()
|
286 |
+
output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
287 |
+
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
|
288 |
+
if img_mode == "L":
|
289 |
+
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
|
290 |
+
|
291 |
+
# ------------------- process the alpha channel if necessary ------------------- #
|
292 |
+
if img_mode == "RGBA":
|
293 |
+
if alpha_upsampler == "realesrgan":
|
294 |
+
self.pre_process(alpha)
|
295 |
+
if self.tile_size > 0:
|
296 |
+
self.tile_process()
|
297 |
+
else:
|
298 |
+
self.process()
|
299 |
+
output_alpha = self.post_process()
|
300 |
+
output_alpha = (
|
301 |
+
output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
302 |
+
)
|
303 |
+
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
|
304 |
+
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
|
305 |
+
else: # use the cv2 resize for alpha channel
|
306 |
+
h, w = alpha.shape[0:2]
|
307 |
+
output_alpha = cv2.resize(
|
308 |
+
alpha,
|
309 |
+
(w * self.scale, h * self.scale),
|
310 |
+
interpolation=cv2.INTER_LINEAR,
|
311 |
+
)
|
312 |
+
|
313 |
+
# merge the alpha channel
|
314 |
+
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
|
315 |
+
output_img[:, :, 3] = output_alpha
|
316 |
+
|
317 |
+
# ------------------------------ return ------------------------------ #
|
318 |
+
if max_range == 65535: # 16-bit image
|
319 |
+
output = (output_img * 65535.0).round().astype(np.uint16)
|
320 |
+
else:
|
321 |
+
output = (output_img * 255.0).round().astype(np.uint8)
|
322 |
+
|
323 |
+
if outscale is not None and outscale != float(self.scale):
|
324 |
+
output = cv2.resize(
|
325 |
+
output,
|
326 |
+
(
|
327 |
+
int(w_input * outscale),
|
328 |
+
int(h_input * outscale),
|
329 |
+
),
|
330 |
+
interpolation=cv2.INTER_LANCZOS4,
|
331 |
+
)
|
332 |
+
|
333 |
+
return output, img_mode
|
334 |
+
|
335 |
+
|
336 |
+
class PrefetchReader(threading.Thread):
|
337 |
+
"""Prefetch images.
|
338 |
+
|
339 |
+
Args:
|
340 |
+
img_list (list[str]): A image list of image paths to be read.
|
341 |
+
num_prefetch_queue (int): Number of prefetch queue.
|
342 |
+
"""
|
343 |
+
|
344 |
+
def __init__(self, img_list, num_prefetch_queue):
|
345 |
+
super().__init__()
|
346 |
+
self.que = queue.Queue(num_prefetch_queue)
|
347 |
+
self.img_list = img_list
|
348 |
+
|
349 |
+
def run(self):
|
350 |
+
for img_path in self.img_list:
|
351 |
+
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
352 |
+
self.que.put(img)
|
353 |
+
|
354 |
+
self.que.put(None)
|
355 |
+
|
356 |
+
def __next__(self):
|
357 |
+
next_item = self.que.get()
|
358 |
+
if next_item is None:
|
359 |
+
raise StopIteration
|
360 |
+
return next_item
|
361 |
+
|
362 |
+
def __iter__(self):
|
363 |
+
return self
|
364 |
+
|
365 |
+
|
366 |
+
class IOConsumer(threading.Thread):
|
367 |
+
|
368 |
+
def __init__(self, opt, que, qid):
|
369 |
+
super().__init__()
|
370 |
+
self._queue = que
|
371 |
+
self.qid = qid
|
372 |
+
self.opt = opt
|
373 |
+
|
374 |
+
def run(self):
|
375 |
+
while True:
|
376 |
+
msg = self._queue.get()
|
377 |
+
if isinstance(msg, str) and msg == "quit":
|
378 |
+
break
|
379 |
+
|
380 |
+
output = msg["output"]
|
381 |
+
save_path = msg["save_path"]
|
382 |
+
cv2.imwrite(save_path, output)
|
383 |
+
print(f"IO worker {self.qid} is done.")
|
384 |
+
|
385 |
+
|
386 |
+
class ESRGANEasyModel:
|
387 |
+
def __init__(
|
388 |
+
self, model_path="./pretrained_models/RealESRGAN_x4plus.pth", face_enhance=True
|
389 |
+
):
|
390 |
+
model = RRDBNet(
|
391 |
+
num_in_ch=3,
|
392 |
+
num_out_ch=3,
|
393 |
+
num_feat=64,
|
394 |
+
num_block=23,
|
395 |
+
num_grow_ch=32,
|
396 |
+
scale=4,
|
397 |
+
)
|
398 |
+
self.net_scale = 4
|
399 |
+
file_url = [
|
400 |
+
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
|
401 |
+
]
|
402 |
+
if model_path is None:
|
403 |
+
model_path = os.path.join("weights", args.model_name + ".pth")
|
404 |
+
if not os.path.isfile(model_path):
|
405 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
406 |
+
for url in file_url:
|
407 |
+
# model_path will be updated
|
408 |
+
model_path = load_file_from_url(
|
409 |
+
url=url,
|
410 |
+
model_dir=os.path.join("./", "pretrained_models"),
|
411 |
+
progress=True,
|
412 |
+
file_name=None,
|
413 |
+
)
|
414 |
+
self.face_enhance = face_enhance
|
415 |
+
|
416 |
+
dni_weight = None
|
417 |
+
|
418 |
+
self.upsampler = RealESRGANer(
|
419 |
+
scale=self.net_scale,
|
420 |
+
model_path=model_path,
|
421 |
+
dni_weight=dni_weight,
|
422 |
+
model=model,
|
423 |
+
tile=0,
|
424 |
+
tile_pad=10,
|
425 |
+
pre_pad=0,
|
426 |
+
half=False,
|
427 |
+
)
|
428 |
+
|
429 |
+
self.upsampler.model.to(avaliable_device())
|
430 |
+
if face_enhance: # Use GFPGAN for face enhancement
|
431 |
+
from gfpgan import GFPGANer
|
432 |
+
|
433 |
+
self.face_enhancer = GFPGANer(
|
434 |
+
model_path="https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth",
|
435 |
+
upscale=4,
|
436 |
+
arch="clean",
|
437 |
+
channel_multiplier=2,
|
438 |
+
bg_upsampler=self.upsampler,
|
439 |
+
)
|
440 |
+
else:
|
441 |
+
self.face_enhancer = None
|
442 |
+
|
443 |
+
@torch.no_grad()
|
444 |
+
def __call__(self, img):
|
445 |
+
if self.face_enhancer is not None:
|
446 |
+
_, _, output = self.face_enhancer.enhance(
|
447 |
+
img, has_aligned=False, only_center_face=False, paste_back=True
|
448 |
+
)
|
449 |
+
else:
|
450 |
+
output, _ = self.upsampler.enhance(img, outscale=4)
|
451 |
+
return output
|
452 |
+
|
453 |
+
def __repr__(self):
|
454 |
+
return f"ESRGANEasyModel:\n {self.upsampler}"
|
455 |
+
|
456 |
+
|
457 |
+
if __name__ == "__main__":
|
458 |
+
|
459 |
+
import time
|
460 |
+
|
461 |
+
model = ESRGANEasyModel(face_enhance=True)
|
462 |
+
input_img = "./debug/face_debug/gt/head_gt_0.png"
|
463 |
+
|
464 |
+
img_np = cv2.imread(input_img)
|
465 |
+
set1 = [
|
466 |
+
"./debug/face_debug/gt/head_gt_0.png",
|
467 |
+
"./debug/face_debug/gt/head_gt_1.png",
|
468 |
+
"./debug/face_debug/gt/head_gt_2.png",
|
469 |
+
"./debug/face_debug/gt/head_gt_3.png",
|
470 |
+
"./debug/face_debug/gt/head_gt_4.png",
|
471 |
+
"./debug/face_debug/gt/head_gt_5.png",
|
472 |
+
"./debug/face_debug/gt/head_gt_6.png",
|
473 |
+
"./debug/face_debug/gt/head_gt_0.png",
|
474 |
+
]
|
475 |
+
img_set1 = [cv2.imread(img_path) for img_path in set1]
|
476 |
+
|
477 |
+
sr = model(img_set1[0])
|
478 |
+
|
479 |
+
s0 = time.time()
|
480 |
+
for img in img_set1:
|
481 |
+
|
482 |
+
sr = model(img)
|
LHM/models/__init__.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024, Zexin He
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from .modeling_human_lrm import (
|
17 |
+
ModelHumanLRM,
|
18 |
+
ModelHumanLRMSapdinoBodyHeadSD3,
|
19 |
+
ModelHumanLRMSapdinoBodyHeadSD3_5,
|
20 |
+
ModelHumanLRMSapdinoSD3,
|
21 |
+
ModelHumanLRMSD3,
|
22 |
+
)
|
23 |
+
|
24 |
+
model_dict = {
|
25 |
+
"human_lrm": ModelHumanLRM,
|
26 |
+
"human_lrm_sd3": ModelHumanLRMSD3,
|
27 |
+
"human_lrm_sapdino_sd3": ModelHumanLRMSapdinoSD3,
|
28 |
+
"human_lrm_sapdino_bh_sd3": ModelHumanLRMSapdinoBodyHeadSD3,
|
29 |
+
"human_lrm_sapdino_bh_sd3_5": ModelHumanLRMSapdinoBodyHeadSD3_5,
|
30 |
+
}
|
LHM/models/__pycache__/ESRGANer_utils.cpython-310.pyc
ADDED
Binary file (11.9 kB). View file
|
|
LHM/models/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (508 Bytes). View file
|
|
LHM/models/__pycache__/arcface_utils.cpython-310.pyc
ADDED
Binary file (9.71 kB). View file
|
|
LHM/models/__pycache__/embedder.cpython-310.pyc
ADDED
Binary file (1.01 kB). View file
|
|
LHM/models/__pycache__/modeling_human_lrm.cpython-310.pyc
ADDED
Binary file (23.9 kB). View file
|
|
LHM/models/__pycache__/transformer.cpython-310.pyc
ADDED
Binary file (6.82 kB). View file
|
|
LHM/models/__pycache__/transformer_dit.cpython-310.pyc
ADDED
Binary file (15 kB). View file
|
|
LHM/models/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (1.42 kB). View file
|
|
LHM/models/arcface_utils.py
ADDED
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Organization : Alibaba XR-Lab
|
3 |
+
# @Author : Lingteng Qiu
|
4 |
+
# @Email : [email protected]
|
5 |
+
# @Time : 2025-03-10 17:38:29
|
6 |
+
# @Function : Arc-Similarity Loss
|
7 |
+
import sys
|
8 |
+
|
9 |
+
sys.path.append(".")
|
10 |
+
|
11 |
+
import pdb
|
12 |
+
from copy import deepcopy
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
18 |
+
|
19 |
+
|
20 |
+
def conv3x3(inplanes, outplanes, stride=1):
|
21 |
+
"""A simple wrapper for 3x3 convolution with padding.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
inplanes (int): Channel number of inputs.
|
25 |
+
outplanes (int): Channel number of outputs.
|
26 |
+
stride (int): Stride in convolution. Default: 1.
|
27 |
+
"""
|
28 |
+
return nn.Conv2d(
|
29 |
+
inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
class BasicBlock(nn.Module):
|
34 |
+
"""Basic residual block used in the ResNetArcFace architecture.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
inplanes (int): Channel number of inputs.
|
38 |
+
planes (int): Channel number of outputs.
|
39 |
+
stride (int): Stride in convolution. Default: 1.
|
40 |
+
downsample (nn.Module): The downsample module. Default: None.
|
41 |
+
"""
|
42 |
+
|
43 |
+
expansion = 1 # output channel expansion ratio
|
44 |
+
|
45 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
46 |
+
super(BasicBlock, self).__init__()
|
47 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
48 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
49 |
+
self.relu = nn.ReLU(inplace=True)
|
50 |
+
self.conv2 = conv3x3(planes, planes)
|
51 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
52 |
+
self.downsample = downsample
|
53 |
+
self.stride = stride
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
residual = x
|
57 |
+
|
58 |
+
out = self.conv1(x)
|
59 |
+
out = self.bn1(out)
|
60 |
+
out = self.relu(out)
|
61 |
+
|
62 |
+
out = self.conv2(out)
|
63 |
+
out = self.bn2(out)
|
64 |
+
|
65 |
+
if self.downsample is not None:
|
66 |
+
residual = self.downsample(x)
|
67 |
+
|
68 |
+
out += residual
|
69 |
+
out = self.relu(out)
|
70 |
+
|
71 |
+
return out
|
72 |
+
|
73 |
+
|
74 |
+
class IRBlock(nn.Module):
|
75 |
+
"""Improved residual block (IR Block) used in the ResNetArcFace architecture.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
inplanes (int): Channel number of inputs.
|
79 |
+
planes (int): Channel number of outputs.
|
80 |
+
stride (int): Stride in convolution. Default: 1.
|
81 |
+
downsample (nn.Module): The downsample module. Default: None.
|
82 |
+
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
|
83 |
+
"""
|
84 |
+
|
85 |
+
expansion = 1 # output channel expansion ratio
|
86 |
+
|
87 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
|
88 |
+
super(IRBlock, self).__init__()
|
89 |
+
self.bn0 = nn.BatchNorm2d(inplanes)
|
90 |
+
self.conv1 = conv3x3(inplanes, inplanes)
|
91 |
+
self.bn1 = nn.BatchNorm2d(inplanes)
|
92 |
+
self.prelu = nn.PReLU()
|
93 |
+
self.conv2 = conv3x3(inplanes, planes, stride)
|
94 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
95 |
+
self.downsample = downsample
|
96 |
+
self.stride = stride
|
97 |
+
self.use_se = use_se
|
98 |
+
if self.use_se:
|
99 |
+
self.se = SEBlock(planes)
|
100 |
+
|
101 |
+
def forward(self, x):
|
102 |
+
residual = x
|
103 |
+
out = self.bn0(x)
|
104 |
+
out = self.conv1(out)
|
105 |
+
out = self.bn1(out)
|
106 |
+
out = self.prelu(out)
|
107 |
+
|
108 |
+
out = self.conv2(out)
|
109 |
+
out = self.bn2(out)
|
110 |
+
if self.use_se:
|
111 |
+
out = self.se(out)
|
112 |
+
|
113 |
+
if self.downsample is not None:
|
114 |
+
residual = self.downsample(x)
|
115 |
+
|
116 |
+
out += residual
|
117 |
+
out = self.prelu(out)
|
118 |
+
|
119 |
+
return out
|
120 |
+
|
121 |
+
|
122 |
+
class Bottleneck(nn.Module):
|
123 |
+
"""Bottleneck block used in the ResNetArcFace architecture.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
inplanes (int): Channel number of inputs.
|
127 |
+
planes (int): Channel number of outputs.
|
128 |
+
stride (int): Stride in convolution. Default: 1.
|
129 |
+
downsample (nn.Module): The downsample module. Default: None.
|
130 |
+
"""
|
131 |
+
|
132 |
+
expansion = 4 # output channel expansion ratio
|
133 |
+
|
134 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
135 |
+
super(Bottleneck, self).__init__()
|
136 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
137 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
138 |
+
self.conv2 = nn.Conv2d(
|
139 |
+
planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
|
140 |
+
)
|
141 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
142 |
+
self.conv3 = nn.Conv2d(
|
143 |
+
planes, planes * self.expansion, kernel_size=1, bias=False
|
144 |
+
)
|
145 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
146 |
+
self.relu = nn.ReLU(inplace=True)
|
147 |
+
self.downsample = downsample
|
148 |
+
self.stride = stride
|
149 |
+
|
150 |
+
def forward(self, x):
|
151 |
+
residual = x
|
152 |
+
|
153 |
+
out = self.conv1(x)
|
154 |
+
out = self.bn1(out)
|
155 |
+
out = self.relu(out)
|
156 |
+
|
157 |
+
out = self.conv2(out)
|
158 |
+
out = self.bn2(out)
|
159 |
+
out = self.relu(out)
|
160 |
+
|
161 |
+
out = self.conv3(out)
|
162 |
+
out = self.bn3(out)
|
163 |
+
|
164 |
+
if self.downsample is not None:
|
165 |
+
residual = self.downsample(x)
|
166 |
+
|
167 |
+
out += residual
|
168 |
+
out = self.relu(out)
|
169 |
+
|
170 |
+
return out
|
171 |
+
|
172 |
+
|
173 |
+
class SEBlock(nn.Module):
|
174 |
+
"""The squeeze-and-excitation block (SEBlock) used in the IRBlock.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
channel (int): Channel number of inputs.
|
178 |
+
reduction (int): Channel reduction ration. Default: 16.
|
179 |
+
"""
|
180 |
+
|
181 |
+
def __init__(self, channel, reduction=16):
|
182 |
+
super(SEBlock, self).__init__()
|
183 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(
|
184 |
+
1
|
185 |
+
) # pool to 1x1 without spatial information
|
186 |
+
self.fc = nn.Sequential(
|
187 |
+
nn.Linear(channel, channel // reduction),
|
188 |
+
nn.PReLU(),
|
189 |
+
nn.Linear(channel // reduction, channel),
|
190 |
+
nn.Sigmoid(),
|
191 |
+
)
|
192 |
+
|
193 |
+
def forward(self, x):
|
194 |
+
b, c, _, _ = x.size()
|
195 |
+
y = self.avg_pool(x).view(b, c)
|
196 |
+
y = self.fc(y).view(b, c, 1, 1)
|
197 |
+
return x * y
|
198 |
+
|
199 |
+
|
200 |
+
class ResNetArcFace(nn.Module):
|
201 |
+
"""ArcFace with ResNet architectures.
|
202 |
+
|
203 |
+
Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
|
204 |
+
|
205 |
+
Args:
|
206 |
+
block (str): Block used in the ArcFace architecture.
|
207 |
+
layers (tuple(int)): Block numbers in each layer.
|
208 |
+
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
|
209 |
+
"""
|
210 |
+
|
211 |
+
def __init__(
|
212 |
+
self,
|
213 |
+
block="IRBlock",
|
214 |
+
layers=[2, 2, 2, 2],
|
215 |
+
use_se=False,
|
216 |
+
pretrain_model="./pretrained_models/arcface_resnet18.pth",
|
217 |
+
):
|
218 |
+
if block == "IRBlock":
|
219 |
+
block = IRBlock
|
220 |
+
self.inplanes = 64
|
221 |
+
self.use_se = use_se
|
222 |
+
super(ResNetArcFace, self).__init__()
|
223 |
+
|
224 |
+
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
|
225 |
+
self.bn1 = nn.BatchNorm2d(64)
|
226 |
+
self.prelu = nn.PReLU()
|
227 |
+
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
|
228 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
229 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
230 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
231 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
232 |
+
self.bn4 = nn.BatchNorm2d(512)
|
233 |
+
self.dropout = nn.Dropout()
|
234 |
+
self.fc5 = nn.Linear(512 * 8 * 8, 512)
|
235 |
+
self.bn5 = nn.BatchNorm1d(512)
|
236 |
+
|
237 |
+
# initialization
|
238 |
+
for m in self.modules():
|
239 |
+
if isinstance(m, nn.Conv2d):
|
240 |
+
nn.init.xavier_normal_(m.weight)
|
241 |
+
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
|
242 |
+
nn.init.constant_(m.weight, 1)
|
243 |
+
nn.init.constant_(m.bias, 0)
|
244 |
+
elif isinstance(m, nn.Linear):
|
245 |
+
nn.init.xavier_normal_(m.weight)
|
246 |
+
nn.init.constant_(m.bias, 0)
|
247 |
+
|
248 |
+
if pretrain_model is not None:
|
249 |
+
self.load_network(self, pretrain_model, strict=True, param_key=None)
|
250 |
+
else:
|
251 |
+
raise ValueError("Please specify the pretrain model path.")
|
252 |
+
|
253 |
+
self.freeze()
|
254 |
+
|
255 |
+
@staticmethod
|
256 |
+
def load_network(net, load_path, strict=True, param_key=None):
|
257 |
+
|
258 |
+
def get_bare_model(net):
|
259 |
+
if isinstance(net, (DataParallel, DistributedDataParallel)):
|
260 |
+
net = net.module
|
261 |
+
return net
|
262 |
+
|
263 |
+
net = get_bare_model(net)
|
264 |
+
load_net = torch.load(load_path, map_location=lambda storage, loc: storage)
|
265 |
+
if param_key is not None:
|
266 |
+
if param_key not in load_net and "params" in load_net:
|
267 |
+
param_key = "params"
|
268 |
+
load_net = load_net[param_key]
|
269 |
+
# remove unnecessary 'module.'
|
270 |
+
for k, v in deepcopy(load_net).items():
|
271 |
+
if k.startswith("module."):
|
272 |
+
load_net[k[7:]] = v
|
273 |
+
load_net.pop(k)
|
274 |
+
ret = net.load_state_dict(load_net, strict=strict)
|
275 |
+
print(ret)
|
276 |
+
|
277 |
+
def _make_layer(self, block, planes, num_blocks, stride=1):
|
278 |
+
downsample = None
|
279 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
280 |
+
downsample = nn.Sequential(
|
281 |
+
nn.Conv2d(
|
282 |
+
self.inplanes,
|
283 |
+
planes * block.expansion,
|
284 |
+
kernel_size=1,
|
285 |
+
stride=stride,
|
286 |
+
bias=False,
|
287 |
+
),
|
288 |
+
nn.BatchNorm2d(planes * block.expansion),
|
289 |
+
)
|
290 |
+
layers = []
|
291 |
+
layers.append(
|
292 |
+
block(self.inplanes, planes, stride, downsample, use_se=self.use_se)
|
293 |
+
)
|
294 |
+
self.inplanes = planes
|
295 |
+
for _ in range(1, num_blocks):
|
296 |
+
layers.append(block(self.inplanes, planes, use_se=self.use_se))
|
297 |
+
|
298 |
+
return nn.Sequential(*layers)
|
299 |
+
|
300 |
+
def forward(self, x):
|
301 |
+
x = self.conv1(x)
|
302 |
+
x = self.bn1(x)
|
303 |
+
x = self.prelu(x)
|
304 |
+
x = self.maxpool(x)
|
305 |
+
|
306 |
+
x = self.layer1(x)
|
307 |
+
x = self.layer2(x)
|
308 |
+
x = self.layer3(x)
|
309 |
+
x = self.layer4(x)
|
310 |
+
x = self.bn4(x)
|
311 |
+
x = self.dropout(x)
|
312 |
+
x = x.view(x.size(0), -1)
|
313 |
+
x = self.fc5(x)
|
314 |
+
x = self.bn5(x)
|
315 |
+
|
316 |
+
return x
|
317 |
+
|
318 |
+
def freeze(self):
|
319 |
+
self.eval()
|
320 |
+
for param in self.parameters():
|
321 |
+
param.requires_grad = False
|
322 |
+
|
323 |
+
|
324 |
+
if __name__ == "__main__":
|
325 |
+
model = ResNetArcFace()
|
326 |
+
model.cuda()
|
327 |
+
model.eval()
|
328 |
+
# model.eval()
|
329 |
+
|
330 |
+
set1 = [
|
331 |
+
"./debug/face_debug/gt/head_gt_0.png",
|
332 |
+
"./debug/face_debug/gt/head_gt_1.png",
|
333 |
+
"./debug/face_debug/gt/head_gt_2.png",
|
334 |
+
"./debug/face_debug/gt/head_gt_3.png",
|
335 |
+
"./debug/face_debug/gt/head_gt_4.png",
|
336 |
+
"./debug/face_debug/gt/head_gt_5.png",
|
337 |
+
"./debug/face_debug/gt/head_gt_6.png",
|
338 |
+
]
|
339 |
+
import cv2
|
340 |
+
|
341 |
+
img_set1 = [cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) for img_path in set1]
|
342 |
+
|
343 |
+
F1_list = []
|
344 |
+
|
345 |
+
f1_scores = []
|
346 |
+
for img in img_set1:
|
347 |
+
img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0) / 255.0
|
348 |
+
img = img.cuda()
|
349 |
+
F1 = model(img)
|
350 |
+
F1_list.append(F1)
|
351 |
+
for i in range(len(F1_list)):
|
352 |
+
for j in range(len(F1_list)):
|
353 |
+
f1_scores.append(F.l1_loss(F1_list[i], F1_list[j]))
|
354 |
+
|
355 |
+
print(len(f1_scores))
|
356 |
+
|
357 |
+
f1_scores = torch.tensor(f1_scores)
|
358 |
+
print(f1_scores)
|
359 |
+
f1_scores = f1_scores.view(len(F1_list), len(F1_list))
|
360 |
+
print(f1_scores)
|
LHM/models/block.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024, Zexin He
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import torch.nn as nn
|
17 |
+
|
18 |
+
from .modulate import ModLN
|
19 |
+
|
20 |
+
|
21 |
+
class BasicBlock(nn.Module):
|
22 |
+
"""
|
23 |
+
Transformer block that is in its simplest form.
|
24 |
+
Designed for PF-LRM architecture.
|
25 |
+
"""
|
26 |
+
# Block contains a self-attention layer and an MLP
|
27 |
+
def __init__(self, inner_dim: int, num_heads: int, eps: float,
|
28 |
+
attn_drop: float = 0., attn_bias: bool = False,
|
29 |
+
mlp_ratio: float = 4., mlp_drop: float = 0.):
|
30 |
+
super().__init__()
|
31 |
+
self.norm1 = nn.LayerNorm(inner_dim, eps=eps)
|
32 |
+
self.self_attn = nn.MultiheadAttention(
|
33 |
+
embed_dim=inner_dim, num_heads=num_heads,
|
34 |
+
dropout=attn_drop, bias=attn_bias, batch_first=True)
|
35 |
+
self.norm2 = nn.LayerNorm(inner_dim, eps=eps)
|
36 |
+
self.mlp = nn.Sequential(
|
37 |
+
nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
|
38 |
+
nn.GELU(),
|
39 |
+
nn.Dropout(mlp_drop),
|
40 |
+
nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
|
41 |
+
nn.Dropout(mlp_drop),
|
42 |
+
)
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
# x: [N, L, D]
|
46 |
+
before_sa = self.norm1(x)
|
47 |
+
x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
|
48 |
+
x = x + self.mlp(self.norm2(x))
|
49 |
+
return x
|
50 |
+
|
51 |
+
|
52 |
+
class ConditionBlock(nn.Module):
|
53 |
+
"""
|
54 |
+
Transformer block that takes in a cross-attention condition.
|
55 |
+
Designed for SparseLRM architecture.
|
56 |
+
"""
|
57 |
+
# Block contains a cross-attention layer, a self-attention layer, and an MLP
|
58 |
+
def __init__(self, inner_dim: int, cond_dim: int, num_heads: int, eps: float,
|
59 |
+
attn_drop: float = 0., attn_bias: bool = False,
|
60 |
+
mlp_ratio: float = 4., mlp_drop: float = 0.):
|
61 |
+
super().__init__()
|
62 |
+
self.norm1 = nn.LayerNorm(inner_dim, eps=eps)
|
63 |
+
self.cross_attn = nn.MultiheadAttention(
|
64 |
+
embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
|
65 |
+
dropout=attn_drop, bias=attn_bias, batch_first=True)
|
66 |
+
self.norm2 = nn.LayerNorm(inner_dim, eps=eps)
|
67 |
+
self.self_attn = nn.MultiheadAttention(
|
68 |
+
embed_dim=inner_dim, num_heads=num_heads,
|
69 |
+
dropout=attn_drop, bias=attn_bias, batch_first=True)
|
70 |
+
self.norm3 = nn.LayerNorm(inner_dim, eps=eps)
|
71 |
+
self.mlp = nn.Sequential(
|
72 |
+
nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
|
73 |
+
nn.GELU(),
|
74 |
+
nn.Dropout(mlp_drop),
|
75 |
+
nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
|
76 |
+
nn.Dropout(mlp_drop),
|
77 |
+
)
|
78 |
+
|
79 |
+
def forward(self, x, cond):
|
80 |
+
# x: [N, L, D]
|
81 |
+
# cond: [N, L_cond, D_cond]
|
82 |
+
x = x + self.cross_attn(self.norm1(x), cond, cond, need_weights=False)[0]
|
83 |
+
before_sa = self.norm2(x)
|
84 |
+
x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
|
85 |
+
x = x + self.mlp(self.norm3(x))
|
86 |
+
return x
|
87 |
+
|
88 |
+
|
89 |
+
class ConditionModulationBlock(nn.Module):
|
90 |
+
"""
|
91 |
+
Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks.
|
92 |
+
Designed for raw LRM architecture.
|
93 |
+
"""
|
94 |
+
# Block contains a cross-attention layer, a self-attention layer, and an MLP
|
95 |
+
def __init__(self, inner_dim: int, cond_dim: int, mod_dim: int, num_heads: int, eps: float,
|
96 |
+
attn_drop: float = 0., attn_bias: bool = False,
|
97 |
+
mlp_ratio: float = 4., mlp_drop: float = 0.):
|
98 |
+
super().__init__()
|
99 |
+
self.norm1 = ModLN(inner_dim, mod_dim, eps)
|
100 |
+
self.cross_attn = nn.MultiheadAttention(
|
101 |
+
embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
|
102 |
+
dropout=attn_drop, bias=attn_bias, batch_first=True)
|
103 |
+
self.norm2 = ModLN(inner_dim, mod_dim, eps)
|
104 |
+
self.self_attn = nn.MultiheadAttention(
|
105 |
+
embed_dim=inner_dim, num_heads=num_heads,
|
106 |
+
dropout=attn_drop, bias=attn_bias, batch_first=True)
|
107 |
+
self.norm3 = ModLN(inner_dim, mod_dim, eps)
|
108 |
+
self.mlp = nn.Sequential(
|
109 |
+
nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
|
110 |
+
nn.GELU(),
|
111 |
+
nn.Dropout(mlp_drop),
|
112 |
+
nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
|
113 |
+
nn.Dropout(mlp_drop),
|
114 |
+
)
|
115 |
+
|
116 |
+
def forward(self, x, cond, mod):
|
117 |
+
# x: [N, L, D]
|
118 |
+
# cond: [N, L_cond, D_cond]
|
119 |
+
# mod: [N, D_mod]
|
120 |
+
x = x + self.cross_attn(self.norm1(x, mod), cond, cond, need_weights=False)[0]
|
121 |
+
before_sa = self.norm2(x, mod)
|
122 |
+
x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
|
123 |
+
x = x + self.mlp(self.norm3(x, mod))
|
124 |
+
return x
|
LHM/models/discriminator.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Ported from Paella
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
9 |
+
from diffusers.models.modeling_utils import ModelMixin
|
10 |
+
|
11 |
+
import functools
|
12 |
+
# import torch.nn as nn
|
13 |
+
from taming.modules.util import ActNorm
|
14 |
+
|
15 |
+
|
16 |
+
# Discriminator model ported from Paella https://github.com/dome272/Paella/blob/main/src_distributed/vqgan.py
|
17 |
+
class Discriminator(ModelMixin, ConfigMixin):
|
18 |
+
@register_to_config
|
19 |
+
def __init__(self, in_channels=3, cond_channels=0, hidden_channels=512, depth=6):
|
20 |
+
super().__init__()
|
21 |
+
d = max(depth - 3, 3)
|
22 |
+
layers = [
|
23 |
+
nn.utils.spectral_norm(
|
24 |
+
nn.Conv2d(in_channels, hidden_channels // (2**d), kernel_size=3, stride=2, padding=1)
|
25 |
+
),
|
26 |
+
nn.LeakyReLU(0.2),
|
27 |
+
]
|
28 |
+
for i in range(depth - 1):
|
29 |
+
c_in = hidden_channels // (2 ** max((d - i), 0))
|
30 |
+
c_out = hidden_channels // (2 ** max((d - 1 - i), 0))
|
31 |
+
layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
|
32 |
+
layers.append(nn.InstanceNorm2d(c_out))
|
33 |
+
layers.append(nn.LeakyReLU(0.2))
|
34 |
+
self.encoder = nn.Sequential(*layers)
|
35 |
+
self.shuffle = nn.Conv2d(
|
36 |
+
(hidden_channels + cond_channels) if cond_channels > 0 else hidden_channels, 1, kernel_size=1
|
37 |
+
)
|
38 |
+
# self.logits = nn.Sigmoid()
|
39 |
+
|
40 |
+
|
41 |
+
def forward(self, x, cond=None):
|
42 |
+
x = self.encoder(x)
|
43 |
+
if cond is not None:
|
44 |
+
cond = cond.view(
|
45 |
+
cond.size(0),
|
46 |
+
cond.size(1),
|
47 |
+
1,
|
48 |
+
1,
|
49 |
+
).expand(-1, -1, x.size(-2), x.size(-1))
|
50 |
+
x = torch.cat([x, cond], dim=1)
|
51 |
+
x = self.shuffle(x)
|
52 |
+
# x = self.logits(x)
|
53 |
+
return x
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
def weights_init(m):
|
59 |
+
classname = m.__class__.__name__
|
60 |
+
if classname.find('Conv') != -1:
|
61 |
+
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
62 |
+
elif classname.find('BatchNorm') != -1:
|
63 |
+
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
64 |
+
nn.init.constant_(m.bias.data, 0)
|
65 |
+
|
66 |
+
|
67 |
+
class NLayerDiscriminator(nn.Module):
|
68 |
+
"""Defines a PatchGAN discriminator as in Pix2Pix
|
69 |
+
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
70 |
+
"""
|
71 |
+
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
|
72 |
+
"""Construct a PatchGAN discriminator
|
73 |
+
Parameters:
|
74 |
+
input_nc (int) -- the number of channels in input images
|
75 |
+
ndf (int) -- the number of filters in the last conv layer
|
76 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
77 |
+
norm_layer -- normalization layer
|
78 |
+
"""
|
79 |
+
super(NLayerDiscriminator, self).__init__()
|
80 |
+
if not use_actnorm:
|
81 |
+
# norm_layer = nn.BatchNorm2d
|
82 |
+
norm_layer = nn.InstanceNorm2d
|
83 |
+
else:
|
84 |
+
norm_layer = ActNorm
|
85 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
86 |
+
# use_bias = norm_layer.func != nn.BatchNorm2d
|
87 |
+
use_bias = norm_layer.func != nn.InstanceNorm2d
|
88 |
+
else:
|
89 |
+
# use_bias = norm_layer != nn.BatchNorm2d
|
90 |
+
use_bias = norm_layer != nn.InstanceNorm2d
|
91 |
+
|
92 |
+
kw = 4
|
93 |
+
padw = 1
|
94 |
+
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, False)]
|
95 |
+
nf_mult = 1
|
96 |
+
nf_mult_prev = 1
|
97 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
98 |
+
nf_mult_prev = nf_mult
|
99 |
+
nf_mult = min(2 ** n, 8)
|
100 |
+
sequence += [
|
101 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
102 |
+
norm_layer(ndf * nf_mult),
|
103 |
+
nn.LeakyReLU(0.2, False)
|
104 |
+
]
|
105 |
+
|
106 |
+
nf_mult_prev = nf_mult
|
107 |
+
nf_mult = min(2 ** n_layers, 8)
|
108 |
+
sequence += [
|
109 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
110 |
+
norm_layer(ndf * nf_mult),
|
111 |
+
nn.LeakyReLU(0.2, False)
|
112 |
+
]
|
113 |
+
|
114 |
+
sequence += [
|
115 |
+
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
116 |
+
self.main = nn.Sequential(*sequence)
|
117 |
+
|
118 |
+
def forward(self, input):
|
119 |
+
"""Standard forward."""
|
120 |
+
return self.main(input)
|
LHM/models/embedder.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024, Zexin He
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
|
19 |
+
|
20 |
+
class CameraEmbedder(nn.Module):
|
21 |
+
"""
|
22 |
+
Embed camera features to a high-dimensional vector.
|
23 |
+
|
24 |
+
Reference:
|
25 |
+
DiT: https://github.com/facebookresearch/DiT/blob/main/models.py#L27
|
26 |
+
"""
|
27 |
+
def __init__(self, raw_dim: int, embed_dim: int):
|
28 |
+
super().__init__()
|
29 |
+
self.mlp = nn.Sequential(
|
30 |
+
nn.Linear(raw_dim, embed_dim),
|
31 |
+
nn.SiLU(),
|
32 |
+
nn.Linear(embed_dim, embed_dim),
|
33 |
+
)
|
34 |
+
|
35 |
+
@torch.compile
|
36 |
+
def forward(self, x):
|
37 |
+
return self.mlp(x)
|
LHM/models/encoders/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024, Zexin He
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# Empty
|
LHM/models/encoders/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (169 Bytes). View file
|
|
LHM/models/encoders/__pycache__/dinov2_fusion_wrapper.cpython-310.pyc
ADDED
Binary file (4.98 kB). View file
|
|
LHM/models/encoders/__pycache__/sapiens_warpper.cpython-310.pyc
ADDED
Binary file (9.21 kB). View file
|
|
LHM/models/encoders/dino_wrapper.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024, Zexin He
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
from transformers import ViTImageProcessor, ViTModel
|
19 |
+
from accelerate.logging import get_logger
|
20 |
+
|
21 |
+
|
22 |
+
logger = get_logger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
class DinoWrapper(nn.Module):
|
26 |
+
"""
|
27 |
+
Dino v1 wrapper using huggingface transformer implementation.
|
28 |
+
"""
|
29 |
+
def __init__(self, model_name: str, freeze: bool = True, encoder_feat_dim: int = 384):
|
30 |
+
super().__init__()
|
31 |
+
self.model, self.processor = self._build_dino(model_name)
|
32 |
+
if freeze:
|
33 |
+
self._freeze()
|
34 |
+
|
35 |
+
@torch.compile
|
36 |
+
def forward_model(self, inputs):
|
37 |
+
return self.model(**inputs, interpolate_pos_encoding=True)
|
38 |
+
|
39 |
+
def forward(self, image):
|
40 |
+
# image: [N, C, H, W], on cpu
|
41 |
+
# RGB image with [0,1] scale and properly sized
|
42 |
+
inputs = self.processor(images=image, return_tensors="pt", do_rescale=False, do_resize=False).to(self.model.device)
|
43 |
+
# This resampling of positional embedding uses bicubic interpolation
|
44 |
+
outputs = self.forward_model(inputs)
|
45 |
+
last_hidden_states = outputs.last_hidden_state
|
46 |
+
return last_hidden_states
|
47 |
+
|
48 |
+
def _freeze(self):
|
49 |
+
logger.warning(f"======== Freezing DinoWrapper ========")
|
50 |
+
self.model.eval()
|
51 |
+
for name, param in self.model.named_parameters():
|
52 |
+
param.requires_grad = False
|
53 |
+
|
54 |
+
@staticmethod
|
55 |
+
def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5):
|
56 |
+
import requests
|
57 |
+
try:
|
58 |
+
model = ViTModel.from_pretrained(model_name, add_pooling_layer=False)
|
59 |
+
processor = ViTImageProcessor.from_pretrained(model_name)
|
60 |
+
return model, processor
|
61 |
+
except requests.exceptions.ProxyError as err:
|
62 |
+
if proxy_error_retries > 0:
|
63 |
+
print(f"Huggingface ProxyError: Retrying ({proxy_error_retries}) in {proxy_error_cooldown} seconds...")
|
64 |
+
import time
|
65 |
+
time.sleep(proxy_error_cooldown)
|
66 |
+
return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown)
|
67 |
+
else:
|
68 |
+
raise err
|
LHM/models/encoders/dinov2/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024, Zexin He
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# Empty
|
LHM/models/encoders/dinov2/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (176 Bytes). View file
|
|
LHM/models/encoders/dinov2/hub/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
LHM/models/encoders/dinov2/hub/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (180 Bytes). View file
|
|
LHM/models/encoders/dinov2/hub/__pycache__/backbones.cpython-310.pyc
ADDED
Binary file (4.45 kB). View file
|
|
LHM/models/encoders/dinov2/hub/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (1.79 kB). View file
|
|
LHM/models/encoders/dinov2/hub/backbones.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from enum import Enum
|
7 |
+
from typing import Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
|
12 |
+
|
13 |
+
|
14 |
+
class Weights(Enum):
|
15 |
+
LVD142M = "LVD142M"
|
16 |
+
|
17 |
+
|
18 |
+
def _make_dinov2_model(
|
19 |
+
*,
|
20 |
+
arch_name: str = "vit_large",
|
21 |
+
img_size: int = 518,
|
22 |
+
patch_size: int = 14,
|
23 |
+
init_values: float = 1.0,
|
24 |
+
ffn_layer: str = "mlp",
|
25 |
+
block_chunks: int = 0,
|
26 |
+
num_register_tokens: int = 0,
|
27 |
+
interpolate_antialias: bool = False,
|
28 |
+
interpolate_offset: float = 0.1,
|
29 |
+
pretrained: bool = True,
|
30 |
+
weights: Union[Weights, str] = Weights.LVD142M,
|
31 |
+
**kwargs,
|
32 |
+
):
|
33 |
+
from ..models import vision_transformer as vits
|
34 |
+
|
35 |
+
if isinstance(weights, str):
|
36 |
+
try:
|
37 |
+
weights = Weights[weights]
|
38 |
+
except KeyError:
|
39 |
+
raise AssertionError(f"Unsupported weights: {weights}")
|
40 |
+
|
41 |
+
model_base_name = _make_dinov2_model_name(arch_name, patch_size)
|
42 |
+
vit_kwargs = dict(
|
43 |
+
img_size=img_size,
|
44 |
+
patch_size=patch_size,
|
45 |
+
init_values=init_values,
|
46 |
+
ffn_layer=ffn_layer,
|
47 |
+
block_chunks=block_chunks,
|
48 |
+
num_register_tokens=num_register_tokens,
|
49 |
+
interpolate_antialias=interpolate_antialias,
|
50 |
+
interpolate_offset=interpolate_offset,
|
51 |
+
)
|
52 |
+
vit_kwargs.update(**kwargs)
|
53 |
+
model = vits.__dict__[arch_name](**vit_kwargs)
|
54 |
+
|
55 |
+
if pretrained:
|
56 |
+
model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
|
57 |
+
url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
|
58 |
+
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
59 |
+
# ********** Modified by Zexin He in 2023-2024 **********
|
60 |
+
state_dict = {k: v for k, v in state_dict.items() if 'mask_token' not in k} # DDP concern
|
61 |
+
if vit_kwargs.get("modulation_dim") is not None:
|
62 |
+
state_dict = {
|
63 |
+
k.replace('norm1', 'norm1.norm').replace('norm2', 'norm2.norm'): v
|
64 |
+
for k, v in state_dict.items()
|
65 |
+
}
|
66 |
+
model.load_state_dict(state_dict, strict=False)
|
67 |
+
else:
|
68 |
+
model.load_state_dict(state_dict, strict=True)
|
69 |
+
# ********************************************************
|
70 |
+
|
71 |
+
return model
|
72 |
+
|
73 |
+
|
74 |
+
def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
75 |
+
"""
|
76 |
+
DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
|
77 |
+
"""
|
78 |
+
return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
|
79 |
+
|
80 |
+
|
81 |
+
def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
82 |
+
"""
|
83 |
+
DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
|
84 |
+
"""
|
85 |
+
return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
|
86 |
+
|
87 |
+
|
88 |
+
def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
89 |
+
"""
|
90 |
+
DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
|
91 |
+
"""
|
92 |
+
return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
|
93 |
+
|
94 |
+
|
95 |
+
def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
96 |
+
"""
|
97 |
+
DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
|
98 |
+
"""
|
99 |
+
return _make_dinov2_model(
|
100 |
+
arch_name="vit_giant2",
|
101 |
+
ffn_layer="swiglufused",
|
102 |
+
weights=weights,
|
103 |
+
pretrained=pretrained,
|
104 |
+
**kwargs,
|
105 |
+
)
|
106 |
+
|
107 |
+
|
108 |
+
def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
109 |
+
"""
|
110 |
+
DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
111 |
+
"""
|
112 |
+
return _make_dinov2_model(
|
113 |
+
arch_name="vit_small",
|
114 |
+
pretrained=pretrained,
|
115 |
+
weights=weights,
|
116 |
+
num_register_tokens=4,
|
117 |
+
interpolate_antialias=True,
|
118 |
+
interpolate_offset=0.0,
|
119 |
+
**kwargs,
|
120 |
+
)
|
121 |
+
|
122 |
+
|
123 |
+
def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
124 |
+
"""
|
125 |
+
DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
126 |
+
"""
|
127 |
+
return _make_dinov2_model(
|
128 |
+
arch_name="vit_base",
|
129 |
+
pretrained=pretrained,
|
130 |
+
weights=weights,
|
131 |
+
num_register_tokens=4,
|
132 |
+
interpolate_antialias=True,
|
133 |
+
interpolate_offset=0.0,
|
134 |
+
**kwargs,
|
135 |
+
)
|
136 |
+
|
137 |
+
|
138 |
+
def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
139 |
+
"""
|
140 |
+
DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
141 |
+
"""
|
142 |
+
return _make_dinov2_model(
|
143 |
+
arch_name="vit_large",
|
144 |
+
pretrained=pretrained,
|
145 |
+
weights=weights,
|
146 |
+
num_register_tokens=4,
|
147 |
+
interpolate_antialias=True,
|
148 |
+
interpolate_offset=0.0,
|
149 |
+
**kwargs,
|
150 |
+
)
|
151 |
+
|
152 |
+
|
153 |
+
def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
154 |
+
"""
|
155 |
+
DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
156 |
+
"""
|
157 |
+
return _make_dinov2_model(
|
158 |
+
arch_name="vit_giant2",
|
159 |
+
ffn_layer="swiglufused",
|
160 |
+
weights=weights,
|
161 |
+
pretrained=pretrained,
|
162 |
+
num_register_tokens=4,
|
163 |
+
interpolate_antialias=True,
|
164 |
+
interpolate_offset=0.0,
|
165 |
+
**kwargs,
|
166 |
+
)
|
LHM/models/encoders/dinov2/hub/classifiers.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from enum import Enum
|
7 |
+
from typing import Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
from .backbones import _make_dinov2_model
|
13 |
+
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
|
14 |
+
|
15 |
+
|
16 |
+
class Weights(Enum):
|
17 |
+
IMAGENET1K = "IMAGENET1K"
|
18 |
+
|
19 |
+
|
20 |
+
def _make_dinov2_linear_classification_head(
|
21 |
+
*,
|
22 |
+
arch_name: str = "vit_large",
|
23 |
+
patch_size: int = 14,
|
24 |
+
embed_dim: int = 1024,
|
25 |
+
layers: int = 4,
|
26 |
+
pretrained: bool = True,
|
27 |
+
weights: Union[Weights, str] = Weights.IMAGENET1K,
|
28 |
+
num_register_tokens: int = 0,
|
29 |
+
**kwargs,
|
30 |
+
):
|
31 |
+
if layers not in (1, 4):
|
32 |
+
raise AssertionError(f"Unsupported number of layers: {layers}")
|
33 |
+
if isinstance(weights, str):
|
34 |
+
try:
|
35 |
+
weights = Weights[weights]
|
36 |
+
except KeyError:
|
37 |
+
raise AssertionError(f"Unsupported weights: {weights}")
|
38 |
+
|
39 |
+
linear_head = nn.Linear((1 + layers) * embed_dim, 1_000)
|
40 |
+
|
41 |
+
if pretrained:
|
42 |
+
model_base_name = _make_dinov2_model_name(arch_name, patch_size)
|
43 |
+
model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
|
44 |
+
layers_str = str(layers) if layers == 4 else ""
|
45 |
+
url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_linear{layers_str}_head.pth"
|
46 |
+
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
47 |
+
linear_head.load_state_dict(state_dict, strict=True)
|
48 |
+
|
49 |
+
return linear_head
|
50 |
+
|
51 |
+
|
52 |
+
class _LinearClassifierWrapper(nn.Module):
|
53 |
+
def __init__(self, *, backbone: nn.Module, linear_head: nn.Module, layers: int = 4):
|
54 |
+
super().__init__()
|
55 |
+
self.backbone = backbone
|
56 |
+
self.linear_head = linear_head
|
57 |
+
self.layers = layers
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
if self.layers == 1:
|
61 |
+
x = self.backbone.forward_features(x)
|
62 |
+
cls_token = x["x_norm_clstoken"]
|
63 |
+
patch_tokens = x["x_norm_patchtokens"]
|
64 |
+
# fmt: off
|
65 |
+
linear_input = torch.cat([
|
66 |
+
cls_token,
|
67 |
+
patch_tokens.mean(dim=1),
|
68 |
+
], dim=1)
|
69 |
+
# fmt: on
|
70 |
+
elif self.layers == 4:
|
71 |
+
x = self.backbone.get_intermediate_layers(x, n=4, return_class_token=True)
|
72 |
+
# fmt: off
|
73 |
+
linear_input = torch.cat([
|
74 |
+
x[0][1],
|
75 |
+
x[1][1],
|
76 |
+
x[2][1],
|
77 |
+
x[3][1],
|
78 |
+
x[3][0].mean(dim=1),
|
79 |
+
], dim=1)
|
80 |
+
# fmt: on
|
81 |
+
else:
|
82 |
+
assert False, f"Unsupported number of layers: {self.layers}"
|
83 |
+
return self.linear_head(linear_input)
|
84 |
+
|
85 |
+
|
86 |
+
def _make_dinov2_linear_classifier(
|
87 |
+
*,
|
88 |
+
arch_name: str = "vit_large",
|
89 |
+
layers: int = 4,
|
90 |
+
pretrained: bool = True,
|
91 |
+
weights: Union[Weights, str] = Weights.IMAGENET1K,
|
92 |
+
num_register_tokens: int = 0,
|
93 |
+
interpolate_antialias: bool = False,
|
94 |
+
interpolate_offset: float = 0.1,
|
95 |
+
**kwargs,
|
96 |
+
):
|
97 |
+
backbone = _make_dinov2_model(
|
98 |
+
arch_name=arch_name,
|
99 |
+
pretrained=pretrained,
|
100 |
+
num_register_tokens=num_register_tokens,
|
101 |
+
interpolate_antialias=interpolate_antialias,
|
102 |
+
interpolate_offset=interpolate_offset,
|
103 |
+
**kwargs,
|
104 |
+
)
|
105 |
+
|
106 |
+
embed_dim = backbone.embed_dim
|
107 |
+
patch_size = backbone.patch_size
|
108 |
+
linear_head = _make_dinov2_linear_classification_head(
|
109 |
+
arch_name=arch_name,
|
110 |
+
patch_size=patch_size,
|
111 |
+
embed_dim=embed_dim,
|
112 |
+
layers=layers,
|
113 |
+
pretrained=pretrained,
|
114 |
+
weights=weights,
|
115 |
+
num_register_tokens=num_register_tokens,
|
116 |
+
)
|
117 |
+
|
118 |
+
return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head, layers=layers)
|
119 |
+
|
120 |
+
|
121 |
+
def dinov2_vits14_lc(
|
122 |
+
*,
|
123 |
+
layers: int = 4,
|
124 |
+
pretrained: bool = True,
|
125 |
+
weights: Union[Weights, str] = Weights.IMAGENET1K,
|
126 |
+
**kwargs,
|
127 |
+
):
|
128 |
+
"""
|
129 |
+
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
130 |
+
"""
|
131 |
+
return _make_dinov2_linear_classifier(
|
132 |
+
arch_name="vit_small",
|
133 |
+
layers=layers,
|
134 |
+
pretrained=pretrained,
|
135 |
+
weights=weights,
|
136 |
+
**kwargs,
|
137 |
+
)
|
138 |
+
|
139 |
+
|
140 |
+
def dinov2_vitb14_lc(
|
141 |
+
*,
|
142 |
+
layers: int = 4,
|
143 |
+
pretrained: bool = True,
|
144 |
+
weights: Union[Weights, str] = Weights.IMAGENET1K,
|
145 |
+
**kwargs,
|
146 |
+
):
|
147 |
+
"""
|
148 |
+
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
149 |
+
"""
|
150 |
+
return _make_dinov2_linear_classifier(
|
151 |
+
arch_name="vit_base",
|
152 |
+
layers=layers,
|
153 |
+
pretrained=pretrained,
|
154 |
+
weights=weights,
|
155 |
+
**kwargs,
|
156 |
+
)
|
157 |
+
|
158 |
+
|
159 |
+
def dinov2_vitl14_lc(
|
160 |
+
*,
|
161 |
+
layers: int = 4,
|
162 |
+
pretrained: bool = True,
|
163 |
+
weights: Union[Weights, str] = Weights.IMAGENET1K,
|
164 |
+
**kwargs,
|
165 |
+
):
|
166 |
+
"""
|
167 |
+
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
168 |
+
"""
|
169 |
+
return _make_dinov2_linear_classifier(
|
170 |
+
arch_name="vit_large",
|
171 |
+
layers=layers,
|
172 |
+
pretrained=pretrained,
|
173 |
+
weights=weights,
|
174 |
+
**kwargs,
|
175 |
+
)
|
176 |
+
|
177 |
+
|
178 |
+
def dinov2_vitg14_lc(
|
179 |
+
*,
|
180 |
+
layers: int = 4,
|
181 |
+
pretrained: bool = True,
|
182 |
+
weights: Union[Weights, str] = Weights.IMAGENET1K,
|
183 |
+
**kwargs,
|
184 |
+
):
|
185 |
+
"""
|
186 |
+
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
187 |
+
"""
|
188 |
+
return _make_dinov2_linear_classifier(
|
189 |
+
arch_name="vit_giant2",
|
190 |
+
layers=layers,
|
191 |
+
ffn_layer="swiglufused",
|
192 |
+
pretrained=pretrained,
|
193 |
+
weights=weights,
|
194 |
+
**kwargs,
|
195 |
+
)
|
196 |
+
|
197 |
+
|
198 |
+
def dinov2_vits14_reg_lc(
|
199 |
+
*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
|
200 |
+
):
|
201 |
+
"""
|
202 |
+
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
203 |
+
"""
|
204 |
+
return _make_dinov2_linear_classifier(
|
205 |
+
arch_name="vit_small",
|
206 |
+
layers=layers,
|
207 |
+
pretrained=pretrained,
|
208 |
+
weights=weights,
|
209 |
+
num_register_tokens=4,
|
210 |
+
interpolate_antialias=True,
|
211 |
+
interpolate_offset=0.0,
|
212 |
+
**kwargs,
|
213 |
+
)
|
214 |
+
|
215 |
+
|
216 |
+
def dinov2_vitb14_reg_lc(
|
217 |
+
*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
|
218 |
+
):
|
219 |
+
"""
|
220 |
+
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
221 |
+
"""
|
222 |
+
return _make_dinov2_linear_classifier(
|
223 |
+
arch_name="vit_base",
|
224 |
+
layers=layers,
|
225 |
+
pretrained=pretrained,
|
226 |
+
weights=weights,
|
227 |
+
num_register_tokens=4,
|
228 |
+
interpolate_antialias=True,
|
229 |
+
interpolate_offset=0.0,
|
230 |
+
**kwargs,
|
231 |
+
)
|
232 |
+
|
233 |
+
|
234 |
+
def dinov2_vitl14_reg_lc(
|
235 |
+
*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
|
236 |
+
):
|
237 |
+
"""
|
238 |
+
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
239 |
+
"""
|
240 |
+
return _make_dinov2_linear_classifier(
|
241 |
+
arch_name="vit_large",
|
242 |
+
layers=layers,
|
243 |
+
pretrained=pretrained,
|
244 |
+
weights=weights,
|
245 |
+
num_register_tokens=4,
|
246 |
+
interpolate_antialias=True,
|
247 |
+
interpolate_offset=0.0,
|
248 |
+
**kwargs,
|
249 |
+
)
|
250 |
+
|
251 |
+
|
252 |
+
def dinov2_vitg14_reg_lc(
|
253 |
+
*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
|
254 |
+
):
|
255 |
+
"""
|
256 |
+
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
257 |
+
"""
|
258 |
+
return _make_dinov2_linear_classifier(
|
259 |
+
arch_name="vit_giant2",
|
260 |
+
layers=layers,
|
261 |
+
ffn_layer="swiglufused",
|
262 |
+
pretrained=pretrained,
|
263 |
+
weights=weights,
|
264 |
+
num_register_tokens=4,
|
265 |
+
interpolate_antialias=True,
|
266 |
+
interpolate_offset=0.0,
|
267 |
+
**kwargs,
|
268 |
+
)
|
LHM/models/encoders/dinov2/hub/depth/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .decode_heads import BNHead, DPTHead
|
7 |
+
from .encoder_decoder import DepthEncoderDecoder
|
LHM/models/encoders/dinov2/hub/depth/decode_heads.py
ADDED
@@ -0,0 +1,747 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import copy
|
7 |
+
from functools import partial
|
8 |
+
import math
|
9 |
+
import warnings
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
|
14 |
+
from .ops import resize
|
15 |
+
|
16 |
+
|
17 |
+
# XXX: (Untested) replacement for mmcv.imdenormalize()
|
18 |
+
def _imdenormalize(img, mean, std, to_bgr=True):
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
mean = mean.reshape(1, -1).astype(np.float64)
|
22 |
+
std = std.reshape(1, -1).astype(np.float64)
|
23 |
+
img = (img * std) + mean
|
24 |
+
if to_bgr:
|
25 |
+
img = img[::-1]
|
26 |
+
return img
|
27 |
+
|
28 |
+
|
29 |
+
class DepthBaseDecodeHead(nn.Module):
|
30 |
+
"""Base class for BaseDecodeHead.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
in_channels (List): Input channels.
|
34 |
+
channels (int): Channels after modules, before conv_depth.
|
35 |
+
conv_layer (nn.Module): Conv layers. Default: None.
|
36 |
+
act_layer (nn.Module): Activation layers. Default: nn.ReLU.
|
37 |
+
loss_decode (dict): Config of decode loss.
|
38 |
+
Default: ().
|
39 |
+
sampler (dict|None): The config of depth map sampler.
|
40 |
+
Default: None.
|
41 |
+
align_corners (bool): align_corners argument of F.interpolate.
|
42 |
+
Default: False.
|
43 |
+
min_depth (int): Min depth in dataset setting.
|
44 |
+
Default: 1e-3.
|
45 |
+
max_depth (int): Max depth in dataset setting.
|
46 |
+
Default: None.
|
47 |
+
norm_layer (dict|None): Norm layers.
|
48 |
+
Default: None.
|
49 |
+
classify (bool): Whether predict depth in a cls.-reg. manner.
|
50 |
+
Default: False.
|
51 |
+
n_bins (int): The number of bins used in cls. step.
|
52 |
+
Default: 256.
|
53 |
+
bins_strategy (str): The discrete strategy used in cls. step.
|
54 |
+
Default: 'UD'.
|
55 |
+
norm_strategy (str): The norm strategy on cls. probability
|
56 |
+
distribution. Default: 'linear'
|
57 |
+
scale_up (str): Whether predict depth in a scale-up manner.
|
58 |
+
Default: False.
|
59 |
+
"""
|
60 |
+
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
in_channels,
|
64 |
+
conv_layer=None,
|
65 |
+
act_layer=nn.ReLU,
|
66 |
+
channels=96,
|
67 |
+
loss_decode=(),
|
68 |
+
sampler=None,
|
69 |
+
align_corners=False,
|
70 |
+
min_depth=1e-3,
|
71 |
+
max_depth=None,
|
72 |
+
norm_layer=None,
|
73 |
+
classify=False,
|
74 |
+
n_bins=256,
|
75 |
+
bins_strategy="UD",
|
76 |
+
norm_strategy="linear",
|
77 |
+
scale_up=False,
|
78 |
+
):
|
79 |
+
super(DepthBaseDecodeHead, self).__init__()
|
80 |
+
|
81 |
+
self.in_channels = in_channels
|
82 |
+
self.channels = channels
|
83 |
+
self.conf_layer = conv_layer
|
84 |
+
self.act_layer = act_layer
|
85 |
+
self.loss_decode = loss_decode
|
86 |
+
self.align_corners = align_corners
|
87 |
+
self.min_depth = min_depth
|
88 |
+
self.max_depth = max_depth
|
89 |
+
self.norm_layer = norm_layer
|
90 |
+
self.classify = classify
|
91 |
+
self.n_bins = n_bins
|
92 |
+
self.scale_up = scale_up
|
93 |
+
|
94 |
+
if self.classify:
|
95 |
+
assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID"
|
96 |
+
assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid"
|
97 |
+
|
98 |
+
self.bins_strategy = bins_strategy
|
99 |
+
self.norm_strategy = norm_strategy
|
100 |
+
self.softmax = nn.Softmax(dim=1)
|
101 |
+
self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1)
|
102 |
+
else:
|
103 |
+
self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1)
|
104 |
+
|
105 |
+
self.relu = nn.ReLU()
|
106 |
+
self.sigmoid = nn.Sigmoid()
|
107 |
+
|
108 |
+
def forward(self, inputs, img_metas):
|
109 |
+
"""Placeholder of forward function."""
|
110 |
+
pass
|
111 |
+
|
112 |
+
def forward_train(self, img, inputs, img_metas, depth_gt):
|
113 |
+
"""Forward function for training.
|
114 |
+
Args:
|
115 |
+
inputs (list[Tensor]): List of multi-level img features.
|
116 |
+
img_metas (list[dict]): List of image info dict where each dict
|
117 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
118 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
119 |
+
For details on the values of these keys see
|
120 |
+
`depth/datasets/pipelines/formatting.py:Collect`.
|
121 |
+
depth_gt (Tensor): GT depth
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
dict[str, Tensor]: a dictionary of loss components
|
125 |
+
"""
|
126 |
+
depth_pred = self.forward(inputs, img_metas)
|
127 |
+
losses = self.losses(depth_pred, depth_gt)
|
128 |
+
|
129 |
+
log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0])
|
130 |
+
losses.update(**log_imgs)
|
131 |
+
|
132 |
+
return losses
|
133 |
+
|
134 |
+
def forward_test(self, inputs, img_metas):
|
135 |
+
"""Forward function for testing.
|
136 |
+
Args:
|
137 |
+
inputs (list[Tensor]): List of multi-level img features.
|
138 |
+
img_metas (list[dict]): List of image info dict where each dict
|
139 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
140 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
141 |
+
For details on the values of these keys see
|
142 |
+
`depth/datasets/pipelines/formatting.py:Collect`.
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
Tensor: Output depth map.
|
146 |
+
"""
|
147 |
+
return self.forward(inputs, img_metas)
|
148 |
+
|
149 |
+
def depth_pred(self, feat):
|
150 |
+
"""Prediction each pixel."""
|
151 |
+
if self.classify:
|
152 |
+
logit = self.conv_depth(feat)
|
153 |
+
|
154 |
+
if self.bins_strategy == "UD":
|
155 |
+
bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
|
156 |
+
elif self.bins_strategy == "SID":
|
157 |
+
bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
|
158 |
+
|
159 |
+
# following Adabins, default linear
|
160 |
+
if self.norm_strategy == "linear":
|
161 |
+
logit = torch.relu(logit)
|
162 |
+
eps = 0.1
|
163 |
+
logit = logit + eps
|
164 |
+
logit = logit / logit.sum(dim=1, keepdim=True)
|
165 |
+
elif self.norm_strategy == "softmax":
|
166 |
+
logit = torch.softmax(logit, dim=1)
|
167 |
+
elif self.norm_strategy == "sigmoid":
|
168 |
+
logit = torch.sigmoid(logit)
|
169 |
+
logit = logit / logit.sum(dim=1, keepdim=True)
|
170 |
+
|
171 |
+
output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1)
|
172 |
+
|
173 |
+
else:
|
174 |
+
if self.scale_up:
|
175 |
+
output = self.sigmoid(self.conv_depth(feat)) * self.max_depth
|
176 |
+
else:
|
177 |
+
output = self.relu(self.conv_depth(feat)) + self.min_depth
|
178 |
+
return output
|
179 |
+
|
180 |
+
def losses(self, depth_pred, depth_gt):
|
181 |
+
"""Compute depth loss."""
|
182 |
+
loss = dict()
|
183 |
+
depth_pred = resize(
|
184 |
+
input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False
|
185 |
+
)
|
186 |
+
if not isinstance(self.loss_decode, nn.ModuleList):
|
187 |
+
losses_decode = [self.loss_decode]
|
188 |
+
else:
|
189 |
+
losses_decode = self.loss_decode
|
190 |
+
for loss_decode in losses_decode:
|
191 |
+
if loss_decode.loss_name not in loss:
|
192 |
+
loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt)
|
193 |
+
else:
|
194 |
+
loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt)
|
195 |
+
return loss
|
196 |
+
|
197 |
+
def log_images(self, img_path, depth_pred, depth_gt, img_meta):
|
198 |
+
import numpy as np
|
199 |
+
|
200 |
+
show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0))
|
201 |
+
show_img = show_img.numpy().astype(np.float32)
|
202 |
+
show_img = _imdenormalize(
|
203 |
+
show_img,
|
204 |
+
img_meta["img_norm_cfg"]["mean"],
|
205 |
+
img_meta["img_norm_cfg"]["std"],
|
206 |
+
img_meta["img_norm_cfg"]["to_rgb"],
|
207 |
+
)
|
208 |
+
show_img = np.clip(show_img, 0, 255)
|
209 |
+
show_img = show_img.astype(np.uint8)
|
210 |
+
show_img = show_img[:, :, ::-1]
|
211 |
+
show_img = show_img.transpose(0, 2, 1)
|
212 |
+
show_img = show_img.transpose(1, 0, 2)
|
213 |
+
|
214 |
+
depth_pred = depth_pred / torch.max(depth_pred)
|
215 |
+
depth_gt = depth_gt / torch.max(depth_gt)
|
216 |
+
|
217 |
+
depth_pred_color = copy.deepcopy(depth_pred.detach().cpu())
|
218 |
+
depth_gt_color = copy.deepcopy(depth_gt.detach().cpu())
|
219 |
+
|
220 |
+
return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color}
|
221 |
+
|
222 |
+
|
223 |
+
class BNHead(DepthBaseDecodeHead):
|
224 |
+
"""Just a batchnorm."""
|
225 |
+
|
226 |
+
def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs):
|
227 |
+
super().__init__(**kwargs)
|
228 |
+
self.input_transform = input_transform
|
229 |
+
self.in_index = in_index
|
230 |
+
self.upsample = upsample
|
231 |
+
# self.bn = nn.SyncBatchNorm(self.in_channels)
|
232 |
+
if self.classify:
|
233 |
+
self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1)
|
234 |
+
else:
|
235 |
+
self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1)
|
236 |
+
|
237 |
+
def _transform_inputs(self, inputs):
|
238 |
+
"""Transform inputs for decoder.
|
239 |
+
Args:
|
240 |
+
inputs (list[Tensor]): List of multi-level img features.
|
241 |
+
Returns:
|
242 |
+
Tensor: The transformed inputs
|
243 |
+
"""
|
244 |
+
|
245 |
+
if "concat" in self.input_transform:
|
246 |
+
inputs = [inputs[i] for i in self.in_index]
|
247 |
+
if "resize" in self.input_transform:
|
248 |
+
inputs = [
|
249 |
+
resize(
|
250 |
+
input=x,
|
251 |
+
size=[s * self.upsample for s in inputs[0].shape[2:]],
|
252 |
+
mode="bilinear",
|
253 |
+
align_corners=self.align_corners,
|
254 |
+
)
|
255 |
+
for x in inputs
|
256 |
+
]
|
257 |
+
inputs = torch.cat(inputs, dim=1)
|
258 |
+
elif self.input_transform == "multiple_select":
|
259 |
+
inputs = [inputs[i] for i in self.in_index]
|
260 |
+
else:
|
261 |
+
inputs = inputs[self.in_index]
|
262 |
+
|
263 |
+
return inputs
|
264 |
+
|
265 |
+
def _forward_feature(self, inputs, img_metas=None, **kwargs):
|
266 |
+
"""Forward function for feature maps before classifying each pixel with
|
267 |
+
``self.cls_seg`` fc.
|
268 |
+
Args:
|
269 |
+
inputs (list[Tensor]): List of multi-level img features.
|
270 |
+
Returns:
|
271 |
+
feats (Tensor): A tensor of shape (batch_size, self.channels,
|
272 |
+
H, W) which is feature map for last layer of decoder head.
|
273 |
+
"""
|
274 |
+
# accept lists (for cls token)
|
275 |
+
inputs = list(inputs)
|
276 |
+
for i, x in enumerate(inputs):
|
277 |
+
if len(x) == 2:
|
278 |
+
x, cls_token = x[0], x[1]
|
279 |
+
if len(x.shape) == 2:
|
280 |
+
x = x[:, :, None, None]
|
281 |
+
cls_token = cls_token[:, :, None, None].expand_as(x)
|
282 |
+
inputs[i] = torch.cat((x, cls_token), 1)
|
283 |
+
else:
|
284 |
+
x = x[0]
|
285 |
+
if len(x.shape) == 2:
|
286 |
+
x = x[:, :, None, None]
|
287 |
+
inputs[i] = x
|
288 |
+
x = self._transform_inputs(inputs)
|
289 |
+
# feats = self.bn(x)
|
290 |
+
return x
|
291 |
+
|
292 |
+
def forward(self, inputs, img_metas=None, **kwargs):
|
293 |
+
"""Forward function."""
|
294 |
+
output = self._forward_feature(inputs, img_metas=img_metas, **kwargs)
|
295 |
+
output = self.depth_pred(output)
|
296 |
+
return output
|
297 |
+
|
298 |
+
|
299 |
+
class ConvModule(nn.Module):
|
300 |
+
"""A conv block that bundles conv/norm/activation layers.
|
301 |
+
|
302 |
+
This block simplifies the usage of convolution layers, which are commonly
|
303 |
+
used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
|
304 |
+
It is based upon three build methods: `build_conv_layer()`,
|
305 |
+
`build_norm_layer()` and `build_activation_layer()`.
|
306 |
+
|
307 |
+
Besides, we add some additional features in this module.
|
308 |
+
1. Automatically set `bias` of the conv layer.
|
309 |
+
2. Spectral norm is supported.
|
310 |
+
3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
|
311 |
+
supports zero and circular padding, and we add "reflect" padding mode.
|
312 |
+
|
313 |
+
Args:
|
314 |
+
in_channels (int): Number of channels in the input feature map.
|
315 |
+
Same as that in ``nn._ConvNd``.
|
316 |
+
out_channels (int): Number of channels produced by the convolution.
|
317 |
+
Same as that in ``nn._ConvNd``.
|
318 |
+
kernel_size (int | tuple[int]): Size of the convolving kernel.
|
319 |
+
Same as that in ``nn._ConvNd``.
|
320 |
+
stride (int | tuple[int]): Stride of the convolution.
|
321 |
+
Same as that in ``nn._ConvNd``.
|
322 |
+
padding (int | tuple[int]): Zero-padding added to both sides of
|
323 |
+
the input. Same as that in ``nn._ConvNd``.
|
324 |
+
dilation (int | tuple[int]): Spacing between kernel elements.
|
325 |
+
Same as that in ``nn._ConvNd``.
|
326 |
+
groups (int): Number of blocked connections from input channels to
|
327 |
+
output channels. Same as that in ``nn._ConvNd``.
|
328 |
+
bias (bool | str): If specified as `auto`, it will be decided by the
|
329 |
+
norm_layer. Bias will be set as True if `norm_layer` is None, otherwise
|
330 |
+
False. Default: "auto".
|
331 |
+
conv_layer (nn.Module): Convolution layer. Default: None,
|
332 |
+
which means using conv2d.
|
333 |
+
norm_layer (nn.Module): Normalization layer. Default: None.
|
334 |
+
act_layer (nn.Module): Activation layer. Default: nn.ReLU.
|
335 |
+
inplace (bool): Whether to use inplace mode for activation.
|
336 |
+
Default: True.
|
337 |
+
with_spectral_norm (bool): Whether use spectral norm in conv module.
|
338 |
+
Default: False.
|
339 |
+
padding_mode (str): If the `padding_mode` has not been supported by
|
340 |
+
current `Conv2d` in PyTorch, we will use our own padding layer
|
341 |
+
instead. Currently, we support ['zeros', 'circular'] with official
|
342 |
+
implementation and ['reflect'] with our own implementation.
|
343 |
+
Default: 'zeros'.
|
344 |
+
order (tuple[str]): The order of conv/norm/activation layers. It is a
|
345 |
+
sequence of "conv", "norm" and "act". Common examples are
|
346 |
+
("conv", "norm", "act") and ("act", "conv", "norm").
|
347 |
+
Default: ('conv', 'norm', 'act').
|
348 |
+
"""
|
349 |
+
|
350 |
+
_abbr_ = "conv_block"
|
351 |
+
|
352 |
+
def __init__(
|
353 |
+
self,
|
354 |
+
in_channels,
|
355 |
+
out_channels,
|
356 |
+
kernel_size,
|
357 |
+
stride=1,
|
358 |
+
padding=0,
|
359 |
+
dilation=1,
|
360 |
+
groups=1,
|
361 |
+
bias="auto",
|
362 |
+
conv_layer=nn.Conv2d,
|
363 |
+
norm_layer=None,
|
364 |
+
act_layer=nn.ReLU,
|
365 |
+
inplace=True,
|
366 |
+
with_spectral_norm=False,
|
367 |
+
padding_mode="zeros",
|
368 |
+
order=("conv", "norm", "act"),
|
369 |
+
):
|
370 |
+
super(ConvModule, self).__init__()
|
371 |
+
official_padding_mode = ["zeros", "circular"]
|
372 |
+
self.conv_layer = conv_layer
|
373 |
+
self.norm_layer = norm_layer
|
374 |
+
self.act_layer = act_layer
|
375 |
+
self.inplace = inplace
|
376 |
+
self.with_spectral_norm = with_spectral_norm
|
377 |
+
self.with_explicit_padding = padding_mode not in official_padding_mode
|
378 |
+
self.order = order
|
379 |
+
assert isinstance(self.order, tuple) and len(self.order) == 3
|
380 |
+
assert set(order) == set(["conv", "norm", "act"])
|
381 |
+
|
382 |
+
self.with_norm = norm_layer is not None
|
383 |
+
self.with_activation = act_layer is not None
|
384 |
+
# if the conv layer is before a norm layer, bias is unnecessary.
|
385 |
+
if bias == "auto":
|
386 |
+
bias = not self.with_norm
|
387 |
+
self.with_bias = bias
|
388 |
+
|
389 |
+
if self.with_explicit_padding:
|
390 |
+
if padding_mode == "zeros":
|
391 |
+
padding_layer = nn.ZeroPad2d
|
392 |
+
else:
|
393 |
+
raise AssertionError(f"Unsupported padding mode: {padding_mode}")
|
394 |
+
self.pad = padding_layer(padding)
|
395 |
+
|
396 |
+
# reset padding to 0 for conv module
|
397 |
+
conv_padding = 0 if self.with_explicit_padding else padding
|
398 |
+
# build convolution layer
|
399 |
+
self.conv = self.conv_layer(
|
400 |
+
in_channels,
|
401 |
+
out_channels,
|
402 |
+
kernel_size,
|
403 |
+
stride=stride,
|
404 |
+
padding=conv_padding,
|
405 |
+
dilation=dilation,
|
406 |
+
groups=groups,
|
407 |
+
bias=bias,
|
408 |
+
)
|
409 |
+
# export the attributes of self.conv to a higher level for convenience
|
410 |
+
self.in_channels = self.conv.in_channels
|
411 |
+
self.out_channels = self.conv.out_channels
|
412 |
+
self.kernel_size = self.conv.kernel_size
|
413 |
+
self.stride = self.conv.stride
|
414 |
+
self.padding = padding
|
415 |
+
self.dilation = self.conv.dilation
|
416 |
+
self.transposed = self.conv.transposed
|
417 |
+
self.output_padding = self.conv.output_padding
|
418 |
+
self.groups = self.conv.groups
|
419 |
+
|
420 |
+
if self.with_spectral_norm:
|
421 |
+
self.conv = nn.utils.spectral_norm(self.conv)
|
422 |
+
|
423 |
+
# build normalization layers
|
424 |
+
if self.with_norm:
|
425 |
+
# norm layer is after conv layer
|
426 |
+
if order.index("norm") > order.index("conv"):
|
427 |
+
norm_channels = out_channels
|
428 |
+
else:
|
429 |
+
norm_channels = in_channels
|
430 |
+
norm = partial(norm_layer, num_features=norm_channels)
|
431 |
+
self.add_module("norm", norm)
|
432 |
+
if self.with_bias:
|
433 |
+
from torch.nnModules.batchnorm import _BatchNorm
|
434 |
+
from torch.nnModules.instancenorm import _InstanceNorm
|
435 |
+
|
436 |
+
if isinstance(norm, (_BatchNorm, _InstanceNorm)):
|
437 |
+
warnings.warn("Unnecessary conv bias before batch/instance norm")
|
438 |
+
else:
|
439 |
+
self.norm_name = None
|
440 |
+
|
441 |
+
# build activation layer
|
442 |
+
if self.with_activation:
|
443 |
+
# nn.Tanh has no 'inplace' argument
|
444 |
+
# (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.HSigmoid, nn.Swish, nn.GELU)
|
445 |
+
if not isinstance(act_layer, (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.GELU)):
|
446 |
+
act_layer = partial(act_layer, inplace=inplace)
|
447 |
+
self.activate = act_layer()
|
448 |
+
|
449 |
+
# Use msra init by default
|
450 |
+
self.init_weights()
|
451 |
+
|
452 |
+
@property
|
453 |
+
def norm(self):
|
454 |
+
if self.norm_name:
|
455 |
+
return getattr(self, self.norm_name)
|
456 |
+
else:
|
457 |
+
return None
|
458 |
+
|
459 |
+
def init_weights(self):
|
460 |
+
# 1. It is mainly for customized conv layers with their own
|
461 |
+
# initialization manners by calling their own ``init_weights()``,
|
462 |
+
# and we do not want ConvModule to override the initialization.
|
463 |
+
# 2. For customized conv layers without their own initialization
|
464 |
+
# manners (that is, they don't have their own ``init_weights()``)
|
465 |
+
# and PyTorch's conv layers, they will be initialized by
|
466 |
+
# this method with default ``kaiming_init``.
|
467 |
+
# Note: For PyTorch's conv layers, they will be overwritten by our
|
468 |
+
# initialization implementation using default ``kaiming_init``.
|
469 |
+
if not hasattr(self.conv, "init_weights"):
|
470 |
+
if self.with_activation and isinstance(self.act_layer, nn.LeakyReLU):
|
471 |
+
nonlinearity = "leaky_relu"
|
472 |
+
a = 0.01 # XXX: default negative_slope
|
473 |
+
else:
|
474 |
+
nonlinearity = "relu"
|
475 |
+
a = 0
|
476 |
+
if hasattr(self.conv, "weight") and self.conv.weight is not None:
|
477 |
+
nn.init.kaiming_normal_(self.conv.weight, a=a, mode="fan_out", nonlinearity=nonlinearity)
|
478 |
+
if hasattr(self.conv, "bias") and self.conv.bias is not None:
|
479 |
+
nn.init.constant_(self.conv.bias, 0)
|
480 |
+
if self.with_norm:
|
481 |
+
if hasattr(self.norm, "weight") and self.norm.weight is not None:
|
482 |
+
nn.init.constant_(self.norm.weight, 1)
|
483 |
+
if hasattr(self.norm, "bias") and self.norm.bias is not None:
|
484 |
+
nn.init.constant_(self.norm.bias, 0)
|
485 |
+
|
486 |
+
def forward(self, x, activate=True, norm=True):
|
487 |
+
for layer in self.order:
|
488 |
+
if layer == "conv":
|
489 |
+
if self.with_explicit_padding:
|
490 |
+
x = self.pad(x)
|
491 |
+
x = self.conv(x)
|
492 |
+
elif layer == "norm" and norm and self.with_norm:
|
493 |
+
x = self.norm(x)
|
494 |
+
elif layer == "act" and activate and self.with_activation:
|
495 |
+
x = self.activate(x)
|
496 |
+
return x
|
497 |
+
|
498 |
+
|
499 |
+
class Interpolate(nn.Module):
|
500 |
+
def __init__(self, scale_factor, mode, align_corners=False):
|
501 |
+
super(Interpolate, self).__init__()
|
502 |
+
self.interp = nn.functional.interpolate
|
503 |
+
self.scale_factor = scale_factor
|
504 |
+
self.mode = mode
|
505 |
+
self.align_corners = align_corners
|
506 |
+
|
507 |
+
def forward(self, x):
|
508 |
+
x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
|
509 |
+
return x
|
510 |
+
|
511 |
+
|
512 |
+
class HeadDepth(nn.Module):
|
513 |
+
def __init__(self, features):
|
514 |
+
super(HeadDepth, self).__init__()
|
515 |
+
self.head = nn.Sequential(
|
516 |
+
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
517 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
518 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
519 |
+
nn.ReLU(),
|
520 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
521 |
+
)
|
522 |
+
|
523 |
+
def forward(self, x):
|
524 |
+
x = self.head(x)
|
525 |
+
return x
|
526 |
+
|
527 |
+
|
528 |
+
class ReassembleBlocks(nn.Module):
|
529 |
+
"""ViTPostProcessBlock, process cls_token in ViT backbone output and
|
530 |
+
rearrange the feature vector to feature map.
|
531 |
+
Args:
|
532 |
+
in_channels (int): ViT feature channels. Default: 768.
|
533 |
+
out_channels (List): output channels of each stage.
|
534 |
+
Default: [96, 192, 384, 768].
|
535 |
+
readout_type (str): Type of readout operation. Default: 'ignore'.
|
536 |
+
patch_size (int): The patch size. Default: 16.
|
537 |
+
"""
|
538 |
+
|
539 |
+
def __init__(self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16):
|
540 |
+
super(ReassembleBlocks, self).__init__()
|
541 |
+
|
542 |
+
assert readout_type in ["ignore", "add", "project"]
|
543 |
+
self.readout_type = readout_type
|
544 |
+
self.patch_size = patch_size
|
545 |
+
|
546 |
+
self.projects = nn.ModuleList(
|
547 |
+
[
|
548 |
+
ConvModule(
|
549 |
+
in_channels=in_channels,
|
550 |
+
out_channels=out_channel,
|
551 |
+
kernel_size=1,
|
552 |
+
act_layer=None,
|
553 |
+
)
|
554 |
+
for out_channel in out_channels
|
555 |
+
]
|
556 |
+
)
|
557 |
+
|
558 |
+
self.resize_layers = nn.ModuleList(
|
559 |
+
[
|
560 |
+
nn.ConvTranspose2d(
|
561 |
+
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
|
562 |
+
),
|
563 |
+
nn.ConvTranspose2d(
|
564 |
+
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
|
565 |
+
),
|
566 |
+
nn.Identity(),
|
567 |
+
nn.Conv2d(
|
568 |
+
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
|
569 |
+
),
|
570 |
+
]
|
571 |
+
)
|
572 |
+
if self.readout_type == "project":
|
573 |
+
self.readout_projects = nn.ModuleList()
|
574 |
+
for _ in range(len(self.projects)):
|
575 |
+
self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU()))
|
576 |
+
|
577 |
+
def forward(self, inputs):
|
578 |
+
assert isinstance(inputs, list)
|
579 |
+
out = []
|
580 |
+
for i, x in enumerate(inputs):
|
581 |
+
assert len(x) == 2
|
582 |
+
x, cls_token = x[0], x[1]
|
583 |
+
feature_shape = x.shape
|
584 |
+
if self.readout_type == "project":
|
585 |
+
x = x.flatten(2).permute((0, 2, 1))
|
586 |
+
readout = cls_token.unsqueeze(1).expand_as(x)
|
587 |
+
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
588 |
+
x = x.permute(0, 2, 1).reshape(feature_shape)
|
589 |
+
elif self.readout_type == "add":
|
590 |
+
x = x.flatten(2) + cls_token.unsqueeze(-1)
|
591 |
+
x = x.reshape(feature_shape)
|
592 |
+
else:
|
593 |
+
pass
|
594 |
+
x = self.projects[i](x)
|
595 |
+
x = self.resize_layers[i](x)
|
596 |
+
out.append(x)
|
597 |
+
return out
|
598 |
+
|
599 |
+
|
600 |
+
class PreActResidualConvUnit(nn.Module):
|
601 |
+
"""ResidualConvUnit, pre-activate residual unit.
|
602 |
+
Args:
|
603 |
+
in_channels (int): number of channels in the input feature map.
|
604 |
+
act_layer (nn.Module): activation layer.
|
605 |
+
norm_layer (nn.Module): norm layer.
|
606 |
+
stride (int): stride of the first block. Default: 1
|
607 |
+
dilation (int): dilation rate for convs layers. Default: 1.
|
608 |
+
"""
|
609 |
+
|
610 |
+
def __init__(self, in_channels, act_layer, norm_layer, stride=1, dilation=1):
|
611 |
+
super(PreActResidualConvUnit, self).__init__()
|
612 |
+
|
613 |
+
self.conv1 = ConvModule(
|
614 |
+
in_channels,
|
615 |
+
in_channels,
|
616 |
+
3,
|
617 |
+
stride=stride,
|
618 |
+
padding=dilation,
|
619 |
+
dilation=dilation,
|
620 |
+
norm_layer=norm_layer,
|
621 |
+
act_layer=act_layer,
|
622 |
+
bias=False,
|
623 |
+
order=("act", "conv", "norm"),
|
624 |
+
)
|
625 |
+
|
626 |
+
self.conv2 = ConvModule(
|
627 |
+
in_channels,
|
628 |
+
in_channels,
|
629 |
+
3,
|
630 |
+
padding=1,
|
631 |
+
norm_layer=norm_layer,
|
632 |
+
act_layer=act_layer,
|
633 |
+
bias=False,
|
634 |
+
order=("act", "conv", "norm"),
|
635 |
+
)
|
636 |
+
|
637 |
+
def forward(self, inputs):
|
638 |
+
inputs_ = inputs.clone()
|
639 |
+
x = self.conv1(inputs)
|
640 |
+
x = self.conv2(x)
|
641 |
+
return x + inputs_
|
642 |
+
|
643 |
+
|
644 |
+
class FeatureFusionBlock(nn.Module):
|
645 |
+
"""FeatureFusionBlock, merge feature map from different stages.
|
646 |
+
Args:
|
647 |
+
in_channels (int): Input channels.
|
648 |
+
act_layer (nn.Module): activation layer for ResidualConvUnit.
|
649 |
+
norm_layer (nn.Module): normalization layer.
|
650 |
+
expand (bool): Whether expand the channels in post process block.
|
651 |
+
Default: False.
|
652 |
+
align_corners (bool): align_corner setting for bilinear upsample.
|
653 |
+
Default: True.
|
654 |
+
"""
|
655 |
+
|
656 |
+
def __init__(self, in_channels, act_layer, norm_layer, expand=False, align_corners=True):
|
657 |
+
super(FeatureFusionBlock, self).__init__()
|
658 |
+
|
659 |
+
self.in_channels = in_channels
|
660 |
+
self.expand = expand
|
661 |
+
self.align_corners = align_corners
|
662 |
+
|
663 |
+
self.out_channels = in_channels
|
664 |
+
if self.expand:
|
665 |
+
self.out_channels = in_channels // 2
|
666 |
+
|
667 |
+
self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_layer=None, bias=True)
|
668 |
+
|
669 |
+
self.res_conv_unit1 = PreActResidualConvUnit(
|
670 |
+
in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer
|
671 |
+
)
|
672 |
+
self.res_conv_unit2 = PreActResidualConvUnit(
|
673 |
+
in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer
|
674 |
+
)
|
675 |
+
|
676 |
+
def forward(self, *inputs):
|
677 |
+
x = inputs[0]
|
678 |
+
if len(inputs) == 2:
|
679 |
+
if x.shape != inputs[1].shape:
|
680 |
+
res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False)
|
681 |
+
else:
|
682 |
+
res = inputs[1]
|
683 |
+
x = x + self.res_conv_unit1(res)
|
684 |
+
x = self.res_conv_unit2(x)
|
685 |
+
x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners)
|
686 |
+
x = self.project(x)
|
687 |
+
return x
|
688 |
+
|
689 |
+
|
690 |
+
class DPTHead(DepthBaseDecodeHead):
|
691 |
+
"""Vision Transformers for Dense Prediction.
|
692 |
+
This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_.
|
693 |
+
Args:
|
694 |
+
embed_dims (int): The embed dimension of the ViT backbone.
|
695 |
+
Default: 768.
|
696 |
+
post_process_channels (List): Out channels of post process conv
|
697 |
+
layers. Default: [96, 192, 384, 768].
|
698 |
+
readout_type (str): Type of readout operation. Default: 'ignore'.
|
699 |
+
patch_size (int): The patch size. Default: 16.
|
700 |
+
expand_channels (bool): Whether expand the channels in post process
|
701 |
+
block. Default: False.
|
702 |
+
"""
|
703 |
+
|
704 |
+
def __init__(
|
705 |
+
self,
|
706 |
+
embed_dims=768,
|
707 |
+
post_process_channels=[96, 192, 384, 768],
|
708 |
+
readout_type="ignore",
|
709 |
+
patch_size=16,
|
710 |
+
expand_channels=False,
|
711 |
+
**kwargs,
|
712 |
+
):
|
713 |
+
super(DPTHead, self).__init__(**kwargs)
|
714 |
+
|
715 |
+
self.in_channels = self.in_channels
|
716 |
+
self.expand_channels = expand_channels
|
717 |
+
self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size)
|
718 |
+
|
719 |
+
self.post_process_channels = [
|
720 |
+
channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels)
|
721 |
+
]
|
722 |
+
self.convs = nn.ModuleList()
|
723 |
+
for channel in self.post_process_channels:
|
724 |
+
self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_layer=None, bias=False))
|
725 |
+
self.fusion_blocks = nn.ModuleList()
|
726 |
+
for _ in range(len(self.convs)):
|
727 |
+
self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_layer, self.norm_layer))
|
728 |
+
self.fusion_blocks[0].res_conv_unit1 = None
|
729 |
+
self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_layer=self.norm_layer)
|
730 |
+
self.num_fusion_blocks = len(self.fusion_blocks)
|
731 |
+
self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers)
|
732 |
+
self.num_post_process_channels = len(self.post_process_channels)
|
733 |
+
assert self.num_fusion_blocks == self.num_reassemble_blocks
|
734 |
+
assert self.num_reassemble_blocks == self.num_post_process_channels
|
735 |
+
self.conv_depth = HeadDepth(self.channels)
|
736 |
+
|
737 |
+
def forward(self, inputs, img_metas):
|
738 |
+
assert len(inputs) == self.num_reassemble_blocks
|
739 |
+
x = [inp for inp in inputs]
|
740 |
+
x = self.reassemble_blocks(x)
|
741 |
+
x = [self.convs[i](feature) for i, feature in enumerate(x)]
|
742 |
+
out = self.fusion_blocks[0](x[-1])
|
743 |
+
for i in range(1, len(self.fusion_blocks)):
|
744 |
+
out = self.fusion_blocks[i](out, x[-(i + 1)])
|
745 |
+
out = self.project(out)
|
746 |
+
out = self.depth_pred(out)
|
747 |
+
return out
|
LHM/models/encoders/dinov2/hub/depth/encoder_decoder.py
ADDED
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from collections import OrderedDict
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
from .ops import resize
|
13 |
+
|
14 |
+
|
15 |
+
def add_prefix(inputs, prefix):
|
16 |
+
"""Add prefix for dict.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
inputs (dict): The input dict with str keys.
|
20 |
+
prefix (str): The prefix to add.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
|
24 |
+
dict: The dict with keys updated with ``prefix``.
|
25 |
+
"""
|
26 |
+
|
27 |
+
outputs = dict()
|
28 |
+
for name, value in inputs.items():
|
29 |
+
outputs[f"{prefix}.{name}"] = value
|
30 |
+
|
31 |
+
return outputs
|
32 |
+
|
33 |
+
|
34 |
+
class DepthEncoderDecoder(nn.Module):
|
35 |
+
"""Encoder Decoder depther.
|
36 |
+
|
37 |
+
EncoderDecoder typically consists of backbone and decode_head.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, backbone, decode_head):
|
41 |
+
super(DepthEncoderDecoder, self).__init__()
|
42 |
+
|
43 |
+
self.backbone = backbone
|
44 |
+
self.decode_head = decode_head
|
45 |
+
self.align_corners = self.decode_head.align_corners
|
46 |
+
|
47 |
+
def extract_feat(self, img):
|
48 |
+
"""Extract features from images."""
|
49 |
+
return self.backbone(img)
|
50 |
+
|
51 |
+
def encode_decode(self, img, img_metas, rescale=True, size=None):
|
52 |
+
"""Encode images with backbone and decode into a depth estimation
|
53 |
+
map of the same size as input."""
|
54 |
+
x = self.extract_feat(img)
|
55 |
+
out = self._decode_head_forward_test(x, img_metas)
|
56 |
+
# crop the pred depth to the certain range.
|
57 |
+
out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth)
|
58 |
+
if rescale:
|
59 |
+
if size is None:
|
60 |
+
if img_metas is not None:
|
61 |
+
size = img_metas[0]["ori_shape"][:2]
|
62 |
+
else:
|
63 |
+
size = img.shape[2:]
|
64 |
+
out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners)
|
65 |
+
return out
|
66 |
+
|
67 |
+
def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs):
|
68 |
+
"""Run forward function and calculate loss for decode head in
|
69 |
+
training."""
|
70 |
+
losses = dict()
|
71 |
+
loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, **kwargs)
|
72 |
+
losses.update(add_prefix(loss_decode, "decode"))
|
73 |
+
return losses
|
74 |
+
|
75 |
+
def _decode_head_forward_test(self, x, img_metas):
|
76 |
+
"""Run forward function and calculate loss for decode head in
|
77 |
+
inference."""
|
78 |
+
depth_pred = self.decode_head.forward_test(x, img_metas)
|
79 |
+
return depth_pred
|
80 |
+
|
81 |
+
def forward_dummy(self, img):
|
82 |
+
"""Dummy forward function."""
|
83 |
+
depth = self.encode_decode(img, None)
|
84 |
+
|
85 |
+
return depth
|
86 |
+
|
87 |
+
def forward_train(self, img, img_metas, depth_gt, **kwargs):
|
88 |
+
"""Forward function for training.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
img (Tensor): Input images.
|
92 |
+
img_metas (list[dict]): List of image info dict where each dict
|
93 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
94 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
95 |
+
For details on the values of these keys see
|
96 |
+
`depth/datasets/pipelines/formatting.py:Collect`.
|
97 |
+
depth_gt (Tensor): Depth gt
|
98 |
+
used if the architecture supports depth estimation task.
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
dict[str, Tensor]: a dictionary of loss components
|
102 |
+
"""
|
103 |
+
|
104 |
+
x = self.extract_feat(img)
|
105 |
+
|
106 |
+
losses = dict()
|
107 |
+
|
108 |
+
# the last of x saves the info from neck
|
109 |
+
loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs)
|
110 |
+
|
111 |
+
losses.update(loss_decode)
|
112 |
+
|
113 |
+
return losses
|
114 |
+
|
115 |
+
def whole_inference(self, img, img_meta, rescale, size=None):
|
116 |
+
"""Inference with full image."""
|
117 |
+
return self.encode_decode(img, img_meta, rescale, size=size)
|
118 |
+
|
119 |
+
def slide_inference(self, img, img_meta, rescale, stride, crop_size):
|
120 |
+
"""Inference by sliding-window with overlap.
|
121 |
+
|
122 |
+
If h_crop > h_img or w_crop > w_img, the small patch will be used to
|
123 |
+
decode without padding.
|
124 |
+
"""
|
125 |
+
|
126 |
+
h_stride, w_stride = stride
|
127 |
+
h_crop, w_crop = crop_size
|
128 |
+
batch_size, _, h_img, w_img = img.size()
|
129 |
+
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
|
130 |
+
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
|
131 |
+
preds = img.new_zeros((batch_size, 1, h_img, w_img))
|
132 |
+
count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
|
133 |
+
for h_idx in range(h_grids):
|
134 |
+
for w_idx in range(w_grids):
|
135 |
+
y1 = h_idx * h_stride
|
136 |
+
x1 = w_idx * w_stride
|
137 |
+
y2 = min(y1 + h_crop, h_img)
|
138 |
+
x2 = min(x1 + w_crop, w_img)
|
139 |
+
y1 = max(y2 - h_crop, 0)
|
140 |
+
x1 = max(x2 - w_crop, 0)
|
141 |
+
crop_img = img[:, :, y1:y2, x1:x2]
|
142 |
+
depth_pred = self.encode_decode(crop_img, img_meta, rescale)
|
143 |
+
preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2)))
|
144 |
+
|
145 |
+
count_mat[:, :, y1:y2, x1:x2] += 1
|
146 |
+
assert (count_mat == 0).sum() == 0
|
147 |
+
if torch.onnx.is_in_onnx_export():
|
148 |
+
# cast count_mat to constant while exporting to ONNX
|
149 |
+
count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device)
|
150 |
+
preds = preds / count_mat
|
151 |
+
return preds
|
152 |
+
|
153 |
+
def inference(self, img, img_meta, rescale, size=None, mode="whole"):
|
154 |
+
"""Inference with slide/whole style.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
img (Tensor): The input image of shape (N, 3, H, W).
|
158 |
+
img_meta (dict): Image info dict where each dict has: 'img_shape',
|
159 |
+
'scale_factor', 'flip', and may also contain
|
160 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
161 |
+
For details on the values of these keys see
|
162 |
+
`depth/datasets/pipelines/formatting.py:Collect`.
|
163 |
+
rescale (bool): Whether rescale back to original shape.
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
Tensor: The output depth map.
|
167 |
+
"""
|
168 |
+
|
169 |
+
assert mode in ["slide", "whole"]
|
170 |
+
ori_shape = img_meta[0]["ori_shape"]
|
171 |
+
assert all(_["ori_shape"] == ori_shape for _ in img_meta)
|
172 |
+
if mode == "slide":
|
173 |
+
depth_pred = self.slide_inference(img, img_meta, rescale)
|
174 |
+
else:
|
175 |
+
depth_pred = self.whole_inference(img, img_meta, rescale, size=size)
|
176 |
+
output = depth_pred
|
177 |
+
flip = img_meta[0]["flip"]
|
178 |
+
if flip:
|
179 |
+
flip_direction = img_meta[0]["flip_direction"]
|
180 |
+
assert flip_direction in ["horizontal", "vertical"]
|
181 |
+
if flip_direction == "horizontal":
|
182 |
+
output = output.flip(dims=(3,))
|
183 |
+
elif flip_direction == "vertical":
|
184 |
+
output = output.flip(dims=(2,))
|
185 |
+
|
186 |
+
return output
|
187 |
+
|
188 |
+
def simple_test(self, img, img_meta, rescale=True):
|
189 |
+
"""Simple test with single image."""
|
190 |
+
depth_pred = self.inference(img, img_meta, rescale)
|
191 |
+
if torch.onnx.is_in_onnx_export():
|
192 |
+
# our inference backend only support 4D output
|
193 |
+
depth_pred = depth_pred.unsqueeze(0)
|
194 |
+
return depth_pred
|
195 |
+
depth_pred = depth_pred.cpu().numpy()
|
196 |
+
# unravel batch dim
|
197 |
+
depth_pred = list(depth_pred)
|
198 |
+
return depth_pred
|
199 |
+
|
200 |
+
def aug_test(self, imgs, img_metas, rescale=True):
|
201 |
+
"""Test with augmentations.
|
202 |
+
|
203 |
+
Only rescale=True is supported.
|
204 |
+
"""
|
205 |
+
# aug_test rescale all imgs back to ori_shape for now
|
206 |
+
assert rescale
|
207 |
+
# to save memory, we get augmented depth logit inplace
|
208 |
+
depth_pred = self.inference(imgs[0], img_metas[0], rescale)
|
209 |
+
for i in range(1, len(imgs)):
|
210 |
+
cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:])
|
211 |
+
depth_pred += cur_depth_pred
|
212 |
+
depth_pred /= len(imgs)
|
213 |
+
depth_pred = depth_pred.cpu().numpy()
|
214 |
+
# unravel batch dim
|
215 |
+
depth_pred = list(depth_pred)
|
216 |
+
return depth_pred
|
217 |
+
|
218 |
+
def forward_test(self, imgs, img_metas, **kwargs):
|
219 |
+
"""
|
220 |
+
Args:
|
221 |
+
imgs (List[Tensor]): the outer list indicates test-time
|
222 |
+
augmentations and inner Tensor should have a shape NxCxHxW,
|
223 |
+
which contains all images in the batch.
|
224 |
+
img_metas (List[List[dict]]): the outer list indicates test-time
|
225 |
+
augs (multiscale, flip, etc.) and the inner list indicates
|
226 |
+
images in a batch.
|
227 |
+
"""
|
228 |
+
for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]:
|
229 |
+
if not isinstance(var, list):
|
230 |
+
raise TypeError(f"{name} must be a list, but got " f"{type(var)}")
|
231 |
+
num_augs = len(imgs)
|
232 |
+
if num_augs != len(img_metas):
|
233 |
+
raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})")
|
234 |
+
# all images in the same aug batch all of the same ori_shape and pad
|
235 |
+
# shape
|
236 |
+
for img_meta in img_metas:
|
237 |
+
ori_shapes = [_["ori_shape"] for _ in img_meta]
|
238 |
+
assert all(shape == ori_shapes[0] for shape in ori_shapes)
|
239 |
+
img_shapes = [_["img_shape"] for _ in img_meta]
|
240 |
+
assert all(shape == img_shapes[0] for shape in img_shapes)
|
241 |
+
pad_shapes = [_["pad_shape"] for _ in img_meta]
|
242 |
+
assert all(shape == pad_shapes[0] for shape in pad_shapes)
|
243 |
+
|
244 |
+
if num_augs == 1:
|
245 |
+
return self.simple_test(imgs[0], img_metas[0], **kwargs)
|
246 |
+
else:
|
247 |
+
return self.aug_test(imgs, img_metas, **kwargs)
|
248 |
+
|
249 |
+
def forward(self, img, img_metas, return_loss=True, **kwargs):
|
250 |
+
"""Calls either :func:`forward_train` or :func:`forward_test` depending
|
251 |
+
on whether ``return_loss`` is ``True``.
|
252 |
+
|
253 |
+
Note this setting will change the expected inputs. When
|
254 |
+
``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
|
255 |
+
and List[dict]), and when ``resturn_loss=False``, img and img_meta
|
256 |
+
should be double nested (i.e. List[Tensor], List[List[dict]]), with
|
257 |
+
the outer list indicating test time augmentations.
|
258 |
+
"""
|
259 |
+
if return_loss:
|
260 |
+
return self.forward_train(img, img_metas, **kwargs)
|
261 |
+
else:
|
262 |
+
return self.forward_test(img, img_metas, **kwargs)
|
263 |
+
|
264 |
+
def train_step(self, data_batch, optimizer, **kwargs):
|
265 |
+
"""The iteration step during training.
|
266 |
+
|
267 |
+
This method defines an iteration step during training, except for the
|
268 |
+
back propagation and optimizer updating, which are done in an optimizer
|
269 |
+
hook. Note that in some complicated cases or models, the whole process
|
270 |
+
including back propagation and optimizer updating is also defined in
|
271 |
+
this method, such as GAN.
|
272 |
+
|
273 |
+
Args:
|
274 |
+
data (dict): The output of dataloader.
|
275 |
+
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
|
276 |
+
runner is passed to ``train_step()``. This argument is unused
|
277 |
+
and reserved.
|
278 |
+
|
279 |
+
Returns:
|
280 |
+
dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
|
281 |
+
``num_samples``.
|
282 |
+
``loss`` is a tensor for back propagation, which can be a
|
283 |
+
weighted sum of multiple losses.
|
284 |
+
``log_vars`` contains all the variables to be sent to the
|
285 |
+
logger.
|
286 |
+
``num_samples`` indicates the batch size (when the model is
|
287 |
+
DDP, it means the batch size on each GPU), which is used for
|
288 |
+
averaging the logs.
|
289 |
+
"""
|
290 |
+
losses = self(**data_batch)
|
291 |
+
|
292 |
+
# split losses and images
|
293 |
+
real_losses = {}
|
294 |
+
log_imgs = {}
|
295 |
+
for k, v in losses.items():
|
296 |
+
if "img" in k:
|
297 |
+
log_imgs[k] = v
|
298 |
+
else:
|
299 |
+
real_losses[k] = v
|
300 |
+
|
301 |
+
loss, log_vars = self._parse_losses(real_losses)
|
302 |
+
|
303 |
+
outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs)
|
304 |
+
|
305 |
+
return outputs
|
306 |
+
|
307 |
+
def val_step(self, data_batch, **kwargs):
|
308 |
+
"""The iteration step during validation.
|
309 |
+
|
310 |
+
This method shares the same signature as :func:`train_step`, but used
|
311 |
+
during val epochs. Note that the evaluation after training epochs is
|
312 |
+
not implemented with this method, but an evaluation hook.
|
313 |
+
"""
|
314 |
+
output = self(**data_batch, **kwargs)
|
315 |
+
return output
|
316 |
+
|
317 |
+
@staticmethod
|
318 |
+
def _parse_losses(losses):
|
319 |
+
import torch.distributed as dist
|
320 |
+
|
321 |
+
"""Parse the raw outputs (losses) of the network.
|
322 |
+
|
323 |
+
Args:
|
324 |
+
losses (dict): Raw output of the network, which usually contain
|
325 |
+
losses and other necessary information.
|
326 |
+
|
327 |
+
Returns:
|
328 |
+
tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
|
329 |
+
which may be a weighted sum of all losses, log_vars contains
|
330 |
+
all the variables to be sent to the logger.
|
331 |
+
"""
|
332 |
+
log_vars = OrderedDict()
|
333 |
+
for loss_name, loss_value in losses.items():
|
334 |
+
if isinstance(loss_value, torch.Tensor):
|
335 |
+
log_vars[loss_name] = loss_value.mean()
|
336 |
+
elif isinstance(loss_value, list):
|
337 |
+
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
|
338 |
+
else:
|
339 |
+
raise TypeError(f"{loss_name} is not a tensor or list of tensors")
|
340 |
+
|
341 |
+
loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key)
|
342 |
+
|
343 |
+
log_vars["loss"] = loss
|
344 |
+
for loss_name, loss_value in log_vars.items():
|
345 |
+
# reduce loss when distributed training
|
346 |
+
if dist.is_available() and dist.is_initialized():
|
347 |
+
loss_value = loss_value.data.clone()
|
348 |
+
dist.all_reduce(loss_value.div_(dist.get_world_size()))
|
349 |
+
log_vars[loss_name] = loss_value.item()
|
350 |
+
|
351 |
+
return loss, log_vars
|