rawalkhirodkar commited on
Commit
28c256d
·
0 Parent(s):

Add initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +38 -0
  2. .gitignore +1 -0
  3. NOTES.md +11 -0
  4. README.md +13 -0
  5. app.py +453 -0
  6. assets/checkpoints/rtmdet_m_8xb32-100e_coco-obj365-person-235e8209.pth +3 -0
  7. assets/checkpoints/sapiens_0.3b_goliath_best_goliath_AP_575_torchscript.pt2 +3 -0
  8. assets/checkpoints/sapiens_1b_goliath_best_goliath_AP_640_torchscript.pt2 +3 -0
  9. assets/images/68204.png +3 -0
  10. assets/images/68210.png +3 -0
  11. assets/images/68658.png +3 -0
  12. assets/images/68666.png +3 -0
  13. assets/images/68691.png +3 -0
  14. assets/images/68956.png +3 -0
  15. assets/images/pexels-amresh444-17315601.png +3 -0
  16. assets/images/pexels-gabby-k-6311686.png +3 -0
  17. assets/images/pexels-julia-m-cameron-4145040.png +3 -0
  18. assets/images/pexels-marcus-aurelius-6787357.png +3 -0
  19. assets/images/pexels-mo-saeed-3616599-5409085.png +3 -0
  20. assets/images/pexels-riedelmax-27355495.png +3 -0
  21. assets/images/pexels-sergeymakashin-5368660.png +3 -0
  22. assets/images/pexels-vinicius-wiesehofer-289347-4219918.png +3 -0
  23. assets/rtmdet_m_640-8xb32_coco-person_no_nms.py +20 -0
  24. build_wheel.py +26 -0
  25. classes_and_palettes.py +1024 -0
  26. detector_utils.py +196 -0
  27. external/cv/.gitignore +125 -0
  28. external/cv/MANIFEST.in +6 -0
  29. external/cv/dist/sapiens_cv-1.0.0-cp310-cp310-linux_x86_64.whl +3 -0
  30. external/cv/mmcv/__init__.py +18 -0
  31. external/cv/mmcv/arraymisc/__init__.py +9 -0
  32. external/cv/mmcv/arraymisc/quantization.py +70 -0
  33. external/cv/mmcv/cnn/__init__.py +33 -0
  34. external/cv/mmcv/cnn/alexnet.py +68 -0
  35. external/cv/mmcv/cnn/bricks/__init__.py +37 -0
  36. external/cv/mmcv/cnn/bricks/activation.py +119 -0
  37. external/cv/mmcv/cnn/bricks/context_block.py +131 -0
  38. external/cv/mmcv/cnn/bricks/conv.py +56 -0
  39. external/cv/mmcv/cnn/bricks/conv2d_adaptive_padding.py +68 -0
  40. external/cv/mmcv/cnn/bricks/conv_module.py +343 -0
  41. external/cv/mmcv/cnn/bricks/conv_ws.py +158 -0
  42. external/cv/mmcv/cnn/bricks/depthwise_separable_conv_module.py +104 -0
  43. external/cv/mmcv/cnn/bricks/drop.py +72 -0
  44. external/cv/mmcv/cnn/bricks/generalized_attention.py +416 -0
  45. external/cv/mmcv/cnn/bricks/hsigmoid.py +55 -0
  46. external/cv/mmcv/cnn/bricks/hswish.py +44 -0
  47. external/cv/mmcv/cnn/bricks/non_local.py +313 -0
  48. external/cv/mmcv/cnn/bricks/norm.py +161 -0
  49. external/cv/mmcv/cnn/bricks/padding.py +48 -0
  50. external/cv/mmcv/cnn/bricks/plugin.py +106 -0
.gitattributes ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.pt2 filter=lfs diff=lfs merge=lfs -text
38
+ *.whl filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ **pycache**
NOTES.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Create wheel for mmcv
2
+ ```
3
+ cd ./external/engine
4
+ python setup.py bdist_wheel
5
+
6
+ cd ./external/cv
7
+ MMCV_WITH_OPS=1 python setup.py bdist_wheel
8
+
9
+ cd ./external/det
10
+ python setup.py bdist_wheel
11
+ ```
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Sapiens Pose
3
+ emoji: 📊
4
+ colorFrom: pink
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 4.42.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: cc-by-nc-4.0
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+ import spaces
4
+ import gradio as gr
5
+ import numpy as np
6
+ import torch
7
+ import json
8
+ import tempfile
9
+ import torch.nn.functional as F
10
+ from torchvision import transforms
11
+ from PIL import Image
12
+ import cv2
13
+ from gradio.themes.utils import sizes
14
+ from classes_and_palettes import (
15
+ COCO_KPTS_COLORS,
16
+ COCO_WHOLEBODY_KPTS_COLORS,
17
+ GOLIATH_KPTS_COLORS,
18
+ GOLIATH_SKELETON_INFO,
19
+ GOLIATH_KEYPOINTS
20
+ )
21
+
22
+ import os
23
+ import sys
24
+ import subprocess
25
+ import importlib.util
26
+
27
+ def is_package_installed(package_name):
28
+ return importlib.util.find_spec(package_name) is not None
29
+
30
+ def find_wheel(package_path):
31
+ dist_dir = os.path.join(package_path, "dist")
32
+ if os.path.exists(dist_dir):
33
+ wheel_files = [f for f in os.listdir(dist_dir) if f.endswith('.whl')]
34
+ if wheel_files:
35
+ return os.path.join(dist_dir, wheel_files[0])
36
+ return None
37
+
38
+ def install_from_wheel(package_name, package_path):
39
+ wheel_file = find_wheel(package_path)
40
+ if wheel_file:
41
+ print(f"Installing {package_name} from wheel: {wheel_file}")
42
+ subprocess.check_call([sys.executable, "-m", "pip", "install", wheel_file])
43
+ else:
44
+ print(f"{package_name} wheel not found in {package_path}. Please build it first.")
45
+ sys.exit(1)
46
+
47
+ def install_local_packages():
48
+ packages = [
49
+ ("mmengine", "./external/engine"),
50
+ ("mmcv", "./external/cv"),
51
+ ("mmdet", "./external/det")
52
+ ]
53
+
54
+ for package_name, package_path in packages:
55
+ if not is_package_installed(package_name):
56
+ print(f"Installing {package_name}...")
57
+ install_from_wheel(package_name, package_path)
58
+ else:
59
+ print(f"{package_name} is already installed.")
60
+
61
+ # Run the installation at the start of your app
62
+ install_local_packages()
63
+
64
+ from detector_utils import (
65
+ adapt_mmdet_pipeline,
66
+ init_detector,
67
+ process_images_detector,
68
+ )
69
+
70
+ class Config:
71
+ ASSETS_DIR = os.path.join(os.path.dirname(__file__), 'assets')
72
+ CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints")
73
+ CHECKPOINTS = {
74
+ "0.3b": "sapiens_0.3b_goliath_best_goliath_AP_575_torchscript.pt2",
75
+ "1b": "sapiens_1b_goliath_best_goliath_AP_640_torchscript.pt2",
76
+ }
77
+ DETECTION_CHECKPOINT = os.path.join(CHECKPOINTS_DIR, 'rtmdet_m_8xb32-100e_coco-obj365-person-235e8209.pth')
78
+ DETECTION_CONFIG = os.path.join(ASSETS_DIR, 'rtmdet_m_640-8xb32_coco-person_no_nms.py')
79
+
80
+ class ModelManager:
81
+ @staticmethod
82
+ def load_model(checkpoint_name: str):
83
+ if checkpoint_name is None:
84
+ return None
85
+ checkpoint_path = os.path.join(Config.CHECKPOINTS_DIR, checkpoint_name)
86
+ model = torch.jit.load(checkpoint_path)
87
+ model.eval()
88
+ model.to("cuda")
89
+ return model
90
+
91
+ @staticmethod
92
+ @torch.inference_mode()
93
+ def run_model(model, input_tensor):
94
+ return model(input_tensor)
95
+
96
+ class ImageProcessor:
97
+ def __init__(self):
98
+ self.transform = transforms.Compose([
99
+ transforms.Resize((1024, 768)),
100
+ transforms.ToTensor(),
101
+ transforms.Normalize(mean=[123.5/255, 116.5/255, 103.5/255],
102
+ std=[58.5/255, 57.0/255, 57.5/255])
103
+ ])
104
+ self.detector = init_detector(
105
+ Config.DETECTION_CONFIG, Config.DETECTION_CHECKPOINT, device='cpu'
106
+ )
107
+ self.detector.cfg = adapt_mmdet_pipeline(self.detector.cfg)
108
+
109
+ def detect_persons(self, image: Image.Image):
110
+ # Convert PIL Image to tensor
111
+ image = np.array(image)
112
+ image = np.expand_dims(image, axis=0)
113
+
114
+ # Perform person detection
115
+ bboxes_batch = process_images_detector(
116
+ image,
117
+ self.detector
118
+ )
119
+ bboxes = self.get_person_bboxes(bboxes_batch[0]) # Get bboxes for the first (and only) image
120
+
121
+ return bboxes
122
+
123
+ def get_person_bboxes(self, bboxes_batch, score_thr=0.3):
124
+ person_bboxes = []
125
+ for bbox in bboxes_batch:
126
+ if len(bbox) == 5: # [x1, y1, x2, y2, score]
127
+ if bbox[4] > score_thr:
128
+ person_bboxes.append(bbox)
129
+ elif len(bbox) == 4: # [x1, y1, x2, y2]
130
+ person_bboxes.append(bbox + [1.0]) # Add a default score of 1.0
131
+ return person_bboxes
132
+
133
+ @spaces.GPU
134
+ @torch.inference_mode()
135
+ def estimate_pose(self, image: Image.Image, bboxes: List[List[float]], model_name: str, kpt_threshold: float):
136
+ pose_model = ModelManager.load_model(Config.CHECKPOINTS[model_name])
137
+
138
+ result_image = image.copy()
139
+ all_keypoints = [] # List to store keypoints for all persons
140
+
141
+ for bbox in bboxes:
142
+ cropped_img = self.crop_image(result_image, bbox)
143
+ input_tensor = self.transform(cropped_img).unsqueeze(0).to("cuda")
144
+ heatmaps = ModelManager.run_model(pose_model, input_tensor)
145
+ keypoints = self.heatmaps_to_keypoints(heatmaps[0].cpu().numpy())
146
+ all_keypoints.append(keypoints) # Collect keypoints
147
+ result_image = self.draw_keypoints(result_image, keypoints, bbox, kpt_threshold)
148
+
149
+ return result_image, all_keypoints
150
+
151
+ def process_image(self, image: Image.Image, model_name: str, kpt_threshold: str):
152
+ bboxes = self.detect_persons(image)
153
+ result_image, keypoints = self.estimate_pose(image, bboxes, model_name, float(kpt_threshold))
154
+ return result_image, keypoints
155
+
156
+ def crop_image(self, image, bbox):
157
+ if len(bbox) == 4:
158
+ x1, y1, x2, y2 = map(int, bbox)
159
+ elif len(bbox) >= 5:
160
+ x1, y1, x2, y2, _ = map(int, bbox[:5])
161
+ else:
162
+ raise ValueError(f"Unexpected bbox format: {bbox}")
163
+
164
+ crop = image.crop((x1, y1, x2, y2))
165
+ return crop
166
+
167
+ @staticmethod
168
+ def heatmaps_to_keypoints(heatmaps):
169
+ num_joints = heatmaps.shape[0] # Should be 308
170
+ keypoints = {}
171
+ for i, name in enumerate(GOLIATH_KEYPOINTS):
172
+ if i < num_joints:
173
+ heatmap = heatmaps[i]
174
+ y, x = np.unravel_index(np.argmax(heatmap), heatmap.shape)
175
+ conf = heatmap[y, x]
176
+ keypoints[name] = (float(x), float(y), float(conf))
177
+ return keypoints
178
+
179
+ @staticmethod
180
+ def draw_keypoints(image, keypoints, bbox, kpt_threshold):
181
+ image = np.array(image)
182
+
183
+ # Handle both 4 and 5-element bounding boxes
184
+ if len(bbox) == 4:
185
+ x1, y1, x2, y2 = map(int, bbox)
186
+ elif len(bbox) >= 5:
187
+ x1, y1, x2, y2, _ = map(int, bbox[:5])
188
+ else:
189
+ raise ValueError(f"Unexpected bbox format: {bbox}")
190
+
191
+ # Calculate adaptive radius and thickness based on bounding box size
192
+ bbox_width = x2 - x1
193
+ bbox_height = y2 - y1
194
+ bbox_size = np.sqrt(bbox_width * bbox_height)
195
+
196
+ radius = max(1, int(bbox_size * 0.006)) # minimum 1 pixel
197
+ thickness = max(1, int(bbox_size * 0.006)) # minimum 1 pixel
198
+ bbox_thickness = max(1, thickness//4)
199
+
200
+ cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), bbox_thickness)
201
+
202
+ # Draw keypoints
203
+ for i, (name, (x, y, conf)) in enumerate(keypoints.items()):
204
+ if conf > kpt_threshold and i < len(GOLIATH_KPTS_COLORS):
205
+ x_coord = int(x * bbox_width / 192) + x1
206
+ y_coord = int(y * bbox_height / 256) + y1
207
+ color = GOLIATH_KPTS_COLORS[i]
208
+ cv2.circle(image, (x_coord, y_coord), radius, color, -1)
209
+
210
+ # Draw skeleton
211
+ for _, link_info in GOLIATH_SKELETON_INFO.items():
212
+ pt1_name, pt2_name = link_info['link']
213
+ color = link_info['color']
214
+
215
+ if pt1_name in keypoints and pt2_name in keypoints:
216
+ pt1 = keypoints[pt1_name]
217
+ pt2 = keypoints[pt2_name]
218
+ if pt1[2] > kpt_threshold and pt2[2] > kpt_threshold:
219
+ x1_coord = int(pt1[0] * bbox_width / 192) + x1
220
+ y1_coord = int(pt1[1] * bbox_height / 256) + y1
221
+ x2_coord = int(pt2[0] * bbox_width / 192) + x1
222
+ y2_coord = int(pt2[1] * bbox_height / 256) + y1
223
+ cv2.line(image, (x1_coord, y1_coord), (x2_coord, y2_coord), color, thickness=thickness)
224
+
225
+ return Image.fromarray(image)
226
+
227
+ class GradioInterface:
228
+ def __init__(self):
229
+ self.image_processor = ImageProcessor()
230
+
231
+ def create_interface(self):
232
+ app_styles = """
233
+ <style>
234
+ /* Global Styles */
235
+ body, #root {
236
+ font-family: Helvetica, Arial, sans-serif;
237
+ background-color: #1a1a1a;
238
+ color: #fafafa;
239
+ }
240
+ /* Header Styles */
241
+ .app-header {
242
+ background: linear-gradient(45deg, #1a1a1a 0%, #333333 100%);
243
+ padding: 24px;
244
+ border-radius: 8px;
245
+ margin-bottom: 24px;
246
+ text-align: center;
247
+ }
248
+ .app-title {
249
+ font-size: 48px;
250
+ margin: 0;
251
+ color: #fafafa;
252
+ }
253
+ .app-subtitle {
254
+ font-size: 24px;
255
+ margin: 8px 0 16px;
256
+ color: #fafafa;
257
+ }
258
+ .app-description {
259
+ font-size: 16px;
260
+ line-height: 1.6;
261
+ opacity: 0.8;
262
+ margin-bottom: 24px;
263
+ }
264
+ /* Button Styles */
265
+ .publication-links {
266
+ display: flex;
267
+ justify-content: center;
268
+ flex-wrap: wrap;
269
+ gap: 8px;
270
+ margin-bottom: 16px;
271
+ }
272
+ .publication-link {
273
+ display: inline-flex;
274
+ align-items: center;
275
+ padding: 8px 16px;
276
+ background-color: #333;
277
+ color: #fff !important;
278
+ text-decoration: none !important;
279
+ border-radius: 20px;
280
+ font-size: 14px;
281
+ transition: background-color 0.3s;
282
+ }
283
+ .publication-link:hover {
284
+ background-color: #555;
285
+ }
286
+ .publication-link i {
287
+ margin-right: 8px;
288
+ }
289
+ /* Content Styles */
290
+ .content-container {
291
+ background-color: #2a2a2a;
292
+ border-radius: 8px;
293
+ padding: 24px;
294
+ margin-bottom: 24px;
295
+ }
296
+ /* Image Styles */
297
+ .image-preview img {
298
+ max-width: 512px;
299
+ max-height: 512px;
300
+ margin: 0 auto;
301
+ border-radius: 4px;
302
+ display: block;
303
+ object-fit: contain;
304
+ }
305
+ /* Control Styles */
306
+ .control-panel {
307
+ background-color: #333;
308
+ padding: 16px;
309
+ border-radius: 8px;
310
+ margin-top: 16px;
311
+ }
312
+ /* Gradio Component Overrides */
313
+ .gr-button {
314
+ background-color: #4a4a4a;
315
+ color: #fff;
316
+ border: none;
317
+ border-radius: 4px;
318
+ padding: 8px 16px;
319
+ cursor: pointer;
320
+ transition: background-color 0.3s;
321
+ }
322
+ .gr-button:hover {
323
+ background-color: #5a5a5a;
324
+ }
325
+ .gr-input, .gr-dropdown {
326
+ background-color: #3a3a3a;
327
+ color: #fff;
328
+ border: 1px solid #4a4a4a;
329
+ border-radius: 4px;
330
+ padding: 8px;
331
+ }
332
+ .gr-form {
333
+ background-color: transparent;
334
+ }
335
+ .gr-panel {
336
+ border: none;
337
+ background-color: transparent;
338
+ }
339
+ /* Override any conflicting styles from Bulma */
340
+ .button.is-normal.is-rounded.is-dark {
341
+ color: #fff !important;
342
+ text-decoration: none !important;
343
+ }
344
+ </style>
345
+ """
346
+
347
+ header_html = f"""
348
+ <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/css/bulma.min.css">
349
+ <link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.15.4/css/all.css">
350
+ {app_styles}
351
+ <div class="app-header">
352
+ <h1 class="app-title">Sapiens: Pose Estimation</h1>
353
+ <h2 class="app-subtitle">ECCV 2024 (Oral)</h2>
354
+ <p class="app-description">
355
+ Meta presents Sapiens, foundation models for human tasks pretrained on 300 million human images.
356
+ This demo showcases the finetuned pose estimation model. <br>
357
+ </p>
358
+ <div class="publication-links">
359
+ <a href="https://arxiv.org/abs/2408.12569" class="publication-link">
360
+ <i class="fas fa-file-pdf"></i>arXiv
361
+ </a>
362
+ <a href="https://github.com/facebookresearch/sapiens" class="publication-link">
363
+ <i class="fab fa-github"></i>Code
364
+ </a>
365
+ <a href="https://about.meta.com/realitylabs/codecavatars/sapiens/" class="publication-link">
366
+ <i class="fas fa-globe"></i>Meta
367
+ </a>
368
+ <a href="https://rawalkhirodkar.github.io/sapiens" class="publication-link">
369
+ <i class="fas fa-chart-bar"></i>Results
370
+ </a>
371
+ </div>
372
+ <div class="publication-links">
373
+ <a href="https://huggingface.co/spaces/facebook/sapiens_pose" class="publication-link">
374
+ <i class="fas fa-user"></i>Demo-Pose
375
+ </a>
376
+ <a href="https://huggingface.co/spaces/facebook/sapiens_seg" class="publication-link">
377
+ <i class="fas fa-puzzle-piece"></i>Demo-Seg
378
+ </a>
379
+ <a href="https://huggingface.co/spaces/facebook/sapiens_depth" class="publication-link">
380
+ <i class="fas fa-cube"></i>Demo-Depth
381
+ </a>
382
+ <a href="https://huggingface.co/spaces/facebook/sapiens_normal" class="publication-link">
383
+ <i class="fas fa-vector-square"></i>Demo-Normal
384
+ </a>
385
+ </div>
386
+ </div>
387
+ """
388
+
389
+ js_func = """
390
+ function refresh() {
391
+ const url = new URL(window.location);
392
+ if (url.searchParams.get('__theme') !== 'dark') {
393
+ url.searchParams.set('__theme', 'dark');
394
+ window.location.href = url.href;
395
+ }
396
+ }
397
+ """
398
+
399
+ def process_image(image, model_name, kpt_threshold):
400
+ result_image, keypoints = self.image_processor.process_image(image, model_name, kpt_threshold)
401
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".json", mode='w') as json_file:
402
+ json.dump(keypoints, json_file)
403
+ json_file_path = json_file.name
404
+ return result_image, json_file_path
405
+
406
+ with gr.Blocks(js=js_func, theme=gr.themes.Default()) as demo:
407
+ gr.HTML(header_html)
408
+ with gr.Row(elem_classes="content-container"):
409
+ with gr.Column():
410
+ input_image = gr.Image(label="Input Image", type="pil", format="png", elem_classes="image-preview")
411
+ with gr.Row():
412
+ model_name = gr.Dropdown(
413
+ label="Model Size",
414
+ choices=list(Config.CHECKPOINTS.keys()),
415
+ value="1b",
416
+ )
417
+ kpt_threshold = gr.Dropdown(
418
+ label="Min Keypoint Confidence",
419
+ choices=["0.1", "0.2", "0.3", "0.4", "0.5", "0.6", "0.7", "0.8", "0.9"],
420
+ value="0.3",
421
+ )
422
+ example_model = gr.Examples(
423
+ inputs=input_image,
424
+ examples_per_page=14,
425
+ examples=[
426
+ os.path.join(Config.ASSETS_DIR, "images", img)
427
+ for img in os.listdir(os.path.join(Config.ASSETS_DIR, "images"))
428
+ ],
429
+ )
430
+ with gr.Column():
431
+ result_image = gr.Image(label="Pose-308 Result", type="pil", elem_classes="image-preview")
432
+ json_output = gr.File(label="Pose-308 Output (.json)")
433
+ run_button = gr.Button("Run")
434
+
435
+ run_button.click(
436
+ fn=process_image,
437
+ inputs=[input_image, model_name, kpt_threshold],
438
+ outputs=[result_image, json_output],
439
+ )
440
+
441
+ return demo
442
+
443
+ def main():
444
+ if torch.cuda.is_available():
445
+ torch.backends.cuda.matmul.allow_tf32 = True
446
+ torch.backends.cudnn.allow_tf32 = True
447
+
448
+ interface = GradioInterface()
449
+ demo = interface.create_interface()
450
+ demo.launch(share=False)
451
+
452
+ if __name__ == "__main__":
453
+ main()
assets/checkpoints/rtmdet_m_8xb32-100e_coco-obj365-person-235e8209.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b66b27072c6a3cd4f093882df440921987076131fb78a7df7b1cf92d67f41509
3
+ size 99149914
assets/checkpoints/sapiens_0.3b_goliath_best_goliath_AP_575_torchscript.pt2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21cf7e3e723720d847bee6d3b321bfcdb33268c9f1418d7552552264ae0a5a9b
3
+ size 1319579523
assets/checkpoints/sapiens_1b_goliath_best_goliath_AP_640_torchscript.pt2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6218c6be17697157f9e65ee34054a94ab8ca0f637380fa5748c18e04814976e
3
+ size 4677162331
assets/images/68204.png ADDED

Git LFS Details

  • SHA256: 9b0268cb801ed164864a4b5f6d131e0ac5cc2fbd149a6467d5d0c97da47122c2
  • Pointer size: 132 Bytes
  • Size of remote file: 4.29 MB
assets/images/68210.png ADDED

Git LFS Details

  • SHA256: dbe5f80498af4ebd1ff09ae4184f37c20ba981e53bd554c3cc78d39ae0ee7fd7
  • Pointer size: 132 Bytes
  • Size of remote file: 3.93 MB
assets/images/68658.png ADDED

Git LFS Details

  • SHA256: 61a68b619bd17235e683324f2826ce0693322e45ab8c86f1c057851ecb333ac7
  • Pointer size: 132 Bytes
  • Size of remote file: 5.1 MB
