QZFantasies commited on
Commit
c614b0f
·
1 Parent(s): 7d7a0a6

add wheels

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. LHM/__init__.py +15 -0
  3. LHM/__pycache__/__init__.cpython-310.pyc +0 -0
  4. LHM/__pycache__/launch.cpython-310.pyc +0 -0
  5. LHM/datasets/__init__.py +16 -0
  6. LHM/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
  7. LHM/datasets/__pycache__/cam_utils.cpython-310.pyc +0 -0
  8. LHM/datasets/__pycache__/mixer.cpython-310.pyc +0 -0
  9. LHM/datasets/base.py +70 -0
  10. LHM/datasets/bedlam.py +493 -0
  11. LHM/datasets/bedlam_util.py +306 -0
  12. LHM/datasets/cam_utils.py +205 -0
  13. LHM/datasets/mixer.py +120 -0
  14. LHM/launch.py +35 -0
  15. LHM/losses/__init__.py +20 -0
  16. LHM/losses/ball_loss.py +54 -0
  17. LHM/losses/offset_loss.py +52 -0
  18. LHM/losses/perceptual.py +70 -0
  19. LHM/losses/pixelwise.py +58 -0
  20. LHM/losses/tvloss.py +55 -0
  21. LHM/models/ESRGANer_utils.py +482 -0
  22. LHM/models/__init__.py +30 -0
  23. LHM/models/__pycache__/ESRGANer_utils.cpython-310.pyc +0 -0
  24. LHM/models/__pycache__/__init__.cpython-310.pyc +0 -0
  25. LHM/models/__pycache__/arcface_utils.cpython-310.pyc +0 -0
  26. LHM/models/__pycache__/embedder.cpython-310.pyc +0 -0
  27. LHM/models/__pycache__/modeling_human_lrm.cpython-310.pyc +0 -0
  28. LHM/models/__pycache__/transformer.cpython-310.pyc +0 -0
  29. LHM/models/__pycache__/transformer_dit.cpython-310.pyc +0 -0
  30. LHM/models/__pycache__/utils.cpython-310.pyc +0 -0
  31. LHM/models/arcface_utils.py +360 -0
  32. LHM/models/block.py +124 -0
  33. LHM/models/discriminator.py +120 -0
  34. LHM/models/embedder.py +37 -0
  35. LHM/models/encoders/__init__.py +15 -0
  36. LHM/models/encoders/__pycache__/__init__.cpython-310.pyc +0 -0
  37. LHM/models/encoders/__pycache__/dinov2_fusion_wrapper.cpython-310.pyc +0 -0
  38. LHM/models/encoders/__pycache__/sapiens_warpper.cpython-310.pyc +0 -0
  39. LHM/models/encoders/dino_wrapper.py +68 -0
  40. LHM/models/encoders/dinov2/__init__.py +15 -0
  41. LHM/models/encoders/dinov2/__pycache__/__init__.cpython-310.pyc +0 -0
  42. LHM/models/encoders/dinov2/hub/__init__.py +4 -0
  43. LHM/models/encoders/dinov2/hub/__pycache__/__init__.cpython-310.pyc +0 -0
  44. LHM/models/encoders/dinov2/hub/__pycache__/backbones.cpython-310.pyc +0 -0
  45. LHM/models/encoders/dinov2/hub/__pycache__/utils.cpython-310.pyc +0 -0
  46. LHM/models/encoders/dinov2/hub/backbones.py +166 -0
  47. LHM/models/encoders/dinov2/hub/classifiers.py +268 -0
  48. LHM/models/encoders/dinov2/hub/depth/__init__.py +7 -0
  49. LHM/models/encoders/dinov2/hub/depth/decode_heads.py +747 -0
  50. LHM/models/encoders/dinov2/hub/depth/encoder_decoder.py +351 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.whl filter=lfs diff=lfs merge=lfs -text
LHM/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # Empty
LHM/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (153 Bytes). View file
 
LHM/__pycache__/launch.cpython-310.pyc ADDED
Binary file (743 Bytes). View file
 
LHM/datasets/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from .mixer import MixerDataset
LHM/datasets/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (204 Bytes). View file
 
LHM/datasets/__pycache__/cam_utils.cpython-310.pyc ADDED
Binary file (5.43 kB). View file
 
LHM/datasets/__pycache__/mixer.cpython-310.pyc ADDED
Binary file (2.91 kB). View file
 
