Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
28c256d
0
Parent(s):
Add initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +38 -0
- .gitignore +1 -0
- NOTES.md +11 -0
- README.md +13 -0
- app.py +453 -0
- assets/checkpoints/rtmdet_m_8xb32-100e_coco-obj365-person-235e8209.pth +3 -0
- assets/checkpoints/sapiens_0.3b_goliath_best_goliath_AP_575_torchscript.pt2 +3 -0
- assets/checkpoints/sapiens_1b_goliath_best_goliath_AP_640_torchscript.pt2 +3 -0
- assets/images/68204.png +3 -0
- assets/images/68210.png +3 -0
- assets/images/68658.png +3 -0
- assets/images/68666.png +3 -0
- assets/images/68691.png +3 -0
- assets/images/68956.png +3 -0
- assets/images/pexels-amresh444-17315601.png +3 -0
- assets/images/pexels-gabby-k-6311686.png +3 -0
- assets/images/pexels-julia-m-cameron-4145040.png +3 -0
- assets/images/pexels-marcus-aurelius-6787357.png +3 -0
- assets/images/pexels-mo-saeed-3616599-5409085.png +3 -0
- assets/images/pexels-riedelmax-27355495.png +3 -0
- assets/images/pexels-sergeymakashin-5368660.png +3 -0
- assets/images/pexels-vinicius-wiesehofer-289347-4219918.png +3 -0
- assets/rtmdet_m_640-8xb32_coco-person_no_nms.py +20 -0
- build_wheel.py +26 -0
- classes_and_palettes.py +1024 -0
- detector_utils.py +196 -0
- external/cv/.gitignore +125 -0
- external/cv/MANIFEST.in +6 -0
- external/cv/dist/sapiens_cv-1.0.0-cp310-cp310-linux_x86_64.whl +3 -0
- external/cv/mmcv/__init__.py +18 -0
- external/cv/mmcv/arraymisc/__init__.py +9 -0
- external/cv/mmcv/arraymisc/quantization.py +70 -0
- external/cv/mmcv/cnn/__init__.py +33 -0
- external/cv/mmcv/cnn/alexnet.py +68 -0
- external/cv/mmcv/cnn/bricks/__init__.py +37 -0
- external/cv/mmcv/cnn/bricks/activation.py +119 -0
- external/cv/mmcv/cnn/bricks/context_block.py +131 -0
- external/cv/mmcv/cnn/bricks/conv.py +56 -0
- external/cv/mmcv/cnn/bricks/conv2d_adaptive_padding.py +68 -0
- external/cv/mmcv/cnn/bricks/conv_module.py +343 -0
- external/cv/mmcv/cnn/bricks/conv_ws.py +158 -0
- external/cv/mmcv/cnn/bricks/depthwise_separable_conv_module.py +104 -0
- external/cv/mmcv/cnn/bricks/drop.py +72 -0
- external/cv/mmcv/cnn/bricks/generalized_attention.py +416 -0
- external/cv/mmcv/cnn/bricks/hsigmoid.py +55 -0
- external/cv/mmcv/cnn/bricks/hswish.py +44 -0
- external/cv/mmcv/cnn/bricks/non_local.py +313 -0
- external/cv/mmcv/cnn/bricks/norm.py +161 -0
- external/cv/mmcv/cnn/bricks/padding.py +48 -0
- external/cv/mmcv/cnn/bricks/plugin.py +106 -0
.gitattributes
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.pt2 filter=lfs diff=lfs merge=lfs -text
|
38 |
+
*.whl filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
**pycache**
|
NOTES.md
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Create wheel for mmcv
|
2 |
+
```
|
3 |
+
cd ./external/engine
|
4 |
+
python setup.py bdist_wheel
|
5 |
+
|
6 |
+
cd ./external/cv
|
7 |
+
MMCV_WITH_OPS=1 python setup.py bdist_wheel
|
8 |
+
|
9 |
+
cd ./external/det
|
10 |
+
python setup.py bdist_wheel
|
11 |
+
```
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Sapiens Pose
|
3 |
+
emoji: 📊
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.42.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: cc-by-nc-4.0
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
import spaces
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import json
|
8 |
+
import tempfile
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torchvision import transforms
|
11 |
+
from PIL import Image
|
12 |
+
import cv2
|
13 |
+
from gradio.themes.utils import sizes
|
14 |
+
from classes_and_palettes import (
|
15 |
+
COCO_KPTS_COLORS,
|
16 |
+
COCO_WHOLEBODY_KPTS_COLORS,
|
17 |
+
GOLIATH_KPTS_COLORS,
|
18 |
+
GOLIATH_SKELETON_INFO,
|
19 |
+
GOLIATH_KEYPOINTS
|
20 |
+
)
|
21 |
+
|
22 |
+
import os
|
23 |
+
import sys
|
24 |
+
import subprocess
|
25 |
+
import importlib.util
|
26 |
+
|
27 |
+
def is_package_installed(package_name):
|
28 |
+
return importlib.util.find_spec(package_name) is not None
|
29 |
+
|
30 |
+
def find_wheel(package_path):
|
31 |
+
dist_dir = os.path.join(package_path, "dist")
|
32 |
+
if os.path.exists(dist_dir):
|
33 |
+
wheel_files = [f for f in os.listdir(dist_dir) if f.endswith('.whl')]
|
34 |
+
if wheel_files:
|
35 |
+
return os.path.join(dist_dir, wheel_files[0])
|
36 |
+
return None
|
37 |
+
|
38 |
+
def install_from_wheel(package_name, package_path):
|
39 |
+
wheel_file = find_wheel(package_path)
|
40 |
+
if wheel_file:
|
41 |
+
print(f"Installing {package_name} from wheel: {wheel_file}")
|
42 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", wheel_file])
|
43 |
+
else:
|
44 |
+
print(f"{package_name} wheel not found in {package_path}. Please build it first.")
|
45 |
+
sys.exit(1)
|
46 |
+
|
47 |
+
def install_local_packages():
|
48 |
+
packages = [
|
49 |
+
("mmengine", "./external/engine"),
|
50 |
+
("mmcv", "./external/cv"),
|
51 |
+
("mmdet", "./external/det")
|
52 |
+
]
|
53 |
+
|
54 |
+
for package_name, package_path in packages:
|
55 |
+
if not is_package_installed(package_name):
|
56 |
+
print(f"Installing {package_name}...")
|
57 |
+
install_from_wheel(package_name, package_path)
|
58 |
+
else:
|
59 |
+
print(f"{package_name} is already installed.")
|
60 |
+
|
61 |
+
# Run the installation at the start of your app
|
62 |
+
install_local_packages()
|
63 |
+
|
64 |
+
from detector_utils import (
|
65 |
+
adapt_mmdet_pipeline,
|
66 |
+
init_detector,
|
67 |
+
process_images_detector,
|
68 |
+
)
|
69 |
+
|
70 |
+
class Config:
|
71 |
+
ASSETS_DIR = os.path.join(os.path.dirname(__file__), 'assets')
|
72 |
+
CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints")
|
73 |
+
CHECKPOINTS = {
|
74 |
+
"0.3b": "sapiens_0.3b_goliath_best_goliath_AP_575_torchscript.pt2",
|
75 |
+
"1b": "sapiens_1b_goliath_best_goliath_AP_640_torchscript.pt2",
|
76 |
+
}
|
77 |
+
DETECTION_CHECKPOINT = os.path.join(CHECKPOINTS_DIR, 'rtmdet_m_8xb32-100e_coco-obj365-person-235e8209.pth')
|
78 |
+
DETECTION_CONFIG = os.path.join(ASSETS_DIR, 'rtmdet_m_640-8xb32_coco-person_no_nms.py')
|
79 |
+
|
80 |
+
class ModelManager:
|
81 |
+
@staticmethod
|
82 |
+
def load_model(checkpoint_name: str):
|
83 |
+
if checkpoint_name is None:
|
84 |
+
return None
|
85 |
+
checkpoint_path = os.path.join(Config.CHECKPOINTS_DIR, checkpoint_name)
|
86 |
+
model = torch.jit.load(checkpoint_path)
|
87 |
+
model.eval()
|
88 |
+
model.to("cuda")
|
89 |
+
return model
|
90 |
+
|
91 |
+
@staticmethod
|
92 |
+
@torch.inference_mode()
|
93 |
+
def run_model(model, input_tensor):
|
94 |
+
return model(input_tensor)
|
95 |
+
|
96 |
+
class ImageProcessor:
|
97 |
+
def __init__(self):
|
98 |
+
self.transform = transforms.Compose([
|
99 |
+
transforms.Resize((1024, 768)),
|
100 |
+
transforms.ToTensor(),
|
101 |
+
transforms.Normalize(mean=[123.5/255, 116.5/255, 103.5/255],
|
102 |
+
std=[58.5/255, 57.0/255, 57.5/255])
|
103 |
+
])
|
104 |
+
self.detector = init_detector(
|
105 |
+
Config.DETECTION_CONFIG, Config.DETECTION_CHECKPOINT, device='cpu'
|
106 |
+
)
|
107 |
+
self.detector.cfg = adapt_mmdet_pipeline(self.detector.cfg)
|
108 |
+
|
109 |
+
def detect_persons(self, image: Image.Image):
|
110 |
+
# Convert PIL Image to tensor
|
111 |
+
image = np.array(image)
|
112 |
+
image = np.expand_dims(image, axis=0)
|
113 |
+
|
114 |
+
# Perform person detection
|
115 |
+
bboxes_batch = process_images_detector(
|
116 |
+
image,
|
117 |
+
self.detector
|
118 |
+
)
|
119 |
+
bboxes = self.get_person_bboxes(bboxes_batch[0]) # Get bboxes for the first (and only) image
|
120 |
+
|
121 |
+
return bboxes
|
122 |
+
|
123 |
+
def get_person_bboxes(self, bboxes_batch, score_thr=0.3):
|
124 |
+
person_bboxes = []
|
125 |
+
for bbox in bboxes_batch:
|
126 |
+
if len(bbox) == 5: # [x1, y1, x2, y2, score]
|
127 |
+
if bbox[4] > score_thr:
|
128 |
+
person_bboxes.append(bbox)
|
129 |
+
elif len(bbox) == 4: # [x1, y1, x2, y2]
|
130 |
+
person_bboxes.append(bbox + [1.0]) # Add a default score of 1.0
|
131 |
+
return person_bboxes
|
132 |
+
|
133 |
+
@spaces.GPU
|
134 |
+
@torch.inference_mode()
|
135 |
+
def estimate_pose(self, image: Image.Image, bboxes: List[List[float]], model_name: str, kpt_threshold: float):
|
136 |
+
pose_model = ModelManager.load_model(Config.CHECKPOINTS[model_name])
|
137 |
+
|
138 |
+
result_image = image.copy()
|
139 |
+
all_keypoints = [] # List to store keypoints for all persons
|
140 |
+
|
141 |
+
for bbox in bboxes:
|
142 |
+
cropped_img = self.crop_image(result_image, bbox)
|
143 |
+
input_tensor = self.transform(cropped_img).unsqueeze(0).to("cuda")
|
144 |
+
heatmaps = ModelManager.run_model(pose_model, input_tensor)
|
145 |
+
keypoints = self.heatmaps_to_keypoints(heatmaps[0].cpu().numpy())
|
146 |
+
all_keypoints.append(keypoints) # Collect keypoints
|
147 |
+
result_image = self.draw_keypoints(result_image, keypoints, bbox, kpt_threshold)
|
148 |
+
|
149 |
+
return result_image, all_keypoints
|
150 |
+
|
151 |
+
def process_image(self, image: Image.Image, model_name: str, kpt_threshold: str):
|
152 |
+
bboxes = self.detect_persons(image)
|
153 |
+
result_image, keypoints = self.estimate_pose(image, bboxes, model_name, float(kpt_threshold))
|
154 |
+
return result_image, keypoints
|
155 |
+
|
156 |
+
def crop_image(self, image, bbox):
|
157 |
+
if len(bbox) == 4:
|
158 |
+
x1, y1, x2, y2 = map(int, bbox)
|
159 |
+
elif len(bbox) >= 5:
|
160 |
+
x1, y1, x2, y2, _ = map(int, bbox[:5])
|
161 |
+
else:
|
162 |
+
raise ValueError(f"Unexpected bbox format: {bbox}")
|
163 |
+
|
164 |
+
crop = image.crop((x1, y1, x2, y2))
|
165 |
+
return crop
|
166 |
+
|
167 |
+
@staticmethod
|
168 |
+
def heatmaps_to_keypoints(heatmaps):
|
169 |
+
num_joints = heatmaps.shape[0] # Should be 308
|
170 |
+
keypoints = {}
|
171 |
+
for i, name in enumerate(GOLIATH_KEYPOINTS):
|
172 |
+
if i < num_joints:
|
173 |
+
heatmap = heatmaps[i]
|
174 |
+
y, x = np.unravel_index(np.argmax(heatmap), heatmap.shape)
|
175 |
+
conf = heatmap[y, x]
|
176 |
+
keypoints[name] = (float(x), float(y), float(conf))
|
177 |
+
return keypoints
|
178 |
+
|
179 |
+
@staticmethod
|
180 |
+
def draw_keypoints(image, keypoints, bbox, kpt_threshold):
|
181 |
+
image = np.array(image)
|
182 |
+
|
183 |
+
# Handle both 4 and 5-element bounding boxes
|
184 |
+
if len(bbox) == 4:
|
185 |
+
x1, y1, x2, y2 = map(int, bbox)
|
186 |
+
elif len(bbox) >= 5:
|
187 |
+
x1, y1, x2, y2, _ = map(int, bbox[:5])
|
188 |
+
else:
|
189 |
+
raise ValueError(f"Unexpected bbox format: {bbox}")
|
190 |
+
|
191 |
+
# Calculate adaptive radius and thickness based on bounding box size
|
192 |
+
bbox_width = x2 - x1
|
193 |
+
bbox_height = y2 - y1
|
194 |
+
bbox_size = np.sqrt(bbox_width * bbox_height)
|
195 |
+
|
196 |
+
radius = max(1, int(bbox_size * 0.006)) # minimum 1 pixel
|
197 |
+
thickness = max(1, int(bbox_size * 0.006)) # minimum 1 pixel
|
198 |
+
bbox_thickness = max(1, thickness//4)
|
199 |
+
|
200 |
+
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), bbox_thickness)
|
201 |
+
|
202 |
+
# Draw keypoints
|
203 |
+
for i, (name, (x, y, conf)) in enumerate(keypoints.items()):
|
204 |
+
if conf > kpt_threshold and i < len(GOLIATH_KPTS_COLORS):
|
205 |
+
x_coord = int(x * bbox_width / 192) + x1
|
206 |
+
y_coord = int(y * bbox_height / 256) + y1
|
207 |
+
color = GOLIATH_KPTS_COLORS[i]
|
208 |
+
cv2.circle(image, (x_coord, y_coord), radius, color, -1)
|
209 |
+
|
210 |
+
# Draw skeleton
|
211 |
+
for _, link_info in GOLIATH_SKELETON_INFO.items():
|
212 |
+
pt1_name, pt2_name = link_info['link']
|
213 |
+
color = link_info['color']
|
214 |
+
|
215 |
+
if pt1_name in keypoints and pt2_name in keypoints:
|
216 |
+
pt1 = keypoints[pt1_name]
|
217 |
+
pt2 = keypoints[pt2_name]
|
218 |
+
if pt1[2] > kpt_threshold and pt2[2] > kpt_threshold:
|
219 |
+
x1_coord = int(pt1[0] * bbox_width / 192) + x1
|
220 |
+
y1_coord = int(pt1[1] * bbox_height / 256) + y1
|
221 |
+
x2_coord = int(pt2[0] * bbox_width / 192) + x1
|
222 |
+
y2_coord = int(pt2[1] * bbox_height / 256) + y1
|
223 |
+
cv2.line(image, (x1_coord, y1_coord), (x2_coord, y2_coord), color, thickness=thickness)
|
224 |
+
|
225 |
+
return Image.fromarray(image)
|
226 |
+
|
227 |
+
class GradioInterface:
|
228 |
+
def __init__(self):
|
229 |
+
self.image_processor = ImageProcessor()
|
230 |
+
|
231 |
+
def create_interface(self):
|
232 |
+
app_styles = """
|
233 |
+
<style>
|
234 |
+
/* Global Styles */
|
235 |
+
body, #root {
|
236 |
+
font-family: Helvetica, Arial, sans-serif;
|
237 |
+
background-color: #1a1a1a;
|
238 |
+
color: #fafafa;
|
239 |
+
}
|
240 |
+
/* Header Styles */
|
241 |
+
.app-header {
|
242 |
+
background: linear-gradient(45deg, #1a1a1a 0%, #333333 100%);
|
243 |
+
padding: 24px;
|
244 |
+
border-radius: 8px;
|
245 |
+
margin-bottom: 24px;
|
246 |
+
text-align: center;
|
247 |
+
}
|
248 |
+
.app-title {
|
249 |
+
font-size: 48px;
|
250 |
+
margin: 0;
|
251 |
+
color: #fafafa;
|
252 |
+
}
|
253 |
+
.app-subtitle {
|
254 |
+
font-size: 24px;
|
255 |
+
margin: 8px 0 16px;
|
256 |
+
color: #fafafa;
|
257 |
+
}
|
258 |
+
.app-description {
|
259 |
+
font-size: 16px;
|
260 |
+
line-height: 1.6;
|
261 |
+
opacity: 0.8;
|
262 |
+
margin-bottom: 24px;
|
263 |
+
}
|
264 |
+
/* Button Styles */
|
265 |
+
.publication-links {
|
266 |
+
display: flex;
|
267 |
+
justify-content: center;
|
268 |
+
flex-wrap: wrap;
|
269 |
+
gap: 8px;
|
270 |
+
margin-bottom: 16px;
|
271 |
+
}
|
272 |
+
.publication-link {
|
273 |
+
display: inline-flex;
|
274 |
+
align-items: center;
|
275 |
+
padding: 8px 16px;
|
276 |
+
background-color: #333;
|
277 |
+
color: #fff !important;
|
278 |
+
text-decoration: none !important;
|
279 |
+
border-radius: 20px;
|
280 |
+
font-size: 14px;
|
281 |
+
transition: background-color 0.3s;
|
282 |
+
}
|
283 |
+
.publication-link:hover {
|
284 |
+
background-color: #555;
|
285 |
+
}
|
286 |
+
.publication-link i {
|
287 |
+
margin-right: 8px;
|
288 |
+
}
|
289 |
+
/* Content Styles */
|
290 |
+
.content-container {
|
291 |
+
background-color: #2a2a2a;
|
292 |
+
border-radius: 8px;
|
293 |
+
padding: 24px;
|
294 |
+
margin-bottom: 24px;
|
295 |
+
}
|
296 |
+
/* Image Styles */
|
297 |
+
.image-preview img {
|
298 |
+
max-width: 512px;
|
299 |
+
max-height: 512px;
|
300 |
+
margin: 0 auto;
|
301 |
+
border-radius: 4px;
|
302 |
+
display: block;
|
303 |
+
object-fit: contain;
|
304 |
+
}
|
305 |
+
/* Control Styles */
|
306 |
+
.control-panel {
|
307 |
+
background-color: #333;
|
308 |
+
padding: 16px;
|
309 |
+
border-radius: 8px;
|
310 |
+
margin-top: 16px;
|
311 |
+
}
|
312 |
+
/* Gradio Component Overrides */
|
313 |
+
.gr-button {
|
314 |
+
background-color: #4a4a4a;
|
315 |
+
color: #fff;
|
316 |
+
border: none;
|
317 |
+
border-radius: 4px;
|
318 |
+
padding: 8px 16px;
|
319 |
+
cursor: pointer;
|
320 |
+
transition: background-color 0.3s;
|
321 |
+
}
|
322 |
+
.gr-button:hover {
|
323 |
+
background-color: #5a5a5a;
|
324 |
+
}
|
325 |
+
.gr-input, .gr-dropdown {
|
326 |
+
background-color: #3a3a3a;
|
327 |
+
color: #fff;
|
328 |
+
border: 1px solid #4a4a4a;
|
329 |
+
border-radius: 4px;
|
330 |
+
padding: 8px;
|
331 |
+
}
|
332 |
+
.gr-form {
|
333 |
+
background-color: transparent;
|
334 |
+
}
|
335 |
+
.gr-panel {
|
336 |
+
border: none;
|
337 |
+
background-color: transparent;
|
338 |
+
}
|
339 |
+
/* Override any conflicting styles from Bulma */
|
340 |
+
.button.is-normal.is-rounded.is-dark {
|
341 |
+
color: #fff !important;
|
342 |
+
text-decoration: none !important;
|
343 |
+
}
|
344 |
+
</style>
|
345 |
+
"""
|
346 |
+
|
347 |
+
header_html = f"""
|
348 |
+
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/css/bulma.min.css">
|
349 |
+
<link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.15.4/css/all.css">
|
350 |
+
{app_styles}
|
351 |
+
<div class="app-header">
|
352 |
+
<h1 class="app-title">Sapiens: Pose Estimation</h1>
|
353 |
+
<h2 class="app-subtitle">ECCV 2024 (Oral)</h2>
|
354 |
+
<p class="app-description">
|
355 |
+
Meta presents Sapiens, foundation models for human tasks pretrained on 300 million human images.
|
356 |
+
This demo showcases the finetuned pose estimation model. <br>
|
357 |
+
</p>
|
358 |
+
<div class="publication-links">
|
359 |
+
<a href="https://arxiv.org/abs/2408.12569" class="publication-link">
|
360 |
+
<i class="fas fa-file-pdf"></i>arXiv
|
361 |
+
</a>
|
362 |
+
<a href="https://github.com/facebookresearch/sapiens" class="publication-link">
|
363 |
+
<i class="fab fa-github"></i>Code
|
364 |
+
</a>
|
365 |
+
<a href="https://about.meta.com/realitylabs/codecavatars/sapiens/" class="publication-link">
|
366 |
+
<i class="fas fa-globe"></i>Meta
|
367 |
+
</a>
|
368 |
+
<a href="https://rawalkhirodkar.github.io/sapiens" class="publication-link">
|
369 |
+
<i class="fas fa-chart-bar"></i>Results
|
370 |
+
</a>
|
371 |
+
</div>
|
372 |
+
<div class="publication-links">
|
373 |
+
<a href="https://huggingface.co/spaces/facebook/sapiens_pose" class="publication-link">
|
374 |
+
<i class="fas fa-user"></i>Demo-Pose
|
375 |
+
</a>
|
376 |
+
<a href="https://huggingface.co/spaces/facebook/sapiens_seg" class="publication-link">
|
377 |
+
<i class="fas fa-puzzle-piece"></i>Demo-Seg
|
378 |
+
</a>
|
379 |
+
<a href="https://huggingface.co/spaces/facebook/sapiens_depth" class="publication-link">
|
380 |
+
<i class="fas fa-cube"></i>Demo-Depth
|
381 |
+
</a>
|
382 |
+
<a href="https://huggingface.co/spaces/facebook/sapiens_normal" class="publication-link">
|
383 |
+
<i class="fas fa-vector-square"></i>Demo-Normal
|
384 |
+
</a>
|
385 |
+
</div>
|
386 |
+
</div>
|
387 |
+
"""
|
388 |
+
|
389 |
+
js_func = """
|
390 |
+
function refresh() {
|
391 |
+
const url = new URL(window.location);
|
392 |
+
if (url.searchParams.get('__theme') !== 'dark') {
|
393 |
+
url.searchParams.set('__theme', 'dark');
|
394 |
+
window.location.href = url.href;
|
395 |
+
}
|
396 |
+
}
|
397 |
+
"""
|
398 |
+
|
399 |
+
def process_image(image, model_name, kpt_threshold):
|
400 |
+
result_image, keypoints = self.image_processor.process_image(image, model_name, kpt_threshold)
|
401 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".json", mode='w') as json_file:
|
402 |
+
json.dump(keypoints, json_file)
|
403 |
+
json_file_path = json_file.name
|
404 |
+
return result_image, json_file_path
|
405 |
+
|
406 |
+
with gr.Blocks(js=js_func, theme=gr.themes.Default()) as demo:
|
407 |
+
gr.HTML(header_html)
|
408 |
+
with gr.Row(elem_classes="content-container"):
|
409 |
+
with gr.Column():
|
410 |
+
input_image = gr.Image(label="Input Image", type="pil", format="png", elem_classes="image-preview")
|
411 |
+
with gr.Row():
|
412 |
+
model_name = gr.Dropdown(
|
413 |
+
label="Model Size",
|
414 |
+
choices=list(Config.CHECKPOINTS.keys()),
|
415 |
+
value="1b",
|
416 |
+
)
|
417 |
+
kpt_threshold = gr.Dropdown(
|
418 |
+
label="Min Keypoint Confidence",
|
419 |
+
choices=["0.1", "0.2", "0.3", "0.4", "0.5", "0.6", "0.7", "0.8", "0.9"],
|
420 |
+
value="0.3",
|
421 |
+
)
|
422 |
+
example_model = gr.Examples(
|
423 |
+
inputs=input_image,
|
424 |
+
examples_per_page=14,
|
425 |
+
examples=[
|
426 |
+
os.path.join(Config.ASSETS_DIR, "images", img)
|
427 |
+
for img in os.listdir(os.path.join(Config.ASSETS_DIR, "images"))
|
428 |
+
],
|
429 |
+
)
|
430 |
+
with gr.Column():
|
431 |
+
result_image = gr.Image(label="Pose-308 Result", type="pil", elem_classes="image-preview")
|
432 |
+
json_output = gr.File(label="Pose-308 Output (.json)")
|
433 |
+
run_button = gr.Button("Run")
|
434 |
+
|
435 |
+
run_button.click(
|
436 |
+
fn=process_image,
|
437 |
+
inputs=[input_image, model_name, kpt_threshold],
|
438 |
+
outputs=[result_image, json_output],
|
439 |
+
)
|
440 |
+
|
441 |
+
return demo
|
442 |
+
|
443 |
+
def main():
|
444 |
+
if torch.cuda.is_available():
|
445 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
446 |
+
torch.backends.cudnn.allow_tf32 = True
|
447 |
+
|
448 |
+
interface = GradioInterface()
|
449 |
+
demo = interface.create_interface()
|
450 |
+
demo.launch(share=False)
|
451 |
+
|
452 |
+
if __name__ == "__main__":
|
453 |
+
main()
|
assets/checkpoints/rtmdet_m_8xb32-100e_coco-obj365-person-235e8209.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b66b27072c6a3cd4f093882df440921987076131fb78a7df7b1cf92d67f41509
|
3 |
+
size 99149914
|
assets/checkpoints/sapiens_0.3b_goliath_best_goliath_AP_575_torchscript.pt2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:21cf7e3e723720d847bee6d3b321bfcdb33268c9f1418d7552552264ae0a5a9b
|
3 |
+
size 1319579523
|
assets/checkpoints/sapiens_1b_goliath_best_goliath_AP_640_torchscript.pt2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c6218c6be17697157f9e65ee34054a94ab8ca0f637380fa5748c18e04814976e
|
3 |
+
size 4677162331
|
assets/images/68204.png
ADDED
![]() |
Git LFS Details
|
assets/images/68210.png
ADDED
![]() |
Git LFS Details
|
assets/images/68658.png
ADDED
![]() |
Git LFS Details
|
assets/images/68666.png
ADDED
![]() |
Git LFS Details
|
assets/images/68691.png
ADDED
![]() |
Git LFS Details
|
assets/images/68956.png
ADDED
![]() |
Git LFS Details
|
assets/images/pexels-amresh444-17315601.png
ADDED
![]() |
Git LFS Details
|
assets/images/pexels-gabby-k-6311686.png
ADDED
![]() |
Git LFS Details
|
assets/images/pexels-julia-m-cameron-4145040.png
ADDED
![]() |
Git LFS Details
|
assets/images/pexels-marcus-aurelius-6787357.png
ADDED
![]() |
Git LFS Details
|
assets/images/pexels-mo-saeed-3616599-5409085.png
ADDED
![]() |
Git LFS Details
|
assets/images/pexels-riedelmax-27355495.png
ADDED
![]() |
Git LFS Details
|
assets/images/pexels-sergeymakashin-5368660.png
ADDED
![]() |
Git LFS Details
|
assets/images/pexels-vinicius-wiesehofer-289347-4219918.png
ADDED
![]() |
Git LFS Details
|
assets/rtmdet_m_640-8xb32_coco-person_no_nms.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = 'mmdet::rtmdet/rtmdet_m_8xb32-300e_coco.py'
|
2 |
+
|
3 |
+
checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-m_8xb256-rsb-a1-600e_in1k-ecb3bbd9.pth' # noqa
|
4 |
+
|
5 |
+
model = dict(
|
6 |
+
backbone=dict(
|
7 |
+
init_cfg=dict(
|
8 |
+
type='Pretrained', prefix='backbone.', checkpoint=checkpoint)),
|
9 |
+
bbox_head=dict(num_classes=1),
|
10 |
+
test_cfg=dict(
|
11 |
+
nms_pre=1000,
|
12 |
+
min_bbox_size=0,
|
13 |
+
score_thr=0.05,
|
14 |
+
nms=None,
|
15 |
+
max_per_img=100))
|
16 |
+
|
17 |
+
train_dataloader = dict(dataset=dict(metainfo=dict(classes=('person', ))))
|
18 |
+
|
19 |
+
val_dataloader = dict(dataset=dict(metainfo=dict(classes=('person', ))))
|
20 |
+
test_dataloader = val_dataloader
|
build_wheel.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import subprocess
|
3 |
+
import sys
|
4 |
+
|
5 |
+
def build_wheel(package_path):
|
6 |
+
current_dir = os.getcwd()
|
7 |
+
os.chdir(package_path)
|
8 |
+
try:
|
9 |
+
subprocess.check_call([sys.executable, "setup.py", "bdist_wheel"])
|
10 |
+
finally:
|
11 |
+
os.chdir(current_dir)
|
12 |
+
|
13 |
+
def main():
|
14 |
+
packages = [
|
15 |
+
"./external/engine",
|
16 |
+
"./external/cv",
|
17 |
+
"./external/det"
|
18 |
+
]
|
19 |
+
|
20 |
+
for package in packages:
|
21 |
+
print(f"Building wheel for {package}...")
|
22 |
+
build_wheel(package)
|
23 |
+
print(f"Wheel built for {package}")
|
24 |
+
|
25 |
+
if __name__ == "__main__":
|
26 |
+
main()
|
classes_and_palettes.py
ADDED
@@ -0,0 +1,1024 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
COCO_KPTS_COLORS = [
|
2 |
+
[51, 153, 255], # 0: nose
|
3 |
+
[51, 153, 255], # 1: left_eye
|
4 |
+
[51, 153, 255], # 2: right_eye
|
5 |
+
[51, 153, 255], # 3: left_ear
|
6 |
+
[51, 153, 255], # 4: right_ear
|
7 |
+
[0, 255, 0], # 5: left_shoulder
|
8 |
+
[255, 128, 0], # 6: right_shoulder
|
9 |
+
[0, 255, 0], # 7: left_elbow
|
10 |
+
[255, 128, 0], # 8: right_elbow
|
11 |
+
[0, 255, 0], # 9: left_wrist
|
12 |
+
[255, 128, 0], # 10: right_wrist
|
13 |
+
[0, 255, 0], # 11: left_hip
|
14 |
+
[255, 128, 0], # 12: right_hip
|
15 |
+
[0, 255, 0], # 13: left_knee
|
16 |
+
[255, 128, 0], # 14: right_knee
|
17 |
+
[0, 255, 0], # 15: left_ankle
|
18 |
+
[255, 128, 0], # 16: right_ankle
|
19 |
+
]
|
20 |
+
|
21 |
+
COCO_WHOLEBODY_KPTS_COLORS = [
|
22 |
+
[51, 153, 255], # 0: nose
|
23 |
+
[51, 153, 255], # 1: left_eye
|
24 |
+
[51, 153, 255], # 2: right_eye
|
25 |
+
[51, 153, 255], # 3: left_ear
|
26 |
+
[51, 153, 255], # 4: right_ear
|
27 |
+
[0, 255, 0], # 5: left_shoulder
|
28 |
+
[255, 128, 0], # 6: right_shoulder
|
29 |
+
[0, 255, 0], # 7: left_elbow
|
30 |
+
[255, 128, 0], # 8: right_elbow
|
31 |
+
[0, 255, 0], # 9: left_wrist
|
32 |
+
[255, 128, 0], # 10: right_wrist
|
33 |
+
[0, 255, 0], # 11: left_hip
|
34 |
+
[255, 128, 0], # 12: right_hip
|
35 |
+
[0, 255, 0], # 13: left_knee
|
36 |
+
[255, 128, 0], # 14: right_knee
|
37 |
+
[0, 255, 0], # 15: left_ankle
|
38 |
+
[255, 128, 0], # 16: right_ankle
|
39 |
+
[255, 128, 0], # 17: left_big_toe
|
40 |
+
[255, 128, 0], # 18: left_small_toe
|
41 |
+
[255, 128, 0], # 19: left_heel
|
42 |
+
[255, 128, 0], # 20: right_big_toe
|
43 |
+
[255, 128, 0], # 21: right_small_toe
|
44 |
+
[255, 128, 0], # 22: right_heel
|
45 |
+
[255, 255, 255], # 23: face-0
|
46 |
+
[255, 255, 255], # 24: face-1
|
47 |
+
[255, 255, 255], # 25: face-2
|
48 |
+
[255, 255, 255], # 26: face-3
|
49 |
+
[255, 255, 255], # 27: face-4
|
50 |
+
[255, 255, 255], # 28: face-5
|
51 |
+
[255, 255, 255], # 29: face-6
|
52 |
+
[255, 255, 255], # 30: face-7
|
53 |
+
[255, 255, 255], # 31: face-8
|
54 |
+
[255, 255, 255], # 32: face-9
|
55 |
+
[255, 255, 255], # 33: face-10
|
56 |
+
[255, 255, 255], # 34: face-11
|
57 |
+
[255, 255, 255], # 35: face-12
|
58 |
+
[255, 255, 255], # 36: face-13
|
59 |
+
[255, 255, 255], # 37: face-14
|
60 |
+
[255, 255, 255], # 38: face-15
|
61 |
+
[255, 255, 255], # 39: face-16
|
62 |
+
[255, 255, 255], # 40: face-17
|
63 |
+
[255, 255, 255], # 41: face-18
|
64 |
+
[255, 255, 255], # 42: face-19
|
65 |
+
[255, 255, 255], # 43: face-20
|
66 |
+
[255, 255, 255], # 44: face-21
|
67 |
+
[255, 255, 255], # 45: face-22
|
68 |
+
[255, 255, 255], # 46: face-23
|
69 |
+
[255, 255, 255], # 47: face-24
|
70 |
+
[255, 255, 255], # 48: face-25
|
71 |
+
[255, 255, 255], # 49: face-26
|
72 |
+
[255, 255, 255], # 50: face-27
|
73 |
+
[255, 255, 255], # 51: face-28
|
74 |
+
[255, 255, 255], # 52: face-29
|
75 |
+
[255, 255, 255], # 53: face-30
|
76 |
+
[255, 255, 255], # 54: face-31
|
77 |
+
[255, 255, 255], # 55: face-32
|
78 |
+
[255, 255, 255], # 56: face-33
|
79 |
+
[255, 255, 255], # 57: face-34
|
80 |
+
[255, 255, 255], # 58: face-35
|
81 |
+
[255, 255, 255], # 59: face-36
|
82 |
+
[255, 255, 255], # 60: face-37
|
83 |
+
[255, 255, 255], # 61: face-38
|
84 |
+
[255, 255, 255], # 62: face-39
|
85 |
+
[255, 255, 255], # 63: face-40
|
86 |
+
[255, 255, 255], # 64: face-41
|
87 |
+
[255, 255, 255], # 65: face-42
|
88 |
+
[255, 255, 255], # 66: face-43
|
89 |
+
[255, 255, 255], # 67: face-44
|
90 |
+
[255, 255, 255], # 68: face-45
|
91 |
+
[255, 255, 255], # 69: face-46
|
92 |
+
[255, 255, 255], # 70: face-47
|
93 |
+
[255, 255, 255], # 71: face-48
|
94 |
+
[255, 255, 255], # 72: face-49
|
95 |
+
[255, 255, 255], # 73: face-50
|
96 |
+
[255, 255, 255], # 74: face-51
|
97 |
+
[255, 255, 255], # 75: face-52
|
98 |
+
[255, 255, 255], # 76: face-53
|
99 |
+
[255, 255, 255], # 77: face-54
|
100 |
+
[255, 255, 255], # 78: face-55
|
101 |
+
[255, 255, 255], # 79: face-56
|
102 |
+
[255, 255, 255], # 80: face-57
|
103 |
+
[255, 255, 255], # 81: face-58
|
104 |
+
[255, 255, 255], # 82: face-59
|
105 |
+
[255, 255, 255], # 83: face-60
|
106 |
+
[255, 255, 255], # 84: face-61
|
107 |
+
[255, 255, 255], # 85: face-62
|
108 |
+
[255, 255, 255], # 86: face-63
|
109 |
+
[255, 255, 255], # 87: face-64
|
110 |
+
[255, 255, 255], # 88: face-65
|
111 |
+
[255, 255, 255], # 89: face-66
|
112 |
+
[255, 255, 255], # 90: face-67
|
113 |
+
[255, 255, 255], # 91: left_hand_root
|
114 |
+
[255, 128, 0], # 92: left_thumb1
|
115 |
+
[255, 128, 0], # 93: left_thumb2
|
116 |
+
[255, 128, 0], # 94: left_thumb3
|
117 |
+
[255, 128, 0], # 95: left_thumb4
|
118 |
+
[255, 153, 255], # 96: left_forefinger1
|
119 |
+
[255, 153, 255], # 97: left_forefinger2
|
120 |
+
[255, 153, 255], # 98: left_forefinger3
|
121 |
+
[255, 153, 255], # 99: left_forefinger4
|
122 |
+
[102, 178, 255], # 100: left_middle_finger1
|
123 |
+
[102, 178, 255], # 101: left_middle_finger2
|
124 |
+
[102, 178, 255], # 102: left_middle_finger3
|
125 |
+
[102, 178, 255], # 103: left_middle_finger4
|
126 |
+
[255, 51, 51], # 104: left_ring_finger1
|
127 |
+
[255, 51, 51], # 105: left_ring_finger2
|
128 |
+
[255, 51, 51], # 106: left_ring_finger3
|
129 |
+
[255, 51, 51], # 107: left_ring_finger4
|
130 |
+
[0, 255, 0], # 108: left_pinky_finger1
|
131 |
+
[0, 255, 0], # 109: left_pinky_finger2
|
132 |
+
[0, 255, 0], # 110: left_pinky_finger3
|
133 |
+
[0, 255, 0], # 111: left_pinky_finger4
|
134 |
+
[255, 255, 255], # 112: right_hand_root
|
135 |
+
[255, 128, 0], # 113: right_thumb1
|
136 |
+
[255, 128, 0], # 114: right_thumb2
|
137 |
+
[255, 128, 0], # 115: right_thumb3
|
138 |
+
[255, 128, 0], # 116: right_thumb4
|
139 |
+
[255, 153, 255], # 117: right_forefinger1
|
140 |
+
[255, 153, 255], # 118: right_forefinger2
|
141 |
+
[255, 153, 255], # 119: right_forefinger3
|
142 |
+
[255, 153, 255], # 120: right_forefinger4
|
143 |
+
[102, 178, 255], # 121: right_middle_finger1
|
144 |
+
[102, 178, 255], # 122: right_middle_finger2
|
145 |
+
[102, 178, 255], # 123: right_middle_finger3
|
146 |
+
[102, 178, 255], # 124: right_middle_finger4
|
147 |
+
[255, 51, 51], # 125: right_ring_finger1
|
148 |
+
[255, 51, 51], # 126: right_ring_finger2
|
149 |
+
[255, 51, 51], # 127: right_ring_finger3
|
150 |
+
[255, 51, 51], # 128: right_ring_finger4
|
151 |
+
[0, 255, 0], # 129: right_pinky_finger1
|
152 |
+
[0, 255, 0], # 130: right_pinky_finger2
|
153 |
+
[0, 255, 0], # 131: right_pinky_finger3
|
154 |
+
[0, 255, 0], # 132: right_pinky_finger4
|
155 |
+
]
|
156 |
+
|
157 |
+
|
158 |
+
GOLIATH_KPTS_COLORS = [
|
159 |
+
[51, 153, 255], # 0: nose
|
160 |
+
[51, 153, 255], # 1: left_eye
|
161 |
+
[51, 153, 255], # 2: right_eye
|
162 |
+
[51, 153, 255], # 3: left_ear
|
163 |
+
[51, 153, 255], # 4: right_ear
|
164 |
+
[51, 153, 255], # 5: left_shoulder
|
165 |
+
[51, 153, 255], # 6: right_shoulder
|
166 |
+
[51, 153, 255], # 7: left_elbow
|
167 |
+
[51, 153, 255], # 8: right_elbow
|
168 |
+
[51, 153, 255], # 9: left_hip
|
169 |
+
[51, 153, 255], # 10: right_hip
|
170 |
+
[51, 153, 255], # 11: left_knee
|
171 |
+
[51, 153, 255], # 12: right_knee
|
172 |
+
[51, 153, 255], # 13: left_ankle
|
173 |
+
[51, 153, 255], # 14: right_ankle
|
174 |
+
[51, 153, 255], # 15: left_big_toe
|
175 |
+
[51, 153, 255], # 16: left_small_toe
|
176 |
+
[51, 153, 255], # 17: left_heel
|
177 |
+
[51, 153, 255], # 18: right_big_toe
|
178 |
+
[51, 153, 255], # 19: right_small_toe
|
179 |
+
[51, 153, 255], # 20: right_heel
|
180 |
+
[51, 153, 255], # 21: right_thumb4
|
181 |
+
[51, 153, 255], # 22: right_thumb3
|
182 |
+
[51, 153, 255], # 23: right_thumb2
|
183 |
+
[51, 153, 255], # 24: right_thumb_third_joint
|
184 |
+
[51, 153, 255], # 25: right_forefinger4
|
185 |
+
[51, 153, 255], # 26: right_forefinger3
|
186 |
+
[51, 153, 255], # 27: right_forefinger2
|
187 |
+
[51, 153, 255], # 28: right_forefinger_third_joint
|
188 |
+
[51, 153, 255], # 29: right_middle_finger4
|
189 |
+
[51, 153, 255], # 30: right_middle_finger3
|
190 |
+
[51, 153, 255], # 31: right_middle_finger2
|
191 |
+
[51, 153, 255], # 32: right_middle_finger_third_joint
|
192 |
+
[51, 153, 255], # 33: right_ring_finger4
|
193 |
+
[51, 153, 255], # 34: right_ring_finger3
|
194 |
+
[51, 153, 255], # 35: right_ring_finger2
|
195 |
+
[51, 153, 255], # 36: right_ring_finger_third_joint
|
196 |
+
[51, 153, 255], # 37: right_pinky_finger4
|
197 |
+
[51, 153, 255], # 38: right_pinky_finger3
|
198 |
+
[51, 153, 255], # 39: right_pinky_finger2
|
199 |
+
[51, 153, 255], # 40: right_pinky_finger_third_joint
|
200 |
+
[51, 153, 255], # 41: right_wrist
|
201 |
+
[51, 153, 255], # 42: left_thumb4
|
202 |
+
[51, 153, 255], # 43: left_thumb3
|
203 |
+
[51, 153, 255], # 44: left_thumb2
|
204 |
+
[51, 153, 255], # 45: left_thumb_third_joint
|
205 |
+
[51, 153, 255], # 46: left_forefinger4
|
206 |
+
[51, 153, 255], # 47: left_forefinger3
|
207 |
+
[51, 153, 255], # 48: left_forefinger2
|
208 |
+
[51, 153, 255], # 49: left_forefinger_third_joint
|
209 |
+
[51, 153, 255], # 50: left_middle_finger4
|
210 |
+
[51, 153, 255], # 51: left_middle_finger3
|
211 |
+
[51, 153, 255], # 52: left_middle_finger2
|
212 |
+
[51, 153, 255], # 53: left_middle_finger_third_joint
|
213 |
+
[51, 153, 255], # 54: left_ring_finger4
|
214 |
+
[51, 153, 255], # 55: left_ring_finger3
|
215 |
+
[51, 153, 255], # 56: left_ring_finger2
|
216 |
+
[51, 153, 255], # 57: left_ring_finger_third_joint
|
217 |
+
[51, 153, 255], # 58: left_pinky_finger4
|
218 |
+
[51, 153, 255], # 59: left_pinky_finger3
|
219 |
+
[51, 153, 255], # 60: left_pinky_finger2
|
220 |
+
[51, 153, 255], # 61: left_pinky_finger_third_joint
|
221 |
+
[51, 153, 255], # 62: left_wrist
|
222 |
+
[51, 153, 255], # 63: left_olecranon
|
223 |
+
[51, 153, 255], # 64: right_olecranon
|
224 |
+
[51, 153, 255], # 65: left_cubital_fossa
|
225 |
+
[51, 153, 255], # 66: right_cubital_fossa
|
226 |
+
[51, 153, 255], # 67: left_acromion
|
227 |
+
[51, 153, 255], # 68: right_acromion
|
228 |
+
[51, 153, 255], # 69: neck
|
229 |
+
[255, 255, 255], # 70: center_of_glabella
|
230 |
+
[255, 255, 255], # 71: center_of_nose_root
|
231 |
+
[255, 255, 255], # 72: tip_of_nose_bridge
|
232 |
+
[255, 255, 255], # 73: midpoint_1_of_nose_bridge
|
233 |
+
[255, 255, 255], # 74: midpoint_2_of_nose_bridge
|
234 |
+
[255, 255, 255], # 75: midpoint_3_of_nose_bridge
|
235 |
+
[255, 255, 255], # 76: center_of_labiomental_groove
|
236 |
+
[255, 255, 255], # 77: tip_of_chin
|
237 |
+
[255, 255, 255], # 78: upper_startpoint_of_r_eyebrow
|
238 |
+
[255, 255, 255], # 79: lower_startpoint_of_r_eyebrow
|
239 |
+
[255, 255, 255], # 80: end_of_r_eyebrow
|
240 |
+
[255, 255, 255], # 81: upper_midpoint_1_of_r_eyebrow
|
241 |
+
[255, 255, 255], # 82: lower_midpoint_1_of_r_eyebrow
|
242 |
+
[255, 255, 255], # 83: upper_midpoint_2_of_r_eyebrow
|
243 |
+
[255, 255, 255], # 84: upper_midpoint_3_of_r_eyebrow
|
244 |
+
[255, 255, 255], # 85: lower_midpoint_2_of_r_eyebrow
|
245 |
+
[255, 255, 255], # 86: lower_midpoint_3_of_r_eyebrow
|
246 |
+
[255, 255, 255], # 87: upper_startpoint_of_l_eyebrow
|
247 |
+
[255, 255, 255], # 88: lower_startpoint_of_l_eyebrow
|
248 |
+
[255, 255, 255], # 89: end_of_l_eyebrow
|
249 |
+
[255, 255, 255], # 90: upper_midpoint_1_of_l_eyebrow
|
250 |
+
[255, 255, 255], # 91: lower_midpoint_1_of_l_eyebrow
|
251 |
+
[255, 255, 255], # 92: upper_midpoint_2_of_l_eyebrow
|
252 |
+
[255, 255, 255], # 93: upper_midpoint_3_of_l_eyebrow
|
253 |
+
[255, 255, 255], # 94: lower_midpoint_2_of_l_eyebrow
|
254 |
+
[255, 255, 255], # 95: lower_midpoint_3_of_l_eyebrow
|
255 |
+
[192, 64, 128], # 96: l_inner_end_of_upper_lash_line
|
256 |
+
[192, 64, 128], # 97: l_outer_end_of_upper_lash_line
|
257 |
+
[192, 64, 128], # 98: l_centerpoint_of_upper_lash_line
|
258 |
+
[192, 64, 128], # 99: l_midpoint_2_of_upper_lash_line
|
259 |
+
[192, 64, 128], # 100: l_midpoint_1_of_upper_lash_line
|
260 |
+
[192, 64, 128], # 101: l_midpoint_6_of_upper_lash_line
|
261 |
+
[192, 64, 128], # 102: l_midpoint_5_of_upper_lash_line
|
262 |
+
[192, 64, 128], # 103: l_midpoint_4_of_upper_lash_line
|
263 |
+
[192, 64, 128], # 104: l_midpoint_3_of_upper_lash_line
|
264 |
+
[192, 64, 128], # 105: l_outer_end_of_upper_eyelid_line
|
265 |
+
[192, 64, 128], # 106: l_midpoint_6_of_upper_eyelid_line
|
266 |
+
[192, 64, 128], # 107: l_midpoint_2_of_upper_eyelid_line
|
267 |
+
[192, 64, 128], # 108: l_midpoint_5_of_upper_eyelid_line
|
268 |
+
[192, 64, 128], # 109: l_centerpoint_of_upper_eyelid_line
|
269 |
+
[192, 64, 128], # 110: l_midpoint_4_of_upper_eyelid_line
|
270 |
+
[192, 64, 128], # 111: l_midpoint_1_of_upper_eyelid_line
|
271 |
+
[192, 64, 128], # 112: l_midpoint_3_of_upper_eyelid_line
|
272 |
+
[192, 64, 128], # 113: l_midpoint_6_of_upper_crease_line
|
273 |
+
[192, 64, 128], # 114: l_midpoint_2_of_upper_crease_line
|
274 |
+
[192, 64, 128], # 115: l_midpoint_5_of_upper_crease_line
|
275 |
+
[192, 64, 128], # 116: l_centerpoint_of_upper_crease_line
|
276 |
+
[192, 64, 128], # 117: l_midpoint_4_of_upper_crease_line
|
277 |
+
[192, 64, 128], # 118: l_midpoint_1_of_upper_crease_line
|
278 |
+
[192, 64, 128], # 119: l_midpoint_3_of_upper_crease_line
|
279 |
+
[64, 32, 192], # 120: r_inner_end_of_upper_lash_line
|
280 |
+
[64, 32, 192], # 121: r_outer_end_of_upper_lash_line
|
281 |
+
[64, 32, 192], # 122: r_centerpoint_of_upper_lash_line
|
282 |
+
[64, 32, 192], # 123: r_midpoint_1_of_upper_lash_line
|
283 |
+
[64, 32, 192], # 124: r_midpoint_2_of_upper_lash_line
|
284 |
+
[64, 32, 192], # 125: r_midpoint_3_of_upper_lash_line
|
285 |
+
[64, 32, 192], # 126: r_midpoint_4_of_upper_lash_line
|
286 |
+
[64, 32, 192], # 127: r_midpoint_5_of_upper_lash_line
|
287 |
+
[64, 32, 192], # 128: r_midpoint_6_of_upper_lash_line
|
288 |
+
[64, 32, 192], # 129: r_outer_end_of_upper_eyelid_line
|
289 |
+
[64, 32, 192], # 130: r_midpoint_3_of_upper_eyelid_line
|
290 |
+
[64, 32, 192], # 131: r_midpoint_1_of_upper_eyelid_line
|
291 |
+
[64, 32, 192], # 132: r_midpoint_4_of_upper_eyelid_line
|
292 |
+
[64, 32, 192], # 133: r_centerpoint_of_upper_eyelid_line
|
293 |
+
[64, 32, 192], # 134: r_midpoint_5_of_upper_eyelid_line
|
294 |
+
[64, 32, 192], # 135: r_midpoint_2_of_upper_eyelid_line
|
295 |
+
[64, 32, 192], # 136: r_midpoint_6_of_upper_eyelid_line
|
296 |
+
[64, 32, 192], # 137: r_midpoint_3_of_upper_crease_line
|
297 |
+
[64, 32, 192], # 138: r_midpoint_1_of_upper_crease_line
|
298 |
+
[64, 32, 192], # 139: r_midpoint_4_of_upper_crease_line
|
299 |
+
[64, 32, 192], # 140: r_centerpoint_of_upper_crease_line
|
300 |
+
[64, 32, 192], # 141: r_midpoint_5_of_upper_crease_line
|
301 |
+
[64, 32, 192], # 142: r_midpoint_2_of_upper_crease_line
|
302 |
+
[64, 32, 192], # 143: r_midpoint_6_of_upper_crease_line
|
303 |
+
[64, 192, 128], # 144: l_inner_end_of_lower_lash_line
|
304 |
+
[64, 192, 128], # 145: l_outer_end_of_lower_lash_line
|
305 |
+
[64, 192, 128], # 146: l_centerpoint_of_lower_lash_line
|
306 |
+
[64, 192, 128], # 147: l_midpoint_2_of_lower_lash_line
|
307 |
+
[64, 192, 128], # 148: l_midpoint_1_of_lower_lash_line
|
308 |
+
[64, 192, 128], # 149: l_midpoint_6_of_lower_lash_line
|
309 |
+
[64, 192, 128], # 150: l_midpoint_5_of_lower_lash_line
|
310 |
+
[64, 192, 128], # 151: l_midpoint_4_of_lower_lash_line
|
311 |
+
[64, 192, 128], # 152: l_midpoint_3_of_lower_lash_line
|
312 |
+
[64, 192, 128], # 153: l_outer_end_of_lower_eyelid_line
|
313 |
+
[64, 192, 128], # 154: l_midpoint_6_of_lower_eyelid_line
|
314 |
+
[64, 192, 128], # 155: l_midpoint_2_of_lower_eyelid_line
|
315 |
+
[64, 192, 128], # 156: l_midpoint_5_of_lower_eyelid_line
|
316 |
+
[64, 192, 128], # 157: l_centerpoint_of_lower_eyelid_line
|
317 |
+
[64, 192, 128], # 158: l_midpoint_4_of_lower_eyelid_line
|
318 |
+
[64, 192, 128], # 159: l_midpoint_1_of_lower_eyelid_line
|
319 |
+
[64, 192, 128], # 160: l_midpoint_3_of_lower_eyelid_line
|
320 |
+
[64, 192, 32], # 161: r_inner_end_of_lower_lash_line
|
321 |
+
[64, 192, 32], # 162: r_outer_end_of_lower_lash_line
|
322 |
+
[64, 192, 32], # 163: r_centerpoint_of_lower_lash_line
|
323 |
+
[64, 192, 32], # 164: r_midpoint_1_of_lower_lash_line
|
324 |
+
[64, 192, 32], # 165: r_midpoint_2_of_lower_lash_line
|
325 |
+
[64, 192, 32], # 166: r_midpoint_3_of_lower_lash_line
|
326 |
+
[64, 192, 32], # 167: r_midpoint_4_of_lower_lash_line
|
327 |
+
[64, 192, 32], # 168: r_midpoint_5_of_lower_lash_line
|
328 |
+
[64, 192, 32], # 169: r_midpoint_6_of_lower_lash_line
|
329 |
+
[64, 192, 32], # 170: r_outer_end_of_lower_eyelid_line
|
330 |
+
[64, 192, 32], # 171: r_midpoint_3_of_lower_eyelid_line
|
331 |
+
[64, 192, 32], # 172: r_midpoint_1_of_lower_eyelid_line
|
332 |
+
[64, 192, 32], # 173: r_midpoint_4_of_lower_eyelid_line
|
333 |
+
[64, 192, 32], # 174: r_centerpoint_of_lower_eyelid_line
|
334 |
+
[64, 192, 32], # 175: r_midpoint_5_of_lower_eyelid_line
|
335 |
+
[64, 192, 32], # 176: r_midpoint_2_of_lower_eyelid_line
|
336 |
+
[64, 192, 32], # 177: r_midpoint_6_of_lower_eyelid_line
|
337 |
+
[0, 192, 0], # 178: tip_of_nose
|
338 |
+
[0, 192, 0], # 179: bottom_center_of_nose
|
339 |
+
[0, 192, 0], # 180: r_outer_corner_of_nose
|
340 |
+
[0, 192, 0], # 181: l_outer_corner_of_nose
|
341 |
+
[0, 192, 0], # 182: inner_corner_of_r_nostril
|
342 |
+
[0, 192, 0], # 183: outer_corner_of_r_nostril
|
343 |
+
[0, 192, 0], # 184: upper_corner_of_r_nostril
|
344 |
+
[0, 192, 0], # 185: inner_corner_of_l_nostril
|
345 |
+
[0, 192, 0], # 186: outer_corner_of_l_nostril
|
346 |
+
[0, 192, 0], # 187: upper_corner_of_l_nostril
|
347 |
+
[192, 0, 0], # 188: r_outer_corner_of_mouth
|
348 |
+
[192, 0, 0], # 189: l_outer_corner_of_mouth
|
349 |
+
[192, 0, 0], # 190: center_of_cupid_bow
|
350 |
+
[192, 0, 0], # 191: center_of_lower_outer_lip
|
351 |
+
[192, 0, 0], # 192: midpoint_1_of_upper_outer_lip
|
352 |
+
[192, 0, 0], # 193: midpoint_2_of_upper_outer_lip
|
353 |
+
[192, 0, 0], # 194: midpoint_1_of_lower_outer_lip
|
354 |
+
[192, 0, 0], # 195: midpoint_2_of_lower_outer_lip
|
355 |
+
[192, 0, 0], # 196: midpoint_3_of_upper_outer_lip
|
356 |
+
[192, 0, 0], # 197: midpoint_4_of_upper_outer_lip
|
357 |
+
[192, 0, 0], # 198: midpoint_5_of_upper_outer_lip
|
358 |
+
[192, 0, 0], # 199: midpoint_6_of_upper_outer_lip
|
359 |
+
[192, 0, 0], # 200: midpoint_3_of_lower_outer_lip
|
360 |
+
[192, 0, 0], # 201: midpoint_4_of_lower_outer_lip
|
361 |
+
[192, 0, 0], # 202: midpoint_5_of_lower_outer_lip
|
362 |
+
[192, 0, 0], # 203: midpoint_6_of_lower_outer_lip
|
363 |
+
[0, 192, 192], # 204: r_inner_corner_of_mouth
|
364 |
+
[0, 192, 192], # 205: l_inner_corner_of_mouth
|
365 |
+
[0, 192, 192], # 206: center_of_upper_inner_lip
|
366 |
+
[0, 192, 192], # 207: center_of_lower_inner_lip
|
367 |
+
[0, 192, 192], # 208: midpoint_1_of_upper_inner_lip
|
368 |
+
[0, 192, 192], # 209: midpoint_2_of_upper_inner_lip
|
369 |
+
[0, 192, 192], # 210: midpoint_1_of_lower_inner_lip
|
370 |
+
[0, 192, 192], # 211: midpoint_2_of_lower_inner_lip
|
371 |
+
[0, 192, 192], # 212: midpoint_3_of_upper_inner_lip
|
372 |
+
[0, 192, 192], # 213: midpoint_4_of_upper_inner_lip
|
373 |
+
[0, 192, 192], # 214: midpoint_5_of_upper_inner_lip
|
374 |
+
[0, 192, 192], # 215: midpoint_6_of_upper_inner_lip
|
375 |
+
[0, 192, 192], # 216: midpoint_3_of_lower_inner_lip
|
376 |
+
[0, 192, 192], # 217: midpoint_4_of_lower_inner_lip
|
377 |
+
[0, 192, 192], # 218: midpoint_5_of_lower_inner_lip
|
378 |
+
[0, 192, 192], # 219: midpoint_6_of_lower_inner_lip. teeths removed
|
379 |
+
[200, 200, 0], # 256: l_top_end_of_inferior_crus
|
380 |
+
[200, 200, 0], # 257: l_top_end_of_superior_crus
|
381 |
+
[200, 200, 0], # 258: l_start_of_antihelix
|
382 |
+
[200, 200, 0], # 259: l_end_of_antihelix
|
383 |
+
[200, 200, 0], # 260: l_midpoint_1_of_antihelix
|
384 |
+
[200, 200, 0], # 261: l_midpoint_1_of_inferior_crus
|
385 |
+
[200, 200, 0], # 262: l_midpoint_2_of_antihelix
|
386 |
+
[200, 200, 0], # 263: l_midpoint_3_of_antihelix
|
387 |
+
[200, 200, 0], # 264: l_point_1_of_inner_helix
|
388 |
+
[200, 200, 0], # 265: l_point_2_of_inner_helix
|
389 |
+
[200, 200, 0], # 266: l_point_3_of_inner_helix
|
390 |
+
[200, 200, 0], # 267: l_point_4_of_inner_helix
|
391 |
+
[200, 200, 0], # 268: l_point_5_of_inner_helix
|
392 |
+
[200, 200, 0], # 269: l_point_6_of_inner_helix
|
393 |
+
[200, 200, 0], # 270: l_point_7_of_inner_helix
|
394 |
+
[200, 200, 0], # 271: l_highest_point_of_antitragus
|
395 |
+
[200, 200, 0], # 272: l_bottom_point_of_tragus
|
396 |
+
[200, 200, 0], # 273: l_protruding_point_of_tragus
|
397 |
+
[200, 200, 0], # 274: l_top_point_of_tragus
|
398 |
+
[200, 200, 0], # 275: l_start_point_of_crus_of_helix
|
399 |
+
[200, 200, 0], # 276: l_deepest_point_of_concha
|
400 |
+
[200, 200, 0], # 277: l_tip_of_ear_lobe
|
401 |
+
[200, 200, 0], # 278: l_midpoint_between_22_15
|
402 |
+
[200, 200, 0], # 279: l_bottom_connecting_point_of_ear_lobe
|
403 |
+
[200, 200, 0], # 280: l_top_connecting_point_of_helix
|
404 |
+
[200, 200, 0], # 281: l_point_8_of_inner_helix
|
405 |
+
[0, 200, 200], # 282: r_top_end_of_inferior_crus
|
406 |
+
[0, 200, 200], # 283: r_top_end_of_superior_crus
|
407 |
+
[0, 200, 200], # 284: r_start_of_antihelix
|
408 |
+
[0, 200, 200], # 285: r_end_of_antihelix
|
409 |
+
[0, 200, 200], # 286: r_midpoint_1_of_antihelix
|
410 |
+
[0, 200, 200], # 287: r_midpoint_1_of_inferior_crus
|
411 |
+
[0, 200, 200], # 288: r_midpoint_2_of_antihelix
|
412 |
+
[0, 200, 200], # 289: r_midpoint_3_of_antihelix
|
413 |
+
[0, 200, 200], # 290: r_point_1_of_inner_helix
|
414 |
+
[0, 200, 200], # 291: r_point_8_of_inner_helix
|
415 |
+
[0, 200, 200], # 292: r_point_3_of_inner_helix
|
416 |
+
[0, 200, 200], # 293: r_point_4_of_inner_helix
|
417 |
+
[0, 200, 200], # 294: r_point_5_of_inner_helix
|
418 |
+
[0, 200, 200], # 295: r_point_6_of_inner_helix
|
419 |
+
[0, 200, 200], # 296: r_point_7_of_inner_helix
|
420 |
+
[0, 200, 200], # 297: r_highest_point_of_antitragus
|
421 |
+
[0, 200, 200], # 298: r_bottom_point_of_tragus
|
422 |
+
[0, 200, 200], # 299: r_protruding_point_of_tragus
|
423 |
+
[0, 200, 200], # 300: r_top_point_of_tragus
|
424 |
+
[0, 200, 200], # 301: r_start_point_of_crus_of_helix
|
425 |
+
[0, 200, 200], # 302: r_deepest_point_of_concha
|
426 |
+
[0, 200, 200], # 303: r_tip_of_ear_lobe
|
427 |
+
[0, 200, 200], # 304: r_midpoint_between_22_15
|
428 |
+
[0, 200, 200], # 305: r_bottom_connecting_point_of_ear_lobe
|
429 |
+
[0, 200, 200], # 306: r_top_connecting_point_of_helix
|
430 |
+
[0, 200, 200], # 307: r_point_2_of_inner_helix
|
431 |
+
[128, 192, 64], # 308: l_center_of_iris
|
432 |
+
[128, 192, 64], # 309: l_border_of_iris_3
|
433 |
+
[128, 192, 64], # 310: l_border_of_iris_midpoint_1
|
434 |
+
[128, 192, 64], # 311: l_border_of_iris_12
|
435 |
+
[128, 192, 64], # 312: l_border_of_iris_midpoint_4
|
436 |
+
[128, 192, 64], # 313: l_border_of_iris_9
|
437 |
+
[128, 192, 64], # 314: l_border_of_iris_midpoint_3
|
438 |
+
[128, 192, 64], # 315: l_border_of_iris_6
|
439 |
+
[128, 192, 64], # 316: l_border_of_iris_midpoint_2
|
440 |
+
[192, 32, 64], # 317: r_center_of_iris
|
441 |
+
[192, 32, 64], # 318: r_border_of_iris_3
|
442 |
+
[192, 32, 64], # 319: r_border_of_iris_midpoint_1
|
443 |
+
[192, 32, 64], # 320: r_border_of_iris_12
|
444 |
+
[192, 32, 64], # 321: r_border_of_iris_midpoint_4
|
445 |
+
[192, 32, 64], # 322: r_border_of_iris_9
|
446 |
+
[192, 32, 64], # 323: r_border_of_iris_midpoint_3
|
447 |
+
[192, 32, 64], # 324: r_border_of_iris_6
|
448 |
+
[192, 32, 64], # 325: r_border_of_iris_midpoint_2
|
449 |
+
[192, 128, 64], # 326: l_center_of_pupil
|
450 |
+
[192, 128, 64], # 327: l_border_of_pupil_3
|
451 |
+
[192, 128, 64], # 328: l_border_of_pupil_midpoint_1
|
452 |
+
[192, 128, 64], # 329: l_border_of_pupil_12
|
453 |
+
[192, 128, 64], # 330: l_border_of_pupil_midpoint_4
|
454 |
+
[192, 128, 64], # 331: l_border_of_pupil_9
|
455 |
+
[192, 128, 64], # 332: l_border_of_pupil_midpoint_3
|
456 |
+
[192, 128, 64], # 333: l_border_of_pupil_6
|
457 |
+
[192, 128, 64], # 334: l_border_of_pupil_midpoint_2
|
458 |
+
[32, 192, 192], # 335: r_center_of_pupil
|
459 |
+
[32, 192, 192], # 336: r_border_of_pupil_3
|
460 |
+
[32, 192, 192], # 337: r_border_of_pupil_midpoint_1
|
461 |
+
[32, 192, 192], # 338: r_border_of_pupil_12
|
462 |
+
[32, 192, 192], # 339: r_border_of_pupil_midpoint_4
|
463 |
+
[32, 192, 192], # 340: r_border_of_pupil_9
|
464 |
+
[32, 192, 192], # 341: r_border_of_pupil_midpoint_3
|
465 |
+
[32, 192, 192], # 342: r_border_of_pupil_6
|
466 |
+
[32, 192, 192], # 343: r_border_of_pupil_midpoint_2
|
467 |
+
]
|
468 |
+
|
469 |
+
GOLIATH_KEYPOINTS = [
|
470 |
+
"nose",
|
471 |
+
"left_eye",
|
472 |
+
"right_eye",
|
473 |
+
"left_ear",
|
474 |
+
"right_ear",
|
475 |
+
"left_shoulder",
|
476 |
+
"right_shoulder",
|
477 |
+
"left_elbow",
|
478 |
+
"right_elbow",
|
479 |
+
"left_hip",
|
480 |
+
"right_hip",
|
481 |
+
"left_knee",
|
482 |
+
"right_knee",
|
483 |
+
"left_ankle",
|
484 |
+
"right_ankle",
|
485 |
+
"left_big_toe",
|
486 |
+
"left_small_toe",
|
487 |
+
"left_heel",
|
488 |
+
"right_big_toe",
|
489 |
+
"right_small_toe",
|
490 |
+
"right_heel",
|
491 |
+
"right_thumb4",
|
492 |
+
"right_thumb3",
|
493 |
+
"right_thumb2",
|
494 |
+
"right_thumb_third_joint",
|
495 |
+
"right_forefinger4",
|
496 |
+
"right_forefinger3",
|
497 |
+
"right_forefinger2",
|
498 |
+
"right_forefinger_third_joint",
|
499 |
+
"right_middle_finger4",
|
500 |
+
"right_middle_finger3",
|
501 |
+
"right_middle_finger2",
|
502 |
+
"right_middle_finger_third_joint",
|
503 |
+
"right_ring_finger4",
|
504 |
+
"right_ring_finger3",
|
505 |
+
"right_ring_finger2",
|
506 |
+
"right_ring_finger_third_joint",
|
507 |
+
"right_pinky_finger4",
|
508 |
+
"right_pinky_finger3",
|
509 |
+
"right_pinky_finger2",
|
510 |
+
"right_pinky_finger_third_joint",
|
511 |
+
"right_wrist",
|
512 |
+
"left_thumb4",
|
513 |
+
"left_thumb3",
|
514 |
+
"left_thumb2",
|
515 |
+
"left_thumb_third_joint",
|
516 |
+
"left_forefinger4",
|
517 |
+
"left_forefinger3",
|
518 |
+
"left_forefinger2",
|
519 |
+
"left_forefinger_third_joint",
|
520 |
+
"left_middle_finger4",
|
521 |
+
"left_middle_finger3",
|
522 |
+
"left_middle_finger2",
|
523 |
+
"left_middle_finger_third_joint",
|
524 |
+
"left_ring_finger4",
|
525 |
+
"left_ring_finger3",
|
526 |
+
"left_ring_finger2",
|
527 |
+
"left_ring_finger_third_joint",
|
528 |
+
"left_pinky_finger4",
|
529 |
+
"left_pinky_finger3",
|
530 |
+
"left_pinky_finger2",
|
531 |
+
"left_pinky_finger_third_joint",
|
532 |
+
"left_wrist",
|
533 |
+
"left_olecranon",
|
534 |
+
"right_olecranon",
|
535 |
+
"left_cubital_fossa",
|
536 |
+
"right_cubital_fossa",
|
537 |
+
"left_acromion",
|
538 |
+
"right_acromion",
|
539 |
+
"neck",
|
540 |
+
"center_of_glabella",
|
541 |
+
"center_of_nose_root",
|
542 |
+
"tip_of_nose_bridge",
|
543 |
+
"midpoint_1_of_nose_bridge",
|
544 |
+
"midpoint_2_of_nose_bridge",
|
545 |
+
"midpoint_3_of_nose_bridge",
|
546 |
+
"center_of_labiomental_groove",
|
547 |
+
"tip_of_chin",
|
548 |
+
"upper_startpoint_of_r_eyebrow",
|
549 |
+
"lower_startpoint_of_r_eyebrow",
|
550 |
+
"end_of_r_eyebrow",
|
551 |
+
"upper_midpoint_1_of_r_eyebrow",
|
552 |
+
"lower_midpoint_1_of_r_eyebrow",
|
553 |
+
"upper_midpoint_2_of_r_eyebrow",
|
554 |
+
"upper_midpoint_3_of_r_eyebrow",
|
555 |
+
"lower_midpoint_2_of_r_eyebrow",
|
556 |
+
"lower_midpoint_3_of_r_eyebrow",
|
557 |
+
"upper_startpoint_of_l_eyebrow",
|
558 |
+
"lower_startpoint_of_l_eyebrow",
|
559 |
+
"end_of_l_eyebrow",
|
560 |
+
"upper_midpoint_1_of_l_eyebrow",
|
561 |
+
"lower_midpoint_1_of_l_eyebrow",
|
562 |
+
"upper_midpoint_2_of_l_eyebrow",
|
563 |
+
"upper_midpoint_3_of_l_eyebrow",
|
564 |
+
"lower_midpoint_2_of_l_eyebrow",
|
565 |
+
"lower_midpoint_3_of_l_eyebrow",
|
566 |
+
"l_inner_end_of_upper_lash_line",
|
567 |
+
"l_outer_end_of_upper_lash_line",
|
568 |
+
"l_centerpoint_of_upper_lash_line",
|
569 |
+
"l_midpoint_2_of_upper_lash_line",
|
570 |
+
"l_midpoint_1_of_upper_lash_line",
|
571 |
+
"l_midpoint_6_of_upper_lash_line",
|
572 |
+
"l_midpoint_5_of_upper_lash_line",
|
573 |
+
"l_midpoint_4_of_upper_lash_line",
|
574 |
+
"l_midpoint_3_of_upper_lash_line",
|
575 |
+
"l_outer_end_of_upper_eyelid_line",
|
576 |
+
"l_midpoint_6_of_upper_eyelid_line",
|
577 |
+
"l_midpoint_2_of_upper_eyelid_line",
|
578 |
+
"l_midpoint_5_of_upper_eyelid_line",
|
579 |
+
"l_centerpoint_of_upper_eyelid_line",
|
580 |
+
"l_midpoint_4_of_upper_eyelid_line",
|
581 |
+
"l_midpoint_1_of_upper_eyelid_line",
|
582 |
+
"l_midpoint_3_of_upper_eyelid_line",
|
583 |
+
"l_midpoint_6_of_upper_crease_line",
|
584 |
+
"l_midpoint_2_of_upper_crease_line",
|
585 |
+
"l_midpoint_5_of_upper_crease_line",
|
586 |
+
"l_centerpoint_of_upper_crease_line",
|
587 |
+
"l_midpoint_4_of_upper_crease_line",
|
588 |
+
"l_midpoint_1_of_upper_crease_line",
|
589 |
+
"l_midpoint_3_of_upper_crease_line",
|
590 |
+
"r_inner_end_of_upper_lash_line",
|
591 |
+
"r_outer_end_of_upper_lash_line",
|
592 |
+
"r_centerpoint_of_upper_lash_line",
|
593 |
+
"r_midpoint_1_of_upper_lash_line",
|
594 |
+
"r_midpoint_2_of_upper_lash_line",
|
595 |
+
"r_midpoint_3_of_upper_lash_line",
|
596 |
+
"r_midpoint_4_of_upper_lash_line",
|
597 |
+
"r_midpoint_5_of_upper_lash_line",
|
598 |
+
"r_midpoint_6_of_upper_lash_line",
|
599 |
+
"r_outer_end_of_upper_eyelid_line",
|
600 |
+
"r_midpoint_3_of_upper_eyelid_line",
|
601 |
+
"r_midpoint_1_of_upper_eyelid_line",
|
602 |
+
"r_midpoint_4_of_upper_eyelid_line",
|
603 |
+
"r_centerpoint_of_upper_eyelid_line",
|
604 |
+
"r_midpoint_5_of_upper_eyelid_line",
|
605 |
+
"r_midpoint_2_of_upper_eyelid_line",
|
606 |
+
"r_midpoint_6_of_upper_eyelid_line",
|
607 |
+
"r_midpoint_3_of_upper_crease_line",
|
608 |
+
"r_midpoint_1_of_upper_crease_line",
|
609 |
+
"r_midpoint_4_of_upper_crease_line",
|
610 |
+
"r_centerpoint_of_upper_crease_line",
|
611 |
+
"r_midpoint_5_of_upper_crease_line",
|
612 |
+
"r_midpoint_2_of_upper_crease_line",
|
613 |
+
"r_midpoint_6_of_upper_crease_line",
|
614 |
+
"l_inner_end_of_lower_lash_line",
|
615 |
+
"l_outer_end_of_lower_lash_line",
|
616 |
+
"l_centerpoint_of_lower_lash_line",
|
617 |
+
"l_midpoint_2_of_lower_lash_line",
|
618 |
+
"l_midpoint_1_of_lower_lash_line",
|
619 |
+
"l_midpoint_6_of_lower_lash_line",
|
620 |
+
"l_midpoint_5_of_lower_lash_line",
|
621 |
+
"l_midpoint_4_of_lower_lash_line",
|
622 |
+
"l_midpoint_3_of_lower_lash_line",
|
623 |
+
"l_outer_end_of_lower_eyelid_line",
|
624 |
+
"l_midpoint_6_of_lower_eyelid_line",
|
625 |
+
"l_midpoint_2_of_lower_eyelid_line",
|
626 |
+
"l_midpoint_5_of_lower_eyelid_line",
|
627 |
+
"l_centerpoint_of_lower_eyelid_line",
|
628 |
+
"l_midpoint_4_of_lower_eyelid_line",
|
629 |
+
"l_midpoint_1_of_lower_eyelid_line",
|
630 |
+
"l_midpoint_3_of_lower_eyelid_line",
|
631 |
+
"r_inner_end_of_lower_lash_line",
|
632 |
+
"r_outer_end_of_lower_lash_line",
|
633 |
+
"r_centerpoint_of_lower_lash_line",
|
634 |
+
"r_midpoint_1_of_lower_lash_line",
|
635 |
+
"r_midpoint_2_of_lower_lash_line",
|
636 |
+
"r_midpoint_3_of_lower_lash_line",
|
637 |
+
"r_midpoint_4_of_lower_lash_line",
|
638 |
+
"r_midpoint_5_of_lower_lash_line",
|
639 |
+
"r_midpoint_6_of_lower_lash_line",
|
640 |
+
"r_outer_end_of_lower_eyelid_line",
|
641 |
+
"r_midpoint_3_of_lower_eyelid_line",
|
642 |
+
"r_midpoint_1_of_lower_eyelid_line",
|
643 |
+
"r_midpoint_4_of_lower_eyelid_line",
|
644 |
+
"r_centerpoint_of_lower_eyelid_line",
|
645 |
+
"r_midpoint_5_of_lower_eyelid_line",
|
646 |
+
"r_midpoint_2_of_lower_eyelid_line",
|
647 |
+
"r_midpoint_6_of_lower_eyelid_line",
|
648 |
+
"tip_of_nose",
|
649 |
+
"bottom_center_of_nose",
|
650 |
+
"r_outer_corner_of_nose",
|
651 |
+
"l_outer_corner_of_nose",
|
652 |
+
"inner_corner_of_r_nostril",
|
653 |
+
"outer_corner_of_r_nostril",
|
654 |
+
"upper_corner_of_r_nostril",
|
655 |
+
"inner_corner_of_l_nostril",
|
656 |
+
"outer_corner_of_l_nostril",
|
657 |
+
"upper_corner_of_l_nostril",
|
658 |
+
"r_outer_corner_of_mouth",
|
659 |
+
"l_outer_corner_of_mouth",
|
660 |
+
"center_of_cupid_bow",
|
661 |
+
"center_of_lower_outer_lip",
|
662 |
+
"midpoint_1_of_upper_outer_lip",
|
663 |
+
"midpoint_2_of_upper_outer_lip",
|
664 |
+
"midpoint_1_of_lower_outer_lip",
|
665 |
+
"midpoint_2_of_lower_outer_lip",
|
666 |
+
"midpoint_3_of_upper_outer_lip",
|
667 |
+
"midpoint_4_of_upper_outer_lip",
|
668 |
+
"midpoint_5_of_upper_outer_lip",
|
669 |
+
"midpoint_6_of_upper_outer_lip",
|
670 |
+
"midpoint_3_of_lower_outer_lip",
|
671 |
+
"midpoint_4_of_lower_outer_lip",
|
672 |
+
"midpoint_5_of_lower_outer_lip",
|
673 |
+
"midpoint_6_of_lower_outer_lip",
|
674 |
+
"r_inner_corner_of_mouth",
|
675 |
+
"l_inner_corner_of_mouth",
|
676 |
+
"center_of_upper_inner_lip",
|
677 |
+
"center_of_lower_inner_lip",
|
678 |
+
"midpoint_1_of_upper_inner_lip",
|
679 |
+
"midpoint_2_of_upper_inner_lip",
|
680 |
+
"midpoint_1_of_lower_inner_lip",
|
681 |
+
"midpoint_2_of_lower_inner_lip",
|
682 |
+
"midpoint_3_of_upper_inner_lip",
|
683 |
+
"midpoint_4_of_upper_inner_lip",
|
684 |
+
"midpoint_5_of_upper_inner_lip",
|
685 |
+
"midpoint_6_of_upper_inner_lip",
|
686 |
+
"midpoint_3_of_lower_inner_lip",
|
687 |
+
"midpoint_4_of_lower_inner_lip",
|
688 |
+
"midpoint_5_of_lower_inner_lip",
|
689 |
+
"midpoint_6_of_lower_inner_lip",
|
690 |
+
"l_top_end_of_inferior_crus",
|
691 |
+
"l_top_end_of_superior_crus",
|
692 |
+
"l_start_of_antihelix",
|
693 |
+
"l_end_of_antihelix",
|
694 |
+
"l_midpoint_1_of_antihelix",
|
695 |
+
"l_midpoint_1_of_inferior_crus",
|
696 |
+
"l_midpoint_2_of_antihelix",
|
697 |
+
"l_midpoint_3_of_antihelix",
|
698 |
+
"l_point_1_of_inner_helix",
|
699 |
+
"l_point_2_of_inner_helix",
|
700 |
+
"l_point_3_of_inner_helix",
|
701 |
+
"l_point_4_of_inner_helix",
|
702 |
+
"l_point_5_of_inner_helix",
|
703 |
+
"l_point_6_of_inner_helix",
|
704 |
+
"l_point_7_of_inner_helix",
|
705 |
+
"l_highest_point_of_antitragus",
|
706 |
+
"l_bottom_point_of_tragus",
|
707 |
+
"l_protruding_point_of_tragus",
|
708 |
+
"l_top_point_of_tragus",
|
709 |
+
"l_start_point_of_crus_of_helix",
|
710 |
+
"l_deepest_point_of_concha",
|
711 |
+
"l_tip_of_ear_lobe",
|
712 |
+
"l_midpoint_between_22_15",
|
713 |
+
"l_bottom_connecting_point_of_ear_lobe",
|
714 |
+
"l_top_connecting_point_of_helix",
|
715 |
+
"l_point_8_of_inner_helix",
|
716 |
+
"r_top_end_of_inferior_crus",
|
717 |
+
"r_top_end_of_superior_crus",
|
718 |
+
"r_start_of_antihelix",
|
719 |
+
"r_end_of_antihelix",
|
720 |
+
"r_midpoint_1_of_antihelix",
|
721 |
+
"r_midpoint_1_of_inferior_crus",
|
722 |
+
"r_midpoint_2_of_antihelix",
|
723 |
+
"r_midpoint_3_of_antihelix",
|
724 |
+
"r_point_1_of_inner_helix",
|
725 |
+
"r_point_8_of_inner_helix",
|
726 |
+
"r_point_3_of_inner_helix",
|
727 |
+
"r_point_4_of_inner_helix",
|
728 |
+
"r_point_5_of_inner_helix",
|
729 |
+
"r_point_6_of_inner_helix",
|
730 |
+
"r_point_7_of_inner_helix",
|
731 |
+
"r_highest_point_of_antitragus",
|
732 |
+
"r_bottom_point_of_tragus",
|
733 |
+
"r_protruding_point_of_tragus",
|
734 |
+
"r_top_point_of_tragus",
|
735 |
+
"r_start_point_of_crus_of_helix",
|
736 |
+
"r_deepest_point_of_concha",
|
737 |
+
"r_tip_of_ear_lobe",
|
738 |
+
"r_midpoint_between_22_15",
|
739 |
+
"r_bottom_connecting_point_of_ear_lobe",
|
740 |
+
"r_top_connecting_point_of_helix",
|
741 |
+
"r_point_2_of_inner_helix",
|
742 |
+
"l_center_of_iris",
|
743 |
+
"l_border_of_iris_3",
|
744 |
+
"l_border_of_iris_midpoint_1",
|
745 |
+
"l_border_of_iris_12",
|
746 |
+
"l_border_of_iris_midpoint_4",
|
747 |
+
"l_border_of_iris_9",
|
748 |
+
"l_border_of_iris_midpoint_3",
|
749 |
+
"l_border_of_iris_6",
|
750 |
+
"l_border_of_iris_midpoint_2",
|
751 |
+
"r_center_of_iris",
|
752 |
+
"r_border_of_iris_3",
|
753 |
+
"r_border_of_iris_midpoint_1",
|
754 |
+
"r_border_of_iris_12",
|
755 |
+
"r_border_of_iris_midpoint_4",
|
756 |
+
"r_border_of_iris_9",
|
757 |
+
"r_border_of_iris_midpoint_3",
|
758 |
+
"r_border_of_iris_6",
|
759 |
+
"r_border_of_iris_midpoint_2",
|
760 |
+
"l_center_of_pupil",
|
761 |
+
"l_border_of_pupil_3",
|
762 |
+
"l_border_of_pupil_midpoint_1",
|
763 |
+
"l_border_of_pupil_12",
|
764 |
+
"l_border_of_pupil_midpoint_4",
|
765 |
+
"l_border_of_pupil_9",
|
766 |
+
"l_border_of_pupil_midpoint_3",
|
767 |
+
"l_border_of_pupil_6",
|
768 |
+
"l_border_of_pupil_midpoint_2",
|
769 |
+
"r_center_of_pupil",
|
770 |
+
"r_border_of_pupil_3",
|
771 |
+
"r_border_of_pupil_midpoint_1",
|
772 |
+
"r_border_of_pupil_12",
|
773 |
+
"r_border_of_pupil_midpoint_4",
|
774 |
+
"r_border_of_pupil_9",
|
775 |
+
"r_border_of_pupil_midpoint_3",
|
776 |
+
"r_border_of_pupil_6",
|
777 |
+
"r_border_of_pupil_midpoint_2"
|
778 |
+
]
|
779 |
+
|
780 |
+
GOLIATH_SKELETON_INFO = {
|
781 |
+
0:
|
782 |
+
dict(link=('left_ankle', 'left_knee'), id=0, color=[0, 255, 0]),
|
783 |
+
1:
|
784 |
+
dict(link=('left_knee', 'left_hip'), id=1, color=[0, 255, 0]),
|
785 |
+
2:
|
786 |
+
dict(link=('right_ankle', 'right_knee'), id=2, color=[255, 128, 0]),
|
787 |
+
3:
|
788 |
+
dict(link=('right_knee', 'right_hip'), id=3, color=[255, 128, 0]),
|
789 |
+
4:
|
790 |
+
dict(link=('left_hip', 'right_hip'), id=4, color=[51, 153, 255]),
|
791 |
+
5:
|
792 |
+
dict(link=('left_shoulder', 'left_hip'), id=5, color=[51, 153, 255]),
|
793 |
+
6:
|
794 |
+
dict(link=('right_shoulder', 'right_hip'), id=6, color=[51, 153, 255]),
|
795 |
+
7:
|
796 |
+
dict(
|
797 |
+
link=('left_shoulder', 'right_shoulder'),
|
798 |
+
id=7,
|
799 |
+
color=[51, 153, 255]),
|
800 |
+
8:
|
801 |
+
dict(link=('left_shoulder', 'left_elbow'), id=8, color=[0, 255, 0]),
|
802 |
+
9:
|
803 |
+
dict(
|
804 |
+
link=('right_shoulder', 'right_elbow'), id=9, color=[255, 128, 0]),
|
805 |
+
10:
|
806 |
+
dict(link=('left_elbow', 'left_wrist'), id=10, color=[0, 255, 0]),
|
807 |
+
11:
|
808 |
+
dict(link=('right_elbow', 'right_wrist'), id=11, color=[255, 128, 0]),
|
809 |
+
12:
|
810 |
+
dict(link=('left_eye', 'right_eye'), id=12, color=[51, 153, 255]),
|
811 |
+
13:
|
812 |
+
dict(link=('nose', 'left_eye'), id=13, color=[51, 153, 255]),
|
813 |
+
14:
|
814 |
+
dict(link=('nose', 'right_eye'), id=14, color=[51, 153, 255]),
|
815 |
+
15:
|
816 |
+
dict(link=('left_eye', 'left_ear'), id=15, color=[51, 153, 255]),
|
817 |
+
16:
|
818 |
+
dict(link=('right_eye', 'right_ear'), id=16, color=[51, 153, 255]),
|
819 |
+
17:
|
820 |
+
dict(link=('left_ear', 'left_shoulder'), id=17, color=[51, 153, 255]),
|
821 |
+
18:
|
822 |
+
dict(
|
823 |
+
link=('right_ear', 'right_shoulder'), id=18, color=[51, 153, 255]),
|
824 |
+
19:
|
825 |
+
dict(link=('left_ankle', 'left_big_toe'), id=19, color=[0, 255, 0]),
|
826 |
+
20:
|
827 |
+
dict(link=('left_ankle', 'left_small_toe'), id=20, color=[0, 255, 0]),
|
828 |
+
21:
|
829 |
+
dict(link=('left_ankle', 'left_heel'), id=21, color=[0, 255, 0]),
|
830 |
+
22:
|
831 |
+
dict(
|
832 |
+
link=('right_ankle', 'right_big_toe'), id=22, color=[255, 128, 0]),
|
833 |
+
23:
|
834 |
+
dict(
|
835 |
+
link=('right_ankle', 'right_small_toe'),
|
836 |
+
id=23,
|
837 |
+
color=[255, 128, 0]),
|
838 |
+
24:
|
839 |
+
dict(link=('right_ankle', 'right_heel'), id=24, color=[255, 128, 0]),
|
840 |
+
25:
|
841 |
+
dict(
|
842 |
+
link=('left_wrist', 'left_thumb_third_joint'), id=25, color=[255, 128,
|
843 |
+
0]),
|
844 |
+
26:
|
845 |
+
dict(link=('left_thumb_third_joint', 'left_thumb2'), id=26, color=[255, 128, 0]),
|
846 |
+
27:
|
847 |
+
dict(link=('left_thumb2', 'left_thumb3'), id=27, color=[255, 128, 0]),
|
848 |
+
28:
|
849 |
+
dict(link=('left_thumb3', 'left_thumb4'), id=28, color=[255, 128, 0]),
|
850 |
+
29:
|
851 |
+
dict(
|
852 |
+
link=('left_wrist', 'left_forefinger_third_joint'),
|
853 |
+
id=29,
|
854 |
+
color=[255, 153, 255]),
|
855 |
+
30:
|
856 |
+
dict(
|
857 |
+
link=('left_forefinger_third_joint', 'left_forefinger2'),
|
858 |
+
id=30,
|
859 |
+
color=[255, 153, 255]),
|
860 |
+
31:
|
861 |
+
dict(
|
862 |
+
link=('left_forefinger2', 'left_forefinger3'),
|
863 |
+
id=31,
|
864 |
+
color=[255, 153, 255]),
|
865 |
+
32:
|
866 |
+
dict(
|
867 |
+
link=('left_forefinger3', 'left_forefinger4'),
|
868 |
+
id=32,
|
869 |
+
color=[255, 153, 255]),
|
870 |
+
33:
|
871 |
+
dict(
|
872 |
+
link=('left_wrist', 'left_middle_finger_third_joint'),
|
873 |
+
id=33,
|
874 |
+
color=[102, 178, 255]),
|
875 |
+
34:
|
876 |
+
dict(
|
877 |
+
link=('left_middle_finger_third_joint', 'left_middle_finger2'),
|
878 |
+
id=34,
|
879 |
+
color=[102, 178, 255]),
|
880 |
+
35:
|
881 |
+
dict(
|
882 |
+
link=('left_middle_finger2', 'left_middle_finger3'),
|
883 |
+
id=35,
|
884 |
+
color=[102, 178, 255]),
|
885 |
+
36:
|
886 |
+
dict(
|
887 |
+
link=('left_middle_finger3', 'left_middle_finger4'),
|
888 |
+
id=36,
|
889 |
+
color=[102, 178, 255]),
|
890 |
+
37:
|
891 |
+
dict(
|
892 |
+
link=('left_wrist', 'left_ring_finger_third_joint'),
|
893 |
+
id=37,
|
894 |
+
color=[255, 51, 51]),
|
895 |
+
38:
|
896 |
+
dict(
|
897 |
+
link=('left_ring_finger_third_joint', 'left_ring_finger2'),
|
898 |
+
id=38,
|
899 |
+
color=[255, 51, 51]),
|
900 |
+
39:
|
901 |
+
dict(
|
902 |
+
link=('left_ring_finger2', 'left_ring_finger3'),
|
903 |
+
id=39,
|
904 |
+
color=[255, 51, 51]),
|
905 |
+
40:
|
906 |
+
dict(
|
907 |
+
link=('left_ring_finger3', 'left_ring_finger4'),
|
908 |
+
id=40,
|
909 |
+
color=[255, 51, 51]),
|
910 |
+
41:
|
911 |
+
dict(
|
912 |
+
link=('left_wrist', 'left_pinky_finger_third_joint'),
|
913 |
+
id=41,
|
914 |
+
color=[0, 255, 0]),
|
915 |
+
42:
|
916 |
+
dict(
|
917 |
+
link=('left_pinky_finger_third_joint', 'left_pinky_finger2'),
|
918 |
+
id=42,
|
919 |
+
color=[0, 255, 0]),
|
920 |
+
43:
|
921 |
+
dict(
|
922 |
+
link=('left_pinky_finger2', 'left_pinky_finger3'),
|
923 |
+
id=43,
|
924 |
+
color=[0, 255, 0]),
|
925 |
+
44:
|
926 |
+
dict(
|
927 |
+
link=('left_pinky_finger3', 'left_pinky_finger4'),
|
928 |
+
id=44,
|
929 |
+
color=[0, 255, 0]),
|
930 |
+
45:
|
931 |
+
dict(
|
932 |
+
link=('right_wrist', 'right_thumb_third_joint'),
|
933 |
+
id=45,
|
934 |
+
color=[255, 128, 0]),
|
935 |
+
46:
|
936 |
+
dict(
|
937 |
+
link=('right_thumb_third_joint', 'right_thumb2'), id=46, color=[255, 128, 0]),
|
938 |
+
47:
|
939 |
+
dict(
|
940 |
+
link=('right_thumb2', 'right_thumb3'), id=47, color=[255, 128, 0]),
|
941 |
+
48:
|
942 |
+
dict(
|
943 |
+
link=('right_thumb3', 'right_thumb4'), id=48, color=[255, 128, 0]),
|
944 |
+
49:
|
945 |
+
dict(
|
946 |
+
link=('right_wrist', 'right_forefinger_third_joint'),
|
947 |
+
id=49,
|
948 |
+
color=[255, 153, 255]),
|
949 |
+
50:
|
950 |
+
dict(
|
951 |
+
link=('right_forefinger_third_joint', 'right_forefinger2'),
|
952 |
+
id=50,
|
953 |
+
color=[255, 153, 255]),
|
954 |
+
51:
|
955 |
+
dict(
|
956 |
+
link=('right_forefinger2', 'right_forefinger3'),
|
957 |
+
id=51,
|
958 |
+
color=[255, 153, 255]),
|
959 |
+
52:
|
960 |
+
dict(
|
961 |
+
link=('right_forefinger3', 'right_forefinger4'),
|
962 |
+
id=52,
|
963 |
+
color=[255, 153, 255]),
|
964 |
+
53:
|
965 |
+
dict(
|
966 |
+
link=('right_wrist', 'right_middle_finger_third_joint'),
|
967 |
+
id=53,
|
968 |
+
color=[102, 178, 255]),
|
969 |
+
54:
|
970 |
+
dict(
|
971 |
+
link=('right_middle_finger_third_joint', 'right_middle_finger2'),
|
972 |
+
id=54,
|
973 |
+
color=[102, 178, 255]),
|
974 |
+
55:
|
975 |
+
dict(
|
976 |
+
link=('right_middle_finger2', 'right_middle_finger3'),
|
977 |
+
id=55,
|
978 |
+
color=[102, 178, 255]),
|
979 |
+
56:
|
980 |
+
dict(
|
981 |
+
link=('right_middle_finger3', 'right_middle_finger4'),
|
982 |
+
id=56,
|
983 |
+
color=[102, 178, 255]),
|
984 |
+
57:
|
985 |
+
dict(
|
986 |
+
link=('right_wrist', 'right_ring_finger_third_joint'),
|
987 |
+
id=57,
|
988 |
+
color=[255, 51, 51]),
|
989 |
+
58:
|
990 |
+
dict(
|
991 |
+
link=('right_ring_finger_third_joint', 'right_ring_finger2'),
|
992 |
+
id=58,
|
993 |
+
color=[255, 51, 51]),
|
994 |
+
59:
|
995 |
+
dict(
|
996 |
+
link=('right_ring_finger2', 'right_ring_finger3'),
|
997 |
+
id=59,
|
998 |
+
color=[255, 51, 51]),
|
999 |
+
60:
|
1000 |
+
dict(
|
1001 |
+
link=('right_ring_finger3', 'right_ring_finger4'),
|
1002 |
+
id=60,
|
1003 |
+
color=[255, 51, 51]),
|
1004 |
+
61:
|
1005 |
+
dict(
|
1006 |
+
link=('right_wrist', 'right_pinky_finger_third_joint'),
|
1007 |
+
id=61,
|
1008 |
+
color=[0, 255, 0]),
|
1009 |
+
62:
|
1010 |
+
dict(
|
1011 |
+
link=('right_pinky_finger_third_joint', 'right_pinky_finger2'),
|
1012 |
+
id=62,
|
1013 |
+
color=[0, 255, 0]),
|
1014 |
+
63:
|
1015 |
+
dict(
|
1016 |
+
link=('right_pinky_finger2', 'right_pinky_finger3'),
|
1017 |
+
id=63,
|
1018 |
+
color=[0, 255, 0]),
|
1019 |
+
64:
|
1020 |
+
dict(
|
1021 |
+
link=('right_pinky_finger3', 'right_pinky_finger4'),
|
1022 |
+
id=64,
|
1023 |
+
color=[0, 255, 0])
|
1024 |
+
}
|
detector_utils.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Sequence, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
from mmcv.ops import RoIPool
|
7 |
+
from mmengine.dataset import Compose, pseudo_collate
|
8 |
+
from mmengine.device import get_device
|
9 |
+
from mmengine.registry import init_default_scope
|
10 |
+
from mmdet.apis import inference_detector, init_detector
|
11 |
+
from mmdet.structures import DetDataSample, SampleList
|
12 |
+
from mmdet.utils import get_test_pipeline_cfg
|
13 |
+
|
14 |
+
|
15 |
+
ImagesType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
|
16 |
+
|
17 |
+
def nms(dets: np.ndarray, thr: float):
|
18 |
+
"""Greedily select boxes with high confidence and overlap <= thr.
|
19 |
+
Args:
|
20 |
+
dets (np.ndarray): [[x1, y1, x2, y2, score]].
|
21 |
+
thr (float): Retain overlap < thr.
|
22 |
+
Returns:
|
23 |
+
list: Indexes to keep.
|
24 |
+
"""
|
25 |
+
if len(dets) == 0:
|
26 |
+
return []
|
27 |
+
|
28 |
+
x1 = dets[:, 0]
|
29 |
+
y1 = dets[:, 1]
|
30 |
+
x2 = dets[:, 2]
|
31 |
+
y2 = dets[:, 3]
|
32 |
+
scores = dets[:, 4]
|
33 |
+
|
34 |
+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
35 |
+
order = scores.argsort()[::-1]
|
36 |
+
|
37 |
+
keep = []
|
38 |
+
while len(order) > 0:
|
39 |
+
i = order[0]
|
40 |
+
keep.append(i)
|
41 |
+
xx1 = np.maximum(x1[i], x1[order[1:]])
|
42 |
+
yy1 = np.maximum(y1[i], y1[order[1:]])
|
43 |
+
xx2 = np.minimum(x2[i], x2[order[1:]])
|
44 |
+
yy2 = np.minimum(y2[i], y2[order[1:]])
|
45 |
+
|
46 |
+
w = np.maximum(0.0, xx2 - xx1 + 1)
|
47 |
+
h = np.maximum(0.0, yy2 - yy1 + 1)
|
48 |
+
inter = w * h
|
49 |
+
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
50 |
+
|
51 |
+
inds = np.where(ovr <= thr)[0]
|
52 |
+
order = order[inds + 1]
|
53 |
+
|
54 |
+
return keep
|
55 |
+
|
56 |
+
def adapt_mmdet_pipeline(cfg):
|
57 |
+
"""Converts pipeline types in MMDetection's test dataloader to use the
|
58 |
+
'mmdet' namespace.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
cfg (ConfigDict): Configuration dictionary for MMDetection.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
ConfigDict: Configuration dictionary with updated pipeline types.
|
65 |
+
"""
|
66 |
+
# use lazy import to avoid hard dependence on mmdet
|
67 |
+
from mmdet.datasets import transforms
|
68 |
+
|
69 |
+
if 'test_dataloader' not in cfg:
|
70 |
+
return cfg
|
71 |
+
|
72 |
+
pipeline = cfg.test_dataloader.dataset.pipeline
|
73 |
+
for trans in pipeline:
|
74 |
+
if trans['type'] in dir(transforms):
|
75 |
+
trans['type'] = 'mmdet.' + trans['type']
|
76 |
+
|
77 |
+
return cfg
|
78 |
+
|
79 |
+
|
80 |
+
def inference_detector(
|
81 |
+
model: torch.nn.Module,
|
82 |
+
imgs: ImagesType,
|
83 |
+
test_pipeline: Optional[Compose] = None,
|
84 |
+
text_prompt: Optional[str] = None,
|
85 |
+
custom_entities: bool = False,
|
86 |
+
) -> Union[DetDataSample, SampleList]:
|
87 |
+
"""Inference image(s) with the detector.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
model (nn.Module): The loaded detector.
|
91 |
+
imgs (str, ndarray, Sequence[str/ndarray]):
|
92 |
+
Either image files or loaded images.
|
93 |
+
test_pipeline (:obj:`Compose`): Test pipeline.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
:obj:`DetDataSample` or list[:obj:`DetDataSample`]:
|
97 |
+
If imgs is a list or tuple, the same length list type results
|
98 |
+
will be returned, otherwise return the detection results directly.
|
99 |
+
"""
|
100 |
+
if isinstance(imgs, torch.Tensor):
|
101 |
+
if imgs.is_cuda:
|
102 |
+
imgs = imgs.cpu()
|
103 |
+
|
104 |
+
# Remove batch dimension and transpose
|
105 |
+
imgs = imgs.squeeze(0).permute(1, 2, 0).numpy()
|
106 |
+
|
107 |
+
# Ensure the data type is appropriate (uint8 for most image processing functions)
|
108 |
+
imgs = (imgs * 255).astype(np.uint8)
|
109 |
+
|
110 |
+
if isinstance(imgs, (list, tuple)) or (isinstance(imgs, np.ndarray) and len(imgs.shape) == 4):
|
111 |
+
is_batch = True
|
112 |
+
else:
|
113 |
+
imgs = [imgs]
|
114 |
+
is_batch = False
|
115 |
+
|
116 |
+
cfg = model.cfg
|
117 |
+
|
118 |
+
if test_pipeline is None:
|
119 |
+
cfg = cfg.copy()
|
120 |
+
test_pipeline = get_test_pipeline_cfg(cfg)
|
121 |
+
if isinstance(imgs[0], np.ndarray):
|
122 |
+
# Calling this method across libraries will result
|
123 |
+
# in module unregistered error if not prefixed with mmdet.
|
124 |
+
test_pipeline[0].type = "mmdet.LoadImageFromNDArray"
|
125 |
+
|
126 |
+
test_pipeline = Compose(test_pipeline)
|
127 |
+
|
128 |
+
if model.data_preprocessor.device.type == "cpu":
|
129 |
+
for m in model.modules():
|
130 |
+
assert not isinstance(
|
131 |
+
m, RoIPool
|
132 |
+
), "CPU inference with RoIPool is not supported currently."
|
133 |
+
|
134 |
+
result_list = []
|
135 |
+
for i, img in enumerate(imgs):
|
136 |
+
# prepare data
|
137 |
+
if isinstance(img, np.ndarray):
|
138 |
+
# TODO: remove img_id.
|
139 |
+
data_ = dict(img=img, img_id=0)
|
140 |
+
else:
|
141 |
+
# TODO: remove img_id.
|
142 |
+
data_ = dict(img_path=img, img_id=0)
|
143 |
+
|
144 |
+
if text_prompt:
|
145 |
+
data_["text"] = text_prompt
|
146 |
+
data_["custom_entities"] = custom_entities
|
147 |
+
|
148 |
+
# build the data pipeline
|
149 |
+
data_ = test_pipeline(data_)
|
150 |
+
|
151 |
+
data_["inputs"] = [data_["inputs"]]
|
152 |
+
data_["data_samples"] = [data_["data_samples"]]
|
153 |
+
|
154 |
+
# forward the model
|
155 |
+
with torch.no_grad(), torch.autocast(device_type=get_device(), dtype=torch.bfloat16):
|
156 |
+
results = model.test_step(data_)[0]
|
157 |
+
|
158 |
+
result_list.append(results)
|
159 |
+
|
160 |
+
if not is_batch:
|
161 |
+
return result_list[0]
|
162 |
+
else:
|
163 |
+
return result_list
|
164 |
+
|
165 |
+
|
166 |
+
def process_one_image_bbox(pred_instance, det_cat_id, bbox_thr, nms_thr):
|
167 |
+
bboxes = np.concatenate(
|
168 |
+
(pred_instance.bboxes, pred_instance.scores[:, None]), axis=1
|
169 |
+
)
|
170 |
+
bboxes = bboxes[
|
171 |
+
np.logical_and(
|
172 |
+
pred_instance.labels == det_cat_id,
|
173 |
+
pred_instance.scores > bbox_thr,
|
174 |
+
)
|
175 |
+
]
|
176 |
+
bboxes = bboxes[nms(bboxes, nms_thr), :4]
|
177 |
+
return bboxes
|
178 |
+
|
179 |
+
|
180 |
+
def process_images_detector(imgs, detector):
|
181 |
+
"""Visualize predicted keypoints (and heatmaps) of one image."""
|
182 |
+
# predict bbox
|
183 |
+
det_results = inference_detector(detector, imgs)
|
184 |
+
pred_instances = list(
|
185 |
+
map(lambda det_result: det_result.pred_instances.numpy(), det_results)
|
186 |
+
)
|
187 |
+
bboxes_batch = list(
|
188 |
+
map(
|
189 |
+
lambda pred_instance: process_one_image_bbox(
|
190 |
+
pred_instance, 0, 0.3, 0.3 ## argparse.Namespace(det_cat_id=0, bbox_thr=0.3, nms_thr=0.3),
|
191 |
+
),
|
192 |
+
pred_instances,
|
193 |
+
)
|
194 |
+
)
|
195 |
+
|
196 |
+
return bboxes_batch
|
external/cv/.gitignore
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# PyTorch checkpoint
|
10 |
+
*.pth
|
11 |
+
|
12 |
+
# Distribution / packaging
|
13 |
+
.Python
|
14 |
+
build/
|
15 |
+
develop-eggs/
|
16 |
+
#dist/
|
17 |
+
downloads/
|
18 |
+
eggs/
|
19 |
+
.eggs/
|
20 |
+
lib/
|
21 |
+
lib64/
|
22 |
+
parts/
|
23 |
+
sdist/
|
24 |
+
var/
|
25 |
+
wheels/
|
26 |
+
*.egg-info/
|
27 |
+
.installed.cfg
|
28 |
+
*.egg
|
29 |
+
MANIFEST
|
30 |
+
mlu-ops/
|
31 |
+
mlu-ops.*
|
32 |
+
|
33 |
+
# PyInstaller
|
34 |
+
# Usually these files are written by a python script from a template
|
35 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
36 |
+
*.manifest
|
37 |
+
*.spec
|
38 |
+
|
39 |
+
# Installer logs
|
40 |
+
pip-log.txt
|
41 |
+
pip-delete-this-directory.txt
|
42 |
+
|
43 |
+
# Unit test / coverage reports
|
44 |
+
htmlcov/
|
45 |
+
.tox/
|
46 |
+
.coverage
|
47 |
+
.coverage.*
|
48 |
+
.cache
|
49 |
+
nosetests.xml
|
50 |
+
coverage.xml
|
51 |
+
*.cover
|
52 |
+
.hypothesis/
|
53 |
+
.pytest_cache/
|
54 |
+
|
55 |
+
# Translations
|
56 |
+
*.mo
|
57 |
+
*.pot
|
58 |
+
|
59 |
+
# Django stuff:
|
60 |
+
*.log
|
61 |
+
local_settings.py
|
62 |
+
db.sqlite3
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/en/_build/
|
73 |
+
docs/en/api/generated/
|
74 |
+
docs/zh_cn/_build/
|
75 |
+
docs/zh_cn/api/generated/
|
76 |
+
|
77 |
+
# PyBuilder
|
78 |
+
target/
|
79 |
+
|
80 |
+
# Jupyter Notebook
|
81 |
+
.ipynb_checkpoints
|
82 |
+
|
83 |
+
# pyenv
|
84 |
+
.python-version
|
85 |
+
|
86 |
+
# celery beat schedule file
|
87 |
+
celerybeat-schedule
|
88 |
+
|
89 |
+
# SageMath parsed files
|
90 |
+
*.sage.py
|
91 |
+
|
92 |
+
# Environments
|
93 |
+
.env
|
94 |
+
.venv
|
95 |
+
env/
|
96 |
+
venv/
|
97 |
+
ENV/
|
98 |
+
env.bak/
|
99 |
+
venv.bak/
|
100 |
+
|
101 |
+
# Spyder project settings
|
102 |
+
.spyderproject
|
103 |
+
.spyproject
|
104 |
+
|
105 |
+
# Rope project settings
|
106 |
+
.ropeproject
|
107 |
+
|
108 |
+
# mkdocs documentation
|
109 |
+
/site
|
110 |
+
|
111 |
+
# mypy
|
112 |
+
.mypy_cache/
|
113 |
+
|
114 |
+
# editors and IDEs
|
115 |
+
.idea/
|
116 |
+
.vscode/
|
117 |
+
|
118 |
+
# custom
|
119 |
+
.DS_Store
|
120 |
+
|
121 |
+
# datasets and logs and checkpoints
|
122 |
+
data/
|
123 |
+
work_dir/
|
124 |
+
|
125 |
+
src/
|
external/cv/MANIFEST.in
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
include requirements/runtime.txt
|
2 |
+
include mmcv/ops/csrc/common/cuda/*.cuh mmcv/ops/csrc/common/cuda/*.hpp mmcv/ops/csrc/common/*.hpp
|
3 |
+
include mmcv/ops/csrc/pytorch/*.cpp mmcv/ops/csrc/pytorch/cuda/*.cu mmcv/ops/csrc/pytorch/cuda/*.cpp mmcv/ops/csrc/pytorch/cpu/*.cpp
|
4 |
+
include mmcv/ops/csrc/parrots/*.h mmcv/ops/csrc/parrots/*.cpp
|
5 |
+
include mmcv/ops/csrc/pytorch/mps/*.mm mmcv/ops/csrc/common/mps/*.h mmcv/ops/csrc/common/mps/*.mm
|
6 |
+
recursive-include mmcv/ops/csrc/ *.h *.hpp *.cpp *.cuh *.cu *.mm
|
external/cv/dist/sapiens_cv-1.0.0-cp310-cp310-linux_x86_64.whl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:746f2be13eefdfe43a59d9c415e03a4b0b922e6ce487b76a572a376ae76c9300
|
3 |
+
size 30006791
|
external/cv/mmcv/__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# flake8: noqa
|
8 |
+
from .arraymisc import *
|
9 |
+
from .image import *
|
10 |
+
from .transforms import *
|
11 |
+
from .version import *
|
12 |
+
from .video import *
|
13 |
+
from .visualization import *
|
14 |
+
|
15 |
+
# The following modules are not imported to this level, so mmcv may be used
|
16 |
+
# without PyTorch.
|
17 |
+
# - op
|
18 |
+
# - utils
|
external/cv/mmcv/arraymisc/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from .quantization import dequantize, quantize
|
8 |
+
|
9 |
+
__all__ = ['quantize', 'dequantize']
|
external/cv/mmcv/arraymisc/quantization.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from typing import Union
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
|
12 |
+
def quantize(arr: np.ndarray,
|
13 |
+
min_val: Union[int, float],
|
14 |
+
max_val: Union[int, float],
|
15 |
+
levels: int,
|
16 |
+
dtype=np.int64) -> tuple:
|
17 |
+
"""Quantize an array of (-inf, inf) to [0, levels-1].
|
18 |
+
|
19 |
+
Args:
|
20 |
+
arr (ndarray): Input array.
|
21 |
+
min_val (int or float): Minimum value to be clipped.
|
22 |
+
max_val (int or float): Maximum value to be clipped.
|
23 |
+
levels (int): Quantization levels.
|
24 |
+
dtype (np.type): The type of the quantized array.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
tuple: Quantized array.
|
28 |
+
"""
|
29 |
+
if not (isinstance(levels, int) and levels > 1):
|
30 |
+
raise ValueError(
|
31 |
+
f'levels must be a positive integer, but got {levels}')
|
32 |
+
if min_val >= max_val:
|
33 |
+
raise ValueError(
|
34 |
+
f'min_val ({min_val}) must be smaller than max_val ({max_val})')
|
35 |
+
|
36 |
+
arr = np.clip(arr, min_val, max_val) - min_val
|
37 |
+
quantized_arr = np.minimum(
|
38 |
+
np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
|
39 |
+
|
40 |
+
return quantized_arr
|
41 |
+
|
42 |
+
|
43 |
+
def dequantize(arr: np.ndarray,
|
44 |
+
min_val: Union[int, float],
|
45 |
+
max_val: Union[int, float],
|
46 |
+
levels: int,
|
47 |
+
dtype=np.float64) -> tuple:
|
48 |
+
"""Dequantize an array.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
arr (ndarray): Input array.
|
52 |
+
min_val (int or float): Minimum value to be clipped.
|
53 |
+
max_val (int or float): Maximum value to be clipped.
|
54 |
+
levels (int): Quantization levels.
|
55 |
+
dtype (np.type): The type of the dequantized array.
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
tuple: Dequantized array.
|
59 |
+
"""
|
60 |
+
if not (isinstance(levels, int) and levels > 1):
|
61 |
+
raise ValueError(
|
62 |
+
f'levels must be a positive integer, but got {levels}')
|
63 |
+
if min_val >= max_val:
|
64 |
+
raise ValueError(
|
65 |
+
f'min_val ({min_val}) must be smaller than max_val ({max_val})')
|
66 |
+
|
67 |
+
dequantized_arr = (arr + 0.5).astype(dtype) * (max_val -
|
68 |
+
min_val) / levels + min_val
|
69 |
+
|
70 |
+
return dequantized_arr
|
external/cv/mmcv/cnn/__init__.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from .alexnet import AlexNet
|
8 |
+
# yapf: disable
|
9 |
+
from .bricks import (ContextBlock, Conv2d, Conv3d, ConvAWS2d, ConvModule,
|
10 |
+
ConvTranspose2d, ConvTranspose3d, ConvWS2d,
|
11 |
+
DepthwiseSeparableConvModule, GeneralizedAttention,
|
12 |
+
HSigmoid, HSwish, Linear, MaxPool2d, MaxPool3d,
|
13 |
+
NonLocal1d, NonLocal2d, NonLocal3d, Scale, Swish,
|
14 |
+
build_activation_layer, build_conv_layer,
|
15 |
+
build_norm_layer, build_padding_layer, build_plugin_layer,
|
16 |
+
build_upsample_layer, conv_ws_2d, is_norm)
|
17 |
+
# yapf: enable
|
18 |
+
from .resnet import ResNet, make_res_layer
|
19 |
+
from .rfsearch import Conv2dRFSearchOp, RFSearchHook
|
20 |
+
from .utils import fuse_conv_bn, get_model_complexity_info
|
21 |
+
from .vgg import VGG, make_vgg_layer
|
22 |
+
|
23 |
+
__all__ = [
|
24 |
+
'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer',
|
25 |
+
'ConvModule', 'build_activation_layer', 'build_conv_layer',
|
26 |
+
'build_norm_layer', 'build_padding_layer', 'build_upsample_layer',
|
27 |
+
'build_plugin_layer', 'is_norm', 'NonLocal1d', 'NonLocal2d', 'NonLocal3d',
|
28 |
+
'ContextBlock', 'HSigmoid', 'Swish', 'HSwish', 'GeneralizedAttention',
|
29 |
+
'Scale', 'conv_ws_2d', 'ConvAWS2d', 'ConvWS2d',
|
30 |
+
'DepthwiseSeparableConvModule', 'Linear', 'Conv2d', 'ConvTranspose2d',
|
31 |
+
'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d', 'Conv3d', 'fuse_conv_bn',
|
32 |
+
'get_model_complexity_info', 'Conv2dRFSearchOp', 'RFSearchHook'
|
33 |
+
]
|
external/cv/mmcv/cnn/alexnet.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
from typing import Optional
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
from mmengine.runner import load_checkpoint
|
13 |
+
|
14 |
+
|
15 |
+
class AlexNet(nn.Module):
|
16 |
+
"""AlexNet backbone.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
num_classes (int): number of classes for classification.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, num_classes: int = -1):
|
23 |
+
super().__init__()
|
24 |
+
self.num_classes = num_classes
|
25 |
+
self.features = nn.Sequential(
|
26 |
+
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
|
27 |
+
nn.ReLU(inplace=True),
|
28 |
+
nn.MaxPool2d(kernel_size=3, stride=2),
|
29 |
+
nn.Conv2d(64, 192, kernel_size=5, padding=2),
|
30 |
+
nn.ReLU(inplace=True),
|
31 |
+
nn.MaxPool2d(kernel_size=3, stride=2),
|
32 |
+
nn.Conv2d(192, 384, kernel_size=3, padding=1),
|
33 |
+
nn.ReLU(inplace=True),
|
34 |
+
nn.Conv2d(384, 256, kernel_size=3, padding=1),
|
35 |
+
nn.ReLU(inplace=True),
|
36 |
+
nn.Conv2d(256, 256, kernel_size=3, padding=1),
|
37 |
+
nn.ReLU(inplace=True),
|
38 |
+
nn.MaxPool2d(kernel_size=3, stride=2),
|
39 |
+
)
|
40 |
+
if self.num_classes > 0:
|
41 |
+
self.classifier = nn.Sequential(
|
42 |
+
nn.Dropout(),
|
43 |
+
nn.Linear(256 * 6 * 6, 4096),
|
44 |
+
nn.ReLU(inplace=True),
|
45 |
+
nn.Dropout(),
|
46 |
+
nn.Linear(4096, 4096),
|
47 |
+
nn.ReLU(inplace=True),
|
48 |
+
nn.Linear(4096, num_classes),
|
49 |
+
)
|
50 |
+
|
51 |
+
def init_weights(self, pretrained: Optional[str] = None) -> None:
|
52 |
+
if isinstance(pretrained, str):
|
53 |
+
logger = logging.getLogger()
|
54 |
+
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
55 |
+
elif pretrained is None:
|
56 |
+
# use default initializer
|
57 |
+
pass
|
58 |
+
else:
|
59 |
+
raise TypeError('pretrained must be a str or None')
|
60 |
+
|
61 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
62 |
+
|
63 |
+
x = self.features(x)
|
64 |
+
if self.num_classes > 0:
|
65 |
+
x = x.view(x.size(0), 256 * 6 * 6)
|
66 |
+
x = self.classifier(x)
|
67 |
+
|
68 |
+
return x
|
external/cv/mmcv/cnn/bricks/__init__.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from .activation import build_activation_layer
|
8 |
+
from .context_block import ContextBlock
|
9 |
+
from .conv import build_conv_layer
|
10 |
+
from .conv2d_adaptive_padding import Conv2dAdaptivePadding
|
11 |
+
from .conv_module import ConvModule
|
12 |
+
from .conv_ws import ConvAWS2d, ConvWS2d, conv_ws_2d
|
13 |
+
from .depthwise_separable_conv_module import DepthwiseSeparableConvModule
|
14 |
+
from .drop import Dropout, DropPath
|
15 |
+
from .generalized_attention import GeneralizedAttention
|
16 |
+
from .hsigmoid import HSigmoid
|
17 |
+
from .hswish import HSwish
|
18 |
+
from .non_local import NonLocal1d, NonLocal2d, NonLocal3d
|
19 |
+
from .norm import build_norm_layer, is_norm
|
20 |
+
from .padding import build_padding_layer
|
21 |
+
from .plugin import build_plugin_layer
|
22 |
+
from .scale import LayerScale, Scale
|
23 |
+
from .swish import Swish
|
24 |
+
from .upsample import build_upsample_layer
|
25 |
+
from .wrappers import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
|
26 |
+
Linear, MaxPool2d, MaxPool3d)
|
27 |
+
|
28 |
+
__all__ = [
|
29 |
+
'ConvModule', 'build_activation_layer', 'build_conv_layer',
|
30 |
+
'build_norm_layer', 'build_padding_layer', 'build_upsample_layer',
|
31 |
+
'build_plugin_layer', 'is_norm', 'HSigmoid', 'HSwish', 'NonLocal1d',
|
32 |
+
'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'GeneralizedAttention',
|
33 |
+
'Scale', 'ConvAWS2d', 'ConvWS2d', 'conv_ws_2d',
|
34 |
+
'DepthwiseSeparableConvModule', 'Swish', 'Linear', 'Conv2dAdaptivePadding',
|
35 |
+
'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d',
|
36 |
+
'Conv3d', 'Dropout', 'DropPath', 'LayerScale'
|
37 |
+
]
|
external/cv/mmcv/cnn/bricks/activation.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from typing import Dict
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from mmengine.registry import MODELS
|
13 |
+
from mmengine.utils import digit_version
|
14 |
+
from mmengine.utils.dl_utils import TORCH_VERSION
|
15 |
+
|
16 |
+
for module in [
|
17 |
+
nn.ReLU, nn.LeakyReLU, nn.PReLU, nn.RReLU, nn.ReLU6, nn.ELU,
|
18 |
+
nn.Sigmoid, nn.Tanh
|
19 |
+
]:
|
20 |
+
MODELS.register_module(module=module)
|
21 |
+
|
22 |
+
if digit_version(torch.__version__) >= digit_version('1.7.0'):
|
23 |
+
MODELS.register_module(module=nn.SiLU, name='SiLU')
|
24 |
+
else:
|
25 |
+
|
26 |
+
class SiLU(nn.Module):
|
27 |
+
"""Sigmoid Weighted Liner Unit."""
|
28 |
+
|
29 |
+
def __init__(self, inplace=False):
|
30 |
+
super().__init__()
|
31 |
+
self.inplace = inplace
|
32 |
+
|
33 |
+
def forward(self, inputs) -> torch.Tensor:
|
34 |
+
if self.inplace:
|
35 |
+
return inputs.mul_(torch.sigmoid(inputs))
|
36 |
+
else:
|
37 |
+
return inputs * torch.sigmoid(inputs)
|
38 |
+
|
39 |
+
MODELS.register_module(module=SiLU, name='SiLU')
|
40 |
+
|
41 |
+
|
42 |
+
@MODELS.register_module(name='Clip')
|
43 |
+
@MODELS.register_module()
|
44 |
+
class Clamp(nn.Module):
|
45 |
+
"""Clamp activation layer.
|
46 |
+
|
47 |
+
This activation function is to clamp the feature map value within
|
48 |
+
:math:`[min, max]`. More details can be found in ``torch.clamp()``.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
min (Number | optional): Lower-bound of the range to be clamped to.
|
52 |
+
Default to -1.
|
53 |
+
max (Number | optional): Upper-bound of the range to be clamped to.
|
54 |
+
Default to 1.
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(self, min: float = -1., max: float = 1.):
|
58 |
+
super().__init__()
|
59 |
+
self.min = min
|
60 |
+
self.max = max
|
61 |
+
|
62 |
+
def forward(self, x) -> torch.Tensor:
|
63 |
+
"""Forward function.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
x (torch.Tensor): The input tensor.
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
torch.Tensor: Clamped tensor.
|
70 |
+
"""
|
71 |
+
return torch.clamp(x, min=self.min, max=self.max)
|
72 |
+
|
73 |
+
|
74 |
+
class GELU(nn.Module):
|
75 |
+
r"""Applies the Gaussian Error Linear Units function:
|
76 |
+
|
77 |
+
.. math::
|
78 |
+
\text{GELU}(x) = x * \Phi(x)
|
79 |
+
where :math:`\Phi(x)` is the Cumulative Distribution Function for
|
80 |
+
Gaussian Distribution.
|
81 |
+
|
82 |
+
Shape:
|
83 |
+
- Input: :math:`(N, *)` where `*` means, any number of additional
|
84 |
+
dimensions
|
85 |
+
- Output: :math:`(N, *)`, same shape as the input
|
86 |
+
|
87 |
+
.. image:: scripts/activation_images/GELU.png
|
88 |
+
|
89 |
+
Examples::
|
90 |
+
|
91 |
+
>>> m = nn.GELU()
|
92 |
+
>>> input = torch.randn(2)
|
93 |
+
>>> output = m(input)
|
94 |
+
"""
|
95 |
+
|
96 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
97 |
+
return F.gelu(input)
|
98 |
+
|
99 |
+
|
100 |
+
if (TORCH_VERSION == 'parrots'
|
101 |
+
or digit_version(TORCH_VERSION) < digit_version('1.4')):
|
102 |
+
MODELS.register_module(module=GELU)
|
103 |
+
else:
|
104 |
+
MODELS.register_module(module=nn.GELU)
|
105 |
+
|
106 |
+
|
107 |
+
def build_activation_layer(cfg: Dict) -> nn.Module:
|
108 |
+
"""Build activation layer.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
cfg (dict): The activation layer config, which should contain:
|
112 |
+
|
113 |
+
- type (str): Layer type.
|
114 |
+
- layer args: Args needed to instantiate an activation layer.
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
nn.Module: Created activation layer.
|
118 |
+
"""
|
119 |
+
return MODELS.build(cfg)
|
external/cv/mmcv/cnn/bricks/context_block.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from typing import Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from mmengine.model import constant_init, kaiming_init
|
11 |
+
from mmengine.registry import MODELS
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
|
15 |
+
def last_zero_init(m: Union[nn.Module, nn.Sequential]) -> None:
|
16 |
+
if isinstance(m, nn.Sequential):
|
17 |
+
constant_init(m[-1], val=0)
|
18 |
+
else:
|
19 |
+
constant_init(m, val=0)
|
20 |
+
|
21 |
+
|
22 |
+
@MODELS.register_module()
|
23 |
+
class ContextBlock(nn.Module):
|
24 |
+
"""ContextBlock module in GCNet.
|
25 |
+
|
26 |
+
See 'GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond'
|
27 |
+
(https://arxiv.org/abs/1904.11492) for details.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
in_channels (int): Channels of the input feature map.
|
31 |
+
ratio (float): Ratio of channels of transform bottleneck
|
32 |
+
pooling_type (str): Pooling method for context modeling.
|
33 |
+
Options are 'att' and 'avg', stand for attention pooling and
|
34 |
+
average pooling respectively. Default: 'att'.
|
35 |
+
fusion_types (Sequence[str]): Fusion method for feature fusion,
|
36 |
+
Options are 'channels_add', 'channel_mul', stand for channelwise
|
37 |
+
addition and multiplication respectively. Default: ('channel_add',)
|
38 |
+
"""
|
39 |
+
|
40 |
+
_abbr_ = 'context_block'
|
41 |
+
|
42 |
+
def __init__(self,
|
43 |
+
in_channels: int,
|
44 |
+
ratio: float,
|
45 |
+
pooling_type: str = 'att',
|
46 |
+
fusion_types: tuple = ('channel_add', )):
|
47 |
+
super().__init__()
|
48 |
+
assert pooling_type in ['avg', 'att']
|
49 |
+
assert isinstance(fusion_types, (list, tuple))
|
50 |
+
valid_fusion_types = ['channel_add', 'channel_mul']
|
51 |
+
assert all([f in valid_fusion_types for f in fusion_types])
|
52 |
+
assert len(fusion_types) > 0, 'at least one fusion should be used'
|
53 |
+
self.in_channels = in_channels
|
54 |
+
self.ratio = ratio
|
55 |
+
self.planes = int(in_channels * ratio)
|
56 |
+
self.pooling_type = pooling_type
|
57 |
+
self.fusion_types = fusion_types
|
58 |
+
if pooling_type == 'att':
|
59 |
+
self.conv_mask = nn.Conv2d(in_channels, 1, kernel_size=1)
|
60 |
+
self.softmax = nn.Softmax(dim=2)
|
61 |
+
else:
|
62 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
63 |
+
if 'channel_add' in fusion_types:
|
64 |
+
self.channel_add_conv = nn.Sequential(
|
65 |
+
nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
|
66 |
+
nn.LayerNorm([self.planes, 1, 1]),
|
67 |
+
nn.ReLU(inplace=True), # yapf: disable
|
68 |
+
nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
|
69 |
+
else:
|
70 |
+
self.channel_add_conv = None
|
71 |
+
if 'channel_mul' in fusion_types:
|
72 |
+
self.channel_mul_conv = nn.Sequential(
|
73 |
+
nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
|
74 |
+
nn.LayerNorm([self.planes, 1, 1]),
|
75 |
+
nn.ReLU(inplace=True), # yapf: disable
|
76 |
+
nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
|
77 |
+
else:
|
78 |
+
self.channel_mul_conv = None
|
79 |
+
self.reset_parameters()
|
80 |
+
|
81 |
+
def reset_parameters(self):
|
82 |
+
if self.pooling_type == 'att':
|
83 |
+
kaiming_init(self.conv_mask, mode='fan_in')
|
84 |
+
self.conv_mask.inited = True
|
85 |
+
|
86 |
+
if self.channel_add_conv is not None:
|
87 |
+
last_zero_init(self.channel_add_conv)
|
88 |
+
if self.channel_mul_conv is not None:
|
89 |
+
last_zero_init(self.channel_mul_conv)
|
90 |
+
|
91 |
+
def spatial_pool(self, x: torch.Tensor) -> torch.Tensor:
|
92 |
+
batch, channel, height, width = x.size()
|
93 |
+
if self.pooling_type == 'att':
|
94 |
+
input_x = x
|
95 |
+
# [N, C, H * W]
|
96 |
+
input_x = input_x.view(batch, channel, height * width)
|
97 |
+
# [N, 1, C, H * W]
|
98 |
+
input_x = input_x.unsqueeze(1)
|
99 |
+
# [N, 1, H, W]
|
100 |
+
context_mask = self.conv_mask(x)
|
101 |
+
# [N, 1, H * W]
|
102 |
+
context_mask = context_mask.view(batch, 1, height * width)
|
103 |
+
# [N, 1, H * W]
|
104 |
+
context_mask = self.softmax(context_mask)
|
105 |
+
# [N, 1, H * W, 1]
|
106 |
+
context_mask = context_mask.unsqueeze(-1)
|
107 |
+
# [N, 1, C, 1]
|
108 |
+
context = torch.matmul(input_x, context_mask)
|
109 |
+
# [N, C, 1, 1]
|
110 |
+
context = context.view(batch, channel, 1, 1)
|
111 |
+
else:
|
112 |
+
# [N, C, 1, 1]
|
113 |
+
context = self.avg_pool(x)
|
114 |
+
|
115 |
+
return context
|
116 |
+
|
117 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
118 |
+
# [N, C, 1, 1]
|
119 |
+
context = self.spatial_pool(x)
|
120 |
+
|
121 |
+
out = x
|
122 |
+
if self.channel_mul_conv is not None:
|
123 |
+
# [N, C, 1, 1]
|
124 |
+
channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
|
125 |
+
out = out * channel_mul_term
|
126 |
+
if self.channel_add_conv is not None:
|
127 |
+
# [N, C, 1, 1]
|
128 |
+
channel_add_term = self.channel_add_conv(context)
|
129 |
+
out = out + channel_add_term
|
130 |
+
|
131 |
+
return out
|
external/cv/mmcv/cnn/bricks/conv.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import inspect
|
8 |
+
from typing import Dict, Optional
|
9 |
+
|
10 |
+
from mmengine.registry import MODELS
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
MODELS.register_module('Conv1d', module=nn.Conv1d)
|
14 |
+
MODELS.register_module('Conv2d', module=nn.Conv2d)
|
15 |
+
MODELS.register_module('Conv3d', module=nn.Conv3d)
|
16 |
+
MODELS.register_module('Conv', module=nn.Conv2d)
|
17 |
+
|
18 |
+
|
19 |
+
def build_conv_layer(cfg: Optional[Dict], *args, **kwargs) -> nn.Module:
|
20 |
+
"""Build convolution layer.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
cfg (None or dict): The conv layer config, which should contain:
|
24 |
+
- type (str): Layer type.
|
25 |
+
- layer args: Args needed to instantiate an conv layer.
|
26 |
+
args (argument list): Arguments passed to the `__init__`
|
27 |
+
method of the corresponding conv layer.
|
28 |
+
kwargs (keyword arguments): Keyword arguments passed to the `__init__`
|
29 |
+
method of the corresponding conv layer.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
nn.Module: Created conv layer.
|
33 |
+
"""
|
34 |
+
if cfg is None:
|
35 |
+
cfg_ = dict(type='Conv2d')
|
36 |
+
else:
|
37 |
+
if not isinstance(cfg, dict):
|
38 |
+
raise TypeError('cfg must be a dict')
|
39 |
+
if 'type' not in cfg:
|
40 |
+
raise KeyError('the cfg dict must contain the key "type"')
|
41 |
+
cfg_ = cfg.copy()
|
42 |
+
|
43 |
+
layer_type = cfg_.pop('type')
|
44 |
+
if inspect.isclass(layer_type):
|
45 |
+
return layer_type(*args, **kwargs, **cfg_) # type: ignore
|
46 |
+
# Switch registry to the target scope. If `conv_layer` cannot be found
|
47 |
+
# in the registry, fallback to search `conv_layer` in the
|
48 |
+
# mmengine.MODELS.
|
49 |
+
with MODELS.switch_scope_and_registry(None) as registry:
|
50 |
+
conv_layer = registry.get(layer_type)
|
51 |
+
if conv_layer is None:
|
52 |
+
raise KeyError(f'Cannot find {conv_layer} in registry under scope '
|
53 |
+
f'name {registry.scope}')
|
54 |
+
layer = conv_layer(*args, **kwargs, **cfg_)
|
55 |
+
|
56 |
+
return layer
|
external/cv/mmcv/cnn/bricks/conv2d_adaptive_padding.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
from typing import Tuple, Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from mmengine.registry import MODELS
|
12 |
+
from torch import nn
|
13 |
+
from torch.nn import functional as F
|
14 |
+
|
15 |
+
|
16 |
+
@MODELS.register_module()
|
17 |
+
class Conv2dAdaptivePadding(nn.Conv2d):
|
18 |
+
"""Implementation of 2D convolution in tensorflow with `padding` as "same",
|
19 |
+
which applies padding to input (if needed) so that input image gets fully
|
20 |
+
covered by filter and stride you specified. For stride 1, this will ensure
|
21 |
+
that output image size is same as input. For stride of 2, output dimensions
|
22 |
+
will be half, for example.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
in_channels (int): Number of channels in the input image
|
26 |
+
out_channels (int): Number of channels produced by the convolution
|
27 |
+
kernel_size (int or tuple): Size of the convolving kernel
|
28 |
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
29 |
+
padding (int or tuple, optional): Zero-padding added to both sides of
|
30 |
+
the input. Default: 0
|
31 |
+
dilation (int or tuple, optional): Spacing between kernel elements.
|
32 |
+
Default: 1
|
33 |
+
groups (int, optional): Number of blocked connections from input
|
34 |
+
channels to output channels. Default: 1
|
35 |
+
bias (bool, optional): If ``True``, adds a learnable bias to the
|
36 |
+
output. Default: ``True``
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self,
|
40 |
+
in_channels: int,
|
41 |
+
out_channels: int,
|
42 |
+
kernel_size: Union[int, Tuple[int, int]],
|
43 |
+
stride: Union[int, Tuple[int, int]] = 1,
|
44 |
+
padding: Union[int, Tuple[int, int]] = 0,
|
45 |
+
dilation: Union[int, Tuple[int, int]] = 1,
|
46 |
+
groups: int = 1,
|
47 |
+
bias: bool = True):
|
48 |
+
super().__init__(in_channels, out_channels, kernel_size, stride, 0,
|
49 |
+
dilation, groups, bias)
|
50 |
+
|
51 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
52 |
+
img_h, img_w = x.size()[-2:]
|
53 |
+
kernel_h, kernel_w = self.weight.size()[-2:]
|
54 |
+
stride_h, stride_w = self.stride
|
55 |
+
output_h = math.ceil(img_h / stride_h)
|
56 |
+
output_w = math.ceil(img_w / stride_w)
|
57 |
+
pad_h = (
|
58 |
+
max((output_h - 1) * self.stride[0] +
|
59 |
+
(kernel_h - 1) * self.dilation[0] + 1 - img_h, 0))
|
60 |
+
pad_w = (
|
61 |
+
max((output_w - 1) * self.stride[1] +
|
62 |
+
(kernel_w - 1) * self.dilation[1] + 1 - img_w, 0))
|
63 |
+
if pad_h > 0 or pad_w > 0:
|
64 |
+
x = F.pad(x, [
|
65 |
+
pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2
|
66 |
+
])
|
67 |
+
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding,
|
68 |
+
self.dilation, self.groups)
|
external/cv/mmcv/cnn/bricks/conv_module.py
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import warnings
|
8 |
+
from functools import partial
|
9 |
+
from typing import Dict, Optional, Tuple, Union
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from mmengine.model import constant_init, kaiming_init
|
14 |
+
from mmengine.registry import MODELS
|
15 |
+
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm
|
16 |
+
|
17 |
+
from .activation import build_activation_layer
|
18 |
+
from .conv import build_conv_layer
|
19 |
+
from .norm import build_norm_layer
|
20 |
+
from .padding import build_padding_layer
|
21 |
+
|
22 |
+
|
23 |
+
def efficient_conv_bn_eval_forward(bn: _BatchNorm,
|
24 |
+
conv: nn.modules.conv._ConvNd,
|
25 |
+
x: torch.Tensor):
|
26 |
+
"""
|
27 |
+
Implementation based on https://arxiv.org/abs/2305.11624
|
28 |
+
"Tune-Mode ConvBN Blocks For Efficient Transfer Learning"
|
29 |
+
It leverages the associative law between convolution and affine transform,
|
30 |
+
i.e., normalize (weight conv feature) = (normalize weight) conv feature.
|
31 |
+
It works for Eval mode of ConvBN blocks during validation, and can be used
|
32 |
+
for training as well. It reduces memory and computation cost.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
bn (_BatchNorm): a BatchNorm module.
|
36 |
+
conv (nn._ConvNd): a conv module
|
37 |
+
x (torch.Tensor): Input feature map.
|
38 |
+
"""
|
39 |
+
# These lines of code are designed to deal with various cases
|
40 |
+
# like bn without affine transform, and conv without bias
|
41 |
+
weight_on_the_fly = conv.weight
|
42 |
+
if conv.bias is not None:
|
43 |
+
bias_on_the_fly = conv.bias
|
44 |
+
else:
|
45 |
+
bias_on_the_fly = torch.zeros_like(bn.running_var)
|
46 |
+
|
47 |
+
if bn.weight is not None:
|
48 |
+
bn_weight = bn.weight
|
49 |
+
else:
|
50 |
+
bn_weight = torch.ones_like(bn.running_var)
|
51 |
+
|
52 |
+
if bn.bias is not None:
|
53 |
+
bn_bias = bn.bias
|
54 |
+
else:
|
55 |
+
bn_bias = torch.zeros_like(bn.running_var)
|
56 |
+
|
57 |
+
# shape of [C_out, 1, 1, 1] in Conv2d
|
58 |
+
weight_coeff = torch.rsqrt(bn.running_var +
|
59 |
+
bn.eps).reshape([-1] + [1] *
|
60 |
+
(len(conv.weight.shape) - 1))
|
61 |
+
# shape of [C_out, 1, 1, 1] in Conv2d
|
62 |
+
coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff
|
63 |
+
|
64 |
+
# shape of [C_out, C_in, k, k] in Conv2d
|
65 |
+
weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly
|
66 |
+
# shape of [C_out] in Conv2d
|
67 |
+
bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() *\
|
68 |
+
(bias_on_the_fly - bn.running_mean)
|
69 |
+
|
70 |
+
return conv._conv_forward(x, weight_on_the_fly, bias_on_the_fly)
|
71 |
+
|
72 |
+
|
73 |
+
@MODELS.register_module()
|
74 |
+
class ConvModule(nn.Module):
|
75 |
+
"""A conv block that bundles conv/norm/activation layers.
|
76 |
+
|
77 |
+
This block simplifies the usage of convolution layers, which are commonly
|
78 |
+
used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
|
79 |
+
It is based upon three build methods: `build_conv_layer()`,
|
80 |
+
`build_norm_layer()` and `build_activation_layer()`.
|
81 |
+
|
82 |
+
Besides, we add some additional features in this module.
|
83 |
+
1. Automatically set `bias` of the conv layer.
|
84 |
+
2. Spectral norm is supported.
|
85 |
+
3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
|
86 |
+
supports zero and circular padding, and we add "reflect" padding mode.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
in_channels (int): Number of channels in the input feature map.
|
90 |
+
Same as that in ``nn._ConvNd``.
|
91 |
+
out_channels (int): Number of channels produced by the convolution.
|
92 |
+
Same as that in ``nn._ConvNd``.
|
93 |
+
kernel_size (int | tuple[int]): Size of the convolving kernel.
|
94 |
+
Same as that in ``nn._ConvNd``.
|
95 |
+
stride (int | tuple[int]): Stride of the convolution.
|
96 |
+
Same as that in ``nn._ConvNd``.
|
97 |
+
padding (int | tuple[int]): Zero-padding added to both sides of
|
98 |
+
the input. Same as that in ``nn._ConvNd``.
|
99 |
+
dilation (int | tuple[int]): Spacing between kernel elements.
|
100 |
+
Same as that in ``nn._ConvNd``.
|
101 |
+
groups (int): Number of blocked connections from input channels to
|
102 |
+
output channels. Same as that in ``nn._ConvNd``.
|
103 |
+
bias (bool | str): If specified as `auto`, it will be decided by the
|
104 |
+
norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise
|
105 |
+
False. Default: "auto".
|
106 |
+
conv_cfg (dict): Config dict for convolution layer. Default: None,
|
107 |
+
which means using conv2d.
|
108 |
+
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
109 |
+
act_cfg (dict): Config dict for activation layer.
|
110 |
+
Default: dict(type='ReLU').
|
111 |
+
inplace (bool): Whether to use inplace mode for activation.
|
112 |
+
Default: True.
|
113 |
+
with_spectral_norm (bool): Whether use spectral norm in conv module.
|
114 |
+
Default: False.
|
115 |
+
padding_mode (str): If the `padding_mode` has not been supported by
|
116 |
+
current `Conv2d` in PyTorch, we will use our own padding layer
|
117 |
+
instead. Currently, we support ['zeros', 'circular'] with official
|
118 |
+
implementation and ['reflect'] with our own implementation.
|
119 |
+
Default: 'zeros'.
|
120 |
+
order (tuple[str]): The order of conv/norm/activation layers. It is a
|
121 |
+
sequence of "conv", "norm" and "act". Common examples are
|
122 |
+
("conv", "norm", "act") and ("act", "conv", "norm").
|
123 |
+
Default: ('conv', 'norm', 'act').
|
124 |
+
efficient_conv_bn_eval (bool): Whether use efficient conv when the
|
125 |
+
consecutive bn is in eval mode (either training or testing), as
|
126 |
+
proposed in https://arxiv.org/abs/2305.11624 . Default: `False`.
|
127 |
+
"""
|
128 |
+
|
129 |
+
_abbr_ = 'conv_block'
|
130 |
+
|
131 |
+
def __init__(self,
|
132 |
+
in_channels: int,
|
133 |
+
out_channels: int,
|
134 |
+
kernel_size: Union[int, Tuple[int, int]],
|
135 |
+
stride: Union[int, Tuple[int, int]] = 1,
|
136 |
+
padding: Union[int, Tuple[int, int]] = 0,
|
137 |
+
dilation: Union[int, Tuple[int, int]] = 1,
|
138 |
+
groups: int = 1,
|
139 |
+
bias: Union[bool, str] = 'auto',
|
140 |
+
conv_cfg: Optional[Dict] = None,
|
141 |
+
norm_cfg: Optional[Dict] = None,
|
142 |
+
act_cfg: Optional[Dict] = dict(type='ReLU'),
|
143 |
+
inplace: bool = True,
|
144 |
+
with_spectral_norm: bool = False,
|
145 |
+
padding_mode: str = 'zeros',
|
146 |
+
order: tuple = ('conv', 'norm', 'act'),
|
147 |
+
efficient_conv_bn_eval: bool = False):
|
148 |
+
super().__init__()
|
149 |
+
assert conv_cfg is None or isinstance(conv_cfg, dict)
|
150 |
+
assert norm_cfg is None or isinstance(norm_cfg, dict)
|
151 |
+
assert act_cfg is None or isinstance(act_cfg, dict)
|
152 |
+
official_padding_mode = ['zeros', 'circular']
|
153 |
+
self.conv_cfg = conv_cfg
|
154 |
+
self.norm_cfg = norm_cfg
|
155 |
+
self.act_cfg = act_cfg
|
156 |
+
self.inplace = inplace
|
157 |
+
self.with_spectral_norm = with_spectral_norm
|
158 |
+
self.with_explicit_padding = padding_mode not in official_padding_mode
|
159 |
+
self.order = order
|
160 |
+
assert isinstance(self.order, tuple) and len(self.order) == 3
|
161 |
+
assert set(order) == {'conv', 'norm', 'act'}
|
162 |
+
|
163 |
+
self.with_norm = norm_cfg is not None
|
164 |
+
self.with_activation = act_cfg is not None
|
165 |
+
# if the conv layer is before a norm layer, bias is unnecessary.
|
166 |
+
if bias == 'auto':
|
167 |
+
bias = not self.with_norm
|
168 |
+
self.with_bias = bias
|
169 |
+
|
170 |
+
if self.with_explicit_padding:
|
171 |
+
pad_cfg = dict(type=padding_mode)
|
172 |
+
self.padding_layer = build_padding_layer(pad_cfg, padding)
|
173 |
+
|
174 |
+
# reset padding to 0 for conv module
|
175 |
+
conv_padding = 0 if self.with_explicit_padding else padding
|
176 |
+
# build convolution layer
|
177 |
+
self.conv = build_conv_layer(
|
178 |
+
conv_cfg,
|
179 |
+
in_channels,
|
180 |
+
out_channels,
|
181 |
+
kernel_size,
|
182 |
+
stride=stride,
|
183 |
+
padding=conv_padding,
|
184 |
+
dilation=dilation,
|
185 |
+
groups=groups,
|
186 |
+
bias=bias)
|
187 |
+
# export the attributes of self.conv to a higher level for convenience
|
188 |
+
self.in_channels = self.conv.in_channels
|
189 |
+
self.out_channels = self.conv.out_channels
|
190 |
+
self.kernel_size = self.conv.kernel_size
|
191 |
+
self.stride = self.conv.stride
|
192 |
+
self.padding = padding
|
193 |
+
self.dilation = self.conv.dilation
|
194 |
+
self.transposed = self.conv.transposed
|
195 |
+
self.output_padding = self.conv.output_padding
|
196 |
+
self.groups = self.conv.groups
|
197 |
+
|
198 |
+
if self.with_spectral_norm:
|
199 |
+
self.conv = nn.utils.spectral_norm(self.conv)
|
200 |
+
|
201 |
+
# build normalization layers
|
202 |
+
if self.with_norm:
|
203 |
+
# norm layer is after conv layer
|
204 |
+
if order.index('norm') > order.index('conv'):
|
205 |
+
norm_channels = out_channels
|
206 |
+
else:
|
207 |
+
norm_channels = in_channels
|
208 |
+
self.norm_name, norm = build_norm_layer(
|
209 |
+
norm_cfg, norm_channels) # type: ignore
|
210 |
+
self.add_module(self.norm_name, norm)
|
211 |
+
if self.with_bias:
|
212 |
+
if isinstance(norm, (_BatchNorm, _InstanceNorm)):
|
213 |
+
warnings.warn(
|
214 |
+
'Unnecessary conv bias before batch/instance norm')
|
215 |
+
else:
|
216 |
+
self.norm_name = None # type: ignore
|
217 |
+
|
218 |
+
self.turn_on_efficient_conv_bn_eval(efficient_conv_bn_eval)
|
219 |
+
|
220 |
+
# build activation layer
|
221 |
+
if self.with_activation:
|
222 |
+
act_cfg_ = act_cfg.copy() # type: ignore
|
223 |
+
# nn.Tanh has no 'inplace' argument
|
224 |
+
if act_cfg_['type'] not in [
|
225 |
+
'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish', 'GELU'
|
226 |
+
]:
|
227 |
+
act_cfg_.setdefault('inplace', inplace)
|
228 |
+
self.activate = build_activation_layer(act_cfg_)
|
229 |
+
|
230 |
+
# Use msra init by default
|
231 |
+
self.init_weights()
|
232 |
+
|
233 |
+
@property
|
234 |
+
def norm(self):
|
235 |
+
if self.norm_name:
|
236 |
+
return getattr(self, self.norm_name)
|
237 |
+
else:
|
238 |
+
return None
|
239 |
+
|
240 |
+
def init_weights(self):
|
241 |
+
# 1. It is mainly for customized conv layers with their own
|
242 |
+
# initialization manners by calling their own ``init_weights()``,
|
243 |
+
# and we do not want ConvModule to override the initialization.
|
244 |
+
# 2. For customized conv layers without their own initialization
|
245 |
+
# manners (that is, they don't have their own ``init_weights()``)
|
246 |
+
# and PyTorch's conv layers, they will be initialized by
|
247 |
+
# this method with default ``kaiming_init``.
|
248 |
+
# Note: For PyTorch's conv layers, they will be overwritten by our
|
249 |
+
# initialization implementation using default ``kaiming_init``.
|
250 |
+
if not hasattr(self.conv, 'init_weights'):
|
251 |
+
if self.with_activation and self.act_cfg['type'] == 'LeakyReLU':
|
252 |
+
nonlinearity = 'leaky_relu'
|
253 |
+
a = self.act_cfg.get('negative_slope', 0.01)
|
254 |
+
else:
|
255 |
+
nonlinearity = 'relu'
|
256 |
+
a = 0
|
257 |
+
kaiming_init(self.conv, a=a, nonlinearity=nonlinearity)
|
258 |
+
if self.with_norm:
|
259 |
+
constant_init(self.norm, 1, bias=0)
|
260 |
+
|
261 |
+
def forward(self,
|
262 |
+
x: torch.Tensor,
|
263 |
+
activate: bool = True,
|
264 |
+
norm: bool = True) -> torch.Tensor:
|
265 |
+
layer_index = 0
|
266 |
+
while layer_index < len(self.order):
|
267 |
+
layer = self.order[layer_index]
|
268 |
+
if layer == 'conv':
|
269 |
+
if self.with_explicit_padding:
|
270 |
+
x = self.padding_layer(x)
|
271 |
+
# if the next operation is norm and we have a norm layer in
|
272 |
+
# eval mode and we have enabled `efficient_conv_bn_eval` for
|
273 |
+
# the conv operator, then activate the optimized forward and
|
274 |
+
# skip the next norm operator since it has been fused
|
275 |
+
if layer_index + 1 < len(self.order) and \
|
276 |
+
self.order[layer_index + 1] == 'norm' and norm and \
|
277 |
+
self.with_norm and not self.norm.training and \
|
278 |
+
self.efficient_conv_bn_eval_forward is not None:
|
279 |
+
self.conv.forward = partial(
|
280 |
+
self.efficient_conv_bn_eval_forward, self.norm,
|
281 |
+
self.conv)
|
282 |
+
layer_index += 1
|
283 |
+
x = self.conv(x)
|
284 |
+
del self.conv.forward
|
285 |
+
else:
|
286 |
+
x = self.conv(x)
|
287 |
+
elif layer == 'norm' and norm and self.with_norm:
|
288 |
+
x = self.norm(x)
|
289 |
+
elif layer == 'act' and activate and self.with_activation:
|
290 |
+
x = self.activate(x)
|
291 |
+
layer_index += 1
|
292 |
+
return x
|
293 |
+
|
294 |
+
def turn_on_efficient_conv_bn_eval(self, efficient_conv_bn_eval=True):
|
295 |
+
# efficient_conv_bn_eval works for conv + bn
|
296 |
+
# with `track_running_stats` option
|
297 |
+
if efficient_conv_bn_eval and self.norm \
|
298 |
+
and isinstance(self.norm, _BatchNorm) \
|
299 |
+
and self.norm.track_running_stats:
|
300 |
+
self.efficient_conv_bn_eval_forward = efficient_conv_bn_eval_forward # noqa: E501
|
301 |
+
else:
|
302 |
+
self.efficient_conv_bn_eval_forward = None # type: ignore
|
303 |
+
|
304 |
+
@staticmethod
|
305 |
+
def create_from_conv_bn(conv: torch.nn.modules.conv._ConvNd,
|
306 |
+
bn: torch.nn.modules.batchnorm._BatchNorm,
|
307 |
+
efficient_conv_bn_eval=True) -> 'ConvModule':
|
308 |
+
"""Create a ConvModule from a conv and a bn module."""
|
309 |
+
self = ConvModule.__new__(ConvModule)
|
310 |
+
super(ConvModule, self).__init__()
|
311 |
+
|
312 |
+
self.conv_cfg = None
|
313 |
+
self.norm_cfg = None
|
314 |
+
self.act_cfg = None
|
315 |
+
self.inplace = False
|
316 |
+
self.with_spectral_norm = False
|
317 |
+
self.with_explicit_padding = False
|
318 |
+
self.order = ('conv', 'norm', 'act')
|
319 |
+
|
320 |
+
self.with_norm = True
|
321 |
+
self.with_activation = False
|
322 |
+
self.with_bias = conv.bias is not None
|
323 |
+
|
324 |
+
# build convolution layer
|
325 |
+
self.conv = conv
|
326 |
+
# export the attributes of self.conv to a higher level for convenience
|
327 |
+
self.in_channels = self.conv.in_channels
|
328 |
+
self.out_channels = self.conv.out_channels
|
329 |
+
self.kernel_size = self.conv.kernel_size
|
330 |
+
self.stride = self.conv.stride
|
331 |
+
self.padding = self.conv.padding
|
332 |
+
self.dilation = self.conv.dilation
|
333 |
+
self.transposed = self.conv.transposed
|
334 |
+
self.output_padding = self.conv.output_padding
|
335 |
+
self.groups = self.conv.groups
|
336 |
+
|
337 |
+
# build normalization layers
|
338 |
+
self.norm_name, norm = 'bn', bn
|
339 |
+
self.add_module(self.norm_name, norm)
|
340 |
+
|
341 |
+
self.turn_on_efficient_conv_bn_eval(efficient_conv_bn_eval)
|
342 |
+
|
343 |
+
return self
|
external/cv/mmcv/cnn/bricks/conv_ws.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from collections import OrderedDict
|
8 |
+
from typing import Dict, List, Optional, Tuple, Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from mmengine.registry import MODELS
|
14 |
+
|
15 |
+
|
16 |
+
def conv_ws_2d(input: torch.Tensor,
|
17 |
+
weight: torch.Tensor,
|
18 |
+
bias: Optional[torch.Tensor] = None,
|
19 |
+
stride: Union[int, Tuple[int, int]] = 1,
|
20 |
+
padding: Union[int, Tuple[int, int]] = 0,
|
21 |
+
dilation: Union[int, Tuple[int, int]] = 1,
|
22 |
+
groups: int = 1,
|
23 |
+
eps: float = 1e-5) -> torch.Tensor:
|
24 |
+
c_in = weight.size(0)
|
25 |
+
weight_flat = weight.view(c_in, -1)
|
26 |
+
mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1)
|
27 |
+
std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1)
|
28 |
+
weight = (weight - mean) / (std + eps)
|
29 |
+
return F.conv2d(input, weight, bias, stride, padding, dilation, groups)
|
30 |
+
|
31 |
+
|
32 |
+
@MODELS.register_module('ConvWS')
|
33 |
+
class ConvWS2d(nn.Conv2d):
|
34 |
+
|
35 |
+
def __init__(self,
|
36 |
+
in_channels: int,
|
37 |
+
out_channels: int,
|
38 |
+
kernel_size: Union[int, Tuple[int, int]],
|
39 |
+
stride: Union[int, Tuple[int, int]] = 1,
|
40 |
+
padding: Union[int, Tuple[int, int]] = 0,
|
41 |
+
dilation: Union[int, Tuple[int, int]] = 1,
|
42 |
+
groups: int = 1,
|
43 |
+
bias: bool = True,
|
44 |
+
eps: float = 1e-5):
|
45 |
+
super().__init__(
|
46 |
+
in_channels,
|
47 |
+
out_channels,
|
48 |
+
kernel_size,
|
49 |
+
stride=stride,
|
50 |
+
padding=padding,
|
51 |
+
dilation=dilation,
|
52 |
+
groups=groups,
|
53 |
+
bias=bias)
|
54 |
+
self.eps = eps
|
55 |
+
|
56 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
57 |
+
return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding,
|
58 |
+
self.dilation, self.groups, self.eps)
|
59 |
+
|
60 |
+
|
61 |
+
@MODELS.register_module(name='ConvAWS')
|
62 |
+
class ConvAWS2d(nn.Conv2d):
|
63 |
+
"""AWS (Adaptive Weight Standardization)
|
64 |
+
|
65 |
+
This is a variant of Weight Standardization
|
66 |
+
(https://arxiv.org/pdf/1903.10520.pdf)
|
67 |
+
It is used in DetectoRS to avoid NaN
|
68 |
+
(https://arxiv.org/pdf/2006.02334.pdf)
|
69 |
+
|
70 |
+
Args:
|
71 |
+
in_channels (int): Number of channels in the input image
|
72 |
+
out_channels (int): Number of channels produced by the convolution
|
73 |
+
kernel_size (int or tuple): Size of the conv kernel
|
74 |
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
75 |
+
padding (int or tuple, optional): Zero-padding added to both sides of
|
76 |
+
the input. Default: 0
|
77 |
+
dilation (int or tuple, optional): Spacing between kernel elements.
|
78 |
+
Default: 1
|
79 |
+
groups (int, optional): Number of blocked connections from input
|
80 |
+
channels to output channels. Default: 1
|
81 |
+
bias (bool, optional): If set True, adds a learnable bias to the
|
82 |
+
output. Default: True
|
83 |
+
"""
|
84 |
+
|
85 |
+
def __init__(self,
|
86 |
+
in_channels: int,
|
87 |
+
out_channels: int,
|
88 |
+
kernel_size: Union[int, Tuple[int, int]],
|
89 |
+
stride: Union[int, Tuple[int, int]] = 1,
|
90 |
+
padding: Union[int, Tuple[int, int]] = 0,
|
91 |
+
dilation: Union[int, Tuple[int, int]] = 1,
|
92 |
+
groups: int = 1,
|
93 |
+
bias: bool = True):
|
94 |
+
super().__init__(
|
95 |
+
in_channels,
|
96 |
+
out_channels,
|
97 |
+
kernel_size,
|
98 |
+
stride=stride,
|
99 |
+
padding=padding,
|
100 |
+
dilation=dilation,
|
101 |
+
groups=groups,
|
102 |
+
bias=bias)
|
103 |
+
self.register_buffer('weight_gamma',
|
104 |
+
torch.ones(self.out_channels, 1, 1, 1))
|
105 |
+
self.register_buffer('weight_beta',
|
106 |
+
torch.zeros(self.out_channels, 1, 1, 1))
|
107 |
+
|
108 |
+
def _get_weight(self, weight: torch.Tensor) -> torch.Tensor:
|
109 |
+
weight_flat = weight.view(weight.size(0), -1)
|
110 |
+
mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
|
111 |
+
std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
|
112 |
+
weight = (weight - mean) / std
|
113 |
+
weight = self.weight_gamma * weight + self.weight_beta
|
114 |
+
return weight
|
115 |
+
|
116 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
117 |
+
weight = self._get_weight(self.weight)
|
118 |
+
return F.conv2d(x, weight, self.bias, self.stride, self.padding,
|
119 |
+
self.dilation, self.groups)
|
120 |
+
|
121 |
+
def _load_from_state_dict(self, state_dict: OrderedDict, prefix: str,
|
122 |
+
local_metadata: Dict, strict: bool,
|
123 |
+
missing_keys: List[str],
|
124 |
+
unexpected_keys: List[str],
|
125 |
+
error_msgs: List[str]) -> None:
|
126 |
+
"""Override default load function.
|
127 |
+
|
128 |
+
AWS overrides the function _load_from_state_dict to recover
|
129 |
+
weight_gamma and weight_beta if they are missing. If weight_gamma and
|
130 |
+
weight_beta are found in the checkpoint, this function will return
|
131 |
+
after super()._load_from_state_dict. Otherwise, it will compute the
|
132 |
+
mean and std of the pretrained weights and store them in weight_beta
|
133 |
+
and weight_gamma.
|
134 |
+
"""
|
135 |
+
|
136 |
+
self.weight_gamma.data.fill_(-1)
|
137 |
+
local_missing_keys: List = []
|
138 |
+
super()._load_from_state_dict(state_dict, prefix, local_metadata,
|
139 |
+
strict, local_missing_keys,
|
140 |
+
unexpected_keys, error_msgs)
|
141 |
+
if self.weight_gamma.data.mean() > 0:
|
142 |
+
for k in local_missing_keys:
|
143 |
+
missing_keys.append(k)
|
144 |
+
return
|
145 |
+
weight = self.weight.data
|
146 |
+
weight_flat = weight.view(weight.size(0), -1)
|
147 |
+
mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
|
148 |
+
std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
|
149 |
+
self.weight_beta.data.copy_(mean)
|
150 |
+
self.weight_gamma.data.copy_(std)
|
151 |
+
missing_gamma_beta = [
|
152 |
+
k for k in local_missing_keys
|
153 |
+
if k.endswith('weight_gamma') or k.endswith('weight_beta')
|
154 |
+
]
|
155 |
+
for k in missing_gamma_beta:
|
156 |
+
local_missing_keys.remove(k)
|
157 |
+
for k in local_missing_keys:
|
158 |
+
missing_keys.append(k)
|
external/cv/mmcv/cnn/bricks/depthwise_separable_conv_module.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from typing import Dict, Optional, Tuple, Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
from .conv_module import ConvModule
|
13 |
+
|
14 |
+
|
15 |
+
class DepthwiseSeparableConvModule(nn.Module):
|
16 |
+
"""Depthwise separable convolution module.
|
17 |
+
|
18 |
+
See https://arxiv.org/pdf/1704.04861.pdf for details.
|
19 |
+
|
20 |
+
This module can replace a ConvModule with the conv block replaced by two
|
21 |
+
conv block: depthwise conv block and pointwise conv block. The depthwise
|
22 |
+
conv block contains depthwise-conv/norm/activation layers. The pointwise
|
23 |
+
conv block contains pointwise-conv/norm/activation layers. It should be
|
24 |
+
noted that there will be norm/activation layer in the depthwise conv block
|
25 |
+
if `norm_cfg` and `act_cfg` are specified.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
in_channels (int): Number of channels in the input feature map.
|
29 |
+
Same as that in ``nn._ConvNd``.
|
30 |
+
out_channels (int): Number of channels produced by the convolution.
|
31 |
+
Same as that in ``nn._ConvNd``.
|
32 |
+
kernel_size (int | tuple[int]): Size of the convolving kernel.
|
33 |
+
Same as that in ``nn._ConvNd``.
|
34 |
+
stride (int | tuple[int]): Stride of the convolution.
|
35 |
+
Same as that in ``nn._ConvNd``. Default: 1.
|
36 |
+
padding (int | tuple[int]): Zero-padding added to both sides of
|
37 |
+
the input. Same as that in ``nn._ConvNd``. Default: 0.
|
38 |
+
dilation (int | tuple[int]): Spacing between kernel elements.
|
39 |
+
Same as that in ``nn._ConvNd``. Default: 1.
|
40 |
+
norm_cfg (dict): Default norm config for both depthwise ConvModule and
|
41 |
+
pointwise ConvModule. Default: None.
|
42 |
+
act_cfg (dict): Default activation config for both depthwise ConvModule
|
43 |
+
and pointwise ConvModule. Default: dict(type='ReLU').
|
44 |
+
dw_norm_cfg (dict): Norm config of depthwise ConvModule. If it is
|
45 |
+
'default', it will be the same as `norm_cfg`. Default: 'default'.
|
46 |
+
dw_act_cfg (dict): Activation config of depthwise ConvModule. If it is
|
47 |
+
'default', it will be the same as `act_cfg`. Default: 'default'.
|
48 |
+
pw_norm_cfg (dict): Norm config of pointwise ConvModule. If it is
|
49 |
+
'default', it will be the same as `norm_cfg`. Default: 'default'.
|
50 |
+
pw_act_cfg (dict): Activation config of pointwise ConvModule. If it is
|
51 |
+
'default', it will be the same as `act_cfg`. Default: 'default'.
|
52 |
+
kwargs (optional): Other shared arguments for depthwise and pointwise
|
53 |
+
ConvModule. See ConvModule for ref.
|
54 |
+
"""
|
55 |
+
|
56 |
+
def __init__(self,
|
57 |
+
in_channels: int,
|
58 |
+
out_channels: int,
|
59 |
+
kernel_size: Union[int, Tuple[int, int]],
|
60 |
+
stride: Union[int, Tuple[int, int]] = 1,
|
61 |
+
padding: Union[int, Tuple[int, int]] = 0,
|
62 |
+
dilation: Union[int, Tuple[int, int]] = 1,
|
63 |
+
norm_cfg: Optional[Dict] = None,
|
64 |
+
act_cfg: Dict = dict(type='ReLU'),
|
65 |
+
dw_norm_cfg: Union[Dict, str] = 'default',
|
66 |
+
dw_act_cfg: Union[Dict, str] = 'default',
|
67 |
+
pw_norm_cfg: Union[Dict, str] = 'default',
|
68 |
+
pw_act_cfg: Union[Dict, str] = 'default',
|
69 |
+
**kwargs):
|
70 |
+
super().__init__()
|
71 |
+
assert 'groups' not in kwargs, 'groups should not be specified'
|
72 |
+
|
73 |
+
# if norm/activation config of depthwise/pointwise ConvModule is not
|
74 |
+
# specified, use default config.
|
75 |
+
dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg # type: ignore # noqa E501
|
76 |
+
dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg
|
77 |
+
pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg # type: ignore # noqa E501
|
78 |
+
pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg
|
79 |
+
|
80 |
+
# depthwise convolution
|
81 |
+
self.depthwise_conv = ConvModule(
|
82 |
+
in_channels,
|
83 |
+
in_channels,
|
84 |
+
kernel_size,
|
85 |
+
stride=stride,
|
86 |
+
padding=padding,
|
87 |
+
dilation=dilation,
|
88 |
+
groups=in_channels,
|
89 |
+
norm_cfg=dw_norm_cfg, # type: ignore
|
90 |
+
act_cfg=dw_act_cfg, # type: ignore
|
91 |
+
**kwargs)
|
92 |
+
|
93 |
+
self.pointwise_conv = ConvModule(
|
94 |
+
in_channels,
|
95 |
+
out_channels,
|
96 |
+
1,
|
97 |
+
norm_cfg=pw_norm_cfg, # type: ignore
|
98 |
+
act_cfg=pw_act_cfg, # type: ignore
|
99 |
+
**kwargs)
|
100 |
+
|
101 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
102 |
+
x = self.depthwise_conv(x)
|
103 |
+
x = self.pointwise_conv(x)
|
104 |
+
return x
|
external/cv/mmcv/cnn/bricks/drop.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from typing import Any, Dict, Optional
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from mmengine.registry import MODELS
|
12 |
+
|
13 |
+
|
14 |
+
def drop_path(x: torch.Tensor,
|
15 |
+
drop_prob: float = 0.,
|
16 |
+
training: bool = False) -> torch.Tensor:
|
17 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of
|
18 |
+
residual blocks).
|
19 |
+
|
20 |
+
We follow the implementation
|
21 |
+
https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
|
22 |
+
"""
|
23 |
+
if not training:
|
24 |
+
return x
|
25 |
+
keep_prob = 1 - drop_prob
|
26 |
+
# handle tensors with different dimensions, not just 4D tensors.
|
27 |
+
shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
|
28 |
+
random_tensor = keep_prob + torch.rand(
|
29 |
+
shape, dtype=x.dtype, device=x.device)
|
30 |
+
output = x.div(keep_prob) * random_tensor.floor()
|
31 |
+
return output
|
32 |
+
|
33 |
+
|
34 |
+
@MODELS.register_module()
|
35 |
+
class DropPath(nn.Module):
|
36 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of
|
37 |
+
residual blocks).
|
38 |
+
|
39 |
+
We follow the implementation
|
40 |
+
https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
|
41 |
+
|
42 |
+
Args:
|
43 |
+
drop_prob (float): Probability of the path to be zeroed. Default: 0.1
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(self, drop_prob: float = 0.1):
|
47 |
+
super().__init__()
|
48 |
+
self.drop_prob = drop_prob
|
49 |
+
|
50 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
51 |
+
return drop_path(x, self.drop_prob, self.training)
|
52 |
+
|
53 |
+
|
54 |
+
@MODELS.register_module()
|
55 |
+
class Dropout(nn.Dropout):
|
56 |
+
"""A wrapper for ``torch.nn.Dropout``, We rename the ``p`` of
|
57 |
+
``torch.nn.Dropout`` to ``drop_prob`` so as to be consistent with
|
58 |
+
``DropPath``
|
59 |
+
|
60 |
+
Args:
|
61 |
+
drop_prob (float): Probability of the elements to be
|
62 |
+
zeroed. Default: 0.5.
|
63 |
+
inplace (bool): Do the operation inplace or not. Default: False.
|
64 |
+
"""
|
65 |
+
|
66 |
+
def __init__(self, drop_prob: float = 0.5, inplace: bool = False):
|
67 |
+
super().__init__(p=drop_prob, inplace=inplace)
|
68 |
+
|
69 |
+
|
70 |
+
def build_dropout(cfg: Dict, default_args: Optional[Dict] = None) -> Any:
|
71 |
+
"""Builder for drop out layers."""
|
72 |
+
return MODELS.build(cfg, default_args=default_args)
|
external/cv/mmcv/cnn/bricks/generalized_attention.py
ADDED
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from mmengine.model import kaiming_init
|
14 |
+
from mmengine.registry import MODELS
|
15 |
+
|
16 |
+
|
17 |
+
@MODELS.register_module()
|
18 |
+
class GeneralizedAttention(nn.Module):
|
19 |
+
"""GeneralizedAttention module.
|
20 |
+
|
21 |
+
See 'An Empirical Study of Spatial Attention Mechanisms in Deep Networks'
|
22 |
+
(https://arxiv.org/abs/1904.05873) for details.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
in_channels (int): Channels of the input feature map.
|
26 |
+
spatial_range (int): The spatial range. -1 indicates no spatial range
|
27 |
+
constraint. Default: -1.
|
28 |
+
num_heads (int): The head number of empirical_attention module.
|
29 |
+
Default: 9.
|
30 |
+
position_embedding_dim (int): The position embedding dimension.
|
31 |
+
Default: -1.
|
32 |
+
position_magnitude (int): A multiplier acting on coord difference.
|
33 |
+
Default: 1.
|
34 |
+
kv_stride (int): The feature stride acting on key/value feature map.
|
35 |
+
Default: 2.
|
36 |
+
q_stride (int): The feature stride acting on query feature map.
|
37 |
+
Default: 1.
|
38 |
+
attention_type (str): A binary indicator string for indicating which
|
39 |
+
items in generalized empirical_attention module are used.
|
40 |
+
Default: '1111'.
|
41 |
+
|
42 |
+
- '1000' indicates 'query and key content' (appr - appr) item,
|
43 |
+
- '0100' indicates 'query content and relative position'
|
44 |
+
(appr - position) item,
|
45 |
+
- '0010' indicates 'key content only' (bias - appr) item,
|
46 |
+
- '0001' indicates 'relative position only' (bias - position) item.
|
47 |
+
"""
|
48 |
+
|
49 |
+
_abbr_ = 'gen_attention_block'
|
50 |
+
|
51 |
+
def __init__(self,
|
52 |
+
in_channels: int,
|
53 |
+
spatial_range: int = -1,
|
54 |
+
num_heads: int = 9,
|
55 |
+
position_embedding_dim: int = -1,
|
56 |
+
position_magnitude: int = 1,
|
57 |
+
kv_stride: int = 2,
|
58 |
+
q_stride: int = 1,
|
59 |
+
attention_type: str = '1111'):
|
60 |
+
|
61 |
+
super().__init__()
|
62 |
+
|
63 |
+
# hard range means local range for non-local operation
|
64 |
+
self.position_embedding_dim = (
|
65 |
+
position_embedding_dim
|
66 |
+
if position_embedding_dim > 0 else in_channels)
|
67 |
+
|
68 |
+
self.position_magnitude = position_magnitude
|
69 |
+
self.num_heads = num_heads
|
70 |
+
self.in_channels = in_channels
|
71 |
+
self.spatial_range = spatial_range
|
72 |
+
self.kv_stride = kv_stride
|
73 |
+
self.q_stride = q_stride
|
74 |
+
self.attention_type = [bool(int(_)) for _ in attention_type]
|
75 |
+
self.qk_embed_dim = in_channels // num_heads
|
76 |
+
out_c = self.qk_embed_dim * num_heads
|
77 |
+
|
78 |
+
if self.attention_type[0] or self.attention_type[1]:
|
79 |
+
self.query_conv = nn.Conv2d(
|
80 |
+
in_channels=in_channels,
|
81 |
+
out_channels=out_c,
|
82 |
+
kernel_size=1,
|
83 |
+
bias=False)
|
84 |
+
self.query_conv.kaiming_init = True
|
85 |
+
|
86 |
+
if self.attention_type[0] or self.attention_type[2]:
|
87 |
+
self.key_conv = nn.Conv2d(
|
88 |
+
in_channels=in_channels,
|
89 |
+
out_channels=out_c,
|
90 |
+
kernel_size=1,
|
91 |
+
bias=False)
|
92 |
+
self.key_conv.kaiming_init = True
|
93 |
+
|
94 |
+
self.v_dim = in_channels // num_heads
|
95 |
+
self.value_conv = nn.Conv2d(
|
96 |
+
in_channels=in_channels,
|
97 |
+
out_channels=self.v_dim * num_heads,
|
98 |
+
kernel_size=1,
|
99 |
+
bias=False)
|
100 |
+
self.value_conv.kaiming_init = True
|
101 |
+
|
102 |
+
if self.attention_type[1] or self.attention_type[3]:
|
103 |
+
self.appr_geom_fc_x = nn.Linear(
|
104 |
+
self.position_embedding_dim // 2, out_c, bias=False)
|
105 |
+
self.appr_geom_fc_x.kaiming_init = True
|
106 |
+
|
107 |
+
self.appr_geom_fc_y = nn.Linear(
|
108 |
+
self.position_embedding_dim // 2, out_c, bias=False)
|
109 |
+
self.appr_geom_fc_y.kaiming_init = True
|
110 |
+
|
111 |
+
if self.attention_type[2]:
|
112 |
+
stdv = 1.0 / math.sqrt(self.qk_embed_dim * 2)
|
113 |
+
appr_bias_value = -2 * stdv * torch.rand(out_c) + stdv
|
114 |
+
self.appr_bias = nn.Parameter(appr_bias_value)
|
115 |
+
|
116 |
+
if self.attention_type[3]:
|
117 |
+
stdv = 1.0 / math.sqrt(self.qk_embed_dim * 2)
|
118 |
+
geom_bias_value = -2 * stdv * torch.rand(out_c) + stdv
|
119 |
+
self.geom_bias = nn.Parameter(geom_bias_value)
|
120 |
+
|
121 |
+
self.proj_conv = nn.Conv2d(
|
122 |
+
in_channels=self.v_dim * num_heads,
|
123 |
+
out_channels=in_channels,
|
124 |
+
kernel_size=1,
|
125 |
+
bias=True)
|
126 |
+
self.proj_conv.kaiming_init = True
|
127 |
+
self.gamma = nn.Parameter(torch.zeros(1))
|
128 |
+
|
129 |
+
if self.spatial_range >= 0:
|
130 |
+
# only works when non local is after 3*3 conv
|
131 |
+
if in_channels == 256:
|
132 |
+
max_len = 84
|
133 |
+
elif in_channels == 512:
|
134 |
+
max_len = 42
|
135 |
+
|
136 |
+
max_len_kv = int((max_len - 1.0) / self.kv_stride + 1)
|
137 |
+
local_constraint_map = np.ones(
|
138 |
+
(max_len, max_len, max_len_kv, max_len_kv), dtype=int)
|
139 |
+
for iy in range(max_len):
|
140 |
+
for ix in range(max_len):
|
141 |
+
local_constraint_map[
|
142 |
+
iy, ix,
|
143 |
+
max((iy - self.spatial_range) //
|
144 |
+
self.kv_stride, 0):min((iy + self.spatial_range +
|
145 |
+
1) // self.kv_stride +
|
146 |
+
1, max_len),
|
147 |
+
max((ix - self.spatial_range) //
|
148 |
+
self.kv_stride, 0):min((ix + self.spatial_range +
|
149 |
+
1) // self.kv_stride +
|
150 |
+
1, max_len)] = 0
|
151 |
+
|
152 |
+
self.local_constraint_map = nn.Parameter(
|
153 |
+
torch.from_numpy(local_constraint_map).byte(),
|
154 |
+
requires_grad=False)
|
155 |
+
|
156 |
+
if self.q_stride > 1:
|
157 |
+
self.q_downsample = nn.AvgPool2d(
|
158 |
+
kernel_size=1, stride=self.q_stride)
|
159 |
+
else:
|
160 |
+
self.q_downsample = None
|
161 |
+
|
162 |
+
if self.kv_stride > 1:
|
163 |
+
self.kv_downsample = nn.AvgPool2d(
|
164 |
+
kernel_size=1, stride=self.kv_stride)
|
165 |
+
else:
|
166 |
+
self.kv_downsample = None
|
167 |
+
|
168 |
+
self.init_weights()
|
169 |
+
|
170 |
+
def get_position_embedding(self,
|
171 |
+
h,
|
172 |
+
w,
|
173 |
+
h_kv,
|
174 |
+
w_kv,
|
175 |
+
q_stride,
|
176 |
+
kv_stride,
|
177 |
+
device,
|
178 |
+
dtype,
|
179 |
+
feat_dim,
|
180 |
+
wave_length=1000):
|
181 |
+
# the default type of Tensor is float32, leading to type mismatch
|
182 |
+
# in fp16 mode. Cast it to support fp16 mode.
|
183 |
+
h_idxs = torch.linspace(0, h - 1, h).to(device=device, dtype=dtype)
|
184 |
+
h_idxs = h_idxs.view((h, 1)) * q_stride
|
185 |
+
|
186 |
+
w_idxs = torch.linspace(0, w - 1, w).to(device=device, dtype=dtype)
|
187 |
+
w_idxs = w_idxs.view((w, 1)) * q_stride
|
188 |
+
|
189 |
+
h_kv_idxs = torch.linspace(0, h_kv - 1, h_kv).to(
|
190 |
+
device=device, dtype=dtype)
|
191 |
+
h_kv_idxs = h_kv_idxs.view((h_kv, 1)) * kv_stride
|
192 |
+
|
193 |
+
w_kv_idxs = torch.linspace(0, w_kv - 1, w_kv).to(
|
194 |
+
device=device, dtype=dtype)
|
195 |
+
w_kv_idxs = w_kv_idxs.view((w_kv, 1)) * kv_stride
|
196 |
+
|
197 |
+
# (h, h_kv, 1)
|
198 |
+
h_diff = h_idxs.unsqueeze(1) - h_kv_idxs.unsqueeze(0)
|
199 |
+
h_diff *= self.position_magnitude
|
200 |
+
|
201 |
+
# (w, w_kv, 1)
|
202 |
+
w_diff = w_idxs.unsqueeze(1) - w_kv_idxs.unsqueeze(0)
|
203 |
+
w_diff *= self.position_magnitude
|
204 |
+
|
205 |
+
feat_range = torch.arange(0, feat_dim / 4).to(
|
206 |
+
device=device, dtype=dtype)
|
207 |
+
|
208 |
+
dim_mat = torch.Tensor([wave_length]).to(device=device, dtype=dtype)
|
209 |
+
dim_mat = dim_mat**((4. / feat_dim) * feat_range)
|
210 |
+
dim_mat = dim_mat.view((1, 1, -1))
|
211 |
+
|
212 |
+
embedding_x = torch.cat(
|
213 |
+
((w_diff / dim_mat).sin(), (w_diff / dim_mat).cos()), dim=2)
|
214 |
+
|
215 |
+
embedding_y = torch.cat(
|
216 |
+
((h_diff / dim_mat).sin(), (h_diff / dim_mat).cos()), dim=2)
|
217 |
+
|
218 |
+
return embedding_x, embedding_y
|
219 |
+
|
220 |
+
def forward(self, x_input: torch.Tensor) -> torch.Tensor:
|
221 |
+
num_heads = self.num_heads
|
222 |
+
|
223 |
+
# use empirical_attention
|
224 |
+
if self.q_downsample is not None:
|
225 |
+
x_q = self.q_downsample(x_input)
|
226 |
+
else:
|
227 |
+
x_q = x_input
|
228 |
+
n, _, h, w = x_q.shape
|
229 |
+
|
230 |
+
if self.kv_downsample is not None:
|
231 |
+
x_kv = self.kv_downsample(x_input)
|
232 |
+
else:
|
233 |
+
x_kv = x_input
|
234 |
+
_, _, h_kv, w_kv = x_kv.shape
|
235 |
+
|
236 |
+
if self.attention_type[0] or self.attention_type[1]:
|
237 |
+
proj_query = self.query_conv(x_q).view(
|
238 |
+
(n, num_heads, self.qk_embed_dim, h * w))
|
239 |
+
proj_query = proj_query.permute(0, 1, 3, 2)
|
240 |
+
|
241 |
+
if self.attention_type[0] or self.attention_type[2]:
|
242 |
+
proj_key = self.key_conv(x_kv).view(
|
243 |
+
(n, num_heads, self.qk_embed_dim, h_kv * w_kv))
|
244 |
+
|
245 |
+
if self.attention_type[1] or self.attention_type[3]:
|
246 |
+
position_embed_x, position_embed_y = self.get_position_embedding(
|
247 |
+
h, w, h_kv, w_kv, self.q_stride, self.kv_stride,
|
248 |
+
x_input.device, x_input.dtype, self.position_embedding_dim)
|
249 |
+
# (n, num_heads, w, w_kv, dim)
|
250 |
+
position_feat_x = self.appr_geom_fc_x(position_embed_x).\
|
251 |
+
view(1, w, w_kv, num_heads, self.qk_embed_dim).\
|
252 |
+
permute(0, 3, 1, 2, 4).\
|
253 |
+
repeat(n, 1, 1, 1, 1)
|
254 |
+
|
255 |
+
# (n, num_heads, h, h_kv, dim)
|
256 |
+
position_feat_y = self.appr_geom_fc_y(position_embed_y).\
|
257 |
+
view(1, h, h_kv, num_heads, self.qk_embed_dim).\
|
258 |
+
permute(0, 3, 1, 2, 4).\
|
259 |
+
repeat(n, 1, 1, 1, 1)
|
260 |
+
|
261 |
+
position_feat_x /= math.sqrt(2)
|
262 |
+
position_feat_y /= math.sqrt(2)
|
263 |
+
|
264 |
+
# accelerate for saliency only
|
265 |
+
if (np.sum(self.attention_type) == 1) and self.attention_type[2]:
|
266 |
+
appr_bias = self.appr_bias.\
|
267 |
+
view(1, num_heads, 1, self.qk_embed_dim).\
|
268 |
+
repeat(n, 1, 1, 1)
|
269 |
+
|
270 |
+
energy = torch.matmul(appr_bias, proj_key).\
|
271 |
+
view(n, num_heads, 1, h_kv * w_kv)
|
272 |
+
|
273 |
+
h = 1
|
274 |
+
w = 1
|
275 |
+
else:
|
276 |
+
# (n, num_heads, h*w, h_kv*w_kv), query before key, 540mb for
|
277 |
+
if not self.attention_type[0]:
|
278 |
+
energy = torch.zeros(
|
279 |
+
n,
|
280 |
+
num_heads,
|
281 |
+
h,
|
282 |
+
w,
|
283 |
+
h_kv,
|
284 |
+
w_kv,
|
285 |
+
dtype=x_input.dtype,
|
286 |
+
device=x_input.device)
|
287 |
+
|
288 |
+
# attention_type[0]: appr - appr
|
289 |
+
# attention_type[1]: appr - position
|
290 |
+
# attention_type[2]: bias - appr
|
291 |
+
# attention_type[3]: bias - position
|
292 |
+
if self.attention_type[0] or self.attention_type[2]:
|
293 |
+
if self.attention_type[0] and self.attention_type[2]:
|
294 |
+
appr_bias = self.appr_bias.\
|
295 |
+
view(1, num_heads, 1, self.qk_embed_dim)
|
296 |
+
energy = torch.matmul(proj_query + appr_bias, proj_key).\
|
297 |
+
view(n, num_heads, h, w, h_kv, w_kv)
|
298 |
+
|
299 |
+
elif self.attention_type[0]:
|
300 |
+
energy = torch.matmul(proj_query, proj_key).\
|
301 |
+
view(n, num_heads, h, w, h_kv, w_kv)
|
302 |
+
|
303 |
+
elif self.attention_type[2]:
|
304 |
+
appr_bias = self.appr_bias.\
|
305 |
+
view(1, num_heads, 1, self.qk_embed_dim).\
|
306 |
+
repeat(n, 1, 1, 1)
|
307 |
+
|
308 |
+
energy += torch.matmul(appr_bias, proj_key).\
|
309 |
+
view(n, num_heads, 1, 1, h_kv, w_kv)
|
310 |
+
|
311 |
+
if self.attention_type[1] or self.attention_type[3]:
|
312 |
+
if self.attention_type[1] and self.attention_type[3]:
|
313 |
+
geom_bias = self.geom_bias.\
|
314 |
+
view(1, num_heads, 1, self.qk_embed_dim)
|
315 |
+
|
316 |
+
proj_query_reshape = (proj_query + geom_bias).\
|
317 |
+
view(n, num_heads, h, w, self.qk_embed_dim)
|
318 |
+
|
319 |
+
energy_x = torch.matmul(
|
320 |
+
proj_query_reshape.permute(0, 1, 3, 2, 4),
|
321 |
+
position_feat_x.permute(0, 1, 2, 4, 3))
|
322 |
+
energy_x = energy_x.\
|
323 |
+
permute(0, 1, 3, 2, 4).unsqueeze(4)
|
324 |
+
|
325 |
+
energy_y = torch.matmul(
|
326 |
+
proj_query_reshape,
|
327 |
+
position_feat_y.permute(0, 1, 2, 4, 3))
|
328 |
+
energy_y = energy_y.unsqueeze(5)
|
329 |
+
|
330 |
+
energy += energy_x + energy_y
|
331 |
+
|
332 |
+
elif self.attention_type[1]:
|
333 |
+
proj_query_reshape = proj_query.\
|
334 |
+
view(n, num_heads, h, w, self.qk_embed_dim)
|
335 |
+
proj_query_reshape = proj_query_reshape.\
|
336 |
+
permute(0, 1, 3, 2, 4)
|
337 |
+
position_feat_x_reshape = position_feat_x.\
|
338 |
+
permute(0, 1, 2, 4, 3)
|
339 |
+
position_feat_y_reshape = position_feat_y.\
|
340 |
+
permute(0, 1, 2, 4, 3)
|
341 |
+
|
342 |
+
energy_x = torch.matmul(proj_query_reshape,
|
343 |
+
position_feat_x_reshape)
|
344 |
+
energy_x = energy_x.permute(0, 1, 3, 2, 4).unsqueeze(4)
|
345 |
+
|
346 |
+
energy_y = torch.matmul(proj_query_reshape,
|
347 |
+
position_feat_y_reshape)
|
348 |
+
energy_y = energy_y.unsqueeze(5)
|
349 |
+
|
350 |
+
energy += energy_x + energy_y
|
351 |
+
|
352 |
+
elif self.attention_type[3]:
|
353 |
+
geom_bias = self.geom_bias.\
|
354 |
+
view(1, num_heads, self.qk_embed_dim, 1).\
|
355 |
+
repeat(n, 1, 1, 1)
|
356 |
+
|
357 |
+
position_feat_x_reshape = position_feat_x.\
|
358 |
+
view(n, num_heads, w * w_kv, self.qk_embed_dim)
|
359 |
+
|
360 |
+
position_feat_y_reshape = position_feat_y.\
|
361 |
+
view(n, num_heads, h * h_kv, self.qk_embed_dim)
|
362 |
+
|
363 |
+
energy_x = torch.matmul(position_feat_x_reshape, geom_bias)
|
364 |
+
energy_x = energy_x.view(n, num_heads, 1, w, 1, w_kv)
|
365 |
+
|
366 |
+
energy_y = torch.matmul(position_feat_y_reshape, geom_bias)
|
367 |
+
energy_y = energy_y.view(n, num_heads, h, 1, h_kv, 1)
|
368 |
+
|
369 |
+
energy += energy_x + energy_y
|
370 |
+
|
371 |
+
energy = energy.view(n, num_heads, h * w, h_kv * w_kv)
|
372 |
+
|
373 |
+
if self.spatial_range >= 0:
|
374 |
+
cur_local_constraint_map = \
|
375 |
+
self.local_constraint_map[:h, :w, :h_kv, :w_kv].\
|
376 |
+
contiguous().\
|
377 |
+
view(1, 1, h*w, h_kv*w_kv)
|
378 |
+
|
379 |
+
energy = energy.masked_fill_(cur_local_constraint_map,
|
380 |
+
float('-inf'))
|
381 |
+
|
382 |
+
attention = F.softmax(energy, 3)
|
383 |
+
|
384 |
+
proj_value = self.value_conv(x_kv)
|
385 |
+
proj_value_reshape = proj_value.\
|
386 |
+
view((n, num_heads, self.v_dim, h_kv * w_kv)).\
|
387 |
+
permute(0, 1, 3, 2)
|
388 |
+
|
389 |
+
out = torch.matmul(attention, proj_value_reshape).\
|
390 |
+
permute(0, 1, 3, 2).\
|
391 |
+
contiguous().\
|
392 |
+
view(n, self.v_dim * self.num_heads, h, w)
|
393 |
+
|
394 |
+
out = self.proj_conv(out)
|
395 |
+
|
396 |
+
# output is downsampled, upsample back to input size
|
397 |
+
if self.q_downsample is not None:
|
398 |
+
out = F.interpolate(
|
399 |
+
out,
|
400 |
+
size=x_input.shape[2:],
|
401 |
+
mode='bilinear',
|
402 |
+
align_corners=False)
|
403 |
+
|
404 |
+
out = self.gamma * out + x_input
|
405 |
+
return out
|
406 |
+
|
407 |
+
def init_weights(self):
|
408 |
+
for m in self.modules():
|
409 |
+
if hasattr(m, 'kaiming_init') and m.kaiming_init:
|
410 |
+
kaiming_init(
|
411 |
+
m,
|
412 |
+
mode='fan_in',
|
413 |
+
nonlinearity='leaky_relu',
|
414 |
+
bias=0,
|
415 |
+
distribution='uniform',
|
416 |
+
a=1)
|
external/cv/mmcv/cnn/bricks/hsigmoid.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import warnings
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from mmengine.registry import MODELS
|
12 |
+
|
13 |
+
|
14 |
+
@MODELS.register_module()
|
15 |
+
class HSigmoid(nn.Module):
|
16 |
+
"""Hard Sigmoid Module. Apply the hard sigmoid function:
|
17 |
+
Hsigmoid(x) = min(max((x + bias) / divisor, min_value), max_value)
|
18 |
+
Default: Hsigmoid(x) = min(max((x + 3) / 6, 0), 1)
|
19 |
+
|
20 |
+
Note:
|
21 |
+
In MMCV v1.4.4, we modified the default value of args to align with
|
22 |
+
PyTorch official.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
bias (float): Bias of the input feature map. Default: 3.0.
|
26 |
+
divisor (float): Divisor of the input feature map. Default: 6.0.
|
27 |
+
min_value (float): Lower bound value. Default: 0.0.
|
28 |
+
max_value (float): Upper bound value. Default: 1.0.
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
Tensor: The output tensor.
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self,
|
35 |
+
bias: float = 3.0,
|
36 |
+
divisor: float = 6.0,
|
37 |
+
min_value: float = 0.0,
|
38 |
+
max_value: float = 1.0):
|
39 |
+
super().__init__()
|
40 |
+
warnings.warn(
|
41 |
+
'In MMCV v1.4.4, we modified the default value of args to align '
|
42 |
+
'with PyTorch official. Previous Implementation: '
|
43 |
+
'Hsigmoid(x) = min(max((x + 1) / 2, 0), 1). '
|
44 |
+
'Current Implementation: '
|
45 |
+
'Hsigmoid(x) = min(max((x + 3) / 6, 0), 1).')
|
46 |
+
self.bias = bias
|
47 |
+
self.divisor = divisor
|
48 |
+
assert self.divisor != 0
|
49 |
+
self.min_value = min_value
|
50 |
+
self.max_value = max_value
|
51 |
+
|
52 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
53 |
+
x = (x + self.bias) / self.divisor
|
54 |
+
|
55 |
+
return x.clamp_(self.min_value, self.max_value)
|
external/cv/mmcv/cnn/bricks/hswish.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from mmengine.registry import MODELS
|
10 |
+
from mmengine.utils import digit_version
|
11 |
+
from mmengine.utils.dl_utils import TORCH_VERSION
|
12 |
+
|
13 |
+
|
14 |
+
class HSwish(nn.Module):
|
15 |
+
"""Hard Swish Module.
|
16 |
+
|
17 |
+
This module applies the hard swish function:
|
18 |
+
|
19 |
+
.. math::
|
20 |
+
Hswish(x) = x * ReLU6(x + 3) / 6
|
21 |
+
|
22 |
+
Args:
|
23 |
+
inplace (bool): can optionally do the operation in-place.
|
24 |
+
Default: False.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
Tensor: The output tensor.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self, inplace: bool = False):
|
31 |
+
super().__init__()
|
32 |
+
self.act = nn.ReLU6(inplace)
|
33 |
+
|
34 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
35 |
+
return x * self.act(x + 3) / 6
|
36 |
+
|
37 |
+
|
38 |
+
if (TORCH_VERSION == 'parrots'
|
39 |
+
or digit_version(TORCH_VERSION) < digit_version('1.7')):
|
40 |
+
# Hardswish is not supported when PyTorch version < 1.6.
|
41 |
+
# And Hardswish in PyTorch 1.6 does not support inplace.
|
42 |
+
MODELS.register_module(module=HSwish)
|
43 |
+
else:
|
44 |
+
MODELS.register_module(module=nn.Hardswish, name='HSwish')
|
external/cv/mmcv/cnn/bricks/non_local.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from abc import ABCMeta
|
8 |
+
from typing import Dict, Optional
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
from mmengine.model import constant_init, normal_init
|
13 |
+
from mmengine.registry import MODELS
|
14 |
+
|
15 |
+
from .conv_module import ConvModule
|
16 |
+
|
17 |
+
|
18 |
+
class _NonLocalNd(nn.Module, metaclass=ABCMeta):
|
19 |
+
"""Basic Non-local module.
|
20 |
+
|
21 |
+
This module is proposed in
|
22 |
+
"Non-local Neural Networks"
|
23 |
+
Paper reference: https://arxiv.org/abs/1711.07971
|
24 |
+
Code reference: https://github.com/AlexHex7/Non-local_pytorch
|
25 |
+
|
26 |
+
Args:
|
27 |
+
in_channels (int): Channels of the input feature map.
|
28 |
+
reduction (int): Channel reduction ratio. Default: 2.
|
29 |
+
use_scale (bool): Whether to scale pairwise_weight by
|
30 |
+
`1/sqrt(inter_channels)` when the mode is `embedded_gaussian`.
|
31 |
+
Default: True.
|
32 |
+
conv_cfg (None | dict): The config dict for convolution layers.
|
33 |
+
If not specified, it will use `nn.Conv2d` for convolution layers.
|
34 |
+
Default: None.
|
35 |
+
norm_cfg (None | dict): The config dict for normalization layers.
|
36 |
+
Default: None. (This parameter is only applicable to conv_out.)
|
37 |
+
mode (str): Options are `gaussian`, `concatenation`,
|
38 |
+
`embedded_gaussian` and `dot_product`. Default: embedded_gaussian.
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(self,
|
42 |
+
in_channels: int,
|
43 |
+
reduction: int = 2,
|
44 |
+
use_scale: bool = True,
|
45 |
+
conv_cfg: Optional[Dict] = None,
|
46 |
+
norm_cfg: Optional[Dict] = None,
|
47 |
+
mode: str = 'embedded_gaussian',
|
48 |
+
**kwargs):
|
49 |
+
super().__init__()
|
50 |
+
self.in_channels = in_channels
|
51 |
+
self.reduction = reduction
|
52 |
+
self.use_scale = use_scale
|
53 |
+
self.inter_channels = max(in_channels // reduction, 1)
|
54 |
+
self.mode = mode
|
55 |
+
|
56 |
+
if mode not in [
|
57 |
+
'gaussian', 'embedded_gaussian', 'dot_product', 'concatenation'
|
58 |
+
]:
|
59 |
+
raise ValueError("Mode should be in 'gaussian', 'concatenation', "
|
60 |
+
f"'embedded_gaussian' or 'dot_product', but got "
|
61 |
+
f'{mode} instead.')
|
62 |
+
|
63 |
+
# g, theta, phi are defaulted as `nn.ConvNd`.
|
64 |
+
# Here we use ConvModule for potential usage.
|
65 |
+
self.g = ConvModule(
|
66 |
+
self.in_channels,
|
67 |
+
self.inter_channels,
|
68 |
+
kernel_size=1,
|
69 |
+
conv_cfg=conv_cfg,
|
70 |
+
act_cfg=None) # type: ignore
|
71 |
+
self.conv_out = ConvModule(
|
72 |
+
self.inter_channels,
|
73 |
+
self.in_channels,
|
74 |
+
kernel_size=1,
|
75 |
+
conv_cfg=conv_cfg,
|
76 |
+
norm_cfg=norm_cfg,
|
77 |
+
act_cfg=None)
|
78 |
+
|
79 |
+
if self.mode != 'gaussian':
|
80 |
+
self.theta = ConvModule(
|
81 |
+
self.in_channels,
|
82 |
+
self.inter_channels,
|
83 |
+
kernel_size=1,
|
84 |
+
conv_cfg=conv_cfg,
|
85 |
+
act_cfg=None)
|
86 |
+
self.phi = ConvModule(
|
87 |
+
self.in_channels,
|
88 |
+
self.inter_channels,
|
89 |
+
kernel_size=1,
|
90 |
+
conv_cfg=conv_cfg,
|
91 |
+
act_cfg=None)
|
92 |
+
|
93 |
+
if self.mode == 'concatenation':
|
94 |
+
self.concat_project = ConvModule(
|
95 |
+
self.inter_channels * 2,
|
96 |
+
1,
|
97 |
+
kernel_size=1,
|
98 |
+
stride=1,
|
99 |
+
padding=0,
|
100 |
+
bias=False,
|
101 |
+
act_cfg=dict(type='ReLU'))
|
102 |
+
|
103 |
+
self.init_weights(**kwargs)
|
104 |
+
|
105 |
+
def init_weights(self, std: float = 0.01, zeros_init: bool = True) -> None:
|
106 |
+
if self.mode != 'gaussian':
|
107 |
+
for m in [self.g, self.theta, self.phi]:
|
108 |
+
normal_init(m.conv, std=std)
|
109 |
+
else:
|
110 |
+
normal_init(self.g.conv, std=std)
|
111 |
+
if zeros_init:
|
112 |
+
if self.conv_out.norm_cfg is None:
|
113 |
+
constant_init(self.conv_out.conv, 0)
|
114 |
+
else:
|
115 |
+
constant_init(self.conv_out.norm, 0)
|
116 |
+
else:
|
117 |
+
if self.conv_out.norm_cfg is None:
|
118 |
+
normal_init(self.conv_out.conv, std=std)
|
119 |
+
else:
|
120 |
+
normal_init(self.conv_out.norm, std=std)
|
121 |
+
|
122 |
+
def gaussian(self, theta_x: torch.Tensor,
|
123 |
+
phi_x: torch.Tensor) -> torch.Tensor:
|
124 |
+
# NonLocal1d pairwise_weight: [N, H, H]
|
125 |
+
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
126 |
+
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
|
127 |
+
pairwise_weight = torch.matmul(theta_x, phi_x)
|
128 |
+
pairwise_weight = pairwise_weight.softmax(dim=-1)
|
129 |
+
return pairwise_weight
|
130 |
+
|
131 |
+
def embedded_gaussian(self, theta_x: torch.Tensor,
|
132 |
+
phi_x: torch.Tensor) -> torch.Tensor:
|
133 |
+
# NonLocal1d pairwise_weight: [N, H, H]
|
134 |
+
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
135 |
+
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
|
136 |
+
pairwise_weight = torch.matmul(theta_x, phi_x)
|
137 |
+
if self.use_scale:
|
138 |
+
# theta_x.shape[-1] is `self.inter_channels`
|
139 |
+
pairwise_weight /= theta_x.shape[-1]**0.5
|
140 |
+
pairwise_weight = pairwise_weight.softmax(dim=-1)
|
141 |
+
return pairwise_weight
|
142 |
+
|
143 |
+
def dot_product(self, theta_x: torch.Tensor,
|
144 |
+
phi_x: torch.Tensor) -> torch.Tensor:
|
145 |
+
# NonLocal1d pairwise_weight: [N, H, H]
|
146 |
+
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
147 |
+
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
|
148 |
+
pairwise_weight = torch.matmul(theta_x, phi_x)
|
149 |
+
pairwise_weight /= pairwise_weight.shape[-1]
|
150 |
+
return pairwise_weight
|
151 |
+
|
152 |
+
def concatenation(self, theta_x: torch.Tensor,
|
153 |
+
phi_x: torch.Tensor) -> torch.Tensor:
|
154 |
+
# NonLocal1d pairwise_weight: [N, H, H]
|
155 |
+
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
156 |
+
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
|
157 |
+
h = theta_x.size(2)
|
158 |
+
w = phi_x.size(3)
|
159 |
+
theta_x = theta_x.repeat(1, 1, 1, w)
|
160 |
+
phi_x = phi_x.repeat(1, 1, h, 1)
|
161 |
+
|
162 |
+
concat_feature = torch.cat([theta_x, phi_x], dim=1)
|
163 |
+
pairwise_weight = self.concat_project(concat_feature)
|
164 |
+
n, _, h, w = pairwise_weight.size()
|
165 |
+
pairwise_weight = pairwise_weight.view(n, h, w)
|
166 |
+
pairwise_weight /= pairwise_weight.shape[-1]
|
167 |
+
|
168 |
+
return pairwise_weight
|
169 |
+
|
170 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
171 |
+
# Assume `reduction = 1`, then `inter_channels = C`
|
172 |
+
# or `inter_channels = C` when `mode="gaussian"`
|
173 |
+
|
174 |
+
# NonLocal1d x: [N, C, H]
|
175 |
+
# NonLocal2d x: [N, C, H, W]
|
176 |
+
# NonLocal3d x: [N, C, T, H, W]
|
177 |
+
n = x.size(0)
|
178 |
+
|
179 |
+
# NonLocal1d g_x: [N, H, C]
|
180 |
+
# NonLocal2d g_x: [N, HxW, C]
|
181 |
+
# NonLocal3d g_x: [N, TxHxW, C]
|
182 |
+
g_x = self.g(x).view(n, self.inter_channels, -1)
|
183 |
+
g_x = g_x.permute(0, 2, 1)
|
184 |
+
|
185 |
+
# NonLocal1d theta_x: [N, H, C], phi_x: [N, C, H]
|
186 |
+
# NonLocal2d theta_x: [N, HxW, C], phi_x: [N, C, HxW]
|
187 |
+
# NonLocal3d theta_x: [N, TxHxW, C], phi_x: [N, C, TxHxW]
|
188 |
+
if self.mode == 'gaussian':
|
189 |
+
theta_x = x.view(n, self.in_channels, -1)
|
190 |
+
theta_x = theta_x.permute(0, 2, 1)
|
191 |
+
if self.sub_sample:
|
192 |
+
phi_x = self.phi(x).view(n, self.in_channels, -1)
|
193 |
+
else:
|
194 |
+
phi_x = x.view(n, self.in_channels, -1)
|
195 |
+
elif self.mode == 'concatenation':
|
196 |
+
theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
|
197 |
+
phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
|
198 |
+
else:
|
199 |
+
theta_x = self.theta(x).view(n, self.inter_channels, -1)
|
200 |
+
theta_x = theta_x.permute(0, 2, 1)
|
201 |
+
phi_x = self.phi(x).view(n, self.inter_channels, -1)
|
202 |
+
|
203 |
+
pairwise_func = getattr(self, self.mode)
|
204 |
+
# NonLocal1d pairwise_weight: [N, H, H]
|
205 |
+
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
206 |
+
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
|
207 |
+
pairwise_weight = pairwise_func(theta_x, phi_x)
|
208 |
+
|
209 |
+
# NonLocal1d y: [N, H, C]
|
210 |
+
# NonLocal2d y: [N, HxW, C]
|
211 |
+
# NonLocal3d y: [N, TxHxW, C]
|
212 |
+
y = torch.matmul(pairwise_weight, g_x)
|
213 |
+
# NonLocal1d y: [N, C, H]
|
214 |
+
# NonLocal2d y: [N, C, H, W]
|
215 |
+
# NonLocal3d y: [N, C, T, H, W]
|
216 |
+
y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
|
217 |
+
*x.size()[2:])
|
218 |
+
|
219 |
+
output = x + self.conv_out(y)
|
220 |
+
|
221 |
+
return output
|
222 |
+
|
223 |
+
|
224 |
+
class NonLocal1d(_NonLocalNd):
|
225 |
+
"""1D Non-local module.
|
226 |
+
|
227 |
+
Args:
|
228 |
+
in_channels (int): Same as `NonLocalND`.
|
229 |
+
sub_sample (bool): Whether to apply max pooling after pairwise
|
230 |
+
function (Note that the `sub_sample` is applied on spatial only).
|
231 |
+
Default: False.
|
232 |
+
conv_cfg (None | dict): Same as `NonLocalND`.
|
233 |
+
Default: dict(type='Conv1d').
|
234 |
+
"""
|
235 |
+
|
236 |
+
def __init__(self,
|
237 |
+
in_channels: int,
|
238 |
+
sub_sample: bool = False,
|
239 |
+
conv_cfg: Dict = dict(type='Conv1d'),
|
240 |
+
**kwargs):
|
241 |
+
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
|
242 |
+
|
243 |
+
self.sub_sample = sub_sample
|
244 |
+
|
245 |
+
if sub_sample:
|
246 |
+
max_pool_layer = nn.MaxPool1d(kernel_size=2)
|
247 |
+
self.g = nn.Sequential(self.g, max_pool_layer)
|
248 |
+
if self.mode != 'gaussian':
|
249 |
+
self.phi = nn.Sequential(self.phi, max_pool_layer)
|
250 |
+
else:
|
251 |
+
self.phi = max_pool_layer
|
252 |
+
|
253 |
+
|
254 |
+
@MODELS.register_module()
|
255 |
+
class NonLocal2d(_NonLocalNd):
|
256 |
+
"""2D Non-local module.
|
257 |
+
|
258 |
+
Args:
|
259 |
+
in_channels (int): Same as `NonLocalND`.
|
260 |
+
sub_sample (bool): Whether to apply max pooling after pairwise
|
261 |
+
function (Note that the `sub_sample` is applied on spatial only).
|
262 |
+
Default: False.
|
263 |
+
conv_cfg (None | dict): Same as `NonLocalND`.
|
264 |
+
Default: dict(type='Conv2d').
|
265 |
+
"""
|
266 |
+
|
267 |
+
_abbr_ = 'nonlocal_block'
|
268 |
+
|
269 |
+
def __init__(self,
|
270 |
+
in_channels: int,
|
271 |
+
sub_sample: bool = False,
|
272 |
+
conv_cfg: Dict = dict(type='Conv2d'),
|
273 |
+
**kwargs):
|
274 |
+
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
|
275 |
+
|
276 |
+
self.sub_sample = sub_sample
|
277 |
+
|
278 |
+
if sub_sample:
|
279 |
+
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
|
280 |
+
self.g = nn.Sequential(self.g, max_pool_layer)
|
281 |
+
if self.mode != 'gaussian':
|
282 |
+
self.phi = nn.Sequential(self.phi, max_pool_layer)
|
283 |
+
else:
|
284 |
+
self.phi = max_pool_layer
|
285 |
+
|
286 |
+
|
287 |
+
class NonLocal3d(_NonLocalNd):
|
288 |
+
"""3D Non-local module.
|
289 |
+
|
290 |
+
Args:
|
291 |
+
in_channels (int): Same as `NonLocalND`.
|
292 |
+
sub_sample (bool): Whether to apply max pooling after pairwise
|
293 |
+
function (Note that the `sub_sample` is applied on spatial only).
|
294 |
+
Default: False.
|
295 |
+
conv_cfg (None | dict): Same as `NonLocalND`.
|
296 |
+
Default: dict(type='Conv3d').
|
297 |
+
"""
|
298 |
+
|
299 |
+
def __init__(self,
|
300 |
+
in_channels: int,
|
301 |
+
sub_sample: bool = False,
|
302 |
+
conv_cfg: Dict = dict(type='Conv3d'),
|
303 |
+
**kwargs):
|
304 |
+
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
|
305 |
+
self.sub_sample = sub_sample
|
306 |
+
|
307 |
+
if sub_sample:
|
308 |
+
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
|
309 |
+
self.g = nn.Sequential(self.g, max_pool_layer)
|
310 |
+
if self.mode != 'gaussian':
|
311 |
+
self.phi = nn.Sequential(self.phi, max_pool_layer)
|
312 |
+
else:
|
313 |
+
self.phi = max_pool_layer
|
external/cv/mmcv/cnn/bricks/norm.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import inspect
|
8 |
+
from typing import Dict, Tuple, Union
|
9 |
+
|
10 |
+
import torch.nn as nn
|
11 |
+
from mmengine.registry import MODELS
|
12 |
+
from mmengine.utils import is_tuple_of
|
13 |
+
from mmengine.utils.dl_utils.parrots_wrapper import (SyncBatchNorm, _BatchNorm,
|
14 |
+
_InstanceNorm)
|
15 |
+
|
16 |
+
MODELS.register_module('BN', module=nn.BatchNorm2d)
|
17 |
+
MODELS.register_module('BN1d', module=nn.BatchNorm1d)
|
18 |
+
MODELS.register_module('BN2d', module=nn.BatchNorm2d)
|
19 |
+
MODELS.register_module('BN3d', module=nn.BatchNorm3d)
|
20 |
+
MODELS.register_module('SyncBN', module=SyncBatchNorm)
|
21 |
+
MODELS.register_module('GN', module=nn.GroupNorm)
|
22 |
+
MODELS.register_module('LN', module=nn.LayerNorm)
|
23 |
+
MODELS.register_module('IN', module=nn.InstanceNorm2d)
|
24 |
+
MODELS.register_module('IN1d', module=nn.InstanceNorm1d)
|
25 |
+
MODELS.register_module('IN2d', module=nn.InstanceNorm2d)
|
26 |
+
MODELS.register_module('IN3d', module=nn.InstanceNorm3d)
|
27 |
+
|
28 |
+
|
29 |
+
def infer_abbr(class_type):
|
30 |
+
"""Infer abbreviation from the class name.
|
31 |
+
|
32 |
+
When we build a norm layer with `build_norm_layer()`, we want to preserve
|
33 |
+
the norm type in variable names, e.g, self.bn1, self.gn. This method will
|
34 |
+
infer the abbreviation to map class types to abbreviations.
|
35 |
+
|
36 |
+
Rule 1: If the class has the property "_abbr_", return the property.
|
37 |
+
Rule 2: If the parent class is _BatchNorm, GroupNorm, LayerNorm or
|
38 |
+
InstanceNorm, the abbreviation of this layer will be "bn", "gn", "ln" and
|
39 |
+
"in" respectively.
|
40 |
+
Rule 3: If the class name contains "batch", "group", "layer" or "instance",
|
41 |
+
the abbreviation of this layer will be "bn", "gn", "ln" and "in"
|
42 |
+
respectively.
|
43 |
+
Rule 4: Otherwise, the abbreviation falls back to "norm".
|
44 |
+
|
45 |
+
Args:
|
46 |
+
class_type (type): The norm layer type.
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
str: The inferred abbreviation.
|
50 |
+
"""
|
51 |
+
if not inspect.isclass(class_type):
|
52 |
+
raise TypeError(
|
53 |
+
f'class_type must be a type, but got {type(class_type)}')
|
54 |
+
if hasattr(class_type, '_abbr_'):
|
55 |
+
return class_type._abbr_
|
56 |
+
if issubclass(class_type, _InstanceNorm): # IN is a subclass of BN
|
57 |
+
return 'in'
|
58 |
+
elif issubclass(class_type, _BatchNorm):
|
59 |
+
return 'bn'
|
60 |
+
elif issubclass(class_type, nn.GroupNorm):
|
61 |
+
return 'gn'
|
62 |
+
elif issubclass(class_type, nn.LayerNorm):
|
63 |
+
return 'ln'
|
64 |
+
else:
|
65 |
+
class_name = class_type.__name__.lower()
|
66 |
+
if 'batch' in class_name:
|
67 |
+
return 'bn'
|
68 |
+
elif 'group' in class_name:
|
69 |
+
return 'gn'
|
70 |
+
elif 'layer' in class_name:
|
71 |
+
return 'ln'
|
72 |
+
elif 'instance' in class_name:
|
73 |
+
return 'in'
|
74 |
+
else:
|
75 |
+
return 'norm_layer'
|
76 |
+
|
77 |
+
|
78 |
+
def build_norm_layer(cfg: Dict,
|
79 |
+
num_features: int,
|
80 |
+
postfix: Union[int, str] = '') -> Tuple[str, nn.Module]:
|
81 |
+
"""Build normalization layer.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
cfg (dict): The norm layer config, which should contain:
|
85 |
+
|
86 |
+
- type (str): Layer type.
|
87 |
+
- layer args: Args needed to instantiate a norm layer.
|
88 |
+
- requires_grad (bool, optional): Whether stop gradient updates.
|
89 |
+
num_features (int): Number of input channels.
|
90 |
+
postfix (int | str): The postfix to be appended into norm abbreviation
|
91 |
+
to create named layer.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
tuple[str, nn.Module]: The first element is the layer name consisting
|
95 |
+
of abbreviation and postfix, e.g., bn1, gn. The second element is the
|
96 |
+
created norm layer.
|
97 |
+
"""
|
98 |
+
if not isinstance(cfg, dict):
|
99 |
+
raise TypeError('cfg must be a dict')
|
100 |
+
if 'type' not in cfg:
|
101 |
+
raise KeyError('the cfg dict must contain the key "type"')
|
102 |
+
cfg_ = cfg.copy()
|
103 |
+
|
104 |
+
layer_type = cfg_.pop('type')
|
105 |
+
|
106 |
+
if inspect.isclass(layer_type):
|
107 |
+
norm_layer = layer_type
|
108 |
+
else:
|
109 |
+
# Switch registry to the target scope. If `norm_layer` cannot be found
|
110 |
+
# in the registry, fallback to search `norm_layer` in the
|
111 |
+
# mmengine.MODELS.
|
112 |
+
with MODELS.switch_scope_and_registry(None) as registry:
|
113 |
+
norm_layer = registry.get(layer_type)
|
114 |
+
if norm_layer is None:
|
115 |
+
raise KeyError(f'Cannot find {norm_layer} in registry under '
|
116 |
+
f'scope name {registry.scope}')
|
117 |
+
abbr = infer_abbr(norm_layer)
|
118 |
+
|
119 |
+
assert isinstance(postfix, (int, str))
|
120 |
+
name = abbr + str(postfix)
|
121 |
+
|
122 |
+
requires_grad = cfg_.pop('requires_grad', True)
|
123 |
+
cfg_.setdefault('eps', 1e-5)
|
124 |
+
if norm_layer is not nn.GroupNorm:
|
125 |
+
layer = norm_layer(num_features, **cfg_)
|
126 |
+
if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'):
|
127 |
+
layer._specify_ddp_gpu_num(1)
|
128 |
+
else:
|
129 |
+
assert 'num_groups' in cfg_
|
130 |
+
layer = norm_layer(num_channels=num_features, **cfg_)
|
131 |
+
|
132 |
+
for param in layer.parameters():
|
133 |
+
param.requires_grad = requires_grad
|
134 |
+
|
135 |
+
return name, layer
|
136 |
+
|
137 |
+
|
138 |
+
def is_norm(layer: nn.Module,
|
139 |
+
exclude: Union[type, tuple, None] = None) -> bool:
|
140 |
+
"""Check if a layer is a normalization layer.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
layer (nn.Module): The layer to be checked.
|
144 |
+
exclude (type | tuple[type]): Types to be excluded.
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
bool: Whether the layer is a norm layer.
|
148 |
+
"""
|
149 |
+
if exclude is not None:
|
150 |
+
if not isinstance(exclude, tuple):
|
151 |
+
exclude = (exclude, )
|
152 |
+
if not is_tuple_of(exclude, type):
|
153 |
+
raise TypeError(
|
154 |
+
f'"exclude" must be either None or type or a tuple of types, '
|
155 |
+
f'but got {type(exclude)}: {exclude}')
|
156 |
+
|
157 |
+
if exclude and isinstance(layer, exclude):
|
158 |
+
return False
|
159 |
+
|
160 |
+
all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm)
|
161 |
+
return isinstance(layer, all_norm_bases)
|
external/cv/mmcv/cnn/bricks/padding.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import inspect
|
8 |
+
from typing import Dict
|
9 |
+
|
10 |
+
import torch.nn as nn
|
11 |
+
from mmengine.registry import MODELS
|
12 |
+
|
13 |
+
MODELS.register_module('zero', module=nn.ZeroPad2d)
|
14 |
+
MODELS.register_module('reflect', module=nn.ReflectionPad2d)
|
15 |
+
MODELS.register_module('replicate', module=nn.ReplicationPad2d)
|
16 |
+
|
17 |
+
|
18 |
+
def build_padding_layer(cfg: Dict, *args, **kwargs) -> nn.Module:
|
19 |
+
"""Build padding layer.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
cfg (dict): The padding layer config, which should contain:
|
23 |
+
- type (str): Layer type.
|
24 |
+
- layer args: Args needed to instantiate a padding layer.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
nn.Module: Created padding layer.
|
28 |
+
"""
|
29 |
+
if not isinstance(cfg, dict):
|
30 |
+
raise TypeError('cfg must be a dict')
|
31 |
+
if 'type' not in cfg:
|
32 |
+
raise KeyError('the cfg dict must contain the key "type"')
|
33 |
+
|
34 |
+
cfg_ = cfg.copy()
|
35 |
+
padding_type = cfg_.pop('type')
|
36 |
+
if inspect.isclass(padding_type):
|
37 |
+
return padding_type(*args, **kwargs, **cfg_)
|
38 |
+
# Switch registry to the target scope. If `padding_layer` cannot be found
|
39 |
+
# in the registry, fallback to search `padding_layer` in the
|
40 |
+
# mmengine.MODELS.
|
41 |
+
with MODELS.switch_scope_and_registry(None) as registry:
|
42 |
+
padding_layer = registry.get(padding_type)
|
43 |
+
if padding_layer is None:
|
44 |
+
raise KeyError(f'Cannot find {padding_layer} in registry under scope '
|
45 |
+
f'name {registry.scope}')
|
46 |
+
layer = padding_layer(*args, **kwargs, **cfg_)
|
47 |
+
|
48 |
+
return layer
|
external/cv/mmcv/cnn/bricks/plugin.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import inspect
|
8 |
+
import platform
|
9 |
+
from typing import Dict, Tuple, Union
|
10 |
+
|
11 |
+
import torch.nn as nn
|
12 |
+
from mmengine.registry import MODELS
|
13 |
+
|
14 |
+
if platform.system() == 'Windows':
|
15 |
+
import regex as re # type: ignore
|
16 |
+
else:
|
17 |
+
import re # type: ignore
|
18 |
+
|
19 |
+
|
20 |
+
def infer_abbr(class_type: type) -> str:
|
21 |
+
"""Infer abbreviation from the class name.
|
22 |
+
|
23 |
+
This method will infer the abbreviation to map class types to
|
24 |
+
abbreviations.
|
25 |
+
|
26 |
+
Rule 1: If the class has the property "abbr", return the property.
|
27 |
+
Rule 2: Otherwise, the abbreviation falls back to snake case of class
|
28 |
+
name, e.g. the abbreviation of ``FancyBlock`` will be ``fancy_block``.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
class_type (type): The norm layer type.
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
str: The inferred abbreviation.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def camel2snack(word):
|
38 |
+
"""Convert camel case word into snack case.
|
39 |
+
|
40 |
+
Modified from `inflection lib
|
41 |
+
<https://inflection.readthedocs.io/en/latest/#inflection.underscore>`_.
|
42 |
+
|
43 |
+
Example::
|
44 |
+
|
45 |
+
>>> camel2snack("FancyBlock")
|
46 |
+
'fancy_block'
|
47 |
+
"""
|
48 |
+
|
49 |
+
word = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', word)
|
50 |
+
word = re.sub(r'([a-z\d])([A-Z])', r'\1_\2', word)
|
51 |
+
word = word.replace('-', '_')
|
52 |
+
return word.lower()
|
53 |
+
|
54 |
+
if not inspect.isclass(class_type):
|
55 |
+
raise TypeError(
|
56 |
+
f'class_type must be a type, but got {type(class_type)}')
|
57 |
+
if hasattr(class_type, '_abbr_'):
|
58 |
+
return class_type._abbr_ # type: ignore
|
59 |
+
else:
|
60 |
+
return camel2snack(class_type.__name__)
|
61 |
+
|
62 |
+
|
63 |
+
def build_plugin_layer(cfg: Dict,
|
64 |
+
postfix: Union[int, str] = '',
|
65 |
+
**kwargs) -> Tuple[str, nn.Module]:
|
66 |
+
"""Build plugin layer.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
cfg (dict): cfg should contain:
|
70 |
+
|
71 |
+
- type (str): identify plugin layer type.
|
72 |
+
- layer args: args needed to instantiate a plugin layer.
|
73 |
+
postfix (int, str): appended into norm abbreviation to
|
74 |
+
create named layer. Default: ''.
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
tuple[str, nn.Module]: The first one is the concatenation of
|
78 |
+
abbreviation and postfix. The second is the created plugin layer.
|
79 |
+
"""
|
80 |
+
if not isinstance(cfg, dict):
|
81 |
+
raise TypeError('cfg must be a dict')
|
82 |
+
if 'type' not in cfg:
|
83 |
+
raise KeyError('the cfg dict must contain the key "type"')
|
84 |
+
cfg_ = cfg.copy()
|
85 |
+
|
86 |
+
layer_type = cfg_.pop('type')
|
87 |
+
if inspect.isclass(layer_type):
|
88 |
+
plugin_layer = layer_type
|
89 |
+
else:
|
90 |
+
# Switch registry to the target scope. If `plugin_layer` cannot be
|
91 |
+
# found in the registry, fallback to search `plugin_layer` in the
|
92 |
+
# mmengine.MODELS.
|
93 |
+
with MODELS.switch_scope_and_registry(None) as registry:
|
94 |
+
plugin_layer = registry.get(layer_type)
|
95 |
+
if plugin_layer is None:
|
96 |
+
raise KeyError(
|
97 |
+
f'Cannot find {plugin_layer} in registry under scope '
|
98 |
+
f'name {registry.scope}')
|
99 |
+
abbr = infer_abbr(plugin_layer)
|
100 |
+
|
101 |
+
assert isinstance(postfix, (int, str))
|
102 |
+
name = abbr + str(postfix)
|
103 |
+
|
104 |
+
layer = plugin_layer(**kwargs, **cfg_)
|
105 |
+
|
106 |
+
return name, layer
|