assets/images/68666.png ADDED

Git LFS Details

  • SHA256: ea3047e6c2ccb485fdb3966aa2325e803cbf49c27c0bff00287b44bc16f18914
  • Pointer size: 132 Bytes
  • Size of remote file: 4.56 MB
assets/images/68691.png ADDED

Git LFS Details

  • SHA256: fae39e4055c1b297af7068cdddfeeba8d685363281b839d8c5afac1980204b57
  • Pointer size: 132 Bytes
  • Size of remote file: 3.74 MB
assets/images/68956.png ADDED

Git LFS Details

  • SHA256: eee1f27082b10999d0fa848121ecb06cda3386b1a864b9aa0f59ae78261f8908
  • Pointer size: 132 Bytes
  • Size of remote file: 4.15 MB
assets/images/pexels-amresh444-17315601.png ADDED

Git LFS Details

  • SHA256: 4e17ee1b229147e4b52e8348a6ef426bc9e9a2f90738e776e15b26b325abb9b3
  • Pointer size: 132 Bytes
  • Size of remote file: 3.5 MB
assets/images/pexels-gabby-k-6311686.png ADDED

Git LFS Details

  • SHA256: 3f10eded3fb05ab04b963f7b9fd2e183d8d4e81b20569b1c6b0653549639421f
  • Pointer size: 132 Bytes
  • Size of remote file: 3.65 MB
assets/images/pexels-julia-m-cameron-4145040.png ADDED

Git LFS Details

  • SHA256: 459cf0280667b028ffbca16aa11188780d7a0205c0defec02916ff3cbaeecb72
  • Pointer size: 132 Bytes
  • Size of remote file: 2.92 MB
assets/images/pexels-marcus-aurelius-6787357.png ADDED

Git LFS Details

  • SHA256: 7d35452f76492125eaf7d5783aa9fd6b0d5990ebe0579fe9dfd58a9d634f4955
  • Pointer size: 132 Bytes
  • Size of remote file: 3.3 MB
assets/images/pexels-mo-saeed-3616599-5409085.png ADDED

Git LFS Details

  • SHA256: 7c1ca7afd6c2a654e94ef59d5fb56fca4f3cde5fb5216f6b218c34a7b8c143dc
  • Pointer size: 132 Bytes
  • Size of remote file: 3.13 MB
assets/images/pexels-riedelmax-27355495.png ADDED

Git LFS Details

  • SHA256: 4141d2f5f718f162ea1f6710c06b28b5cb51fd69598fde35948f8f3491228164
  • Pointer size: 132 Bytes
  • Size of remote file: 3.73 MB
assets/images/pexels-sergeymakashin-5368660.png ADDED

Git LFS Details

  • SHA256: af8f5a8f26dd102d87d94c1be36ec903791fe8e6d951c68ebb9ebcfc6d7397bb
  • Pointer size: 132 Bytes
  • Size of remote file: 4.08 MB
assets/images/pexels-vinicius-wiesehofer-289347-4219918.png ADDED

Git LFS Details

  • SHA256: a6eef5eee15b81fe65ea95627e9a46040b9889466689b3c1ca6ed273e02fe84f
  • Pointer size: 132 Bytes
  • Size of remote file: 3.63 MB