LHM/datasets/base.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Organization : Alibaba XR-Lab
3
+ # @Author : Peihao Li & Lingteng Qiu & Xiaodong Gu & Qi Zuo
4
+ # @Email : [email protected]
5
+ # @Time : 2025-03-10 18:47:56
6
+ # @Function : dataset base
7
+
8
+ import json
9
+ import pdb
10
+ import traceback
11
+ from abc import ABC, abstractmethod
12
+
13
+ import numpy as np
14
+ import torch
15
+ from megfile import smart_exists, smart_open, smart_path_join
16
+ from PIL import Image
17
+
18
+
19
+ class BaseDataset(torch.utils.data.Dataset, ABC):
20
+ def __init__(self, root_dirs: str, meta_path: str):
21
+ super().__init__()
22
+ self.root_dirs = root_dirs
23
+ self.uids = self._load_uids(meta_path)
24
+
25
+ def __len__(self):
26
+ return len(self.uids)
27
+
28
+ @abstractmethod
29
+ def inner_get_item(self, idx):
30
+ pass
31
+
32
+ def __getitem__(self, idx):
33
+ try:
34
+ return self.inner_get_item(idx)
35
+ except Exception as e:
36
+ traceback.print_exc()
37
+ print(f"[DEBUG-DATASET] Error when loading {self.uids[idx]}")
38
+ # raise e
39
+ return self.__getitem__((idx + 1) % self.__len__())
40
+
41
+ @staticmethod
42
+ def _load_uids(meta_path: str):
43
+ # meta_path is a json file
44
+ if meta_path == None:
45
+ uids = []
46
+ else:
47
+ with open(meta_path, "r") as f:
48
+ uids = json.load(f)
49
+
50
+ return uids
51
+
52
+ @staticmethod
53
+ def _load_rgba_image(file_path, bg_color: float = 1.0):
54
+ """Load and blend RGBA image to RGB with certain background, 0-1 scaled"""
55
+ rgba = np.array(Image.open(smart_open(file_path, "rb")))
56
+ rgba = torch.from_numpy(rgba).float() / 255.0
57
+ rgba = rgba.permute(2, 0, 1).unsqueeze(0)
58
+ rgb = rgba[:, :3, :, :] * rgba[:, 3:4, :, :] + bg_color * (
59
+ 1 - rgba[:, 3:, :, :]
60
+ )
61
+ # rgba[:, :3, ...] * rgba[:, 3:, ...] + (1 - rgba[:, 3:, ...])
62
+ return rgb
63
+
64
+ @staticmethod
65
+ def _locate_datadir(root_dirs, uid, locator: str):
66
+ for root_dir in root_dirs:
67
+ datadir = smart_path_join(root_dir, uid, locator)
68
+ if smart_exists(datadir):
69
+ return root_dir
70
+ raise FileNotFoundError(f"Cannot find valid data directory for uid {uid}")
LHM/datasets/bedlam.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import glob
17
+
18
+ # from megfile import smart_path_join, smart_open
19
+ import json
20
+ import os
21
+ import random
22
+ from collections import defaultdict
23
+ from typing import Union
24
+
25
+ import cv2
26
+ import numpy as np
27
+ import torch
28
+ from PIL import Image
29
+
30
+ from LHM.datasets.base import BaseDataset
31
+ from LHM.datasets.cam_utils import (
32
+ build_camera_principle,
33
+ build_camera_standard,
34
+ camera_normalization_objaverse,
35
+ )
36
+ from LHM.utils.proxy import no_proxy
37
+
38
+ __all__ = ["BedlamDataset"]
39
+
40
+
41
+ class BedlamDataset(BaseDataset):
42
+
43
+ def __init__(
44
+ self,
45
+ root_dirs: str,
46
+ meta_path: str,
47
+ sample_side_views: int,
48
+ render_image_res_low: int,
49
+ render_image_res_high: int,
50
+ render_region_size: int,
51
+ source_image_res: int,
52
+ repeat_num=1,
53
+ crop_range_ratio_hw=[1.0, 1.0],
54
+ valid_area_ratio=0.4,
55
+ debug=False,
56
+ **kwargs,
57
+ ):
58
+ super().__init__(root_dirs, meta_path)
59
+ self.sample_side_views = sample_side_views
60
+ self.render_image_res_low = render_image_res_low
61
+ self.render_image_res_high = render_image_res_high
62
+ if not (
63
+ isinstance(render_region_size, list)
64
+ or isinstance(render_region_size, tuple)
65
+ ):
66
+ render_region_size = render_region_size, render_region_size # [H, W]
67
+ self.render_region_size = render_region_size
68
+ self.source_image_res = source_image_res
69
+
70
+ self.uids = self.uids * repeat_num
71
+ self.crop_range_ratio_hw = crop_range_ratio_hw
72
+ self.debug = debug
73
+ self.valid_area_ratio = valid_area_ratio
74
+ print(
75
+ f"BedlamDataset, data_len:{len(self.uids)}, repeat_num:{repeat_num}, debug:{debug}"
76
+ )
77
+ self.multiply = kwargs.get("multiply", 14)
78
+
79
+ @staticmethod
80
+ def _load_pose(pose):
81
+ intrinsic = torch.eye(4)
82
+ intrinsic[0, 0] = pose["focal"][0]
83
+ intrinsic[1, 1] = pose["focal"][1]
84
+ intrinsic[0, 2] = pose["princpt"][0]
85
+ intrinsic[1, 2] = pose["princpt"][1]
86
+ intrinsic = intrinsic.float()
87
+
88
+ c2w = torch.eye(4)
89
+ # c2w[:3, :3] = torch.tensor(pose["R"])
90
+ # c2w[3, :3] = torch.tensor(pose["t"])
91
+ c2w = c2w.float()
92
+
93
+ return c2w, intrinsic
94
+
95
+ def load_rgb_image_with_aug_bg(self, rgb_path, mask_path, bg_color):
96
+ rgb = np.array(Image.open(rgb_path))
97
+ rgb = torch.from_numpy(rgb).float() / 255.0
98
+ rgb = rgb.permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
99
+ mask = None
100
+
101
+ if mask_path is not None:
102
+ mask = np.array(Image.open(mask_path))
103
+ mask = torch.from_numpy(mask).float() / 255.0
104
+ mask = (mask > 0.5).float()
105
+ if len(mask.shape) == 3:
106
+ mask = mask[:, :, 0:1]
107
+ if len(mask.shape) == 2:
108
+ mask = mask.unsqueeze(-1)
109
+ mask = mask.permute(2, 0, 1).unsqueeze(0) # [1, 1, H, W]
110
+ rgb = torch.cat([rgb, mask], dim=1) # [1, 4, H, W]
111
+ else:
112
+ mask = rgb[:, 3:4, :, :]
113
+
114
+ # erode mask
115
+ mask_np = (mask[0, 0].numpy() * 255).astype(np.uint8)
116
+ kernel_size, iterations = 3, 1
117
+ kernel = np.ones((kernel_size, kernel_size), np.uint8)
118
+ mask_np = cv2.erode(mask_np, kernel, iterations=iterations)
119
+ mask = torch.from_numpy(mask_np).unsqueeze(0).unsqueeze(0) / 255.0
120
+ mask = (mask > 0.5).float()
121
+ rgb = torch.cat([rgb[:, :3], mask], dim=1) # [1, 4, H, W]
122
+
123
+ if rgb.shape[1] == 4:
124
+ rgb = rgb[:, :3, :, :] * rgb[:, 3:4, :, :] + bg_color * (
125
+ 1 - rgb[:, 3:, :, :]
126
+ )
127
+
128
+ return rgb, mask
129
+
130
+ def scale_intrs(self, intrs, ratio_x, ratio_y):
131
+ intrs[:, 0] = intrs[:, 0] * ratio_x
132
+ intrs[:, 1] = intrs[:, 1] * ratio_y
133
+ return intrs
134
+
135
+ def uniform_sample_in_chunk(self, sample_num, sample_data):
136
+ chunks = np.array_split(sample_data, sample_num)
137
+ select_list = []
138
+ for chunk in chunks:
139
+ select_list.append(np.random.choice(chunk))
140
+ return select_list
141
+
142
+ @no_proxy
143
+ def inner_get_item(self, idx):
144
+ """
145
+ Loaded contents:
146
+ rgbs: [M, 3, H, W]
147
+ poses: [M, 3, 4], [R|t]
148
+ intrinsics: [3, 2], [[fx, fy], [cx, cy], [weight, height]]
149
+ """
150
+ uid = self.uids[idx]
151
+ seq_id = uid["seq_id"]
152
+ all_frame_info = uid["all_frame_info"]
153
+ uid = os.path.join(self.root_dirs, seq_id)
154
+ valid_imgs = [
155
+ e["frame_name"]
156
+ for e in all_frame_info
157
+ if e["valid_area_ratio"] > self.valid_area_ratio
158
+ ]
159
+ assert len(valid_imgs) >= 1
160
+
161
+ if self.sample_side_views + 1 <= len(valid_imgs):
162
+ cam_id_list = np.random.choice(
163
+ valid_imgs, self.sample_side_views + 1, replace=False
164
+ )
165
+ else:
166
+ cam_id_list = np.random.choice(
167
+ valid_imgs, self.sample_side_views + 1, replace=True
168
+ )
169
+
170
+ assert self.sample_side_views + 1 == len(cam_id_list)
171
+ crop_ratio_h, crop_ratio_w = self.crop_range_ratio_hw
172
+
173
+ frame_id_list = cam_id_list
174
+
175
+ # source images
176
+ c2ws, intrs, rgbs, bg_colors, masks = [], [], [], [], []
177
+ source_c2ws, source_intrs, source_rgbs = [], [], []
178
+ smplx_params = []
179
+ shape_param = None
180
+ for cam_id, frame_id in zip(cam_id_list, frame_id_list):
181
+ frame_path = os.path.join(uid, cam_id + ".png")
182
+ frame_name = os.path.splitext(os.path.basename(frame_path))[0]
183
+ smplx_path = os.path.join(
184
+ uid.replace("/png_post/", "/smplx/"), f"{frame_name}.json"
185
+ )
186
+
187
+ with open(smplx_path) as f:
188
+ smplx_param = {
189
+ k: torch.FloatTensor(v)
190
+ for k, v in json.load(f).items()
191
+ if "valid_area_ratio" not in k
192
+ }
193
+
194
+ # if cam_id == 0:
195
+ shape_param = smplx_param["betas"]
196
+
197
+ c2w, intrinsic = self._load_pose(smplx_param)
198
+
199
+ bg_color = random.choice([0.0, 0.5, 1.0])
200
+ rgb, mask = self.load_rgb_image_with_aug_bg(
201
+ frame_path, mask_path=None, bg_color=bg_color
202
+ )
203
+
204
+ # crop image to enlarge human area.
205
+ if (crop_ratio_h < 1.0) or (crop_ratio_w < 1.0):
206
+ img_size_hw = rgb.shape[2], rgb.shape[3]
207
+ h_crop, w_crop = round(img_size_hw[0] * crop_ratio_h), round(
208
+ img_size_hw[1] * crop_ratio_w
209
+ )
210
+ h_crop_offset, w_crop_offset = round(
211
+ (img_size_hw[0] - h_crop) / 2
212
+ ), round((img_size_hw[1] - w_crop) / 2)
213
+ rgb = rgb[
214
+ :,
215
+ :,
216
+ h_crop_offset : h_crop_offset + h_crop,
217
+ w_crop_offset : w_crop_offset + w_crop,
218
+ ]
219
+ mask = mask[
220
+ :,
221
+ :,
222
+ h_crop_offset : h_crop_offset + h_crop,
223
+ w_crop_offset : w_crop_offset + w_crop,
224
+ ]
225
+ intrinsic[0, 2] -= w_crop_offset
226
+ intrinsic[1, 2] -= h_crop_offset
227
+
228
+ assert (
229
+ abs(intrinsic[0, 2] * 2 - rgb.shape[-1]) <= 1
230
+ ), f"{intrinsic[0, 2] * 2}, {rgb.shape[-1]}"
231
+
232
+ c2ws.append(c2w)
233
+ rgbs.append(rgb)
234
+ bg_colors.append(bg_color)
235
+ intrs.append(intrinsic)
236
+ smplx_params.append(smplx_param)
237
+ masks.append(mask)
238
+
239
+ c2ws = torch.stack(c2ws, dim=0) # [N, 4, 4]
240
+ intrs = torch.stack(intrs, dim=0) # [N, 4, 4]
241
+ rgbs = torch.cat(rgbs, dim=0) # [N, 3, H, W]
242
+ bg_colors = (
243
+ torch.tensor(bg_colors, dtype=torch.float32).unsqueeze(-1).repeat(1, 3)
244
+ ) # [N, 3]
245
+ masks = torch.cat(masks, dim=0) # [N, 1, H, W]
246
+
247
+ smplx_params_tmp = defaultdict(list)
248
+ for smplx in smplx_params:
249
+ for k, v in smplx.items():
250
+ smplx_params_tmp[k].append(v)
251
+ for k, v in smplx_params_tmp.items():
252
+ smplx_params_tmp[k] = torch.stack(v)
253
+ smplx_params = smplx_params_tmp
254
+ # TODO check different betas for same person
255
+ smplx_params["betas"] = shape_param
256
+
257
+ # reference images
258
+ # TODO check prob
259
+ ref_idx = np.random.choice(self.sample_side_views + 1)
260
+
261
+ cam_id_source_list = cam_id_list[ref_idx : ref_idx + 1]
262
+ frame_id_source_list = frame_id_list[ref_idx : ref_idx + 1]
263
+
264
+ for cam_id, frame_id in zip(cam_id_source_list, frame_id_source_list):
265
+ frame_path = os.path.join(uid, cam_id + ".png")
266
+ frame_name = os.path.splitext(os.path.basename(frame_path))[0]
267
+ smplx_path = os.path.join(
268
+ uid.replace("/png_post/", "/smplx/"), f"{frame_name}.json"
269
+ )
270
+
271
+ with open(smplx_path) as f:
272
+ smplx_param = {
273
+ k: torch.FloatTensor(v)
274
+ for k, v in json.load(f).items()
275
+ if "valid_area_ratio" not in k
276
+ }
277
+
278
+ c2w, intrinsic = self._load_pose(smplx_param)
279
+
280
+ bg_color = 1.0
281
+ rgb, mask = self.load_rgb_image_with_aug_bg(
282
+ frame_path, mask_path=None, bg_color=bg_color
283
+ )
284
+
285
+ # crop image to enlarge human area.
286
+ if (crop_ratio_h < 1.0) or (crop_ratio_w < 1.0):
287
+ img_size_hw = rgb.shape[2], rgb.shape[3]
288
+ h_crop, w_crop = round(img_size_hw[0] * crop_ratio_h), round(
289
+ img_size_hw[1] * crop_ratio_w
290
+ )
291
+ h_crop_offset, w_crop_offset = round(
292
+ (img_size_hw[0] - h_crop) / 2
293
+ ), round((img_size_hw[1] - w_crop) / 2)
294
+ rgb = rgb[
295
+ :,
296
+ :,
297
+ h_crop_offset : h_crop_offset + h_crop,
298
+ w_crop_offset : w_crop_offset + w_crop,
299
+ ]
300
+ mask = mask[
301
+ :,
302
+ :,
303
+ h_crop_offset : h_crop_offset + h_crop,
304
+ w_crop_offset : w_crop_offset + w_crop,
305
+ ]
306
+ intrinsic[0, 2] -= w_crop_offset
307
+ intrinsic[1, 2] -= h_crop_offset
308
+
309
+ assert (
310
+ abs(intrinsic[0, 2] * 2 - rgb.shape[-1]) <= 1
311
+ ), f"{intrinsic[0, 2] * 2}, {rgb.shape[-1]}"
312
+
313
+ source_c2ws.append(c2w)
314
+ source_intrs.append(intrinsic)
315
+ source_rgbs.append(rgb)
316
+
317
+ source_c2ws = torch.stack(source_c2ws, dim=0)
318
+ source_intrs = torch.stack(source_intrs, dim=0)
319
+ source_rgbs = torch.cat(source_rgbs, dim=0)
320
+
321
+ # adjust source image resolution
322
+ # TODO check 224x224 need to padding?
323
+ # ratio_x, ratio_y = self.source_image_res / source_rgbs.shape[3], self.source_image_res / source_rgbs.shape[2]
324
+ ratio = self.source_image_res / min(source_rgbs.shape[2:])
325
+ tgt_size = int(ratio * source_rgbs.shape[2]), int(ratio * source_rgbs.shape[3])
326
+ multiply = self.multiply
327
+ tgt_size = (
328
+ int(tgt_size[0] / multiply) * multiply,
329
+ int(tgt_size[1] / multiply) * multiply,
330
+ )
331
+ ratio_y, ratio_x = (
332
+ tgt_size[0] / source_rgbs.shape[2],
333
+ tgt_size[1] / source_rgbs.shape[3],
334
+ )
335
+ source_rgbs = torch.nn.functional.interpolate(
336
+ source_rgbs, size=tgt_size, mode="bicubic", align_corners=True
337
+ )
338
+ source_rgbs = torch.clamp(source_rgbs, 0, 1)
339
+ source_intrs = self.scale_intrs(source_intrs, ratio_x=ratio_x, ratio_y=ratio_y)
340
+
341
+ # adjust render image resolution and sample intended rendering region
342
+ render_image_res = np.random.randint(
343
+ self.render_image_res_low, self.render_image_res_high + 1
344
+ )
345
+ ratio = render_image_res / min(rgbs.shape[2:])
346
+ tgt_size = int(ratio * rgbs.shape[2]), int(ratio * rgbs.shape[3])
347
+ # multiply = 14
348
+ # tgt_size = int(tgt_size[0] / multiply) * multiply, int(tgt_size[1] / multiply) * multiply
349
+ # ratio_y, ratio_x = tgt_size[0] / rgbs.shape[2], tgt_size[1] / rgbs.shape[3]
350
+ render_image = torch.nn.functional.interpolate(
351
+ rgbs, size=tgt_size, mode="bicubic", align_corners=True
352
+ )
353
+ render_image = torch.clamp(render_image, 0, 1)
354
+ intrs = self.scale_intrs(intrs, ratio_x=ratio, ratio_y=ratio)
355
+
356
+ render_mask = torch.nn.functional.interpolate(
357
+ masks, size=tgt_size, mode="bicubic", align_corners=True
358
+ )
359
+ render_mask = torch.clamp(render_mask, 0, 1)
360
+
361
+ assert (
362
+ abs(intrs[0, 0, 2] * 2 - render_image.shape[3]) <= 1.1
363
+ ), f"{intrs[0, 0, 2] * 2}, {render_image.shape}"
364
+ assert (
365
+ abs(intrs[0, 1, 2] * 2 - render_image.shape[2]) <= 1.1
366
+ ), f"{intrs[0, 1, 2] * 2}, {render_image.shape}"
367
+
368
+ # anchors = torch.randint(
369
+ # 0, render_image_res - min(self.render_region_size) + 1, size=(self.sample_side_views + 1, 2))
370
+ # crop_indices_h = torch.arange(0, self.render_region_size[0], device=render_image.device)
371
+ # crop_indices_w = torch.arange(0, self.render_region_size[1], device=render_image.device)
372
+ # index_h = (anchors[:, 0].unsqueeze(1) + crop_indices_h).view(-1, self.render_region_size[0], 1)
373
+ # index_w = (anchors[:, 1].unsqueeze(1) + crop_indices_w).view(-1, 1, self.render_region_size[1])
374
+ # batch_indices = torch.arange(self.sample_side_views + 1, device=render_image.device).view(-1, 1, 1)
375
+ # cropped_render_image = render_image[batch_indices, :, index_h, index_w].permute(0, 3, 1, 2)
376
+
377
+ ret = {
378
+ "uid": uid,
379
+ "source_c2ws": source_c2ws, # [N1, 4, 4]
380
+ "source_intrs": source_intrs, # [N1, 4, 4]
381
+ "source_rgbs": source_rgbs, # [N1, 3, H, W]
382
+ "render_image": render_image, # [N, 3, H, W]
383
+ "render_mask": render_mask, # [ N, 1, H, W]
384
+ "c2ws": c2ws, # [N, 4, 4]
385
+ "intrs": intrs, # [N, 4, 4]
386
+ # 'render_anchors': anchors, # [N, 2]
387
+ "render_full_resolutions": torch.tensor(
388
+ [tgt_size], dtype=torch.float32
389
+ ).repeat(
390
+ self.sample_side_views + 1, 1
391
+ ), # [N, 2]
392
+ "render_bg_colors": bg_colors, # [N, 3]
393
+ }
394
+
395
+ # ['root_pose', 'body_pose', 'jaw_pose', 'leye_pose', 'reye_pose', 'lhand_pose', 'rhand_pose', 'expr', 'trans', 'betas']
396
+ # 'smplx_params': smplx_params, # dict: body_pose:[N, 21, 3],
397
+ ret.update(smplx_params)
398
+
399
+ return ret
400
+
401
+
402
+ if __name__ == "__main__":
403
+ import cv2
404
+
405
+ root_dir = "./train_data/bedlam/data/"
406
+ meta_path = "./train_data/bedlam/data/annots/valid_list.json"
407
+ dataset = BedlamDataset(
408
+ root_dirs=root_dir,
409
+ meta_path=meta_path,
410
+ sample_side_views=3,
411
+ render_image_res_low=384,
412
+ render_image_res_high=384,
413
+ render_region_size=(682, 384),
414
+ source_image_res=384,
415
+ valid_area_ratio=0.1,
416
+ debug=False,
417
+ )
418
+
419
+ for data in dataset:
420
+ print(
421
+ "source_c2ws.shape",
422
+ data["source_c2ws"].shape,
423
+ )
424
+ print(
425
+ "source_intrs.shape",
426
+ data["source_intrs"].shape,
427
+ )
428
+ print(
429
+ "source_rgbs.shape",
430
+ data["source_rgbs"].shape,
431
+ )
432
+ print(
433
+ "render_image.shape",
434
+ data["render_image"].shape,
435
+ )
436
+ print(
437
+ "c2ws.shape",
438
+ data["c2ws"].shape,
439
+ )
440
+ print(
441
+ "intrs.shape",
442
+ data["intrs"].shape,
443
+ )
444
+ # print("render_anchors.shape", data["render_anchors"].shape, )
445
+ print(
446
+ "render_full_resolutions.shape",
447
+ data["render_full_resolutions"].shape,
448
+ )
449
+ print(
450
+ "render_bg_colors.shape",
451
+ data["render_bg_colors"].shape,
452
+ )
453
+ # print("smplx_params", data["smplx_params"].keys())
454
+ print("smplx_params.body_pose.shape", data["body_pose"].shape)
455
+ print("smplx_params.expr.shape", data["expr"].shape)
456
+ print("smplx_params.betas.shape", data["betas"].shape)
457
+ os.makedirs("debug_vis/dataloader", exist_ok=True)
458
+ for i in range(data["source_rgbs"].shape[0]):
459
+ cv2.imwrite(
460
+ f"debug_vis/dataloader/source_rgbs_{i}.jpg",
461
+ (
462
+ (
463
+ data["source_rgbs"][i].permute(1, 2, 0).numpy()[:, :, (2, 1, 0)]
464
+ * 255
465
+ ).astype(np.uint8)
466
+ ),
467
+ )
468
+ print(
469
+ "source_rgbs",
470
+ data["source_rgbs"].shape,
471
+ )
472
+ print("source_intrs", data["source_intrs"][i])
473
+
474
+ for i in range(data["render_image"].shape[0]):
475
+ cv2.imwrite(
476
+ f"debug_vis/dataloader/rgbs{i}.jpg",
477
+ (
478
+ (
479
+ data["render_image"][i]
480
+ .permute(1, 2, 0)
481
+ .numpy()[:, :, (2, 1, 0)]
482
+ * 255
483
+ ).astype(np.uint8)
484
+ ),
485
+ )
486
+ print(
487
+ "render_image",
488
+ data["render_image"].shape,
489
+ )
490
+ print("render_full_resolutions", data["render_full_resolutions"][i])
491
+ # print("render_anchors", data["render_anchors"][i])
492
+ print("intrs", data["intrs"][i])
493
+ xx
LHM/datasets/bedlam_util.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Multi-HMR
2
+ # Copyright (c) 2024-present NAVER Corp.
3
+ # CC BY-NC-SA 4.0 license
4
+
5
+ import os
6
+ # os.environ["PYOPENGL_PLATFORM"] = "egl"
7
+ # os.environ['EGL_DEVICE_ID'] = '0'
8
+
9
+ import warnings
10
+ import pickle
11
+ import torch
12
+ import smplx
13
+ from tqdm import tqdm
14
+ import sys
15
+ import numpy as np
16
+ from PIL import Image, ImageOps, ImageFile
17
+ import random
18
+ import json
19
+ import tqdm
20
+ import cv2
21
+ import traceback
22
+ ImageFile.LOAD_TRUNCATED_IMAGES = True # to avoid "OSError: image file is truncated"
23
+ from torch.utils.data import Dataset
24
+
25
+ BEDLAM_DIR = "./train_data/bedlam/data"
26
+ SMPLX_DIR = "./pretrained_models/human_model_files"
27
+ ANNOT_DIR = "./train_data/bedlam/data/annots"
28
+
29
+ class BEDLAMSeg(Dataset):
30
+ def __init__(self,
31
+ split='training',
32
+ training=False,
33
+ img_size=512,
34
+ root_dir=BEDLAM_DIR,
35
+ force_build_dataset=0,
36
+ n_iter=None,
37
+ subsample=1,
38
+ extension='png',
39
+ crops=[0],
40
+ flip=1,
41
+ res=None,
42
+ n=-1,
43
+ ):
44
+ super().__init__()
45
+
46
+ self.name = 'bedlam'
47
+ self.annotations_dir = ANNOT_DIR
48
+ self.training = training
49
+ self.img_size = img_size
50
+ self.n_iter = n_iter
51
+ self.subsample = subsample
52
+ self.crops = crops # 0 is the default
53
+ self.flip = flip # 1 by default
54
+
55
+ assert split in ['training', 'validation']
56
+
57
+ self.root_dir = root_dir
58
+ self.split = split
59
+ self.image_dir = os.path.join(self.root_dir, f"{self.split}")
60
+ self.mask_dir = os.path.join(self.root_dir, "masks")
61
+
62
+ self.annot_file = os.path.join(self.annotations_dir, f"{self.name}_{split}.pkl")
63
+ # self.force_build_dataset = force_build_dataset
64
+
65
+ self.annots = None
66
+ # if self.force_build_dataset or not os.path.isfile(self.annot_file):
67
+ # self.annots = self.build_dataset()
68
+ if self.annots is None:
69
+ with open(self.annot_file, 'rb') as f:
70
+ self.annots = pickle.load(f)
71
+
72
+ self.imagenames = list(self.annots.keys())
73
+ self.imagenames.sort()
74
+
75
+
76
+ def __len__(self):
77
+ return len(self.imagenames)
78
+
79
+ def __repr__(self):
80
+ return f"{self.name}: split={self.split} - N={len(self.imagenames)}"
81
+
82
+ def save_smplx_params_to_json(self, person, focal, princpt, valid_area_ratio, save_path):
83
+ smplx_params = {}
84
+ smplx_params["betas"] = person['smplx_shape'].reshape(11).tolist()
85
+ smplx_params["root_pose"] = person["smplx_root_pose"].reshape(3).tolist()
86
+ smplx_params['body_pose'] = person["smplx_body_pose"].tolist()
87
+ smplx_params['jaw_pose'] = person["smplx_jaw_pose"].reshape(3).tolist()
88
+ smplx_params['leye_pose'] = person["smplx_leye_pose"].reshape(3).tolist()
89
+ smplx_params['reye_pose'] = person["smplx_reye_pose"].reshape(3).tolist()
90
+ smplx_params['lhand_pose'] = person["smplx_left_hand_pose"].tolist()
91
+ smplx_params['rhand_pose'] = person["smplx_right_hand_pose"].tolist()
92
+ smplx_params['trans'] = person["smplx_transl"].reshape(3).tolist()
93
+ smplx_params['expr'] = np.zeros(10).tolist()
94
+
95
+ smplx_params['focal'] = focal
96
+ smplx_params['princpt'] = princpt
97
+ smplx_params['valid_area_ratio'] = valid_area_ratio
98
+
99
+ # for k, v in smplx_params.items():
100
+ # print(k, np.array(v).shape)
101
+
102
+ with open(save_path, 'w') as fp:
103
+ json.dump(smplx_params, fp)
104
+
105
+ return smplx_params
106
+
107
+
108
+ def center_crop_and_resize(self, img, mask, princpt_x, princpt_y, fx, fy, area_ratio):
109
+
110
+ ys, xs = np.where(mask > 0)
111
+
112
+ if len(xs) == 0 or len(ys) == 0:
113
+ print(f"unvalid: no body")
114
+ return None
115
+
116
+ x_min = np.min(xs)
117
+ x_max = np.max(xs)
118
+ y_min = np.min(ys)
119
+ y_max = np.max(ys)
120
+
121
+ center_x, center_y = img.shape[1]//2, img.shape[0]//2
122
+
123
+ half_w = max(abs(center_x - x_min), abs(center_x - x_max))
124
+ half_h = max(abs(center_y - y_min), abs(center_y - y_max))
125
+ ratio = half_h / half_w
126
+ ratio_standard= 1280 / 720
127
+ if ratio >= 1:
128
+ if ratio >= ratio_standard:
129
+ half_w = round(half_h / ratio_standard)
130
+ else:
131
+ half_h = round(half_w * ratio_standard)
132
+ else:
133
+ print(f"unvalid: h/w ratio:{ratio}")
134
+ return None
135
+
136
+ assert abs(half_h / half_w - ratio_standard) < 0.1
137
+ offset_x = center_x - half_w
138
+ offset_y = center_y - half_h
139
+
140
+ new_img = img[offset_y: offset_y + 2*half_h, offset_x: offset_x + 2*half_w]
141
+ new_mask = mask[offset_y: offset_y + 2*half_h, offset_x: offset_x + 2*half_w]
142
+
143
+ princpt_x -= offset_x
144
+ princpt_y -= offset_y
145
+
146
+ new_img = cv2.resize(new_img, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC)
147
+ new_mask = cv2.resize(new_mask, (mask.shape[1], mask.shape[0]), interpolation=cv2.INTER_NEAREST)
148
+
149
+ valid_area_ratio = np.sum(new_mask > 0) / new_mask.shape[0] / new_mask.shape[1]
150
+ if valid_area_ratio < area_ratio:
151
+ print(f"unvalid: area ratio:{valid_area_ratio}")
152
+ return None
153
+
154
+ scale = img.shape[0] / 2. / half_h
155
+
156
+ fx *= scale
157
+ princpt_x *= scale
158
+
159
+ fy *= scale
160
+ princpt_y *= scale
161
+
162
+ new_img = np.concatenate([new_img, new_mask[:, :, None]], axis=2)
163
+
164
+ return new_img, princpt_x, princpt_y, fx, fy, valid_area_ratio
165
+
166
+
167
+ def __getitem__(self, idx):
168
+ imagename = self.imagenames[idx]
169
+ annot = self.annots[imagename].copy()
170
+ annot['imagename'] = imagename
171
+
172
+ # find appropriate image_dir
173
+ img_path = os.path.join(self.image_dir, imagename)
174
+
175
+ mask_path = os.path.join(self.mask_dir, imagename.replace("_6fps/png/", "/masks/").replace(".png", "_env.png"))
176
+ assert os.path.exists(mask_path), f"mask_path:{mask_path}"
177
+
178
+ # Original size
179
+ real_width, real_height = annot['size']
180
+
181
+ # preprocessing the image
182
+ img_pil = Image.open(img_path)
183
+ if img_pil.mode != 'RGB':
184
+ img_pil = img_pil.convert('RGB')
185
+
186
+ # BEDLAM specifc to correct the rotation issue
187
+ # https://github.com/pixelite1201/BEDLAM/blob/ebf8bb14a43de46cc74dca4c00c13e571b325726/visualize_ground_truth.py#L183
188
+ if self.name == 'bedlam' and 'closeup' in imagename and self.split != 'test':
189
+ img_pil = img_pil.rotate(-90, expand=True)
190
+
191
+
192
+ # preprocessing the image
193
+ mask_pil = Image.open(mask_path)
194
+ # if mask_pil.mode != 'RGB':
195
+ # img_pil = img_pil.convert('RGB')
196
+
197
+ # BEDLAM specifc to correct the rotation issue
198
+ # https://github.com/pixelite1201/BEDLAM/blob/ebf8bb14a43de46cc74dca4c00c13e571b325726/visualize_ground_truth.py#L183
199
+ if self.name == 'bedlam' and 'closeup' in imagename and self.split != 'test':
200
+ mask_pil = mask_pil.rotate(-90, expand=True)
201
+
202
+ img = np.asarray(img_pil)
203
+ mask = np.asarray(mask_pil)
204
+ mask = 255 * (mask < 1).astype(np.uint8)
205
+
206
+ princpt, focal = annot['princpt'], annot['focal']
207
+
208
+ ret = self.center_crop_and_resize(img, mask, princpt[0], princpt[1], focal[0], focal[1], area_ratio=0.05)
209
+ if ret is None:
210
+ print(f"unvalid, img_path:{img_path}")
211
+ return
212
+
213
+ new_img, princpt_x, princpt_y, fx, fy, valid_area_ratio = ret
214
+ # print(new_img.shape, princpt_x, princpt_y, fx, fy, "ori", princpt, focal)
215
+
216
+ princpt = princpt_x, princpt_y
217
+ focal = fx, fy
218
+
219
+
220
+ save_path = img_path.replace("/png/", "/png_post/")
221
+ save_vis_path = img_path.replace("/png/", "/png_post_vis/")
222
+ save_smplx_path = img_path.replace("/png/", "/smplx/").replace(".png", ".json")
223
+
224
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
225
+ os.makedirs(os.path.dirname(save_vis_path), exist_ok=True)
226
+ os.makedirs(os.path.dirname(save_smplx_path), exist_ok=True)
227
+
228
+ cv2.imwrite(save_path, new_img[:, :, (2, 1, 0, 3)])
229
+ cv2.imwrite(save_vis_path, np.hstack([np.concatenate([img, 255 * np.ones_like(mask[:,:, None])], axis=2), new_img])[:, :, (2, 1, 0, 3)])
230
+
231
+ # Humans
232
+ _humans = annot['humans'].copy()
233
+ # annot.pop('humans')
234
+ # if self.training:
235
+ humans = [hum for hum in _humans if hum['smplx_transl'][-1] > 0.01] # the person should be in front of the camera
236
+ # else:
237
+ # humans = [hum for hum in _humans]
238
+
239
+ assert len(humans) == 1
240
+
241
+ self.save_smplx_params_to_json(humans[0], focal, princpt, valid_area_ratio, save_smplx_path)
242
+
243
+ # return img_array, annot
244
+
245
+ def create_annots(splits=['validation', 'training']):
246
+ for split in splits:
247
+ dataset = BEDLAM(split=split, force_build_dataset=1)
248
+
249
+
250
+ def visualize(split='validation', i=1500, res=None, extension='png', training=0, img_size=800):
251
+ # training - 52287 for a closeup
252
+ from utils import render_meshes, demo_color
253
+ model_neutral = smplx.create(SMPLX_DIR, 'smplx', gender='neutral', num_betas=11, use_pca=False, flat_hand_mean=True)
254
+
255
+ dataset = BEDLAM(split=split, force_build_dataset=0,
256
+ res=res, extension=extension,
257
+ training=training,
258
+ img_size=img_size,
259
+ )
260
+ print(dataset)
261
+
262
+ img_array, annot = dataset.__getitem__(i)
263
+
264
+ img_array = denormalize_rgb(img_array, imagenet_normalization=1)
265
+ verts_list = []
266
+ for person in annot['humans']:
267
+ with torch.no_grad():
268
+ verts = model_neutral(
269
+ global_orient=torch.from_numpy(person['smplx_root_pose']).reshape(1,-1),
270
+ body_pose=torch.from_numpy(person['smplx_body_pose']).reshape(1,-1),
271
+ jaw_pose=torch.from_numpy(person['smplx_jaw_pose']).reshape(1,-1),
272
+ leye_pose=torch.from_numpy(person['smplx_leye_pose']).reshape(1,-1),
273
+ reye_pose=torch.from_numpy(person['smplx_reye_pose']).reshape(1,-1),
274
+ left_hand_pose=torch.from_numpy(person['smplx_left_hand_pose']).reshape(1,-1),
275
+ right_hand_pose=torch.from_numpy(person['smplx_right_hand_pose']).reshape(1,-1),
276
+ betas=torch.from_numpy(person['smplx_shape']).reshape(1,-1),
277
+ transl=torch.from_numpy(person['smplx_transl']).reshape(1,-1),
278
+ ).vertices.cpu().numpy().reshape(-1,3)
279
+ verts_list.append(verts)
280
+ faces_list = [model_neutral.faces for _ in annot['humans']]
281
+ _color = [demo_color[0] for _ in annot['humans']]
282
+ pred_rend_array = render_meshes(img_array.copy(),
283
+ verts_list,
284
+ faces_list,
285
+ {'focal': annot['K'][[0,1],[0,1]],
286
+ 'princpt': annot['K'][[0,1],[-1,-1]]},
287
+ alpha=0.7,
288
+ color=_color)
289
+ img_array = np.concatenate([img_array, np.asarray(pred_rend_array)], 1)
290
+
291
+ fn = f"{dataset.name}_{split}_{i}.jpg"
292
+ Image.fromarray(img_array).save(fn)
293
+ print(f"open {fn}")
294
+ return 1
295
+
296
+
297
+
298
+ if __name__ == "__main__":
299
+ # exec(sys.argv[1])
300
+ dataset = BEDLAMSeg(split="validation")
301
+ for i in tqdm.tqdm(range(len(dataset))):
302
+ try:
303
+ dataset.__getitem__(i)
304
+ except:
305
+ traceback.print_exc()
306
+ continue
LHM/datasets/cam_utils.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import math
17
+ import torch
18
+
19
+ """
20
+ R: (N, 3, 3)
21
+ T: (N, 3)
22
+ E: (N, 4, 4)
23
+ vector: (N, 3)
24
+ """
25
+
26
+
27
+ def compose_extrinsic_R_T(R: torch.Tensor, T: torch.Tensor):
28
+ """
29
+ Compose the standard form extrinsic matrix from R and T.
30
+ Batched I/O.
31
+ """
32
+ RT = torch.cat((R, T.unsqueeze(-1)), dim=-1)
33
+ return compose_extrinsic_RT(RT)
34
+
35
+
36
+ def compose_extrinsic_RT(RT: torch.Tensor):
37
+ """
38
+ Compose the standard form extrinsic matrix from RT.
39
+ Batched I/O.
40
+ """
41
+ return torch.cat([
42
+ RT,
43
+ torch.tensor([[[0, 0, 0, 1]]], dtype=RT.dtype, device=RT.device).repeat(RT.shape[0], 1, 1)
44
+ ], dim=1)
45
+
46
+
47
+ def decompose_extrinsic_R_T(E: torch.Tensor):
48
+ """
49
+ Decompose the standard extrinsic matrix into R and T.
50
+ Batched I/O.
51
+ """
52
+ RT = decompose_extrinsic_RT(E)
53
+ return RT[:, :, :3], RT[:, :, 3]
54
+
55
+
56
+ def decompose_extrinsic_RT(E: torch.Tensor):
57
+ """
58
+ Decompose the standard extrinsic matrix into RT.
59
+ Batched I/O.
60
+ """
61
+ return E[:, :3, :]
62
+
63
+
64
+ def camera_normalization_objaverse(normed_dist_to_center, poses: torch.Tensor, ret_transform: bool = False):
65
+ assert normed_dist_to_center is not None
66
+ pivotal_pose = compose_extrinsic_RT(poses[:1])
67
+ dist_to_center = pivotal_pose[:, :3, 3].norm(dim=-1, keepdim=True).item() \
68
+ if normed_dist_to_center == 'auto' else normed_dist_to_center
69
+
70
+ # compute camera norm (new version)
71
+ canonical_camera_extrinsics = torch.tensor([[
72
+ [1, 0, 0, 0],
73
+ [0, 0, -1, -dist_to_center],
74
+ [0, 1, 0, 0],
75
+ [0, 0, 0, 1],
76
+ ]], dtype=torch.float32)
77
+ pivotal_pose_inv = torch.inverse(pivotal_pose)
78
+ camera_norm_matrix = torch.bmm(canonical_camera_extrinsics, pivotal_pose_inv)
79
+
80
+ # normalize all views
81
+ poses = compose_extrinsic_RT(poses)
82
+ poses = torch.bmm(camera_norm_matrix.repeat(poses.shape[0], 1, 1), poses)
83
+ poses = decompose_extrinsic_RT(poses)
84
+
85
+ if ret_transform:
86
+ return poses, camera_norm_matrix.squeeze(dim=0)
87
+ return poses
88
+
89
+
90
+ def get_normalized_camera_intrinsics(intrinsics: torch.Tensor):
91
+ """
92
+ intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
93
+ Return batched fx, fy, cx, cy
94
+ """
95
+ fx, fy = intrinsics[:, 0, 0], intrinsics[:, 0, 1]
96
+ cx, cy = intrinsics[:, 1, 0], intrinsics[:, 1, 1]
97
+ width, height = intrinsics[:, 2, 0], intrinsics[:, 2, 1]
98
+ fx, fy = fx / width, fy / height
99
+ cx, cy = cx / width, cy / height
100
+ return fx, fy, cx, cy
101
+
102
+
103
+ def build_camera_principle(RT: torch.Tensor, intrinsics: torch.Tensor):
104
+ """
105
+ RT: (N, 3, 4)
106
+ intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
107
+ """
108
+ fx, fy, cx, cy = get_normalized_camera_intrinsics(intrinsics)
109
+ return torch.cat([
110
+ RT.reshape(-1, 12),
111
+ fx.unsqueeze(-1), fy.unsqueeze(-1), cx.unsqueeze(-1), cy.unsqueeze(-1),
112
+ ], dim=-1)
113
+
114
+
115
+ def build_camera_standard(RT: torch.Tensor, intrinsics: torch.Tensor):
116
+ """
117
+ RT: (N, 3, 4)
118
+ intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
119
+ """
120
+ E = compose_extrinsic_RT(RT)
121
+ fx, fy, cx, cy = get_normalized_camera_intrinsics(intrinsics)
122
+ I = torch.stack([
123
+ torch.stack([fx, torch.zeros_like(fx), cx], dim=-1),
124
+ torch.stack([torch.zeros_like(fy), fy, cy], dim=-1),
125
+ torch.tensor([[0, 0, 1]], dtype=torch.float32, device=RT.device).repeat(RT.shape[0], 1),
126
+ ], dim=1)
127
+ return torch.cat([
128
+ E.reshape(-1, 16),
129
+ I.reshape(-1, 9),
130
+ ], dim=-1)
131
+
132
+
133
+ def center_looking_at_camera_pose(
134
+ camera_position: torch.Tensor, look_at: torch.Tensor = None, up_world: torch.Tensor = None,
135
+ device: torch.device = torch.device('cpu'),
136
+ ):
137
+ """
138
+ camera_position: (M, 3)
139
+ look_at: (3)
140
+ up_world: (3)
141
+ return: (M, 3, 4)
142
+ """
143
+ # by default, looking at the origin and world up is pos-z
144
+ if look_at is None:
145
+ look_at = torch.tensor([0, 0, 0], dtype=torch.float32, device=device)
146
+ if up_world is None:
147
+ up_world = torch.tensor([0, 0, 1], dtype=torch.float32, device=device)
148
+ look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1)
149
+ up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1)
150
+
151
+ z_axis = camera_position - look_at
152
+ z_axis = z_axis / z_axis.norm(dim=-1, keepdim=True)
153
+ x_axis = torch.cross(up_world, z_axis)
154
+ x_axis = x_axis / x_axis.norm(dim=-1, keepdim=True)
155
+ y_axis = torch.cross(z_axis, x_axis)
156
+ y_axis = y_axis / y_axis.norm(dim=-1, keepdim=True)
157
+ extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1)
158
+ return extrinsics
159
+
160
+
161
+ def surrounding_views_linspace(n_views: int, radius: float = 2.0, height: float = 0.8, device: torch.device = torch.device('cpu')):
162
+ """
163
+ n_views: number of surrounding views
164
+ radius: camera dist to center
165
+ height: height of the camera
166
+ return: (M, 3, 4)
167
+ """
168
+ assert n_views > 0
169
+ assert radius > 0
170
+
171
+ theta = torch.linspace(-torch.pi / 2, 3 * torch.pi / 2, n_views, device=device)
172
+ projected_radius = math.sqrt(radius ** 2 - height ** 2)
173
+ x = torch.cos(theta) * projected_radius
174
+ y = torch.sin(theta) * projected_radius
175
+ z = torch.full((n_views,), height, device=device)
176
+
177
+ camera_positions = torch.stack([x, y, z], dim=1)
178
+ extrinsics = center_looking_at_camera_pose(camera_positions, device=device)
179
+
180
+ return extrinsics
181
+
182
+
183
+ def create_intrinsics(
184
+ f: float,
185
+ c: float = None, cx: float = None, cy: float = None,
186
+ w: float = 1., h: float = 1.,
187
+ dtype: torch.dtype = torch.float32,
188
+ device: torch.device = torch.device('cpu'),
189
+ ):
190
+ """
191
+ return: (3, 2)
192
+ """
193
+ fx = fy = f
194
+ if c is not None:
195
+ assert cx is None and cy is None, "c and cx/cy cannot be used together"
196
+ cx = cy = c
197
+ else:
198
+ assert cx is not None and cy is not None, "cx/cy must be provided when c is not provided"
199
+ fx, fy, cx, cy, w, h = fx/w, fy/h, cx/w, cy/h, 1., 1.
200
+ intrinsics = torch.tensor([
201
+ [fx, fy],
202
+ [cx, cy],
203
+ [w, h],
204
+ ], dtype=dtype, device=device)
205
+ return intrinsics
LHM/datasets/mixer.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");:
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import math
17
+ import pdb
18
+ from functools import partial
19
+
20
+ import torch
21
+
22
+ __all__ = ["MixerDataset"]
23
+
24
+
25
+ class MixerDataset(torch.utils.data.Dataset):
26
+ """Reference"""
27
+
28
+ def __init__(
29
+ self,
30
+ split: str,
31
+ subsets: dict,
32
+ **dataset_kwargs,
33
+ ):
34
+
35
+ self.subsets = [
36
+ self._dataset_fn(subset, split)(
37
+ use_flame=subset["use_flame"],
38
+ src_head_size=subset.get("src_head_size", 448),
39
+ **dataset_kwargs,
40
+ )
41
+ for subset in subsets
42
+ ]
43
+ self.virtual_lens = [
44
+ math.ceil(subset_config["sample_rate"] * len(subset_obj))
45
+ for subset_config, subset_obj in zip(subsets, self.subsets)
46
+ ]
47
+
48
+ @staticmethod
49
+ def _dataset_fn(subset_config: dict, split: str):
50
+ name = subset_config["name"]
51
+
52
+ dataset_cls = None
53
+ if name == "exavatar":
54
+ from .exavatar import ExAvatarDataset
55
+
56
+ dataset_cls = ExAvatarDataset
57
+ elif name == "humman":
58
+ from .humman import HuMManDataset
59
+
60
+ dataset_cls = HuMManDataset
61
+ elif name == "static_human":
62
+ from .static_human import StaticHumanDataset
63
+
64
+ dataset_cls = StaticHumanDataset
65
+ elif name == "singleview_human":
66
+ from .singleview_human import SingleViewHumanDataset
67
+
68
+ dataset_cls = SingleViewHumanDataset
69
+ elif name == "singleview_square_human":
70
+ from .singleview_square_human import SingleViewSquareHumanDataset
71
+
72
+ dataset_cls = SingleViewSquareHumanDataset
73
+ elif name == "bedlam":
74
+ from .bedlam import BedlamDataset
75
+
76
+ dataset_cls = BedlamDataset
77
+ elif name == "dna_human":
78
+ from .dna import DNAHumanDataset
79
+
80
+ dataset_cls = DNAHumanDataset
81
+ elif name == "video_human":
82
+ from .video_human import VideoHumanDataset
83
+
84
+ dataset_cls = VideoHumanDataset
85
+ elif name == "video_human_flame":
86
+ from .video_human_flame import VideoHumanFlameDataset
87
+
88
+ dataset_cls = VideoHumanFlameDataset
89
+ elif name == "video_human_flame_dp":
90
+ from .video_human_flame_df import VideoHumanFlameDFDataset
91
+
92
+ # add deepfashon random sample in video_human_flame
93
+ dataset_cls = VideoHumanFlameDFDataset
94
+ elif name == "objaverse":
95
+ from .objaverse import ObjaverseDataset
96
+
97
+ dataset_cls = ObjaverseDataset
98
+ # elif name == 'mvimgnet':
99
+ # from .mvimgnet import MVImgNetDataset
100
+ # dataset_cls = MVImgNetDataset
101
+ else:
102
+ raise NotImplementedError(f"Dataset {name} not implemented")
103
+
104
+ return partial(
105
+ dataset_cls,
106
+ root_dirs=subset_config["root_dirs"],
107
+ meta_path=subset_config["meta_path"][split],
108
+ )
109
+
110
+ def __len__(self):
111
+ return sum(self.virtual_lens)
112
+
113
+ def __getitem__(self, idx):
114
+ subset_idx = 0
115
+ virtual_idx = idx
116
+ while virtual_idx >= self.virtual_lens[subset_idx]:
117
+ virtual_idx -= self.virtual_lens[subset_idx]
118
+ subset_idx += 1
119
+ real_idx = virtual_idx % len(self.subsets[subset_idx])
120
+ return self.subsets[subset_idx][real_idx]
LHM/launch.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import argparse
15
+ import pdb
16
+
17
+ from LHM.runners import REGISTRY_RUNNERS
18
+
19
+
20
+ def main():
21
+
22
+ parser = argparse.ArgumentParser(description="OpenLRM launcher")
23
+ parser.add_argument("runner", type=str, help="Runner to launch")
24
+ args, unknown = parser.parse_known_args()
25
+
26
+ if args.runner not in REGISTRY_RUNNERS:
27
+ raise ValueError("Runner {} not found".format(args.runner))
28
+
29
+ RunnerClass = REGISTRY_RUNNERS[args.runner]
30
+ with RunnerClass() as runner:
31
+ runner.run()
32
+
33
+
34
+ if __name__ == "__main__":
35
+ main()
LHM/losses/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from .ball_loss import *
17
+ from .offset_loss import *
18
+ from .perceptual import *
19
+ from .pixelwise import *
20
+ from .tvloss import *
LHM/losses/ball_loss.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Organization : Alibaba XR-Lab
3
+ # @Author : Lingteng Qiu
4
+ # @Email : [email protected]
5
+ # @Time : 2025-03-10 19:08:35
6
+ # @Function : ASAP loss
7
+ import pdb
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ __all__ = ["ASAP_Loss", "Heuristic_ASAP_Loss"]
13
+
14
+
15
+ class ASAP_Loss(nn.Module):
16
+
17
+ def forward(self, scaling, r=1, **params):
18
+ """where r is the radius of the ball between max-axis and min-axis."""
19
+ raise NotImplementedError(
20
+ "ASAP_Loss is not implemented yet in Inference version"
21
+ )
22
+
23
+
24
+ class Heuristic_ASAP_Loss(nn.Module):
25
+ def __init__(self, group_dict, group_body_mapping):
26
+ super(Heuristic_ASAP_Loss, self).__init__()
27
+
28
+ self.group_dict = group_dict # register weights fro different body parts
29
+ self.group_body_mapping = group_body_mapping # mapping of body parts to group
30
+
31
+ def _heurisitic_loss(self, _ball_loss):
32
+
33
+ _loss = 0.0
34
+ for key in self.group_dict.keys():
35
+ key_weights = self.group_dict[key]
36
+ group_mapping_idx = self.group_body_mapping[key]
37
+ _loss += key_weights * _ball_loss[:, group_mapping_idx].mean()
38
+
39
+ return _loss
40
+
41
+ def forward(self, scaling, r=5, **params):
42
+ """where r is the radius of the ball between max-axis and min-axis."""
43
+ "human motion or rotation is very different in each body parts, for example, the head is more stable than the leg and hand, so we use heuristic_ball_loss"
44
+
45
+ _scale = scaling
46
+
47
+ _scale_min = torch.min(_scale, dim=-1)[0]
48
+ _scale_max = torch.max(_scale, dim=-1)[0]
49
+
50
+ scale_ratio = _scale_max / (_scale_min + 1e-6)
51
+
52
+ _ball_loss = torch.clamp(scale_ratio, min=r) - r
53
+
54
+ return self._heurisitic_loss(_ball_loss)
LHM/losses/offset_loss.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Organization : Alibaba XR-Lab
3
+ # @Author : Lingteng Qiu
4
+ # @Email : [email protected]
5
+ # @Time : 2025-03-10 19:08:56
6
+ # @Function : ACAP Loss
7
+ import pdb
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ __all__ = ["ACAP_Loss", "Heuristic_ACAP_Loss"]
14
+
15
+
16
+ class ACAP_Loss(nn.Module):
17
+ """As close as possibel loss"""
18
+
19
+ def forward(self, offset, d=0.05625, **params):
20
+ """Empirically, where d is the thresold of distance points leave from 1.8/32 = 0.0562."""
21
+
22
+ offset_loss = torch.clamp(offset.norm(p=2, dim=-1), min=d) - d
23
+
24
+ return offset_loss.mean()
25
+
26
+
27
+ class Heuristic_ACAP_Loss(nn.Module):
28
+ """As close as possibel loss"""
29
+
30
+ def __init__(self, group_dict, group_body_mapping):
31
+ super(Heuristic_ACAP_Loss, self).__init__()
32
+
33
+ self.group_dict = group_dict # register weights fro different body parts
34
+ self.group_body_mapping = group_body_mapping # mapping of body parts to group
35
+
36
+ def _heurisitic_loss(self, _offset_loss):
37
+
38
+ _loss = 0.0
39
+ for key in self.group_dict.keys():
40
+ key_weights = self.group_dict[key]
41
+ group_mapping_idx = self.group_body_mapping[key]
42
+ _loss += key_weights * _offset_loss[:, group_mapping_idx].mean()
43
+
44
+ return _loss
45
+
46
+ def forward(self, offset, d=0.05625, **params):
47
+ """Empirically, where d is the thresold of distance points leave from human prior model, 1.8/32 = 0.0562."""
48
+ "human motion or rotation is very different in each body parts, for example, the head is more stable than the leg and hand, so we use heuristic_ball_loss"
49
+
50
+ _offset_loss = torch.clamp(offset.norm(p=2, dim=-1), min=d) - d
51
+
52
+ return self._heurisitic_loss(_offset_loss)
LHM/losses/perceptual.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ __all__ = ['LPIPSLoss']
20
+
21
+
22
+ class LPIPSLoss(nn.Module):
23
+ """
24
+ Compute LPIPS loss between two images.
25
+ """
26
+
27
+ def __init__(self, device, prefech: bool = False):
28
+ super().__init__()
29
+ self.device = device
30
+ self.cached_models = {}
31
+ if prefech:
32
+ self.prefetch_models()
33
+
34
+ def _get_model(self, model_name: str):
35
+ if model_name not in self.cached_models:
36
+ import warnings
37
+ with warnings.catch_warnings():
38
+ warnings.filterwarnings('ignore', category=UserWarning)
39
+ import lpips
40
+ _model = lpips.LPIPS(net=model_name, eval_mode=True, verbose=False).to(self.device)
41
+ _model = torch.compile(_model)
42
+ self.cached_models[model_name] = _model
43
+ return self.cached_models[model_name]
44
+
45
+ def prefetch_models(self):
46
+ _model_names = ['alex', 'vgg']
47
+ for model_name in _model_names:
48
+ self._get_model(model_name)
49
+
50
+ def forward(self, x, y, is_training: bool = True):
51
+ """
52
+ Assume images are 0-1 scaled and channel first.
53
+
54
+ Args:
55
+ x: [N, M, C, H, W]
56
+ y: [N, M, C, H, W]
57
+ is_training: whether to use VGG or AlexNet.
58
+
59
+ Returns:
60
+ Mean-reduced LPIPS loss across batch.
61
+ """
62
+ model_name = 'vgg' if is_training else 'alex'
63
+ loss_fn = self._get_model(model_name)
64
+ N, M, C, H, W = x.shape
65
+ x = x.reshape(N*M, C, H, W)
66
+ y = y.reshape(N*M, C, H, W)
67
+ image_loss = loss_fn(x, y, normalize=True).mean(dim=[1, 2, 3])
68
+ batch_loss = image_loss.reshape(N, M).mean(dim=1)
69
+ all_loss = batch_loss.mean()
70
+ return all_loss
LHM/losses/pixelwise.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ __all__ = ['PixelLoss']
20
+
21
+
22
+ class PixelLoss(nn.Module):
23
+ """
24
+ Pixel-wise loss between two images.
25
+ """
26
+
27
+ def __init__(self, option: str = 'mse'):
28
+ super().__init__()
29
+ self.loss_fn = self._build_from_option(option)
30
+
31
+ @staticmethod
32
+ def _build_from_option(option: str, reduction: str = 'none'):
33
+ if option == 'mse':
34
+ return nn.MSELoss(reduction=reduction)
35
+ elif option == 'l1':
36
+ return nn.L1Loss(reduction=reduction)
37
+ else:
38
+ raise NotImplementedError(f'Unknown pixel loss option: {option}')
39
+
40
+ @torch.compile
41
+ def forward(self, x, y):
42
+ """
43
+ Assume images are channel first.
44
+
45
+ Args:
46
+ x: [N, M, C, H, W]
47
+ y: [N, M, C, H, W]
48
+
49
+ Returns:
50
+ Mean-reduced pixel loss across batch.
51
+ """
52
+ N, M, C, H, W = x.shape
53
+ x = x.reshape(N*M, C, H, W)
54
+ y = y.reshape(N*M, C, H, W)
55
+ image_loss = self.loss_fn(x, y).mean(dim=[1, 2, 3])
56
+ batch_loss = image_loss.reshape(N, M).mean(dim=1)
57
+ all_loss = batch_loss.mean()
58
+ return all_loss
LHM/losses/tvloss.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ __all__ = ['TVLoss']
20
+
21
+
22
+ class TVLoss(nn.Module):
23
+ """
24
+ Total variance loss.
25
+ """
26
+
27
+ def __init__(self):
28
+ super().__init__()
29
+
30
+ def numel_excluding_first_dim(self, x):
31
+ return x.numel() // x.shape[0]
32
+
33
+ @torch.compile
34
+ def forward(self, x):
35
+ """
36
+ Assume batched and channel first with inner sizes.
37
+
38
+ Args:
39
+ x: [N, M, C, H, W]
40
+
41
+ Returns:
42
+ Mean-reduced TV loss with element-level scaling.
43
+ """
44
+ N, M, C, H, W = x.shape
45
+ x = x.reshape(N*M, C, H, W)
46
+ diff_i = x[..., 1:, :] - x[..., :-1, :]
47
+ diff_j = x[..., :, 1:] - x[..., :, :-1]
48
+ div_i = self.numel_excluding_first_dim(diff_i)
49
+ div_j = self.numel_excluding_first_dim(diff_j)
50
+ tv_i = diff_i.pow(2).sum(dim=[1,2,3]) / div_i
51
+ tv_j = diff_j.pow(2).sum(dim=[1,2,3]) / div_j
52
+ tv = tv_i + tv_j
53
+ batch_tv = tv.reshape(N, M).mean(dim=1)
54
+ all_tv = batch_tv.mean()
55
+ return all_tv
LHM/models/ESRGANer_utils.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Organization : Alibaba XR-Lab
3
+ # @Author : Lingteng Qiu
4
+ # @Email : [email protected]
5
+ # @Time : 2025-03-1 17:39:52
6
+ # @Function : Function to improve face quality when training.
7
+
8
+ import math
9
+ import os
10
+ import queue
11
+ import sys
12
+
13
+ sys.path.append("./")
14
+ import threading
15
+
16
+ import cv2
17
+ import numpy as np
18
+ import torch
19
+ from basicsr.utils.download_util import load_file_from_url
20
+ from torch.nn import functional as F
21
+
22
+ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
23
+ import pdb
24
+
25
+ import torch
26
+ from basicsr.archs.rrdbnet_arch import RRDBNet
27
+
28
+
29
+ def avaliable_device():
30
+ if torch.cuda.is_available():
31
+ current_device_id = torch.cuda.current_device()
32
+ device = f"cuda:{current_device_id}"
33
+ else:
34
+ device = "cpu"
35
+
36
+ return device
37
+
38
+
39
+ class RealESRGANer:
40
+ """A helper class for upsampling images with RealESRGAN.
41
+
42
+ Args:
43
+ scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
44
+ model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
45
+ model (nn.Module): The defined network. Default: None.
46
+ tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
47
+ input images into tiles, and then process each of them. Finally, they will be merged into one image.
48
+ 0 denotes for do not use tile. Default: 0.
49
+ tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
50
+ pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
51
+ half (float): Whether to use half precision during inference. Default: False.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ scale,
57
+ model_path,
58
+ dni_weight=None,
59
+ model=None,
60
+ tile=0,
61
+ tile_pad=10,
62
+ pre_pad=10,
63
+ half=False,
64
+ device=None,
65
+ gpu_id=None,
66
+ ):
67
+ self.scale = scale
68
+ self.tile_size = tile
69
+ self.tile_pad = tile_pad
70
+ self.pre_pad = pre_pad
71
+ self.mod_scale = None
72
+ self.half = half
73
+
74
+ # initialize model
75
+ if gpu_id:
76
+ self.device = (
77
+ torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
78
+ if device is None
79
+ else device
80
+ )
81
+ else:
82
+ self.device = (
83
+ torch.device("cuda" if torch.cuda.is_available() else "cpu")
84
+ if device is None
85
+ else device
86
+ )
87
+
88
+ if isinstance(model_path, list):
89
+ # dni
90
+ assert len(model_path) == len(
91
+ dni_weight
92
+ ), "model_path and dni_weight should have the save length."
93
+ loadnet = self.dni(model_path[0], model_path[1], dni_weight)
94
+ else:
95
+ # if the model_path starts with https, it will first download models to the folder: weights
96
+ if model_path.startswith("https://"):
97
+ model_path = load_file_from_url(
98
+ url=model_path,
99
+ model_dir=os.path.join(ROOT_DIR, "weights"),
100
+ progress=True,
101
+ file_name=None,
102
+ )
103
+ loadnet = torch.load(model_path, map_location=torch.device("cpu"))
104
+
105
+ # prefer to use params_ema
106
+ if "params_ema" in loadnet:
107
+ keyname = "params_ema"
108
+ else:
109
+ keyname = "params"
110
+ model.load_state_dict(loadnet[keyname], strict=True)
111
+
112
+ model.eval()
113
+ self.model = model.to(self.device)
114
+ if self.half:
115
+ self.model = self.model.half()
116
+
117
+ def dni(self, net_a, net_b, dni_weight, key="params", loc="cpu"):
118
+ """Deep network interpolation.
119
+
120
+ ``Paper: Deep Network Interpolation for Continuous Imagery Effect Transition``
121
+ """
122
+ net_a = torch.load(net_a, map_location=torch.device(loc))
123
+ net_b = torch.load(net_b, map_location=torch.device(loc))
124
+ for k, v_a in net_a[key].items():
125
+ net_a[key][k] = dni_weight[0] * v_a + dni_weight[1] * net_b[key][k]
126
+ return net_a
127
+
128
+ def pre_process(self, img):
129
+ """Pre-process, such as pre-pad and mod pad, so that the images can be divisible"""
130
+ img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
131
+ self.img = img.unsqueeze(0).to(self.device)
132
+ if self.half:
133
+ self.img = self.img.half()
134
+
135
+ # pre_pad
136
+ if self.pre_pad != 0:
137
+ self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), "reflect")
138
+ # mod pad for divisible borders
139
+ if self.scale == 2:
140
+ self.mod_scale = 2
141
+ elif self.scale == 1:
142
+ self.mod_scale = 4
143
+ if self.mod_scale is not None:
144
+ self.mod_pad_h, self.mod_pad_w = 0, 0
145
+ _, _, h, w = self.img.size()
146
+ if h % self.mod_scale != 0:
147
+ self.mod_pad_h = self.mod_scale - h % self.mod_scale
148
+ if w % self.mod_scale != 0:
149
+ self.mod_pad_w = self.mod_scale - w % self.mod_scale
150
+ self.img = F.pad(
151
+ self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), "reflect"
152
+ )
153
+
154
+ def process(self):
155
+ # model inference
156
+ self.output = self.model(self.img)
157
+
158
+ def tile_process(self):
159
+ """It will first crop input images to tiles, and then process each tile.
160
+ Finally, all the processed tiles are merged into one images.
161
+
162
+ Modified from: https://github.com/ata4/esrgan-launcher
163
+ """
164
+ batch, channel, height, width = self.img.shape
165
+ output_height = height * self.scale
166
+ output_width = width * self.scale
167
+ output_shape = (batch, channel, output_height, output_width)
168
+
169
+ # start with black image
170
+ self.output = self.img.new_zeros(output_shape)
171
+ tiles_x = math.ceil(width / self.tile_size)
172
+ tiles_y = math.ceil(height / self.tile_size)
173
+
174
+ # loop over all tiles
175
+ for y in range(tiles_y):
176
+ for x in range(tiles_x):
177
+ # extract tile from input image
178
+ ofs_x = x * self.tile_size
179
+ ofs_y = y * self.tile_size
180
+ # input tile area on total image
181
+ input_start_x = ofs_x
182
+ input_end_x = min(ofs_x + self.tile_size, width)
183
+ input_start_y = ofs_y
184
+ input_end_y = min(ofs_y + self.tile_size, height)
185
+
186
+ # input tile area on total image with padding
187
+ input_start_x_pad = max(input_start_x - self.tile_pad, 0)
188
+ input_end_x_pad = min(input_end_x + self.tile_pad, width)
189
+ input_start_y_pad = max(input_start_y - self.tile_pad, 0)
190
+ input_end_y_pad = min(input_end_y + self.tile_pad, height)
191
+
192
+ # input tile dimensions
193
+ input_tile_width = input_end_x - input_start_x
194
+ input_tile_height = input_end_y - input_start_y
195
+ tile_idx = y * tiles_x + x + 1
196
+ input_tile = self.img[
197
+ :,
198
+ :,
199
+ input_start_y_pad:input_end_y_pad,
200
+ input_start_x_pad:input_end_x_pad,
201
+ ]
202
+
203
+ # upscale tile
204
+ try:
205
+ with torch.no_grad():
206
+ output_tile = self.model(input_tile)
207
+ except RuntimeError as error:
208
+ print("Error", error)
209
+ print(f"\tTile {tile_idx}/{tiles_x * tiles_y}")
210
+
211
+ # output tile area on total image
212
+ output_start_x = input_start_x * self.scale
213
+ output_end_x = input_end_x * self.scale
214
+ output_start_y = input_start_y * self.scale
215
+ output_end_y = input_end_y * self.scale
216
+
217
+ # output tile area without padding
218
+ output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
219
+ output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
220
+ output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
221
+ output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
222
+
223
+ # put tile into output image
224
+ self.output[
225
+ :, :, output_start_y:output_end_y, output_start_x:output_end_x
226
+ ] = output_tile[
227
+ :,
228
+ :,
229
+ output_start_y_tile:output_end_y_tile,
230
+ output_start_x_tile:output_end_x_tile,
231
+ ]
232
+
233
+ def post_process(self):
234
+ # remove extra pad
235
+ if self.mod_scale is not None:
236
+ _, _, h, w = self.output.size()
237
+ self.output = self.output[
238
+ :,
239
+ :,
240
+ 0 : h - self.mod_pad_h * self.scale,
241
+ 0 : w - self.mod_pad_w * self.scale,
242
+ ]
243
+ # remove prepad
244
+ if self.pre_pad != 0:
245
+ _, _, h, w = self.output.size()
246
+ self.output = self.output[
247
+ :,
248
+ :,
249
+ 0 : h - self.pre_pad * self.scale,
250
+ 0 : w - self.pre_pad * self.scale,
251
+ ]
252
+ return self.output
253
+
254
+ @torch.no_grad()
255
+ def enhance(self, img, outscale=None, alpha_upsampler="realesrgan"):
256
+ h_input, w_input = img.shape[0:2]
257
+ # img: numpy
258
+ img = img.astype(np.float32)
259
+ if np.max(img) > 256: # 16-bit image
260
+ max_range = 65535
261
+ print("\tInput is a 16-bit image")
262
+ else:
263
+ max_range = 255
264
+ img = img / max_range
265
+ if len(img.shape) == 2: # gray image
266
+ img_mode = "L"
267
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
268
+ elif img.shape[2] == 4: # RGBA image with alpha channel
269
+ img_mode = "RGBA"
270
+ alpha = img[:, :, 3]
271
+ img = img[:, :, 0:3]
272
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
273
+ if alpha_upsampler == "realesrgan":
274
+ alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
275
+ else:
276
+ img_mode = "RGB"
277
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
278
+
279
+ # ------------------- process image (without the alpha channel) ------------------- #
280
+ self.pre_process(img)
281
+ if self.tile_size > 0:
282
+ self.tile_process()
283
+ else:
284
+ self.process()
285
+ output_img = self.post_process()
286
+ output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
287
+ output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
288
+ if img_mode == "L":
289
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
290
+
291
+ # ------------------- process the alpha channel if necessary ------------------- #
292
+ if img_mode == "RGBA":
293
+ if alpha_upsampler == "realesrgan":
294
+ self.pre_process(alpha)
295
+ if self.tile_size > 0:
296
+ self.tile_process()
297
+ else:
298
+ self.process()
299
+ output_alpha = self.post_process()
300
+ output_alpha = (
301
+ output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
302
+ )
303
+ output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
304
+ output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
305
+ else: # use the cv2 resize for alpha channel
306
+ h, w = alpha.shape[0:2]
307
+ output_alpha = cv2.resize(
308
+ alpha,
309
+ (w * self.scale, h * self.scale),
310
+ interpolation=cv2.INTER_LINEAR,
311
+ )
312
+
313
+ # merge the alpha channel
314
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
315
+ output_img[:, :, 3] = output_alpha
316
+
317
+ # ------------------------------ return ------------------------------ #
318
+ if max_range == 65535: # 16-bit image
319
+ output = (output_img * 65535.0).round().astype(np.uint16)
320
+ else:
321
+ output = (output_img * 255.0).round().astype(np.uint8)
322
+
323
+ if outscale is not None and outscale != float(self.scale):
324
+ output = cv2.resize(
325
+ output,
326
+ (
327
+ int(w_input * outscale),
328
+ int(h_input * outscale),
329
+ ),
330
+ interpolation=cv2.INTER_LANCZOS4,
331
+ )
332
+
333
+ return output, img_mode
334
+
335
+
336
+ class PrefetchReader(threading.Thread):
337
+ """Prefetch images.
338
+
339
+ Args:
340
+ img_list (list[str]): A image list of image paths to be read.
341
+ num_prefetch_queue (int): Number of prefetch queue.
342
+ """
343
+
344
+ def __init__(self, img_list, num_prefetch_queue):
345
+ super().__init__()
346
+ self.que = queue.Queue(num_prefetch_queue)
347
+ self.img_list = img_list
348
+
349
+ def run(self):
350
+ for img_path in self.img_list:
351
+ img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
352
+ self.que.put(img)
353
+
354
+ self.que.put(None)
355
+
356
+ def __next__(self):
357
+ next_item = self.que.get()
358
+ if next_item is None:
359
+ raise StopIteration
360
+ return next_item
361
+
362
+ def __iter__(self):
363
+ return self
364
+
365
+
366
+ class IOConsumer(threading.Thread):
367
+
368
+ def __init__(self, opt, que, qid):
369
+ super().__init__()
370
+ self._queue = que
371
+ self.qid = qid
372
+ self.opt = opt
373
+
374
+ def run(self):
375
+ while True:
376
+ msg = self._queue.get()
377
+ if isinstance(msg, str) and msg == "quit":
378
+ break
379
+
380
+ output = msg["output"]
381
+ save_path = msg["save_path"]
382
+ cv2.imwrite(save_path, output)
383
+ print(f"IO worker {self.qid} is done.")
384
+
385
+
386
+ class ESRGANEasyModel:
387
+ def __init__(
388
+ self, model_path="./pretrained_models/RealESRGAN_x4plus.pth", face_enhance=True
389
+ ):
390
+ model = RRDBNet(
391
+ num_in_ch=3,
392
+ num_out_ch=3,
393
+ num_feat=64,
394
+ num_block=23,
395
+ num_grow_ch=32,
396
+ scale=4,
397
+ )
398
+ self.net_scale = 4
399
+ file_url = [
400
+ "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
401
+ ]
402
+ if model_path is None:
403
+ model_path = os.path.join("weights", args.model_name + ".pth")
404
+ if not os.path.isfile(model_path):
405
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
406
+ for url in file_url:
407
+ # model_path will be updated
408
+ model_path = load_file_from_url(
409
+ url=url,
410
+ model_dir=os.path.join("./", "pretrained_models"),
411
+ progress=True,
412
+ file_name=None,
413
+ )
414
+ self.face_enhance = face_enhance
415
+
416
+ dni_weight = None
417
+
418
+ self.upsampler = RealESRGANer(
419
+ scale=self.net_scale,
420
+ model_path=model_path,
421
+ dni_weight=dni_weight,
422
+ model=model,
423
+ tile=0,
424
+ tile_pad=10,
425
+ pre_pad=0,
426
+ half=False,
427
+ )
428
+
429
+ self.upsampler.model.to(avaliable_device())
430
+ if face_enhance: # Use GFPGAN for face enhancement
431
+ from gfpgan import GFPGANer
432
+
433
+ self.face_enhancer = GFPGANer(
434
+ model_path="https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth",
435
+ upscale=4,
436
+ arch="clean",
437
+ channel_multiplier=2,
438
+ bg_upsampler=self.upsampler,
439
+ )
440
+ else:
441
+ self.face_enhancer = None
442
+
443
+ @torch.no_grad()
444
+ def __call__(self, img):
445
+ if self.face_enhancer is not None:
446
+ _, _, output = self.face_enhancer.enhance(
447
+ img, has_aligned=False, only_center_face=False, paste_back=True
448
+ )
449
+ else:
450
+ output, _ = self.upsampler.enhance(img, outscale=4)
451
+ return output
452
+
453
+ def __repr__(self):
454
+ return f"ESRGANEasyModel:\n {self.upsampler}"
455
+
456
+
457
+ if __name__ == "__main__":
458
+
459
+ import time
460
+
461
+ model = ESRGANEasyModel(face_enhance=True)
462
+ input_img = "./debug/face_debug/gt/head_gt_0.png"
463
+
464
+ img_np = cv2.imread(input_img)
465
+ set1 = [
466
+ "./debug/face_debug/gt/head_gt_0.png",
467
+ "./debug/face_debug/gt/head_gt_1.png",
468
+ "./debug/face_debug/gt/head_gt_2.png",
469
+ "./debug/face_debug/gt/head_gt_3.png",
470
+ "./debug/face_debug/gt/head_gt_4.png",
471
+ "./debug/face_debug/gt/head_gt_5.png",
472
+ "./debug/face_debug/gt/head_gt_6.png",
473
+ "./debug/face_debug/gt/head_gt_0.png",
474
+ ]
475
+ img_set1 = [cv2.imread(img_path) for img_path in set1]
476
+
477
+ sr = model(img_set1[0])
478
+
479
+ s0 = time.time()
480
+ for img in img_set1:
481
+
482
+ sr = model(img)
LHM/models/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from .modeling_human_lrm import (
17
+ ModelHumanLRM,
18
+ ModelHumanLRMSapdinoBodyHeadSD3,
19
+ ModelHumanLRMSapdinoBodyHeadSD3_5,
20
+ ModelHumanLRMSapdinoSD3,
21
+ ModelHumanLRMSD3,
22
+ )
23
+
24
+ model_dict = {
25
+ "human_lrm": ModelHumanLRM,
26
+ "human_lrm_sd3": ModelHumanLRMSD3,
27
+ "human_lrm_sapdino_sd3": ModelHumanLRMSapdinoSD3,
28
+ "human_lrm_sapdino_bh_sd3": ModelHumanLRMSapdinoBodyHeadSD3,
29
+ "human_lrm_sapdino_bh_sd3_5": ModelHumanLRMSapdinoBodyHeadSD3_5,
30
+ }
LHM/models/__pycache__/ESRGANer_utils.cpython-310.pyc ADDED
Binary file (11.9 kB). View file
 
LHM/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (508 Bytes). View file
 
LHM/models/__pycache__/arcface_utils.cpython-310.pyc ADDED
Binary file (9.71 kB). View file
 
LHM/models/__pycache__/embedder.cpython-310.pyc ADDED
Binary file (1.01 kB). View file
 
LHM/models/__pycache__/modeling_human_lrm.cpython-310.pyc ADDED
Binary file (23.9 kB). View file
 
LHM/models/__pycache__/transformer.cpython-310.pyc ADDED
Binary file (6.82 kB). View file
 
LHM/models/__pycache__/transformer_dit.cpython-310.pyc ADDED
Binary file (15 kB). View file
 
LHM/models/__pycache__/utils.cpython-310.pyc ADDED
Binary file (1.42 kB). View file
 
LHM/models/arcface_utils.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Organization : Alibaba XR-Lab
3
+ # @Author : Lingteng Qiu
4
+ # @Email : [email protected]
5
+ # @Time : 2025-03-10 17:38:29
6
+ # @Function : Arc-Similarity Loss
7
+ import sys
8
+
9
+ sys.path.append(".")
10
+
11
+ import pdb
12
+ from copy import deepcopy
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from torch.nn.parallel import DataParallel, DistributedDataParallel
18
+
19
+
20
+ def conv3x3(inplanes, outplanes, stride=1):
21
+ """A simple wrapper for 3x3 convolution with padding.
22
+
23
+ Args:
24
+ inplanes (int): Channel number of inputs.
25
+ outplanes (int): Channel number of outputs.
26
+ stride (int): Stride in convolution. Default: 1.
27
+ """
28
+ return nn.Conv2d(
29
+ inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False
30
+ )
31
+
32
+
33
+ class BasicBlock(nn.Module):
34
+ """Basic residual block used in the ResNetArcFace architecture.
35
+
36
+ Args:
37
+ inplanes (int): Channel number of inputs.
38
+ planes (int): Channel number of outputs.
39
+ stride (int): Stride in convolution. Default: 1.
40
+ downsample (nn.Module): The downsample module. Default: None.
41
+ """
42
+
43
+ expansion = 1 # output channel expansion ratio
44
+
45
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
46
+ super(BasicBlock, self).__init__()
47
+ self.conv1 = conv3x3(inplanes, planes, stride)
48
+ self.bn1 = nn.BatchNorm2d(planes)
49
+ self.relu = nn.ReLU(inplace=True)
50
+ self.conv2 = conv3x3(planes, planes)
51
+ self.bn2 = nn.BatchNorm2d(planes)
52
+ self.downsample = downsample
53
+ self.stride = stride
54
+
55
+ def forward(self, x):
56
+ residual = x
57
+
58
+ out = self.conv1(x)
59
+ out = self.bn1(out)
60
+ out = self.relu(out)
61
+
62
+ out = self.conv2(out)
63
+ out = self.bn2(out)
64
+
65
+ if self.downsample is not None:
66
+ residual = self.downsample(x)
67
+
68
+ out += residual
69
+ out = self.relu(out)
70
+
71
+ return out
72
+
73
+
74
+ class IRBlock(nn.Module):
75
+ """Improved residual block (IR Block) used in the ResNetArcFace architecture.
76
+
77
+ Args:
78
+ inplanes (int): Channel number of inputs.
79
+ planes (int): Channel number of outputs.
80
+ stride (int): Stride in convolution. Default: 1.
81
+ downsample (nn.Module): The downsample module. Default: None.
82
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
83
+ """
84
+
85
+ expansion = 1 # output channel expansion ratio
86
+
87
+ def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
88
+ super(IRBlock, self).__init__()
89
+ self.bn0 = nn.BatchNorm2d(inplanes)
90
+ self.conv1 = conv3x3(inplanes, inplanes)
91
+ self.bn1 = nn.BatchNorm2d(inplanes)
92
+ self.prelu = nn.PReLU()
93
+ self.conv2 = conv3x3(inplanes, planes, stride)
94
+ self.bn2 = nn.BatchNorm2d(planes)
95
+ self.downsample = downsample
96
+ self.stride = stride
97
+ self.use_se = use_se
98
+ if self.use_se:
99
+ self.se = SEBlock(planes)
100
+
101
+ def forward(self, x):
102
+ residual = x
103
+ out = self.bn0(x)
104
+ out = self.conv1(out)
105
+ out = self.bn1(out)
106
+ out = self.prelu(out)
107
+
108
+ out = self.conv2(out)
109
+ out = self.bn2(out)
110
+ if self.use_se:
111
+ out = self.se(out)
112
+
113
+ if self.downsample is not None:
114
+ residual = self.downsample(x)
115
+
116
+ out += residual
117
+ out = self.prelu(out)
118
+
119
+ return out
120
+
121
+
122
+ class Bottleneck(nn.Module):
123
+ """Bottleneck block used in the ResNetArcFace architecture.
124
+
125
+ Args:
126
+ inplanes (int): Channel number of inputs.
127
+ planes (int): Channel number of outputs.
128
+ stride (int): Stride in convolution. Default: 1.
129
+ downsample (nn.Module): The downsample module. Default: None.
130
+ """
131
+
132
+ expansion = 4 # output channel expansion ratio
133
+
134
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
135
+ super(Bottleneck, self).__init__()
136
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
137
+ self.bn1 = nn.BatchNorm2d(planes)
138
+ self.conv2 = nn.Conv2d(
139
+ planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
140
+ )
141
+ self.bn2 = nn.BatchNorm2d(planes)
142
+ self.conv3 = nn.Conv2d(
143
+ planes, planes * self.expansion, kernel_size=1, bias=False
144
+ )
145
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
146
+ self.relu = nn.ReLU(inplace=True)
147
+ self.downsample = downsample
148
+ self.stride = stride
149
+
150
+ def forward(self, x):
151
+ residual = x
152
+
153
+ out = self.conv1(x)
154
+ out = self.bn1(out)
155
+ out = self.relu(out)
156
+
157
+ out = self.conv2(out)
158
+ out = self.bn2(out)
159
+ out = self.relu(out)
160
+
161
+ out = self.conv3(out)
162
+ out = self.bn3(out)
163
+
164
+ if self.downsample is not None:
165
+ residual = self.downsample(x)
166
+
167
+ out += residual
168
+ out = self.relu(out)
169
+
170
+ return out
171
+
172
+
173
+ class SEBlock(nn.Module):
174
+ """The squeeze-and-excitation block (SEBlock) used in the IRBlock.
175
+
176
+ Args:
177
+ channel (int): Channel number of inputs.
178
+ reduction (int): Channel reduction ration. Default: 16.
179
+ """
180
+
181
+ def __init__(self, channel, reduction=16):
182
+ super(SEBlock, self).__init__()
183
+ self.avg_pool = nn.AdaptiveAvgPool2d(
184
+ 1
185
+ ) # pool to 1x1 without spatial information
186
+ self.fc = nn.Sequential(
187
+ nn.Linear(channel, channel // reduction),
188
+ nn.PReLU(),
189
+ nn.Linear(channel // reduction, channel),
190
+ nn.Sigmoid(),
191
+ )
192
+
193
+ def forward(self, x):
194
+ b, c, _, _ = x.size()
195
+ y = self.avg_pool(x).view(b, c)
196
+ y = self.fc(y).view(b, c, 1, 1)
197
+ return x * y
198
+
199
+
200
+ class ResNetArcFace(nn.Module):
201
+ """ArcFace with ResNet architectures.
202
+
203
+ Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
204
+
205
+ Args:
206
+ block (str): Block used in the ArcFace architecture.
207
+ layers (tuple(int)): Block numbers in each layer.
208
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
209
+ """
210
+
211
+ def __init__(
212
+ self,
213
+ block="IRBlock",
214
+ layers=[2, 2, 2, 2],
215
+ use_se=False,
216
+ pretrain_model="./pretrained_models/arcface_resnet18.pth",
217
+ ):
218
+ if block == "IRBlock":
219
+ block = IRBlock
220
+ self.inplanes = 64
221
+ self.use_se = use_se
222
+ super(ResNetArcFace, self).__init__()
223
+
224
+ self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
225
+ self.bn1 = nn.BatchNorm2d(64)
226
+ self.prelu = nn.PReLU()
227
+ self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
228
+ self.layer1 = self._make_layer(block, 64, layers[0])
229
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
230
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
231
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
232
+ self.bn4 = nn.BatchNorm2d(512)
233
+ self.dropout = nn.Dropout()
234
+ self.fc5 = nn.Linear(512 * 8 * 8, 512)
235
+ self.bn5 = nn.BatchNorm1d(512)
236
+
237
+ # initialization
238
+ for m in self.modules():
239
+ if isinstance(m, nn.Conv2d):
240
+ nn.init.xavier_normal_(m.weight)
241
+ elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
242
+ nn.init.constant_(m.weight, 1)
243
+ nn.init.constant_(m.bias, 0)
244
+ elif isinstance(m, nn.Linear):
245
+ nn.init.xavier_normal_(m.weight)
246
+ nn.init.constant_(m.bias, 0)
247
+
248
+ if pretrain_model is not None:
249
+ self.load_network(self, pretrain_model, strict=True, param_key=None)
250
+ else:
251
+ raise ValueError("Please specify the pretrain model path.")
252
+
253
+ self.freeze()
254
+
255
+ @staticmethod
256
+ def load_network(net, load_path, strict=True, param_key=None):
257
+
258
+ def get_bare_model(net):
259
+ if isinstance(net, (DataParallel, DistributedDataParallel)):
260
+ net = net.module
261
+ return net
262
+
263
+ net = get_bare_model(net)
264
+ load_net = torch.load(load_path, map_location=lambda storage, loc: storage)
265
+ if param_key is not None:
266
+ if param_key not in load_net and "params" in load_net:
267
+ param_key = "params"
268
+ load_net = load_net[param_key]
269
+ # remove unnecessary 'module.'
270
+ for k, v in deepcopy(load_net).items():
271
+ if k.startswith("module."):
272
+ load_net[k[7:]] = v
273
+ load_net.pop(k)
274
+ ret = net.load_state_dict(load_net, strict=strict)
275
+ print(ret)
276
+
277
+ def _make_layer(self, block, planes, num_blocks, stride=1):
278
+ downsample = None
279
+ if stride != 1 or self.inplanes != planes * block.expansion:
280
+ downsample = nn.Sequential(
281
+ nn.Conv2d(
282
+ self.inplanes,
283
+ planes * block.expansion,
284
+ kernel_size=1,
285
+ stride=stride,
286
+ bias=False,
287
+ ),
288
+ nn.BatchNorm2d(planes * block.expansion),
289
+ )
290
+ layers = []
291
+ layers.append(
292
+ block(self.inplanes, planes, stride, downsample, use_se=self.use_se)
293
+ )
294
+ self.inplanes = planes
295
+ for _ in range(1, num_blocks):
296
+ layers.append(block(self.inplanes, planes, use_se=self.use_se))
297
+
298
+ return nn.Sequential(*layers)
299
+
300
+ def forward(self, x):
301
+ x = self.conv1(x)
302
+ x = self.bn1(x)
303
+ x = self.prelu(x)
304
+ x = self.maxpool(x)
305
+
306
+ x = self.layer1(x)
307
+ x = self.layer2(x)
308
+ x = self.layer3(x)
309
+ x = self.layer4(x)
310
+ x = self.bn4(x)
311
+ x = self.dropout(x)
312
+ x = x.view(x.size(0), -1)
313
+ x = self.fc5(x)
314
+ x = self.bn5(x)
315
+
316
+ return x
317
+
318
+ def freeze(self):
319
+ self.eval()
320
+ for param in self.parameters():
321
+ param.requires_grad = False
322
+
323
+
324
+ if __name__ == "__main__":
325
+ model = ResNetArcFace()
326
+ model.cuda()
327
+ model.eval()
328
+ # model.eval()
329
+
330
+ set1 = [
331
+ "./debug/face_debug/gt/head_gt_0.png",
332
+ "./debug/face_debug/gt/head_gt_1.png",
333
+ "./debug/face_debug/gt/head_gt_2.png",
334
+ "./debug/face_debug/gt/head_gt_3.png",
335
+ "./debug/face_debug/gt/head_gt_4.png",
336
+ "./debug/face_debug/gt/head_gt_5.png",
337
+ "./debug/face_debug/gt/head_gt_6.png",
338
+ ]
339
+ import cv2
340
+
341
+ img_set1 = [cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) for img_path in set1]
342
+
343
+ F1_list = []
344
+
345
+ f1_scores = []
346
+ for img in img_set1:
347
+ img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0) / 255.0
348
+ img = img.cuda()
349
+ F1 = model(img)
350
+ F1_list.append(F1)
351
+ for i in range(len(F1_list)):
352
+ for j in range(len(F1_list)):
353
+ f1_scores.append(F.l1_loss(F1_list[i], F1_list[j]))
354
+
355
+ print(len(f1_scores))
356
+
357
+ f1_scores = torch.tensor(f1_scores)
358
+ print(f1_scores)
359
+ f1_scores = f1_scores.view(len(F1_list), len(F1_list))
360
+ print(f1_scores)
LHM/models/block.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch.nn as nn
17
+
18
+ from .modulate import ModLN
19
+
20
+
21
+ class BasicBlock(nn.Module):
22
+ """
23
+ Transformer block that is in its simplest form.
24
+ Designed for PF-LRM architecture.
25
+ """
26
+ # Block contains a self-attention layer and an MLP
27
+ def __init__(self, inner_dim: int, num_heads: int, eps: float,
28
+ attn_drop: float = 0., attn_bias: bool = False,
29
+ mlp_ratio: float = 4., mlp_drop: float = 0.):
30
+ super().__init__()
31
+ self.norm1 = nn.LayerNorm(inner_dim, eps=eps)
32
+ self.self_attn = nn.MultiheadAttention(
33
+ embed_dim=inner_dim, num_heads=num_heads,
34
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
35
+ self.norm2 = nn.LayerNorm(inner_dim, eps=eps)
36
+ self.mlp = nn.Sequential(
37
+ nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
38
+ nn.GELU(),
39
+ nn.Dropout(mlp_drop),
40
+ nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
41
+ nn.Dropout(mlp_drop),
42
+ )
43
+
44
+ def forward(self, x):
45
+ # x: [N, L, D]
46
+ before_sa = self.norm1(x)
47
+ x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
48
+ x = x + self.mlp(self.norm2(x))
49
+ return x
50
+
51
+
52
+ class ConditionBlock(nn.Module):
53
+ """
54
+ Transformer block that takes in a cross-attention condition.
55
+ Designed for SparseLRM architecture.
56
+ """
57
+ # Block contains a cross-attention layer, a self-attention layer, and an MLP
58
+ def __init__(self, inner_dim: int, cond_dim: int, num_heads: int, eps: float,
59
+ attn_drop: float = 0., attn_bias: bool = False,
60
+ mlp_ratio: float = 4., mlp_drop: float = 0.):
61
+ super().__init__()
62
+ self.norm1 = nn.LayerNorm(inner_dim, eps=eps)
63
+ self.cross_attn = nn.MultiheadAttention(
64
+ embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
65
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
66
+ self.norm2 = nn.LayerNorm(inner_dim, eps=eps)
67
+ self.self_attn = nn.MultiheadAttention(
68
+ embed_dim=inner_dim, num_heads=num_heads,
69
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
70
+ self.norm3 = nn.LayerNorm(inner_dim, eps=eps)
71
+ self.mlp = nn.Sequential(
72
+ nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
73
+ nn.GELU(),
74
+ nn.Dropout(mlp_drop),
75
+ nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
76
+ nn.Dropout(mlp_drop),
77
+ )
78
+
79
+ def forward(self, x, cond):
80
+ # x: [N, L, D]
81
+ # cond: [N, L_cond, D_cond]
82
+ x = x + self.cross_attn(self.norm1(x), cond, cond, need_weights=False)[0]
83
+ before_sa = self.norm2(x)
84
+ x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
85
+ x = x + self.mlp(self.norm3(x))
86
+ return x
87
+
88
+
89
+ class ConditionModulationBlock(nn.Module):
90
+ """
91
+ Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks.
92
+ Designed for raw LRM architecture.
93
+ """
94
+ # Block contains a cross-attention layer, a self-attention layer, and an MLP
95
+ def __init__(self, inner_dim: int, cond_dim: int, mod_dim: int, num_heads: int, eps: float,
96
+ attn_drop: float = 0., attn_bias: bool = False,
97
+ mlp_ratio: float = 4., mlp_drop: float = 0.):
98
+ super().__init__()
99
+ self.norm1 = ModLN(inner_dim, mod_dim, eps)
100
+ self.cross_attn = nn.MultiheadAttention(
101
+ embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
102
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
103
+ self.norm2 = ModLN(inner_dim, mod_dim, eps)
104
+ self.self_attn = nn.MultiheadAttention(
105
+ embed_dim=inner_dim, num_heads=num_heads,
106
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
107
+ self.norm3 = ModLN(inner_dim, mod_dim, eps)
108
+ self.mlp = nn.Sequential(
109
+ nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
110
+ nn.GELU(),
111
+ nn.Dropout(mlp_drop),
112
+ nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
113
+ nn.Dropout(mlp_drop),
114
+ )
115
+
116
+ def forward(self, x, cond, mod):
117
+ # x: [N, L, D]
118
+ # cond: [N, L_cond, D_cond]
119
+ # mod: [N, D_mod]
120
+ x = x + self.cross_attn(self.norm1(x, mod), cond, cond, need_weights=False)[0]
121
+ before_sa = self.norm2(x, mod)
122
+ x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
123
+ x = x + self.mlp(self.norm3(x, mod))
124
+ return x
LHM/models/discriminator.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Ported from Paella
3
+ """
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.models.modeling_utils import ModelMixin
10
+
11
+ import functools
12
+ # import torch.nn as nn
13
+ from taming.modules.util import ActNorm
14
+
15
+
16
+ # Discriminator model ported from Paella https://github.com/dome272/Paella/blob/main/src_distributed/vqgan.py
17
+ class Discriminator(ModelMixin, ConfigMixin):
18
+ @register_to_config
19
+ def __init__(self, in_channels=3, cond_channels=0, hidden_channels=512, depth=6):
20
+ super().__init__()
21
+ d = max(depth - 3, 3)
22
+ layers = [
23
+ nn.utils.spectral_norm(
24
+ nn.Conv2d(in_channels, hidden_channels // (2**d), kernel_size=3, stride=2, padding=1)
25
+ ),
26
+ nn.LeakyReLU(0.2),
27
+ ]
28
+ for i in range(depth - 1):
29
+ c_in = hidden_channels // (2 ** max((d - i), 0))
30
+ c_out = hidden_channels // (2 ** max((d - 1 - i), 0))
31
+ layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
32
+ layers.append(nn.InstanceNorm2d(c_out))
33
+ layers.append(nn.LeakyReLU(0.2))
34
+ self.encoder = nn.Sequential(*layers)
35
+ self.shuffle = nn.Conv2d(
36
+ (hidden_channels + cond_channels) if cond_channels > 0 else hidden_channels, 1, kernel_size=1
37
+ )
38
+ # self.logits = nn.Sigmoid()
39
+
40
+
41
+ def forward(self, x, cond=None):
42
+ x = self.encoder(x)
43
+ if cond is not None:
44
+ cond = cond.view(
45
+ cond.size(0),
46
+ cond.size(1),
47
+ 1,
48
+ 1,
49
+ ).expand(-1, -1, x.size(-2), x.size(-1))
50
+ x = torch.cat([x, cond], dim=1)
51
+ x = self.shuffle(x)
52
+ # x = self.logits(x)
53
+ return x
54
+
55
+
56
+
57
+
58
+ def weights_init(m):
59
+ classname = m.__class__.__name__
60
+ if classname.find('Conv') != -1:
61
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
62
+ elif classname.find('BatchNorm') != -1:
63
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
64
+ nn.init.constant_(m.bias.data, 0)
65
+
66
+
67
+ class NLayerDiscriminator(nn.Module):
68
+ """Defines a PatchGAN discriminator as in Pix2Pix
69
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
70
+ """
71
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
72
+ """Construct a PatchGAN discriminator
73
+ Parameters:
74
+ input_nc (int) -- the number of channels in input images
75
+ ndf (int) -- the number of filters in the last conv layer
76
+ n_layers (int) -- the number of conv layers in the discriminator
77
+ norm_layer -- normalization layer
78
+ """
79
+ super(NLayerDiscriminator, self).__init__()
80
+ if not use_actnorm:
81
+ # norm_layer = nn.BatchNorm2d
82
+ norm_layer = nn.InstanceNorm2d
83
+ else:
84
+ norm_layer = ActNorm
85
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
86
+ # use_bias = norm_layer.func != nn.BatchNorm2d
87
+ use_bias = norm_layer.func != nn.InstanceNorm2d
88
+ else:
89
+ # use_bias = norm_layer != nn.BatchNorm2d
90
+ use_bias = norm_layer != nn.InstanceNorm2d
91
+
92
+ kw = 4
93
+ padw = 1
94
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, False)]
95
+ nf_mult = 1
96
+ nf_mult_prev = 1
97
+ for n in range(1, n_layers): # gradually increase the number of filters
98
+ nf_mult_prev = nf_mult
99
+ nf_mult = min(2 ** n, 8)
100
+ sequence += [
101
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
102
+ norm_layer(ndf * nf_mult),
103
+ nn.LeakyReLU(0.2, False)
104
+ ]
105
+
106
+ nf_mult_prev = nf_mult
107
+ nf_mult = min(2 ** n_layers, 8)
108
+ sequence += [
109
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
110
+ norm_layer(ndf * nf_mult),
111
+ nn.LeakyReLU(0.2, False)
112
+ ]
113
+
114
+ sequence += [
115
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
116
+ self.main = nn.Sequential(*sequence)
117
+
118
+ def forward(self, input):
119
+ """Standard forward."""
120
+ return self.main(input)
LHM/models/embedder.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+
20
+ class CameraEmbedder(nn.Module):
21
+ """
22
+ Embed camera features to a high-dimensional vector.
23
+
24
+ Reference:
25
+ DiT: https://github.com/facebookresearch/DiT/blob/main/models.py#L27
26
+ """
27
+ def __init__(self, raw_dim: int, embed_dim: int):
28
+ super().__init__()
29
+ self.mlp = nn.Sequential(
30
+ nn.Linear(raw_dim, embed_dim),
31
+ nn.SiLU(),
32
+ nn.Linear(embed_dim, embed_dim),
33
+ )
34
+
35
+ @torch.compile
36
+ def forward(self, x):
37
+ return self.mlp(x)
LHM/models/encoders/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # Empty
LHM/models/encoders/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (169 Bytes). View file
 
LHM/models/encoders/__pycache__/dinov2_fusion_wrapper.cpython-310.pyc ADDED
Binary file (4.98 kB). View file
 
LHM/models/encoders/__pycache__/sapiens_warpper.cpython-310.pyc ADDED
Binary file (9.21 kB). View file
 
LHM/models/encoders/dino_wrapper.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from transformers import ViTImageProcessor, ViTModel
19
+ from accelerate.logging import get_logger
20
+
21
+
22
+ logger = get_logger(__name__)
23
+
24
+
25
+ class DinoWrapper(nn.Module):
26
+ """
27
+ Dino v1 wrapper using huggingface transformer implementation.
28
+ """
29
+ def __init__(self, model_name: str, freeze: bool = True, encoder_feat_dim: int = 384):
30
+ super().__init__()
31
+ self.model, self.processor = self._build_dino(model_name)
32
+ if freeze:
33
+ self._freeze()
34
+
35
+ @torch.compile
36
+ def forward_model(self, inputs):
37
+ return self.model(**inputs, interpolate_pos_encoding=True)
38
+
39
+ def forward(self, image):
40
+ # image: [N, C, H, W], on cpu
41
+ # RGB image with [0,1] scale and properly sized
42
+ inputs = self.processor(images=image, return_tensors="pt", do_rescale=False, do_resize=False).to(self.model.device)
43
+ # This resampling of positional embedding uses bicubic interpolation
44
+ outputs = self.forward_model(inputs)
45
+ last_hidden_states = outputs.last_hidden_state
46
+ return last_hidden_states
47
+
48
+ def _freeze(self):
49
+ logger.warning(f"======== Freezing DinoWrapper ========")
50
+ self.model.eval()
51
+ for name, param in self.model.named_parameters():
52
+ param.requires_grad = False
53
+
54
+ @staticmethod
55
+ def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5):
56
+ import requests
57
+ try:
58
+ model = ViTModel.from_pretrained(model_name, add_pooling_layer=False)
59
+ processor = ViTImageProcessor.from_pretrained(model_name)
60
+ return model, processor
61
+ except requests.exceptions.ProxyError as err:
62
+ if proxy_error_retries > 0:
63
+ print(f"Huggingface ProxyError: Retrying ({proxy_error_retries}) in {proxy_error_cooldown} seconds...")
64
+ import time
65
+ time.sleep(proxy_error_cooldown)
66
+ return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown)
67
+ else:
68
+ raise err
LHM/models/encoders/dinov2/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # Empty
LHM/models/encoders/dinov2/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (176 Bytes). View file
 
