Yiming-M commited on
Commit
570db9a
·
1 Parent(s): f319c12
.gitignore ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MacOS
2
+ **/.DS_Store
3
+
4
+ **/*.pth
5
+ models/clip/_clip/configs/*
6
+ models/clip/_clip/weights/*
7
+
8
+ # Byte-compiled / optimized / DLL files
9
+ __pycache__/
10
+ *.py[cod]
11
+ *$py.class
12
+
13
+ # C extensions
14
+ *.so
15
+
16
+ # Distribution / packaging
17
+ .Python
18
+ build/
19
+ develop-eggs/
20
+ dist/
21
+ downloads/
22
+ eggs/
23
+ .eggs/
24
+ lib/
25
+ lib64/
26
+ parts/
27
+ sdist/
28
+ var/
29
+ wheels/
30
+ share/python-wheels/
31
+ *.egg-info/
32
+ .installed.cfg
33
+ *.egg
34
+ MANIFEST
35
+
36
+ # PyInstaller
37
+ # Usually these files are written by a python script from a template
38
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
39
+ *.manifest
40
+ *.spec
41
+
42
+ # Installer logs
43
+ pip-log.txt
44
+ pip-delete-this-directory.txt
45
+
46
+ # Unit test / coverage reports
47
+ htmlcov/
48
+ .tox/
49
+ .nox/
50
+ .coverage
51
+ .coverage.*
52
+ .cache
53
+ nosetests.xml
54
+ coverage.xml
55
+ *.cover
56
+ *.py,cover
57
+ .hypothesis/
58
+ .pytest_cache/
59
+ cover/
60
+
61
+ # Translations
62
+ *.mo
63
+ *.pot
64
+
65
+ # Django stuff:
66
+ *.log
67
+ local_settings.py
68
+ db.sqlite3
69
+ db.sqlite3-journal
70
+
71
+ # Flask stuff:
72
+ instance/
73
+ .webassets-cache
74
+
75
+ # Scrapy stuff:
76
+ .scrapy
77
+
78
+ # Sphinx documentation
79
+ docs/_build/
80
+
81
+ # PyBuilder
82
+ .pybuilder/
83
+ target/
84
+
85
+ # Jupyter Notebook
86
+ .ipynb_checkpoints
87
+
88
+ # IPython
89
+ profile_default/
90
+ ipython_config.py
91
+
92
+ # pyenv
93
+ # For a library or package, you might want to ignore these files since the code is
94
+ # intended to run in multiple environments; otherwise, check them in:
95
+ # .python-version
96
+
97
+ # pipenv
98
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
99
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
100
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
101
+ # install all needed dependencies.
102
+ #Pipfile.lock
103
+
104
+ # poetry
105
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
106
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
107
+ # commonly ignored for libraries.
108
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
109
+ #poetry.lock
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ #pdm.lock
114
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
115
+ # in version control.
116
+ # https://pdm.fming.dev/#use-with-ide
117
+ .pdm.toml
118
+
119
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
120
+ __pypackages__/
121
+
122
+ # Celery stuff
123
+ celerybeat-schedule
124
+ celerybeat.pid
125
+
126
+ # SageMath parsed files
127
+ *.sage.py
128
+
129
+ # Environments
130
+ .env
131
+ .venv
132
+ env/
133
+ venv/
134
+ ENV/
135
+ env.bak/
136
+ venv.bak/
137
+
138
+ # Spyder project settings
139
+ .spyderproject
140
+ .spyproject
141
+
142
+ # Rope project settings
143
+ .ropeproject
144
+
145
+ # mkdocs documentation
146
+ /site
147
+
148
+ # mypy
149
+ .mypy_cache/
150
+ .dmypy.json
151
+ dmypy.json
152
+
153
+ # Pyre type checker
154
+ .pyre/
155
+
156
+ # pytype static type analyzer
157
+ .pytype/
158
+
159
+ # Cython debug symbols
160
+ cython_debug/
161
+
162
+ # PyCharm
163
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
164
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
165
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
166
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
167
+ #.idea/
app.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import Tensor
4
+ import numpy as np
5
+ from PIL import Image
6
+ import json, os, random
7
+ import gradio as gr
8
+ import torchvision.transforms.functional as TF
9
+ from safetensors.torch import load_file # Import the load_file function from safetensors
10
+ from matplotlib import cm
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ from typing import Tuple
14
+
15
+ from models import get_model
16
+
17
+
18
+ def resize_density_map(x: Tensor, size: Tuple[int, int]) -> Tensor:
19
+ x_sum = torch.sum(x, dim=(-1, -2))
20
+ x = F.interpolate(x, size=size, mode="bilinear")
21
+ scale_factor = torch.nan_to_num(torch.sum(x, dim=(-1, -2)) / x_sum, nan=0.0, posinf=0.0, neginf=0.0)
22
+ return x * scale_factor
23
+
24
+
25
+ def init_seeds(seed: int) -> None:
26
+ random.seed(seed)
27
+ np.random.seed(seed)
28
+ torch.manual_seed(seed)
29
+
30
+
31
+ mean = (0.485, 0.456, 0.406)
32
+ std = (0.229, 0.224, 0.225)
33
+ alpha = 0.8
34
+ init_seeds(42)
35
+
36
+ # -----------------------------
37
+ # Define the model architecture
38
+ # -----------------------------
39
+ truncation = 4
40
+ reduction = 8
41
+ granularity = "fine"
42
+ anchor_points = "average"
43
+
44
+ model_name = "clip_vit_l_14"
45
+ input_size = 224
46
+
47
+ # Comment the lines below to test non-CLIP models.
48
+ prompt_type = "word"
49
+ num_vpt = 32
50
+ vpt_drop = 0.
51
+ deep_vpt = True
52
+
53
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
+
55
+
56
+ if truncation is None: # regression, no truncation.
57
+ bins, anchor_points = None, None
58
+ else:
59
+ with open(os.path.join("configs", f"reduction_{reduction}.json"), "r") as f:
60
+ config = json.load(f)[str(truncation)]["nwpu"]
61
+ bins = config["bins"][granularity]
62
+ anchor_points = config["anchor_points"][granularity]["average"] if anchor_points == "average" else config["anchor_points"][granularity]["middle"]
63
+ bins = [(float(b[0]), float(b[1])) for b in bins]
64
+ anchor_points = [float(p) for p in anchor_points]
65
+
66
+
67
+ model = get_model(
68
+ backbone=model_name,
69
+ input_size=input_size,
70
+ reduction=reduction,
71
+ bins=bins,
72
+ anchor_points=anchor_points,
73
+ # CLIP parameters
74
+ prompt_type=prompt_type,
75
+ num_vpt=num_vpt,
76
+ vpt_drop=vpt_drop,
77
+ deep_vpt=deep_vpt
78
+ )
79
+
80
+ repo_id = "Yiming-M/CLIP-EBC"
81
+ filename = "nwpu_weights/CLIP_EBC_ViT_L_14/model.safetensors"
82
+ weights_path = hf_hub_download(repo_id, filename)
83
+ # weights_path = os.path.join("CLIP_EBC_ViT_L_14", "model.safetensors")
84
+ state_dict = load_file(weights_path)
85
+ new_state_dict = {}
86
+ for k, v in state_dict.items():
87
+ new_state_dict[k.replace("model.", "")] = v
88
+ model.load_state_dict(new_state_dict)
89
+ model.to(device)
90
+ model.eval()
91
+
92
+
93
+ # -----------------------------
94
+ # Preprocessing function
95
+ # -----------------------------
96
+ # Adjust the image transforms to match what your model expects.
97
+ def transform(image: Image.Image):
98
+ assert isinstance(image, Image.Image), "Input must be a PIL Image"
99
+ image_tensor = TF.to_tensor(image)
100
+
101
+ image_height, image_width = image_tensor.shape[-2:]
102
+ if image_height < input_size or image_width < input_size:
103
+ # Find the ratio to resize the image while maintaining the aspect ratio
104
+ ratio = max(input_size / image_height, input_size / image_width)
105
+ new_height = int(image_height * ratio) + 1
106
+ new_width = int(image_width * ratio) + 1
107
+ image_tensor = TF.resize(image_tensor, (new_height, new_width), interpolation=TF.InterpolationMode.BICUBIC, antialias=True)
108
+
109
+ image_tensor = TF.normalize(image_tensor, mean=mean, std=std)
110
+ return image_tensor.unsqueeze(0) # Add batch dimension
111
+
112
+
113
+
114
+ # -----------------------------
115
+ # Inference function
116
+ # -----------------------------
117
+ def predict(image: Image.Image):
118
+ """
119
+ Given an input image, preprocess it, run the model to obtain a density map,
120
+ compute the total crowd count, and prepare the density map for display.
121
+ """
122
+ # Preprocess the image
123
+ input_width, input_height = image.size
124
+ input_tensor = transform(image).to(device) # shape: (1, 3, H, W)
125
+
126
+ with torch.no_grad():
127
+ density_map = model(input_tensor) # expected shape: (1, 1, H, W)
128
+ total_count = density_map.sum().item()
129
+ resized_density_map = resize_density_map(density_map, (input_height, input_width)).cpu().squeeze().numpy()
130
+
131
+ # Normalize the density map for display purposes
132
+ eps = 1e-8
133
+ density_map_norm = (resized_density_map - resized_density_map.min()) / (resized_density_map.max() - resized_density_map.min() + eps)
134
+
135
+ # Apply a colormap (e.g., 'jet') to get an RGBA image
136
+ colormap = cm.get_cmap("jet")
137
+ # The colormap returns values in [0,1]. Scale to [0,255] and convert to uint8.
138
+ density_map_color = (colormap(density_map_norm) * 255).astype(np.uint8)
139
+ density_map_color_img = Image.fromarray(density_map_color).convert("RGBA")
140
+
141
+ # Ensure the original image is in RGBA format.
142
+ image_rgba = image.convert("RGBA")
143
+ overlayed_image = Image.blend(image_rgba, density_map_color_img, alpha=alpha)
144
+
145
+ return image, overlayed_image, f"Predicted Count: {total_count:.2f}"
146
+
147
+
148
+ # -----------------------------
149
+ # Build Gradio Interface using Blocks for a two-column layout
150
+ # -----------------------------
151
+ with gr.Blocks() as demo:
152
+ gr.Markdown("# Crowd Counting Demo")
153
+ gr.Markdown("Upload an image or select an example below to see the predicted crowd density map and total count.")
154
+
155
+ with gr.Row():
156
+ with gr.Column():
157
+ input_img = gr.Image(
158
+ label="Input Image",
159
+ sources=["upload", "clipboard"],
160
+ type="pil",
161
+ )
162
+ submit_btn = gr.Button("Predict")
163
+ with gr.Column():
164
+ output_img = gr.Image(label="Predicted Density Map", type="pil")
165
+ output_text = gr.Textbox(label="Total Count")
166
+
167
+ submit_btn.click(fn=predict, inputs=input_img, outputs=[input_img, output_img, output_text])
168
+
169
+ # Optional: add example images. Ensure these files are in your repo.
170
+ gr.Examples(
171
+ examples=[
172
+ ["example1.jpg"],
173
+ ["example2.jpg"]
174
+ ],
175
+ inputs=input_img,
176
+ label="Try an example"
177
+ )
178
+
179
+ # Launch the app
180
+ demo.launch()
configs/reduction_16.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "8":{
3
+ "qnrf": {
4
+ "bins": {
5
+ "fine":[
6
+ [0, 0], [1, 1], [2, 2], [3, 3], [4, 4],
7
+ [5, 5], [6, 6], [7, 7], [8, "inf"]
8
+ ],
9
+ "dynamic": [
10
+ [0, 0], [1, 1], [2, 2], [3, 3],
11
+ [4, 5], [6, 7], [8, "inf"]
12
+ ],
13
+ "coarse": [
14
+ [0, 0], [1, 2], [3, 4], [5, 6], [7, "inf"]
15
+ ]
16
+ },
17
+ "anchor_points": {
18
+ "fine": {
19
+ "middle": [0, 1, 2, 3, 4, 5, 6, 7, 8],
20
+ "average": [0, 1, 2, 3, 4, 5, 6, 7, 9.23349]
21
+ },
22
+ "dynamic": {
23
+ "middle": [0, 1, 2, 3, 4.5, 6.5, 8],
24
+ "average": [0, 1, 2, 3, 4.29278, 6.31441, 9.23349]
25
+ },
26
+ "coarse": {
27
+ "middle": [0, 1.5, 3.5, 5.5, 7],
28
+ "average": [0, 1.14978, 3.27641, 5.30609, 8.11466]
29
+ }
30
+ }
31
+ }
32
+ }
33
+ }
configs/reduction_32.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "19": {
3
+ "qnrf": {
4
+ "bins": {
5
+ "fine": [
6
+ [0, 0], [1, 1], [2, 2], [3, 3], [4, 4],
7
+ [5, 5], [6, 6], [7, 7], [8, 8], [9, 9],
8
+ [10, 10], [11, 11], [12, 12], [13, 13], [14, 14],
9
+ [15, 15], [16, 16], [17, 17], [18, 18], [19, "inf"]
10
+ ],
11
+ "dynamic": [
12
+ [0, 0], [1, 1], [2, 2], [3, 3], [4, 4],
13
+ [5, 5], [6, 6], [7, 7], [8, 8], [9, 9],
14
+ [10, 11], [12, 13], [14, 15], [16, 17], [18, "inf"]
15
+ ],
16
+ "coarse": [
17
+ [0, 0], [1, 2], [3, 4], [5, 6], [7, 8],
18
+ [9, 10], [11, 12], [13, 14], [15, 16], [17, 18],
19
+ [19, "inf"]
20
+ ]
21
+ },
22
+ "anchor_points": {
23
+ "fine": {
24
+ "middle": [
25
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
26
+ 11, 12, 13, 14, 15, 16, 17, 18, 19
27
+ ],
28
+ "average": [
29
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
30
+ 11, 12, 13, 14, 15, 16, 17, 18, 23.01897
31
+ ]
32
+ },
33
+ "dynamic": {
34
+ "middle": [
35
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10.5,
36
+ 12.5, 14.5, 16.5, 18
37
+ ],
38
+ "average": [
39
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10.42903,
40
+ 12.43320, 14.43341, 16.43521, 21.93548
41
+ ]
42
+ },
43
+ "coarse": {
44
+ "middle": [
45
+ 0, 1.5, 3.5, 5.5, 7.5, 9.5,
46
+ 11.5, 13.5, 15.5, 17.5, 19
47
+ ],
48
+ "average": [
49
+ 0, 1.23498, 3.36108, 5.40298, 7.41406, 9.42356,
50
+ 11.43094, 13.43244, 15.43697, 17.43759, 23.01897
51
+ ]
52
+ }
53
+ }
54
+ }
55
+ }
56
+ }
configs/reduction_8.json ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "2": {
3
+ "sha": {
4
+ "bins": {
5
+ "fine": [[0, 0], [1, 1], [2, "inf"]]
6
+ },
7
+ "anchor_points": {
8
+ "fine": {
9
+ "middle": [0, 1, 2],
10
+ "average": [0, 1, 2.24479]
11
+ }
12
+ }
13
+ },
14
+ "shb": {
15
+ "bins": {
16
+ "fine": [[0, 0], [1, 1], [2, "inf"]]
17
+ },
18
+ "anchor_points": {
19
+ "fine": {
20
+ "middle": [0, 1, 2],
21
+ "average": [0, 1, 2.15171]
22
+ }
23
+ }
24
+ },
25
+ "nwpu": {
26
+ "bins": {
27
+ "fine": [[0, 0], [1, 1], [2, "inf"]]
28
+ },
29
+ "anchor_points": {
30
+ "fine": {
31
+ "middle": [0, 1, 2],
32
+ "average": [0, 1, 2.10737]
33
+ }
34
+ }
35
+ },
36
+ "qnrf": {
37
+ "bins": {
38
+ "fine": [[0, 0], [1, 1], [2, "inf"]]
39
+ },
40
+ "anchor_points": {
41
+ "fine": {
42
+ "middle": [0, 1, 2],
43
+ "average": [0, 1, 2.09296]
44
+ }
45
+ }
46
+ },
47
+ "jhu": {
48
+ "bins": {
49
+ "fine": [[0, 0], [1, 1], [2, "inf"]]
50
+ },
51
+ "anchor_points": {
52
+ "fine": {
53
+ "middle": [0, 1, 2],
54
+ "average": [0, 1, 2.18589]
55
+ }
56
+ }
57
+ }
58
+ },
59
+ "4": {
60
+ "sha": {
61
+ "bins": {
62
+ "fine": [[0, 0], [1, 1], [2, 2], [3, 3], [4, "inf"]]
63
+ },
64
+ "anchor_points": {
65
+ "fine": {
66
+ "middle": [0, 1, 2, 3, 4],
67
+ "average": [0, 1, 2, 3, 4.29992]
68
+ }
69
+ }
70
+ },
71
+ "shb": {
72
+ "bins": {
73
+ "fine": [[0, 0], [1, 1], [2, 2], [3, 3], [4, "inf"]]
74
+ },
75
+ "anchor_points": {
76
+ "fine": {
77
+ "middle": [0, 1, 2, 3, 4],
78
+ "average": [0, 1, 2, 3, 4.41009]
79
+ }
80
+ }
81
+ },
82
+ "nwpu": {
83
+ "bins": {
84
+ "fine": [[0, 0], [1, 1], [2, 2], [3, 3], [4, "inf"]]
85
+ },
86
+ "anchor_points": {
87
+ "fine": {
88
+ "middle": [0, 1, 2, 3, 4],
89
+ "average": [0, 1, 2, 3, 4.21931]
90
+ }
91
+ }
92
+ },
93
+ "qnrf": {
94
+ "bins": {
95
+ "fine": [[0, 0], [1, 1], [2, 2], [3, 3], [4, "inf"]]
96
+ },
97
+ "anchor_points": {
98
+ "fine": {
99
+ "middle": [0, 1, 2, 3, 4],
100
+ "average": [0, 1, 2, 3, 4.21937]
101
+ }
102
+ }
103
+ },
104
+ "jhu": {
105
+ "bins": {
106
+ "fine": [[0, 0], [1, 1], [2, 2], [3, 3], [4, "inf"]]
107
+ },
108
+ "anchor_points": {
109
+ "fine": {
110
+ "middle": [0, 1, 2, 3, 4],
111
+ "average": [0, 1, 2, 3, 4.24058]
112
+ }
113
+ }
114
+ }
115
+ },
116
+ "11": {
117
+ "qnrf": {
118
+ "bins": {
119
+ "fine": [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9], [10, 10], [11, "inf"]]
120
+ },
121
+ "anchor_points": {
122
+ "fine": {
123
+ "middle": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
124
+ "average": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
125
+ }
126
+ }
127
+ }
128
+ }
129
+ }
models/__init__.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Optional, Any, Union
2
+
3
+ from .model import _classifier, _regressor, Classifier, Regressor
4
+ from .clip import _clip_ebc, CLIP_EBC
5
+
6
+
7
+ clip_names = ["resnet50", "resnet50x4", "resnet50x16", "resnet50x64", "resnet101", "vit_b_16", "vit_b_32", "vit_l_14"]
8
+
9
+
10
+ def get_model(
11
+ backbone: str,
12
+ input_size: int,
13
+ reduction: int,
14
+ bins: Optional[List[Tuple[float, float]]] = None,
15
+ anchor_points: Optional[List[float]] = None,
16
+ **kwargs: Any,
17
+ ) -> Union[Regressor, Classifier, CLIP_EBC]:
18
+ backbone = backbone.lower()
19
+ if "clip" in backbone:
20
+ backbone = backbone[5:]
21
+ assert backbone in clip_names, f"Expected backbone to be in {clip_names}, got {backbone}"
22
+ return _clip_ebc(
23
+ backbone=backbone,
24
+ input_size=input_size,
25
+ reduction=reduction,
26
+ bins=bins,
27
+ anchor_points=anchor_points,
28
+ **kwargs
29
+ )
30
+ elif bins is None and anchor_points is None:
31
+ return _regressor(
32
+ backbone=backbone,
33
+ input_size=input_size,
34
+ reduction=reduction,
35
+ )
36
+ else:
37
+ assert bins is not None and anchor_points is not None, f"Expected bins and anchor_points to be both None or not None, got {bins} and {anchor_points}"
38
+ return _classifier(
39
+ backbone=backbone,
40
+ input_size=input_size,
41
+ reduction=reduction,
42
+ bins=bins,
43
+ anchor_points=anchor_points,
44
+ )
45
+
46
+
47
+ __all__ = [
48
+ "get_model",
49
+ ]
models/clip/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .model import CLIP_EBC, _clip_ebc
2
+
3
+
4
+ __all__ = [
5
+ "CLIP_EBC",
6
+ "_clip_ebc",
7
+ ]
models/clip/_clip/__init__.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from typing import Tuple, Optional, Any, Union
4
+ import json
5
+
6
+ from .utils import tokenize, transform
7
+ from .prepare import prepare
8
+ from .text_encoder import CLIPTextEncoder
9
+ from .image_encoder import ModifiedResNet, VisionTransformer
10
+ from .model import CLIP
11
+
12
+
13
+ curr_dir = os.path.dirname(os.path.abspath(__file__))
14
+
15
+ clip_model_names = [
16
+ "clip_resnet50",
17
+ "clip_resnet101",
18
+ "clip_resnet50x4",
19
+ "clip_resnet50x16",
20
+ "clip_resnet50x64",
21
+ "clip_vit_b_32",
22
+ "clip_vit_b_16",
23
+ "clip_vit_l_14",
24
+ "clip_vit_l_14_336px",
25
+ ]
26
+
27
+ clip_image_encoder_names = [f"clip_image_encoder_{name[5:]}" for name in clip_model_names]
28
+ clip_text_encoder_names = [f"clip_text_encoder_{name[5:]}" for name in clip_model_names]
29
+
30
+
31
+ for name in clip_model_names + clip_image_encoder_names + clip_text_encoder_names:
32
+ model_weights_path = os.path.join(curr_dir, "weights", f"{name}.pth")
33
+ model_config_path = os.path.join(curr_dir, "configs", f"{name}.json")
34
+ if not os.path.exists(os.path.join(curr_dir, "weights", f"{name}.pth")) or not os.path.exists(os.path.join(curr_dir, "configs", f"{name}.json")):
35
+ prepare()
36
+ break
37
+
38
+
39
+ for name in clip_model_names + clip_image_encoder_names + clip_text_encoder_names:
40
+ assert os.path.exists(os.path.join(curr_dir, "weights", f"{name}.pth")), f"Missing {name}.pth in weights folder. Please run models/clip/prepare.py to download the weights."
41
+ assert os.path.exists(os.path.join(curr_dir, "configs", f"{name}.json")), f"Missing {name}.json in configs folder. Please run models/clip/prepare.py to download the configs."
42
+
43
+
44
+ def _clip(name: str, input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP:
45
+ with open(os.path.join(curr_dir, "configs", f"clip_{name}.json"), "r") as f:
46
+ config = json.load(f)
47
+
48
+ model = CLIP(
49
+ embed_dim=config["embed_dim"],
50
+ # vision
51
+ image_resolution=config["image_resolution"],
52
+ vision_layers=config["vision_layers"],
53
+ vision_width=config["vision_width"],
54
+ vision_patch_size=config["vision_patch_size"],
55
+ # text
56
+ context_length=config["context_length"],
57
+ vocab_size=config["vocab_size"],
58
+ transformer_width=config["transformer_width"],
59
+ transformer_heads=config["transformer_heads"],
60
+ transformer_layers=config["transformer_layers"]
61
+ )
62
+ state_dict = torch.load(os.path.join(curr_dir, "weights", f"clip_{name}.pth"), map_location="cpu")
63
+ model.load_state_dict(state_dict, strict=True)
64
+
65
+ if input_size is not None:
66
+ input_size = (input_size, input_size) if isinstance(input_size, int) else input_size
67
+ if name.startswith("vit"):
68
+ model.visual.adjust_pos_embed(*input_size)
69
+
70
+ return model
71
+
72
+
73
+ def _resnet(
74
+ name: str,
75
+ reduction: int = 32,
76
+ features_only: bool = False,
77
+ out_indices: Optional[Tuple[int, ...]] = None,
78
+ **kwargs: Any
79
+ ) -> ModifiedResNet:
80
+ with open(os.path.join(curr_dir, "configs", f"clip_image_encoder_{name}.json"), "r") as f:
81
+ config = json.load(f)
82
+ model = ModifiedResNet(
83
+ layers=config["vision_layers"],
84
+ output_dim=config["embed_dim"],
85
+ input_resolution=config["image_resolution"],
86
+ width=config["vision_width"],
87
+ heads=config["vision_heads"],
88
+ features_only=features_only,
89
+ out_indices=out_indices,
90
+ reduction=reduction
91
+ )
92
+ state_dict = torch.load(os.path.join(curr_dir, "weights", f"clip_image_encoder_{name}.pth"), map_location="cpu")
93
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
94
+ if len(missing_keys) > 0 or len(unexpected_keys) > 0:
95
+ print(f"Missing keys: {missing_keys}")
96
+ print(f"Unexpected keys: {unexpected_keys}")
97
+ else:
98
+ print(f"All keys matched successfully.")
99
+
100
+ return model
101
+
102
+
103
+ def _vit(name: str, features_only: bool = False, input_size: Optional[Union[int, Tuple[int, int]]] = None, **kwargs: Any) -> VisionTransformer:
104
+ with open(os.path.join(curr_dir, "configs", f"clip_image_encoder_{name}.json"), "r") as f:
105
+ config = json.load(f)
106
+ model = VisionTransformer(
107
+ input_resolution=config["image_resolution"],
108
+ patch_size=config["vision_patch_size"],
109
+ output_dim=config["embed_dim"],
110
+ width=config["vision_width"],
111
+ layers=config["vision_layers"],
112
+ heads=config["vision_heads"],
113
+ features_only=features_only
114
+ )
115
+ state_dict = torch.load(os.path.join(curr_dir, "weights", f"clip_image_encoder_{name}.pth"), map_location="cpu")
116
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
117
+ if len(missing_keys) > 0 or len(unexpected_keys) > 0:
118
+ print(f"Missing keys: {missing_keys}")
119
+ print(f"Unexpected keys: {unexpected_keys}")
120
+ else:
121
+ print(f"All keys matched successfully.")
122
+
123
+ if input_size is not None:
124
+ input_size = (input_size, input_size) if isinstance(input_size, int) else input_size
125
+ model.adjust_pos_embed(*input_size)
126
+ return model
127
+
128
+
129
+ def _text_encoder(name: str) -> CLIPTextEncoder:
130
+ with open(os.path.join(curr_dir, "configs", f"clip_text_encoder_{name}.json"), "r") as f:
131
+ config = json.load(f)
132
+ model = CLIPTextEncoder(
133
+ embed_dim=config["embed_dim"],
134
+ context_length=config["context_length"],
135
+ vocab_size=config["vocab_size"],
136
+ transformer_width=config["transformer_width"],
137
+ transformer_heads=config["transformer_heads"],
138
+ transformer_layers=config["transformer_layers"]
139
+ )
140
+ state_dict = torch.load(os.path.join(curr_dir, "weights", f"clip_text_encoder_{name}.pth"), map_location="cpu")
141
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
142
+ if len(missing_keys) > 0 or len(unexpected_keys) > 0:
143
+ print(f"Missing keys: {missing_keys}")
144
+ print(f"Unexpected keys: {unexpected_keys}")
145
+ else:
146
+ print(f"All keys matched successfully.")
147
+
148
+ return model
149
+
150
+
151
+
152
+ # CLIP models
153
+ def resnet50_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP:
154
+ return _clip("resnet50", input_size)
155
+
156
+ def resnet101_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP:
157
+ return _clip("resnet101", input_size)
158
+
159
+ def resnet50x4_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP:
160
+ return _clip("resnet50x4", input_size)
161
+
162
+ def resnet50x16_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP:
163
+ return _clip("resnet50x16", input_size)
164
+
165
+ def resnet50x64_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP:
166
+ return _clip("resnet50x64", input_size)
167
+
168
+ def vit_b_32_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP:
169
+ return _clip("vit_b_32", input_size)
170
+
171
+ def vit_b_16_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP:
172
+ return _clip("vit_b_16", input_size)
173
+
174
+ def vit_l_14_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP:
175
+ return _clip("vit_l_14", input_size)
176
+
177
+ def vit_l_14_336px_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP:
178
+ return _clip("vit_l_14_336px", input_size)
179
+
180
+
181
+ # CLIP image encoders
182
+ def resnet50_img(features_only: bool = False, out_indices: Optional[Tuple[int, ...]] = None, **kwargs: Any) -> ModifiedResNet:
183
+ return _resnet("resnet50", features_only=features_only, out_indices=out_indices, **kwargs)
184
+
185
+ def resnet101_img(features_only: bool = False, out_indices: Optional[Tuple[int, ...]] = None, **kwargs: Any) -> ModifiedResNet:
186
+ return _resnet("resnet101", features_only=features_only, out_indices=out_indices, **kwargs)
187
+
188
+ def resnet50x4_img(features_only: bool = False, out_indices: Optional[Tuple[int, ...]] = None, **kwargs: Any) -> ModifiedResNet:
189
+ return _resnet("resnet50x4", features_only=features_only, out_indices=out_indices, **kwargs)
190
+
191
+ def resnet50x16_img(features_only: bool = False, out_indices: Optional[Tuple[int, ...]] = None, **kwargs: Any) -> ModifiedResNet:
192
+ return _resnet("resnet50x16", features_only=features_only, out_indices=out_indices, **kwargs)
193
+
194
+ def resnet50x64_img(features_only: bool = False, out_indices: Optional[Tuple[int, ...]] = None, **kwargs: Any) -> ModifiedResNet:
195
+ return _resnet("resnet50x64", features_only=features_only, out_indices=out_indices, **kwargs)
196
+
197
+ def vit_b_32_img(features_only: bool = False, input_size: Optional[Union[int, Tuple[int, int]]] = None, **kwargs: Any) -> VisionTransformer:
198
+ return _vit("vit_b_32", features_only=features_only, input_size=input_size, **kwargs)
199
+
200
+ def vit_b_16_img(features_only: bool = False, input_size: Optional[Union[int, Tuple[int, int]]] = None, **kwargs: Any) -> VisionTransformer:
201
+ return _vit("vit_b_16", features_only=features_only, input_size=input_size, **kwargs)
202
+
203
+ def vit_l_14_img(features_only: bool = False, input_size: Optional[Union[int, Tuple[int, int]]] = None, **kwargs: Any) -> VisionTransformer:
204
+ return _vit("vit_l_14", features_only=features_only, input_size=input_size, **kwargs)
205
+
206
+ def vit_l_14_336px_img(features_only: bool = False, input_size: Optional[Union[int, Tuple[int, int]]] = None, **kwargs: Any) -> VisionTransformer:
207
+ return _vit("vit_l_14_336px", features_only=features_only, input_size=input_size, **kwargs)
208
+
209
+
210
+ # CLIP text encoders
211
+ def resnet50_txt() -> CLIPTextEncoder:
212
+ return _text_encoder("resnet50")
213
+
214
+ def resnet101_txt() -> CLIPTextEncoder:
215
+ return _text_encoder("resnet101")
216
+
217
+ def resnet50x4_txt() -> CLIPTextEncoder:
218
+ return _text_encoder("resnet50x4")
219
+
220
+ def resnet50x16_txt() -> CLIPTextEncoder:
221
+ return _text_encoder("resnet50x16")
222
+
223
+ def resnet50x64_txt() -> CLIPTextEncoder:
224
+ return _text_encoder("resnet50x64")
225
+
226
+ def vit_b_32_txt() -> CLIPTextEncoder:
227
+ return _text_encoder("vit_b_32")
228
+
229
+ def vit_b_16_txt() -> CLIPTextEncoder:
230
+ return _text_encoder("vit_b_16")
231
+
232
+ def vit_l_14_txt() -> CLIPTextEncoder:
233
+ return _text_encoder("vit_l_14")
234
+
235
+ def vit_l_14_336px_txt() -> CLIPTextEncoder:
236
+ return _text_encoder("vit_l_14_336px")
237
+
238
+
239
+ __all__ = [
240
+ # utils
241
+ "tokenize",
242
+ "transform",
243
+ # clip models
244
+ "resnet50_clip",
245
+ "resnet101_clip",
246
+ "resnet50x4_clip",
247
+ "resnet50x16_clip",
248
+ "resnet50x64_clip",
249
+ "vit_b_32_clip",
250
+ "vit_b_16_clip",
251
+ "vit_l_14_clip",
252
+ "vit_l_14_336px_clip",
253
+ # clip image encoders
254
+ "resnet50_img",
255
+ "resnet101_img",
256
+ "resnet50x4_img",
257
+ "resnet50x16_img",
258
+ "resnet50x64_img",
259
+ "vit_b_32_img",
260
+ "vit_b_16_img",
261
+ "vit_l_14_img",
262
+ "vit_l_14_336px_img",
263
+ # clip text encoders
264
+ "resnet50_txt",
265
+ "resnet101_txt",
266
+ "resnet50x4_txt",
267
+ "resnet50x16_txt",
268
+ "resnet50x64_txt",
269
+ "vit_b_32_txt",
270
+ "vit_b_16_txt",
271
+ "vit_l_14_txt",
272
+ "vit_l_14_336px_txt",
273
+ ]
models/clip/_clip/blocks.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ import torch.nn.functional as F
4
+ from collections import OrderedDict
5
+ from typing import Optional, Iterable
6
+
7
+
8
+ class LayerNorm(nn.LayerNorm):
9
+ """Subclass torch's LayerNorm to handle fp16."""
10
+
11
+ def forward(self, x: Tensor):
12
+ orig_type = x.dtype
13
+ ret = super().forward(x.type(torch.float32))
14
+ return ret.type(orig_type)
15
+
16
+
17
+ class QuickGELU(nn.Module):
18
+ def forward(self, x: Tensor):
19
+ return x * torch.sigmoid(1.702 * x)
20
+
21
+
22
+ class ResidualAttentionBlock(nn.Module):
23
+ def __init__(self, d_model: int, n_head: int, attn_mask: Tensor = None):
24
+ super().__init__()
25
+ self.attn = nn.MultiheadAttention(d_model, n_head)
26
+ self.ln_1 = LayerNorm(d_model)
27
+ self.mlp = nn.Sequential(OrderedDict([
28
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
29
+ ("gelu", QuickGELU()),
30
+ ("c_proj", nn.Linear(d_model * 4, d_model))
31
+ ]))
32
+ self.ln_2 = LayerNorm(d_model)
33
+ self.attn_mask = attn_mask
34
+
35
+ def attention(self, x: Tensor):
36
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
37
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
38
+
39
+ def forward(self, x: Tensor) -> Tensor:
40
+ x = x + self.attention(self.ln_1(x))
41
+ x = x + self.mlp(self.ln_2(x))
42
+ return x
43
+
44
+
45
+ class Transformer(nn.Module):
46
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: Tensor = None):
47
+ super().__init__()
48
+ self.width = width
49
+ self.layers = layers
50
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
51
+
52
+ def forward(self, x: Tensor):
53
+ return self.resblocks(x)
54
+
55
+
56
+ class Bottleneck(nn.Module):
57
+ expansion = 4
58
+
59
+ def __init__(self, inplanes, planes, stride=1):
60
+ super().__init__()
61
+
62
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
63
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
64
+ self.bn1 = nn.BatchNorm2d(planes)
65
+ self.relu1 = nn.ReLU(inplace=True)
66
+
67
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
68
+ self.bn2 = nn.BatchNorm2d(planes)
69
+ self.relu2 = nn.ReLU(inplace=True)
70
+
71
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
72
+
73
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
74
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
75
+ self.relu3 = nn.ReLU(inplace=True)
76
+
77
+ self.downsample = None
78
+ self.stride = stride
79
+
80
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
81
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
82
+ self.downsample = nn.Sequential(OrderedDict([
83
+ ("-1", nn.AvgPool2d(stride)),
84
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
85
+ ("1", nn.BatchNorm2d(planes * self.expansion))
86
+ ]))
87
+
88
+ def forward(self, x: Tensor):
89
+ identity = x
90
+
91
+ out = self.relu1(self.bn1(self.conv1(x)))
92
+ out = self.relu2(self.bn2(self.conv2(out)))
93
+ out = self.avgpool(out)
94
+ out = self.bn3(self.conv3(out))
95
+
96
+ if self.downsample is not None:
97
+ identity = self.downsample(x)
98
+
99
+ out += identity
100
+ out = self.relu3(out)
101
+ return out
102
+
103
+
104
+ class AttentionPool2d(nn.Module):
105
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
106
+ super().__init__()
107
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
108
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
109
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
110
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
111
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
112
+ self.num_heads = num_heads
113
+
114
+ def forward(self, x):
115
+ x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
116
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
117
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
118
+ x, _ = F.multi_head_attention_forward(
119
+ query=x[:1], key=x, value=x,
120
+ embed_dim_to_check=x.shape[-1],
121
+ num_heads=self.num_heads,
122
+ q_proj_weight=self.q_proj.weight,
123
+ k_proj_weight=self.k_proj.weight,
124
+ v_proj_weight=self.v_proj.weight,
125
+ in_proj_weight=None,
126
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
127
+ bias_k=None,
128
+ bias_v=None,
129
+ add_zero_attn=False,
130
+ dropout_p=0,
131
+ out_proj_weight=self.c_proj.weight,
132
+ out_proj_bias=self.c_proj.bias,
133
+ use_separate_proj_weight=True,
134
+ training=self.training,
135
+ need_weights=False
136
+ )
137
+ return x.squeeze(0)
models/clip/_clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
models/clip/_clip/image_encoder.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+ from typing import Tuple, Union, Any, List, Iterable, Optional
6
+
7
+ from .blocks import LayerNorm, Transformer, Bottleneck, AttentionPool2d
8
+
9
+
10
+ class ModifiedResNet(nn.Module):
11
+ """
12
+ A ResNet class that is similar to torchvision's but contains the following changes:
13
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
14
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
15
+ - The final pooling layer is a QKV attention instead of an average pool
16
+ """
17
+ def __init__(
18
+ self,
19
+ layers: Tuple[int, int, int, int],
20
+ output_dim: int,
21
+ input_resolution: int = 224,
22
+ width: int = 64,
23
+ heads: int = 8,
24
+ features_only: bool = False,
25
+ out_indices: Optional[Iterable[int]] = None,
26
+ reduction: int = 32,
27
+ **kwargs: Any,
28
+ ) -> None:
29
+ super().__init__()
30
+ input_resolution = (input_resolution, input_resolution) if isinstance(input_resolution, int) else input_resolution
31
+ assert isinstance(input_resolution, tuple) and len(input_resolution) == 2, f"input_resolution should be a tuple of length 2, but got {input_resolution}"
32
+ self.input_resolution = input_resolution
33
+ self.downsampling_rate = 32 # the rate at which the input is downsampled by the network
34
+
35
+ # the 3-layer stem
36
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
37
+ self.bn1 = nn.BatchNorm2d(width // 2)
38
+ self.relu1 = nn.ReLU(inplace=True)
39
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
40
+ self.bn2 = nn.BatchNorm2d(width // 2)
41
+ self.relu2 = nn.ReLU(inplace=True)
42
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
43
+ self.bn3 = nn.BatchNorm2d(width)
44
+ self.relu3 = nn.ReLU(inplace=True)
45
+ self.avgpool = nn.AvgPool2d(2)
46
+
47
+ # residual layers
48
+ self._inplanes = width # this is a *mutable* variable used during construction
49
+ self.layer1 = self._make_layer(width, layers[0])
50
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
51
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
52
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=1 if reduction <= 16 else 2)
53
+
54
+ self.features_only = features_only
55
+ if features_only:
56
+ self.out_indices = out_indices if out_indices is not None else range(5)
57
+ self.out_indices = [idx + 5 if idx < 0 else idx for idx in self.out_indices] # map negative indices to positive indices
58
+ self.out_indices = sorted(set(self.out_indices)) # remove duplicates and sort
59
+ assert min(self.out_indices) >= 0 and max(self.out_indices) <= 4, f"out_indices={self.out_indices} is invalid for a ResNet with 5 stages"
60
+ self.channels = width * 32 # the ResNet feature dimension
61
+ else:
62
+ self.out_indices = None
63
+ embed_dim = width * 32 # the ResNet feature dimension
64
+ self.attnpool = AttentionPool2d((input_resolution[0] // 32) * (input_resolution[1] // 32), embed_dim, heads, output_dim)
65
+ self.channels = output_dim
66
+
67
+ self.reduction = self.downsampling_rate // 2 if reduction <= 16 else self.downsampling_rate
68
+ self.clip_embed_dim = output_dim
69
+
70
+ def _make_layer(self, planes, blocks, stride=1):
71
+ layers = [Bottleneck(self._inplanes, planes, stride)]
72
+
73
+ self._inplanes = planes * Bottleneck.expansion
74
+ for _ in range(1, blocks):
75
+ layers.append(Bottleneck(self._inplanes, planes))
76
+
77
+ return nn.Sequential(*layers)
78
+
79
+ def _stem(self, x: Tensor) -> Tensor:
80
+ x = self.relu1(self.bn1(self.conv1(x)))
81
+ x = self.relu2(self.bn2(self.conv2(x)))
82
+ x = self.relu3(self.bn3(self.conv3(x)))
83
+ x = self.avgpool(x)
84
+ return x
85
+
86
+ def forward(self, x: Tensor) -> Union[Tensor, List[Tensor]]:
87
+ x = x.type(self.conv1.weight.dtype)
88
+ x = self._stem(x)
89
+
90
+ feats = [x] if self.features_only and 0 in self.out_indices else []
91
+
92
+ x = self.layer1(x)
93
+ if self.features_only and 1 in self.out_indices:
94
+ feats.append(x)
95
+
96
+ x = self.layer2(x)
97
+ if self.features_only and 2 in self.out_indices:
98
+ feats.append(x)
99
+
100
+ x = self.layer3(x)
101
+ if self.features_only and 3 in self.out_indices:
102
+ feats.append(x)
103
+
104
+ x = self.layer4(x)
105
+ if self.features_only and 4 in self.out_indices:
106
+ feats.append(x)
107
+
108
+ if self.features_only:
109
+ if len(self.out_indices) == 1:
110
+ return feats[0]
111
+ else:
112
+ return feats
113
+ else:
114
+ x = self.attnpool(x)
115
+ return x
116
+
117
+
118
+ class VisionTransformer(nn.Module):
119
+ def __init__(
120
+ self,
121
+ input_resolution: Union[int, Tuple[int, int]],
122
+ patch_size: Union[int, Tuple[int, int]],
123
+ output_dim: int,
124
+ width: int,
125
+ layers: int,
126
+ heads: int,
127
+ features_only: bool = False,
128
+ **kwargs: Any,
129
+ ) -> None:
130
+ super().__init__()
131
+ input_resolution = (input_resolution, input_resolution) if isinstance(input_resolution, int) else input_resolution
132
+ patch_size = (patch_size, patch_size) if isinstance(patch_size, int) else patch_size
133
+ assert isinstance(input_resolution, tuple) and len(input_resolution) == 2, f"input_resolution should be a tuple of length 2, but got {input_resolution}"
134
+ assert isinstance(patch_size, tuple) and len(patch_size) == 2, f"patch_size should be a tuple of length 2, but got {patch_size}"
135
+ assert patch_size[0] == patch_size[1], f"ViT only supports square patches, patch_size={patch_size} is invalid."
136
+ assert input_resolution[0] % patch_size[0] == 0 and input_resolution[1] % patch_size[1] == 0, f"input_resolution {input_resolution} should be divisible by patch_size {patch_size}"
137
+ self.input_resolution = input_resolution
138
+ self.patch_size = patch_size
139
+ self.downsampling_rate = patch_size[0]
140
+
141
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
142
+
143
+ scale = width ** -0.5
144
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
145
+ self.num_patches_h = int(input_resolution[0] // patch_size[0])
146
+ self.num_patches_w = int(input_resolution[1] // patch_size[1])
147
+ self.positional_embedding = nn.Parameter(scale * torch.randn(self.num_patches_h * self.num_patches_w + 1, width))
148
+ self.ln_pre = LayerNorm(width)
149
+
150
+ self.transformer = Transformer(width, layers, heads)
151
+ self.ln_post = LayerNorm(width)
152
+
153
+ self.features_only = features_only # if True, return the final patches instead of the CLS token
154
+ if features_only:
155
+ self.channels = width
156
+ else:
157
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
158
+ self.channels = output_dim
159
+
160
+ self.reduction = patch_size[0]
161
+ self.clip_embed_dim = output_dim
162
+
163
+ def adjust_pos_embed(self, h: int, w: int) -> None:
164
+ """
165
+ Permanently adjust the size of the positional embedding matrix.
166
+
167
+ Args:
168
+ h: the height of the original input image.
169
+ w: the width of the original input image.
170
+ """
171
+ assert h % self.patch_size[0] == 0 and w % self.patch_size[1] == 0, f"input_resolution {h, w} should be divisible by patch_size {self.patch_size}"
172
+ if self.input_resolution[0] != h or self.input_resolution[1] != w:
173
+ new_num_patches_h = int(h // self.patch_size[0])
174
+ new_num_patches_w = int(w // self.patch_size[1])
175
+ positional_embedding = rearrange(self.positional_embedding[1:, :], "(h w) c -> c h w", h=self.num_patches_h, w=self.num_patches_w).unsqueeze(0) # add batch dimension
176
+ positional_embedding = F.interpolate(positional_embedding, size=(new_num_patches_h, new_num_patches_w), mode="bicubic", ).squeeze(0) # remove batch dimension
177
+ positional_embedding = rearrange(positional_embedding, "c h w -> (h w) c")
178
+ self.positional_embedding = nn.Parameter(torch.cat([self.positional_embedding[:1, :], positional_embedding], dim=0))
179
+ self.input_resolution = (h, w)
180
+ self.num_patches_h = new_num_patches_h
181
+ self.num_patches_w = new_num_patches_w
182
+
183
+ def _interpolate_pos_embed(self, h: int, w: int) -> Tensor:
184
+ """
185
+ Interpolate the positional embedding matrix to match the size of the input image.
186
+
187
+ Args:
188
+ h: the required number of patches along the height dimension.
189
+ w: the required number of patches along the width dimension.
190
+ """
191
+ if h == self.num_patches_h and w == self.num_patches_w:
192
+ return self.positional_embedding
193
+ else:
194
+ positional_embedding = rearrange(self.positional_embedding[1:, :], "(h w) c -> c h w", h=self.num_patches_h, w=self.num_patches_w).unsqueeze(0) # add batch dimension
195
+ positional_embedding = F.interpolate(positional_embedding, size=(h, w), mode="bicubic").squeeze(0) # remove batch dimension
196
+ positional_embedding = rearrange(positional_embedding, "c h w -> (h w) c")
197
+ positional_embedding = torch.cat([self.positional_embedding[:1, :], positional_embedding], dim=0)
198
+ return positional_embedding
199
+
200
+ def forward(self, x: Tensor) -> Tensor:
201
+ x = self.conv1(x) # shape = [*, width, grid, grid]
202
+ num_patches_h, num_patches_w = x.shape[-2:]
203
+
204
+ positional_embedding = self._interpolate_pos_embed(num_patches_h, num_patches_w).to(x.dtype)
205
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
206
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
207
+ x = torch.cat([
208
+ self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
209
+ x
210
+ ], dim=1)
211
+ x = x + positional_embedding
212
+ x = self.ln_pre(x)
213
+
214
+ x = x.permute(1, 0, 2) # NLD -> LND. N: batch size, L: sequence length, D: feature dimension
215
+ x = self.transformer(x)
216
+ x = x.permute(1, 0, 2) # LND -> NLD
217
+ x = self.ln_post(x)
218
+
219
+ if self.features_only:
220
+ x = x[:, 1:, :] # remove the CLS token
221
+ x = rearrange(x, "n (h w) c -> n c h w", h=num_patches_h, w=num_patches_w)
222
+ else:
223
+ x = x[:, 0, :]
224
+ x = x @ self.proj
225
+ return x
models/clip/_clip/model.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import numpy as np
4
+
5
+ from typing import Tuple, Union
6
+
7
+ from .image_encoder import ModifiedResNet, VisionTransformer
8
+ from .text_encoder import LayerNorm, Transformer
9
+
10
+
11
+ class CLIP(nn.Module):
12
+ def __init__(
13
+ self,
14
+ embed_dim: int,
15
+ # vision
16
+ image_resolution: int,
17
+ vision_layers: Union[Tuple[int, int, int, int], int],
18
+ vision_width: int,
19
+ vision_patch_size: int,
20
+ # text
21
+ context_length: int,
22
+ vocab_size: int,
23
+ transformer_width: int,
24
+ transformer_heads: int,
25
+ transformer_layers: int
26
+ ) -> None:
27
+ super().__init__()
28
+ self.embed_dim = embed_dim
29
+ self.image_resolution = image_resolution
30
+ self.vision_layers = vision_layers
31
+ self.vision_width = vision_width
32
+ self.vision_patch_size = vision_patch_size
33
+ self.context_length = context_length
34
+ self.vocab_size = vocab_size
35
+ self.transformer_width = transformer_width
36
+ self.transformer_heads = transformer_heads
37
+ self.transformer_layers = transformer_layers
38
+
39
+ if isinstance(vision_layers, (tuple, list)):
40
+ vision_heads = vision_width * 32 // 64
41
+ self.visual = ModifiedResNet(
42
+ layers=vision_layers,
43
+ output_dim=embed_dim,
44
+ heads=vision_heads,
45
+ input_resolution=image_resolution,
46
+ width=vision_width,
47
+ features_only=False,
48
+ )
49
+ else:
50
+ vision_heads = vision_width // 64
51
+ self.visual = VisionTransformer(
52
+ input_resolution=image_resolution,
53
+ patch_size=vision_patch_size,
54
+ width=vision_width,
55
+ layers=vision_layers,
56
+ heads=vision_heads,
57
+ output_dim=embed_dim,
58
+ features_only=False,
59
+ )
60
+ self.vision_heads = vision_heads
61
+ self.transformer = Transformer(
62
+ width=transformer_width,
63
+ layers=transformer_layers,
64
+ heads=transformer_heads,
65
+ attn_mask=self.build_attention_mask()
66
+ )
67
+
68
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
69
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
70
+ self.ln_final = LayerNorm(transformer_width)
71
+
72
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
73
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
74
+
75
+ self.initialize_parameters()
76
+
77
+ def initialize_parameters(self):
78
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
79
+ nn.init.normal_(self.positional_embedding, std=0.01)
80
+
81
+ if isinstance(self.visual, ModifiedResNet):
82
+ if self.visual.attnpool is not None:
83
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
84
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
85
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
86
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
87
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
88
+
89
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
90
+ for name, param in resnet_block.named_parameters():
91
+ if name.endswith("bn3.weight"):
92
+ nn.init.zeros_(param)
93
+
94
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
95
+ attn_std = self.transformer.width ** -0.5
96
+ fc_std = (2 * self.transformer.width) ** -0.5
97
+ for block in self.transformer.resblocks:
98
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
99
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
100
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
101
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
102
+
103
+ if self.text_projection is not None:
104
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
105
+
106
+ def build_attention_mask(self):
107
+ # lazily create causal attention mask, with full attention between the vision tokens
108
+ # pytorch uses additive attention mask; fill with -inf
109
+ mask = torch.empty(self.context_length, self.context_length)
110
+ mask.fill_(float("-inf"))
111
+ mask.triu_(1) # zero out the lower diagonal
112
+ return mask
113
+
114
+ @property
115
+ def dtype(self):
116
+ return self.visual.conv1.weight.dtype
117
+
118
+ def encode_image(self, image):
119
+ return self.visual(image.type(self.dtype))
120
+
121
+ def encode_text(self, text):
122
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
123
+
124
+ x = x + self.positional_embedding.type(self.dtype)
125
+ x = x.permute(1, 0, 2) # NLD -> LND
126
+ x = self.transformer(x)
127
+ x = x.permute(1, 0, 2) # LND -> NLD
128
+ x = self.ln_final(x).type(self.dtype)
129
+
130
+ # x.shape = [batch_size, n_ctx, transformer.width]
131
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
132
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
133
+
134
+ return x
135
+
136
+ def forward(self, image, text):
137
+ image_features = self.encode_image(image)
138
+ text_features = self.encode_text(text)
139
+
140
+ # normalized features
141
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
142
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
143
+
144
+ # cosine similarity as logits
145
+ logit_scale = self.logit_scale.exp()
146
+ logits_per_image = logit_scale * image_features @ text_features.t()
147
+ logits_per_text = logits_per_image.t()
148
+
149
+ # shape = [global_batch_size, global_batch_size]
150
+ return logits_per_image, logits_per_text
151
+
152
+
153
+ def convert_weights(model: nn.Module):
154
+ """Convert applicable model parameters to fp16"""
155
+
156
+ def _convert_weights_to_fp16(l):
157
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
158
+ l.weight.data = l.weight.data.half()
159
+ if l.bias is not None:
160
+ l.bias.data = l.bias.data.half()
161
+
162
+ if isinstance(l, nn.MultiheadAttention):
163
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
164
+ tensor = getattr(l, attr)
165
+ if tensor is not None:
166
+ tensor.data = tensor.data.half()
167
+
168
+ for name in ["text_projection", "proj"]:
169
+ if hasattr(l, name):
170
+ attr = getattr(l, name)
171
+ if attr is not None:
172
+ attr.data = attr.data.half()
173
+
174
+ model.apply(_convert_weights_to_fp16)
175
+
176
+
177
+ def build_model(state_dict: dict):
178
+ vit = "visual.proj" in state_dict
179
+
180
+ if vit:
181
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
182
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
183
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
184
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
185
+ image_resolution = vision_patch_size * grid_size
186
+ else:
187
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
188
+ vision_layers = tuple(counts)
189
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
190
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
191
+ vision_patch_size = None
192
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
193
+ image_resolution = output_width * 32
194
+
195
+ embed_dim = state_dict["text_projection"].shape[1]
196
+ context_length = state_dict["positional_embedding"].shape[0]
197
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
198
+ transformer_width = state_dict["ln_final.weight"].shape[0]
199
+ transformer_heads = transformer_width // 64
200
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
201
+
202
+ model = CLIP(
203
+ embed_dim,
204
+ image_resolution, vision_layers, vision_width, vision_patch_size,
205
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
206
+ )
207
+
208
+ for key in ["input_resolution", "context_length", "vocab_size"]:
209
+ if key in state_dict:
210
+ del state_dict[key]
211
+
212
+ convert_weights(model)
213
+ model.load_state_dict(state_dict, strict=False)
214
+ return model.eval()
models/clip/_clip/prepare.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Prepare the models to speed up loading them later
2
+ import torch
3
+ from torch import nn, Tensor
4
+ import os
5
+ from tqdm import tqdm
6
+ import json
7
+
8
+ from .utils import load
9
+
10
+
11
+ model_name_map = {
12
+ "RN50": "resnet50",
13
+ "RN101": "resnet101",
14
+ "RN50x4": "resnet50x4",
15
+ "RN50x16": "resnet50x16",
16
+ "RN50x64": "resnet50x64",
17
+ "ViT-B/32": "vit_b_32",
18
+ "ViT-B/16": "vit_b_16",
19
+ "ViT-L/14": "vit_l_14",
20
+ "ViT-L/14@336px": "vit_l_14_336px",
21
+ }
22
+
23
+
24
+ class CLIPTextEncoderTemp(nn.Module):
25
+ def __init__(
26
+ self,
27
+ clip: nn.Module,
28
+ ) -> None:
29
+ super().__init__()
30
+ self.context_length = clip.context_length
31
+ self.vocab_size = clip.vocab_size
32
+ self.dtype = clip.dtype
33
+ self.token_embedding = clip.token_embedding
34
+ self.positional_embedding = clip.positional_embedding
35
+ self.transformer = clip.transformer
36
+ self.ln_final = clip.ln_final
37
+ self.text_projection = clip.text_projection
38
+
39
+ def forward(self, text: Tensor) -> None:
40
+ pass
41
+
42
+
43
+ def prepare() -> None:
44
+ print("Preparing CLIP models...")
45
+ curr_dir = os.path.dirname(os.path.abspath(__file__))
46
+ weight_dir = os.path.join(curr_dir, "weights")
47
+ config_dir = os.path.join(curr_dir, "configs")
48
+ os.makedirs(weight_dir, exist_ok=True)
49
+ os.makedirs(config_dir, exist_ok=True)
50
+ device = torch.device("cpu")
51
+
52
+ for model_name in tqdm(["RN50", "RN101", "RN50x4", "RN50x16", "RN50x64", "ViT-B/32", "ViT-B/16", "ViT-L/14", "ViT-L/14@336px"]):
53
+ model = load(model_name, device=device).to(device)
54
+ image_encoder = model.visual.to(device)
55
+ text_encoder = CLIPTextEncoderTemp(model).to(device)
56
+ torch.save(model.state_dict(), os.path.join(weight_dir, f"clip_{model_name_map[model_name]}.pth"))
57
+ torch.save(image_encoder.state_dict(), os.path.join(weight_dir, f"clip_image_encoder_{model_name_map[model_name]}.pth"))
58
+ torch.save(text_encoder.state_dict(), os.path.join(weight_dir, f"clip_text_encoder_{model_name_map[model_name]}.pth"))
59
+ model_config = {
60
+ "embed_dim": model.embed_dim,
61
+ # vision
62
+ "image_resolution": model.image_resolution,
63
+ "vision_layers": model.vision_layers,
64
+ "vision_width": model.vision_width,
65
+ "vision_patch_size": model.vision_patch_size,
66
+ # text
67
+ "context_length": model.context_length,
68
+ "vocab_size": model.vocab_size,
69
+ "transformer_width": model.transformer_width,
70
+ "transformer_heads": model.transformer_heads,
71
+ "transformer_layers": model.transformer_layers,
72
+ }
73
+ image_encoder_config = {
74
+ "embed_dim": model.embed_dim,
75
+ "image_resolution": model.image_resolution,
76
+ "vision_layers": model.vision_layers,
77
+ "vision_width": model.vision_width,
78
+ "vision_patch_size": model.vision_patch_size,
79
+ "vision_heads": model.vision_heads,
80
+ }
81
+ text_encoder_config = {
82
+ "embed_dim": model.embed_dim,
83
+ "context_length": model.context_length,
84
+ "vocab_size": model.vocab_size,
85
+ "transformer_width": model.transformer_width,
86
+ "transformer_heads": model.transformer_heads,
87
+ "transformer_layers": model.transformer_layers,
88
+ }
89
+ with open(os.path.join(config_dir, f"clip_{model_name_map[model_name]}.json"), "w") as f:
90
+ json.dump(model_config, f, indent=4)
91
+ with open(os.path.join(config_dir, f"clip_image_encoder_{model_name_map[model_name]}.json"), "w") as f:
92
+ json.dump(image_encoder_config, f, indent=4)
93
+ with open(os.path.join(config_dir, f"clip_text_encoder_{model_name_map[model_name]}.json"), "w") as f:
94
+ json.dump(text_encoder_config, f, indent=4)
95
+ print("Done!")
models/clip/_clip/simple_tokenizer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import html
3
+ import os
4
+ from functools import lru_cache
5
+
6
+ import ftfy
7
+ import regex as re
8
+
9
+
10
+ @lru_cache()
11
+ def default_bpe():
12
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
+
14
+
15
+ @lru_cache()
16
+ def bytes_to_unicode():
17
+ """
18
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
19
+ The reversible bpe codes work on unicode strings.
20
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
+ This is a significant percentage of your normal, say, 32K bpe vocab.
23
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
25
+ """
26
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27
+ cs = bs[:]
28
+ n = 0
29
+ for b in range(2**8):
30
+ if b not in bs:
31
+ bs.append(b)
32
+ cs.append(2**8+n)
33
+ n += 1
34
+ cs = [chr(n) for n in cs]
35
+ return dict(zip(bs, cs))
36
+
37
+
38
+ def get_pairs(word):
39
+ """Return set of symbol pairs in a word.
40
+ Word is represented as tuple of symbols (symbols being variable-length strings).
41
+ """
42
+ pairs = set()
43
+ prev_char = word[0]
44
+ for char in word[1:]:
45
+ pairs.add((prev_char, char))
46
+ prev_char = char
47
+ return pairs
48
+
49
+
50
+ def basic_clean(text):
51
+ text = ftfy.fix_text(text)
52
+ text = html.unescape(html.unescape(text))
53
+ return text.strip()
54
+
55
+
56
+ def whitespace_clean(text):
57
+ text = re.sub(r'\s+', ' ', text)
58
+ text = text.strip()
59
+ return text
60
+
61
+
62
+ class SimpleTokenizer(object):
63
+ def __init__(self, bpe_path: str = default_bpe()):
64
+ self.byte_encoder = bytes_to_unicode()
65
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67
+ merges = merges[1:49152-256-2+1]
68
+ merges = [tuple(merge.split()) for merge in merges]
69
+ vocab = list(bytes_to_unicode().values())
70
+ vocab = vocab + [v+'</w>' for v in vocab]
71
+ for merge in merges:
72
+ vocab.append(''.join(merge))
73
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74
+ self.encoder = dict(zip(vocab, range(len(vocab))))
75
+ self.decoder = {v: k for k, v in self.encoder.items()}
76
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
77
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79
+
80
+ def bpe(self, token):
81
+ if token in self.cache:
82
+ return self.cache[token]
83
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
84
+ pairs = get_pairs(word)
85
+
86
+ if not pairs:
87
+ return token+'</w>'
88
+
89
+ while True:
90
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91
+ if bigram not in self.bpe_ranks:
92
+ break
93
+ first, second = bigram
94
+ new_word = []
95
+ i = 0
96
+ while i < len(word):
97
+ try:
98
+ j = word.index(first, i)
99
+ new_word.extend(word[i:j])
100
+ i = j
101
+ except:
102
+ new_word.extend(word[i:])
103
+ break
104
+
105
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
106
+ new_word.append(first+second)
107
+ i += 2
108
+ else:
109
+ new_word.append(word[i])
110
+ i += 1
111
+ new_word = tuple(new_word)
112
+ word = new_word
113
+ if len(word) == 1:
114
+ break
115
+ else:
116
+ pairs = get_pairs(word)
117
+ word = ' '.join(word)
118
+ self.cache[token] = word
119
+ return word
120
+
121
+ def encode(self, text):
122
+ bpe_tokens = []
123
+ text = whitespace_clean(basic_clean(text)).lower()
124
+ for token in re.findall(self.pat, text):
125
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127
+ return bpe_tokens
128
+
129
+ def decode(self, tokens):
130
+ text = ''.join([self.decoder[token] for token in tokens])
131
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
132
+ return text
models/clip/_clip/text_encoder.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+
4
+ from .blocks import LayerNorm, Transformer
5
+
6
+
7
+ class CLIPTextEncoder(nn.Module):
8
+ def __init__(
9
+ self,
10
+ embed_dim: int,
11
+ context_length: int,
12
+ vocab_size: int,
13
+ transformer_width: int,
14
+ transformer_heads: int,
15
+ transformer_layers: int,
16
+ ) -> None:
17
+ super().__init__()
18
+ self.context_length = context_length
19
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
20
+ self.transformer = Transformer(
21
+ width=transformer_width,
22
+ layers=transformer_layers,
23
+ heads=transformer_heads,
24
+ attn_mask=self.build_attention_mask(),
25
+ )
26
+ self.vocab_size = vocab_size
27
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
28
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
29
+ self.ln_final = LayerNorm(transformer_width)
30
+
31
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
32
+
33
+ def build_attention_mask(self):
34
+ # lazily create causal attention mask, with full attention between the vision tokens
35
+ # pytorch uses additive attention mask; fill with -inf
36
+ mask = torch.empty(self.context_length, self.context_length)
37
+ mask.fill_(float("-inf"))
38
+ mask.triu_(1) # zero out the lower diagonal
39
+ return mask
40
+
41
+ @property
42
+ def dtype(self):
43
+ return self.transformer.resblocks[0].attn.in_proj_weight.dtype
44
+
45
+ def forward(self, text: Tensor):
46
+ x = self.token_embedding(text).type(self.dtype)
47
+ x = x + self.positional_embedding.type(self.dtype)
48
+ x = x.permute(1, 0, 2) # NLD -> LND
49
+ x = self.transformer(x)
50
+ x = x.permute(1, 0, 2) # LND -> NLD
51
+ x = self.ln_final(x).type(self.dtype)
52
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
53
+ return x
models/clip/_clip/utils.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from typing import Union, List
6
+ from pkg_resources import packaging
7
+
8
+ from PIL import Image
9
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
10
+ import torch
11
+
12
+ from typing import List, Union
13
+ from tqdm import tqdm
14
+
15
+ from .model import build_model
16
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
17
+
18
+ try:
19
+ from torchvision.transforms import InterpolationMode
20
+ BICUBIC = InterpolationMode.BICUBIC
21
+ except ImportError:
22
+ BICUBIC = Image.BICUBIC
23
+
24
+
25
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
26
+ warnings.warn("PyTorch version 1.7.1 or higher is recommended")
27
+
28
+
29
+ __all__ = ["available_models", "load", "tokenize"]
30
+ _tokenizer = _Tokenizer()
31
+
32
+
33
+
34
+ _MODELS = {
35
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
36
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
37
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
38
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
39
+ "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
40
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
41
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
42
+ "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
43
+ "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
44
+ }
45
+
46
+
47
+ def _download(url: str, root: str):
48
+ os.makedirs(root, exist_ok=True)
49
+ filename = os.path.basename(url)
50
+
51
+ expected_sha256 = url.split("/")[-2]
52
+ download_target = os.path.join(root, filename)
53
+
54
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
55
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
56
+
57
+ if os.path.isfile(download_target):
58
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
59
+ return download_target
60
+ else:
61
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
62
+
63
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
64
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
65
+ while True:
66
+ buffer = source.read(8192)
67
+ if not buffer:
68
+ break
69
+
70
+ output.write(buffer)
71
+ loop.update(len(buffer))
72
+
73
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
74
+ raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
75
+
76
+ return download_target
77
+
78
+
79
+ def _convert_image_to_rgb(image):
80
+ return image.convert("RGB")
81
+
82
+
83
+ def transform(n_px):
84
+ return Compose([
85
+ Resize(n_px, interpolation=BICUBIC),
86
+ CenterCrop(n_px),
87
+ _convert_image_to_rgb,
88
+ ToTensor(),
89
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
90
+ ])
91
+
92
+
93
+ def available_models() -> List[str]:
94
+ """Returns the names of available CLIP models"""
95
+ return list(_MODELS.keys())
96
+
97
+
98
+ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
99
+ """Load a CLIP model
100
+
101
+ Parameters
102
+ ----------
103
+ name : str
104
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
105
+
106
+ device : Union[str, torch.device]
107
+ The device to put the loaded model
108
+
109
+ jit : bool
110
+ Whether to load the optimized JIT model or more hackable non-JIT model (default).
111
+
112
+ download_root: str
113
+ path to download the model files; by default, it uses "~/.cache/clip"
114
+
115
+ Returns
116
+ -------
117
+ model : torch.nn.Module
118
+ The CLIP model
119
+
120
+ preprocess : Callable[[PIL.Image], torch.Tensor]
121
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
122
+ """
123
+ if name in _MODELS:
124
+ model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
125
+ elif os.path.isfile(name):
126
+ model_path = name
127
+ else:
128
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
129
+
130
+ with open(model_path, 'rb') as opened_file:
131
+ try:
132
+ # loading JIT archive
133
+ model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
134
+ state_dict = None
135
+ except RuntimeError:
136
+ # loading saved state dict
137
+ if jit:
138
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
139
+ jit = False
140
+ state_dict = torch.load(opened_file, map_location="cpu")
141
+
142
+ if not jit:
143
+ model = build_model(state_dict or model.state_dict()).to(device)
144
+ if str(device) == "cpu":
145
+ model.float()
146
+ return model
147
+
148
+ # patch the device names
149
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
150
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
151
+
152
+ def _node_get(node: torch._C.Node, key: str):
153
+ """Gets attributes of a node which is polymorphic over return type.
154
+
155
+ From https://github.com/pytorch/pytorch/pull/82628
156
+ """
157
+ sel = node.kindOf(key)
158
+ return getattr(node, sel)(key)
159
+
160
+ def patch_device(module):
161
+ try:
162
+ graphs = [module.graph] if hasattr(module, "graph") else []
163
+ except RuntimeError:
164
+ graphs = []
165
+
166
+ if hasattr(module, "forward1"):
167
+ graphs.append(module.forward1.graph)
168
+
169
+ for graph in graphs:
170
+ for node in graph.findAllNodes("prim::Constant"):
171
+ if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
172
+ node.copyAttributes(device_node)
173
+
174
+ model.apply(patch_device)
175
+ patch_device(model.encode_image)
176
+ patch_device(model.encode_text)
177
+
178
+ # patch dtype to float32 on CPU
179
+ if str(device) == "cpu":
180
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
181
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
182
+ float_node = float_input.node()
183
+
184
+ def patch_float(module):
185
+ try:
186
+ graphs = [module.graph] if hasattr(module, "graph") else []
187
+ except RuntimeError:
188
+ graphs = []
189
+
190
+ if hasattr(module, "forward1"):
191
+ graphs.append(module.forward1.graph)
192
+
193
+ for graph in graphs:
194
+ for node in graph.findAllNodes("aten::to"):
195
+ inputs = list(node.inputs())
196
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
197
+ if _node_get(inputs[i].node(), "value") == 5:
198
+ inputs[i].node().copyAttributes(float_node)
199
+
200
+ model.apply(patch_float)
201
+ patch_float(model.encode_image)
202
+ patch_float(model.encode_text)
203
+
204
+ model.float()
205
+
206
+ return model
207
+
208
+
209
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
210
+ """
211
+ Returns the tokenized representation of given input string(s)
212
+
213
+ Parameters
214
+ ----------
215
+ texts : Union[str, List[str]]
216
+ An input string or a list of input strings to tokenize
217
+
218
+ context_length : int
219
+ The context length to use; all CLIP models use 77 as the context length
220
+
221
+ truncate: bool
222
+ Whether to truncate the text in case its encoding is longer than the context length
223
+
224
+ Returns
225
+ -------
226
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
227
+ We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
228
+ """
229
+ if isinstance(texts, str):
230
+ texts = [texts]
231
+
232
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
233
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
234
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
235
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
236
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
237
+ else:
238
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
239
+
240
+ for i, tokens in enumerate(all_tokens):
241
+ if len(tokens) > context_length:
242
+ if truncate:
243
+ tokens = tokens[:context_length]
244
+ tokens[-1] = eot_token
245
+ else:
246
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
247
+ result[i, :len(tokens)] = torch.tensor(tokens)
248
+
249
+ return result
models/clip/model.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import os
6
+ import math
7
+ from typing import List, Tuple, Union, Optional
8
+
9
+ from . import _clip
10
+ from ..utils import _init_weights, make_resnet_layers, Bottleneck, BasicBlock
11
+ from .utils import format_count
12
+
13
+ curr_dir = os.path.abspath(os.path.dirname(__file__))
14
+
15
+
16
+ # resnet50: reduction, channels, embed_dim = 32, 2048, 1024
17
+ # resnet101: reduction, channels, embed_dim = 32, 2048, 512
18
+ # resnet50x4: reduction, channels, embed_dim = 32, 2560, 640
19
+ # resnet50x16: reduction, channels, embed_dim = 32, 3072, 768
20
+ # resnet50x64: reduction, channels, embed_dim = 32, 4096, 1024
21
+ # vit_b_32: reduction, channels, embed_dim = 32, 768, 512
22
+ # vit_b_16: reduction, channels, embed_dim = 16, 768, 512
23
+ # vit_l_14: reduction, channels, embed_dim = 14, 1024, 768
24
+ # vit_l_14_336px: reduction, channels, embed_dim = 14, 1024, 768
25
+
26
+ resnet_backbones = ["resnet50", "resnet101", "resnet50x4", "resnet50x16", "resnet50x64"]
27
+ vit_backbones = ["vit_b_16", "vit_b_32", "vit_l_14", "vit_l_14_336px"]
28
+
29
+
30
+ class CLIP_EBC(nn.Module):
31
+ def __init__(
32
+ self,
33
+ backbone: str,
34
+ bins: List[Tuple[float, float]],
35
+ anchor_points: List[float],
36
+ reduction: Optional[int] = None,
37
+ freeze_text_encoder: bool = True,
38
+ prompt_type: str = "number",
39
+ input_size: Optional[int] = None,
40
+ num_vpt: Optional[int] = None,
41
+ deep_vpt: Optional[bool] = None,
42
+ vpt_drop: Optional[float] = None,
43
+ decoder_block: Optional[nn.Module] = None,
44
+ decoder_cfg: Optional[List[Union[str, int]]] = None,
45
+ ) -> None:
46
+ super().__init__()
47
+ assert backbone in resnet_backbones + vit_backbones, f"Backbone should be in {resnet_backbones + vit_backbones}, got {backbone}"
48
+ self.backbone = backbone
49
+
50
+ # Image encoder
51
+ if backbone in resnet_backbones:
52
+ self.image_encoder = getattr(_clip, f"{backbone}_img")(features_only=True, out_indices=(-1,), reduction=reduction)
53
+
54
+ else:
55
+ assert input_size is not None, "Expected input_size to be an integer, got None."
56
+ assert num_vpt is not None, "Expected num_vpt to be an integer, got None."
57
+ assert deep_vpt is not None, "Expected deep_vpt to be a boolean, got None."
58
+ assert vpt_drop is not None, "Expected vpt_drop to be a float, got None."
59
+
60
+ self.image_encoder = getattr(_clip, f"{backbone}_img")(features_only=True, input_size=input_size)
61
+ self.image_encoder_depth = len(self.image_encoder.transformer.resblocks)
62
+
63
+ # Use VPT. Freeze the image encoder.
64
+ for param in self.image_encoder.parameters():
65
+ param.requires_grad = False
66
+
67
+ self.num_vpt = num_vpt
68
+ self.deep_vpt = deep_vpt
69
+
70
+ patch_size = self.image_encoder.patch_size[0]
71
+ val = math.sqrt(6. / float(3 * patch_size + self.image_encoder.channels))
72
+
73
+ for idx in range(self.image_encoder_depth if self.deep_vpt else 1):
74
+ setattr(self, f"vpt_{idx}", nn.Parameter(torch.empty(self.num_vpt, self.image_encoder.channels)))
75
+ nn.init.uniform_(getattr(self, f"vpt_{idx}"), -val, val)
76
+ setattr(self, f"vpt_drop_{idx}", nn.Dropout(vpt_drop) if vpt_drop > 0 else nn.Identity())
77
+
78
+ self.encoder_reduction = self.image_encoder.reduction
79
+ self.reduction = self.encoder_reduction if reduction is None else reduction
80
+ self.channels = self.image_encoder.channels
81
+ self.clip_embed_dim = self.image_encoder.clip_embed_dim
82
+
83
+ if decoder_cfg is not None:
84
+ assert decoder_block is not None, "Expected decoder_block to be a nn.Module, got None."
85
+ self.image_decoder = make_resnet_layers(decoder_block, decoder_cfg, in_channels=self.channels, expansion=1, dilation=1)
86
+ self.image_decoder.apply(_init_weights)
87
+ self.channels = decoder_cfg[-1]
88
+ else:
89
+ self.image_decoder = nn.Identity()
90
+
91
+ if self.channels != self.clip_embed_dim:
92
+ self.projection = nn.Conv2d(in_channels=self.channels, out_channels=self.clip_embed_dim, kernel_size=1)
93
+ self.projection.apply(_init_weights)
94
+ else:
95
+ self.projection = nn.Identity()
96
+
97
+ # Text encoder
98
+ assert prompt_type in ["number", "word"], f"Expected prompt_type to be 'number' or 'word', got {prompt_type}"
99
+ self.prompt_type = prompt_type
100
+ self.text_encoder = getattr(_clip, f"{backbone}_txt")()
101
+ self.freeze_text_encoder = freeze_text_encoder
102
+ if self.freeze_text_encoder:
103
+ for param in self.text_encoder.parameters():
104
+ param.requires_grad = False
105
+
106
+ self.bins = bins
107
+ self.anchor_points = torch.tensor(anchor_points, dtype=torch.float32, requires_grad=False).view(1, -1, 1, 1)
108
+
109
+ self._get_text_prompts()
110
+ self._tokenize_text_prompts()
111
+
112
+ if self.freeze_text_encoder:
113
+ self._extract_text_features()
114
+ else:
115
+ self.text_features = None
116
+
117
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07), requires_grad=True)
118
+
119
+ def _get_text_prompts(self) -> None:
120
+ bins = [b[0] if b[0] == b[1] else b for b in self.bins]
121
+ self.text_prompts = [format_count(b, self.prompt_type) for b in bins]
122
+ print(f"Initialized model with text prompts: {self.text_prompts}")
123
+
124
+ def _tokenize_text_prompts(self) -> None:
125
+ self.text_prompts = _clip.tokenize(self.text_prompts)
126
+
127
+ def _extract_text_features(self) -> None:
128
+ with torch.no_grad():
129
+ self.text_features = self.text_encoder(self.text_prompts)
130
+
131
+ def _prepare_vpt(self, layer: int, batch_size: int, device: torch.device) -> Tensor:
132
+ if not self.deep_vpt:
133
+ assert layer == 0, f"Expected layer to be 0 when using Shallow Visual Prompt Tuning, got {layer}"
134
+
135
+ vpt = getattr(self, f"vpt_{layer}").to(device)
136
+ vpt = vpt.unsqueeze(0).expand(batch_size, -1, -1)
137
+ vpt = getattr(self, f"vpt_drop_{layer}")(vpt)
138
+ vpt = vpt.permute(1, 0, 2) # (num_vpt, batch_size, hidden_dim)
139
+ assert vpt.shape[1] == batch_size, f"Expected the VPT to have the shape [L_vis B C], got {vpt.shape}."
140
+ return vpt
141
+
142
+ def _forward_vpt(self, x: Tensor) -> Tuple[Tensor]:
143
+ device = x.device
144
+ batch_size, _, height, width = x.shape
145
+ num_h_patches, num_w_patches = height // self.image_encoder.patch_size[0], width // self.image_encoder.patch_size[1]
146
+
147
+ image_features = self.image_encoder.conv1(x)
148
+ image_features = image_features.reshape(batch_size, image_features.shape[1], -1)
149
+ image_features = image_features.permute(0, 2, 1) # (B, num_patches, C)
150
+ image_features = torch.cat([
151
+ self.image_encoder.class_embedding + torch.zeros(batch_size, 1, image_features.shape[-1], dtype=image_features.dtype, device=device),
152
+ image_features,
153
+ ], dim=1) # (B, num_patches + 1, C)
154
+
155
+ pos_embedding = self.image_encoder._interpolate_pos_embed(num_h_patches, num_w_patches)
156
+ image_features = image_features + pos_embedding
157
+ image_features = self.image_encoder.ln_pre(image_features)
158
+ image_features = image_features.permute(1, 0, 2) # (num_patches + 1, B, C)
159
+ assert image_features.shape[0] == num_h_patches * num_w_patches + 1 and image_features.shape[1] == batch_size, f"Expected image_features to have shape [num_patches + 1, B, C], got {image_features.shape}."
160
+
161
+ vpt = self._prepare_vpt(0, batch_size, device)
162
+ for idx in range(self.image_encoder_depth):
163
+ # assemble
164
+ image_features = torch.cat([
165
+ image_features[:1, :, :], # CLS token
166
+ vpt,
167
+ image_features[1:, :, :],
168
+ ], dim=0)
169
+
170
+ # transformer
171
+ image_features = self.image_encoder.transformer.resblocks[idx](image_features)
172
+
173
+ # disassemble
174
+ if idx < self.image_encoder_depth - 1:
175
+ if self.deep_vpt:
176
+ vpt = self._prepare_vpt(idx + 1, batch_size, device)
177
+ else:
178
+ vpt = image_features[1: (self.num_vpt + 1), :, :]
179
+
180
+ image_features = torch.cat([
181
+ image_features[:1, :, :], # CLS token
182
+ image_features[(self.num_vpt + 1):, :, :],
183
+ ], dim=0)
184
+
185
+ image_features = image_features.permute(1, 0, 2) # (B, num_patches + 1, C)
186
+ image_features = self.image_encoder.ln_post(image_features)
187
+ image_features = image_features[:, 1:, :].permute(0, 2, 1) # (B, C, num_patches)
188
+ image_features = image_features.reshape(batch_size, -1, num_h_patches, num_w_patches)
189
+ return image_features
190
+
191
+ def _forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]:
192
+ device = x.device
193
+
194
+ x = self.image_encoder(x) if self.backbone in resnet_backbones else self._forward_vpt(x)
195
+ if self.reduction != self.encoder_reduction:
196
+ x = F.interpolate(x, scale_factor=self.encoder_reduction / self.reduction, mode="bilinear")
197
+ x = self.image_decoder(x)
198
+ x = self.projection(x)
199
+
200
+ image_features = x.permute(0, 2, 3, 1) # shape (B, H, W, C)
201
+ text_features = self.text_encoder(self.text_prompts.to(device)) if self.text_features is None else self.text_features.to(device) # shape (N, C)
202
+
203
+ image_features = F.normalize(image_features, p=2, dim=-1)
204
+ text_features = F.normalize(text_features, p=2, dim=-1)
205
+
206
+ # cosine similarity as logits
207
+ logit_scale = self.logit_scale.exp()
208
+ logits = logit_scale * image_features @ text_features.t() # (B, H, W, N), logits per image
209
+ logits = logits.permute(0, 3, 1, 2) # (B, N, H, W)
210
+
211
+ probs = logits.softmax(dim=1)
212
+ exp = (probs * self.anchor_points.to(x.device)).sum(dim=1, keepdim=True) # (B, 1, H, W)
213
+
214
+ if self.training:
215
+ return logits, exp
216
+ else:
217
+ return exp
218
+
219
+ def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]:
220
+ assert len(x.shape) == 4, f"Expected input to have shape (B C H W), got {x.shape}."
221
+ if "vit" in self.backbone:
222
+ image_height, image_width = x.shape[2], x.shape[3]
223
+ window_height, window_width = self.image_encoder.input_resolution
224
+
225
+ if self.training:
226
+ assert (image_height, image_width) == (window_height, window_width), f"Expected input to have shape ({window_height} {window_width}), got ({image_height} {image_width})."
227
+ return self._forward(x)
228
+
229
+ elif (image_height, image_width) == (window_height, window_width): # evaluation, input size = training size
230
+ return self._forward(x)
231
+
232
+ else: # evaluation, input_size != training size, use sliding window prediction
233
+ stride_height, stride_width = window_height, window_width
234
+ reduction = self.reduction
235
+
236
+ num_rows = int(np.ceil((image_height - window_height) / stride_height) + 1)
237
+ num_cols = int(np.ceil((image_width - window_width) / stride_width) + 1)
238
+
239
+ windows = []
240
+ for i in range(num_rows):
241
+ for j in range(num_cols):
242
+ x_start, y_start = i * stride_height, j * stride_width
243
+ x_end, y_end = x_start + window_height, y_start + window_width
244
+ if x_end > image_height:
245
+ x_start, x_end = image_height - window_height, image_height
246
+ if y_end > image_width:
247
+ y_start, y_end = image_width - window_width, image_width
248
+
249
+ window = x[:, :, x_start:x_end, y_start:y_end]
250
+ windows.append(window)
251
+
252
+ windows = torch.cat(windows, dim=0).to(x.device) # batched windows, shape: (num_windows, c, h, w)
253
+
254
+ preds = self._forward(windows)
255
+ preds = preds.cpu().detach().numpy()
256
+
257
+ # assemble the density map
258
+ pred_map = np.zeros((preds.shape[1], image_height // reduction, image_width // reduction), dtype=np.float32)
259
+ count_map = np.zeros((preds.shape[1], image_height // reduction, image_width // reduction), dtype=np.float32)
260
+ idx = 0
261
+ for i in range(num_rows):
262
+ for j in range(num_cols):
263
+ x_start, y_start = i * stride_height, j * stride_width
264
+ x_end, y_end = x_start + window_height, y_start + window_width
265
+ if x_end > image_height:
266
+ x_start, x_end = image_height - window_height, image_height
267
+ if y_end > image_width:
268
+ y_start, y_end = image_width - window_width, image_width
269
+
270
+ pred_map[:, (x_start // reduction): (x_end // reduction), (y_start // reduction): (y_end // reduction)] += preds[idx, :, :, :]
271
+ count_map[:, (x_start // reduction): (x_end // reduction), (y_start // reduction): (y_end // reduction)] += 1.
272
+ idx += 1
273
+
274
+ pred_map /= count_map # average the overlapping regions
275
+ return torch.tensor(pred_map).unsqueeze(0) # shape: (1, 1, h // reduction, w // reduction)
276
+
277
+ else:
278
+ return self._forward(x)
279
+
280
+
281
+ def _clip_ebc(
282
+ backbone: str,
283
+ bins: List[Tuple[float, float]],
284
+ anchor_points: List[float],
285
+ reduction: Optional[int] = None,
286
+ freeze_text_encoder: bool = True,
287
+ prompt_type: str = "number",
288
+ input_size: Optional[int] = None,
289
+ num_vpt: Optional[int] = None,
290
+ deep_vpt: Optional[bool] = None,
291
+ vpt_drop: Optional[float] = None,
292
+ decoder_block: Optional[nn.Module] = None,
293
+ decoder_cfg: Optional[List[Union[str, int]]] = None
294
+ ) -> CLIP_EBC:
295
+ if backbone in resnet_backbones:
296
+ decoder_block = Bottleneck
297
+ if decoder_cfg is None:
298
+ if backbone == "resnet50":
299
+ decoder_cfg = [2048]
300
+ elif backbone == "resnet50x4":
301
+ decoder_cfg = [1280]
302
+ elif backbone == "resnet50x16":
303
+ decoder_cfg = [1536]
304
+ elif backbone == "resnet50x64":
305
+ decoder_cfg = [2048]
306
+ else: # backbone == "resnet101"
307
+ decoder_cfg = [2048, 1024]
308
+ else:
309
+ decoder_block = BasicBlock
310
+ if decoder_cfg is None:
311
+ if backbone == "vit_b_16":
312
+ decoder_cfg = [768]
313
+ elif backbone == "vit_b_32":
314
+ decoder_cfg = [768]
315
+ else: # backbone == "vit_l_14"
316
+ decoder_cfg = [1024]
317
+
318
+ return CLIP_EBC(
319
+ backbone=backbone,
320
+ bins=bins,
321
+ anchor_points=anchor_points,
322
+ reduction=reduction,
323
+ freeze_text_encoder=freeze_text_encoder,
324
+ prompt_type=prompt_type,
325
+ input_size=input_size,
326
+ num_vpt=num_vpt,
327
+ deep_vpt=deep_vpt,
328
+ vpt_drop=vpt_drop,
329
+ decoder_block=decoder_block,
330
+ decoder_cfg=decoder_cfg,
331
+ )
models/clip/utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Tuple
2
+
3
+
4
+ num_to_word = {
5
+ "0": "zero", "1": "one", "2": "two", "3": "three", "4": "four", "5": "five", "6": "six", "7": "seven", "8": "eight", "9": "nine",
6
+ "10": "ten", "11": "eleven", "12": "twelve", "13": "thirteen", "14": "fourteen", "15": "fifteen", "16": "sixteen", "17": "seventeen", "18": "eighteen", "19": "nineteen",
7
+ "20": "twenty", "21": "twenty-one", "22": "twenty-two", "23": "twenty-three", "24": "twenty-four", "25": "twenty-five", "26": "twenty-six", "27": "twenty-seven", "28": "twenty-eight", "29": "twenty-nine",
8
+ "30": "thirty", "31": "thirty-one", "32": "thirty-two", "33": "thirty-three", "34": "thirty-four", "35": "thirty-five", "36": "thirty-six", "37": "thirty-seven", "38": "thirty-eight", "39": "thirty-nine",
9
+ "40": "forty", "41": "forty-one", "42": "forty-two", "43": "forty-three", "44": "forty-four", "45": "forty-five", "46": "forty-six", "47": "forty-seven", "48": "forty-eight", "49": "forty-nine",
10
+ "50": "fifty", "51": "fifty-one", "52": "fifty-two", "53": "fifty-three", "54": "fifty-four", "55": "fifty-five", "56": "fifty-six", "57": "fifty-seven", "58": "fifty-eight", "59": "fifty-nine",
11
+ "60": "sixty", "61": "sixty-one", "62": "sixty-two", "63": "sixty-three", "64": "sixty-four", "65": "sixty-five", "66": "sixty-six", "67": "sixty-seven", "68": "sixty-eight", "69": "sixty-nine",
12
+ "70": "seventy", "71": "seventy-one", "72": "seventy-two", "73": "seventy-three", "74": "seventy-four", "75": "seventy-five", "76": "seventy-six", "77": "seventy-seven", "78": "seventy-eight", "79": "seventy-nine",
13
+ "80": "eighty", "81": "eighty-one", "82": "eighty-two", "83": "eighty-three", "84": "eighty-four", "85": "eighty-five", "86": "eighty-six", "87": "eighty-seven", "88": "eighty-eight", "89": "eighty-nine",
14
+ "90": "ninety", "91": "ninety-one", "92": "ninety-two", "93": "ninety-three", "94": "ninety-four", "95": "ninety-five", "96": "ninety-six", "97": "ninety-seven", "98": "ninety-eight", "99": "ninety-nine",
15
+ "100": "one hundred", "200": "two hundred", "300": "three hundred", "400": "four hundred", "500": "five hundred", "600": "six hundred", "700": "seven hundred", "800": "eight hundred", "900": "nine hundred",
16
+ "1000": "one thousand"
17
+ }
18
+
19
+
20
+ def num2word(num: Union[int, str]) -> str:
21
+ """
22
+ Convert the input number to the corresponding English word. For example, 1 -> "one", 2 -> "two", etc.
23
+ """
24
+ num = str(int(num))
25
+ return num_to_word.get(num, num)
26
+
27
+
28
+ def format_count(count: Union[float, Tuple[float, float]], prompt_type: str = "word") -> str:
29
+ if count == 0:
30
+ return "There is no person." if prompt_type == "word" else "There is 0 person."
31
+ elif count == 1:
32
+ return "There is one person." if prompt_type == "word" else "There is 1 person."
33
+ elif isinstance(count, (int, float)):
34
+ return f"There are {num2word(int(count))} people." if prompt_type == "word" else f"There are {int(count)} people."
35
+ elif count[1] == float("inf"):
36
+ return f"There are more than {num2word(int(count[0]))} people." if prompt_type == "word" else f"There are more than {int(count[0])} people."
37
+ else: # count is a tuple of finite numbers
38
+ left, right = int(count[0]), int(count[1])
39
+ left, right = num2word(left), num2word(right) if prompt_type == "word" else left, right
40
+ return f"There are between {left} and {right} people."
models/encoder/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .vgg import vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19, vgg19_bn
2
+ from .vit import vit_b_16, vit_b_32, vit_l_16, vit_l_32, vit_h_14
3
+ from .timm_models import _timm_encoder
4
+
5
+
6
+ __all__ = [
7
+ "vgg11", "vgg11_bn", "vgg13", "vgg13_bn", "vgg16", "vgg16_bn", "vgg19", "vgg19_bn",
8
+ "vit_b_16", "vit_b_32", "vit_l_16", "vit_l_32", "vit_h_14",
9
+ "_timm_encoder",
10
+ ]
models/encoder/timm_models.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from timm import create_model, list_models
2
+ from torch import nn, Tensor
3
+ import torch.nn.functional as F
4
+ from typing import Optional
5
+
6
+ from warnings import warn
7
+
8
+
9
+ class TIMMEncoder(nn.Module):
10
+ def __init__(
11
+ self,
12
+ backbone: str,
13
+ reduction: Optional[int] = None,
14
+ ) -> None:
15
+ super().__init__()
16
+ assert backbone in list_models(), f"Backbone {backbone} not available in timm"
17
+ encoder = create_model(backbone, pretrained=True, features_only=True, out_indices=[-1])
18
+ encoder_reduction = encoder.feature_info.reduction()[-1]
19
+
20
+ if reduction <= 16:
21
+ if "resnet" in backbone:
22
+ if "resnet18" in backbone or "resnet34" in backbone:
23
+ encoder.layer4[0].conv1.stride = (1, 1)
24
+ encoder.layer4[0].downsample[0].stride = (1, 1)
25
+ else:
26
+ encoder.layer4[0].conv2.stride = (1, 1)
27
+ encoder.layer4[0].downsample[0].stride = (1, 1)
28
+ encoder_reduction = encoder_reduction // 2
29
+
30
+ elif "mobilenetv2" in backbone:
31
+ encoder.blocks[5][0].conv_dw.stride = (1, 1)
32
+ encoder_reduction = encoder_reduction // 2
33
+
34
+ elif "densenet" in backbone:
35
+ encoder.features_transition3.pool = nn.Identity()
36
+ encoder_reduction = encoder_reduction // 2
37
+
38
+ else:
39
+ warn(f"Reduction for {backbone} not handled. Using default reduction of {encoder_reduction}")
40
+
41
+ self.encoder = encoder
42
+ self.encoder_reduction = encoder_reduction
43
+ self.reduction = self.encoder_reduction if reduction is None else reduction
44
+ self.channels = self.encoder.feature_info.channels()[-1]
45
+
46
+ def forward(self, x: Tensor) -> Tensor:
47
+ x = self.encoder(x)[-1]
48
+ if self.encoder_reduction != self.reduction:
49
+ x = F.interpolate(x, scale_factor=self.encoder_reduction / self.reduction, mode="bilinear")
50
+ return x
51
+
52
+
53
+ def _timm_encoder(backbone: str, reduction: Optional[int] = None) -> TIMMEncoder:
54
+ return TIMMEncoder(backbone, reduction)
models/encoder/vgg.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn, Tensor
2
+ import torch.nn.functional as F
3
+ from torch.hub import load_state_dict_from_url
4
+ from typing import Optional
5
+
6
+ from ..utils import make_vgg_layers, vgg_cfgs, vgg_urls
7
+
8
+
9
+ class VGG(nn.Module):
10
+ def __init__(
11
+ self,
12
+ features: nn.Module,
13
+ reduction: Optional[int] = None,
14
+ ) -> None:
15
+ super().__init__()
16
+ self.features = features
17
+ self.encoder_reduction = 16
18
+ self.reduction = self.encoder_reduction if reduction is None else reduction
19
+ self.channels = 512
20
+
21
+ def forward(self, x: Tensor) -> Tensor:
22
+ x = self.features(x)
23
+ if self.encoder_reduction != self.reduction:
24
+ x = F.interpolate(x, scale_factor=self.encoder_reduction / self.reduction, mode="bilinear")
25
+ return x
26
+
27
+
28
+ def _load_weights(model: VGG, url: str) -> VGG:
29
+ state_dict = load_state_dict_from_url(url)
30
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
31
+ print("Loading pre-trained weights")
32
+ if len(missing_keys) > 0:
33
+ print(f"Missing keys: {missing_keys}")
34
+ if len(unexpected_keys) > 0:
35
+ print(f"Unexpected keys: {unexpected_keys}")
36
+ return model
37
+
38
+
39
+ def vgg11(reduction: int = 8) -> VGG:
40
+ model = VGG(make_vgg_layers(vgg_cfgs["A"]), reduction=reduction)
41
+ return _load_weights(model, vgg_urls["vgg11"])
42
+
43
+ def vgg11_bn(reduction: int = 8) -> VGG:
44
+ model = VGG(make_vgg_layers(vgg_cfgs["A"], batch_norm=True), reduction=reduction)
45
+ return _load_weights(model, vgg_urls["vgg11_bn"])
46
+
47
+ def vgg13(reduction: int = 8) -> VGG:
48
+ model = VGG(make_vgg_layers(vgg_cfgs["B"]), reduction=reduction)
49
+ return _load_weights(model, vgg_urls["vgg13"])
50
+
51
+ def vgg13_bn(reduction: int = 8) -> VGG:
52
+ model = VGG(make_vgg_layers(vgg_cfgs["B"], batch_norm=True), reduction=reduction)
53
+ return _load_weights(model, vgg_urls["vgg13_bn"])
54
+
55
+ def vgg16(reduction: int = 8) -> VGG:
56
+ model = VGG(make_vgg_layers(vgg_cfgs["D"]), reduction=reduction)
57
+ return _load_weights(model, vgg_urls["vgg16"])
58
+
59
+ def vgg16_bn(reduction: int = 8) -> VGG:
60
+ model = VGG(make_vgg_layers(vgg_cfgs["D"], batch_norm=True), reduction=reduction)
61
+ return _load_weights(model, vgg_urls["vgg16_bn"])
62
+
63
+ def vgg19(reduction: int = 8) -> VGG:
64
+ model = VGG(make_vgg_layers(vgg_cfgs["E"]), reduction=reduction)
65
+ return _load_weights(model, vgg_urls["vgg19"])
66
+
67
+ def vgg19_bn(reduction: int = 8) -> VGG:
68
+ model = VGG(make_vgg_layers(vgg_cfgs["E"], batch_norm=True), reduction=reduction)
69
+ return _load_weights(model, vgg_urls["vgg19_bn"])
models/encoder/vit.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections import OrderedDict
3
+ from functools import partial
4
+ from typing import Any, Callable, List, NamedTuple, Optional, Tuple
5
+
6
+ import torch
7
+ from torch import nn, Tensor
8
+ import torch.nn.functional as F
9
+ from torch.hub import load_state_dict_from_url
10
+ from einops import rearrange
11
+
12
+ from ..utils import Conv2dNormActivation, MLP
13
+ from ..utils import _log_api_usage_once
14
+
15
+
16
+ weights = {
17
+ "vit_b_16": "https://download.pytorch.org/models/vit_b_16-c867db91.pth",
18
+ "vit_b_32": "https://download.pytorch.org/models/vit_b_32-d86f8d99.pth",
19
+ "vit_l_16": "https://download.pytorch.org/models/vit_l_16-852ce7e3.pth",
20
+ "vit_l_32": "https://download.pytorch.org/models/vit_l_32-c7638314.pth",
21
+ "vit_h_14": "https://download.pytorch.org/models/vit_h_14-6kbcf7eb.pth",
22
+ }
23
+
24
+
25
+ class ConvStemConfig(NamedTuple):
26
+ out_channels: int
27
+ kernel_size: int
28
+ stride: int
29
+ norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d
30
+ activation_layer: Callable[..., nn.Module] = nn.ReLU
31
+
32
+
33
+ class MLPBlock(MLP):
34
+ """Transformer MLP block."""
35
+
36
+ _version = 2
37
+
38
+ def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
39
+ super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)
40
+
41
+ for m in self.modules():
42
+ if isinstance(m, nn.Linear):
43
+ nn.init.xavier_uniform_(m.weight)
44
+ if m.bias is not None:
45
+ nn.init.normal_(m.bias, std=1e-6)
46
+
47
+ def _load_from_state_dict(
48
+ self,
49
+ state_dict,
50
+ prefix,
51
+ local_metadata,
52
+ strict,
53
+ missing_keys,
54
+ unexpected_keys,
55
+ error_msgs,
56
+ ):
57
+ version = local_metadata.get("version", None)
58
+
59
+ if version is None or version < 2:
60
+ # Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053
61
+ for i in range(2):
62
+ for type in ["weight", "bias"]:
63
+ old_key = f"{prefix}linear_{i+1}.{type}"
64
+ new_key = f"{prefix}{3*i}.{type}"
65
+ if old_key in state_dict:
66
+ state_dict[new_key] = state_dict.pop(old_key)
67
+
68
+ super()._load_from_state_dict(
69
+ state_dict,
70
+ prefix,
71
+ local_metadata,
72
+ strict,
73
+ missing_keys,
74
+ unexpected_keys,
75
+ error_msgs,
76
+ )
77
+
78
+
79
+ class EncoderBlock(nn.Module):
80
+ """Transformer encoder block."""
81
+
82
+ def __init__(
83
+ self,
84
+ num_heads: int,
85
+ hidden_dim: int,
86
+ mlp_dim: int,
87
+ dropout: float,
88
+ attention_dropout: float,
89
+ norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
90
+ ):
91
+ super().__init__()
92
+ self.num_heads = num_heads
93
+
94
+ # Attention block
95
+ self.ln_1 = norm_layer(hidden_dim)
96
+ self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
97
+ self.dropout = nn.Dropout(dropout)
98
+
99
+ # MLP block
100
+ self.ln_2 = norm_layer(hidden_dim)
101
+ self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)
102
+
103
+ def forward(self, input: Tensor):
104
+ torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
105
+ x = self.ln_1(input)
106
+ x, _ = self.self_attention(x, x, x, need_weights=False)
107
+ x = self.dropout(x)
108
+ x = x + input
109
+
110
+ y = self.ln_2(x)
111
+ y = self.mlp(y)
112
+ return x + y
113
+
114
+
115
+ class Encoder(nn.Module):
116
+ """Transformer Model Encoder for sequence to sequence translation."""
117
+ def __init__(
118
+ self,
119
+ num_h_patches: int,
120
+ num_w_patches: int,
121
+ num_layers: int,
122
+ num_heads: int,
123
+ hidden_dim: int,
124
+ mlp_dim: int,
125
+ dropout: float,
126
+ attention_dropout: float,
127
+ norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
128
+ ):
129
+ super().__init__()
130
+ self.num_h_patches = num_h_patches
131
+ self.num_w_patches = num_w_patches
132
+
133
+ # Note that batch_size is on the first dim because
134
+ # we have batch_first=True in nn.MultiAttention() by default
135
+ seq_length = num_h_patches * num_w_patches + 1 # +1 for the class token
136
+ self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT
137
+ self.dropout = nn.Dropout(dropout)
138
+ layers: OrderedDict[str, nn.Module] = OrderedDict()
139
+ for i in range(num_layers):
140
+ layers[f"encoder_layer_{i}"] = EncoderBlock(
141
+ num_heads,
142
+ hidden_dim,
143
+ mlp_dim,
144
+ dropout,
145
+ attention_dropout,
146
+ norm_layer,
147
+ )
148
+ self.layers = nn.Sequential(layers)
149
+ self.ln = norm_layer(hidden_dim)
150
+
151
+ def _get_pos_embedding(self, n_h: int, n_w: int) -> Tensor:
152
+ if n_h == self.num_h_patches and n_w == self.num_w_patches:
153
+ return self.pos_embedding
154
+ else:
155
+ pos_embedding = self.pos_embedding[:, 1:, :]
156
+ pos_embedding = rearrange(pos_embedding, "1 (h w) d -> 1 d h w", h=self.num_h_patches, w=self.num_w_patches)
157
+ pos_embedding = F.interpolate(pos_embedding, size=(n_h, n_w), mode="bicubic")
158
+ pos_embedding = rearrange(pos_embedding, "1 d h w -> 1 (h w) d")
159
+ return torch.cat([self.pos_embedding[:, :1, :], pos_embedding], dim=1)
160
+
161
+ def forward(self, input: Tensor, n_h: int, n_w: int) -> Tensor:
162
+ torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
163
+ input = input + self._get_pos_embedding(n_h, n_w)
164
+ return self.ln(self.layers(self.dropout(input)))
165
+
166
+
167
+ class VisionTransformer(nn.Module):
168
+ """Vision Transformer as a feature extractor."""
169
+
170
+ def __init__(
171
+ self,
172
+ image_size: int,
173
+ patch_size: int,
174
+ num_layers: int,
175
+ num_heads: int,
176
+ hidden_dim: int,
177
+ mlp_dim: int,
178
+ dropout: float = 0.0,
179
+ attention_dropout: float = 0.0,
180
+ # num_classes: int = 1000, # No need for the classification head as we only need the features
181
+ reduction: Optional[int] = None,
182
+ representation_size: Optional[int] = None,
183
+ norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
184
+ conv_stem_configs: Optional[List[ConvStemConfig]] = None,
185
+ ):
186
+ super().__init__()
187
+ _log_api_usage_once(self)
188
+ torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!")
189
+ self.image_size = image_size
190
+ self.patch_size = patch_size
191
+ self.hidden_dim = hidden_dim
192
+ self.mlp_dim = mlp_dim
193
+ self.attention_dropout = attention_dropout
194
+ self.dropout = dropout
195
+ # self.num_classes = num_classes
196
+ self.representation_size = representation_size
197
+ self.norm_layer = norm_layer
198
+
199
+ if conv_stem_configs is not None:
200
+ # As per https://arxiv.org/abs/2106.14881
201
+ seq_proj = nn.Sequential()
202
+ prev_channels = 3
203
+ for i, conv_stem_layer_config in enumerate(conv_stem_configs):
204
+ seq_proj.add_module(
205
+ f"conv_bn_relu_{i}",
206
+ Conv2dNormActivation(
207
+ in_channels=prev_channels,
208
+ out_channels=conv_stem_layer_config.out_channels,
209
+ kernel_size=conv_stem_layer_config.kernel_size,
210
+ stride=conv_stem_layer_config.stride,
211
+ norm_layer=conv_stem_layer_config.norm_layer,
212
+ activation_layer=conv_stem_layer_config.activation_layer,
213
+ ),
214
+ )
215
+ prev_channels = conv_stem_layer_config.out_channels
216
+ seq_proj.add_module(
217
+ "conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1)
218
+ )
219
+ self.conv_proj: nn.Module = seq_proj
220
+ else:
221
+ self.conv_proj = nn.Conv2d(
222
+ in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size
223
+ )
224
+
225
+ seq_length = (image_size // patch_size) ** 2
226
+
227
+ # Add a class token
228
+ self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
229
+ seq_length += 1
230
+
231
+ self.encoder = Encoder(
232
+ image_size // patch_size,
233
+ image_size // patch_size,
234
+ num_layers,
235
+ num_heads,
236
+ hidden_dim,
237
+ mlp_dim,
238
+ dropout,
239
+ attention_dropout,
240
+ norm_layer,
241
+ )
242
+ self.seq_length = seq_length
243
+
244
+ # heads_layers: OrderedDict[str, nn.Module] = OrderedDict()
245
+ # if representation_size is None:
246
+ # heads_layers["head"] = nn.Linear(hidden_dim, num_classes)
247
+ # else:
248
+ # heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)
249
+ # heads_layers["act"] = nn.Tanh()
250
+ # heads_layers["head"] = nn.Linear(representation_size, num_classes)
251
+
252
+ # self.heads = nn.Sequential(heads_layers)
253
+
254
+ if isinstance(self.conv_proj, nn.Conv2d):
255
+ # Init the patchify stem
256
+ fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
257
+ nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
258
+ if self.conv_proj.bias is not None:
259
+ nn.init.zeros_(self.conv_proj.bias)
260
+ elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d):
261
+ # Init the last 1x1 conv of the conv stem
262
+ nn.init.normal_(
263
+ self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels)
264
+ )
265
+ if self.conv_proj.conv_last.bias is not None:
266
+ nn.init.zeros_(self.conv_proj.conv_last.bias)
267
+
268
+ # if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear):
269
+ # fan_in = self.heads.pre_logits.in_features
270
+ # nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in))
271
+ # nn.init.zeros_(self.heads.pre_logits.bias)
272
+
273
+ # if isinstance(self.heads.head, nn.Linear):
274
+ # nn.init.zeros_(self.heads.head.weight)
275
+ # nn.init.zeros_(self.heads.head.bias)
276
+
277
+ self.encoder_reduction = self.patch_size
278
+ self.reduction = self.encoder_reduction if reduction is None else reduction
279
+ self.channels = hidden_dim
280
+
281
+ def _process_input(self, x: Tensor) -> Tuple[Tensor, int, int, int]:
282
+ # (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
283
+ x = self.conv_proj(x)
284
+ n, _, n_h, n_w = x.shape
285
+ # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
286
+ x = x.reshape(n, self.hidden_dim, n_h * n_w)
287
+
288
+ # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
289
+ # The self attention layer expects inputs in the format (N, S, E)
290
+ # where S is the source sequence length, N is the batch size, E is the
291
+ # embedding dimension
292
+ x = x.permute(0, 2, 1)
293
+
294
+ return x, n, n_h, n_w
295
+
296
+ def forward(self, x: Tensor) -> Tensor:
297
+ # Reshape and permute the input tensor
298
+ x, n, n_h, n_w = self._process_input(x)
299
+
300
+ # Expand the class token to the full batch
301
+ batch_class_token = self.class_token.expand(n, -1, -1)
302
+ x = torch.cat([batch_class_token, x], dim=1)
303
+
304
+ x = self.encoder(x, n_h, n_w) # Allows input image to be of any size.
305
+
306
+ # Classifier "token" as used by standard language architectures
307
+ # x = x[:, 0]
308
+
309
+ # x = self.heads(x)
310
+
311
+ x = x[:, 1:, :]
312
+ x = rearrange(x, "n (h w) d -> n d h w", h=n_h, w=n_w)
313
+ if self.encoder_reduction != self.reduction:
314
+ x = F.interpolate(x, scale_factor=self.encoder_reduction / self.reduction, mode="bilinear")
315
+ return x # To be consistent with timm models
316
+
317
+
318
+ def _vision_transformer(
319
+ patch_size: int,
320
+ num_layers: int,
321
+ num_heads: int,
322
+ hidden_dim: int,
323
+ mlp_dim: int,
324
+ weights: str,
325
+ **kwargs: Any,
326
+ ) -> VisionTransformer:
327
+ image_size = kwargs.pop("image_size", 224)
328
+
329
+ model = VisionTransformer(
330
+ image_size=image_size,
331
+ patch_size=patch_size,
332
+ num_layers=num_layers,
333
+ num_heads=num_heads,
334
+ hidden_dim=hidden_dim,
335
+ mlp_dim=mlp_dim,
336
+ **kwargs,
337
+ )
338
+
339
+ if weights is not None:
340
+ weights = load_state_dict_from_url(weights, progress=kwargs.get("progress", True))
341
+ missing_keys, unexpected_keys = model.load_state_dict(weights, strict=False)
342
+ if len(missing_keys) > 0:
343
+ print(f"Missing keys: {missing_keys}")
344
+ if len(unexpected_keys) > 0:
345
+ print(f"Unexpected keys: {unexpected_keys}")
346
+
347
+ return model
348
+
349
+
350
+ def interpolate_embeddings(
351
+ image_size: int,
352
+ patch_size: int,
353
+ pos_embedding: Tensor,
354
+ interpolation_mode: str = "bicubic",
355
+ ) -> Tensor:
356
+ """This function helps interpolate positional embeddings during checkpoint loading,
357
+ especially when you want to apply a pre-trained model on images with different resolution.
358
+
359
+ Args:
360
+ image_size (int): Image size of the new model.
361
+ patch_size (int): Patch size of the new model.
362
+ model_state (OrderedDict[str, Tensor]): State dict of the pre-trained model.
363
+ interpolation_mode (str): The algorithm used for upsampling. Default: bicubic.
364
+ reset_heads (bool): If true, not copying the state of heads. Default: False.
365
+
366
+ Returns:
367
+ Tensor: The interpolated positional embedding.
368
+ """
369
+ # Shape of pos_embedding is (1, seq_length, hidden_dim)
370
+ n, seq_length, hidden_dim = pos_embedding.shape
371
+ if n != 1:
372
+ raise ValueError(f"Unexpected position embedding shape: {pos_embedding.shape}")
373
+
374
+ new_seq_length = (image_size // patch_size) ** 2 + 1
375
+
376
+ # Need to interpolate the weights for the position embedding.
377
+ # We do this by reshaping the positions embeddings to a 2d grid, performing
378
+ # an interpolation in the (h, w) space and then reshaping back to a 1d grid.
379
+ if new_seq_length != seq_length:
380
+ # The class token embedding shouldn't be interpolated, so we split it up.
381
+ seq_length -= 1
382
+ new_seq_length -= 1
383
+ pos_embedding_token = pos_embedding[:, :1, :]
384
+ pos_embedding_img = pos_embedding[:, 1:, :]
385
+
386
+ # (1, seq_length, hidden_dim) -> (1, hidden_dim, seq_length)
387
+ pos_embedding_img = pos_embedding_img.permute(0, 2, 1)
388
+ seq_length_1d = int(math.sqrt(seq_length))
389
+ if seq_length_1d * seq_length_1d != seq_length:
390
+ raise ValueError(
391
+ f"seq_length is not a perfect square! Instead got seq_length_1d * seq_length_1d = {seq_length_1d * seq_length_1d } and seq_length = {seq_length}"
392
+ )
393
+
394
+ # (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d)
395
+ pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d)
396
+ new_seq_length_1d = image_size // patch_size
397
+
398
+ # Perform interpolation.
399
+ # (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d)
400
+ new_pos_embedding_img = nn.functional.interpolate(
401
+ pos_embedding_img,
402
+ size=new_seq_length_1d,
403
+ mode=interpolation_mode,
404
+ )
405
+
406
+ # (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length)
407
+ new_pos_embedding_img = new_pos_embedding_img.reshape(1, hidden_dim, new_seq_length)
408
+
409
+ # (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim)
410
+ new_pos_embedding_img = new_pos_embedding_img.permute(0, 2, 1)
411
+ new_pos_embedding = torch.cat([pos_embedding_token, new_pos_embedding_img], dim=1)
412
+
413
+ return new_pos_embedding
414
+
415
+ return pos_embedding
416
+
417
+
418
+ def vit_b_16(
419
+ image_size: int = 224,
420
+ reduction: int = 16,
421
+ **kwargs: Any,
422
+ ) -> VisionTransformer:
423
+ vit = _vision_transformer(
424
+ patch_size=16,
425
+ num_layers=12,
426
+ num_heads=12,
427
+ hidden_dim=768,
428
+ mlp_dim=3072,
429
+ weights=weights["vit_b_16"],
430
+ reduction=reduction,
431
+ **kwargs,
432
+ )
433
+ if image_size != 224:
434
+ vit.image_size = image_size
435
+ new_pos_embedding = interpolate_embeddings(image_size, 16, vit.state_dict()["encoder.pos_embedding"], "bicubic")
436
+ vit.encoder.pos_embedding = nn.Parameter(new_pos_embedding, requires_grad=True)
437
+ return vit
438
+
439
+
440
+ def vit_b_32(
441
+ image_size: int = 224,
442
+ reduction: int = 32,
443
+ **kwargs: Any,
444
+ ) -> VisionTransformer:
445
+ vit = _vision_transformer(
446
+ patch_size=32,
447
+ num_layers=12,
448
+ num_heads=12,
449
+ hidden_dim=768,
450
+ mlp_dim=3072,
451
+ weights=weights["vit_b_32"],
452
+ reduction=reduction,
453
+ **kwargs,
454
+ )
455
+ if image_size != 224:
456
+ vit.image_size = image_size
457
+ new_pos_embedding = interpolate_embeddings(image_size, 32, vit.state_dict()["encoder.pos_embedding"], "bicubic")
458
+ vit.encoder.pos_embedding = nn.Parameter(new_pos_embedding, requires_grad=True)
459
+ return vit
460
+
461
+
462
+ def vit_l_16(
463
+ image_size: int = 224,
464
+ reduction: int = 16,
465
+ **kwargs: Any,
466
+ ) -> VisionTransformer:
467
+ vit = _vision_transformer(
468
+ patch_size=16,
469
+ num_layers=24,
470
+ num_heads=16,
471
+ hidden_dim=1024,
472
+ mlp_dim=4096,
473
+ weights=weights["vit_l_16"],
474
+ reduction=reduction,
475
+ **kwargs,
476
+ )
477
+ if image_size != 224:
478
+ vit.image_size = image_size
479
+ new_pos_embedding = interpolate_embeddings(image_size, 16, vit.state_dict()["encoder.pos_embedding"], "bicubic")
480
+ vit.encoder.pos_embedding = nn.Parameter(new_pos_embedding, requires_grad=True)
481
+ return vit
482
+
483
+
484
+ def vit_l_32(
485
+ image_size: int = 224,
486
+ reduction: int = 32,
487
+ **kwargs: Any,
488
+ ) -> VisionTransformer:
489
+ vit = _vision_transformer(
490
+ patch_size=32,
491
+ num_layers=24,
492
+ num_heads=16,
493
+ hidden_dim=1024,
494
+ mlp_dim=4096,
495
+ weights=weights["vit_l_32"],
496
+ reduction=reduction,
497
+ **kwargs,
498
+ )
499
+ if image_size != 224:
500
+ vit.image_size = image_size
501
+ new_pos_embedding = interpolate_embeddings(image_size, 32, vit.state_dict()["encoder.pos_embedding"], "bicubic")
502
+ vit.encoder.pos_embedding = nn.Parameter(new_pos_embedding, requires_grad=True)
503
+ return vit
504
+
505
+
506
+ def vit_h_14(
507
+ image_size: int = 224,
508
+ reduction: int = 14,
509
+ **kwargs: Any,
510
+ ) -> VisionTransformer:
511
+ vit = _vision_transformer(
512
+ patch_size=14,
513
+ num_layers=32,
514
+ num_heads=16,
515
+ hidden_dim=1280,
516
+ mlp_dim=5120,
517
+ weights=weights["vit_h_14"],
518
+ reduction=reduction,
519
+ **kwargs,
520
+ )
521
+ if image_size != 224:
522
+ vit.image_size = image_size
523
+ new_pos_embedding = interpolate_embeddings(image_size, 14, vit.state_dict()["encoder.pos_embedding"], "bicubic")
524
+ vit.encoder.pos_embedding = nn.Parameter(new_pos_embedding, requires_grad=True)
525
+ return vit
526
+
models/encoder_decoder/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .vgg import vgg11 as vgg11_ae, vgg11_bn as vgg11_bn_ae
2
+ from .vgg import vgg13 as vgg13_ae, vgg13_bn as vgg13_bn_ae
3
+ from .vgg import vgg16 as vgg16_ae, vgg16_bn as vgg16_bn_ae
4
+ from .vgg import vgg19 as vgg19_ae, vgg19_bn as vgg19_bn_ae
5
+ from .resnet import resnet18 as resnet18_ae, resnet34 as resnet34_ae
6
+ from .resnet import resnet50 as resnet50_ae, resnet101 as resnet101_ae, resnet152 as resnet152_ae
7
+
8
+ from .cannet import cannet, cannet_bn
9
+ from .csrnet import csrnet, csrnet_bn
10
+
11
+
12
+ __all__ = [
13
+ "vgg11_ae", "vgg11_bn_ae", "vgg13_ae", "vgg13_bn_ae", "vgg16_ae", "vgg16_bn_ae", "vgg19_ae", "vgg19_bn_ae",
14
+ "resnet18_ae", "resnet34_ae", "resnet50_ae", "resnet101_ae", "resnet152_ae",
15
+ "cannet", "cannet_bn",
16
+ "csrnet", "csrnet_bn",
17
+ ]
models/encoder_decoder/cannet.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ import torch.nn.functional as F
4
+
5
+ from typing import List, Optional
6
+
7
+ from ..utils import _init_weights
8
+ from .csrnet import CSRNet, csrnet, csrnet_bn
9
+
10
+ EPS = 1e-6
11
+
12
+
13
+ class ContextualModule(nn.Module):
14
+ def __init__(
15
+ self,
16
+ in_channels: int,
17
+ out_channels: int = 512,
18
+ sizes: List[int] = [1, 2, 3, 6],
19
+ ) -> None:
20
+ super().__init__()
21
+ self.scales = nn.ModuleList([self.__make_scale__(in_channels, size) for size in sizes])
22
+ self.bottleneck = nn.Conv2d(in_channels * 2, out_channels, kernel_size=1)
23
+ self.relu = nn.ReLU(inplace=True)
24
+ self.weight_net = nn.Conv2d(in_channels, in_channels, kernel_size=1)
25
+
26
+ def __make_weight__(self, feature: Tensor, scale_feature: Tensor) -> Tensor:
27
+ weight_feature = feature - scale_feature
28
+ weight_feature = self.weight_net(weight_feature)
29
+ return F.sigmoid(weight_feature)
30
+
31
+ def __make_scale__(self, channels: int, size: int) -> nn.Module:
32
+ return nn.Sequential(
33
+ nn.AdaptiveAvgPool2d(output_size=(size, size)),
34
+ nn.Conv2d(channels, channels, kernel_size=1, bias=False),
35
+ )
36
+
37
+ def forward(self, feature: Tensor) -> Tensor:
38
+ h, w = feature.shape[-2:]
39
+ multi_scales = [F.interpolate(input=scale(feature), size=(h, w), mode="bilinear") for scale in self.scales]
40
+ weights = [self.__make_weight__(feature, scale_feature) for scale_feature in multi_scales]
41
+ multi_scales = sum([multi_scales[i] * weights[i] for i in range(len(weights))]) / (sum(weights) + EPS)
42
+ overall_features = torch.cat([multi_scales, feature], dim=1)
43
+ overall_features = self.bottleneck(overall_features)
44
+ overall_features = self.relu(overall_features)
45
+ return overall_features
46
+
47
+
48
+ class CANNet(nn.Module):
49
+ def __init__(
50
+ self,
51
+ csrnet: CSRNet,
52
+ sizes: List[int] = [1, 2, 3, 6],
53
+ reduction: Optional[int] = 8,
54
+ ) -> None:
55
+ super().__init__()
56
+ assert isinstance(csrnet, CSRNet), f"csrnet should be an instance of CSRNet, got {type(csrnet)}."
57
+ assert isinstance(sizes, (tuple, list)), f"sizes should be a list or tuple, got {type(sizes)}."
58
+ assert len(sizes) > 0, f"Expected at least one size, got {len(sizes)}."
59
+ assert all([isinstance(size, int) for size in sizes]), f"Expected all size to be int, got {sizes}."
60
+ self.sizes = sizes
61
+ self.encoder_reduction = csrnet.encoder_reduction
62
+ self.reduction = self.encoder_reduction if reduction is None else reduction
63
+
64
+ self.features = csrnet.features
65
+ self.decoder = csrnet.decoder
66
+ self.decoder.apply(_init_weights)
67
+ self.context = ContextualModule(512, 512, self.sizes)
68
+ self.context.apply(_init_weights)
69
+
70
+ self.channels = csrnet.channels
71
+
72
+ def forward(self, x: Tensor) -> Tensor:
73
+ x = self.features(x)
74
+ x = self.context(x)
75
+ if self.encoder_reduction != self.reduction:
76
+ x = F.interpolate(x, scale_factor=self.encoder_reduction / self.reduction, mode="bilinear")
77
+ x = self.decoder(x)
78
+ return x
79
+
80
+
81
+ def cannet(sizes=[1, 2, 3, 6], reduction: int = 8) -> CANNet:
82
+ return CANNet(csrnet(), sizes=sizes, reduction=reduction)
83
+
84
+ def cannet_bn(sizes=[1, 2, 3, 6], reduction: int = 8) -> CANNet:
85
+ return CANNet(csrnet_bn(), sizes=sizes, reduction=reduction)
models/encoder_decoder/csrnet.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn, Tensor
2
+ import torch.nn.functional as F
3
+ from typing import Optional
4
+
5
+ from ..utils import _init_weights, make_vgg_layers, vgg_urls
6
+ from .vgg import _load_weights
7
+
8
+ EPS = 1e-6
9
+
10
+
11
+ encoder_cfg = [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512]
12
+ decoder_cfg = [512, 512, 512, 256, 128, 64]
13
+
14
+
15
+ class CSRNet(nn.Module):
16
+ def __init__(
17
+ self,
18
+ features: nn.Module,
19
+ decoder: nn.Module,
20
+ reduction: Optional[int] = None,
21
+ ) -> None:
22
+ super().__init__()
23
+ self.features = features
24
+ self.features.apply(_init_weights)
25
+ self.decoder = decoder
26
+ self.decoder.apply(_init_weights)
27
+
28
+ self.encoder_reduction = 8
29
+ self.reduction = self.encoder_reduction if reduction is None else reduction
30
+ self.channels = 64
31
+
32
+ def forward(self, x: Tensor) -> Tensor:
33
+ x = self.features(x)
34
+ if self.encoder_reduction != self.reduction:
35
+ x = F.interpolate(x, scale_factor=self.encoder_reduction / self.reduction, mode="bilinear")
36
+ x = self.decoder(x)
37
+ return x
38
+
39
+
40
+ def csrnet(reduction: int = 8) -> CSRNet:
41
+ model = CSRNet(
42
+ make_vgg_layers(encoder_cfg, in_channels=3, batch_norm=False, dilation=1),
43
+ make_vgg_layers(decoder_cfg, in_channels=512, batch_norm=False, dilation=2),
44
+ reduction=reduction
45
+ )
46
+ return _load_weights(model, vgg_urls["vgg16"])
47
+
48
+ def csrnet_bn(reduction: int = 8) -> CSRNet:
49
+ model = CSRNet(
50
+ make_vgg_layers(encoder_cfg, in_channels=3, batch_norm=True, dilation=1),
51
+ make_vgg_layers(decoder_cfg, in_channels=512, batch_norm=True, dilation=2),
52
+ reduction=reduction
53
+ )
54
+ return _load_weights(model, vgg_urls["vgg16"])
models/encoder_decoder/resnet.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn, Tensor
2
+ import torch.nn.functional as F
3
+ import timm
4
+ from typing import Union, Optional
5
+
6
+ from ..utils import BasicBlock, Bottleneck, make_resnet_layers
7
+ from ..utils import _init_weights
8
+
9
+
10
+ model_configs = {
11
+ "resnet18.tv_in1k": {
12
+ "decoder_channels": [512, 256, 128],
13
+ },
14
+ "resnet34.tv_in1k": {
15
+ "decoder_channels": [512, 256, 128],
16
+ },
17
+ "resnet50.tv_in1k": {
18
+ "decoder_channels": [512, 256, 256, 128],
19
+ },
20
+ "resnet101.tv_in1k": {
21
+ "decoder_channels": [512, 512, 256, 256, 128],
22
+ },
23
+ "resnet152.tv_in1k": {
24
+ "decoder_channels": [512, 512, 512, 256, 256, 128],
25
+ },
26
+ }
27
+
28
+
29
+ class ResNet(nn.Module):
30
+ def __init__(
31
+ self,
32
+ decoder_block: Union[BasicBlock, Bottleneck],
33
+ backbone: str = "resnet34.tv_in1k",
34
+ reduction: Optional[int] = None,
35
+ ) -> None:
36
+ super().__init__()
37
+ assert backbone in model_configs.keys(), f"Backbone should be in {model_configs.keys()}"
38
+ config = model_configs[backbone]
39
+ encoder = timm.create_model(backbone, pretrained=True, features_only=True, out_indices=(-1,))
40
+ encoder_reduction = encoder.feature_info.reduction()[-1]
41
+
42
+ if reduction <= 16:
43
+ if "resnet18" in backbone or "resnet34" in backbone:
44
+ encoder.layer4[0].conv1.stride = (1, 1)
45
+ encoder.layer4[0].downsample[0].stride = (1, 1)
46
+ else:
47
+ encoder.layer4[0].conv2.stride = (1, 1)
48
+ encoder.layer4[0].downsample[0].stride = (1, 1)
49
+ encoder_reduction = encoder_reduction // 2
50
+
51
+ self.encoder = encoder
52
+ self.encoder_reduction = encoder_reduction
53
+
54
+ encoder_out_channels = self.encoder.feature_info.channels()[-1]
55
+
56
+ decoder_channels = config["decoder_channels"]
57
+ self.decoder = make_resnet_layers(
58
+ block=decoder_block,
59
+ cfg=decoder_channels,
60
+ in_channels=encoder_out_channels,
61
+ dilation=1,
62
+ expansion=1,
63
+ )
64
+ self.decoder.apply(_init_weights)
65
+
66
+ self.reduction = self.encoder_reduction if reduction is None else reduction
67
+ self.channels = decoder_channels[-1]
68
+
69
+ def forward(self, x: Tensor) -> Tensor:
70
+ x = self.encoder(x)[-1]
71
+ if self.encoder_reduction != self.reduction:
72
+ x = F.interpolate(x, scale_factor=self.encoder_reduction / self.reduction, mode="bilinear")
73
+ x = self.decoder(x)
74
+
75
+ return x
76
+
77
+
78
+ def resnet18(reduction: int = 32) -> ResNet:
79
+ return ResNet(decoder_block=BasicBlock, backbone="resnet18.tv_in1k", reduction=reduction)
80
+
81
+
82
+ def resnet34(reduction: int = 32) -> ResNet:
83
+ return ResNet(decoder_block=BasicBlock, backbone="resnet34.tv_in1k", reduction=reduction)
84
+
85
+
86
+ def resnet50(reduction: int = 32) -> ResNet:
87
+ return ResNet(decoder_block=Bottleneck, backbone="resnet50.tv_in1k", reduction=reduction)
88
+
89
+
90
+ def resnet101(reduction: int = 32) -> ResNet:
91
+ return ResNet(decoder_block=Bottleneck, backbone="resnet101.tv_in1k", reduction=reduction)
92
+
93
+
94
+ def resnet152(reduction: int = 32) -> ResNet:
95
+ return ResNet(decoder_block=Bottleneck, backbone="resnet152.tv_in1k", reduction=reduction)
models/encoder_decoder/vgg.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The model used in the paper Distribution Matching for Crowd Counting.
2
+ # Code adapted from https://github.com/cvlab-stonybrook/DM-Count/blob/master/models.py
3
+ from torch import nn, Tensor
4
+ import torch.nn.functional as F
5
+ from torch.hub import load_state_dict_from_url
6
+ from typing import Optional
7
+
8
+ from ..utils import make_vgg_layers, vgg_cfgs, vgg_urls
9
+ from ..utils import _init_weights
10
+
11
+
12
+
13
+ class VGG(nn.Module):
14
+ def __init__(
15
+ self,
16
+ features: nn.Module,
17
+ reduction: Optional[int] = None,
18
+ ) -> None:
19
+ super().__init__()
20
+ self.features = features
21
+ self.reg_layer = nn.Sequential(
22
+ nn.Conv2d(512, 256, kernel_size=3, padding=1),
23
+ nn.ReLU(inplace=True),
24
+ nn.Conv2d(256, 128, kernel_size=3, padding=1),
25
+ nn.ReLU(inplace=True),
26
+ )
27
+
28
+ self.reg_layer.apply(_init_weights)
29
+ # Remove the density layer, as the output from this model is not final and will be further processed.
30
+ # self.density_layer = nn.Sequential(nn.Conv2d(128, 1, 1), nn.ReLU())
31
+ self.encoder_reduction = 16
32
+ self.reduction = self.encoder_reduction if reduction is None else reduction
33
+ self.channels = 128
34
+
35
+ def forward(self, x: Tensor) -> Tensor:
36
+ x = self.features(x)
37
+ if self.encoder_reduction != self.reduction:
38
+ x = F.interpolate(x, scale_factor=self.encoder_reduction / self.reduction, mode="bilinear")
39
+ x = self.reg_layer(x)
40
+ # x = self.density_layer(x)
41
+ return x
42
+
43
+
44
+ def _load_weights(model: VGG, url: str) -> VGG:
45
+ state_dict = load_state_dict_from_url(url)
46
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
47
+ print("Loading pre-trained weights")
48
+ if len(missing_keys) > 0:
49
+ print(f"Missing keys: {missing_keys}")
50
+ if len(unexpected_keys) > 0:
51
+ print(f"Unexpected keys: {unexpected_keys}")
52
+ return model
53
+
54
+
55
+ def vgg11(reduction: int = 8) -> VGG:
56
+ model = VGG(make_vgg_layers(vgg_cfgs["A"]), reduction=reduction)
57
+ return _load_weights(model, vgg_urls["vgg11"])
58
+
59
+ def vgg11_bn(reduction: int = 8) -> VGG:
60
+ model = VGG(make_vgg_layers(vgg_cfgs["A"], batch_norm=True), reduction=reduction)
61
+ return _load_weights(model, vgg_urls["vgg11_bn"])
62
+
63
+ def vgg13(reduction: int = 8) -> VGG:
64
+ model = VGG(make_vgg_layers(vgg_cfgs["B"]), reduction=reduction)
65
+ return _load_weights(model, vgg_urls["vgg13"])
66
+
67
+ def vgg13_bn(reduction: int = 8) -> VGG:
68
+ model = VGG(make_vgg_layers(vgg_cfgs["B"], batch_norm=True), reduction=reduction)
69
+ return _load_weights(model, vgg_urls["vgg13_bn"])
70
+
71
+ def vgg16(reduction: int = 8) -> VGG:
72
+ model = VGG(make_vgg_layers(vgg_cfgs["D"]), reduction=reduction)
73
+ return _load_weights(model, vgg_urls["vgg16"])
74
+
75
+ def vgg16_bn(reduction: int = 8) -> VGG:
76
+ model = VGG(make_vgg_layers(vgg_cfgs["D"], batch_norm=True), reduction=reduction)
77
+ return _load_weights(model, vgg_urls["vgg16_bn"])
78
+
79
+ def vgg19(reduction: int = 8) -> VGG:
80
+ model = VGG(make_vgg_layers(vgg_cfgs["E"]), reduction=reduction)
81
+ return _load_weights(model, vgg_urls["vgg19"])
82
+
83
+ def vgg19_bn(reduction: int = 8) -> VGG:
84
+ model = VGG(make_vgg_layers(vgg_cfgs["E"], batch_norm=True), reduction=reduction)
85
+ return _load_weights(model, vgg_urls["vgg19_bn"])
models/model.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ import os
4
+ from typing import List, Tuple, Union, Callable
5
+ from functools import partial
6
+
7
+ from .utils import _init_weights
8
+
9
+ from . import encoder
10
+ from . import encoder_decoder
11
+ from .encoder import _timm_encoder
12
+
13
+
14
+ curr_dir = os.path.abspath(os.path.dirname(__file__))
15
+
16
+
17
+ class Regressor(nn.Module):
18
+ def __init__(self, backbone: nn.Module) -> None:
19
+ super().__init__()
20
+ self.backbone = backbone
21
+ self.reduction = backbone.reduction
22
+
23
+ self.regressor = nn.Sequential(
24
+ nn.Conv2d(backbone.channels, 1, kernel_size=1),
25
+ nn.ReLU(inplace=True),
26
+ )
27
+ self.regressor.apply(_init_weights)
28
+ self.bins = None
29
+ self.anchor_points = None
30
+
31
+ def forward(self, x: Tensor) -> Tensor:
32
+ x = self.backbone(x)
33
+ x = self.regressor(x)
34
+ return x
35
+
36
+
37
+ class Classifier(nn.Module):
38
+ def __init__(
39
+ self,
40
+ backbone: nn.Module,
41
+ bins: List[Tuple[float, float]],
42
+ anchor_points: List[float],
43
+ ) -> None:
44
+ super().__init__()
45
+ self.backbone = backbone
46
+ self.reduction = backbone.reduction
47
+
48
+ assert len(bins) == len(anchor_points), f"Expected bins and anchor_points to have the same length, got {len(bins)} and {len(anchor_points)}"
49
+ assert all(len(b) == 2 for b in bins), f"Expected bins to be a list of tuples of length 2, got {bins}"
50
+ assert all(bin[0] <= p <= bin[1] for bin, p in zip(bins, anchor_points)), f"Expected anchor_points to be within the range of the corresponding bin, got {bins} and {anchor_points}"
51
+
52
+ self.bins = bins
53
+ self.anchor_points = torch.tensor(anchor_points, dtype=torch.float32, requires_grad=False).view(1, -1, 1, 1)
54
+
55
+ if backbone.channels > 512:
56
+ self.classifier = nn.Sequential(
57
+ nn.Conv2d(backbone.channels, 512, kernel_size=1), # serves as a linear layer for feature vectors at each pixel
58
+ nn.ReLU(inplace=True),
59
+ nn.Conv2d(512, len(self.bins), kernel_size=1),
60
+ )
61
+ else:
62
+ self.classifier = nn.Conv2d(backbone.channels, len(self.bins), kernel_size=1)
63
+
64
+ self.classifier.apply(_init_weights)
65
+
66
+ def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]:
67
+ x = self.backbone(x)
68
+ x = self.classifier(x) # shape (B, C, H, W), where C = len(bins), x is the logits
69
+
70
+ probs = x.softmax(dim=1) # shape (B, C, H, W)
71
+ exp = (probs * self.anchor_points.to(x.device)).sum(dim=1, keepdim=True) # shape (B, 1, H, W)
72
+ if self.training:
73
+ return x, exp
74
+ else:
75
+ return exp
76
+
77
+
78
+ def _get_backbone(backbone: str, input_size: int, reduction: int) -> Callable:
79
+ assert "clip" not in backbone, f"This function does not support CLIP model, got {backbone}"
80
+
81
+ if backbone in ["vit_b_16", "vit_b_32", "vit_l_16", "vit_l_32", "vit_h_14"]:
82
+ return partial(getattr(encoder, backbone), image_size=input_size, reduction=reduction)
83
+ elif backbone in ["vgg11", "vgg11_bn", "vgg13", "vgg13_bn", "vgg16", "vgg16_bn", "vgg19", "vgg19_bn"]:
84
+ return partial(getattr(encoder, backbone), reduction=reduction)
85
+ elif backbone in ["vgg11_ae", "vgg11_bn_ae", "vgg13_ae", "vgg13_bn_ae", "vgg16_ae", "vgg16_bn_ae", "vgg19_ae", "vgg19_bn_ae"]:
86
+ return partial(getattr(encoder_decoder, backbone), reduction=reduction)
87
+ elif backbone in ["resnet18_ae", "resnet34_ae", "resnet50_ae", "resnet101_ae", "resnet152_ae"]:
88
+ return partial(getattr(encoder_decoder, backbone), reduction=reduction)
89
+ elif backbone in ["cannet", "cannet_bn", "csrnet", "csrnet_bn"]:
90
+ return partial(getattr(encoder_decoder, backbone), reduction=reduction)
91
+ else:
92
+ return partial(_timm_encoder, backbone=backbone, reduction=reduction)
93
+
94
+
95
+ def _regressor(
96
+ backbone: str,
97
+ input_size: int,
98
+ reduction: int,
99
+ ) -> Regressor:
100
+ backbone = _get_backbone(backbone.lower(), input_size, reduction)
101
+ return Regressor(backbone())
102
+
103
+
104
+ def _classifier(
105
+ backbone: nn.Module,
106
+ input_size: int,
107
+ reduction: int,
108
+ bins: List[Tuple[float, float]],
109
+ anchor_points: List[float],
110
+ ) -> Classifier:
111
+ backbone = _get_backbone(backbone.lower(), input_size, reduction)
112
+ return Classifier(backbone(), bins, anchor_points)
models/utils.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ import torch.nn.functional as F
4
+ from functools import partial
5
+ from typing import Callable, Optional, Sequence, Tuple, Union, Any, List, TypeVar, List
6
+ from types import FunctionType
7
+ from itertools import repeat
8
+ import warnings
9
+ import os
10
+ from collections.abc import Iterable
11
+
12
+ V = TypeVar("V")
13
+ curr_dir = os.path.dirname(os.path.abspath(__file__))
14
+
15
+
16
+ vgg_urls = {
17
+ "vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth",
18
+ "vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
19
+ "vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth",
20
+ "vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
21
+ "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
22
+ "vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
23
+ "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
24
+ "vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
25
+ }
26
+
27
+ vgg_cfgs = {
28
+ "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512],
29
+ "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512],
30
+ "D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512],
31
+ "E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512]
32
+ }
33
+
34
+
35
+ def _log_api_usage_once(obj: Any) -> None:
36
+
37
+ """
38
+ Logs API usage(module and name) within an organization.
39
+ In a large ecosystem, it's often useful to track the PyTorch and
40
+ TorchVision APIs usage. This API provides the similar functionality to the
41
+ logging module in the Python stdlib. It can be used for debugging purpose
42
+ to log which methods are used and by default it is inactive, unless the user
43
+ manually subscribes a logger via the `SetAPIUsageLogger method <https://github.com/pytorch/pytorch/blob/eb3b9fe719b21fae13c7a7cf3253f970290a573e/c10/util/Logging.cpp#L114>`_.
44
+ Please note it is triggered only once for the same API call within a process.
45
+ It does not collect any data from open-source users since it is no-op by default.
46
+ For more information, please refer to
47
+ * PyTorch note: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#api-usage-logging;
48
+ * Logging policy: https://github.com/pytorch/vision/issues/5052;
49
+
50
+ Args:
51
+ obj (class instance or method): an object to extract info from.
52
+ """
53
+ module = obj.__module__
54
+ if not module.startswith("torchvision"):
55
+ module = f"torchvision.internal.{module}"
56
+ name = obj.__class__.__name__
57
+ if isinstance(obj, FunctionType):
58
+ name = obj.__name__
59
+ torch._C._log_api_usage_once(f"{module}.{name}")
60
+
61
+
62
+ def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]:
63
+ """
64
+ Make n-tuple from input x. If x is an iterable, then we just convert it to tuple.
65
+ Otherwise, we will make a tuple of length n, all with value of x.
66
+ reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8
67
+
68
+ Args:
69
+ x (Any): input value
70
+ n (int): length of the resulting tuple
71
+ """
72
+ if isinstance(x, Iterable):
73
+ return tuple(x)
74
+ return tuple(repeat(x, n))
75
+
76
+
77
+ class ConvNormActivation(torch.nn.Sequential):
78
+ def __init__(
79
+ self,
80
+ in_channels: int,
81
+ out_channels: int,
82
+ kernel_size: Union[int, Tuple[int, ...]] = 3,
83
+ stride: Union[int, Tuple[int, ...]] = 1,
84
+ padding: Optional[Union[int, Tuple[int, ...], str]] = None,
85
+ groups: int = 1,
86
+ norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
87
+ activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
88
+ dilation: Union[int, Tuple[int, ...]] = 1,
89
+ inplace: Optional[bool] = True,
90
+ bias: Optional[bool] = None,
91
+ conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d,
92
+ ) -> None:
93
+
94
+ if padding is None:
95
+ if isinstance(kernel_size, int) and isinstance(dilation, int):
96
+ padding = (kernel_size - 1) // 2 * dilation
97
+ else:
98
+ _conv_dim = len(kernel_size) if isinstance(kernel_size, Sequence) else len(dilation)
99
+ kernel_size = _make_ntuple(kernel_size, _conv_dim)
100
+ dilation = _make_ntuple(dilation, _conv_dim)
101
+ padding = tuple((kernel_size[i] - 1) // 2 * dilation[i] for i in range(_conv_dim))
102
+ if bias is None:
103
+ bias = norm_layer is None
104
+
105
+ layers = [
106
+ conv_layer(
107
+ in_channels,
108
+ out_channels,
109
+ kernel_size,
110
+ stride,
111
+ padding,
112
+ dilation=dilation,
113
+ groups=groups,
114
+ bias=bias,
115
+ )
116
+ ]
117
+
118
+ if norm_layer is not None:
119
+ layers.append(norm_layer(out_channels))
120
+
121
+ if activation_layer is not None:
122
+ params = {} if inplace is None else {"inplace": inplace}
123
+ layers.append(activation_layer(**params))
124
+ super().__init__(*layers)
125
+ _log_api_usage_once(self)
126
+ self.out_channels = out_channels
127
+
128
+ if self.__class__ == ConvNormActivation:
129
+ warnings.warn(
130
+ "Don't use ConvNormActivation directly, please use Conv2dNormActivation and Conv3dNormActivation instead."
131
+ )
132
+
133
+
134
+ class Conv2dNormActivation(ConvNormActivation):
135
+ """
136
+ Configurable block used for Convolution2d-Normalization-Activation blocks.
137
+
138
+ Args:
139
+ in_channels (int): Number of channels in the input image
140
+ out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block
141
+ kernel_size: (int, optional): Size of the convolving kernel. Default: 3
142
+ stride (int, optional): Stride of the convolution. Default: 1
143
+ padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will be calculated as ``padding = (kernel_size - 1) // 2 * dilation``
144
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
145
+ norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer won't be used. Default: ``torch.nn.BatchNorm2d``
146
+ activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU``
147
+ dilation (int): Spacing between kernel elements. Default: 1
148
+ inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
149
+ bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
150
+
151
+ """
152
+
153
+ def __init__(
154
+ self,
155
+ in_channels: int,
156
+ out_channels: int,
157
+ kernel_size: Union[int, Tuple[int, int]] = 3,
158
+ stride: Union[int, Tuple[int, int]] = 1,
159
+ padding: Optional[Union[int, Tuple[int, int], str]] = None,
160
+ groups: int = 1,
161
+ norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
162
+ activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
163
+ dilation: Union[int, Tuple[int, int]] = 1,
164
+ inplace: Optional[bool] = True,
165
+ bias: Optional[bool] = None,
166
+ ) -> None:
167
+
168
+ super().__init__(
169
+ in_channels,
170
+ out_channels,
171
+ kernel_size,
172
+ stride,
173
+ padding,
174
+ groups,
175
+ norm_layer,
176
+ activation_layer,
177
+ dilation,
178
+ inplace,
179
+ bias,
180
+ torch.nn.Conv2d,
181
+ )
182
+
183
+
184
+ class MLP(torch.nn.Sequential):
185
+ """This block implements the multi-layer perceptron (MLP) module.
186
+
187
+ Args:
188
+ in_channels (int): Number of channels of the input
189
+ hidden_channels (List[int]): List of the hidden channel dimensions
190
+ norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the linear layer. If ``None`` this layer won't be used. Default: ``None``
191
+ activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU``
192
+ inplace (bool, optional): Parameter for the activation layer, which can optionally do the operation in-place.
193
+ Default is ``None``, which uses the respective default values of the ``activation_layer`` and Dropout layer.
194
+ bias (bool): Whether to use bias in the linear layer. Default ``True``
195
+ dropout (float): The probability for the dropout layer. Default: 0.0
196
+ """
197
+
198
+ def __init__(
199
+ self,
200
+ in_channels: int,
201
+ hidden_channels: List[int],
202
+ norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
203
+ activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
204
+ inplace: Optional[bool] = None,
205
+ bias: bool = True,
206
+ dropout: float = 0.0,
207
+ ):
208
+ # The addition of `norm_layer` is inspired from the implementation of TorchMultimodal:
209
+ # https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py
210
+ params = {} if inplace is None else {"inplace": inplace}
211
+
212
+ layers = []
213
+ in_dim = in_channels
214
+ for hidden_dim in hidden_channels[:-1]:
215
+ layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias))
216
+ if norm_layer is not None:
217
+ layers.append(norm_layer(hidden_dim))
218
+ layers.append(activation_layer(**params))
219
+ layers.append(torch.nn.Dropout(dropout, **params))
220
+ in_dim = hidden_dim
221
+
222
+ layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias))
223
+ layers.append(torch.nn.Dropout(dropout, **params))
224
+
225
+ super().__init__(*layers)
226
+ _log_api_usage_once(self)
227
+
228
+
229
+ def conv3x3(
230
+ in_channels: int,
231
+ out_channels: int,
232
+ stride: int = 1,
233
+ groups: int = 1,
234
+ dilation: int = 1,
235
+ ) -> nn.Conv2d:
236
+ """3x3 convolution with padding"""
237
+ return nn.Conv2d(
238
+ in_channels,
239
+ out_channels,
240
+ kernel_size=3,
241
+ stride=stride,
242
+ padding=dilation,
243
+ groups=groups,
244
+ bias=False,
245
+ dilation=dilation,
246
+ )
247
+
248
+
249
+ def conv1x1(in_channels: int, out_channels: int, stride: int = 1) -> nn.Conv2d:
250
+ """1x1 convolution"""
251
+ return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
252
+
253
+
254
+ class BasicBlock(nn.Module):
255
+ expansion: int = 1
256
+
257
+ def __init__(
258
+ self,
259
+ in_channels: int,
260
+ out_channels: int,
261
+ stride: int = 1,
262
+ groups: int = 1,
263
+ base_width: int = 64,
264
+ dilation: int = 1,
265
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
266
+ **kwargs: Any,
267
+ ) -> None:
268
+ super().__init__()
269
+ if norm_layer is None:
270
+ norm_layer = nn.BatchNorm2d
271
+ if groups != 1 or base_width != 64:
272
+ raise ValueError("BasicBlock only supports groups=1 and base_width=64")
273
+ if dilation > 1:
274
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
275
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
276
+ self.conv1 = conv3x3(in_channels, out_channels, stride)
277
+ self.bn1 = norm_layer(out_channels)
278
+ self.relu = nn.ReLU(inplace=True)
279
+ self.conv2 = conv3x3(out_channels, out_channels)
280
+ self.bn2 = norm_layer(out_channels)
281
+ self.stride = stride
282
+ if in_channels != out_channels:
283
+ self.downsample = nn.Sequential(
284
+ conv1x1(in_channels, out_channels),
285
+ nn.BatchNorm2d(out_channels),
286
+ )
287
+ else:
288
+ self.downsample = nn.Identity()
289
+
290
+ def forward(self, x: Tensor) -> Tensor:
291
+ identity = x
292
+
293
+ out = self.conv1(x)
294
+ out = self.bn1(out)
295
+ out = self.relu(out)
296
+
297
+ out = self.conv2(out)
298
+ out = self.bn2(out)
299
+
300
+ out += self.downsample(identity)
301
+ out = self.relu(out)
302
+
303
+ return out
304
+
305
+
306
+ class Bottleneck(nn.Module):
307
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
308
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
309
+ # according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385.
310
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
311
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
312
+ def __init__(
313
+ self,
314
+ in_channels: int,
315
+ out_channels: int,
316
+ stride: int = 1,
317
+ groups: int = 1,
318
+ base_width: int = 64,
319
+ dilation: int = 1,
320
+ expansion: int = 4,
321
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
322
+ **kwargs: Any,
323
+ ) -> None:
324
+ super().__init__()
325
+ if norm_layer is None:
326
+ norm_layer = nn.BatchNorm2d
327
+ width = int(out_channels * (base_width / 64.0)) * groups
328
+ self.expansion = expansion
329
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
330
+ self.conv1 = conv1x1(in_channels, width)
331
+ self.bn1 = norm_layer(width)
332
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
333
+ self.bn2 = norm_layer(width)
334
+ self.conv3 = conv1x1(width, out_channels * self.expansion)
335
+ self.bn3 = norm_layer(out_channels * self.expansion)
336
+ self.relu = nn.ReLU(inplace=True)
337
+ self.stride = stride
338
+ if in_channels != out_channels:
339
+ self.downsample = nn.Sequential(
340
+ conv1x1(in_channels, out_channels),
341
+ nn.BatchNorm2d(out_channels),
342
+ )
343
+ else:
344
+ self.downsample = nn.Identity()
345
+
346
+ def forward(self, x: Tensor) -> Tensor:
347
+ identity = x
348
+
349
+ out = self.conv1(x)
350
+ out = self.bn1(out)
351
+ out = self.relu(out)
352
+
353
+ out = self.conv2(out)
354
+ out = self.bn2(out)
355
+ out = self.relu(out)
356
+
357
+ out = self.conv3(out)
358
+ out = self.bn3(out)
359
+
360
+ out += self.downsample(identity)
361
+ out = self.relu(out)
362
+
363
+ return out
364
+
365
+
366
+ def _init_weights(model: nn.Module) -> None:
367
+ for m in model.modules():
368
+ if isinstance(m, nn.Conv2d):
369
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
370
+ if m.bias is not None:
371
+ nn.init.constant_(m.bias, 0.)
372
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
373
+ nn.init.constant_(m.weight, 1.)
374
+ if m.bias is not None:
375
+ nn.init.constant_(m.bias, 0.)
376
+ elif isinstance(m, nn.Linear):
377
+ nn.init.normal_(m.weight, std=0.01)
378
+ if m.bias is not None:
379
+ nn.init.constant_(m.bias, 0.)
380
+
381
+
382
+ class Upsample(nn.Module):
383
+ def __init__(
384
+ self,
385
+ size: Union[int, Tuple[int, int]] = None,
386
+ scale_factor: Union[float, Tuple[float, float]] = None,
387
+ mode: str = "nearest",
388
+ align_corners: bool = False,
389
+ antialias: bool = False,
390
+ ) -> None:
391
+ super().__init__()
392
+ self.interpolate = partial(
393
+ F.interpolate,
394
+ size=size,
395
+ scale_factor=scale_factor,
396
+ mode=mode,
397
+ align_corners=align_corners,
398
+ antialias=antialias,
399
+ )
400
+
401
+ def forward(self, x: Tensor) -> Tensor:
402
+ return self.interpolate(x)
403
+
404
+
405
+ def make_vgg_layers(cfg: List[Union[str, int]], in_channels: int = 3, batch_norm: bool = False, dilation: int = 1) -> nn.Sequential:
406
+ layers = []
407
+ for v in cfg:
408
+ if v == "M":
409
+ layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
410
+ elif v == "U":
411
+ layers += [Upsample(scale_factor=2, mode="bilinear")]
412
+ else:
413
+ conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=dilation, dilation=dilation)
414
+ if batch_norm:
415
+ layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
416
+ else:
417
+ layers += [conv2d, nn.ReLU(inplace=True)]
418
+ in_channels = v
419
+ return nn.Sequential(*layers)
420
+
421
+
422
+ def make_resnet_layers(
423
+ block: Union[BasicBlock, Bottleneck],
424
+ cfg: List[Union[int, str]],
425
+ in_channels: int,
426
+ dilation: int = 1,
427
+ expansion: int = 1,
428
+ ) -> nn.Sequential:
429
+ layers = []
430
+ for v in cfg:
431
+ if v == "U":
432
+ layers.append(Upsample(scale_factor=2, mode="bilinear"))
433
+ else:
434
+ layers.append(block(
435
+ in_channels=in_channels,
436
+ out_channels=v,
437
+ dilation=dilation,
438
+ expansion=expansion,
439
+ ))
440
+ in_channels = v
441
+
442
+ layers = nn.Sequential(*layers)
443
+ layers.apply(_init_weights)
444
+ return layers