assets/rtmdet_m_640-8xb32_coco-person_no_nms.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = 'mmdet::rtmdet/rtmdet_m_8xb32-300e_coco.py'
2
+
3
+ checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-m_8xb256-rsb-a1-600e_in1k-ecb3bbd9.pth' # noqa
4
+
5
+ model = dict(
6
+ backbone=dict(
7
+ init_cfg=dict(
8
+ type='Pretrained', prefix='backbone.', checkpoint=checkpoint)),
9
+ bbox_head=dict(num_classes=1),
10
+ test_cfg=dict(
11
+ nms_pre=1000,
12
+ min_bbox_size=0,
13
+ score_thr=0.05,
14
+ nms=None,
15
+ max_per_img=100))
16
+
17
+ train_dataloader = dict(dataset=dict(metainfo=dict(classes=('person', ))))
18
+
19
+ val_dataloader = dict(dataset=dict(metainfo=dict(classes=('person', ))))
20
+ test_dataloader = val_dataloader
build_wheel.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import sys
4
+
5
+ def build_wheel(package_path):
6
+ current_dir = os.getcwd()
7
+ os.chdir(package_path)
8
+ try:
9
+ subprocess.check_call([sys.executable, "setup.py", "bdist_wheel"])
10
+ finally:
11
+ os.chdir(current_dir)
12
+
13
+ def main():
14
+ packages = [
15
+ "./external/engine",
16
+ "./external/cv",
17
+ "./external/det"
18
+ ]
19
+
20
+ for package in packages:
21
+ print(f"Building wheel for {package}...")
22
+ build_wheel(package)
23
+ print(f"Wheel built for {package}")
24
+
25
+ if __name__ == "__main__":
26
+ main()
classes_and_palettes.py ADDED
@@ -0,0 +1,1024 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ COCO_KPTS_COLORS = [
2
+ [51, 153, 255], # 0: nose
3
+ [51, 153, 255], # 1: left_eye
4
+ [51, 153, 255], # 2: right_eye
5
+ [51, 153, 255], # 3: left_ear
6
+ [51, 153, 255], # 4: right_ear
7
+ [0, 255, 0], # 5: left_shoulder
8
+ [255, 128, 0], # 6: right_shoulder
9
+ [0, 255, 0], # 7: left_elbow
10
+ [255, 128, 0], # 8: right_elbow
11
+ [0, 255, 0], # 9: left_wrist
12
+ [255, 128, 0], # 10: right_wrist
13
+ [0, 255, 0], # 11: left_hip
14
+ [255, 128, 0], # 12: right_hip
15
+ [0, 255, 0], # 13: left_knee
16
+ [255, 128, 0], # 14: right_knee
17
+ [0, 255, 0], # 15: left_ankle
18
+ [255, 128, 0], # 16: right_ankle
19
+ ]
20
+
21
+ COCO_WHOLEBODY_KPTS_COLORS = [
22
+ [51, 153, 255], # 0: nose
23
+ [51, 153, 255], # 1: left_eye
24
+ [51, 153, 255], # 2: right_eye
25
+ [51, 153, 255], # 3: left_ear
26
+ [51, 153, 255], # 4: right_ear
27
+ [0, 255, 0], # 5: left_shoulder
28
+ [255, 128, 0], # 6: right_shoulder
29
+ [0, 255, 0], # 7: left_elbow
30
+ [255, 128, 0], # 8: right_elbow
31
+ [0, 255, 0], # 9: left_wrist
32
+ [255, 128, 0], # 10: right_wrist
33
+ [0, 255, 0], # 11: left_hip
34
+ [255, 128, 0], # 12: right_hip
35
+ [0, 255, 0], # 13: left_knee
36
+ [255, 128, 0], # 14: right_knee
37
+ [0, 255, 0], # 15: left_ankle
38
+ [255, 128, 0], # 16: right_ankle
39
+ [255, 128, 0], # 17: left_big_toe
40
+ [255, 128, 0], # 18: left_small_toe
41
+ [255, 128, 0], # 19: left_heel
42
+ [255, 128, 0], # 20: right_big_toe
43
+ [255, 128, 0], # 21: right_small_toe
44
+ [255, 128, 0], # 22: right_heel
45
+ [255, 255, 255], # 23: face-0
46
+ [255, 255, 255], # 24: face-1
47
+ [255, 255, 255], # 25: face-2
48
+ [255, 255, 255], # 26: face-3
49
+ [255, 255, 255], # 27: face-4
50
+ [255, 255, 255], # 28: face-5
51
+ [255, 255, 255], # 29: face-6
52
+ [255, 255, 255], # 30: face-7
53
+ [255, 255, 255], # 31: face-8
54
+ [255, 255, 255], # 32: face-9
55
+ [255, 255, 255], # 33: face-10
56
+ [255, 255, 255], # 34: face-11
57
+ [255, 255, 255], # 35: face-12
58
+ [255, 255, 255], # 36: face-13
59
+ [255, 255, 255], # 37: face-14
60
+ [255, 255, 255], # 38: face-15
61
+ [255, 255, 255], # 39: face-16
62
+ [255, 255, 255], # 40: face-17
63
+ [255, 255, 255], # 41: face-18
64
+ [255, 255, 255], # 42: face-19
65
+ [255, 255, 255], # 43: face-20
66
+ [255, 255, 255], # 44: face-21
67
+ [255, 255, 255], # 45: face-22
68
+ [255, 255, 255], # 46: face-23
69
+ [255, 255, 255], # 47: face-24
70
+ [255, 255, 255], # 48: face-25
71
+ [255, 255, 255], # 49: face-26
72
+ [255, 255, 255], # 50: face-27
73
+ [255, 255, 255], # 51: face-28
74
+ [255, 255, 255], # 52: face-29
75
+ [255, 255, 255], # 53: face-30
76
+ [255, 255, 255], # 54: face-31
77
+ [255, 255, 255], # 55: face-32
78
+ [255, 255, 255], # 56: face-33
79
+ [255, 255, 255], # 57: face-34
80
+ [255, 255, 255], # 58: face-35
81
+ [255, 255, 255], # 59: face-36
82
+ [255, 255, 255], # 60: face-37
83
+ [255, 255, 255], # 61: face-38
84
+ [255, 255, 255], # 62: face-39
85
+ [255, 255, 255], # 63: face-40
86
+ [255, 255, 255], # 64: face-41
87
+ [255, 255, 255], # 65: face-42
88
+ [255, 255, 255], # 66: face-43
89
+ [255, 255, 255], # 67: face-44
90
+ [255, 255, 255], # 68: face-45
91
+ [255, 255, 255], # 69: face-46
92
+ [255, 255, 255], # 70: face-47
93
+ [255, 255, 255], # 71: face-48
94
+ [255, 255, 255], # 72: face-49
95
+ [255, 255, 255], # 73: face-50
96
+ [255, 255, 255], # 74: face-51
97
+ [255, 255, 255], # 75: face-52
98
+ [255, 255, 255], # 76: face-53
99
+ [255, 255, 255], # 77: face-54
100
+ [255, 255, 255], # 78: face-55
101
+ [255, 255, 255], # 79: face-56
102
+ [255, 255, 255], # 80: face-57
103
+ [255, 255, 255], # 81: face-58
104
+ [255, 255, 255], # 82: face-59
105
+ [255, 255, 255], # 83: face-60
106
+ [255, 255, 255], # 84: face-61
107
+ [255, 255, 255], # 85: face-62
108
+ [255, 255, 255], # 86: face-63
109
+ [255, 255, 255], # 87: face-64
110
+ [255, 255, 255], # 88: face-65
111
+ [255, 255, 255], # 89: face-66
112
+ [255, 255, 255], # 90: face-67
113
+ [255, 255, 255], # 91: left_hand_root
114
+ [255, 128, 0], # 92: left_thumb1
115
+ [255, 128, 0], # 93: left_thumb2
116
+ [255, 128, 0], # 94: left_thumb3
117
+ [255, 128, 0], # 95: left_thumb4
118
+ [255, 153, 255], # 96: left_forefinger1
119
+ [255, 153, 255], # 97: left_forefinger2
120
+ [255, 153, 255], # 98: left_forefinger3
121
+ [255, 153, 255], # 99: left_forefinger4
122
+ [102, 178, 255], # 100: left_middle_finger1
123
+ [102, 178, 255], # 101: left_middle_finger2
124
+ [102, 178, 255], # 102: left_middle_finger3
125
+ [102, 178, 255], # 103: left_middle_finger4
126
+ [255, 51, 51], # 104: left_ring_finger1
127
+ [255, 51, 51], # 105: left_ring_finger2
128
+ [255, 51, 51], # 106: left_ring_finger3
129
+ [255, 51, 51], # 107: left_ring_finger4
130
+ [0, 255, 0], # 108: left_pinky_finger1
131
+ [0, 255, 0], # 109: left_pinky_finger2
132
+ [0, 255, 0], # 110: left_pinky_finger3
133
+ [0, 255, 0], # 111: left_pinky_finger4
134
+ [255, 255, 255], # 112: right_hand_root
135
+ [255, 128, 0], # 113: right_thumb1
136
+ [255, 128, 0], # 114: right_thumb2
137
+ [255, 128, 0], # 115: right_thumb3
138
+ [255, 128, 0], # 116: right_thumb4
139
+ [255, 153, 255], # 117: right_forefinger1
140
+ [255, 153, 255], # 118: right_forefinger2
141
+ [255, 153, 255], # 119: right_forefinger3
142
+ [255, 153, 255], # 120: right_forefinger4
143
+ [102, 178, 255], # 121: right_middle_finger1
144
+ [102, 178, 255], # 122: right_middle_finger2
145
+ [102, 178, 255], # 123: right_middle_finger3
146
+ [102, 178, 255], # 124: right_middle_finger4
147
+ [255, 51, 51], # 125: right_ring_finger1
148
+ [255, 51, 51], # 126: right_ring_finger2
149
+ [255, 51, 51], # 127: right_ring_finger3
150
+ [255, 51, 51], # 128: right_ring_finger4
151
+ [0, 255, 0], # 129: right_pinky_finger1
152
+ [0, 255, 0], # 130: right_pinky_finger2
153
+ [0, 255, 0], # 131: right_pinky_finger3
154
+ [0, 255, 0], # 132: right_pinky_finger4
155
+ ]
156
+
157
+
158
+ GOLIATH_KPTS_COLORS = [
159
+ [51, 153, 255], # 0: nose
160
+ [51, 153, 255], # 1: left_eye
161
+ [51, 153, 255], # 2: right_eye
162
+ [51, 153, 255], # 3: left_ear
163
+ [51, 153, 255], # 4: right_ear
164
+ [51, 153, 255], # 5: left_shoulder
165
+ [51, 153, 255], # 6: right_shoulder
166
+ [51, 153, 255], # 7: left_elbow
167
+ [51, 153, 255], # 8: right_elbow
168
+ [51, 153, 255], # 9: left_hip
169
+ [51, 153, 255], # 10: right_hip
170
+ [51, 153, 255], # 11: left_knee
171
+ [51, 153, 255], # 12: right_knee
172
+ [51, 153, 255], # 13: left_ankle
173
+ [51, 153, 255], # 14: right_ankle
174
+ [51, 153, 255], # 15: left_big_toe
175
+ [51, 153, 255], # 16: left_small_toe
176
+ [51, 153, 255], # 17: left_heel
177
+ [51, 153, 255], # 18: right_big_toe
178
+ [51, 153, 255], # 19: right_small_toe
179
+ [51, 153, 255], # 20: right_heel
180
+ [51, 153, 255], # 21: right_thumb4
181
+ [51, 153, 255], # 22: right_thumb3
182
+ [51, 153, 255], # 23: right_thumb2
183
+ [51, 153, 255], # 24: right_thumb_third_joint
184
+ [51, 153, 255], # 25: right_forefinger4
185
+ [51, 153, 255], # 26: right_forefinger3
186
+ [51, 153, 255], # 27: right_forefinger2
187
+ [51, 153, 255], # 28: right_forefinger_third_joint
188
+ [51, 153, 255], # 29: right_middle_finger4
189
+ [51, 153, 255], # 30: right_middle_finger3
190
+ [51, 153, 255], # 31: right_middle_finger2
191
+ [51, 153, 255], # 32: right_middle_finger_third_joint
192
+ [51, 153, 255], # 33: right_ring_finger4
193
+ [51, 153, 255], # 34: right_ring_finger3
194
+ [51, 153, 255], # 35: right_ring_finger2
195
+ [51, 153, 255], # 36: right_ring_finger_third_joint
196
+ [51, 153, 255], # 37: right_pinky_finger4
197
+ [51, 153, 255], # 38: right_pinky_finger3
198
+ [51, 153, 255], # 39: right_pinky_finger2
199
+ [51, 153, 255], # 40: right_pinky_finger_third_joint
200
+ [51, 153, 255], # 41: right_wrist
201
+ [51, 153, 255], # 42: left_thumb4
202
+ [51, 153, 255], # 43: left_thumb3
203
+ [51, 153, 255], # 44: left_thumb2
204
+ [51, 153, 255], # 45: left_thumb_third_joint
205
+ [51, 153, 255], # 46: left_forefinger4
206
+ [51, 153, 255], # 47: left_forefinger3
207
+ [51, 153, 255], # 48: left_forefinger2
208
+ [51, 153, 255], # 49: left_forefinger_third_joint
209
+ [51, 153, 255], # 50: left_middle_finger4
210
+ [51, 153, 255], # 51: left_middle_finger3
211
+ [51, 153, 255], # 52: left_middle_finger2
212
+ [51, 153, 255], # 53: left_middle_finger_third_joint
213
+ [51, 153, 255], # 54: left_ring_finger4
214
+ [51, 153, 255], # 55: left_ring_finger3
215
+ [51, 153, 255], # 56: left_ring_finger2
216
+ [51, 153, 255], # 57: left_ring_finger_third_joint
217
+ [51, 153, 255], # 58: left_pinky_finger4
218
+ [51, 153, 255], # 59: left_pinky_finger3
219
+ [51, 153, 255], # 60: left_pinky_finger2
220
+ [51, 153, 255], # 61: left_pinky_finger_third_joint
221
+ [51, 153, 255], # 62: left_wrist
222
+ [51, 153, 255], # 63: left_olecranon
223
+ [51, 153, 255], # 64: right_olecranon
224
+ [51, 153, 255], # 65: left_cubital_fossa
225
+ [51, 153, 255], # 66: right_cubital_fossa
226
+ [51, 153, 255], # 67: left_acromion
227
+ [51, 153, 255], # 68: right_acromion
228
+ [51, 153, 255], # 69: neck
229
+ [255, 255, 255], # 70: center_of_glabella
230
+ [255, 255, 255], # 71: center_of_nose_root
231
+ [255, 255, 255], # 72: tip_of_nose_bridge
232
+ [255, 255, 255], # 73: midpoint_1_of_nose_bridge
233
+ [255, 255, 255], # 74: midpoint_2_of_nose_bridge
234
+ [255, 255, 255], # 75: midpoint_3_of_nose_bridge
235
+ [255, 255, 255], # 76: center_of_labiomental_groove
236
+ [255, 255, 255], # 77: tip_of_chin
237
+ [255, 255, 255], # 78: upper_startpoint_of_r_eyebrow
238
+ [255, 255, 255], # 79: lower_startpoint_of_r_eyebrow
239
+ [255, 255, 255], # 80: end_of_r_eyebrow
240
+ [255, 255, 255], # 81: upper_midpoint_1_of_r_eyebrow
241
+ [255, 255, 255], # 82: lower_midpoint_1_of_r_eyebrow
242
+ [255, 255, 255], # 83: upper_midpoint_2_of_r_eyebrow
243
+ [255, 255, 255], # 84: upper_midpoint_3_of_r_eyebrow
244
+ [255, 255, 255], # 85: lower_midpoint_2_of_r_eyebrow
245
+ [255, 255, 255], # 86: lower_midpoint_3_of_r_eyebrow
246
+ [255, 255, 255], # 87: upper_startpoint_of_l_eyebrow
247
+ [255, 255, 255], # 88: lower_startpoint_of_l_eyebrow
248
+ [255, 255, 255], # 89: end_of_l_eyebrow
249
+ [255, 255, 255], # 90: upper_midpoint_1_of_l_eyebrow
250
+ [255, 255, 255], # 91: lower_midpoint_1_of_l_eyebrow
251
+ [255, 255, 255], # 92: upper_midpoint_2_of_l_eyebrow
252
+ [255, 255, 255], # 93: upper_midpoint_3_of_l_eyebrow
253
+ [255, 255, 255], # 94: lower_midpoint_2_of_l_eyebrow
254
+ [255, 255, 255], # 95: lower_midpoint_3_of_l_eyebrow
255
+ [192, 64, 128], # 96: l_inner_end_of_upper_lash_line
256
+ [192, 64, 128], # 97: l_outer_end_of_upper_lash_line
257
+ [192, 64, 128], # 98: l_centerpoint_of_upper_lash_line
258
+ [192, 64, 128], # 99: l_midpoint_2_of_upper_lash_line
259
+ [192, 64, 128], # 100: l_midpoint_1_of_upper_lash_line
260
+ [192, 64, 128], # 101: l_midpoint_6_of_upper_lash_line
261
+ [192, 64, 128], # 102: l_midpoint_5_of_upper_lash_line
262
+ [192, 64, 128], # 103: l_midpoint_4_of_upper_lash_line
263
+ [192, 64, 128], # 104: l_midpoint_3_of_upper_lash_line
264
+ [192, 64, 128], # 105: l_outer_end_of_upper_eyelid_line
265
+ [192, 64, 128], # 106: l_midpoint_6_of_upper_eyelid_line
266
+ [192, 64, 128], # 107: l_midpoint_2_of_upper_eyelid_line
267
+ [192, 64, 128], # 108: l_midpoint_5_of_upper_eyelid_line
268
+ [192, 64, 128], # 109: l_centerpoint_of_upper_eyelid_line
269
+ [192, 64, 128], # 110: l_midpoint_4_of_upper_eyelid_line
270
+ [192, 64, 128], # 111: l_midpoint_1_of_upper_eyelid_line
271
+ [192, 64, 128], # 112: l_midpoint_3_of_upper_eyelid_line
272
+ [192, 64, 128], # 113: l_midpoint_6_of_upper_crease_line
273
+ [192, 64, 128], # 114: l_midpoint_2_of_upper_crease_line
274
+ [192, 64, 128], # 115: l_midpoint_5_of_upper_crease_line
275
+ [192, 64, 128], # 116: l_centerpoint_of_upper_crease_line
276
+ [192, 64, 128], # 117: l_midpoint_4_of_upper_crease_line
277
+ [192, 64, 128], # 118: l_midpoint_1_of_upper_crease_line
278
+ [192, 64, 128], # 119: l_midpoint_3_of_upper_crease_line
279
+ [64, 32, 192], # 120: r_inner_end_of_upper_lash_line
280
+ [64, 32, 192], # 121: r_outer_end_of_upper_lash_line
281
+ [64, 32, 192], # 122: r_centerpoint_of_upper_lash_line
282
+ [64, 32, 192], # 123: r_midpoint_1_of_upper_lash_line
283
+ [64, 32, 192], # 124: r_midpoint_2_of_upper_lash_line
284
+ [64, 32, 192], # 125: r_midpoint_3_of_upper_lash_line
285
+ [64, 32, 192], # 126: r_midpoint_4_of_upper_lash_line
286
+ [64, 32, 192], # 127: r_midpoint_5_of_upper_lash_line
287
+ [64, 32, 192], # 128: r_midpoint_6_of_upper_lash_line
288
+ [64, 32, 192], # 129: r_outer_end_of_upper_eyelid_line
289
+ [64, 32, 192], # 130: r_midpoint_3_of_upper_eyelid_line
290
+ [64, 32, 192], # 131: r_midpoint_1_of_upper_eyelid_line
291
+ [64, 32, 192], # 132: r_midpoint_4_of_upper_eyelid_line
292
+ [64, 32, 192], # 133: r_centerpoint_of_upper_eyelid_line
293
+ [64, 32, 192], # 134: r_midpoint_5_of_upper_eyelid_line
294
+ [64, 32, 192], # 135: r_midpoint_2_of_upper_eyelid_line
295
+ [64, 32, 192], # 136: r_midpoint_6_of_upper_eyelid_line
296
+ [64, 32, 192], # 137: r_midpoint_3_of_upper_crease_line
297
+ [64, 32, 192], # 138: r_midpoint_1_of_upper_crease_line
298
+ [64, 32, 192], # 139: r_midpoint_4_of_upper_crease_line
299
+ [64, 32, 192], # 140: r_centerpoint_of_upper_crease_line
300
+ [64, 32, 192], # 141: r_midpoint_5_of_upper_crease_line
301
+ [64, 32, 192], # 142: r_midpoint_2_of_upper_crease_line
302
+ [64, 32, 192], # 143: r_midpoint_6_of_upper_crease_line
303
+ [64, 192, 128], # 144: l_inner_end_of_lower_lash_line
304
+ [64, 192, 128], # 145: l_outer_end_of_lower_lash_line
305
+ [64, 192, 128], # 146: l_centerpoint_of_lower_lash_line
306
+ [64, 192, 128], # 147: l_midpoint_2_of_lower_lash_line
307
+ [64, 192, 128], # 148: l_midpoint_1_of_lower_lash_line
308
+ [64, 192, 128], # 149: l_midpoint_6_of_lower_lash_line
309
+ [64, 192, 128], # 150: l_midpoint_5_of_lower_lash_line
310
+ [64, 192, 128], # 151: l_midpoint_4_of_lower_lash_line
311
+ [64, 192, 128], # 152: l_midpoint_3_of_lower_lash_line
312
+ [64, 192, 128], # 153: l_outer_end_of_lower_eyelid_line
313
+ [64, 192, 128], # 154: l_midpoint_6_of_lower_eyelid_line
314
+ [64, 192, 128], # 155: l_midpoint_2_of_lower_eyelid_line
315
+ [64, 192, 128], # 156: l_midpoint_5_of_lower_eyelid_line
316
+ [64, 192, 128], # 157: l_centerpoint_of_lower_eyelid_line
317
+ [64, 192, 128], # 158: l_midpoint_4_of_lower_eyelid_line
318
+ [64, 192, 128], # 159: l_midpoint_1_of_lower_eyelid_line
319
+ [64, 192, 128], # 160: l_midpoint_3_of_lower_eyelid_line
320
+ [64, 192, 32], # 161: r_inner_end_of_lower_lash_line
321
+ [64, 192, 32], # 162: r_outer_end_of_lower_lash_line
322
+ [64, 192, 32], # 163: r_centerpoint_of_lower_lash_line
323
+ [64, 192, 32], # 164: r_midpoint_1_of_lower_lash_line
324
+ [64, 192, 32], # 165: r_midpoint_2_of_lower_lash_line
325
+ [64, 192, 32], # 166: r_midpoint_3_of_lower_lash_line
326
+ [64, 192, 32], # 167: r_midpoint_4_of_lower_lash_line
327
+ [64, 192, 32], # 168: r_midpoint_5_of_lower_lash_line
328
+ [64, 192, 32], # 169: r_midpoint_6_of_lower_lash_line
329
+ [64, 192, 32], # 170: r_outer_end_of_lower_eyelid_line
330
+ [64, 192, 32], # 171: r_midpoint_3_of_lower_eyelid_line
331
+ [64, 192, 32], # 172: r_midpoint_1_of_lower_eyelid_line
332
+ [64, 192, 32], # 173: r_midpoint_4_of_lower_eyelid_line
333
+ [64, 192, 32], # 174: r_centerpoint_of_lower_eyelid_line
334
+ [64, 192, 32], # 175: r_midpoint_5_of_lower_eyelid_line
335
+ [64, 192, 32], # 176: r_midpoint_2_of_lower_eyelid_line
336
+ [64, 192, 32], # 177: r_midpoint_6_of_lower_eyelid_line
337
+ [0, 192, 0], # 178: tip_of_nose
338
+ [0, 192, 0], # 179: bottom_center_of_nose
339
+ [0, 192, 0], # 180: r_outer_corner_of_nose
340
+ [0, 192, 0], # 181: l_outer_corner_of_nose
341
+ [0, 192, 0], # 182: inner_corner_of_r_nostril
342
+ [0, 192, 0], # 183: outer_corner_of_r_nostril
343
+ [0, 192, 0], # 184: upper_corner_of_r_nostril
344
+ [0, 192, 0], # 185: inner_corner_of_l_nostril
345
+ [0, 192, 0], # 186: outer_corner_of_l_nostril
346
+ [0, 192, 0], # 187: upper_corner_of_l_nostril
347
+ [192, 0, 0], # 188: r_outer_corner_of_mouth
348
+ [192, 0, 0], # 189: l_outer_corner_of_mouth
349
+ [192, 0, 0], # 190: center_of_cupid_bow
350
+ [192, 0, 0], # 191: center_of_lower_outer_lip
351
+ [192, 0, 0], # 192: midpoint_1_of_upper_outer_lip
352
+ [192, 0, 0], # 193: midpoint_2_of_upper_outer_lip
353
+ [192, 0, 0], # 194: midpoint_1_of_lower_outer_lip
354
+ [192, 0, 0], # 195: midpoint_2_of_lower_outer_lip
355
+ [192, 0, 0], # 196: midpoint_3_of_upper_outer_lip
356
+ [192, 0, 0], # 197: midpoint_4_of_upper_outer_lip
357
+ [192, 0, 0], # 198: midpoint_5_of_upper_outer_lip
358
+ [192, 0, 0], # 199: midpoint_6_of_upper_outer_lip
359
+ [192, 0, 0], # 200: midpoint_3_of_lower_outer_lip
360
+ [192, 0, 0], # 201: midpoint_4_of_lower_outer_lip
361
+ [192, 0, 0], # 202: midpoint_5_of_lower_outer_lip
362
+ [192, 0, 0], # 203: midpoint_6_of_lower_outer_lip
363
+ [0, 192, 192], # 204: r_inner_corner_of_mouth
364
+ [0, 192, 192], # 205: l_inner_corner_of_mouth
365
+ [0, 192, 192], # 206: center_of_upper_inner_lip
366
+ [0, 192, 192], # 207: center_of_lower_inner_lip
367
+ [0, 192, 192], # 208: midpoint_1_of_upper_inner_lip
368
+ [0, 192, 192], # 209: midpoint_2_of_upper_inner_lip
369
+ [0, 192, 192], # 210: midpoint_1_of_lower_inner_lip
370
+ [0, 192, 192], # 211: midpoint_2_of_lower_inner_lip
371
+ [0, 192, 192], # 212: midpoint_3_of_upper_inner_lip
372
+ [0, 192, 192], # 213: midpoint_4_of_upper_inner_lip
373
+ [0, 192, 192], # 214: midpoint_5_of_upper_inner_lip
374
+ [0, 192, 192], # 215: midpoint_6_of_upper_inner_lip
375
+ [0, 192, 192], # 216: midpoint_3_of_lower_inner_lip
376
+ [0, 192, 192], # 217: midpoint_4_of_lower_inner_lip
377
+ [0, 192, 192], # 218: midpoint_5_of_lower_inner_lip
378
+ [0, 192, 192], # 219: midpoint_6_of_lower_inner_lip. teeths removed
379
+ [200, 200, 0], # 256: l_top_end_of_inferior_crus
380
+ [200, 200, 0], # 257: l_top_end_of_superior_crus
381
+ [200, 200, 0], # 258: l_start_of_antihelix
382
+ [200, 200, 0], # 259: l_end_of_antihelix
383
+ [200, 200, 0], # 260: l_midpoint_1_of_antihelix
384
+ [200, 200, 0], # 261: l_midpoint_1_of_inferior_crus
385
+ [200, 200, 0], # 262: l_midpoint_2_of_antihelix
386
+ [200, 200, 0], # 263: l_midpoint_3_of_antihelix
387
+ [200, 200, 0], # 264: l_point_1_of_inner_helix
388
+ [200, 200, 0], # 265: l_point_2_of_inner_helix
389
+ [200, 200, 0], # 266: l_point_3_of_inner_helix
390
+ [200, 200, 0], # 267: l_point_4_of_inner_helix
391
+ [200, 200, 0], # 268: l_point_5_of_inner_helix
392
+ [200, 200, 0], # 269: l_point_6_of_inner_helix
393
+ [200, 200, 0], # 270: l_point_7_of_inner_helix
394
+ [200, 200, 0], # 271: l_highest_point_of_antitragus
395
+ [200, 200, 0], # 272: l_bottom_point_of_tragus
396
+ [200, 200, 0], # 273: l_protruding_point_of_tragus
397
+ [200, 200, 0], # 274: l_top_point_of_tragus
398
+ [200, 200, 0], # 275: l_start_point_of_crus_of_helix
399
+ [200, 200, 0], # 276: l_deepest_point_of_concha
400
+ [200, 200, 0], # 277: l_tip_of_ear_lobe
401
+ [200, 200, 0], # 278: l_midpoint_between_22_15
402
+ [200, 200, 0], # 279: l_bottom_connecting_point_of_ear_lobe
403
+ [200, 200, 0], # 280: l_top_connecting_point_of_helix
404
+ [200, 200, 0], # 281: l_point_8_of_inner_helix
405
+ [0, 200, 200], # 282: r_top_end_of_inferior_crus
406
+ [0, 200, 200], # 283: r_top_end_of_superior_crus
407
+ [0, 200, 200], # 284: r_start_of_antihelix
408
+ [0, 200, 200], # 285: r_end_of_antihelix
409
+ [0, 200, 200], # 286: r_midpoint_1_of_antihelix
410
+ [0, 200, 200], # 287: r_midpoint_1_of_inferior_crus
411
+ [0, 200, 200], # 288: r_midpoint_2_of_antihelix
412
+ [0, 200, 200], # 289: r_midpoint_3_of_antihelix
413
+ [0, 200, 200], # 290: r_point_1_of_inner_helix
414
+ [0, 200, 200], # 291: r_point_8_of_inner_helix
415
+ [0, 200, 200], # 292: r_point_3_of_inner_helix
416
+ [0, 200, 200], # 293: r_point_4_of_inner_helix
417
+ [0, 200, 200], # 294: r_point_5_of_inner_helix
418
+ [0, 200, 200], # 295: r_point_6_of_inner_helix
419
+ [0, 200, 200], # 296: r_point_7_of_inner_helix
420
+ [0, 200, 200], # 297: r_highest_point_of_antitragus
421
+ [0, 200, 200], # 298: r_bottom_point_of_tragus
422
+ [0, 200, 200], # 299: r_protruding_point_of_tragus
423
+ [0, 200, 200], # 300: r_top_point_of_tragus
424
+ [0, 200, 200], # 301: r_start_point_of_crus_of_helix
425
+ [0, 200, 200], # 302: r_deepest_point_of_concha
426
+ [0, 200, 200], # 303: r_tip_of_ear_lobe
427
+ [0, 200, 200], # 304: r_midpoint_between_22_15
428
+ [0, 200, 200], # 305: r_bottom_connecting_point_of_ear_lobe
429
+ [0, 200, 200], # 306: r_top_connecting_point_of_helix
430
+ [0, 200, 200], # 307: r_point_2_of_inner_helix
431
+ [128, 192, 64], # 308: l_center_of_iris
432
+ [128, 192, 64], # 309: l_border_of_iris_3
433
+ [128, 192, 64], # 310: l_border_of_iris_midpoint_1
434
+ [128, 192, 64], # 311: l_border_of_iris_12
435
+ [128, 192, 64], # 312: l_border_of_iris_midpoint_4
436
+ [128, 192, 64], # 313: l_border_of_iris_9
437
+ [128, 192, 64], # 314: l_border_of_iris_midpoint_3
438
+ [128, 192, 64], # 315: l_border_of_iris_6
439
+ [128, 192, 64], # 316: l_border_of_iris_midpoint_2
440
+ [192, 32, 64], # 317: r_center_of_iris
441
+ [192, 32, 64], # 318: r_border_of_iris_3
442
+ [192, 32, 64], # 319: r_border_of_iris_midpoint_1
443
+ [192, 32, 64], # 320: r_border_of_iris_12
444
+ [192, 32, 64], # 321: r_border_of_iris_midpoint_4
445
+ [192, 32, 64], # 322: r_border_of_iris_9
446
+ [192, 32, 64], # 323: r_border_of_iris_midpoint_3
447
+ [192, 32, 64], # 324: r_border_of_iris_6
448
+ [192, 32, 64], # 325: r_border_of_iris_midpoint_2
449
+ [192, 128, 64], # 326: l_center_of_pupil
450
+ [192, 128, 64], # 327: l_border_of_pupil_3
451
+ [192, 128, 64], # 328: l_border_of_pupil_midpoint_1
452
+ [192, 128, 64], # 329: l_border_of_pupil_12
453
+ [192, 128, 64], # 330: l_border_of_pupil_midpoint_4
454
+ [192, 128, 64], # 331: l_border_of_pupil_9
455
+ [192, 128, 64], # 332: l_border_of_pupil_midpoint_3
456
+ [192, 128, 64], # 333: l_border_of_pupil_6
457
+ [192, 128, 64], # 334: l_border_of_pupil_midpoint_2
458
+ [32, 192, 192], # 335: r_center_of_pupil
459
+ [32, 192, 192], # 336: r_border_of_pupil_3
460
+ [32, 192, 192], # 337: r_border_of_pupil_midpoint_1
461
+ [32, 192, 192], # 338: r_border_of_pupil_12
462
+ [32, 192, 192], # 339: r_border_of_pupil_midpoint_4
463
+ [32, 192, 192], # 340: r_border_of_pupil_9
464
+ [32, 192, 192], # 341: r_border_of_pupil_midpoint_3
465
+ [32, 192, 192], # 342: r_border_of_pupil_6
466
+ [32, 192, 192], # 343: r_border_of_pupil_midpoint_2
467
+ ]
468
+
469
+ GOLIATH_KEYPOINTS = [
470
+ "nose",
471
+ "left_eye",
472
+ "right_eye",
473
+ "left_ear",
474
+ "right_ear",
475
+ "left_shoulder",
476
+ "right_shoulder",
477
+ "left_elbow",
478
+ "right_elbow",
479
+ "left_hip",
480
+ "right_hip",
481
+ "left_knee",
482
+ "right_knee",
483
+ "left_ankle",
484
+ "right_ankle",
485
+ "left_big_toe",
486
+ "left_small_toe",
487
+ "left_heel",
488
+ "right_big_toe",
489
+ "right_small_toe",
490
+ "right_heel",
491
+ "right_thumb4",
492
+ "right_thumb3",
493
+ "right_thumb2",
494
+ "right_thumb_third_joint",
495
+ "right_forefinger4",
496
+ "right_forefinger3",
497
+ "right_forefinger2",
498
+ "right_forefinger_third_joint",
499
+ "right_middle_finger4",
500
+ "right_middle_finger3",
501
+ "right_middle_finger2",
502
+ "right_middle_finger_third_joint",
503
+ "right_ring_finger4",
504
+ "right_ring_finger3",
505
+ "right_ring_finger2",
506
+ "right_ring_finger_third_joint",
507
+ "right_pinky_finger4",
508
+ "right_pinky_finger3",
509
+ "right_pinky_finger2",
510
+ "right_pinky_finger_third_joint",
511
+ "right_wrist",
512
+ "left_thumb4",
513
+ "left_thumb3",
514
+ "left_thumb2",
515
+ "left_thumb_third_joint",
516
+ "left_forefinger4",
517
+ "left_forefinger3",
518
+ "left_forefinger2",
519
+ "left_forefinger_third_joint",
520
+ "left_middle_finger4",
521
+ "left_middle_finger3",
522
+ "left_middle_finger2",
523
+ "left_middle_finger_third_joint",
524
+ "left_ring_finger4",
525
+ "left_ring_finger3",
526
+ "left_ring_finger2",
527
+ "left_ring_finger_third_joint",
528
+ "left_pinky_finger4",
529
+ "left_pinky_finger3",
530
+ "left_pinky_finger2",
531
+ "left_pinky_finger_third_joint",
532
+ "left_wrist",
533
+ "left_olecranon",
534
+ "right_olecranon",
535
+ "left_cubital_fossa",
536
+ "right_cubital_fossa",
537
+ "left_acromion",
538
+ "right_acromion",
539
+ "neck",
540
+ "center_of_glabella",
541
+ "center_of_nose_root",
542
+ "tip_of_nose_bridge",
543
+ "midpoint_1_of_nose_bridge",
544
+ "midpoint_2_of_nose_bridge",
545
+ "midpoint_3_of_nose_bridge",
546
+ "center_of_labiomental_groove",
547
+ "tip_of_chin",
548
+ "upper_startpoint_of_r_eyebrow",
549
+ "lower_startpoint_of_r_eyebrow",
550
+ "end_of_r_eyebrow",
551
+ "upper_midpoint_1_of_r_eyebrow",
552
+ "lower_midpoint_1_of_r_eyebrow",
553
+ "upper_midpoint_2_of_r_eyebrow",
554
+ "upper_midpoint_3_of_r_eyebrow",
555
+ "lower_midpoint_2_of_r_eyebrow",
556
+ "lower_midpoint_3_of_r_eyebrow",
557
+ "upper_startpoint_of_l_eyebrow",
558
+ "lower_startpoint_of_l_eyebrow",
559
+ "end_of_l_eyebrow",
560
+ "upper_midpoint_1_of_l_eyebrow",
561
+ "lower_midpoint_1_of_l_eyebrow",
562
+ "upper_midpoint_2_of_l_eyebrow",
563
+ "upper_midpoint_3_of_l_eyebrow",
564
+ "lower_midpoint_2_of_l_eyebrow",
565
+ "lower_midpoint_3_of_l_eyebrow",
566
+ "l_inner_end_of_upper_lash_line",
567
+ "l_outer_end_of_upper_lash_line",
568
+ "l_centerpoint_of_upper_lash_line",
569
+ "l_midpoint_2_of_upper_lash_line",
570
+ "l_midpoint_1_of_upper_lash_line",
571
+ "l_midpoint_6_of_upper_lash_line",
572
+ "l_midpoint_5_of_upper_lash_line",
573
+ "l_midpoint_4_of_upper_lash_line",
574
+ "l_midpoint_3_of_upper_lash_line",
575
+ "l_outer_end_of_upper_eyelid_line",
576
+ "l_midpoint_6_of_upper_eyelid_line",
577
+ "l_midpoint_2_of_upper_eyelid_line",
578
+ "l_midpoint_5_of_upper_eyelid_line",
579
+ "l_centerpoint_of_upper_eyelid_line",
580
+ "l_midpoint_4_of_upper_eyelid_line",
581
+ "l_midpoint_1_of_upper_eyelid_line",
582
+ "l_midpoint_3_of_upper_eyelid_line",
583
+ "l_midpoint_6_of_upper_crease_line",
584
+ "l_midpoint_2_of_upper_crease_line",
585
+ "l_midpoint_5_of_upper_crease_line",
586
+ "l_centerpoint_of_upper_crease_line",
587
+ "l_midpoint_4_of_upper_crease_line",
588
+ "l_midpoint_1_of_upper_crease_line",
589
+ "l_midpoint_3_of_upper_crease_line",
590
+ "r_inner_end_of_upper_lash_line",
591
+ "r_outer_end_of_upper_lash_line",
592
+ "r_centerpoint_of_upper_lash_line",
593
+ "r_midpoint_1_of_upper_lash_line",
594
+ "r_midpoint_2_of_upper_lash_line",
595
+ "r_midpoint_3_of_upper_lash_line",
596
+ "r_midpoint_4_of_upper_lash_line",
597
+ "r_midpoint_5_of_upper_lash_line",
598
+ "r_midpoint_6_of_upper_lash_line",
599
+ "r_outer_end_of_upper_eyelid_line",
600
+ "r_midpoint_3_of_upper_eyelid_line",
601
+ "r_midpoint_1_of_upper_eyelid_line",
602
+ "r_midpoint_4_of_upper_eyelid_line",
603
+ "r_centerpoint_of_upper_eyelid_line",
604
+ "r_midpoint_5_of_upper_eyelid_line",
605
+ "r_midpoint_2_of_upper_eyelid_line",
606
+ "r_midpoint_6_of_upper_eyelid_line",
607
+ "r_midpoint_3_of_upper_crease_line",
608
+ "r_midpoint_1_of_upper_crease_line",
609
+ "r_midpoint_4_of_upper_crease_line",
610
+ "r_centerpoint_of_upper_crease_line",
611
+ "r_midpoint_5_of_upper_crease_line",
612
+ "r_midpoint_2_of_upper_crease_line",
613
+ "r_midpoint_6_of_upper_crease_line",
614
+ "l_inner_end_of_lower_lash_line",
615
+ "l_outer_end_of_lower_lash_line",
616
+ "l_centerpoint_of_lower_lash_line",
617
+ "l_midpoint_2_of_lower_lash_line",
618
+ "l_midpoint_1_of_lower_lash_line",
619
+ "l_midpoint_6_of_lower_lash_line",
620
+ "l_midpoint_5_of_lower_lash_line",
621
+ "l_midpoint_4_of_lower_lash_line",
622
+ "l_midpoint_3_of_lower_lash_line",
623
+ "l_outer_end_of_lower_eyelid_line",
624
+ "l_midpoint_6_of_lower_eyelid_line",
625
+ "l_midpoint_2_of_lower_eyelid_line",
626
+ "l_midpoint_5_of_lower_eyelid_line",
627
+ "l_centerpoint_of_lower_eyelid_line",
628
+ "l_midpoint_4_of_lower_eyelid_line",
629
+ "l_midpoint_1_of_lower_eyelid_line",
630
+ "l_midpoint_3_of_lower_eyelid_line",
631
+ "r_inner_end_of_lower_lash_line",
632
+ "r_outer_end_of_lower_lash_line",
633
+ "r_centerpoint_of_lower_lash_line",
634
+ "r_midpoint_1_of_lower_lash_line",
635
+ "r_midpoint_2_of_lower_lash_line",
636
+ "r_midpoint_3_of_lower_lash_line",
637
+ "r_midpoint_4_of_lower_lash_line",
638
+ "r_midpoint_5_of_lower_lash_line",
639
+ "r_midpoint_6_of_lower_lash_line",
640
+ "r_outer_end_of_lower_eyelid_line",
641
+ "r_midpoint_3_of_lower_eyelid_line",
642
+ "r_midpoint_1_of_lower_eyelid_line",
643
+ "r_midpoint_4_of_lower_eyelid_line",
644
+ "r_centerpoint_of_lower_eyelid_line",
645
+ "r_midpoint_5_of_lower_eyelid_line",
646
+ "r_midpoint_2_of_lower_eyelid_line",
647
+ "r_midpoint_6_of_lower_eyelid_line",
648
+ "tip_of_nose",
649
+ "bottom_center_of_nose",
650
+ "r_outer_corner_of_nose",
651
+ "l_outer_corner_of_nose",
652
+ "inner_corner_of_r_nostril",
653
+ "outer_corner_of_r_nostril",
654
+ "upper_corner_of_r_nostril",
655
+ "inner_corner_of_l_nostril",
656
+ "outer_corner_of_l_nostril",
657
+ "upper_corner_of_l_nostril",
658
+ "r_outer_corner_of_mouth",
659
+ "l_outer_corner_of_mouth",
660
+ "center_of_cupid_bow",
661
+ "center_of_lower_outer_lip",
662
+ "midpoint_1_of_upper_outer_lip",
663
+ "midpoint_2_of_upper_outer_lip",
664
+ "midpoint_1_of_lower_outer_lip",
665
+ "midpoint_2_of_lower_outer_lip",
666
+ "midpoint_3_of_upper_outer_lip",
667
+ "midpoint_4_of_upper_outer_lip",
668
+ "midpoint_5_of_upper_outer_lip",
669
+ "midpoint_6_of_upper_outer_lip",
670
+ "midpoint_3_of_lower_outer_lip",
671
+ "midpoint_4_of_lower_outer_lip",
672
+ "midpoint_5_of_lower_outer_lip",
673
+ "midpoint_6_of_lower_outer_lip",
674
+ "r_inner_corner_of_mouth",
675
+ "l_inner_corner_of_mouth",
676
+ "center_of_upper_inner_lip",
677
+ "center_of_lower_inner_lip",
678
+ "midpoint_1_of_upper_inner_lip",
679
+ "midpoint_2_of_upper_inner_lip",
680
+ "midpoint_1_of_lower_inner_lip",
681
+ "midpoint_2_of_lower_inner_lip",
682
+ "midpoint_3_of_upper_inner_lip",
683
+ "midpoint_4_of_upper_inner_lip",
684
+ "midpoint_5_of_upper_inner_lip",
685
+ "midpoint_6_of_upper_inner_lip",
686
+ "midpoint_3_of_lower_inner_lip",
687
+ "midpoint_4_of_lower_inner_lip",
688
+ "midpoint_5_of_lower_inner_lip",
689
+ "midpoint_6_of_lower_inner_lip",
690
+ "l_top_end_of_inferior_crus",
691
+ "l_top_end_of_superior_crus",
692
+ "l_start_of_antihelix",
693
+ "l_end_of_antihelix",
694
+ "l_midpoint_1_of_antihelix",
695
+ "l_midpoint_1_of_inferior_crus",
696
+ "l_midpoint_2_of_antihelix",
697
+ "l_midpoint_3_of_antihelix",
698
+ "l_point_1_of_inner_helix",
699
+ "l_point_2_of_inner_helix",
700
+ "l_point_3_of_inner_helix",
701
+ "l_point_4_of_inner_helix",
702
+ "l_point_5_of_inner_helix",
703
+ "l_point_6_of_inner_helix",
704
+ "l_point_7_of_inner_helix",
705
+ "l_highest_point_of_antitragus",
706
+ "l_bottom_point_of_tragus",
707
+ "l_protruding_point_of_tragus",
708
+ "l_top_point_of_tragus",
709
+ "l_start_point_of_crus_of_helix",
710
+ "l_deepest_point_of_concha",
711
+ "l_tip_of_ear_lobe",
712
+ "l_midpoint_between_22_15",
713
+ "l_bottom_connecting_point_of_ear_lobe",
714
+ "l_top_connecting_point_of_helix",
715
+ "l_point_8_of_inner_helix",
716
+ "r_top_end_of_inferior_crus",
717
+ "r_top_end_of_superior_crus",
718
+ "r_start_of_antihelix",
719
+ "r_end_of_antihelix",
720
+ "r_midpoint_1_of_antihelix",
721
+ "r_midpoint_1_of_inferior_crus",
722
+ "r_midpoint_2_of_antihelix",
723
+ "r_midpoint_3_of_antihelix",
724
+ "r_point_1_of_inner_helix",
725
+ "r_point_8_of_inner_helix",
726
+ "r_point_3_of_inner_helix",
727
+ "r_point_4_of_inner_helix",
728
+ "r_point_5_of_inner_helix",
729
+ "r_point_6_of_inner_helix",
730
+ "r_point_7_of_inner_helix",
731
+ "r_highest_point_of_antitragus",
732
+ "r_bottom_point_of_tragus",
733
+ "r_protruding_point_of_tragus",
734
+ "r_top_point_of_tragus",
735
+ "r_start_point_of_crus_of_helix",
736
+ "r_deepest_point_of_concha",
737
+ "r_tip_of_ear_lobe",
738
+ "r_midpoint_between_22_15",
739
+ "r_bottom_connecting_point_of_ear_lobe",
740
+ "r_top_connecting_point_of_helix",
741
+ "r_point_2_of_inner_helix",
742
+ "l_center_of_iris",
743
+ "l_border_of_iris_3",
744
+ "l_border_of_iris_midpoint_1",
745
+ "l_border_of_iris_12",
746
+ "l_border_of_iris_midpoint_4",
747
+ "l_border_of_iris_9",
748
+ "l_border_of_iris_midpoint_3",
749
+ "l_border_of_iris_6",
750
+ "l_border_of_iris_midpoint_2",
751
+ "r_center_of_iris",
752
+ "r_border_of_iris_3",
753
+ "r_border_of_iris_midpoint_1",
754
+ "r_border_of_iris_12",
755
+ "r_border_of_iris_midpoint_4",
756
+ "r_border_of_iris_9",
757
+ "r_border_of_iris_midpoint_3",
758
+ "r_border_of_iris_6",
759
+ "r_border_of_iris_midpoint_2",
760
+ "l_center_of_pupil",
761
+ "l_border_of_pupil_3",
762
+ "l_border_of_pupil_midpoint_1",
763
+ "l_border_of_pupil_12",
764
+ "l_border_of_pupil_midpoint_4",
765
+ "l_border_of_pupil_9",
766
+ "l_border_of_pupil_midpoint_3",
767
+ "l_border_of_pupil_6",
768
+ "l_border_of_pupil_midpoint_2",
769
+ "r_center_of_pupil",
770
+ "r_border_of_pupil_3",
771
+ "r_border_of_pupil_midpoint_1",
772
+ "r_border_of_pupil_12",
773
+ "r_border_of_pupil_midpoint_4",
774
+ "r_border_of_pupil_9",
775
+ "r_border_of_pupil_midpoint_3",
776
+ "r_border_of_pupil_6",
777
+ "r_border_of_pupil_midpoint_2"
778
+ ]
779
+
780
+ GOLIATH_SKELETON_INFO = {
781
+ 0:
782
+ dict(link=('left_ankle', 'left_knee'), id=0, color=[0, 255, 0]),
783
+ 1:
784
+ dict(link=('left_knee', 'left_hip'), id=1, color=[0, 255, 0]),
785
+ 2:
786
+ dict(link=('right_ankle', 'right_knee'), id=2, color=[255, 128, 0]),
787
+ 3:
788
+ dict(link=('right_knee', 'right_hip'), id=3, color=[255, 128, 0]),
789
+ 4:
790
+ dict(link=('left_hip', 'right_hip'), id=4, color=[51, 153, 255]),
791
+ 5:
792
+ dict(link=('left_shoulder', 'left_hip'), id=5, color=[51, 153, 255]),
793
+ 6:
794
+ dict(link=('right_shoulder', 'right_hip'), id=6, color=[51, 153, 255]),
795
+ 7:
796
+ dict(
797
+ link=('left_shoulder', 'right_shoulder'),
798
+ id=7,
799
+ color=[51, 153, 255]),
800
+ 8:
801
+ dict(link=('left_shoulder', 'left_elbow'), id=8, color=[0, 255, 0]),
802
+ 9:
803
+ dict(
804
+ link=('right_shoulder', 'right_elbow'), id=9, color=[255, 128, 0]),
805
+ 10:
806
+ dict(link=('left_elbow', 'left_wrist'), id=10, color=[0, 255, 0]),
807
+ 11:
808
+ dict(link=('right_elbow', 'right_wrist'), id=11, color=[255, 128, 0]),
809
+ 12:
810
+ dict(link=('left_eye', 'right_eye'), id=12, color=[51, 153, 255]),
811
+ 13:
812
+ dict(link=('nose', 'left_eye'), id=13, color=[51, 153, 255]),
813
+ 14:
814
+ dict(link=('nose', 'right_eye'), id=14, color=[51, 153, 255]),
815
+ 15:
816
+ dict(link=('left_eye', 'left_ear'), id=15, color=[51, 153, 255]),
817
+ 16:
818
+ dict(link=('right_eye', 'right_ear'), id=16, color=[51, 153, 255]),
819
+ 17:
820
+ dict(link=('left_ear', 'left_shoulder'), id=17, color=[51, 153, 255]),
821
+ 18:
822
+ dict(
823
+ link=('right_ear', 'right_shoulder'), id=18, color=[51, 153, 255]),
824
+ 19:
825
+ dict(link=('left_ankle', 'left_big_toe'), id=19, color=[0, 255, 0]),
826
+ 20:
827
+ dict(link=('left_ankle', 'left_small_toe'), id=20, color=[0, 255, 0]),
828
+ 21:
829
+ dict(link=('left_ankle', 'left_heel'), id=21, color=[0, 255, 0]),
830
+ 22:
831
+ dict(
832
+ link=('right_ankle', 'right_big_toe'), id=22, color=[255, 128, 0]),
833
+ 23:
834
+ dict(
835
+ link=('right_ankle', 'right_small_toe'),
836
+ id=23,
837
+ color=[255, 128, 0]),
838
+ 24:
839
+ dict(link=('right_ankle', 'right_heel'), id=24, color=[255, 128, 0]),
840
+ 25:
841
+ dict(
842
+ link=('left_wrist', 'left_thumb_third_joint'), id=25, color=[255, 128,
843
+ 0]),
844
+ 26:
845
+ dict(link=('left_thumb_third_joint', 'left_thumb2'), id=26, color=[255, 128, 0]),
846
+ 27:
847
+ dict(link=('left_thumb2', 'left_thumb3'), id=27, color=[255, 128, 0]),
848
+ 28:
849
+ dict(link=('left_thumb3', 'left_thumb4'), id=28, color=[255, 128, 0]),
850
+ 29:
851
+ dict(
852
+ link=('left_wrist', 'left_forefinger_third_joint'),
853
+ id=29,
854
+ color=[255, 153, 255]),
855
+ 30:
856
+ dict(
857
+ link=('left_forefinger_third_joint', 'left_forefinger2'),
858
+ id=30,
859
+ color=[255, 153, 255]),
860
+ 31:
861
+ dict(
862
+ link=('left_forefinger2', 'left_forefinger3'),
863
+ id=31,
864
+ color=[255, 153, 255]),
865
+ 32:
866
+ dict(
867
+ link=('left_forefinger3', 'left_forefinger4'),
868
+ id=32,
869
+ color=[255, 153, 255]),
870
+ 33:
871
+ dict(
872
+ link=('left_wrist', 'left_middle_finger_third_joint'),
873
+ id=33,
874
+ color=[102, 178, 255]),
875
+ 34:
876
+ dict(
877
+ link=('left_middle_finger_third_joint', 'left_middle_finger2'),
878
+ id=34,
879
+ color=[102, 178, 255]),
880
+ 35:
881
+ dict(
882
+ link=('left_middle_finger2', 'left_middle_finger3'),
883
+ id=35,
884
+ color=[102, 178, 255]),
885
+ 36:
886
+ dict(
887
+ link=('left_middle_finger3', 'left_middle_finger4'),
888
+ id=36,
889
+ color=[102, 178, 255]),
890
+ 37:
891
+ dict(
892
+ link=('left_wrist', 'left_ring_finger_third_joint'),
893
+ id=37,
894
+ color=[255, 51, 51]),
895
+ 38:
896
+ dict(
897
+ link=('left_ring_finger_third_joint', 'left_ring_finger2'),
898
+ id=38,
899
+ color=[255, 51, 51]),
900
+ 39:
901
+ dict(
902
+ link=('left_ring_finger2', 'left_ring_finger3'),
903
+ id=39,
904
+ color=[255, 51, 51]),
905
+ 40:
906
+ dict(
907
+ link=('left_ring_finger3', 'left_ring_finger4'),
908
+ id=40,
909
+ color=[255, 51, 51]),
910
+ 41:
911
+ dict(
912
+ link=('left_wrist', 'left_pinky_finger_third_joint'),
913
+ id=41,
914
+ color=[0, 255, 0]),
915
+ 42:
916
+ dict(
917
+ link=('left_pinky_finger_third_joint', 'left_pinky_finger2'),
918
+ id=42,
919
+ color=[0, 255, 0]),
920
+ 43:
921
+ dict(
922
+ link=('left_pinky_finger2', 'left_pinky_finger3'),
923
+ id=43,
924
+ color=[0, 255, 0]),
925
+ 44:
926
+ dict(
927
+ link=('left_pinky_finger3', 'left_pinky_finger4'),
928
+ id=44,
929
+ color=[0, 255, 0]),
930
+ 45:
931
+ dict(
932
+ link=('right_wrist', 'right_thumb_third_joint'),
933
+ id=45,
934
+ color=[255, 128, 0]),
935
+ 46:
936
+ dict(
937
+ link=('right_thumb_third_joint', 'right_thumb2'), id=46, color=[255, 128, 0]),
938
+ 47:
939
+ dict(
940
+ link=('right_thumb2', 'right_thumb3'), id=47, color=[255, 128, 0]),
941
+ 48:
942
+ dict(
943
+ link=('right_thumb3', 'right_thumb4'), id=48, color=[255, 128, 0]),
944
+ 49:
945
+ dict(
946
+ link=('right_wrist', 'right_forefinger_third_joint'),
947
+ id=49,
948
+ color=[255, 153, 255]),
949
+ 50:
950
+ dict(
951
+ link=('right_forefinger_third_joint', 'right_forefinger2'),
952
+ id=50,
953
+ color=[255, 153, 255]),
954
+ 51:
955
+ dict(
956
+ link=('right_forefinger2', 'right_forefinger3'),
957
+ id=51,
958
+ color=[255, 153, 255]),
959
+ 52:
960
+ dict(
961
+ link=('right_forefinger3', 'right_forefinger4'),
962
+ id=52,
963
+ color=[255, 153, 255]),
964
+ 53:
965
+ dict(
966
+ link=('right_wrist', 'right_middle_finger_third_joint'),
967
+ id=53,
968
+ color=[102, 178, 255]),
969
+ 54:
970
+ dict(
971
+ link=('right_middle_finger_third_joint', 'right_middle_finger2'),
972
+ id=54,
973
+ color=[102, 178, 255]),
974
+ 55:
975
+ dict(
976
+ link=('right_middle_finger2', 'right_middle_finger3'),
977
+ id=55,
978
+ color=[102, 178, 255]),
979
+ 56:
980
+ dict(
981
+ link=('right_middle_finger3', 'right_middle_finger4'),
982
+ id=56,
983
+ color=[102, 178, 255]),
984
+ 57:
985
+ dict(
986
+ link=('right_wrist', 'right_ring_finger_third_joint'),
987
+ id=57,
988
+ color=[255, 51, 51]),
989
+ 58:
990
+ dict(
991
+ link=('right_ring_finger_third_joint', 'right_ring_finger2'),
992
+ id=58,
993
+ color=[255, 51, 51]),
994
+ 59:
995
+ dict(
996
+ link=('right_ring_finger2', 'right_ring_finger3'),
997
+ id=59,
998
+ color=[255, 51, 51]),
999
+ 60:
1000
+ dict(
1001
+ link=('right_ring_finger3', 'right_ring_finger4'),
1002
+ id=60,
1003
+ color=[255, 51, 51]),
1004
+ 61:
1005
+ dict(
1006
+ link=('right_wrist', 'right_pinky_finger_third_joint'),
1007
+ id=61,
1008
+ color=[0, 255, 0]),
1009
+ 62:
1010
+ dict(
1011
+ link=('right_pinky_finger_third_joint', 'right_pinky_finger2'),
1012
+ id=62,
1013
+ color=[0, 255, 0]),
1014
+ 63:
1015
+ dict(
1016
+ link=('right_pinky_finger2', 'right_pinky_finger3'),
1017
+ id=63,
1018
+ color=[0, 255, 0]),
1019
+ 64:
1020
+ dict(
1021
+ link=('right_pinky_finger3', 'right_pinky_finger4'),
1022
+ id=64,
1023
+ color=[0, 255, 0])
1024
+ }
detector_utils.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Sequence, Union
2
+
3
+ import torch
4
+ import cv2
5
+ import numpy as np
6
+ from mmcv.ops import RoIPool
7
+ from mmengine.dataset import Compose, pseudo_collate
8
+ from mmengine.device import get_device
9
+ from mmengine.registry import init_default_scope
10
+ from mmdet.apis import inference_detector, init_detector
11
+ from mmdet.structures import DetDataSample, SampleList
12
+ from mmdet.utils import get_test_pipeline_cfg
13
+
14
+
15
+ ImagesType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
16
+
17
+ def nms(dets: np.ndarray, thr: float):
18
+ """Greedily select boxes with high confidence and overlap <= thr.
19
+ Args:
20
+ dets (np.ndarray): [[x1, y1, x2, y2, score]].
21
+ thr (float): Retain overlap < thr.
22
+ Returns:
23
+ list: Indexes to keep.
24
+ """
25
+ if len(dets) == 0:
26
+ return []
27
+
28
+ x1 = dets[:, 0]
29
+ y1 = dets[:, 1]
30
+ x2 = dets[:, 2]
31
+ y2 = dets[:, 3]
32
+ scores = dets[:, 4]
33
+
34
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
35
+ order = scores.argsort()[::-1]
36
+
37
+ keep = []
38
+ while len(order) > 0:
39
+ i = order[0]
40
+ keep.append(i)
41
+ xx1 = np.maximum(x1[i], x1[order[1:]])
42
+ yy1 = np.maximum(y1[i], y1[order[1:]])
43
+ xx2 = np.minimum(x2[i], x2[order[1:]])
44
+ yy2 = np.minimum(y2[i], y2[order[1:]])
45
+
46
+ w = np.maximum(0.0, xx2 - xx1 + 1)
47
+ h = np.maximum(0.0, yy2 - yy1 + 1)
48
+ inter = w * h
49
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
50
+
51
+ inds = np.where(ovr <= thr)[0]
52
+ order = order[inds + 1]
53
+
54
+ return keep
55
+
56
+ def adapt_mmdet_pipeline(cfg):
57
+ """Converts pipeline types in MMDetection's test dataloader to use the
58
+ 'mmdet' namespace.
59
+
60
+ Args:
61
+ cfg (ConfigDict): Configuration dictionary for MMDetection.
62
+
63
+ Returns:
64
+ ConfigDict: Configuration dictionary with updated pipeline types.
65
+ """
66
+ # use lazy import to avoid hard dependence on mmdet
67
+ from mmdet.datasets import transforms
68
+
69
+ if 'test_dataloader' not in cfg:
70
+ return cfg
71
+
72
+ pipeline = cfg.test_dataloader.dataset.pipeline
73
+ for trans in pipeline:
74
+ if trans['type'] in dir(transforms):
75
+ trans['type'] = 'mmdet.' + trans['type']
76
+
77
+ return cfg
78
+
79
+
80
+ def inference_detector(
81
+ model: torch.nn.Module,
82
+ imgs: ImagesType,
83
+ test_pipeline: Optional[Compose] = None,
84
+ text_prompt: Optional[str] = None,
85
+ custom_entities: bool = False,
86
+ ) -> Union[DetDataSample, SampleList]:
87
+ """Inference image(s) with the detector.
88
+
89
+ Args:
90
+ model (nn.Module): The loaded detector.
91
+ imgs (str, ndarray, Sequence[str/ndarray]):
92
+ Either image files or loaded images.
93
+ test_pipeline (:obj:`Compose`): Test pipeline.
94
+
95
+ Returns:
96
+ :obj:`DetDataSample` or list[:obj:`DetDataSample`]:
97
+ If imgs is a list or tuple, the same length list type results
98
+ will be returned, otherwise return the detection results directly.
99
+ """
100
+ if isinstance(imgs, torch.Tensor):
101
+ if imgs.is_cuda:
102
+ imgs = imgs.cpu()
103
+
104
+ # Remove batch dimension and transpose
105
+ imgs = imgs.squeeze(0).permute(1, 2, 0).numpy()
106
+
107
+ # Ensure the data type is appropriate (uint8 for most image processing functions)
108
+ imgs = (imgs * 255).astype(np.uint8)
109
+
110
+ if isinstance(imgs, (list, tuple)) or (isinstance(imgs, np.ndarray) and len(imgs.shape) == 4):
111
+ is_batch = True
112
+ else:
113
+ imgs = [imgs]
114
+ is_batch = False
115
+
116
+ cfg = model.cfg
117
+
118
+ if test_pipeline is None:
119
+ cfg = cfg.copy()
120
+ test_pipeline = get_test_pipeline_cfg(cfg)
121
+ if isinstance(imgs[0], np.ndarray):
122
+ # Calling this method across libraries will result
123
+ # in module unregistered error if not prefixed with mmdet.
124
+ test_pipeline[0].type = "mmdet.LoadImageFromNDArray"
125
+
126
+ test_pipeline = Compose(test_pipeline)
127
+
128
+ if model.data_preprocessor.device.type == "cpu":
129
+ for m in model.modules():
130
+ assert not isinstance(
131
+ m, RoIPool
132
+ ), "CPU inference with RoIPool is not supported currently."
133
+
134
+ result_list = []
135
+ for i, img in enumerate(imgs):
136
+ # prepare data
137
+ if isinstance(img, np.ndarray):
138
+ # TODO: remove img_id.
139
+ data_ = dict(img=img, img_id=0)
140
+ else:
141
+ # TODO: remove img_id.
142
+ data_ = dict(img_path=img, img_id=0)
143
+
144
+ if text_prompt:
145
+ data_["text"] = text_prompt
146
+ data_["custom_entities"] = custom_entities
147
+
148
+ # build the data pipeline
149
+ data_ = test_pipeline(data_)
150
+
151
+ data_["inputs"] = [data_["inputs"]]
152
+ data_["data_samples"] = [data_["data_samples"]]
153
+
154
+ # forward the model
155
+ with torch.no_grad(), torch.autocast(device_type=get_device(), dtype=torch.bfloat16):
156
+ results = model.test_step(data_)[0]
157
+
158
+ result_list.append(results)
159
+
160
+ if not is_batch:
161
+ return result_list[0]
162
+ else:
163
+ return result_list
164
+
165
+
166
+ def process_one_image_bbox(pred_instance, det_cat_id, bbox_thr, nms_thr):
167
+ bboxes = np.concatenate(
168
+ (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1
169
+ )
170
+ bboxes = bboxes[
171
+ np.logical_and(
172
+ pred_instance.labels == det_cat_id,
173
+ pred_instance.scores > bbox_thr,
174
+ )
175
+ ]
176
+ bboxes = bboxes[nms(bboxes, nms_thr), :4]
177
+ return bboxes
178
+
179
+
180
+ def process_images_detector(imgs, detector):
181
+ """Visualize predicted keypoints (and heatmaps) of one image."""
182
+ # predict bbox
183
+ det_results = inference_detector(detector, imgs)
184
+ pred_instances = list(
185
+ map(lambda det_result: det_result.pred_instances.numpy(), det_results)
186
+ )
187
+ bboxes_batch = list(
188
+ map(
189
+ lambda pred_instance: process_one_image_bbox(
190
+ pred_instance, 0, 0.3, 0.3 ## argparse.Namespace(det_cat_id=0, bbox_thr=0.3, nms_thr=0.3),
191
+ ),
192
+ pred_instances,
193
+ )
194
+ )
195
+
196
+ return bboxes_batch
external/cv/.gitignore ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # PyTorch checkpoint
10
+ *.pth
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ #dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ *.egg-info/
27
+ .installed.cfg
28
+ *.egg
29
+ MANIFEST
30
+ mlu-ops/
31
+ mlu-ops.*
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ .hypothesis/
53
+ .pytest_cache/
54
+
55
+ # Translations
56
+ *.mo
57
+ *.pot
58
+
59
+ # Django stuff:
60
+ *.log
61
+ local_settings.py
62
+ db.sqlite3
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/en/_build/
73
+ docs/en/api/generated/
74
+ docs/zh_cn/_build/
75
+ docs/zh_cn/api/generated/
76
+
77
+ # PyBuilder
78
+ target/
79
+
80
+ # Jupyter Notebook
81
+ .ipynb_checkpoints
82
+
83
+ # pyenv
84
+ .python-version
85
+
86
+ # celery beat schedule file
87
+ celerybeat-schedule
88
+
89
+ # SageMath parsed files
90
+ *.sage.py
91
+
92
+ # Environments
93
+ .env
94
+ .venv
95
+ env/
96
+ venv/
97
+ ENV/
98
+ env.bak/
99
+ venv.bak/
100
+
101
+ # Spyder project settings
102
+ .spyderproject
103
+ .spyproject
104
+
105
+ # Rope project settings
106
+ .ropeproject
107
+
108
+ # mkdocs documentation
109
+ /site
110
+
111
+ # mypy
112
+ .mypy_cache/
113
+
114
+ # editors and IDEs
115
+ .idea/
116
+ .vscode/
117
+
118
+ # custom
119
+ .DS_Store
120
+
121
+ # datasets and logs and checkpoints
122
+ data/
123
+ work_dir/
124
+
125
+ src/
external/cv/MANIFEST.in ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ include requirements/runtime.txt
2
+ include mmcv/ops/csrc/common/cuda/*.cuh mmcv/ops/csrc/common/cuda/*.hpp mmcv/ops/csrc/common/*.hpp
3
+ include mmcv/ops/csrc/pytorch/*.cpp mmcv/ops/csrc/pytorch/cuda/*.cu mmcv/ops/csrc/pytorch/cuda/*.cpp mmcv/ops/csrc/pytorch/cpu/*.cpp
4
+ include mmcv/ops/csrc/parrots/*.h mmcv/ops/csrc/parrots/*.cpp
5
+ include mmcv/ops/csrc/pytorch/mps/*.mm mmcv/ops/csrc/common/mps/*.h mmcv/ops/csrc/common/mps/*.mm
6
+ recursive-include mmcv/ops/csrc/ *.h *.hpp *.cpp *.cuh *.cu *.mm
external/cv/dist/sapiens_cv-1.0.0-cp310-cp310-linux_x86_64.whl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:746f2be13eefdfe43a59d9c415e03a4b0b922e6ce487b76a572a376ae76c9300
3
+ size 30006791
external/cv/mmcv/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # flake8: noqa
8
+ from .arraymisc import *
9
+ from .image import *
10
+ from .transforms import *
11
+ from .version import *
12
+ from .video import *
13
+ from .visualization import *
14
+
15
+ # The following modules are not imported to this level, so mmcv may be used
16
+ # without PyTorch.
17
+ # - op
18
+ # - utils
external/cv/mmcv/arraymisc/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .quantization import dequantize, quantize
8
+
9
+ __all__ = ['quantize', 'dequantize']
external/cv/mmcv/arraymisc/quantization.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Union
8
+
9
+ import numpy as np
10
+
11
+
12
+ def quantize(arr: np.ndarray,
13
+ min_val: Union[int, float],
14
+ max_val: Union[int, float],
15
+ levels: int,
16
+ dtype=np.int64) -> tuple:
17
+ """Quantize an array of (-inf, inf) to [0, levels-1].
18
+
19
+ Args:
20
+ arr (ndarray): Input array.
21
+ min_val (int or float): Minimum value to be clipped.
22
+ max_val (int or float): Maximum value to be clipped.
23
+ levels (int): Quantization levels.
24
+ dtype (np.type): The type of the quantized array.
25
+
26
+ Returns:
27
+ tuple: Quantized array.
28
+ """
29
+ if not (isinstance(levels, int) and levels > 1):
30
+ raise ValueError(
31
+ f'levels must be a positive integer, but got {levels}')
32
+ if min_val >= max_val:
33
+ raise ValueError(
34
+ f'min_val ({min_val}) must be smaller than max_val ({max_val})')
35
+
36
+ arr = np.clip(arr, min_val, max_val) - min_val
37
+ quantized_arr = np.minimum(
38
+ np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
39
+
40
+ return quantized_arr
41
+
42
+
43
+ def dequantize(arr: np.ndarray,
44
+ min_val: Union[int, float],
45
+ max_val: Union[int, float],
46
+ levels: int,
47
+ dtype=np.float64) -> tuple:
48
+ """Dequantize an array.
49
+
50
+ Args:
51
+ arr (ndarray): Input array.
52
+ min_val (int or float): Minimum value to be clipped.
53
+ max_val (int or float): Maximum value to be clipped.
54
+ levels (int): Quantization levels.
55
+ dtype (np.type): The type of the dequantized array.
56
+
57
+ Returns:
58
+ tuple: Dequantized array.
59
+ """
60
+ if not (isinstance(levels, int) and levels > 1):
61
+ raise ValueError(
62
+ f'levels must be a positive integer, but got {levels}')
63
+ if min_val >= max_val:
64
+ raise ValueError(
65
+ f'min_val ({min_val}) must be smaller than max_val ({max_val})')
66
+
67
+ dequantized_arr = (arr + 0.5).astype(dtype) * (max_val -
68
+ min_val) / levels + min_val
69
+
70
+ return dequantized_arr
external/cv/mmcv/cnn/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .alexnet import AlexNet
8
+ # yapf: disable
9
+ from .bricks import (ContextBlock, Conv2d, Conv3d, ConvAWS2d, ConvModule,
10
+ ConvTranspose2d, ConvTranspose3d, ConvWS2d,
11
+ DepthwiseSeparableConvModule, GeneralizedAttention,
12
+ HSigmoid, HSwish, Linear, MaxPool2d, MaxPool3d,
13
+ NonLocal1d, NonLocal2d, NonLocal3d, Scale, Swish,
14
+ build_activation_layer, build_conv_layer,
15
+ build_norm_layer, build_padding_layer, build_plugin_layer,
16
+ build_upsample_layer, conv_ws_2d, is_norm)
17
+ # yapf: enable
18
+ from .resnet import ResNet, make_res_layer
19
+ from .rfsearch import Conv2dRFSearchOp, RFSearchHook
20
+ from .utils import fuse_conv_bn, get_model_complexity_info
21
+ from .vgg import VGG, make_vgg_layer
22
+
23
+ __all__ = [
24
+ 'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer',
25
+ 'ConvModule', 'build_activation_layer', 'build_conv_layer',
26
+ 'build_norm_layer', 'build_padding_layer', 'build_upsample_layer',
27
+ 'build_plugin_layer', 'is_norm', 'NonLocal1d', 'NonLocal2d', 'NonLocal3d',
28
+ 'ContextBlock', 'HSigmoid', 'Swish', 'HSwish', 'GeneralizedAttention',
29
+ 'Scale', 'conv_ws_2d', 'ConvAWS2d', 'ConvWS2d',
30
+ 'DepthwiseSeparableConvModule', 'Linear', 'Conv2d', 'ConvTranspose2d',
31
+ 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d', 'Conv3d', 'fuse_conv_bn',
32
+ 'get_model_complexity_info', 'Conv2dRFSearchOp', 'RFSearchHook'
33
+ ]
external/cv/mmcv/cnn/alexnet.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ from typing import Optional
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from mmengine.runner import load_checkpoint
13
+
14
+
15
+ class AlexNet(nn.Module):
16
+ """AlexNet backbone.
17
+
18
+ Args:
19
+ num_classes (int): number of classes for classification.
20
+ """
21
+
22
+ def __init__(self, num_classes: int = -1):
23
+ super().__init__()
24
+ self.num_classes = num_classes
25
+ self.features = nn.Sequential(
26
+ nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
27
+ nn.ReLU(inplace=True),
28
+ nn.MaxPool2d(kernel_size=3, stride=2),
29
+ nn.Conv2d(64, 192, kernel_size=5, padding=2),
30
+ nn.ReLU(inplace=True),
31
+ nn.MaxPool2d(kernel_size=3, stride=2),
32
+ nn.Conv2d(192, 384, kernel_size=3, padding=1),
33
+ nn.ReLU(inplace=True),
34
+ nn.Conv2d(384, 256, kernel_size=3, padding=1),
35
+ nn.ReLU(inplace=True),
36
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
37
+ nn.ReLU(inplace=True),
38
+ nn.MaxPool2d(kernel_size=3, stride=2),
39
+ )
40
+ if self.num_classes > 0:
41
+ self.classifier = nn.Sequential(
42
+ nn.Dropout(),
43
+ nn.Linear(256 * 6 * 6, 4096),
44
+ nn.ReLU(inplace=True),
45
+ nn.Dropout(),
46
+ nn.Linear(4096, 4096),
47
+ nn.ReLU(inplace=True),
48
+ nn.Linear(4096, num_classes),
49
+ )
50
+
51
+ def init_weights(self, pretrained: Optional[str] = None) -> None:
52
+ if isinstance(pretrained, str):
53
+ logger = logging.getLogger()
54
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
55
+ elif pretrained is None:
56
+ # use default initializer
57
+ pass
58
+ else:
59
+ raise TypeError('pretrained must be a str or None')
60
+
61
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
62
+
63
+ x = self.features(x)
64
+ if self.num_classes > 0:
65
+ x = x.view(x.size(0), 256 * 6 * 6)
66
+ x = self.classifier(x)
67
+
68
+ return x
external/cv/mmcv/cnn/bricks/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .activation import build_activation_layer
8
+ from .context_block import ContextBlock
9
+ from .conv import build_conv_layer
10
+ from .conv2d_adaptive_padding import Conv2dAdaptivePadding
11
+ from .conv_module import ConvModule
12
+ from .conv_ws import ConvAWS2d, ConvWS2d, conv_ws_2d
13
+ from .depthwise_separable_conv_module import DepthwiseSeparableConvModule
14
+ from .drop import Dropout, DropPath
15
+ from .generalized_attention import GeneralizedAttention
16
+ from .hsigmoid import HSigmoid
17
+ from .hswish import HSwish
18
+ from .non_local import NonLocal1d, NonLocal2d, NonLocal3d
19
+ from .norm import build_norm_layer, is_norm
20
+ from .padding import build_padding_layer
21
+ from .plugin import build_plugin_layer
22
+ from .scale import LayerScale, Scale
23
+ from .swish import Swish
24
+ from .upsample import build_upsample_layer
25
+ from .wrappers import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
26
+ Linear, MaxPool2d, MaxPool3d)
27
+
28
+ __all__ = [
29
+ 'ConvModule', 'build_activation_layer', 'build_conv_layer',
30
+ 'build_norm_layer', 'build_padding_layer', 'build_upsample_layer',
31
+ 'build_plugin_layer', 'is_norm', 'HSigmoid', 'HSwish', 'NonLocal1d',
32
+ 'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'GeneralizedAttention',
33
+ 'Scale', 'ConvAWS2d', 'ConvWS2d', 'conv_ws_2d',
34
+ 'DepthwiseSeparableConvModule', 'Swish', 'Linear', 'Conv2dAdaptivePadding',
35
+ 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d',
36
+ 'Conv3d', 'Dropout', 'DropPath', 'LayerScale'
37
+ ]
external/cv/mmcv/cnn/bricks/activation.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Dict
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from mmengine.registry import MODELS
13
+ from mmengine.utils import digit_version
14
+ from mmengine.utils.dl_utils import TORCH_VERSION
15
+
16
+ for module in [
17
+ nn.ReLU, nn.LeakyReLU, nn.PReLU, nn.RReLU, nn.ReLU6, nn.ELU,
18
+ nn.Sigmoid, nn.Tanh
19
+ ]:
20
+ MODELS.register_module(module=module)
21
+
22
+ if digit_version(torch.__version__) >= digit_version('1.7.0'):
23
+ MODELS.register_module(module=nn.SiLU, name='SiLU')
24
+ else:
25
+
26
+ class SiLU(nn.Module):
27
+ """Sigmoid Weighted Liner Unit."""
28
+
29
+ def __init__(self, inplace=False):
30
+ super().__init__()
31
+ self.inplace = inplace
32
+
33
+ def forward(self, inputs) -> torch.Tensor:
34
+ if self.inplace:
35
+ return inputs.mul_(torch.sigmoid(inputs))
36
+ else:
37
+ return inputs * torch.sigmoid(inputs)
38
+
39
+ MODELS.register_module(module=SiLU, name='SiLU')
40
+
41
+
42
+ @MODELS.register_module(name='Clip')
43
+ @MODELS.register_module()
44
+ class Clamp(nn.Module):
45
+ """Clamp activation layer.
46
+
47
+ This activation function is to clamp the feature map value within
48
+ :math:`[min, max]`. More details can be found in ``torch.clamp()``.
49
+
50
+ Args:
51
+ min (Number | optional): Lower-bound of the range to be clamped to.
52
+ Default to -1.
53
+ max (Number | optional): Upper-bound of the range to be clamped to.
54
+ Default to 1.
55
+ """
56
+
57
+ def __init__(self, min: float = -1., max: float = 1.):
58
+ super().__init__()
59
+ self.min = min
60
+ self.max = max
61
+
62
+ def forward(self, x) -> torch.Tensor:
63
+ """Forward function.
64
+
65
+ Args:
66
+ x (torch.Tensor): The input tensor.
67
+
68
+ Returns:
69
+ torch.Tensor: Clamped tensor.
70
+ """
71
+ return torch.clamp(x, min=self.min, max=self.max)
72
+
73
+
74
+ class GELU(nn.Module):
75
+ r"""Applies the Gaussian Error Linear Units function:
76
+
77
+ .. math::
78
+ \text{GELU}(x) = x * \Phi(x)
79
+ where :math:`\Phi(x)` is the Cumulative Distribution Function for
80
+ Gaussian Distribution.
81
+
82
+ Shape:
83
+ - Input: :math:`(N, *)` where `*` means, any number of additional
84
+ dimensions
85
+ - Output: :math:`(N, *)`, same shape as the input
86
+
87
+ .. image:: scripts/activation_images/GELU.png
88
+
89
+ Examples::
90
+
91
+ >>> m = nn.GELU()
92
+ >>> input = torch.randn(2)
93
+ >>> output = m(input)
94
+ """
95
+
96
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
97
+ return F.gelu(input)
98
+
99
+
100
+ if (TORCH_VERSION == 'parrots'
101
+ or digit_version(TORCH_VERSION) < digit_version('1.4')):
102
+ MODELS.register_module(module=GELU)
103
+ else:
104
+ MODELS.register_module(module=nn.GELU)
105
+
106
+
107
+ def build_activation_layer(cfg: Dict) -> nn.Module:
108
+ """Build activation layer.
109
+
110
+ Args:
111
+ cfg (dict): The activation layer config, which should contain:
112
+
113
+ - type (str): Layer type.
114
+ - layer args: Args needed to instantiate an activation layer.
115
+
116
+ Returns:
117
+ nn.Module: Created activation layer.
118
+ """
119
+ return MODELS.build(cfg)
external/cv/mmcv/cnn/bricks/context_block.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Union
8
+
9
+ import torch
10
+ from mmengine.model import constant_init, kaiming_init
11
+ from mmengine.registry import MODELS
12
+ from torch import nn
13
+
14
+
15
+ def last_zero_init(m: Union[nn.Module, nn.Sequential]) -> None:
16
+ if isinstance(m, nn.Sequential):
17
+ constant_init(m[-1], val=0)
18
+ else:
19
+ constant_init(m, val=0)
20
+
21
+
22
+ @MODELS.register_module()
23
+ class ContextBlock(nn.Module):
24
+ """ContextBlock module in GCNet.
25
+
26
+ See 'GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond'
27
+ (https://arxiv.org/abs/1904.11492) for details.
28
+
29
+ Args:
30
+ in_channels (int): Channels of the input feature map.
31
+ ratio (float): Ratio of channels of transform bottleneck
32
+ pooling_type (str): Pooling method for context modeling.
33
+ Options are 'att' and 'avg', stand for attention pooling and
34
+ average pooling respectively. Default: 'att'.
35
+ fusion_types (Sequence[str]): Fusion method for feature fusion,
36
+ Options are 'channels_add', 'channel_mul', stand for channelwise
37
+ addition and multiplication respectively. Default: ('channel_add',)
38
+ """
39
+
40
+ _abbr_ = 'context_block'
41
+
42
+ def __init__(self,
43
+ in_channels: int,
44
+ ratio: float,
45
+ pooling_type: str = 'att',
46
+ fusion_types: tuple = ('channel_add', )):
47
+ super().__init__()
48
+ assert pooling_type in ['avg', 'att']
49
+ assert isinstance(fusion_types, (list, tuple))
50
+ valid_fusion_types = ['channel_add', 'channel_mul']
51
+ assert all([f in valid_fusion_types for f in fusion_types])
52
+ assert len(fusion_types) > 0, 'at least one fusion should be used'
53
+ self.in_channels = in_channels
54
+ self.ratio = ratio
55
+ self.planes = int(in_channels * ratio)
56
+ self.pooling_type = pooling_type
57
+ self.fusion_types = fusion_types
58
+ if pooling_type == 'att':
59
+ self.conv_mask = nn.Conv2d(in_channels, 1, kernel_size=1)
60
+ self.softmax = nn.Softmax(dim=2)
61
+ else:
62
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
63
+ if 'channel_add' in fusion_types:
64
+ self.channel_add_conv = nn.Sequential(
65
+ nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
66
+ nn.LayerNorm([self.planes, 1, 1]),
67
+ nn.ReLU(inplace=True), # yapf: disable
68
+ nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
69
+ else:
70
+ self.channel_add_conv = None
71
+ if 'channel_mul' in fusion_types:
72
+ self.channel_mul_conv = nn.Sequential(
73
+ nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
74
+ nn.LayerNorm([self.planes, 1, 1]),
75
+ nn.ReLU(inplace=True), # yapf: disable
76
+ nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
77
+ else:
78
+ self.channel_mul_conv = None
79
+ self.reset_parameters()
80
+
81
+ def reset_parameters(self):
82
+ if self.pooling_type == 'att':
83
+ kaiming_init(self.conv_mask, mode='fan_in')
84
+ self.conv_mask.inited = True
85
+
86
+ if self.channel_add_conv is not None:
87
+ last_zero_init(self.channel_add_conv)
88
+ if self.channel_mul_conv is not None:
89
+ last_zero_init(self.channel_mul_conv)
90
+
91
+ def spatial_pool(self, x: torch.Tensor) -> torch.Tensor:
92
+ batch, channel, height, width = x.size()
93
+ if self.pooling_type == 'att':
94
+ input_x = x
95
+ # [N, C, H * W]
96
+ input_x = input_x.view(batch, channel, height * width)
97
+ # [N, 1, C, H * W]
98
+ input_x = input_x.unsqueeze(1)
99
+ # [N, 1, H, W]
100
+ context_mask = self.conv_mask(x)
101
+ # [N, 1, H * W]
102
+ context_mask = context_mask.view(batch, 1, height * width)
103
+ # [N, 1, H * W]
104
+ context_mask = self.softmax(context_mask)
105
+ # [N, 1, H * W, 1]
106
+ context_mask = context_mask.unsqueeze(-1)
107
+ # [N, 1, C, 1]
108
+ context = torch.matmul(input_x, context_mask)
109
+ # [N, C, 1, 1]
110
+ context = context.view(batch, channel, 1, 1)
111
+ else:
112
+ # [N, C, 1, 1]
113
+ context = self.avg_pool(x)
114
+
115
+ return context
116
+
117
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
118
+ # [N, C, 1, 1]
119
+ context = self.spatial_pool(x)
120
+
121
+ out = x
122
+ if self.channel_mul_conv is not None:
123
+ # [N, C, 1, 1]
124
+ channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
125
+ out = out * channel_mul_term
126
+ if self.channel_add_conv is not None:
127
+ # [N, C, 1, 1]
128
+ channel_add_term = self.channel_add_conv(context)
129
+ out = out + channel_add_term
130
+
131
+ return out
external/cv/mmcv/cnn/bricks/conv.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import inspect
8
+ from typing import Dict, Optional
9
+
10
+ from mmengine.registry import MODELS
11
+ from torch import nn
12
+
13
+ MODELS.register_module('Conv1d', module=nn.Conv1d)
14
+ MODELS.register_module('Conv2d', module=nn.Conv2d)
15
+ MODELS.register_module('Conv3d', module=nn.Conv3d)
16
+ MODELS.register_module('Conv', module=nn.Conv2d)
17
+
18
+
19
+ def build_conv_layer(cfg: Optional[Dict], *args, **kwargs) -> nn.Module:
20
+ """Build convolution layer.
21
+
22
+ Args:
23
+ cfg (None or dict): The conv layer config, which should contain:
24
+ - type (str): Layer type.
25
+ - layer args: Args needed to instantiate an conv layer.
26
+ args (argument list): Arguments passed to the `__init__`
27
+ method of the corresponding conv layer.
28
+ kwargs (keyword arguments): Keyword arguments passed to the `__init__`
29
+ method of the corresponding conv layer.
30
+
31
+ Returns:
32
+ nn.Module: Created conv layer.
33
+ """
34
+ if cfg is None:
35
+ cfg_ = dict(type='Conv2d')
36
+ else:
37
+ if not isinstance(cfg, dict):
38
+ raise TypeError('cfg must be a dict')
39
+ if 'type' not in cfg:
40
+ raise KeyError('the cfg dict must contain the key "type"')
41
+ cfg_ = cfg.copy()
42
+
43
+ layer_type = cfg_.pop('type')
44
+ if inspect.isclass(layer_type):
45
+ return layer_type(*args, **kwargs, **cfg_) # type: ignore
46
+ # Switch registry to the target scope. If `conv_layer` cannot be found
47
+ # in the registry, fallback to search `conv_layer` in the
48
+ # mmengine.MODELS.
49
+ with MODELS.switch_scope_and_registry(None) as registry:
50
+ conv_layer = registry.get(layer_type)
51
+ if conv_layer is None:
52
+ raise KeyError(f'Cannot find {conv_layer} in registry under scope '
53
+ f'name {registry.scope}')
54
+ layer = conv_layer(*args, **kwargs, **cfg_)
55
+
56
+ return layer
external/cv/mmcv/cnn/bricks/conv2d_adaptive_padding.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Tuple, Union
9
+
10
+ import torch
11
+ from mmengine.registry import MODELS
12
+ from torch import nn
13
+ from torch.nn import functional as F
14
+
15
+
16
+ @MODELS.register_module()
17
+ class Conv2dAdaptivePadding(nn.Conv2d):
18
+ """Implementation of 2D convolution in tensorflow with `padding` as "same",
19
+ which applies padding to input (if needed) so that input image gets fully
20
+ covered by filter and stride you specified. For stride 1, this will ensure
21
+ that output image size is same as input. For stride of 2, output dimensions
22
+ will be half, for example.
23
+
24
+ Args:
25
+ in_channels (int): Number of channels in the input image
26
+ out_channels (int): Number of channels produced by the convolution
27
+ kernel_size (int or tuple): Size of the convolving kernel
28
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
29
+ padding (int or tuple, optional): Zero-padding added to both sides of
30
+ the input. Default: 0
31
+ dilation (int or tuple, optional): Spacing between kernel elements.
32
+ Default: 1
33
+ groups (int, optional): Number of blocked connections from input
34
+ channels to output channels. Default: 1
35
+ bias (bool, optional): If ``True``, adds a learnable bias to the
36
+ output. Default: ``True``
37
+ """
38
+
39
+ def __init__(self,
40
+ in_channels: int,
41
+ out_channels: int,
42
+ kernel_size: Union[int, Tuple[int, int]],
43
+ stride: Union[int, Tuple[int, int]] = 1,
44
+ padding: Union[int, Tuple[int, int]] = 0,
45
+ dilation: Union[int, Tuple[int, int]] = 1,
46
+ groups: int = 1,
47
+ bias: bool = True):
48
+ super().__init__(in_channels, out_channels, kernel_size, stride, 0,
49
+ dilation, groups, bias)
50
+
51
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
52
+ img_h, img_w = x.size()[-2:]
53
+ kernel_h, kernel_w = self.weight.size()[-2:]
54
+ stride_h, stride_w = self.stride
55
+ output_h = math.ceil(img_h / stride_h)
56
+ output_w = math.ceil(img_w / stride_w)
57
+ pad_h = (
58
+ max((output_h - 1) * self.stride[0] +
59
+ (kernel_h - 1) * self.dilation[0] + 1 - img_h, 0))
60
+ pad_w = (
61
+ max((output_w - 1) * self.stride[1] +
62
+ (kernel_w - 1) * self.dilation[1] + 1 - img_w, 0))
63
+ if pad_h > 0 or pad_w > 0:
64
+ x = F.pad(x, [
65
+ pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2
66
+ ])
67
+ return F.conv2d(x, self.weight, self.bias, self.stride, self.padding,
68
+ self.dilation, self.groups)
external/cv/mmcv/cnn/bricks/conv_module.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import warnings
8
+ from functools import partial
9
+ from typing import Dict, Optional, Tuple, Union
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from mmengine.model import constant_init, kaiming_init
14
+ from mmengine.registry import MODELS
15
+ from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm
16
+
17
+ from .activation import build_activation_layer
18
+ from .conv import build_conv_layer
19
+ from .norm import build_norm_layer
20
+ from .padding import build_padding_layer
21
+
22
+
23
+ def efficient_conv_bn_eval_forward(bn: _BatchNorm,
24
+ conv: nn.modules.conv._ConvNd,
25
+ x: torch.Tensor):
26
+ """
27
+ Implementation based on https://arxiv.org/abs/2305.11624
28
+ "Tune-Mode ConvBN Blocks For Efficient Transfer Learning"
29
+ It leverages the associative law between convolution and affine transform,
30
+ i.e., normalize (weight conv feature) = (normalize weight) conv feature.
31
+ It works for Eval mode of ConvBN blocks during validation, and can be used
32
+ for training as well. It reduces memory and computation cost.
33
+
34
+ Args:
35
+ bn (_BatchNorm): a BatchNorm module.
36
+ conv (nn._ConvNd): a conv module
37
+ x (torch.Tensor): Input feature map.
38
+ """
39
+ # These lines of code are designed to deal with various cases
40
+ # like bn without affine transform, and conv without bias
41
+ weight_on_the_fly = conv.weight
42
+ if conv.bias is not None:
43
+ bias_on_the_fly = conv.bias
44
+ else:
45
+ bias_on_the_fly = torch.zeros_like(bn.running_var)
46
+
47
+ if bn.weight is not None:
48
+ bn_weight = bn.weight
49
+ else:
50
+ bn_weight = torch.ones_like(bn.running_var)
51
+
52
+ if bn.bias is not None:
53
+ bn_bias = bn.bias
54
+ else:
55
+ bn_bias = torch.zeros_like(bn.running_var)
56
+
57
+ # shape of [C_out, 1, 1, 1] in Conv2d
58
+ weight_coeff = torch.rsqrt(bn.running_var +
59
+ bn.eps).reshape([-1] + [1] *
60
+ (len(conv.weight.shape) - 1))
61
+ # shape of [C_out, 1, 1, 1] in Conv2d
62
+ coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff
63
+
64
+ # shape of [C_out, C_in, k, k] in Conv2d
65
+ weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly
66
+ # shape of [C_out] in Conv2d
67
+ bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() *\
68
+ (bias_on_the_fly - bn.running_mean)
69
+
70
+ return conv._conv_forward(x, weight_on_the_fly, bias_on_the_fly)
71
+
72
+
73
+ @MODELS.register_module()
74
+ class ConvModule(nn.Module):
75
+ """A conv block that bundles conv/norm/activation layers.
76
+
77
+ This block simplifies the usage of convolution layers, which are commonly
78
+ used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
79
+ It is based upon three build methods: `build_conv_layer()`,
80
+ `build_norm_layer()` and `build_activation_layer()`.
81
+
82
+ Besides, we add some additional features in this module.
83
+ 1. Automatically set `bias` of the conv layer.
84
+ 2. Spectral norm is supported.
85
+ 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
86
+ supports zero and circular padding, and we add "reflect" padding mode.
87
+
88
+ Args:
89
+ in_channels (int): Number of channels in the input feature map.
90
+ Same as that in ``nn._ConvNd``.
91
+ out_channels (int): Number of channels produced by the convolution.
92
+ Same as that in ``nn._ConvNd``.
93
+ kernel_size (int | tuple[int]): Size of the convolving kernel.
94
+ Same as that in ``nn._ConvNd``.
95
+ stride (int | tuple[int]): Stride of the convolution.
96
+ Same as that in ``nn._ConvNd``.
97
+ padding (int | tuple[int]): Zero-padding added to both sides of
98
+ the input. Same as that in ``nn._ConvNd``.
99
+ dilation (int | tuple[int]): Spacing between kernel elements.
100
+ Same as that in ``nn._ConvNd``.
101
+ groups (int): Number of blocked connections from input channels to
102
+ output channels. Same as that in ``nn._ConvNd``.
103
+ bias (bool | str): If specified as `auto`, it will be decided by the
104
+ norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise
105
+ False. Default: "auto".
106
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
107
+ which means using conv2d.
108
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
109
+ act_cfg (dict): Config dict for activation layer.
110
+ Default: dict(type='ReLU').
111
+ inplace (bool): Whether to use inplace mode for activation.
112
+ Default: True.
113
+ with_spectral_norm (bool): Whether use spectral norm in conv module.
114
+ Default: False.
115
+ padding_mode (str): If the `padding_mode` has not been supported by
116
+ current `Conv2d` in PyTorch, we will use our own padding layer
117
+ instead. Currently, we support ['zeros', 'circular'] with official
118
+ implementation and ['reflect'] with our own implementation.
119
+ Default: 'zeros'.
120
+ order (tuple[str]): The order of conv/norm/activation layers. It is a
121
+ sequence of "conv", "norm" and "act". Common examples are
122
+ ("conv", "norm", "act") and ("act", "conv", "norm").
123
+ Default: ('conv', 'norm', 'act').
124
+ efficient_conv_bn_eval (bool): Whether use efficient conv when the
125
+ consecutive bn is in eval mode (either training or testing), as
126
+ proposed in https://arxiv.org/abs/2305.11624 . Default: `False`.
127
+ """
128
+
129
+ _abbr_ = 'conv_block'
130
+
131
+ def __init__(self,
132
+ in_channels: int,
133
+ out_channels: int,
134
+ kernel_size: Union[int, Tuple[int, int]],
135
+ stride: Union[int, Tuple[int, int]] = 1,
136
+ padding: Union[int, Tuple[int, int]] = 0,
137
+ dilation: Union[int, Tuple[int, int]] = 1,
138
+ groups: int = 1,
139
+ bias: Union[bool, str] = 'auto',
140
+ conv_cfg: Optional[Dict] = None,
141
+ norm_cfg: Optional[Dict] = None,
142
+ act_cfg: Optional[Dict] = dict(type='ReLU'),
143
+ inplace: bool = True,
144
+ with_spectral_norm: bool = False,
145
+ padding_mode: str = 'zeros',
146
+ order: tuple = ('conv', 'norm', 'act'),
147
+ efficient_conv_bn_eval: bool = False):
148
+ super().__init__()
149
+ assert conv_cfg is None or isinstance(conv_cfg, dict)
150
+ assert norm_cfg is None or isinstance(norm_cfg, dict)
151
+ assert act_cfg is None or isinstance(act_cfg, dict)
152
+ official_padding_mode = ['zeros', 'circular']
153
+ self.conv_cfg = conv_cfg
154
+ self.norm_cfg = norm_cfg
155
+ self.act_cfg = act_cfg
156
+ self.inplace = inplace
157
+ self.with_spectral_norm = with_spectral_norm
158
+ self.with_explicit_padding = padding_mode not in official_padding_mode
159
+ self.order = order
160
+ assert isinstance(self.order, tuple) and len(self.order) == 3
161
+ assert set(order) == {'conv', 'norm', 'act'}
162
+
163
+ self.with_norm = norm_cfg is not None
164
+ self.with_activation = act_cfg is not None
165
+ # if the conv layer is before a norm layer, bias is unnecessary.
166
+ if bias == 'auto':
167
+ bias = not self.with_norm
168
+ self.with_bias = bias
169
+
170
+ if self.with_explicit_padding:
171
+ pad_cfg = dict(type=padding_mode)
172
+ self.padding_layer = build_padding_layer(pad_cfg, padding)
173
+
174
+ # reset padding to 0 for conv module
175
+ conv_padding = 0 if self.with_explicit_padding else padding
176
+ # build convolution layer
177
+ self.conv = build_conv_layer(
178
+ conv_cfg,
179
+ in_channels,
180
+ out_channels,
181
+ kernel_size,
182
+ stride=stride,
183
+ padding=conv_padding,
184
+ dilation=dilation,
185
+ groups=groups,
186
+ bias=bias)
187
+ # export the attributes of self.conv to a higher level for convenience
188
+ self.in_channels = self.conv.in_channels
189
+ self.out_channels = self.conv.out_channels
190
+ self.kernel_size = self.conv.kernel_size
191
+ self.stride = self.conv.stride
192
+ self.padding = padding
193
+ self.dilation = self.conv.dilation
194
+ self.transposed = self.conv.transposed
195
+ self.output_padding = self.conv.output_padding
196
+ self.groups = self.conv.groups
197
+
198
+ if self.with_spectral_norm:
199
+ self.conv = nn.utils.spectral_norm(self.conv)
200
+
201
+ # build normalization layers
202
+ if self.with_norm:
203
+ # norm layer is after conv layer
204
+ if order.index('norm') > order.index('conv'):
205
+ norm_channels = out_channels
206
+ else:
207
+ norm_channels = in_channels
208
+ self.norm_name, norm = build_norm_layer(
209
+ norm_cfg, norm_channels) # type: ignore
210
+ self.add_module(self.norm_name, norm)
211
+ if self.with_bias:
212
+ if isinstance(norm, (_BatchNorm, _InstanceNorm)):
213
+ warnings.warn(
214
+ 'Unnecessary conv bias before batch/instance norm')
215
+ else:
216
+ self.norm_name = None # type: ignore
217
+
218
+ self.turn_on_efficient_conv_bn_eval(efficient_conv_bn_eval)
219
+
220
+ # build activation layer
221
+ if self.with_activation:
222
+ act_cfg_ = act_cfg.copy() # type: ignore
223
+ # nn.Tanh has no 'inplace' argument
224
+ if act_cfg_['type'] not in [
225
+ 'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish', 'GELU'
226
+ ]:
227
+ act_cfg_.setdefault('inplace', inplace)
228
+ self.activate = build_activation_layer(act_cfg_)
229
+
230
+ # Use msra init by default
231
+ self.init_weights()
232
+
233
+ @property
234
+ def norm(self):
235
+ if self.norm_name:
236
+ return getattr(self, self.norm_name)
237
+ else:
238
+ return None
239
+
240
+ def init_weights(self):
241
+ # 1. It is mainly for customized conv layers with their own
242
+ # initialization manners by calling their own ``init_weights()``,
243
+ # and we do not want ConvModule to override the initialization.
244
+ # 2. For customized conv layers without their own initialization
245
+ # manners (that is, they don't have their own ``init_weights()``)
246
+ # and PyTorch's conv layers, they will be initialized by
247
+ # this method with default ``kaiming_init``.
248
+ # Note: For PyTorch's conv layers, they will be overwritten by our
249
+ # initialization implementation using default ``kaiming_init``.
250
+ if not hasattr(self.conv, 'init_weights'):
251
+ if self.with_activation and self.act_cfg['type'] == 'LeakyReLU':
252
+ nonlinearity = 'leaky_relu'
253
+ a = self.act_cfg.get('negative_slope', 0.01)
254
+ else:
255
+ nonlinearity = 'relu'
256
+ a = 0
257
+ kaiming_init(self.conv, a=a, nonlinearity=nonlinearity)
258
+ if self.with_norm:
259
+ constant_init(self.norm, 1, bias=0)
260
+
261
+ def forward(self,
262
+ x: torch.Tensor,
263
+ activate: bool = True,
264
+ norm: bool = True) -> torch.Tensor:
265
+ layer_index = 0
266
+ while layer_index < len(self.order):
267
+ layer = self.order[layer_index]
268
+ if layer == 'conv':
269
+ if self.with_explicit_padding:
270
+ x = self.padding_layer(x)
271
+ # if the next operation is norm and we have a norm layer in
272
+ # eval mode and we have enabled `efficient_conv_bn_eval` for
273
+ # the conv operator, then activate the optimized forward and
274
+ # skip the next norm operator since it has been fused
275
+ if layer_index + 1 < len(self.order) and \
276
+ self.order[layer_index + 1] == 'norm' and norm and \
277
+ self.with_norm and not self.norm.training and \
278
+ self.efficient_conv_bn_eval_forward is not None:
279
+ self.conv.forward = partial(
280
+ self.efficient_conv_bn_eval_forward, self.norm,
281
+ self.conv)
282
+ layer_index += 1
283
+ x = self.conv(x)
284
+ del self.conv.forward
285
+ else:
286
+ x = self.conv(x)
287
+ elif layer == 'norm' and norm and self.with_norm:
288
+ x = self.norm(x)
289
+ elif layer == 'act' and activate and self.with_activation:
290
+ x = self.activate(x)
291
+ layer_index += 1
292
+ return x
293
+
294
+ def turn_on_efficient_conv_bn_eval(self, efficient_conv_bn_eval=True):
295
+ # efficient_conv_bn_eval works for conv + bn
296
+ # with `track_running_stats` option
297
+ if efficient_conv_bn_eval and self.norm \
298
+ and isinstance(self.norm, _BatchNorm) \
299
+ and self.norm.track_running_stats:
300
+ self.efficient_conv_bn_eval_forward = efficient_conv_bn_eval_forward # noqa: E501
301
+ else:
302
+ self.efficient_conv_bn_eval_forward = None # type: ignore
303
+
304
+ @staticmethod
305
+ def create_from_conv_bn(conv: torch.nn.modules.conv._ConvNd,
306
+ bn: torch.nn.modules.batchnorm._BatchNorm,
307
+ efficient_conv_bn_eval=True) -> 'ConvModule':
308
+ """Create a ConvModule from a conv and a bn module."""
309
+ self = ConvModule.__new__(ConvModule)
310
+ super(ConvModule, self).__init__()
311
+
312
+ self.conv_cfg = None
313
+ self.norm_cfg = None
314
+ self.act_cfg = None
315
+ self.inplace = False
316
+ self.with_spectral_norm = False
317
+ self.with_explicit_padding = False
318
+ self.order = ('conv', 'norm', 'act')
319
+
320
+ self.with_norm = True
321
+ self.with_activation = False
322
+ self.with_bias = conv.bias is not None
323
+
324
+ # build convolution layer
325
+ self.conv = conv
326
+ # export the attributes of self.conv to a higher level for convenience
327
+ self.in_channels = self.conv.in_channels
328
+ self.out_channels = self.conv.out_channels
329
+ self.kernel_size = self.conv.kernel_size
330
+ self.stride = self.conv.stride
331
+ self.padding = self.conv.padding
332
+ self.dilation = self.conv.dilation
333
+ self.transposed = self.conv.transposed
334
+ self.output_padding = self.conv.output_padding
335
+ self.groups = self.conv.groups
336
+
337
+ # build normalization layers
338
+ self.norm_name, norm = 'bn', bn
339
+ self.add_module(self.norm_name, norm)
340
+
341
+ self.turn_on_efficient_conv_bn_eval(efficient_conv_bn_eval)
342
+
343
+ return self
external/cv/mmcv/cnn/bricks/conv_ws.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from collections import OrderedDict
8
+ from typing import Dict, List, Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from mmengine.registry import MODELS
14
+
15
+
16
+ def conv_ws_2d(input: torch.Tensor,
17
+ weight: torch.Tensor,
18
+ bias: Optional[torch.Tensor] = None,
19
+ stride: Union[int, Tuple[int, int]] = 1,
20
+ padding: Union[int, Tuple[int, int]] = 0,
21
+ dilation: Union[int, Tuple[int, int]] = 1,
22
+ groups: int = 1,
23
+ eps: float = 1e-5) -> torch.Tensor:
24
+ c_in = weight.size(0)
25
+ weight_flat = weight.view(c_in, -1)
26
+ mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1)
27
+ std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1)
28
+ weight = (weight - mean) / (std + eps)
29
+ return F.conv2d(input, weight, bias, stride, padding, dilation, groups)
30
+
31
+
32
+ @MODELS.register_module('ConvWS')
33
+ class ConvWS2d(nn.Conv2d):
34
+
35
+ def __init__(self,
36
+ in_channels: int,
37
+ out_channels: int,
38
+ kernel_size: Union[int, Tuple[int, int]],
39
+ stride: Union[int, Tuple[int, int]] = 1,
40
+ padding: Union[int, Tuple[int, int]] = 0,
41
+ dilation: Union[int, Tuple[int, int]] = 1,
42
+ groups: int = 1,
43
+ bias: bool = True,
44
+ eps: float = 1e-5):
45
+ super().__init__(
46
+ in_channels,
47
+ out_channels,
48
+ kernel_size,
49
+ stride=stride,
50
+ padding=padding,
51
+ dilation=dilation,
52
+ groups=groups,
53
+ bias=bias)
54
+ self.eps = eps
55
+
56
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
57
+ return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding,
58
+ self.dilation, self.groups, self.eps)
59
+
60
+
61
+ @MODELS.register_module(name='ConvAWS')
62
+ class ConvAWS2d(nn.Conv2d):
63
+ """AWS (Adaptive Weight Standardization)
64
+
65
+ This is a variant of Weight Standardization
66
+ (https://arxiv.org/pdf/1903.10520.pdf)
67
+ It is used in DetectoRS to avoid NaN
68
+ (https://arxiv.org/pdf/2006.02334.pdf)
69
+
70
+ Args:
71
+ in_channels (int): Number of channels in the input image
72
+ out_channels (int): Number of channels produced by the convolution
73
+ kernel_size (int or tuple): Size of the conv kernel
74
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
75
+ padding (int or tuple, optional): Zero-padding added to both sides of
76
+ the input. Default: 0
77
+ dilation (int or tuple, optional): Spacing between kernel elements.
78
+ Default: 1
79
+ groups (int, optional): Number of blocked connections from input
80
+ channels to output channels. Default: 1
81
+ bias (bool, optional): If set True, adds a learnable bias to the
82
+ output. Default: True
83
+ """
84
+
85
+ def __init__(self,
86
+ in_channels: int,
87
+ out_channels: int,
88
+ kernel_size: Union[int, Tuple[int, int]],
89
+ stride: Union[int, Tuple[int, int]] = 1,
90
+ padding: Union[int, Tuple[int, int]] = 0,
91
+ dilation: Union[int, Tuple[int, int]] = 1,
92
+ groups: int = 1,
93
+ bias: bool = True):
94
+ super().__init__(
95
+ in_channels,
96
+ out_channels,
97
+ kernel_size,
98
+ stride=stride,
99
+ padding=padding,
100
+ dilation=dilation,
101
+ groups=groups,
102
+ bias=bias)
103
+ self.register_buffer('weight_gamma',
104
+ torch.ones(self.out_channels, 1, 1, 1))
105
+ self.register_buffer('weight_beta',
106
+ torch.zeros(self.out_channels, 1, 1, 1))
107
+
108
+ def _get_weight(self, weight: torch.Tensor) -> torch.Tensor:
109
+ weight_flat = weight.view(weight.size(0), -1)
110
+ mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
111
+ std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
112
+ weight = (weight - mean) / std
113
+ weight = self.weight_gamma * weight + self.weight_beta
114
+ return weight
115
+
116
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
117
+ weight = self._get_weight(self.weight)
118
+ return F.conv2d(x, weight, self.bias, self.stride, self.padding,
119
+ self.dilation, self.groups)
120
+
121
+ def _load_from_state_dict(self, state_dict: OrderedDict, prefix: str,
122
+ local_metadata: Dict, strict: bool,
123
+ missing_keys: List[str],
124
+ unexpected_keys: List[str],
125
+ error_msgs: List[str]) -> None:
126
+ """Override default load function.
127
+
128
+ AWS overrides the function _load_from_state_dict to recover
129
+ weight_gamma and weight_beta if they are missing. If weight_gamma and
130
+ weight_beta are found in the checkpoint, this function will return
131
+ after super()._load_from_state_dict. Otherwise, it will compute the
132
+ mean and std of the pretrained weights and store them in weight_beta
133
+ and weight_gamma.
134
+ """
135
+
136
+ self.weight_gamma.data.fill_(-1)
137
+ local_missing_keys: List = []
138
+ super()._load_from_state_dict(state_dict, prefix, local_metadata,
139
+ strict, local_missing_keys,
140
+ unexpected_keys, error_msgs)
141
+ if self.weight_gamma.data.mean() > 0:
142
+ for k in local_missing_keys:
143
+ missing_keys.append(k)
144
+ return
145
+ weight = self.weight.data
146
+ weight_flat = weight.view(weight.size(0), -1)
147
+ mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
148
+ std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
149
+ self.weight_beta.data.copy_(mean)
150
+ self.weight_gamma.data.copy_(std)
151
+ missing_gamma_beta = [
152
+ k for k in local_missing_keys
153
+ if k.endswith('weight_gamma') or k.endswith('weight_beta')
154
+ ]
155
+ for k in missing_gamma_beta:
156
+ local_missing_keys.remove(k)
157
+ for k in local_missing_keys:
158
+ missing_keys.append(k)
external/cv/mmcv/cnn/bricks/depthwise_separable_conv_module.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Dict, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from .conv_module import ConvModule
13
+
14
+
15
+ class DepthwiseSeparableConvModule(nn.Module):
16
+ """Depthwise separable convolution module.
17
+
18
+ See https://arxiv.org/pdf/1704.04861.pdf for details.
19
+
20
+ This module can replace a ConvModule with the conv block replaced by two
21
+ conv block: depthwise conv block and pointwise conv block. The depthwise
22
+ conv block contains depthwise-conv/norm/activation layers. The pointwise
23
+ conv block contains pointwise-conv/norm/activation layers. It should be
24
+ noted that there will be norm/activation layer in the depthwise conv block
25
+ if `norm_cfg` and `act_cfg` are specified.
26
+
27
+ Args:
28
+ in_channels (int): Number of channels in the input feature map.
29
+ Same as that in ``nn._ConvNd``.
30
+ out_channels (int): Number of channels produced by the convolution.
31
+ Same as that in ``nn._ConvNd``.
32
+ kernel_size (int | tuple[int]): Size of the convolving kernel.
33
+ Same as that in ``nn._ConvNd``.
34
+ stride (int | tuple[int]): Stride of the convolution.
35
+ Same as that in ``nn._ConvNd``. Default: 1.
36
+ padding (int | tuple[int]): Zero-padding added to both sides of
37
+ the input. Same as that in ``nn._ConvNd``. Default: 0.
38
+ dilation (int | tuple[int]): Spacing between kernel elements.
39
+ Same as that in ``nn._ConvNd``. Default: 1.
40
+ norm_cfg (dict): Default norm config for both depthwise ConvModule and
41
+ pointwise ConvModule. Default: None.
42
+ act_cfg (dict): Default activation config for both depthwise ConvModule
43
+ and pointwise ConvModule. Default: dict(type='ReLU').
44
+ dw_norm_cfg (dict): Norm config of depthwise ConvModule. If it is
45
+ 'default', it will be the same as `norm_cfg`. Default: 'default'.
46
+ dw_act_cfg (dict): Activation config of depthwise ConvModule. If it is
47
+ 'default', it will be the same as `act_cfg`. Default: 'default'.
48
+ pw_norm_cfg (dict): Norm config of pointwise ConvModule. If it is
49
+ 'default', it will be the same as `norm_cfg`. Default: 'default'.
50
+ pw_act_cfg (dict): Activation config of pointwise ConvModule. If it is
51
+ 'default', it will be the same as `act_cfg`. Default: 'default'.
52
+ kwargs (optional): Other shared arguments for depthwise and pointwise
53
+ ConvModule. See ConvModule for ref.
54
+ """
55
+
56
+ def __init__(self,
57
+ in_channels: int,
58
+ out_channels: int,
59
+ kernel_size: Union[int, Tuple[int, int]],
60
+ stride: Union[int, Tuple[int, int]] = 1,
61
+ padding: Union[int, Tuple[int, int]] = 0,
62
+ dilation: Union[int, Tuple[int, int]] = 1,
63
+ norm_cfg: Optional[Dict] = None,
64
+ act_cfg: Dict = dict(type='ReLU'),
65
+ dw_norm_cfg: Union[Dict, str] = 'default',
66
+ dw_act_cfg: Union[Dict, str] = 'default',
67
+ pw_norm_cfg: Union[Dict, str] = 'default',
68
+ pw_act_cfg: Union[Dict, str] = 'default',
69
+ **kwargs):
70
+ super().__init__()
71
+ assert 'groups' not in kwargs, 'groups should not be specified'
72
+
73
+ # if norm/activation config of depthwise/pointwise ConvModule is not
74
+ # specified, use default config.
75
+ dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg # type: ignore # noqa E501
76
+ dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg
77
+ pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg # type: ignore # noqa E501
78
+ pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg
79
+
80
+ # depthwise convolution
81
+ self.depthwise_conv = ConvModule(
82
+ in_channels,
83
+ in_channels,
84
+ kernel_size,
85
+ stride=stride,
86
+ padding=padding,
87
+ dilation=dilation,
88
+ groups=in_channels,
89
+ norm_cfg=dw_norm_cfg, # type: ignore
90
+ act_cfg=dw_act_cfg, # type: ignore
91
+ **kwargs)
92
+
93
+ self.pointwise_conv = ConvModule(
94
+ in_channels,
95
+ out_channels,
96
+ 1,
97
+ norm_cfg=pw_norm_cfg, # type: ignore
98
+ act_cfg=pw_act_cfg, # type: ignore
99
+ **kwargs)
100
+
101
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
102
+ x = self.depthwise_conv(x)
103
+ x = self.pointwise_conv(x)
104
+ return x
external/cv/mmcv/cnn/bricks/drop.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Any, Dict, Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from mmengine.registry import MODELS
12
+
13
+
14
+ def drop_path(x: torch.Tensor,
15
+ drop_prob: float = 0.,
16
+ training: bool = False) -> torch.Tensor:
17
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
18
+ residual blocks).
19
+
20
+ We follow the implementation
21
+ https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
22
+ """
23
+ if not training:
24
+ return x
25
+ keep_prob = 1 - drop_prob
26
+ # handle tensors with different dimensions, not just 4D tensors.
27
+ shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
28
+ random_tensor = keep_prob + torch.rand(
29
+ shape, dtype=x.dtype, device=x.device)
30
+ output = x.div(keep_prob) * random_tensor.floor()
31
+ return output
32
+
33
+
34
+ @MODELS.register_module()
35
+ class DropPath(nn.Module):
36
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
37
+ residual blocks).
38
+
39
+ We follow the implementation
40
+ https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
41
+
42
+ Args:
43
+ drop_prob (float): Probability of the path to be zeroed. Default: 0.1
44
+ """
45
+
46
+ def __init__(self, drop_prob: float = 0.1):
47
+ super().__init__()
48
+ self.drop_prob = drop_prob
49
+
50
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ return drop_path(x, self.drop_prob, self.training)
52
+
53
+
54
+ @MODELS.register_module()
55
+ class Dropout(nn.Dropout):
56
+ """A wrapper for ``torch.nn.Dropout``, We rename the ``p`` of
57
+ ``torch.nn.Dropout`` to ``drop_prob`` so as to be consistent with
58
+ ``DropPath``
59
+
60
+ Args:
61
+ drop_prob (float): Probability of the elements to be
62
+ zeroed. Default: 0.5.
63
+ inplace (bool): Do the operation inplace or not. Default: False.
64
+ """
65
+
66
+ def __init__(self, drop_prob: float = 0.5, inplace: bool = False):
67
+ super().__init__(p=drop_prob, inplace=inplace)
68
+
69
+
70
+ def build_dropout(cfg: Dict, default_args: Optional[Dict] = None) -> Any:
71
+ """Builder for drop out layers."""
72
+ return MODELS.build(cfg, default_args=default_args)
external/cv/mmcv/cnn/bricks/generalized_attention.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from mmengine.model import kaiming_init
14
+ from mmengine.registry import MODELS
15
+
16
+
17
+ @MODELS.register_module()
18
+ class GeneralizedAttention(nn.Module):
19
+ """GeneralizedAttention module.
20
+
21
+ See 'An Empirical Study of Spatial Attention Mechanisms in Deep Networks'
22
+ (https://arxiv.org/abs/1904.05873) for details.
23
+
24
+ Args:
25
+ in_channels (int): Channels of the input feature map.
26
+ spatial_range (int): The spatial range. -1 indicates no spatial range
27
+ constraint. Default: -1.
28
+ num_heads (int): The head number of empirical_attention module.
29
+ Default: 9.
30
+ position_embedding_dim (int): The position embedding dimension.
31
+ Default: -1.
32
+ position_magnitude (int): A multiplier acting on coord difference.
33
+ Default: 1.
34
+ kv_stride (int): The feature stride acting on key/value feature map.
35
+ Default: 2.
36
+ q_stride (int): The feature stride acting on query feature map.
37
+ Default: 1.
38
+ attention_type (str): A binary indicator string for indicating which
39
+ items in generalized empirical_attention module are used.
40
+ Default: '1111'.
41
+
42
+ - '1000' indicates 'query and key content' (appr - appr) item,
43
+ - '0100' indicates 'query content and relative position'
44
+ (appr - position) item,
45
+ - '0010' indicates 'key content only' (bias - appr) item,
46
+ - '0001' indicates 'relative position only' (bias - position) item.
47
+ """
48
+
49
+ _abbr_ = 'gen_attention_block'
50
+
51
+ def __init__(self,
52
+ in_channels: int,
53
+ spatial_range: int = -1,
54
+ num_heads: int = 9,
55
+ position_embedding_dim: int = -1,
56
+ position_magnitude: int = 1,
57
+ kv_stride: int = 2,
58
+ q_stride: int = 1,
59
+ attention_type: str = '1111'):
60
+
61
+ super().__init__()
62
+
63
+ # hard range means local range for non-local operation
64
+ self.position_embedding_dim = (
65
+ position_embedding_dim
66
+ if position_embedding_dim > 0 else in_channels)
67
+
68
+ self.position_magnitude = position_magnitude
69
+ self.num_heads = num_heads
70
+ self.in_channels = in_channels
71
+ self.spatial_range = spatial_range
72
+ self.kv_stride = kv_stride
73
+ self.q_stride = q_stride
74
+ self.attention_type = [bool(int(_)) for _ in attention_type]
75
+ self.qk_embed_dim = in_channels // num_heads
76
+ out_c = self.qk_embed_dim * num_heads
77
+
78
+ if self.attention_type[0] or self.attention_type[1]:
79
+ self.query_conv = nn.Conv2d(
80
+ in_channels=in_channels,
81
+ out_channels=out_c,
82
+ kernel_size=1,
83
+ bias=False)
84
+ self.query_conv.kaiming_init = True
85
+
86
+ if self.attention_type[0] or self.attention_type[2]:
87
+ self.key_conv = nn.Conv2d(
88
+ in_channels=in_channels,
89
+ out_channels=out_c,
90
+ kernel_size=1,
91
+ bias=False)
92
+ self.key_conv.kaiming_init = True
93
+
94
+ self.v_dim = in_channels // num_heads
95
+ self.value_conv = nn.Conv2d(
96
+ in_channels=in_channels,
97
+ out_channels=self.v_dim * num_heads,
98
+ kernel_size=1,
99
+ bias=False)
100
+ self.value_conv.kaiming_init = True
101
+
102
+ if self.attention_type[1] or self.attention_type[3]:
103
+ self.appr_geom_fc_x = nn.Linear(
104
+ self.position_embedding_dim // 2, out_c, bias=False)
105
+ self.appr_geom_fc_x.kaiming_init = True
106
+
107
+ self.appr_geom_fc_y = nn.Linear(
108
+ self.position_embedding_dim // 2, out_c, bias=False)
109
+ self.appr_geom_fc_y.kaiming_init = True
110
+
111
+ if self.attention_type[2]:
112
+ stdv = 1.0 / math.sqrt(self.qk_embed_dim * 2)
113
+ appr_bias_value = -2 * stdv * torch.rand(out_c) + stdv
114
+ self.appr_bias = nn.Parameter(appr_bias_value)
115
+
116
+ if self.attention_type[3]:
117
+ stdv = 1.0 / math.sqrt(self.qk_embed_dim * 2)
118
+ geom_bias_value = -2 * stdv * torch.rand(out_c) + stdv
119
+ self.geom_bias = nn.Parameter(geom_bias_value)
120
+
121
+ self.proj_conv = nn.Conv2d(
122
+ in_channels=self.v_dim * num_heads,
123
+ out_channels=in_channels,
124
+ kernel_size=1,
125
+ bias=True)
126
+ self.proj_conv.kaiming_init = True
127
+ self.gamma = nn.Parameter(torch.zeros(1))
128
+
129
+ if self.spatial_range >= 0:
130
+ # only works when non local is after 3*3 conv
131
+ if in_channels == 256:
132
+ max_len = 84
133
+ elif in_channels == 512:
134
+ max_len = 42
135
+
136
+ max_len_kv = int((max_len - 1.0) / self.kv_stride + 1)
137
+ local_constraint_map = np.ones(
138
+ (max_len, max_len, max_len_kv, max_len_kv), dtype=int)
139
+ for iy in range(max_len):
140
+ for ix in range(max_len):
141
+ local_constraint_map[
142
+ iy, ix,
143
+ max((iy - self.spatial_range) //
144
+ self.kv_stride, 0):min((iy + self.spatial_range +
145
+ 1) // self.kv_stride +
146
+ 1, max_len),
147
+ max((ix - self.spatial_range) //
148
+ self.kv_stride, 0):min((ix + self.spatial_range +
149
+ 1) // self.kv_stride +
150
+ 1, max_len)] = 0
151
+
152
+ self.local_constraint_map = nn.Parameter(
153
+ torch.from_numpy(local_constraint_map).byte(),
154
+ requires_grad=False)
155
+
156
+ if self.q_stride > 1:
157
+ self.q_downsample = nn.AvgPool2d(
158
+ kernel_size=1, stride=self.q_stride)
159
+ else:
160
+ self.q_downsample = None
161
+
162
+ if self.kv_stride > 1:
163
+ self.kv_downsample = nn.AvgPool2d(
164
+ kernel_size=1, stride=self.kv_stride)
165
+ else:
166
+ self.kv_downsample = None
167
+
168
+ self.init_weights()
169
+
170
+ def get_position_embedding(self,
171
+ h,
172
+ w,
173
+ h_kv,
174
+ w_kv,
175
+ q_stride,
176
+ kv_stride,
177
+ device,
178
+ dtype,
179
+ feat_dim,
180
+ wave_length=1000):
181
+ # the default type of Tensor is float32, leading to type mismatch
182
+ # in fp16 mode. Cast it to support fp16 mode.
183
+ h_idxs = torch.linspace(0, h - 1, h).to(device=device, dtype=dtype)
184
+ h_idxs = h_idxs.view((h, 1)) * q_stride
185
+
186
+ w_idxs = torch.linspace(0, w - 1, w).to(device=device, dtype=dtype)
187
+ w_idxs = w_idxs.view((w, 1)) * q_stride
188
+
189
+ h_kv_idxs = torch.linspace(0, h_kv - 1, h_kv).to(
190
+ device=device, dtype=dtype)
191
+ h_kv_idxs = h_kv_idxs.view((h_kv, 1)) * kv_stride
192
+
193
+ w_kv_idxs = torch.linspace(0, w_kv - 1, w_kv).to(
194
+ device=device, dtype=dtype)
195
+ w_kv_idxs = w_kv_idxs.view((w_kv, 1)) * kv_stride
196
+
197
+ # (h, h_kv, 1)
198
+ h_diff = h_idxs.unsqueeze(1) - h_kv_idxs.unsqueeze(0)
199
+ h_diff *= self.position_magnitude
200
+
201
+ # (w, w_kv, 1)
202
+ w_diff = w_idxs.unsqueeze(1) - w_kv_idxs.unsqueeze(0)
203
+ w_diff *= self.position_magnitude
204
+
205
+ feat_range = torch.arange(0, feat_dim / 4).to(
206
+ device=device, dtype=dtype)
207
+
208
+ dim_mat = torch.Tensor([wave_length]).to(device=device, dtype=dtype)
209
+ dim_mat = dim_mat**((4. / feat_dim) * feat_range)
210
+ dim_mat = dim_mat.view((1, 1, -1))
211
+
212
+ embedding_x = torch.cat(
213
+ ((w_diff / dim_mat).sin(), (w_diff / dim_mat).cos()), dim=2)
214
+
215
+ embedding_y = torch.cat(
216
+ ((h_diff / dim_mat).sin(), (h_diff / dim_mat).cos()), dim=2)
217
+
218
+ return embedding_x, embedding_y
219
+
220
+ def forward(self, x_input: torch.Tensor) -> torch.Tensor:
221
+ num_heads = self.num_heads
222
+
223
+ # use empirical_attention
224
+ if self.q_downsample is not None:
225
+ x_q = self.q_downsample(x_input)
226
+ else:
227
+ x_q = x_input
228
+ n, _, h, w = x_q.shape
229
+
230
+ if self.kv_downsample is not None:
231
+ x_kv = self.kv_downsample(x_input)
232
+ else:
233
+ x_kv = x_input
234
+ _, _, h_kv, w_kv = x_kv.shape
235
+
236
+ if self.attention_type[0] or self.attention_type[1]:
237
+ proj_query = self.query_conv(x_q).view(
238
+ (n, num_heads, self.qk_embed_dim, h * w))
239
+ proj_query = proj_query.permute(0, 1, 3, 2)
240
+
241
+ if self.attention_type[0] or self.attention_type[2]:
242
+ proj_key = self.key_conv(x_kv).view(
243
+ (n, num_heads, self.qk_embed_dim, h_kv * w_kv))
244
+
245
+ if self.attention_type[1] or self.attention_type[3]:
246
+ position_embed_x, position_embed_y = self.get_position_embedding(
247
+ h, w, h_kv, w_kv, self.q_stride, self.kv_stride,
248
+ x_input.device, x_input.dtype, self.position_embedding_dim)
249
+ # (n, num_heads, w, w_kv, dim)
250
+ position_feat_x = self.appr_geom_fc_x(position_embed_x).\
251
+ view(1, w, w_kv, num_heads, self.qk_embed_dim).\
252
+ permute(0, 3, 1, 2, 4).\
253
+ repeat(n, 1, 1, 1, 1)
254
+
255
+ # (n, num_heads, h, h_kv, dim)
256
+ position_feat_y = self.appr_geom_fc_y(position_embed_y).\
257
+ view(1, h, h_kv, num_heads, self.qk_embed_dim).\
258
+ permute(0, 3, 1, 2, 4).\
259
+ repeat(n, 1, 1, 1, 1)
260
+
261
+ position_feat_x /= math.sqrt(2)
262
+ position_feat_y /= math.sqrt(2)
263
+
264
+ # accelerate for saliency only
265
+ if (np.sum(self.attention_type) == 1) and self.attention_type[2]:
266
+ appr_bias = self.appr_bias.\
267
+ view(1, num_heads, 1, self.qk_embed_dim).\
268
+ repeat(n, 1, 1, 1)
269
+
270
+ energy = torch.matmul(appr_bias, proj_key).\
271
+ view(n, num_heads, 1, h_kv * w_kv)
272
+
273
+ h = 1
274
+ w = 1
275
+ else:
276
+ # (n, num_heads, h*w, h_kv*w_kv), query before key, 540mb for
277
+ if not self.attention_type[0]:
278
+ energy = torch.zeros(
279
+ n,
280
+ num_heads,
281
+ h,
282
+ w,
283
+ h_kv,
284
+ w_kv,
285
+ dtype=x_input.dtype,
286
+ device=x_input.device)
287
+
288
+ # attention_type[0]: appr - appr
289
+ # attention_type[1]: appr - position
290
+ # attention_type[2]: bias - appr
291
+ # attention_type[3]: bias - position
292
+ if self.attention_type[0] or self.attention_type[2]:
293
+ if self.attention_type[0] and self.attention_type[2]:
294
+ appr_bias = self.appr_bias.\
295
+ view(1, num_heads, 1, self.qk_embed_dim)
296
+ energy = torch.matmul(proj_query + appr_bias, proj_key).\
297
+ view(n, num_heads, h, w, h_kv, w_kv)
298
+
299
+ elif self.attention_type[0]:
300
+ energy = torch.matmul(proj_query, proj_key).\
301
+ view(n, num_heads, h, w, h_kv, w_kv)
302
+
303
+ elif self.attention_type[2]:
304
+ appr_bias = self.appr_bias.\
305
+ view(1, num_heads, 1, self.qk_embed_dim).\
306
+ repeat(n, 1, 1, 1)
307
+
308
+ energy += torch.matmul(appr_bias, proj_key).\
309
+ view(n, num_heads, 1, 1, h_kv, w_kv)
310
+
311
+ if self.attention_type[1] or self.attention_type[3]:
312
+ if self.attention_type[1] and self.attention_type[3]:
313
+ geom_bias = self.geom_bias.\
314
+ view(1, num_heads, 1, self.qk_embed_dim)
315
+
316
+ proj_query_reshape = (proj_query + geom_bias).\
317
+ view(n, num_heads, h, w, self.qk_embed_dim)
318
+
319
+ energy_x = torch.matmul(
320
+ proj_query_reshape.permute(0, 1, 3, 2, 4),
321
+ position_feat_x.permute(0, 1, 2, 4, 3))
322
+ energy_x = energy_x.\
323
+ permute(0, 1, 3, 2, 4).unsqueeze(4)
324
+
325
+ energy_y = torch.matmul(
326
+ proj_query_reshape,
327
+ position_feat_y.permute(0, 1, 2, 4, 3))
328
+ energy_y = energy_y.unsqueeze(5)
329
+
330
+ energy += energy_x + energy_y
331
+
332
+ elif self.attention_type[1]:
333
+ proj_query_reshape = proj_query.\
334
+ view(n, num_heads, h, w, self.qk_embed_dim)
335
+ proj_query_reshape = proj_query_reshape.\
336
+ permute(0, 1, 3, 2, 4)
337
+ position_feat_x_reshape = position_feat_x.\
338
+ permute(0, 1, 2, 4, 3)
339
+ position_feat_y_reshape = position_feat_y.\
340
+ permute(0, 1, 2, 4, 3)
341
+
342
+ energy_x = torch.matmul(proj_query_reshape,
343
+ position_feat_x_reshape)
344
+ energy_x = energy_x.permute(0, 1, 3, 2, 4).unsqueeze(4)
345
+
346
+ energy_y = torch.matmul(proj_query_reshape,
347
+ position_feat_y_reshape)
348
+ energy_y = energy_y.unsqueeze(5)
349
+
350
+ energy += energy_x + energy_y
351
+
352
+ elif self.attention_type[3]:
353
+ geom_bias = self.geom_bias.\
354
+ view(1, num_heads, self.qk_embed_dim, 1).\
355
+ repeat(n, 1, 1, 1)
356
+
357
+ position_feat_x_reshape = position_feat_x.\
358
+ view(n, num_heads, w * w_kv, self.qk_embed_dim)
359
+
360
+ position_feat_y_reshape = position_feat_y.\
361
+ view(n, num_heads, h * h_kv, self.qk_embed_dim)
362
+
363
+ energy_x = torch.matmul(position_feat_x_reshape, geom_bias)
364
+ energy_x = energy_x.view(n, num_heads, 1, w, 1, w_kv)
365
+
366
+ energy_y = torch.matmul(position_feat_y_reshape, geom_bias)
367
+ energy_y = energy_y.view(n, num_heads, h, 1, h_kv, 1)
368
+
369
+ energy += energy_x + energy_y
370
+
371
+ energy = energy.view(n, num_heads, h * w, h_kv * w_kv)
372
+
373
+ if self.spatial_range >= 0:
374
+ cur_local_constraint_map = \
375
+ self.local_constraint_map[:h, :w, :h_kv, :w_kv].\
376
+ contiguous().\
377
+ view(1, 1, h*w, h_kv*w_kv)
378
+
379
+ energy = energy.masked_fill_(cur_local_constraint_map,
380
+ float('-inf'))
381
+
382
+ attention = F.softmax(energy, 3)
383
+
384
+ proj_value = self.value_conv(x_kv)
385
+ proj_value_reshape = proj_value.\
386
+ view((n, num_heads, self.v_dim, h_kv * w_kv)).\
387
+ permute(0, 1, 3, 2)
388
+
389
+ out = torch.matmul(attention, proj_value_reshape).\
390
+ permute(0, 1, 3, 2).\
391
+ contiguous().\
392
+ view(n, self.v_dim * self.num_heads, h, w)
393
+
394
+ out = self.proj_conv(out)
395
+
396
+ # output is downsampled, upsample back to input size
397
+ if self.q_downsample is not None:
398
+ out = F.interpolate(
399
+ out,
400
+ size=x_input.shape[2:],
401
+ mode='bilinear',
402
+ align_corners=False)
403
+
404
+ out = self.gamma * out + x_input
405
+ return out
406
+
407
+ def init_weights(self):
408
+ for m in self.modules():
409
+ if hasattr(m, 'kaiming_init') and m.kaiming_init:
410
+ kaiming_init(
411
+ m,
412
+ mode='fan_in',
413
+ nonlinearity='leaky_relu',
414
+ bias=0,
415
+ distribution='uniform',
416
+ a=1)
external/cv/mmcv/cnn/bricks/hsigmoid.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import warnings
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from mmengine.registry import MODELS
12
+
13
+
14
+ @MODELS.register_module()
15
+ class HSigmoid(nn.Module):
16
+ """Hard Sigmoid Module. Apply the hard sigmoid function:
17
+ Hsigmoid(x) = min(max((x + bias) / divisor, min_value), max_value)
18
+ Default: Hsigmoid(x) = min(max((x + 3) / 6, 0), 1)
19
+
20
+ Note:
21
+ In MMCV v1.4.4, we modified the default value of args to align with
22
+ PyTorch official.
23
+
24
+ Args:
25
+ bias (float): Bias of the input feature map. Default: 3.0.
26
+ divisor (float): Divisor of the input feature map. Default: 6.0.
27
+ min_value (float): Lower bound value. Default: 0.0.
28
+ max_value (float): Upper bound value. Default: 1.0.
29
+
30
+ Returns:
31
+ Tensor: The output tensor.
32
+ """
33
+
34
+ def __init__(self,
35
+ bias: float = 3.0,
36
+ divisor: float = 6.0,
37
+ min_value: float = 0.0,
38
+ max_value: float = 1.0):
39
+ super().__init__()
40
+ warnings.warn(
41
+ 'In MMCV v1.4.4, we modified the default value of args to align '
42
+ 'with PyTorch official. Previous Implementation: '
43
+ 'Hsigmoid(x) = min(max((x + 1) / 2, 0), 1). '
44
+ 'Current Implementation: '
45
+ 'Hsigmoid(x) = min(max((x + 3) / 6, 0), 1).')
46
+ self.bias = bias
47
+ self.divisor = divisor
48
+ assert self.divisor != 0
49
+ self.min_value = min_value
50
+ self.max_value = max_value
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ x = (x + self.bias) / self.divisor
54
+
55
+ return x.clamp_(self.min_value, self.max_value)
external/cv/mmcv/cnn/bricks/hswish.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from mmengine.registry import MODELS
10
+ from mmengine.utils import digit_version
11
+ from mmengine.utils.dl_utils import TORCH_VERSION
12
+
13
+
14
+ class HSwish(nn.Module):
15
+ """Hard Swish Module.
16
+
17
+ This module applies the hard swish function:
18
+
19
+ .. math::
20
+ Hswish(x) = x * ReLU6(x + 3) / 6
21
+
22
+ Args:
23
+ inplace (bool): can optionally do the operation in-place.
24
+ Default: False.
25
+
26
+ Returns:
27
+ Tensor: The output tensor.
28
+ """
29
+
30
+ def __init__(self, inplace: bool = False):
31
+ super().__init__()
32
+ self.act = nn.ReLU6(inplace)
33
+
34
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
35
+ return x * self.act(x + 3) / 6
36
+
37
+
38
+ if (TORCH_VERSION == 'parrots'
39
+ or digit_version(TORCH_VERSION) < digit_version('1.7')):
40
+ # Hardswish is not supported when PyTorch version < 1.6.
41
+ # And Hardswish in PyTorch 1.6 does not support inplace.
42
+ MODELS.register_module(module=HSwish)
43
+ else:
44
+ MODELS.register_module(module=nn.Hardswish, name='HSwish')
external/cv/mmcv/cnn/bricks/non_local.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from abc import ABCMeta
8
+ from typing import Dict, Optional
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from mmengine.model import constant_init, normal_init
13
+ from mmengine.registry import MODELS
14
+
15
+ from .conv_module import ConvModule
16
+
17
+
18
+ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
19
+ """Basic Non-local module.
20
+
21
+ This module is proposed in
22
+ "Non-local Neural Networks"
23
+ Paper reference: https://arxiv.org/abs/1711.07971
24
+ Code reference: https://github.com/AlexHex7/Non-local_pytorch
25
+
26
+ Args:
27
+ in_channels (int): Channels of the input feature map.
28
+ reduction (int): Channel reduction ratio. Default: 2.
29
+ use_scale (bool): Whether to scale pairwise_weight by
30
+ `1/sqrt(inter_channels)` when the mode is `embedded_gaussian`.
31
+ Default: True.
32
+ conv_cfg (None | dict): The config dict for convolution layers.
33
+ If not specified, it will use `nn.Conv2d` for convolution layers.
34
+ Default: None.
35
+ norm_cfg (None | dict): The config dict for normalization layers.
36
+ Default: None. (This parameter is only applicable to conv_out.)
37
+ mode (str): Options are `gaussian`, `concatenation`,
38
+ `embedded_gaussian` and `dot_product`. Default: embedded_gaussian.
39
+ """
40
+
41
+ def __init__(self,
42
+ in_channels: int,
43
+ reduction: int = 2,
44
+ use_scale: bool = True,
45
+ conv_cfg: Optional[Dict] = None,
46
+ norm_cfg: Optional[Dict] = None,
47
+ mode: str = 'embedded_gaussian',
48
+ **kwargs):
49
+ super().__init__()
50
+ self.in_channels = in_channels
51
+ self.reduction = reduction
52
+ self.use_scale = use_scale
53
+ self.inter_channels = max(in_channels // reduction, 1)
54
+ self.mode = mode
55
+
56
+ if mode not in [
57
+ 'gaussian', 'embedded_gaussian', 'dot_product', 'concatenation'
58
+ ]:
59
+ raise ValueError("Mode should be in 'gaussian', 'concatenation', "
60
+ f"'embedded_gaussian' or 'dot_product', but got "
61
+ f'{mode} instead.')
62
+
63
+ # g, theta, phi are defaulted as `nn.ConvNd`.
64
+ # Here we use ConvModule for potential usage.
65
+ self.g = ConvModule(
66
+ self.in_channels,
67
+ self.inter_channels,
68
+ kernel_size=1,
69
+ conv_cfg=conv_cfg,
70
+ act_cfg=None) # type: ignore
71
+ self.conv_out = ConvModule(
72
+ self.inter_channels,
73
+ self.in_channels,
74
+ kernel_size=1,
75
+ conv_cfg=conv_cfg,
76
+ norm_cfg=norm_cfg,
77
+ act_cfg=None)
78
+
79
+ if self.mode != 'gaussian':
80
+ self.theta = ConvModule(
81
+ self.in_channels,
82
+ self.inter_channels,
83
+ kernel_size=1,
84
+ conv_cfg=conv_cfg,
85
+ act_cfg=None)
86
+ self.phi = ConvModule(
87
+ self.in_channels,
88
+ self.inter_channels,
89
+ kernel_size=1,
90
+ conv_cfg=conv_cfg,
91
+ act_cfg=None)
92
+
93
+ if self.mode == 'concatenation':
94
+ self.concat_project = ConvModule(
95
+ self.inter_channels * 2,
96
+ 1,
97
+ kernel_size=1,
98
+ stride=1,
99
+ padding=0,
100
+ bias=False,
101
+ act_cfg=dict(type='ReLU'))
102
+
103
+ self.init_weights(**kwargs)
104
+
105
+ def init_weights(self, std: float = 0.01, zeros_init: bool = True) -> None:
106
+ if self.mode != 'gaussian':
107
+ for m in [self.g, self.theta, self.phi]:
108
+ normal_init(m.conv, std=std)
109
+ else:
110
+ normal_init(self.g.conv, std=std)
111
+ if zeros_init:
112
+ if self.conv_out.norm_cfg is None:
113
+ constant_init(self.conv_out.conv, 0)
114
+ else:
115
+ constant_init(self.conv_out.norm, 0)
116
+ else:
117
+ if self.conv_out.norm_cfg is None:
118
+ normal_init(self.conv_out.conv, std=std)
119
+ else:
120
+ normal_init(self.conv_out.norm, std=std)
121
+
122
+ def gaussian(self, theta_x: torch.Tensor,
123
+ phi_x: torch.Tensor) -> torch.Tensor:
124
+ # NonLocal1d pairwise_weight: [N, H, H]
125
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
126
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
127
+ pairwise_weight = torch.matmul(theta_x, phi_x)
128
+ pairwise_weight = pairwise_weight.softmax(dim=-1)
129
+ return pairwise_weight
130
+
131
+ def embedded_gaussian(self, theta_x: torch.Tensor,
132
+ phi_x: torch.Tensor) -> torch.Tensor:
133
+ # NonLocal1d pairwise_weight: [N, H, H]
134
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
135
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
136
+ pairwise_weight = torch.matmul(theta_x, phi_x)
137
+ if self.use_scale:
138
+ # theta_x.shape[-1] is `self.inter_channels`
139
+ pairwise_weight /= theta_x.shape[-1]**0.5
140
+ pairwise_weight = pairwise_weight.softmax(dim=-1)
141
+ return pairwise_weight
142
+
143
+ def dot_product(self, theta_x: torch.Tensor,
144
+ phi_x: torch.Tensor) -> torch.Tensor:
145
+ # NonLocal1d pairwise_weight: [N, H, H]
146
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
147
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
148
+ pairwise_weight = torch.matmul(theta_x, phi_x)
149
+ pairwise_weight /= pairwise_weight.shape[-1]
150
+ return pairwise_weight
151
+
152
+ def concatenation(self, theta_x: torch.Tensor,
153
+ phi_x: torch.Tensor) -> torch.Tensor:
154
+ # NonLocal1d pairwise_weight: [N, H, H]
155
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
156
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
157
+ h = theta_x.size(2)
158
+ w = phi_x.size(3)
159
+ theta_x = theta_x.repeat(1, 1, 1, w)
160
+ phi_x = phi_x.repeat(1, 1, h, 1)
161
+
162
+ concat_feature = torch.cat([theta_x, phi_x], dim=1)
163
+ pairwise_weight = self.concat_project(concat_feature)
164
+ n, _, h, w = pairwise_weight.size()
165
+ pairwise_weight = pairwise_weight.view(n, h, w)
166
+ pairwise_weight /= pairwise_weight.shape[-1]
167
+
168
+ return pairwise_weight
169
+
170
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
171
+ # Assume `reduction = 1`, then `inter_channels = C`
172
+ # or `inter_channels = C` when `mode="gaussian"`
173
+
174
+ # NonLocal1d x: [N, C, H]
175
+ # NonLocal2d x: [N, C, H, W]
176
+ # NonLocal3d x: [N, C, T, H, W]
177
+ n = x.size(0)
178
+
179
+ # NonLocal1d g_x: [N, H, C]
180
+ # NonLocal2d g_x: [N, HxW, C]
181
+ # NonLocal3d g_x: [N, TxHxW, C]
182
+ g_x = self.g(x).view(n, self.inter_channels, -1)
183
+ g_x = g_x.permute(0, 2, 1)
184
+
185
+ # NonLocal1d theta_x: [N, H, C], phi_x: [N, C, H]
186
+ # NonLocal2d theta_x: [N, HxW, C], phi_x: [N, C, HxW]
187
+ # NonLocal3d theta_x: [N, TxHxW, C], phi_x: [N, C, TxHxW]
188
+ if self.mode == 'gaussian':
189
+ theta_x = x.view(n, self.in_channels, -1)
190
+ theta_x = theta_x.permute(0, 2, 1)
191
+ if self.sub_sample:
192
+ phi_x = self.phi(x).view(n, self.in_channels, -1)
193
+ else:
194
+ phi_x = x.view(n, self.in_channels, -1)
195
+ elif self.mode == 'concatenation':
196
+ theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
197
+ phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
198
+ else:
199
+ theta_x = self.theta(x).view(n, self.inter_channels, -1)
200
+ theta_x = theta_x.permute(0, 2, 1)
201
+ phi_x = self.phi(x).view(n, self.inter_channels, -1)
202
+
203
+ pairwise_func = getattr(self, self.mode)
204
+ # NonLocal1d pairwise_weight: [N, H, H]
205
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
206
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
207
+ pairwise_weight = pairwise_func(theta_x, phi_x)
208
+
209
+ # NonLocal1d y: [N, H, C]
210
+ # NonLocal2d y: [N, HxW, C]
211
+ # NonLocal3d y: [N, TxHxW, C]
212
+ y = torch.matmul(pairwise_weight, g_x)
213
+ # NonLocal1d y: [N, C, H]
214
+ # NonLocal2d y: [N, C, H, W]
215
+ # NonLocal3d y: [N, C, T, H, W]
216
+ y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
217
+ *x.size()[2:])
218
+
219
+ output = x + self.conv_out(y)
220
+
221
+ return output
222
+
223
+
224
+ class NonLocal1d(_NonLocalNd):
225
+ """1D Non-local module.
226
+
227
+ Args:
228
+ in_channels (int): Same as `NonLocalND`.
229
+ sub_sample (bool): Whether to apply max pooling after pairwise
230
+ function (Note that the `sub_sample` is applied on spatial only).
231
+ Default: False.
232
+ conv_cfg (None | dict): Same as `NonLocalND`.
233
+ Default: dict(type='Conv1d').
234
+ """
235
+
236
+ def __init__(self,
237
+ in_channels: int,
238
+ sub_sample: bool = False,
239
+ conv_cfg: Dict = dict(type='Conv1d'),
240
+ **kwargs):
241
+ super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
242
+
243
+ self.sub_sample = sub_sample
244
+
245
+ if sub_sample:
246
+ max_pool_layer = nn.MaxPool1d(kernel_size=2)
247
+ self.g = nn.Sequential(self.g, max_pool_layer)
248
+ if self.mode != 'gaussian':
249
+ self.phi = nn.Sequential(self.phi, max_pool_layer)
250
+ else:
251
+ self.phi = max_pool_layer
252
+
253
+
254
+ @MODELS.register_module()
255
+ class NonLocal2d(_NonLocalNd):
256
+ """2D Non-local module.
257
+
258
+ Args:
259
+ in_channels (int): Same as `NonLocalND`.
260
+ sub_sample (bool): Whether to apply max pooling after pairwise
261
+ function (Note that the `sub_sample` is applied on spatial only).
262
+ Default: False.
263
+ conv_cfg (None | dict): Same as `NonLocalND`.
264
+ Default: dict(type='Conv2d').
265
+ """
266
+
267
+ _abbr_ = 'nonlocal_block'
268
+
269
+ def __init__(self,
270
+ in_channels: int,
271
+ sub_sample: bool = False,
272
+ conv_cfg: Dict = dict(type='Conv2d'),
273
+ **kwargs):
274
+ super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
275
+
276
+ self.sub_sample = sub_sample
277
+
278
+ if sub_sample:
279
+ max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
280
+ self.g = nn.Sequential(self.g, max_pool_layer)
281
+ if self.mode != 'gaussian':
282
+ self.phi = nn.Sequential(self.phi, max_pool_layer)
283
+ else:
284
+ self.phi = max_pool_layer
285
+
286
+
287
+ class NonLocal3d(_NonLocalNd):
288
+ """3D Non-local module.
289
+
290
+ Args:
291
+ in_channels (int): Same as `NonLocalND`.
292
+ sub_sample (bool): Whether to apply max pooling after pairwise
293
+ function (Note that the `sub_sample` is applied on spatial only).
294
+ Default: False.
295
+ conv_cfg (None | dict): Same as `NonLocalND`.
296
+ Default: dict(type='Conv3d').
297
+ """
298
+
299
+ def __init__(self,
300
+ in_channels: int,
301
+ sub_sample: bool = False,
302
+ conv_cfg: Dict = dict(type='Conv3d'),
303
+ **kwargs):
304
+ super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
305
+ self.sub_sample = sub_sample
306
+
307
+ if sub_sample:
308
+ max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
309
+ self.g = nn.Sequential(self.g, max_pool_layer)
310
+ if self.mode != 'gaussian':
311
+ self.phi = nn.Sequential(self.phi, max_pool_layer)
312
+ else:
313
+ self.phi = max_pool_layer
external/cv/mmcv/cnn/bricks/norm.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import inspect
8
+ from typing import Dict, Tuple, Union
9
+
10
+ import torch.nn as nn
11
+ from mmengine.registry import MODELS
12
+ from mmengine.utils import is_tuple_of
13
+ from mmengine.utils.dl_utils.parrots_wrapper import (SyncBatchNorm, _BatchNorm,
14
+ _InstanceNorm)
15
+
16
+ MODELS.register_module('BN', module=nn.BatchNorm2d)
17
+ MODELS.register_module('BN1d', module=nn.BatchNorm1d)
18
+ MODELS.register_module('BN2d', module=nn.BatchNorm2d)
19
+ MODELS.register_module('BN3d', module=nn.BatchNorm3d)
20
+ MODELS.register_module('SyncBN', module=SyncBatchNorm)
21
+ MODELS.register_module('GN', module=nn.GroupNorm)
22
+ MODELS.register_module('LN', module=nn.LayerNorm)
23
+ MODELS.register_module('IN', module=nn.InstanceNorm2d)
24
+ MODELS.register_module('IN1d', module=nn.InstanceNorm1d)
25
+ MODELS.register_module('IN2d', module=nn.InstanceNorm2d)
26
+ MODELS.register_module('IN3d', module=nn.InstanceNorm3d)
27
+
28
+
29
+ def infer_abbr(class_type):
30
+ """Infer abbreviation from the class name.
31
+
32
+ When we build a norm layer with `build_norm_layer()`, we want to preserve
33
+ the norm type in variable names, e.g, self.bn1, self.gn. This method will
34
+ infer the abbreviation to map class types to abbreviations.
35
+
36
+ Rule 1: If the class has the property "_abbr_", return the property.
37
+ Rule 2: If the parent class is _BatchNorm, GroupNorm, LayerNorm or
38
+ InstanceNorm, the abbreviation of this layer will be "bn", "gn", "ln" and
39
+ "in" respectively.
40
+ Rule 3: If the class name contains "batch", "group", "layer" or "instance",
41
+ the abbreviation of this layer will be "bn", "gn", "ln" and "in"
42
+ respectively.
43
+ Rule 4: Otherwise, the abbreviation falls back to "norm".
44
+
45
+ Args:
46
+ class_type (type): The norm layer type.
47
+
48
+ Returns:
49
+ str: The inferred abbreviation.
50
+ """
51
+ if not inspect.isclass(class_type):
52
+ raise TypeError(
53
+ f'class_type must be a type, but got {type(class_type)}')
54
+ if hasattr(class_type, '_abbr_'):
55
+ return class_type._abbr_
56
+ if issubclass(class_type, _InstanceNorm): # IN is a subclass of BN
57
+ return 'in'
58
+ elif issubclass(class_type, _BatchNorm):
59
+ return 'bn'
60
+ elif issubclass(class_type, nn.GroupNorm):
61
+ return 'gn'
62
+ elif issubclass(class_type, nn.LayerNorm):
63
+ return 'ln'
64
+ else:
65
+ class_name = class_type.__name__.lower()
66
+ if 'batch' in class_name:
67
+ return 'bn'
68
+ elif 'group' in class_name:
69
+ return 'gn'
70
+ elif 'layer' in class_name:
71
+ return 'ln'
72
+ elif 'instance' in class_name:
73
+ return 'in'
74
+ else:
75
+ return 'norm_layer'
76
+
77
+
78
+ def build_norm_layer(cfg: Dict,
79
+ num_features: int,
80
+ postfix: Union[int, str] = '') -> Tuple[str, nn.Module]:
81
+ """Build normalization layer.
82
+
83
+ Args:
84
+ cfg (dict): The norm layer config, which should contain:
85
+
86
+ - type (str): Layer type.
87
+ - layer args: Args needed to instantiate a norm layer.
88
+ - requires_grad (bool, optional): Whether stop gradient updates.
89
+ num_features (int): Number of input channels.
90
+ postfix (int | str): The postfix to be appended into norm abbreviation
91
+ to create named layer.
92
+
93
+ Returns:
94
+ tuple[str, nn.Module]: The first element is the layer name consisting
95
+ of abbreviation and postfix, e.g., bn1, gn. The second element is the
96
+ created norm layer.
97
+ """
98
+ if not isinstance(cfg, dict):
99
+ raise TypeError('cfg must be a dict')
100
+ if 'type' not in cfg:
101
+ raise KeyError('the cfg dict must contain the key "type"')
102
+ cfg_ = cfg.copy()
103
+
104
+ layer_type = cfg_.pop('type')
105
+
106
+ if inspect.isclass(layer_type):
107
+ norm_layer = layer_type
108
+ else:
109
+ # Switch registry to the target scope. If `norm_layer` cannot be found
110
+ # in the registry, fallback to search `norm_layer` in the
111
+ # mmengine.MODELS.
112
+ with MODELS.switch_scope_and_registry(None) as registry:
113
+ norm_layer = registry.get(layer_type)
114
+ if norm_layer is None:
115
+ raise KeyError(f'Cannot find {norm_layer} in registry under '
116
+ f'scope name {registry.scope}')
117
+ abbr = infer_abbr(norm_layer)
118
+
119
+ assert isinstance(postfix, (int, str))
120
+ name = abbr + str(postfix)
121
+
122
+ requires_grad = cfg_.pop('requires_grad', True)
123
+ cfg_.setdefault('eps', 1e-5)
124
+ if norm_layer is not nn.GroupNorm:
125
+ layer = norm_layer(num_features, **cfg_)
126
+ if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'):
127
+ layer._specify_ddp_gpu_num(1)
128
+ else:
129
+ assert 'num_groups' in cfg_
130
+ layer = norm_layer(num_channels=num_features, **cfg_)
131
+
132
+ for param in layer.parameters():
133
+ param.requires_grad = requires_grad
134
+
135
+ return name, layer
136
+
137
+
138
+ def is_norm(layer: nn.Module,
139
+ exclude: Union[type, tuple, None] = None) -> bool:
140
+ """Check if a layer is a normalization layer.
141
+
142
+ Args:
143
+ layer (nn.Module): The layer to be checked.
144
+ exclude (type | tuple[type]): Types to be excluded.
145
+
146
+ Returns:
147
+ bool: Whether the layer is a norm layer.
148
+ """
149
+ if exclude is not None:
150
+ if not isinstance(exclude, tuple):
151
+ exclude = (exclude, )
152
+ if not is_tuple_of(exclude, type):
153
+ raise TypeError(
154
+ f'"exclude" must be either None or type or a tuple of types, '
155
+ f'but got {type(exclude)}: {exclude}')
156
+
157
+ if exclude and isinstance(layer, exclude):
158
+ return False
159
+
160
+ all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm)
161
+ return isinstance(layer, all_norm_bases)
external/cv/mmcv/cnn/bricks/padding.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import inspect
8
+ from typing import Dict
9
+
10
+ import torch.nn as nn
11
+ from mmengine.registry import MODELS
12
+
13
+ MODELS.register_module('zero', module=nn.ZeroPad2d)
14
+ MODELS.register_module('reflect', module=nn.ReflectionPad2d)
15
+ MODELS.register_module('replicate', module=nn.ReplicationPad2d)
16
+
17
+
18
+ def build_padding_layer(cfg: Dict, *args, **kwargs) -> nn.Module:
19
+ """Build padding layer.
20
+
21
+ Args:
22
+ cfg (dict): The padding layer config, which should contain:
23
+ - type (str): Layer type.
24
+ - layer args: Args needed to instantiate a padding layer.
25
+
26
+ Returns:
27
+ nn.Module: Created padding layer.
28
+ """
29
+ if not isinstance(cfg, dict):
30
+ raise TypeError('cfg must be a dict')
31
+ if 'type' not in cfg:
32
+ raise KeyError('the cfg dict must contain the key "type"')
33
+
34
+ cfg_ = cfg.copy()
35
+ padding_type = cfg_.pop('type')
36
+ if inspect.isclass(padding_type):
37
+ return padding_type(*args, **kwargs, **cfg_)
38
+ # Switch registry to the target scope. If `padding_layer` cannot be found
39
+ # in the registry, fallback to search `padding_layer` in the
40
+ # mmengine.MODELS.
41
+ with MODELS.switch_scope_and_registry(None) as registry:
42
+ padding_layer = registry.get(padding_type)
43
+ if padding_layer is None:
44
+ raise KeyError(f'Cannot find {padding_layer} in registry under scope '
45
+ f'name {registry.scope}')
46
+ layer = padding_layer(*args, **kwargs, **cfg_)
47
+
48
+ return layer
external/cv/mmcv/cnn/bricks/plugin.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import inspect
8
+ import platform
9
+ from typing import Dict, Tuple, Union
10
+
11
+ import torch.nn as nn
12
+ from mmengine.registry import MODELS
13
+
14
+ if platform.system() == 'Windows':
15
+ import regex as re # type: ignore
16
+ else:
17
+ import re # type: ignore
18
+
19
+
20
+ def infer_abbr(class_type: type) -> str:
21
+ """Infer abbreviation from the class name.
22
+
23
+ This method will infer the abbreviation to map class types to
24
+ abbreviations.
25
+
26
+ Rule 1: If the class has the property "abbr", return the property.
27
+ Rule 2: Otherwise, the abbreviation falls back to snake case of class
28
+ name, e.g. the abbreviation of ``FancyBlock`` will be ``fancy_block``.
29
+
30
+ Args:
31
+ class_type (type): The norm layer type.
32
+
33
+ Returns:
34
+ str: The inferred abbreviation.
35
+ """
36
+
37
+ def camel2snack(word):
38
+ """Convert camel case word into snack case.
39
+
40
+ Modified from `inflection lib
41
+ <https://inflection.readthedocs.io/en/latest/#inflection.underscore>`_.
42
+
43
+ Example::
44
+
45
+ >>> camel2snack("FancyBlock")
46
+ 'fancy_block'
47
+ """
48
+
49
+ word = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', word)
50
+ word = re.sub(r'([a-z\d])([A-Z])', r'\1_\2', word)
51
+ word = word.replace('-', '_')
52
+ return word.lower()
53
+
54
+ if not inspect.isclass(class_type):
55
+ raise TypeError(
56
+ f'class_type must be a type, but got {type(class_type)}')
57
+ if hasattr(class_type, '_abbr_'):
58
+ return class_type._abbr_ # type: ignore
59
+ else:
60
+ return camel2snack(class_type.__name__)
61
+
62
+
63
+ def build_plugin_layer(cfg: Dict,
64
+ postfix: Union[int, str] = '',
65
+ **kwargs) -> Tuple[str, nn.Module]:
66
+ """Build plugin layer.
67
+
68
+ Args:
69
+ cfg (dict): cfg should contain:
70
+
71
+ - type (str): identify plugin layer type.
72
+ - layer args: args needed to instantiate a plugin layer.
73
+ postfix (int, str): appended into norm abbreviation to
74
+ create named layer. Default: ''.
75
+
76
+ Returns:
77
+ tuple[str, nn.Module]: The first one is the concatenation of
78
+ abbreviation and postfix. The second is the created plugin layer.
79
+ """
80
+ if not isinstance(cfg, dict):
81
+ raise TypeError('cfg must be a dict')
82
+ if 'type' not in cfg:
83
+ raise KeyError('the cfg dict must contain the key "type"')
84
+ cfg_ = cfg.copy()
85
+
86
+ layer_type = cfg_.pop('type')
87
+ if inspect.isclass(layer_type):
88
+ plugin_layer = layer_type
89
+ else:
90
+ # Switch registry to the target scope. If `plugin_layer` cannot be
91
+ # found in the registry, fallback to search `plugin_layer` in the
92
+ # mmengine.MODELS.
93
+ with MODELS.switch_scope_and_registry(None) as registry:
94
+ plugin_layer = registry.get(layer_type)
95
+ if plugin_layer is None:
96
+ raise KeyError(
97
+ f'Cannot find {plugin_layer} in registry under scope '
98
+ f'name {registry.scope}')
99
+ abbr = infer_abbr(plugin_layer)
100
+
101
+ assert isinstance(postfix, (int, str))
102
+ name = abbr + str(postfix)
103
+
104
+ layer = plugin_layer(**kwargs, **cfg_)
105
+
106
+ return name, layer