LHM/models/encoders/dinov2/hub/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
LHM/models/encoders/dinov2/hub/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (180 Bytes). View file
 
LHM/models/encoders/dinov2/hub/__pycache__/backbones.cpython-310.pyc ADDED
Binary file (4.45 kB). View file
 
LHM/models/encoders/dinov2/hub/__pycache__/utils.cpython-310.pyc ADDED
Binary file (1.79 kB). View file
 
LHM/models/encoders/dinov2/hub/backbones.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from enum import Enum
7
+ from typing import Union
8
+
9
+ import torch
10
+
11
+ from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
12
+
13
+
14
+ class Weights(Enum):
15
+ LVD142M = "LVD142M"
16
+
17
+
18
+ def _make_dinov2_model(
19
+ *,
20
+ arch_name: str = "vit_large",
21
+ img_size: int = 518,
22
+ patch_size: int = 14,
23
+ init_values: float = 1.0,
24
+ ffn_layer: str = "mlp",
25
+ block_chunks: int = 0,
26
+ num_register_tokens: int = 0,
27
+ interpolate_antialias: bool = False,
28
+ interpolate_offset: float = 0.1,
29
+ pretrained: bool = True,
30
+ weights: Union[Weights, str] = Weights.LVD142M,
31
+ **kwargs,
32
+ ):
33
+ from ..models import vision_transformer as vits
34
+
35
+ if isinstance(weights, str):
36
+ try:
37
+ weights = Weights[weights]
38
+ except KeyError:
39
+ raise AssertionError(f"Unsupported weights: {weights}")
40
+
41
+ model_base_name = _make_dinov2_model_name(arch_name, patch_size)
42
+ vit_kwargs = dict(
43
+ img_size=img_size,
44
+ patch_size=patch_size,
45
+ init_values=init_values,
46
+ ffn_layer=ffn_layer,
47
+ block_chunks=block_chunks,
48
+ num_register_tokens=num_register_tokens,
49
+ interpolate_antialias=interpolate_antialias,
50
+ interpolate_offset=interpolate_offset,
51
+ )
52
+ vit_kwargs.update(**kwargs)
53
+ model = vits.__dict__[arch_name](**vit_kwargs)
54
+
55
+ if pretrained:
56
+ model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
57
+ url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
58
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
59
+ # ********** Modified by Zexin He in 2023-2024 **********
60
+ state_dict = {k: v for k, v in state_dict.items() if 'mask_token' not in k} # DDP concern
61
+ if vit_kwargs.get("modulation_dim") is not None:
62
+ state_dict = {
63
+ k.replace('norm1', 'norm1.norm').replace('norm2', 'norm2.norm'): v
64
+ for k, v in state_dict.items()
65
+ }
66
+ model.load_state_dict(state_dict, strict=False)
67
+ else:
68
+ model.load_state_dict(state_dict, strict=True)
69
+ # ********************************************************
70
+
71
+ return model
72
+
73
+
74
+ def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
75
+ """
76
+ DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
77
+ """
78
+ return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
79
+
80
+
81
+ def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
82
+ """
83
+ DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
84
+ """
85
+ return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
86
+
87
+
88
+ def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
89
+ """
90
+ DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
91
+ """
92
+ return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
93
+
94
+
95
+ def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
96
+ """
97
+ DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
98
+ """
99
+ return _make_dinov2_model(
100
+ arch_name="vit_giant2",
101
+ ffn_layer="swiglufused",
102
+ weights=weights,
103
+ pretrained=pretrained,
104
+ **kwargs,
105
+ )
106
+
107
+
108
+ def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
109
+ """
110
+ DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
111
+ """
112
+ return _make_dinov2_model(
113
+ arch_name="vit_small",
114
+ pretrained=pretrained,
115
+ weights=weights,
116
+ num_register_tokens=4,
117
+ interpolate_antialias=True,
118
+ interpolate_offset=0.0,
119
+ **kwargs,
120
+ )
121
+
122
+
123
+ def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
124
+ """
125
+ DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
126
+ """
127
+ return _make_dinov2_model(
128
+ arch_name="vit_base",
129
+ pretrained=pretrained,
130
+ weights=weights,
131
+ num_register_tokens=4,
132
+ interpolate_antialias=True,
133
+ interpolate_offset=0.0,
134
+ **kwargs,
135
+ )
136
+
137
+
138
+ def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
139
+ """
140
+ DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
141
+ """
142
+ return _make_dinov2_model(
143
+ arch_name="vit_large",
144
+ pretrained=pretrained,
145
+ weights=weights,
146
+ num_register_tokens=4,
147
+ interpolate_antialias=True,
148
+ interpolate_offset=0.0,
149
+ **kwargs,
150
+ )
151
+
152
+
153
+ def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
154
+ """
155
+ DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
156
+ """
157
+ return _make_dinov2_model(
158
+ arch_name="vit_giant2",
159
+ ffn_layer="swiglufused",
160
+ weights=weights,
161
+ pretrained=pretrained,
162
+ num_register_tokens=4,
163
+ interpolate_antialias=True,
164
+ interpolate_offset=0.0,
165
+ **kwargs,
166
+ )
LHM/models/encoders/dinov2/hub/classifiers.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from enum import Enum
7
+ from typing import Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from .backbones import _make_dinov2_model
13
+ from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
14
+
15
+
16
+ class Weights(Enum):
17
+ IMAGENET1K = "IMAGENET1K"
18
+
19
+
20
+ def _make_dinov2_linear_classification_head(
21
+ *,
22
+ arch_name: str = "vit_large",
23
+ patch_size: int = 14,
24
+ embed_dim: int = 1024,
25
+ layers: int = 4,
26
+ pretrained: bool = True,
27
+ weights: Union[Weights, str] = Weights.IMAGENET1K,
28
+ num_register_tokens: int = 0,
29
+ **kwargs,
30
+ ):
31
+ if layers not in (1, 4):
32
+ raise AssertionError(f"Unsupported number of layers: {layers}")
33
+ if isinstance(weights, str):
34
+ try:
35
+ weights = Weights[weights]
36
+ except KeyError:
37
+ raise AssertionError(f"Unsupported weights: {weights}")
38
+
39
+ linear_head = nn.Linear((1 + layers) * embed_dim, 1_000)
40
+
41
+ if pretrained:
42
+ model_base_name = _make_dinov2_model_name(arch_name, patch_size)
43
+ model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
44
+ layers_str = str(layers) if layers == 4 else ""
45
+ url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_linear{layers_str}_head.pth"
46
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
47
+ linear_head.load_state_dict(state_dict, strict=True)
48
+
49
+ return linear_head
50
+
51
+
52
+ class _LinearClassifierWrapper(nn.Module):
53
+ def __init__(self, *, backbone: nn.Module, linear_head: nn.Module, layers: int = 4):
54
+ super().__init__()
55
+ self.backbone = backbone
56
+ self.linear_head = linear_head
57
+ self.layers = layers
58
+
59
+ def forward(self, x):
60
+ if self.layers == 1:
61
+ x = self.backbone.forward_features(x)
62
+ cls_token = x["x_norm_clstoken"]
63
+ patch_tokens = x["x_norm_patchtokens"]
64
+ # fmt: off
65
+ linear_input = torch.cat([
66
+ cls_token,
67
+ patch_tokens.mean(dim=1),
68
+ ], dim=1)
69
+ # fmt: on
70
+ elif self.layers == 4:
71
+ x = self.backbone.get_intermediate_layers(x, n=4, return_class_token=True)
72
+ # fmt: off
73
+ linear_input = torch.cat([
74
+ x[0][1],
75
+ x[1][1],
76
+ x[2][1],
77
+ x[3][1],
78
+ x[3][0].mean(dim=1),
79
+ ], dim=1)
80
+ # fmt: on
81
+ else:
82
+ assert False, f"Unsupported number of layers: {self.layers}"
83
+ return self.linear_head(linear_input)
84
+
85
+
86
+ def _make_dinov2_linear_classifier(
87
+ *,
88
+ arch_name: str = "vit_large",
89
+ layers: int = 4,
90
+ pretrained: bool = True,
91
+ weights: Union[Weights, str] = Weights.IMAGENET1K,
92
+ num_register_tokens: int = 0,
93
+ interpolate_antialias: bool = False,
94
+ interpolate_offset: float = 0.1,
95
+ **kwargs,
96
+ ):
97
+ backbone = _make_dinov2_model(
98
+ arch_name=arch_name,
99
+ pretrained=pretrained,
100
+ num_register_tokens=num_register_tokens,
101
+ interpolate_antialias=interpolate_antialias,
102
+ interpolate_offset=interpolate_offset,
103
+ **kwargs,
104
+ )
105
+
106
+ embed_dim = backbone.embed_dim
107
+ patch_size = backbone.patch_size
108
+ linear_head = _make_dinov2_linear_classification_head(
109
+ arch_name=arch_name,
110
+ patch_size=patch_size,
111
+ embed_dim=embed_dim,
112
+ layers=layers,
113
+ pretrained=pretrained,
114
+ weights=weights,
115
+ num_register_tokens=num_register_tokens,
116
+ )
117
+
118
+ return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head, layers=layers)
119
+
120
+
121
+ def dinov2_vits14_lc(
122
+ *,
123
+ layers: int = 4,
124
+ pretrained: bool = True,
125
+ weights: Union[Weights, str] = Weights.IMAGENET1K,
126
+ **kwargs,
127
+ ):
128
+ """
129
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
130
+ """
131
+ return _make_dinov2_linear_classifier(
132
+ arch_name="vit_small",
133
+ layers=layers,
134
+ pretrained=pretrained,
135
+ weights=weights,
136
+ **kwargs,
137
+ )
138
+
139
+
140
+ def dinov2_vitb14_lc(
141
+ *,
142
+ layers: int = 4,
143
+ pretrained: bool = True,
144
+ weights: Union[Weights, str] = Weights.IMAGENET1K,
145
+ **kwargs,
146
+ ):
147
+ """
148
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
149
+ """
150
+ return _make_dinov2_linear_classifier(
151
+ arch_name="vit_base",
152
+ layers=layers,
153
+ pretrained=pretrained,
154
+ weights=weights,
155
+ **kwargs,
156
+ )
157
+
158
+
159
+ def dinov2_vitl14_lc(
160
+ *,
161
+ layers: int = 4,
162
+ pretrained: bool = True,
163
+ weights: Union[Weights, str] = Weights.IMAGENET1K,
164
+ **kwargs,
165
+ ):
166
+ """
167
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
168
+ """
169
+ return _make_dinov2_linear_classifier(
170
+ arch_name="vit_large",
171
+ layers=layers,
172
+ pretrained=pretrained,
173
+ weights=weights,
174
+ **kwargs,
175
+ )
176
+
177
+
178
+ def dinov2_vitg14_lc(
179
+ *,
180
+ layers: int = 4,
181
+ pretrained: bool = True,
182
+ weights: Union[Weights, str] = Weights.IMAGENET1K,
183
+ **kwargs,
184
+ ):
185
+ """
186
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
187
+ """
188
+ return _make_dinov2_linear_classifier(
189
+ arch_name="vit_giant2",
190
+ layers=layers,
191
+ ffn_layer="swiglufused",
192
+ pretrained=pretrained,
193
+ weights=weights,
194
+ **kwargs,
195
+ )
196
+
197
+
198
+ def dinov2_vits14_reg_lc(
199
+ *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
200
+ ):
201
+ """
202
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
203
+ """
204
+ return _make_dinov2_linear_classifier(
205
+ arch_name="vit_small",
206
+ layers=layers,
207
+ pretrained=pretrained,
208
+ weights=weights,
209
+ num_register_tokens=4,
210
+ interpolate_antialias=True,
211
+ interpolate_offset=0.0,
212
+ **kwargs,
213
+ )
214
+
215
+
216
+ def dinov2_vitb14_reg_lc(
217
+ *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
218
+ ):
219
+ """
220
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
221
+ """
222
+ return _make_dinov2_linear_classifier(
223
+ arch_name="vit_base",
224
+ layers=layers,
225
+ pretrained=pretrained,
226
+ weights=weights,
227
+ num_register_tokens=4,
228
+ interpolate_antialias=True,
229
+ interpolate_offset=0.0,
230
+ **kwargs,
231
+ )
232
+
233
+
234
+ def dinov2_vitl14_reg_lc(
235
+ *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
236
+ ):
237
+ """
238
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
239
+ """
240
+ return _make_dinov2_linear_classifier(
241
+ arch_name="vit_large",
242
+ layers=layers,
243
+ pretrained=pretrained,
244
+ weights=weights,
245
+ num_register_tokens=4,
246
+ interpolate_antialias=True,
247
+ interpolate_offset=0.0,
248
+ **kwargs,
249
+ )
250
+
251
+
252
+ def dinov2_vitg14_reg_lc(
253
+ *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
254
+ ):
255
+ """
256
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
257
+ """
258
+ return _make_dinov2_linear_classifier(
259
+ arch_name="vit_giant2",
260
+ layers=layers,
261
+ ffn_layer="swiglufused",
262
+ pretrained=pretrained,
263
+ weights=weights,
264
+ num_register_tokens=4,
265
+ interpolate_antialias=True,
266
+ interpolate_offset=0.0,
267
+ **kwargs,
268
+ )
LHM/models/encoders/dinov2/hub/depth/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .decode_heads import BNHead, DPTHead
7
+ from .encoder_decoder import DepthEncoderDecoder
LHM/models/encoders/dinov2/hub/depth/decode_heads.py ADDED
@@ -0,0 +1,747 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import copy
7
+ from functools import partial
8
+ import math
9
+ import warnings
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from .ops import resize
15
+
16
+
17
+ # XXX: (Untested) replacement for mmcv.imdenormalize()
18
+ def _imdenormalize(img, mean, std, to_bgr=True):
19
+ import numpy as np
20
+
21
+ mean = mean.reshape(1, -1).astype(np.float64)
22
+ std = std.reshape(1, -1).astype(np.float64)
23
+ img = (img * std) + mean
24
+ if to_bgr:
25
+ img = img[::-1]
26
+ return img
27
+
28
+
29
+ class DepthBaseDecodeHead(nn.Module):
30
+ """Base class for BaseDecodeHead.
31
+
32
+ Args:
33
+ in_channels (List): Input channels.
34
+ channels (int): Channels after modules, before conv_depth.
35
+ conv_layer (nn.Module): Conv layers. Default: None.
36
+ act_layer (nn.Module): Activation layers. Default: nn.ReLU.
37
+ loss_decode (dict): Config of decode loss.
38
+ Default: ().
39
+ sampler (dict|None): The config of depth map sampler.
40
+ Default: None.
41
+ align_corners (bool): align_corners argument of F.interpolate.
42
+ Default: False.
43
+ min_depth (int): Min depth in dataset setting.
44
+ Default: 1e-3.
45
+ max_depth (int): Max depth in dataset setting.
46
+ Default: None.
47
+ norm_layer (dict|None): Norm layers.
48
+ Default: None.
49
+ classify (bool): Whether predict depth in a cls.-reg. manner.
50
+ Default: False.
51
+ n_bins (int): The number of bins used in cls. step.
52
+ Default: 256.
53
+ bins_strategy (str): The discrete strategy used in cls. step.
54
+ Default: 'UD'.
55
+ norm_strategy (str): The norm strategy on cls. probability
56
+ distribution. Default: 'linear'
57
+ scale_up (str): Whether predict depth in a scale-up manner.
58
+ Default: False.
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ in_channels,
64
+ conv_layer=None,
65
+ act_layer=nn.ReLU,
66
+ channels=96,
67
+ loss_decode=(),
68
+ sampler=None,
69
+ align_corners=False,
70
+ min_depth=1e-3,
71
+ max_depth=None,
72
+ norm_layer=None,
73
+ classify=False,
74
+ n_bins=256,
75
+ bins_strategy="UD",
76
+ norm_strategy="linear",
77
+ scale_up=False,
78
+ ):
79
+ super(DepthBaseDecodeHead, self).__init__()
80
+
81
+ self.in_channels = in_channels
82
+ self.channels = channels
83
+ self.conf_layer = conv_layer
84
+ self.act_layer = act_layer
85
+ self.loss_decode = loss_decode
86
+ self.align_corners = align_corners
87
+ self.min_depth = min_depth
88
+ self.max_depth = max_depth
89
+ self.norm_layer = norm_layer
90
+ self.classify = classify
91
+ self.n_bins = n_bins
92
+ self.scale_up = scale_up
93
+
94
+ if self.classify:
95
+ assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID"
96
+ assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid"
97
+
98
+ self.bins_strategy = bins_strategy
99
+ self.norm_strategy = norm_strategy
100
+ self.softmax = nn.Softmax(dim=1)
101
+ self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1)
102
+ else:
103
+ self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1)
104
+
105
+ self.relu = nn.ReLU()
106
+ self.sigmoid = nn.Sigmoid()
107
+
108
+ def forward(self, inputs, img_metas):
109
+ """Placeholder of forward function."""
110
+ pass
111
+
112
+ def forward_train(self, img, inputs, img_metas, depth_gt):
113
+ """Forward function for training.
114
+ Args:
115
+ inputs (list[Tensor]): List of multi-level img features.
116
+ img_metas (list[dict]): List of image info dict where each dict
117
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
118
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
119
+ For details on the values of these keys see
120
+ `depth/datasets/pipelines/formatting.py:Collect`.
121
+ depth_gt (Tensor): GT depth
122
+
123
+ Returns:
124
+ dict[str, Tensor]: a dictionary of loss components
125
+ """
126
+ depth_pred = self.forward(inputs, img_metas)
127
+ losses = self.losses(depth_pred, depth_gt)
128
+
129
+ log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0])
130
+ losses.update(**log_imgs)
131
+
132
+ return losses
133
+
134
+ def forward_test(self, inputs, img_metas):
135
+ """Forward function for testing.
136
+ Args:
137
+ inputs (list[Tensor]): List of multi-level img features.
138
+ img_metas (list[dict]): List of image info dict where each dict
139
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
140
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
141
+ For details on the values of these keys see
142
+ `depth/datasets/pipelines/formatting.py:Collect`.
143
+
144
+ Returns:
145
+ Tensor: Output depth map.
146
+ """
147
+ return self.forward(inputs, img_metas)
148
+
149
+ def depth_pred(self, feat):
150
+ """Prediction each pixel."""
151
+ if self.classify:
152
+ logit = self.conv_depth(feat)
153
+
154
+ if self.bins_strategy == "UD":
155
+ bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
156
+ elif self.bins_strategy == "SID":
157
+ bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
158
+
159
+ # following Adabins, default linear
160
+ if self.norm_strategy == "linear":
161
+ logit = torch.relu(logit)
162
+ eps = 0.1
163
+ logit = logit + eps
164
+ logit = logit / logit.sum(dim=1, keepdim=True)
165
+ elif self.norm_strategy == "softmax":
166
+ logit = torch.softmax(logit, dim=1)
167
+ elif self.norm_strategy == "sigmoid":
168
+ logit = torch.sigmoid(logit)
169
+ logit = logit / logit.sum(dim=1, keepdim=True)
170
+
171
+ output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1)
172
+
173
+ else:
174
+ if self.scale_up:
175
+ output = self.sigmoid(self.conv_depth(feat)) * self.max_depth
176
+ else:
177
+ output = self.relu(self.conv_depth(feat)) + self.min_depth
178
+ return output
179
+
180
+ def losses(self, depth_pred, depth_gt):
181
+ """Compute depth loss."""
182
+ loss = dict()
183
+ depth_pred = resize(
184
+ input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False
185
+ )
186
+ if not isinstance(self.loss_decode, nn.ModuleList):
187
+ losses_decode = [self.loss_decode]
188
+ else:
189
+ losses_decode = self.loss_decode
190
+ for loss_decode in losses_decode:
191
+ if loss_decode.loss_name not in loss:
192
+ loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt)
193
+ else:
194
+ loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt)
195
+ return loss
196
+
197
+ def log_images(self, img_path, depth_pred, depth_gt, img_meta):
198
+ import numpy as np
199
+
200
+ show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0))
201
+ show_img = show_img.numpy().astype(np.float32)
202
+ show_img = _imdenormalize(
203
+ show_img,
204
+ img_meta["img_norm_cfg"]["mean"],
205
+ img_meta["img_norm_cfg"]["std"],
206
+ img_meta["img_norm_cfg"]["to_rgb"],
207
+ )
208
+ show_img = np.clip(show_img, 0, 255)
209
+ show_img = show_img.astype(np.uint8)
210
+ show_img = show_img[:, :, ::-1]
211
+ show_img = show_img.transpose(0, 2, 1)
212
+ show_img = show_img.transpose(1, 0, 2)
213
+
214
+ depth_pred = depth_pred / torch.max(depth_pred)
215
+ depth_gt = depth_gt / torch.max(depth_gt)
216
+
217
+ depth_pred_color = copy.deepcopy(depth_pred.detach().cpu())
218
+ depth_gt_color = copy.deepcopy(depth_gt.detach().cpu())
219
+
220
+ return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color}
221
+
222
+
223
+ class BNHead(DepthBaseDecodeHead):
224
+ """Just a batchnorm."""
225
+
226
+ def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs):
227
+ super().__init__(**kwargs)
228
+ self.input_transform = input_transform
229
+ self.in_index = in_index
230
+ self.upsample = upsample
231
+ # self.bn = nn.SyncBatchNorm(self.in_channels)
232
+ if self.classify:
233
+ self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1)
234
+ else:
235
+ self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1)
236
+
237
+ def _transform_inputs(self, inputs):
238
+ """Transform inputs for decoder.
239
+ Args:
240
+ inputs (list[Tensor]): List of multi-level img features.
241
+ Returns:
242
+ Tensor: The transformed inputs
243
+ """
244
+
245
+ if "concat" in self.input_transform:
246
+ inputs = [inputs[i] for i in self.in_index]
247
+ if "resize" in self.input_transform:
248
+ inputs = [
249
+ resize(
250
+ input=x,
251
+ size=[s * self.upsample for s in inputs[0].shape[2:]],
252
+ mode="bilinear",
253
+ align_corners=self.align_corners,
254
+ )
255
+ for x in inputs
256
+ ]
257
+ inputs = torch.cat(inputs, dim=1)
258
+ elif self.input_transform == "multiple_select":
259
+ inputs = [inputs[i] for i in self.in_index]
260
+ else:
261
+ inputs = inputs[self.in_index]
262
+
263
+ return inputs
264
+
265
+ def _forward_feature(self, inputs, img_metas=None, **kwargs):
266
+ """Forward function for feature maps before classifying each pixel with
267
+ ``self.cls_seg`` fc.
268
+ Args:
269
+ inputs (list[Tensor]): List of multi-level img features.
270
+ Returns:
271
+ feats (Tensor): A tensor of shape (batch_size, self.channels,
272
+ H, W) which is feature map for last layer of decoder head.
273
+ """
274
+ # accept lists (for cls token)
275
+ inputs = list(inputs)
276
+ for i, x in enumerate(inputs):
277
+ if len(x) == 2:
278
+ x, cls_token = x[0], x[1]
279
+ if len(x.shape) == 2:
280
+ x = x[:, :, None, None]
281
+ cls_token = cls_token[:, :, None, None].expand_as(x)
282
+ inputs[i] = torch.cat((x, cls_token), 1)
283
+ else:
284
+ x = x[0]
285
+ if len(x.shape) == 2:
286
+ x = x[:, :, None, None]
287
+ inputs[i] = x
288
+ x = self._transform_inputs(inputs)
289
+ # feats = self.bn(x)
290
+ return x
291
+
292
+ def forward(self, inputs, img_metas=None, **kwargs):
293
+ """Forward function."""
294
+ output = self._forward_feature(inputs, img_metas=img_metas, **kwargs)
295
+ output = self.depth_pred(output)
296
+ return output
297
+
298
+
299
+ class ConvModule(nn.Module):
300
+ """A conv block that bundles conv/norm/activation layers.
301
+
302
+ This block simplifies the usage of convolution layers, which are commonly
303
+ used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
304
+ It is based upon three build methods: `build_conv_layer()`,
305
+ `build_norm_layer()` and `build_activation_layer()`.
306
+
307
+ Besides, we add some additional features in this module.
308
+ 1. Automatically set `bias` of the conv layer.
309
+ 2. Spectral norm is supported.
310
+ 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
311
+ supports zero and circular padding, and we add "reflect" padding mode.
312
+
313
+ Args:
314
+ in_channels (int): Number of channels in the input feature map.
315
+ Same as that in ``nn._ConvNd``.
316
+ out_channels (int): Number of channels produced by the convolution.
317
+ Same as that in ``nn._ConvNd``.
318
+ kernel_size (int | tuple[int]): Size of the convolving kernel.
319
+ Same as that in ``nn._ConvNd``.
320
+ stride (int | tuple[int]): Stride of the convolution.
321
+ Same as that in ``nn._ConvNd``.
322
+ padding (int | tuple[int]): Zero-padding added to both sides of
323
+ the input. Same as that in ``nn._ConvNd``.
324
+ dilation (int | tuple[int]): Spacing between kernel elements.
325
+ Same as that in ``nn._ConvNd``.
326
+ groups (int): Number of blocked connections from input channels to
327
+ output channels. Same as that in ``nn._ConvNd``.
328
+ bias (bool | str): If specified as `auto`, it will be decided by the
329
+ norm_layer. Bias will be set as True if `norm_layer` is None, otherwise
330
+ False. Default: "auto".
331
+ conv_layer (nn.Module): Convolution layer. Default: None,
332
+ which means using conv2d.
333
+ norm_layer (nn.Module): Normalization layer. Default: None.
334
+ act_layer (nn.Module): Activation layer. Default: nn.ReLU.
335
+ inplace (bool): Whether to use inplace mode for activation.
336
+ Default: True.
337
+ with_spectral_norm (bool): Whether use spectral norm in conv module.
338
+ Default: False.
339
+ padding_mode (str): If the `padding_mode` has not been supported by
340
+ current `Conv2d` in PyTorch, we will use our own padding layer
341
+ instead. Currently, we support ['zeros', 'circular'] with official
342
+ implementation and ['reflect'] with our own implementation.
343
+ Default: 'zeros'.
344
+ order (tuple[str]): The order of conv/norm/activation layers. It is a
345
+ sequence of "conv", "norm" and "act". Common examples are
346
+ ("conv", "norm", "act") and ("act", "conv", "norm").
347
+ Default: ('conv', 'norm', 'act').
348
+ """
349
+
350
+ _abbr_ = "conv_block"
351
+
352
+ def __init__(
353
+ self,
354
+ in_channels,
355
+ out_channels,
356
+ kernel_size,
357
+ stride=1,
358
+ padding=0,
359
+ dilation=1,
360
+ groups=1,
361
+ bias="auto",
362
+ conv_layer=nn.Conv2d,
363
+ norm_layer=None,
364
+ act_layer=nn.ReLU,
365
+ inplace=True,
366
+ with_spectral_norm=False,
367
+ padding_mode="zeros",
368
+ order=("conv", "norm", "act"),
369
+ ):
370
+ super(ConvModule, self).__init__()
371
+ official_padding_mode = ["zeros", "circular"]
372
+ self.conv_layer = conv_layer
373
+ self.norm_layer = norm_layer
374
+ self.act_layer = act_layer
375
+ self.inplace = inplace
376
+ self.with_spectral_norm = with_spectral_norm
377
+ self.with_explicit_padding = padding_mode not in official_padding_mode
378
+ self.order = order
379
+ assert isinstance(self.order, tuple) and len(self.order) == 3
380
+ assert set(order) == set(["conv", "norm", "act"])
381
+
382
+ self.with_norm = norm_layer is not None
383
+ self.with_activation = act_layer is not None
384
+ # if the conv layer is before a norm layer, bias is unnecessary.
385
+ if bias == "auto":
386
+ bias = not self.with_norm
387
+ self.with_bias = bias
388
+
389
+ if self.with_explicit_padding:
390
+ if padding_mode == "zeros":
391
+ padding_layer = nn.ZeroPad2d
392
+ else:
393
+ raise AssertionError(f"Unsupported padding mode: {padding_mode}")
394
+ self.pad = padding_layer(padding)
395
+
396
+ # reset padding to 0 for conv module
397
+ conv_padding = 0 if self.with_explicit_padding else padding
398
+ # build convolution layer
399
+ self.conv = self.conv_layer(
400
+ in_channels,
401
+ out_channels,
402
+ kernel_size,
403
+ stride=stride,
404
+ padding=conv_padding,
405
+ dilation=dilation,
406
+ groups=groups,
407
+ bias=bias,
408
+ )
409
+ # export the attributes of self.conv to a higher level for convenience
410
+ self.in_channels = self.conv.in_channels
411
+ self.out_channels = self.conv.out_channels
412
+ self.kernel_size = self.conv.kernel_size
413
+ self.stride = self.conv.stride
414
+ self.padding = padding
415
+ self.dilation = self.conv.dilation
416
+ self.transposed = self.conv.transposed
417
+ self.output_padding = self.conv.output_padding
418
+ self.groups = self.conv.groups
419
+
420
+ if self.with_spectral_norm:
421
+ self.conv = nn.utils.spectral_norm(self.conv)
422
+
423
+ # build normalization layers
424
+ if self.with_norm:
425
+ # norm layer is after conv layer
426
+ if order.index("norm") > order.index("conv"):
427
+ norm_channels = out_channels
428
+ else:
429
+ norm_channels = in_channels
430
+ norm = partial(norm_layer, num_features=norm_channels)
431
+ self.add_module("norm", norm)
432
+ if self.with_bias:
433
+ from torch.nnModules.batchnorm import _BatchNorm
434
+ from torch.nnModules.instancenorm import _InstanceNorm
435
+
436
+ if isinstance(norm, (_BatchNorm, _InstanceNorm)):
437
+ warnings.warn("Unnecessary conv bias before batch/instance norm")
438
+ else:
439
+ self.norm_name = None
440
+
441
+ # build activation layer
442
+ if self.with_activation:
443
+ # nn.Tanh has no 'inplace' argument
444
+ # (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.HSigmoid, nn.Swish, nn.GELU)
445
+ if not isinstance(act_layer, (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.GELU)):
446
+ act_layer = partial(act_layer, inplace=inplace)
447
+ self.activate = act_layer()
448
+
449
+ # Use msra init by default
450
+ self.init_weights()
451
+
452
+ @property
453
+ def norm(self):
454
+ if self.norm_name:
455
+ return getattr(self, self.norm_name)
456
+ else:
457
+ return None
458
+
459
+ def init_weights(self):
460
+ # 1. It is mainly for customized conv layers with their own
461
+ # initialization manners by calling their own ``init_weights()``,
462
+ # and we do not want ConvModule to override the initialization.
463
+ # 2. For customized conv layers without their own initialization
464
+ # manners (that is, they don't have their own ``init_weights()``)
465
+ # and PyTorch's conv layers, they will be initialized by
466
+ # this method with default ``kaiming_init``.
467
+ # Note: For PyTorch's conv layers, they will be overwritten by our
468
+ # initialization implementation using default ``kaiming_init``.
469
+ if not hasattr(self.conv, "init_weights"):
470
+ if self.with_activation and isinstance(self.act_layer, nn.LeakyReLU):
471
+ nonlinearity = "leaky_relu"
472
+ a = 0.01 # XXX: default negative_slope
473
+ else:
474
+ nonlinearity = "relu"
475
+ a = 0
476
+ if hasattr(self.conv, "weight") and self.conv.weight is not None:
477
+ nn.init.kaiming_normal_(self.conv.weight, a=a, mode="fan_out", nonlinearity=nonlinearity)
478
+ if hasattr(self.conv, "bias") and self.conv.bias is not None:
479
+ nn.init.constant_(self.conv.bias, 0)
480
+ if self.with_norm:
481
+ if hasattr(self.norm, "weight") and self.norm.weight is not None:
482
+ nn.init.constant_(self.norm.weight, 1)
483
+ if hasattr(self.norm, "bias") and self.norm.bias is not None:
484
+ nn.init.constant_(self.norm.bias, 0)
485
+
486
+ def forward(self, x, activate=True, norm=True):
487
+ for layer in self.order:
488
+ if layer == "conv":
489
+ if self.with_explicit_padding:
490
+ x = self.pad(x)
491
+ x = self.conv(x)
492
+ elif layer == "norm" and norm and self.with_norm:
493
+ x = self.norm(x)
494
+ elif layer == "act" and activate and self.with_activation:
495
+ x = self.activate(x)
496
+ return x
497
+
498
+
499
+ class Interpolate(nn.Module):
500
+ def __init__(self, scale_factor, mode, align_corners=False):
501
+ super(Interpolate, self).__init__()
502
+ self.interp = nn.functional.interpolate
503
+ self.scale_factor = scale_factor
504
+ self.mode = mode
505
+ self.align_corners = align_corners
506
+
507
+ def forward(self, x):
508
+ x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
509
+ return x
510
+
511
+
512
+ class HeadDepth(nn.Module):
513
+ def __init__(self, features):
514
+ super(HeadDepth, self).__init__()
515
+ self.head = nn.Sequential(
516
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
517
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
518
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
519
+ nn.ReLU(),
520
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
521
+ )
522
+
523
+ def forward(self, x):
524
+ x = self.head(x)
525
+ return x
526
+
527
+
528
+ class ReassembleBlocks(nn.Module):
529
+ """ViTPostProcessBlock, process cls_token in ViT backbone output and
530
+ rearrange the feature vector to feature map.
531
+ Args:
532
+ in_channels (int): ViT feature channels. Default: 768.
533
+ out_channels (List): output channels of each stage.
534
+ Default: [96, 192, 384, 768].
535
+ readout_type (str): Type of readout operation. Default: 'ignore'.
536
+ patch_size (int): The patch size. Default: 16.
537
+ """
538
+
539
+ def __init__(self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16):
540
+ super(ReassembleBlocks, self).__init__()
541
+
542
+ assert readout_type in ["ignore", "add", "project"]
543
+ self.readout_type = readout_type
544
+ self.patch_size = patch_size
545
+
546
+ self.projects = nn.ModuleList(
547
+ [
548
+ ConvModule(
549
+ in_channels=in_channels,
550
+ out_channels=out_channel,
551
+ kernel_size=1,
552
+ act_layer=None,
553
+ )
554
+ for out_channel in out_channels
555
+ ]
556
+ )
557
+
558
+ self.resize_layers = nn.ModuleList(
559
+ [
560
+ nn.ConvTranspose2d(
561
+ in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
562
+ ),
563
+ nn.ConvTranspose2d(
564
+ in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
565
+ ),
566
+ nn.Identity(),
567
+ nn.Conv2d(
568
+ in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
569
+ ),
570
+ ]
571
+ )
572
+ if self.readout_type == "project":
573
+ self.readout_projects = nn.ModuleList()
574
+ for _ in range(len(self.projects)):
575
+ self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU()))
576
+
577
+ def forward(self, inputs):
578
+ assert isinstance(inputs, list)
579
+ out = []
580
+ for i, x in enumerate(inputs):
581
+ assert len(x) == 2
582
+ x, cls_token = x[0], x[1]
583
+ feature_shape = x.shape
584
+ if self.readout_type == "project":
585
+ x = x.flatten(2).permute((0, 2, 1))
586
+ readout = cls_token.unsqueeze(1).expand_as(x)
587
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
588
+ x = x.permute(0, 2, 1).reshape(feature_shape)
589
+ elif self.readout_type == "add":
590
+ x = x.flatten(2) + cls_token.unsqueeze(-1)
591
+ x = x.reshape(feature_shape)
592
+ else:
593
+ pass
594
+ x = self.projects[i](x)
595
+ x = self.resize_layers[i](x)
596
+ out.append(x)
597
+ return out
598
+
599
+
600
+ class PreActResidualConvUnit(nn.Module):
601
+ """ResidualConvUnit, pre-activate residual unit.
602
+ Args:
603
+ in_channels (int): number of channels in the input feature map.
604
+ act_layer (nn.Module): activation layer.
605
+ norm_layer (nn.Module): norm layer.
606
+ stride (int): stride of the first block. Default: 1
607
+ dilation (int): dilation rate for convs layers. Default: 1.
608
+ """
609
+
610
+ def __init__(self, in_channels, act_layer, norm_layer, stride=1, dilation=1):
611
+ super(PreActResidualConvUnit, self).__init__()
612
+
613
+ self.conv1 = ConvModule(
614
+ in_channels,
615
+ in_channels,
616
+ 3,
617
+ stride=stride,
618
+ padding=dilation,
619
+ dilation=dilation,
620
+ norm_layer=norm_layer,
621
+ act_layer=act_layer,
622
+ bias=False,
623
+ order=("act", "conv", "norm"),
624
+ )
625
+
626
+ self.conv2 = ConvModule(
627
+ in_channels,
628
+ in_channels,
629
+ 3,
630
+ padding=1,
631
+ norm_layer=norm_layer,
632
+ act_layer=act_layer,
633
+ bias=False,
634
+ order=("act", "conv", "norm"),
635
+ )
636
+
637
+ def forward(self, inputs):
638
+ inputs_ = inputs.clone()
639
+ x = self.conv1(inputs)
640
+ x = self.conv2(x)
641
+ return x + inputs_
642
+
643
+
644
+ class FeatureFusionBlock(nn.Module):
645
+ """FeatureFusionBlock, merge feature map from different stages.
646
+ Args:
647
+ in_channels (int): Input channels.
648
+ act_layer (nn.Module): activation layer for ResidualConvUnit.
649
+ norm_layer (nn.Module): normalization layer.
650
+ expand (bool): Whether expand the channels in post process block.
651
+ Default: False.
652
+ align_corners (bool): align_corner setting for bilinear upsample.
653
+ Default: True.
654
+ """
655
+
656
+ def __init__(self, in_channels, act_layer, norm_layer, expand=False, align_corners=True):
657
+ super(FeatureFusionBlock, self).__init__()
658
+
659
+ self.in_channels = in_channels
660
+ self.expand = expand
661
+ self.align_corners = align_corners
662
+
663
+ self.out_channels = in_channels
664
+ if self.expand:
665
+ self.out_channels = in_channels // 2
666
+
667
+ self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_layer=None, bias=True)
668
+
669
+ self.res_conv_unit1 = PreActResidualConvUnit(
670
+ in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer
671
+ )
672
+ self.res_conv_unit2 = PreActResidualConvUnit(
673
+ in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer
674
+ )
675
+
676
+ def forward(self, *inputs):
677
+ x = inputs[0]
678
+ if len(inputs) == 2:
679
+ if x.shape != inputs[1].shape:
680
+ res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False)
681
+ else:
682
+ res = inputs[1]
683
+ x = x + self.res_conv_unit1(res)
684
+ x = self.res_conv_unit2(x)
685
+ x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners)
686
+ x = self.project(x)
687
+ return x
688
+
689
+
690
+ class DPTHead(DepthBaseDecodeHead):
691
+ """Vision Transformers for Dense Prediction.
692
+ This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_.
693
+ Args:
694
+ embed_dims (int): The embed dimension of the ViT backbone.
695
+ Default: 768.
696
+ post_process_channels (List): Out channels of post process conv
697
+ layers. Default: [96, 192, 384, 768].
698
+ readout_type (str): Type of readout operation. Default: 'ignore'.
699
+ patch_size (int): The patch size. Default: 16.
700
+ expand_channels (bool): Whether expand the channels in post process
701
+ block. Default: False.
702
+ """
703
+
704
+ def __init__(
705
+ self,
706
+ embed_dims=768,
707
+ post_process_channels=[96, 192, 384, 768],
708
+ readout_type="ignore",
709
+ patch_size=16,
710
+ expand_channels=False,
711
+ **kwargs,
712
+ ):
713
+ super(DPTHead, self).__init__(**kwargs)
714
+
715
+ self.in_channels = self.in_channels
716
+ self.expand_channels = expand_channels
717
+ self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size)
718
+
719
+ self.post_process_channels = [
720
+ channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels)
721
+ ]
722
+ self.convs = nn.ModuleList()
723
+ for channel in self.post_process_channels:
724
+ self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_layer=None, bias=False))
725
+ self.fusion_blocks = nn.ModuleList()
726
+ for _ in range(len(self.convs)):
727
+ self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_layer, self.norm_layer))
728
+ self.fusion_blocks[0].res_conv_unit1 = None
729
+ self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_layer=self.norm_layer)
730
+ self.num_fusion_blocks = len(self.fusion_blocks)
731
+ self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers)
732
+ self.num_post_process_channels = len(self.post_process_channels)
733
+ assert self.num_fusion_blocks == self.num_reassemble_blocks
734
+ assert self.num_reassemble_blocks == self.num_post_process_channels
735
+ self.conv_depth = HeadDepth(self.channels)
736
+
737
+ def forward(self, inputs, img_metas):
738
+ assert len(inputs) == self.num_reassemble_blocks
739
+ x = [inp for inp in inputs]
740
+ x = self.reassemble_blocks(x)
741
+ x = [self.convs[i](feature) for i, feature in enumerate(x)]
742
+ out = self.fusion_blocks[0](x[-1])
743
+ for i in range(1, len(self.fusion_blocks)):
744
+ out = self.fusion_blocks[i](out, x[-(i + 1)])
745
+ out = self.project(out)
746
+ out = self.depth_pred(out)
747
+ return out
LHM/models/encoders/dinov2/hub/depth/encoder_decoder.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from collections import OrderedDict
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ from .ops import resize
13
+
14
+
15
+ def add_prefix(inputs, prefix):
16
+ """Add prefix for dict.
17
+
18
+ Args:
19
+ inputs (dict): The input dict with str keys.
20
+ prefix (str): The prefix to add.
21
+
22
+ Returns:
23
+
24
+ dict: The dict with keys updated with ``prefix``.
25
+ """
26
+
27
+ outputs = dict()
28
+ for name, value in inputs.items():
29
+ outputs[f"{prefix}.{name}"] = value
30
+
31
+ return outputs
32
+
33
+
34
+ class DepthEncoderDecoder(nn.Module):
35
+ """Encoder Decoder depther.
36
+
37
+ EncoderDecoder typically consists of backbone and decode_head.
38
+ """
39
+
40
+ def __init__(self, backbone, decode_head):
41
+ super(DepthEncoderDecoder, self).__init__()
42
+
43
+ self.backbone = backbone
44
+ self.decode_head = decode_head
45
+ self.align_corners = self.decode_head.align_corners
46
+
47
+ def extract_feat(self, img):
48
+ """Extract features from images."""
49
+ return self.backbone(img)
50
+
51
+ def encode_decode(self, img, img_metas, rescale=True, size=None):
52
+ """Encode images with backbone and decode into a depth estimation
53
+ map of the same size as input."""
54
+ x = self.extract_feat(img)
55
+ out = self._decode_head_forward_test(x, img_metas)
56
+ # crop the pred depth to the certain range.
57
+ out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth)
58
+ if rescale:
59
+ if size is None:
60
+ if img_metas is not None:
61
+ size = img_metas[0]["ori_shape"][:2]
62
+ else:
63
+ size = img.shape[2:]
64
+ out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners)
65
+ return out
66
+
67
+ def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs):
68
+ """Run forward function and calculate loss for decode head in
69
+ training."""
70
+ losses = dict()
71
+ loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, **kwargs)
72
+ losses.update(add_prefix(loss_decode, "decode"))
73
+ return losses
74
+
75
+ def _decode_head_forward_test(self, x, img_metas):
76
+ """Run forward function and calculate loss for decode head in
77
+ inference."""
78
+ depth_pred = self.decode_head.forward_test(x, img_metas)
79
+ return depth_pred
80
+
81
+ def forward_dummy(self, img):
82
+ """Dummy forward function."""
83
+ depth = self.encode_decode(img, None)
84
+
85
+ return depth
86
+
87
+ def forward_train(self, img, img_metas, depth_gt, **kwargs):
88
+ """Forward function for training.
89
+
90
+ Args:
91
+ img (Tensor): Input images.
92
+ img_metas (list[dict]): List of image info dict where each dict
93
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
94
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
95
+ For details on the values of these keys see
96
+ `depth/datasets/pipelines/formatting.py:Collect`.
97
+ depth_gt (Tensor): Depth gt
98
+ used if the architecture supports depth estimation task.
99
+
100
+ Returns:
101
+ dict[str, Tensor]: a dictionary of loss components
102
+ """
103
+
104
+ x = self.extract_feat(img)
105
+
106
+ losses = dict()
107
+
108
+ # the last of x saves the info from neck
109
+ loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs)
110
+
111
+ losses.update(loss_decode)
112
+
113
+ return losses
114
+
115
+ def whole_inference(self, img, img_meta, rescale, size=None):
116
+ """Inference with full image."""
117
+ return self.encode_decode(img, img_meta, rescale, size=size)
118
+
119
+ def slide_inference(self, img, img_meta, rescale, stride, crop_size):
120
+ """Inference by sliding-window with overlap.
121
+
122
+ If h_crop > h_img or w_crop > w_img, the small patch will be used to
123
+ decode without padding.
124
+ """
125
+
126
+ h_stride, w_stride = stride
127
+ h_crop, w_crop = crop_size
128
+ batch_size, _, h_img, w_img = img.size()
129
+ h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
130
+ w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
131
+ preds = img.new_zeros((batch_size, 1, h_img, w_img))
132
+ count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
133
+ for h_idx in range(h_grids):
134
+ for w_idx in range(w_grids):
135
+ y1 = h_idx * h_stride
136
+ x1 = w_idx * w_stride
137
+ y2 = min(y1 + h_crop, h_img)
138
+ x2 = min(x1 + w_crop, w_img)
139
+ y1 = max(y2 - h_crop, 0)
140
+ x1 = max(x2 - w_crop, 0)
141
+ crop_img = img[:, :, y1:y2, x1:x2]
142
+ depth_pred = self.encode_decode(crop_img, img_meta, rescale)
143
+ preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2)))
144
+
145
+ count_mat[:, :, y1:y2, x1:x2] += 1
146
+ assert (count_mat == 0).sum() == 0
147
+ if torch.onnx.is_in_onnx_export():
148
+ # cast count_mat to constant while exporting to ONNX
149
+ count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device)
150
+ preds = preds / count_mat
151
+ return preds
152
+
153
+ def inference(self, img, img_meta, rescale, size=None, mode="whole"):
154
+ """Inference with slide/whole style.
155
+
156
+ Args:
157
+ img (Tensor): The input image of shape (N, 3, H, W).
158
+ img_meta (dict): Image info dict where each dict has: 'img_shape',
159
+ 'scale_factor', 'flip', and may also contain
160
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
161
+ For details on the values of these keys see
162
+ `depth/datasets/pipelines/formatting.py:Collect`.
163
+ rescale (bool): Whether rescale back to original shape.
164
+
165
+ Returns:
166
+ Tensor: The output depth map.
167
+ """
168
+
169
+ assert mode in ["slide", "whole"]
170
+ ori_shape = img_meta[0]["ori_shape"]
171
+ assert all(_["ori_shape"] == ori_shape for _ in img_meta)
172
+ if mode == "slide":
173
+ depth_pred = self.slide_inference(img, img_meta, rescale)
174
+ else:
175
+ depth_pred = self.whole_inference(img, img_meta, rescale, size=size)
176
+ output = depth_pred
177
+ flip = img_meta[0]["flip"]
178
+ if flip:
179
+ flip_direction = img_meta[0]["flip_direction"]
180
+ assert flip_direction in ["horizontal", "vertical"]
181
+ if flip_direction == "horizontal":
182
+ output = output.flip(dims=(3,))
183
+ elif flip_direction == "vertical":
184
+ output = output.flip(dims=(2,))
185
+
186
+ return output
187
+
188
+ def simple_test(self, img, img_meta, rescale=True):
189
+ """Simple test with single image."""
190
+ depth_pred = self.inference(img, img_meta, rescale)
191
+ if torch.onnx.is_in_onnx_export():
192
+ # our inference backend only support 4D output
193
+ depth_pred = depth_pred.unsqueeze(0)
194
+ return depth_pred
195
+ depth_pred = depth_pred.cpu().numpy()
196
+ # unravel batch dim
197
+ depth_pred = list(depth_pred)
198
+ return depth_pred
199
+
200
+ def aug_test(self, imgs, img_metas, rescale=True):
201
+ """Test with augmentations.
202
+
203
+ Only rescale=True is supported.
204
+ """
205
+ # aug_test rescale all imgs back to ori_shape for now
206
+ assert rescale
207
+ # to save memory, we get augmented depth logit inplace
208
+ depth_pred = self.inference(imgs[0], img_metas[0], rescale)
209
+ for i in range(1, len(imgs)):
210
+ cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:])
211
+ depth_pred += cur_depth_pred
212
+ depth_pred /= len(imgs)
213
+ depth_pred = depth_pred.cpu().numpy()
214
+ # unravel batch dim
215
+ depth_pred = list(depth_pred)
216
+ return depth_pred
217
+
218
+ def forward_test(self, imgs, img_metas, **kwargs):
219
+ """
220
+ Args:
221
+ imgs (List[Tensor]): the outer list indicates test-time
222
+ augmentations and inner Tensor should have a shape NxCxHxW,
223
+ which contains all images in the batch.
224
+ img_metas (List[List[dict]]): the outer list indicates test-time
225
+ augs (multiscale, flip, etc.) and the inner list indicates
226
+ images in a batch.
227
+ """
228
+ for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]:
229
+ if not isinstance(var, list):
230
+ raise TypeError(f"{name} must be a list, but got " f"{type(var)}")
231
+ num_augs = len(imgs)
232
+ if num_augs != len(img_metas):
233
+ raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})")
234
+ # all images in the same aug batch all of the same ori_shape and pad
235
+ # shape
236
+ for img_meta in img_metas:
237
+ ori_shapes = [_["ori_shape"] for _ in img_meta]
238
+ assert all(shape == ori_shapes[0] for shape in ori_shapes)
239
+ img_shapes = [_["img_shape"] for _ in img_meta]
240
+ assert all(shape == img_shapes[0] for shape in img_shapes)
241
+ pad_shapes = [_["pad_shape"] for _ in img_meta]
242
+ assert all(shape == pad_shapes[0] for shape in pad_shapes)
243
+
244
+ if num_augs == 1:
245
+ return self.simple_test(imgs[0], img_metas[0], **kwargs)
246
+ else:
247
+ return self.aug_test(imgs, img_metas, **kwargs)
248
+
249
+ def forward(self, img, img_metas, return_loss=True, **kwargs):
250
+ """Calls either :func:`forward_train` or :func:`forward_test` depending
251
+ on whether ``return_loss`` is ``True``.
252
+
253
+ Note this setting will change the expected inputs. When
254
+ ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
255
+ and List[dict]), and when ``resturn_loss=False``, img and img_meta
256
+ should be double nested (i.e. List[Tensor], List[List[dict]]), with
257
+ the outer list indicating test time augmentations.
258
+ """
259
+ if return_loss:
260
+ return self.forward_train(img, img_metas, **kwargs)
261
+ else:
262
+ return self.forward_test(img, img_metas, **kwargs)
263
+
264
+ def train_step(self, data_batch, optimizer, **kwargs):
265
+ """The iteration step during training.
266
+
267
+ This method defines an iteration step during training, except for the
268
+ back propagation and optimizer updating, which are done in an optimizer
269
+ hook. Note that in some complicated cases or models, the whole process
270
+ including back propagation and optimizer updating is also defined in
271
+ this method, such as GAN.
272
+
273
+ Args:
274
+ data (dict): The output of dataloader.
275
+ optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
276
+ runner is passed to ``train_step()``. This argument is unused
277
+ and reserved.
278
+
279
+ Returns:
280
+ dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
281
+ ``num_samples``.
282
+ ``loss`` is a tensor for back propagation, which can be a
283
+ weighted sum of multiple losses.
284
+ ``log_vars`` contains all the variables to be sent to the
285
+ logger.
286
+ ``num_samples`` indicates the batch size (when the model is
287
+ DDP, it means the batch size on each GPU), which is used for
288
+ averaging the logs.
289
+ """
290
+ losses = self(**data_batch)
291
+
292
+ # split losses and images
293
+ real_losses = {}
294
+ log_imgs = {}
295
+ for k, v in losses.items():
296
+ if "img" in k:
297
+ log_imgs[k] = v
298
+ else:
299
+ real_losses[k] = v
300
+
301
+ loss, log_vars = self._parse_losses(real_losses)
302
+
303
+ outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs)
304
+
305
+ return outputs
306
+
307
+ def val_step(self, data_batch, **kwargs):
308
+ """The iteration step during validation.
309
+
310
+ This method shares the same signature as :func:`train_step`, but used
311
+ during val epochs. Note that the evaluation after training epochs is
312
+ not implemented with this method, but an evaluation hook.
313
+ """
314
+ output = self(**data_batch, **kwargs)
315
+ return output
316
+
317
+ @staticmethod
318
+ def _parse_losses(losses):
319
+ import torch.distributed as dist
320
+
321
+ """Parse the raw outputs (losses) of the network.
322
+
323
+ Args:
324
+ losses (dict): Raw output of the network, which usually contain
325
+ losses and other necessary information.
326
+
327
+ Returns:
328
+ tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
329
+ which may be a weighted sum of all losses, log_vars contains
330
+ all the variables to be sent to the logger.
331
+ """
332
+ log_vars = OrderedDict()
333
+ for loss_name, loss_value in losses.items():
334
+ if isinstance(loss_value, torch.Tensor):
335
+ log_vars[loss_name] = loss_value.mean()
336
+ elif isinstance(loss_value, list):
337
+ log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
338
+ else:
339
+ raise TypeError(f"{loss_name} is not a tensor or list of tensors")
340
+
341
+ loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key)
342
+
343
+ log_vars["loss"] = loss
344
+ for loss_name, loss_value in log_vars.items():
345
+ # reduce loss when distributed training
346
+ if dist.is_available() and dist.is_initialized():
347
+ loss_value = loss_value.data.clone()
348
+ dist.all_reduce(loss_value.div_(dist.get_world_size()))
349
+ log_vars[loss_name] = loss_value.item()
350
+
351
+ return loss, log_vars